S3: Store and return ServerSideEncryption and KMS Key Id for Multiparts (#5393)

This commit is contained in:
Cristopher Pinzón 2022-08-23 16:08:37 -05:00 committed by GitHub
parent 7affaf3e52
commit 126ac1777a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 100 additions and 6 deletions

View File

@ -343,7 +343,16 @@ class FakeKey(BaseModel, ManagedState):
class FakeMultipart(BaseModel):
def __init__(self, key_name, metadata, storage=None, tags=None, acl=None):
def __init__(
self,
key_name,
metadata,
storage=None,
tags=None,
acl=None,
sse_encryption=None,
kms_key_id=None,
):
self.key_name = key_name
self.metadata = metadata
self.storage = storage
@ -355,6 +364,8 @@ class FakeMultipart(BaseModel):
self.id = (
rand_b64.decode("utf-8").replace("=", "").replace("+", "").replace("/", "")
)
self.sse_encryption = sse_encryption
self.kms_key_id = kms_key_id
def complete(self, body):
decode_hex = codecs.getdecoder("hex_codec")
@ -389,7 +400,9 @@ class FakeMultipart(BaseModel):
if part_id < 1:
raise NoSuchUpload(upload_id=part_id)
key = FakeKey(part_id, value)
key = FakeKey(
part_id, value, encryption=self.sse_encryption, kms_key_id=self.kms_key_id
)
self.parts[part_id] = key
if part_id not in self.partlist:
insort(self.partlist, part_id)
@ -1928,10 +1941,24 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
return len(bucket.multiparts[multipart_id].parts) > next_part_number_marker
def create_multipart_upload(
self, bucket_name, key_name, metadata, storage_type, tags, acl
self,
bucket_name,
key_name,
metadata,
storage_type,
tags,
acl,
sse_encryption,
kms_key_id,
):
multipart = FakeMultipart(
key_name, metadata, storage=storage_type, tags=tags, acl=acl
key_name,
metadata,
storage=storage_type,
tags=tags,
acl=acl,
sse_encryption=sse_encryption,
kms_key_id=kms_key_id,
)
bucket = self.get_bucket(bucket_name)

View File

@ -1929,20 +1929,38 @@ class S3Response(BaseResponse):
self._set_action("KEY", "POST", query)
self._authenticate_and_authorize_s3_action()
encryption = request.headers.get("x-amz-server-side-encryption")
kms_key_id = request.headers.get("x-amz-server-side-encryption-aws-kms-key-id")
if body == b"" and "uploads" in query:
response_headers = {}
metadata = metadata_from_headers(request.headers)
tagging = self._tagging_from_headers(request.headers)
storage_type = request.headers.get("x-amz-storage-class", "STANDARD")
acl = self._acl_from_headers(request.headers)
multipart_id = self.backend.create_multipart_upload(
bucket_name, key_name, metadata, storage_type, tagging, acl
bucket_name,
key_name,
metadata,
storage_type,
tagging,
acl,
encryption,
kms_key_id,
)
if encryption:
response_headers["x-amz-server-side-encryption"] = encryption
if kms_key_id:
response_headers[
"x-amz-server-side-encryption-aws-kms-key-id"
] = kms_key_id
template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE)
response = template.render(
bucket_name=bucket_name, key_name=key_name, upload_id=multipart_id
)
return 200, {}, response
return 200, response_headers, response
if query.get("uploadId"):
body = self._complete_multipart_body(body)
@ -1961,6 +1979,8 @@ class S3Response(BaseResponse):
storage=multipart.storage,
etag=etag,
multipart=multipart,
encryption=multipart.sse_encryption,
kms_key_id=multipart.kms_key_id,
)
key.set_metadata(multipart.metadata)
self.backend.set_key_tags(key, multipart.tags)
@ -1970,6 +1990,13 @@ class S3Response(BaseResponse):
headers = {}
if key.version_id:
headers["x-amz-version-id"] = key.version_id
if key.encryption:
headers["x-amz-server-side-encryption"] = key.encryption
if key.kms_key_id:
headers["x-amz-server-side-encryption-aws-kms-key-id"] = key.kms_key_id
return (
200,
headers,

View File

@ -885,3 +885,43 @@ def test_complete_multipart_with_empty_partlist():
err["Message"].should.equal(
"The XML you provided was not well-formed or did not validate against our published schema"
)
@mock_s3
def test_ssm_key_headers_in_create_multipart():
s3_client = boto3.client("s3", region_name=DEFAULT_REGION_NAME)
bucket_name = "ssm-headers-bucket"
s3_client.create_bucket(Bucket=bucket_name)
kms_key_id = "random-id"
key_name = "test-file.txt"
create_multipart_response = s3_client.create_multipart_upload(
Bucket=bucket_name,
Key=key_name,
ServerSideEncryption="aws:kms",
SSEKMSKeyId=kms_key_id,
)
assert create_multipart_response["ServerSideEncryption"] == "aws:kms"
assert create_multipart_response["SSEKMSKeyId"] == kms_key_id
upload_part_response = s3_client.upload_part(
Body=b"bytes",
Bucket=bucket_name,
Key=key_name,
PartNumber=1,
UploadId=create_multipart_response["UploadId"],
)
assert upload_part_response["ServerSideEncryption"] == "aws:kms"
assert upload_part_response["SSEKMSKeyId"] == kms_key_id
parts = {"Parts": [{"PartNumber": 1, "ETag": upload_part_response["ETag"]}]}
complete_multipart_response = s3_client.complete_multipart_upload(
Bucket=bucket_name,
Key=key_name,
UploadId=create_multipart_response["UploadId"],
MultipartUpload=parts,
)
assert complete_multipart_response["ServerSideEncryption"] == "aws:kms"
assert complete_multipart_response["SSEKMSKeyId"] == kms_key_id