S3: Store and return ServerSideEncryption and KMS Key Id for Multiparts (#5393)

This commit is contained in:
Cristopher Pinzón 2022-08-23 16:08:37 -05:00 committed by GitHub
parent 7affaf3e52
commit 126ac1777a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 100 additions and 6 deletions

View File

@ -343,7 +343,16 @@ class FakeKey(BaseModel, ManagedState):
class FakeMultipart(BaseModel): class FakeMultipart(BaseModel):
def __init__(self, key_name, metadata, storage=None, tags=None, acl=None): def __init__(
self,
key_name,
metadata,
storage=None,
tags=None,
acl=None,
sse_encryption=None,
kms_key_id=None,
):
self.key_name = key_name self.key_name = key_name
self.metadata = metadata self.metadata = metadata
self.storage = storage self.storage = storage
@ -355,6 +364,8 @@ class FakeMultipart(BaseModel):
self.id = ( self.id = (
rand_b64.decode("utf-8").replace("=", "").replace("+", "").replace("/", "") rand_b64.decode("utf-8").replace("=", "").replace("+", "").replace("/", "")
) )
self.sse_encryption = sse_encryption
self.kms_key_id = kms_key_id
def complete(self, body): def complete(self, body):
decode_hex = codecs.getdecoder("hex_codec") decode_hex = codecs.getdecoder("hex_codec")
@ -389,7 +400,9 @@ class FakeMultipart(BaseModel):
if part_id < 1: if part_id < 1:
raise NoSuchUpload(upload_id=part_id) raise NoSuchUpload(upload_id=part_id)
key = FakeKey(part_id, value) key = FakeKey(
part_id, value, encryption=self.sse_encryption, kms_key_id=self.kms_key_id
)
self.parts[part_id] = key self.parts[part_id] = key
if part_id not in self.partlist: if part_id not in self.partlist:
insort(self.partlist, part_id) insort(self.partlist, part_id)
@ -1928,10 +1941,24 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
return len(bucket.multiparts[multipart_id].parts) > next_part_number_marker return len(bucket.multiparts[multipart_id].parts) > next_part_number_marker
def create_multipart_upload( def create_multipart_upload(
self, bucket_name, key_name, metadata, storage_type, tags, acl self,
bucket_name,
key_name,
metadata,
storage_type,
tags,
acl,
sse_encryption,
kms_key_id,
): ):
multipart = FakeMultipart( multipart = FakeMultipart(
key_name, metadata, storage=storage_type, tags=tags, acl=acl key_name,
metadata,
storage=storage_type,
tags=tags,
acl=acl,
sse_encryption=sse_encryption,
kms_key_id=kms_key_id,
) )
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)

View File

@ -1929,20 +1929,38 @@ class S3Response(BaseResponse):
self._set_action("KEY", "POST", query) self._set_action("KEY", "POST", query)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
encryption = request.headers.get("x-amz-server-side-encryption")
kms_key_id = request.headers.get("x-amz-server-side-encryption-aws-kms-key-id")
if body == b"" and "uploads" in query: if body == b"" and "uploads" in query:
response_headers = {}
metadata = metadata_from_headers(request.headers) metadata = metadata_from_headers(request.headers)
tagging = self._tagging_from_headers(request.headers) tagging = self._tagging_from_headers(request.headers)
storage_type = request.headers.get("x-amz-storage-class", "STANDARD") storage_type = request.headers.get("x-amz-storage-class", "STANDARD")
acl = self._acl_from_headers(request.headers) acl = self._acl_from_headers(request.headers)
multipart_id = self.backend.create_multipart_upload( multipart_id = self.backend.create_multipart_upload(
bucket_name, key_name, metadata, storage_type, tagging, acl bucket_name,
key_name,
metadata,
storage_type,
tagging,
acl,
encryption,
kms_key_id,
) )
if encryption:
response_headers["x-amz-server-side-encryption"] = encryption
if kms_key_id:
response_headers[
"x-amz-server-side-encryption-aws-kms-key-id"
] = kms_key_id
template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE) template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE)
response = template.render( response = template.render(
bucket_name=bucket_name, key_name=key_name, upload_id=multipart_id bucket_name=bucket_name, key_name=key_name, upload_id=multipart_id
) )
return 200, {}, response return 200, response_headers, response
if query.get("uploadId"): if query.get("uploadId"):
body = self._complete_multipart_body(body) body = self._complete_multipart_body(body)
@ -1961,6 +1979,8 @@ class S3Response(BaseResponse):
storage=multipart.storage, storage=multipart.storage,
etag=etag, etag=etag,
multipart=multipart, multipart=multipart,
encryption=multipart.sse_encryption,
kms_key_id=multipart.kms_key_id,
) )
key.set_metadata(multipart.metadata) key.set_metadata(multipart.metadata)
self.backend.set_key_tags(key, multipart.tags) self.backend.set_key_tags(key, multipart.tags)
@ -1970,6 +1990,13 @@ class S3Response(BaseResponse):
headers = {} headers = {}
if key.version_id: if key.version_id:
headers["x-amz-version-id"] = key.version_id headers["x-amz-version-id"] = key.version_id
if key.encryption:
headers["x-amz-server-side-encryption"] = key.encryption
if key.kms_key_id:
headers["x-amz-server-side-encryption-aws-kms-key-id"] = key.kms_key_id
return ( return (
200, 200,
headers, headers,

View File

@ -885,3 +885,43 @@ def test_complete_multipart_with_empty_partlist():
err["Message"].should.equal( err["Message"].should.equal(
"The XML you provided was not well-formed or did not validate against our published schema" "The XML you provided was not well-formed or did not validate against our published schema"
) )
@mock_s3
def test_ssm_key_headers_in_create_multipart():
s3_client = boto3.client("s3", region_name=DEFAULT_REGION_NAME)
bucket_name = "ssm-headers-bucket"
s3_client.create_bucket(Bucket=bucket_name)
kms_key_id = "random-id"
key_name = "test-file.txt"
create_multipart_response = s3_client.create_multipart_upload(
Bucket=bucket_name,
Key=key_name,
ServerSideEncryption="aws:kms",
SSEKMSKeyId=kms_key_id,
)
assert create_multipart_response["ServerSideEncryption"] == "aws:kms"
assert create_multipart_response["SSEKMSKeyId"] == kms_key_id
upload_part_response = s3_client.upload_part(
Body=b"bytes",
Bucket=bucket_name,
Key=key_name,
PartNumber=1,
UploadId=create_multipart_response["UploadId"],
)
assert upload_part_response["ServerSideEncryption"] == "aws:kms"
assert upload_part_response["SSEKMSKeyId"] == kms_key_id
parts = {"Parts": [{"PartNumber": 1, "ETag": upload_part_response["ETag"]}]}
complete_multipart_response = s3_client.complete_multipart_upload(
Bucket=bucket_name,
Key=key_name,
UploadId=create_multipart_response["UploadId"],
MultipartUpload=parts,
)
assert complete_multipart_response["ServerSideEncryption"] == "aws:kms"
assert complete_multipart_response["SSEKMSKeyId"] == kms_key_id