From 126ac1777a75dc0597f7b37ab4c94e6612e11a5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristopher=20Pinz=C3=B3n?= Date: Tue, 23 Aug 2022 16:08:37 -0500 Subject: [PATCH] S3: Store and return ServerSideEncryption and KMS Key Id for Multiparts (#5393) --- moto/s3/models.py | 35 +++++++++++++++++++++++--- moto/s3/responses.py | 31 +++++++++++++++++++++-- tests/test_s3/test_s3_multipart.py | 40 ++++++++++++++++++++++++++++++ 3 files changed, 100 insertions(+), 6 deletions(-) diff --git a/moto/s3/models.py b/moto/s3/models.py index eb599e0bf..bf7b48939 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -343,7 +343,16 @@ class FakeKey(BaseModel, ManagedState): class FakeMultipart(BaseModel): - def __init__(self, key_name, metadata, storage=None, tags=None, acl=None): + def __init__( + self, + key_name, + metadata, + storage=None, + tags=None, + acl=None, + sse_encryption=None, + kms_key_id=None, + ): self.key_name = key_name self.metadata = metadata self.storage = storage @@ -355,6 +364,8 @@ class FakeMultipart(BaseModel): self.id = ( rand_b64.decode("utf-8").replace("=", "").replace("+", "").replace("/", "") ) + self.sse_encryption = sse_encryption + self.kms_key_id = kms_key_id def complete(self, body): decode_hex = codecs.getdecoder("hex_codec") @@ -389,7 +400,9 @@ class FakeMultipart(BaseModel): if part_id < 1: raise NoSuchUpload(upload_id=part_id) - key = FakeKey(part_id, value) + key = FakeKey( + part_id, value, encryption=self.sse_encryption, kms_key_id=self.kms_key_id + ) self.parts[part_id] = key if part_id not in self.partlist: insort(self.partlist, part_id) @@ -1928,10 +1941,24 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): return len(bucket.multiparts[multipart_id].parts) > next_part_number_marker def create_multipart_upload( - self, bucket_name, key_name, metadata, storage_type, tags, acl + self, + bucket_name, + key_name, + metadata, + storage_type, + tags, + acl, + sse_encryption, + kms_key_id, ): multipart = FakeMultipart( - key_name, metadata, storage=storage_type, tags=tags, acl=acl + key_name, + metadata, + storage=storage_type, + tags=tags, + acl=acl, + sse_encryption=sse_encryption, + kms_key_id=kms_key_id, ) bucket = self.get_bucket(bucket_name) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index e825faa56..50a753b8c 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -1929,20 +1929,38 @@ class S3Response(BaseResponse): self._set_action("KEY", "POST", query) self._authenticate_and_authorize_s3_action() + encryption = request.headers.get("x-amz-server-side-encryption") + kms_key_id = request.headers.get("x-amz-server-side-encryption-aws-kms-key-id") + if body == b"" and "uploads" in query: + response_headers = {} metadata = metadata_from_headers(request.headers) tagging = self._tagging_from_headers(request.headers) storage_type = request.headers.get("x-amz-storage-class", "STANDARD") acl = self._acl_from_headers(request.headers) + multipart_id = self.backend.create_multipart_upload( - bucket_name, key_name, metadata, storage_type, tagging, acl + bucket_name, + key_name, + metadata, + storage_type, + tagging, + acl, + encryption, + kms_key_id, ) + if encryption: + response_headers["x-amz-server-side-encryption"] = encryption + if kms_key_id: + response_headers[ + "x-amz-server-side-encryption-aws-kms-key-id" + ] = kms_key_id template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE) response = template.render( bucket_name=bucket_name, key_name=key_name, upload_id=multipart_id ) - return 200, {}, response + return 200, response_headers, response if query.get("uploadId"): body = self._complete_multipart_body(body) @@ -1961,6 +1979,8 @@ class S3Response(BaseResponse): storage=multipart.storage, etag=etag, multipart=multipart, + encryption=multipart.sse_encryption, + kms_key_id=multipart.kms_key_id, ) key.set_metadata(multipart.metadata) self.backend.set_key_tags(key, multipart.tags) @@ -1970,6 +1990,13 @@ class S3Response(BaseResponse): headers = {} if key.version_id: headers["x-amz-version-id"] = key.version_id + + if key.encryption: + headers["x-amz-server-side-encryption"] = key.encryption + + if key.kms_key_id: + headers["x-amz-server-side-encryption-aws-kms-key-id"] = key.kms_key_id + return ( 200, headers, diff --git a/tests/test_s3/test_s3_multipart.py b/tests/test_s3/test_s3_multipart.py index d8a0fd82c..cee620a91 100644 --- a/tests/test_s3/test_s3_multipart.py +++ b/tests/test_s3/test_s3_multipart.py @@ -885,3 +885,43 @@ def test_complete_multipart_with_empty_partlist(): err["Message"].should.equal( "The XML you provided was not well-formed or did not validate against our published schema" ) + + +@mock_s3 +def test_ssm_key_headers_in_create_multipart(): + s3_client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + + bucket_name = "ssm-headers-bucket" + s3_client.create_bucket(Bucket=bucket_name) + + kms_key_id = "random-id" + key_name = "test-file.txt" + + create_multipart_response = s3_client.create_multipart_upload( + Bucket=bucket_name, + Key=key_name, + ServerSideEncryption="aws:kms", + SSEKMSKeyId=kms_key_id, + ) + assert create_multipart_response["ServerSideEncryption"] == "aws:kms" + assert create_multipart_response["SSEKMSKeyId"] == kms_key_id + + upload_part_response = s3_client.upload_part( + Body=b"bytes", + Bucket=bucket_name, + Key=key_name, + PartNumber=1, + UploadId=create_multipart_response["UploadId"], + ) + assert upload_part_response["ServerSideEncryption"] == "aws:kms" + assert upload_part_response["SSEKMSKeyId"] == kms_key_id + + parts = {"Parts": [{"PartNumber": 1, "ETag": upload_part_response["ETag"]}]} + complete_multipart_response = s3_client.complete_multipart_upload( + Bucket=bucket_name, + Key=key_name, + UploadId=create_multipart_response["UploadId"], + MultipartUpload=parts, + ) + assert complete_multipart_response["ServerSideEncryption"] == "aws:kms" + assert complete_multipart_response["SSEKMSKeyId"] == kms_key_id