AWSLambda: Layers are now loaded onto the Docker image (#6772)

This commit is contained in:
Bert Blommers 2023-09-05 17:31:13 +00:00 committed by GitHub
parent 7f3a69b7a5
commit bf9bbcc506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 158 additions and 3 deletions

View File

@ -20,6 +20,7 @@ import tarfile
import calendar
import threading
import weakref
import warnings
import requests.exceptions
from moto.awslambda.policy import Policy
@ -149,6 +150,89 @@ class _DockerDataVolumeContext:
raise # multiple processes trying to use same volume?
class _DockerDataVolumeLayerContext:
_data_vol_map: Dict[str, _VolumeRefCount] = defaultdict(
lambda: _VolumeRefCount(0, None)
)
_lock = threading.Lock()
def __init__(self, lambda_func: "LambdaFunction"):
self._lambda_func = lambda_func
self._layers: List[Dict[str, str]] = self._lambda_func.layers
self._vol_ref: Optional[_VolumeRefCount] = None
@property
def name(self) -> str:
return self._vol_ref.volume.name # type: ignore[union-attr]
@property
def hash(self) -> str:
return "-".join(
[
layer["Arn"].split("layer:")[-1].replace(":", "_")
for layer in self._layers
]
)
def __enter__(self) -> "_DockerDataVolumeLayerContext":
# See if volume is already known
with self.__class__._lock:
self._vol_ref = self.__class__._data_vol_map[self.hash]
self._vol_ref.refcount += 1
if self._vol_ref.refcount > 1:
return self
# See if the volume already exists
for vol in self._lambda_func.docker_client.volumes.list():
if vol.name == self.hash:
self._vol_ref.volume = vol
return self
# It doesn't exist so we need to create it
self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(
self.hash
)
# If we don't have any layers to apply, just return at this point
# When invoking the function, we will bind this empty volume
if len(self._layers) == 0:
return self
volumes = {self.name: {"bind": "/opt", "mode": "rw"}}
self._lambda_func.ensure_image_exists("busybox")
container = self._lambda_func.docker_client.containers.run(
"busybox", "sleep 100", volumes=volumes, detach=True
)
backend: "LambdaBackend" = lambda_backends[self._lambda_func.account_id][
self._lambda_func.region
]
try:
for layer in self._layers:
try:
layer_zip = backend.layers_versions_by_arn( # type: ignore[union-attr]
layer["Arn"]
).code_bytes
layer_tar = zip2tar(layer_zip)
container.put_archive("/opt", layer_tar)
except zipfile.BadZipfile as e:
warnings.warn(f"Error extracting layer to Lambda: {e}")
finally:
container.remove(force=True)
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
with self.__class__._lock:
self._vol_ref.refcount -= 1 # type: ignore[union-attr]
if self._vol_ref.refcount == 0: # type: ignore[union-attr]
try:
self._vol_ref.volume.remove() # type: ignore[union-attr]
except docker.errors.APIError as e:
if e.status_code != 409:
raise
raise # multiple processes trying to use same volume?
def _zipfile_content(zipfile_content: Union[str, bytes]) -> Tuple[bytes, int, str, str]:
try:
to_unzip_code = base64.b64decode(bytes(zipfile_content, "utf-8")) # type: ignore[arg-type]
@ -447,7 +531,9 @@ class LambdaFunction(CloudFormationModel, DockerModel):
self.package_type = spec.get("PackageType", None)
self.publish = spec.get("Publish", False) # this is ignored currently
self.timeout = spec.get("Timeout", 3)
self.layers = self._get_layers_data(spec.get("Layers", []))
self.layers: List[Dict[str, str]] = self._get_layers_data(
spec.get("Layers", [])
)
self.signing_profile_version_arn = spec.get("SigningProfileVersionArn")
self.signing_job_arn = spec.get("SigningJobArn")
self.code_signing_config_arn = spec.get("CodeSigningConfigArn")
@ -784,7 +870,9 @@ class LambdaFunction(CloudFormationModel, DockerModel):
container = exit_code = None
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
with _DockerDataVolumeContext(self) as data_vol:
with _DockerDataVolumeContext(
self
) as data_vol, _DockerDataVolumeLayerContext(self) as layer_context:
try:
run_kwargs: Dict[str, Any] = dict()
network_name = settings.moto_network_name()
@ -826,12 +914,16 @@ class LambdaFunction(CloudFormationModel, DockerModel):
break
except docker.errors.NotFound:
pass
volumes = {
data_vol.name: {"bind": "/var/task", "mode": "rw"},
layer_context.name: {"bind": "/opt", "mode": "rw"},
}
container = self.docker_client.containers.run(
image_ref,
[self.handler, json.dumps(event)],
remove=False,
mem_limit=f"{self.memory_size}m",
volumes=[f"{data_vol.name}:/var/task"],
volumes=volumes,
environment=env_vars,
detach=True,
log_config=log_config,

View File

@ -0,0 +1,63 @@
import boto3
import pkgutil
from moto import mock_lambda
from uuid import uuid4
from .utilities import get_role_name, _process_lambda
PYTHON_VERSION = "python3.11"
_lambda_region = "us-west-2"
boto3.setup_default_session(region_name=_lambda_region)
def get_requests_zip_file():
pfunc = """
import requests
def lambda_handler(event, context):
return requests.__version__
"""
return _process_lambda(pfunc)
@mock_lambda
def test_invoke_local_lambda_layers():
conn = boto3.client("lambda", _lambda_region)
lambda_name = str(uuid4())[0:6]
# https://api.klayers.cloud/api/v2/p3.11/layers/latest/us-east-1/json
requests_location = (
"resources/Klayers-p311-requests-a637a171-679b-4057-8a62-0a274b260710.zip"
)
requests_layer = pkgutil.get_data(__name__, requests_location)
layer_arn = conn.publish_layer_version(
LayerName=str(uuid4())[0:6],
Content={"ZipFile": requests_layer},
CompatibleRuntimes=["python3.11"],
LicenseInfo="MIT",
)["LayerArn"]
bogus_layer_arn = conn.publish_layer_version(
LayerName=str(uuid4())[0:6],
Content={"ZipFile": b"zipfile"},
CompatibleRuntimes=["python3.11"],
LicenseInfo="MIT",
)["LayerArn"]
function_arn = conn.create_function(
FunctionName=lambda_name,
Runtime="python3.11",
Role=get_role_name(),
Handler="lambda_function.lambda_handler",
Code={"ZipFile": get_requests_zip_file()},
Timeout=3,
MemorySize=128,
Publish=True,
Layers=[f"{layer_arn}:1", f"{bogus_layer_arn}:1"],
)["FunctionArn"]
success_result = conn.invoke(
FunctionName=function_arn, Payload="{}", LogType="Tail"
)
msg = success_result["Payload"].read().decode("utf-8")
assert msg == '"2.31.0"'