diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py index 8a09ef9a9..7120a71eb 100644 --- a/moto/awslambda/models.py +++ b/moto/awslambda/models.py @@ -479,61 +479,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): datetime.datetime.utcnow() ) - if "ZipFile" in self.code: - ( - self.code_bytes, - self.code_size, - self.code_sha_256, - self.code_digest, - ) = _zipfile_content(self.code["ZipFile"]) - - # TODO: we should be putting this in a lambda bucket - self.code["UUID"] = str(random.uuid4()) - self.code["S3Key"] = f"{self.function_name}-{self.code['UUID']}" - elif "S3Bucket" in self.code: - key = _validate_s3_bucket_and_key(self.account_id, data=self.code) - if key: - ( - self.code_bytes, - self.code_size, - self.code_sha_256, - self.code_digest, - ) = _s3_content(key) - else: - self.code_bytes = b"" - self.code_size = 0 - self.code_sha_256 = "" - elif "ImageUri" in self.code: - if settings.lambda_stub_ecr(): - self.code_sha_256 = hashlib.sha256( - self.code["ImageUri"].encode("utf-8") - ).hexdigest() - self.code_size = 0 - else: - if "@" in self.code["ImageUri"]: - # deploying via digest - uri, digest = self.code["ImageUri"].split("@") - image_id = {"imageDigest": digest} - else: - # deploying via tag - uri, tag = self.code["ImageUri"].split(":") - image_id = {"imageTag": tag} - - repo_name = uri.split("/")[-1] - ecr_backend = ecr_backends[self.account_id][self.region] - registry_id = ecr_backend.describe_registry()["registryId"] - images = ecr_backend.batch_get_image( - repository_name=repo_name, image_ids=[image_id] - )["images"] - - if len(images) == 0: - raise ImageNotFoundException(image_id, repo_name, registry_id) # type: ignore - else: - manifest = json.loads(images[0]["imageManifest"]) - self.code_sha_256 = images[0]["imageId"]["imageDigest"].replace( - "sha256:", "" - ) - self.code_size = manifest["config"]["size"] + self._set_function_code(self.code) self.function_arn = make_function_arn( self.region, self.account_id, self.function_name @@ -711,12 +657,16 @@ class LambdaFunction(CloudFormationModel, DockerModel): return self.get_configuration() - def update_function_code(self, updated_spec: Dict[str, Any]) -> Dict[str, Any]: - if "DryRun" in updated_spec and updated_spec["DryRun"]: - return self.get_configuration() + def _set_function_code(self, updated_spec: Dict[str, Any]) -> None: + from_update = updated_spec is not self.code + + # "DryRun" is only used for UpdateFunctionCode + if from_update and "DryRun" in updated_spec and updated_spec["DryRun"]: + return if "ZipFile" in updated_spec: - self.code["ZipFile"] = updated_spec["ZipFile"] + if from_update: + self.code["ZipFile"] = updated_spec["ZipFile"] ( self.code_bytes, @@ -731,10 +681,13 @@ class LambdaFunction(CloudFormationModel, DockerModel): elif "S3Bucket" in updated_spec and "S3Key" in updated_spec: key = None try: - # FIXME: does not validate bucket region - key = s3_backends[self.account_id]["global"].get_object( - updated_spec["S3Bucket"], updated_spec["S3Key"] - ) + if from_update: + # FIXME: does not validate bucket region + key = s3_backends[self.account_id]["global"].get_object( + updated_spec["S3Bucket"], updated_spec["S3Key"] + ) + else: + key = _validate_s3_bucket_and_key(self.account_id, data=self.code) except MissingBucket: if do_validate_s3(): raise ValueError( @@ -754,9 +707,49 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.code_sha_256, self.code_digest, ) = _s3_content(key) + else: + self.code_bytes = b"" + self.code_size = 0 + self.code_sha_256 = "" + if from_update: self.code["S3Bucket"] = updated_spec["S3Bucket"] self.code["S3Key"] = updated_spec["S3Key"] + elif "ImageUri" in updated_spec: + if settings.lambda_stub_ecr(): + self.code_sha_256 = hashlib.sha256( + updated_spec["ImageUri"].encode("utf-8") + ).hexdigest() + self.code_size = 0 + else: + if "@" in updated_spec["ImageUri"]: + # deploying via digest + uri, digest = updated_spec["ImageUri"].split("@") + image_id = {"imageDigest": digest} + else: + # deploying via tag + uri, tag = updated_spec["ImageUri"].split(":") + image_id = {"imageTag": tag} + repo_name = uri.split("/")[-1] + ecr_backend = ecr_backends[self.account_id][self.region] + registry_id = ecr_backend.describe_registry()["registryId"] + images = ecr_backend.batch_get_image( + repository_name=repo_name, image_ids=[image_id] + )["images"] + + if len(images) == 0: + raise ImageNotFoundException(image_id, repo_name, registry_id) # type: ignore + else: + manifest = json.loads(images[0]["imageManifest"]) + self.code_sha_256 = images[0]["imageId"]["imageDigest"].replace( + "sha256:", "" + ) + self.code_size = manifest["config"]["size"] + if from_update: + self.code["ImageUri"] = updated_spec["ImageUri"] + + def update_function_code(self, updated_spec: Dict[str, Any]) -> Dict[str, Any]: + self._set_function_code(updated_spec) return self.get_configuration() @staticmethod diff --git a/tests/test_awslambda/test_lambda.py b/tests/test_awslambda/test_lambda.py index 842af314b..6b86206bf 100644 --- a/tests/test_awslambda/test_lambda.py +++ b/tests/test_awslambda/test_lambda.py @@ -1459,6 +1459,62 @@ def test_update_function_s3(): assert config["LastUpdateStatus"] == "Successful" +@mock_lambda +def test_update_function_ecr(): + conn = boto3.client("lambda", _lambda_region) + function_name = str(uuid4())[0:6] + image_uri = f"{ACCOUNT_ID}.dkr.ecr.us-east-1.amazonaws.com/testlambdaecr:prod" + image_config = { + "EntryPoint": [ + "python", + ], + "Command": [ + "/opt/app.py", + ], + "WorkingDirectory": "/opt", + } + + conn.create_function( + FunctionName=function_name, + Role=get_role_name(), + Code={"ImageUri": image_uri}, + Description="test lambda function", + ImageConfig=image_config, + Timeout=3, + MemorySize=128, + Publish=True, + ) + + new_uri = image_uri.replace("prod", "newer") + + conn.update_function_code( + FunctionName=function_name, + ImageUri=new_uri, + Publish=True, + ) + + response = conn.get_function(FunctionName=function_name, Qualifier="2") + + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert len(response["Code"]) == 3 + assert response["Code"]["RepositoryType"] == "ECR" + assert response["Code"]["ImageUri"] == new_uri + assert response["Code"]["ResolvedImageUri"].endswith( + hashlib.sha256(new_uri.encode("utf-8")).hexdigest() + ) + + config = response["Configuration"] + assert config["CodeSize"] == 0 + assert config["Description"] == "test lambda function" + assert ( + config["FunctionArn"] + == f"arn:aws:lambda:{_lambda_region}:{ACCOUNT_ID}:function:{function_name}:2" + ) + assert config["FunctionName"] == function_name + assert config["Version"] == "2" + assert config["LastUpdateStatus"] == "Successful" + + @mock_lambda def test_create_function_with_invalid_arn(): err = create_invalid_lambda("test-iam-role")