From 90850bc57314b2c7fa62b4917577b1a24a528f25 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 10 Dec 2023 15:26:26 -0100 Subject: [PATCH] S3: copy() using multiparts should respect ExtraArgs (#7110) --- IMPLEMENTATION_COVERAGE.md | 4 +-- docs/docs/services/s3.rst | 2 +- moto/s3/models.py | 25 +++++++++---- moto/s3/responses.py | 15 ++++++-- moto/s3/utils.py | 13 +++++-- tests/test_s3/test_s3_copyobject.py | 56 +++++++++++++++++++++++++---- 6 files changed, 95 insertions(+), 20 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index af9ffc855..73285ec2e 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -6257,7 +6257,7 @@ ## s3
-64% implemented +65% implemented - [X] abort_multipart_upload - [X] complete_multipart_upload @@ -6356,7 +6356,7 @@ - [ ] restore_object - [X] select_object_content - [X] upload_part -- [ ] upload_part_copy +- [X] upload_part_copy - [ ] write_get_object_response
diff --git a/docs/docs/services/s3.rst b/docs/docs/services/s3.rst index 89796119d..67639282f 100644 --- a/docs/docs/services/s3.rst +++ b/docs/docs/services/s3.rst @@ -159,6 +159,6 @@ s3 - [X] upload_part -- [ ] upload_part_copy +- [X] upload_part_copy - [ ] write_get_object_response diff --git a/moto/s3/models.py b/moto/s3/models.py index cde316dfd..4665e96d0 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -73,6 +73,7 @@ from .utils import ( STORAGE_CLASS, CaseInsensitiveDict, _VersionedKeyStore, + compute_checksum, ) MAX_BUCKET_NAME_LENGTH = 63 @@ -399,10 +400,14 @@ class FakeMultipart(BaseModel): self.sse_encryption = sse_encryption self.kms_key_id = kms_key_id - def complete(self, body: Iterator[Tuple[int, str]]) -> Tuple[bytes, str]: + def complete( + self, body: Iterator[Tuple[int, str]] + ) -> Tuple[bytes, str, Optional[str]]: + checksum_algo = self.metadata.get("x-amz-checksum-algorithm") decode_hex = codecs.getdecoder("hex_codec") total = bytearray() md5s = bytearray() + checksum = bytearray() last = None count = 0 @@ -418,6 +423,10 @@ class FakeMultipart(BaseModel): raise EntityTooSmall() md5s.extend(decode_hex(part_etag)[0]) # type: ignore total.extend(part.value) + if checksum_algo: + checksum.extend( + compute_checksum(part.value, checksum_algo, encode_base64=False) + ) last = part count += 1 @@ -426,7 +435,11 @@ class FakeMultipart(BaseModel): full_etag = md5_hash() full_etag.update(bytes(md5s)) - return total, f"{full_etag.hexdigest()}-{count}" + if checksum_algo: + encoded_checksum = compute_checksum(checksum, checksum_algo).decode("utf-8") + else: + encoded_checksum = None + return total, f"{full_etag.hexdigest()}-{count}", encoded_checksum def set_part(self, part_id: int, value: bytes) -> FakeKey: if part_id < 1: @@ -2319,13 +2332,13 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): def complete_multipart_upload( self, bucket_name: str, multipart_id: str, body: Iterator[Tuple[int, str]] - ) -> Tuple[FakeMultipart, bytes, str]: + ) -> Tuple[FakeMultipart, bytes, str, Optional[str]]: bucket = self.get_bucket(bucket_name) multipart = bucket.multiparts[multipart_id] - value, etag = multipart.complete(body) + value, etag, checksum = multipart.complete(body) if value is not None: del bucket.multiparts[multipart_id] - return multipart, value, etag + return multipart, value, etag, checksum def get_all_multiparts(self, bucket_name: str) -> Dict[str, FakeMultipart]: bucket = self.get_bucket(bucket_name) @@ -2338,7 +2351,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): multipart = bucket.multiparts[multipart_id] return multipart.set_part(part_id, value) - def copy_part( + def upload_part_copy( self, dest_bucket_name: str, multipart_id: str, diff --git a/moto/s3/responses.py b/moto/s3/responses.py index d2575fe44..fc589caec 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -1507,7 +1507,7 @@ class S3Response(BaseResponse): if self.backend.get_object( src_bucket, src_key, version_id=src_version_id ): - key = self.backend.copy_part( + key = self.backend.upload_part_copy( bucket_name, upload_id, part_number, @@ -1707,6 +1707,8 @@ class S3Response(BaseResponse): response_headers.update( {"Checksum": {f"Checksum{checksum_algorithm}": checksum_value}} ) + # By default, the checksum-details for the copy will be the same as the original + # But if another algorithm is provided during the copy-operation, we override the values new_key.checksum_algorithm = checksum_algorithm new_key.checksum_value = checksum_value @@ -2252,12 +2254,13 @@ class S3Response(BaseResponse): if query.get("uploadId"): multipart_id = query["uploadId"][0] - multipart, value, etag = self.backend.complete_multipart_upload( + multipart, value, etag, checksum = self.backend.complete_multipart_upload( bucket_name, multipart_id, self._complete_multipart_body(body) ) if value is None: return 400, {}, "" + headers: Dict[str, Any] = {} key = self.backend.put_object( bucket_name, multipart.key_name, @@ -2269,6 +2272,13 @@ class S3Response(BaseResponse): kms_key_id=multipart.kms_key_id, ) key.set_metadata(multipart.metadata) + + if checksum: + key.checksum_algorithm = multipart.metadata.get( + "x-amz-checksum-algorithm" + ) + key.checksum_value = checksum + self.backend.set_key_tags(key, multipart.tags) self.backend.put_object_acl( bucket_name=bucket_name, @@ -2277,7 +2287,6 @@ class S3Response(BaseResponse): ) template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) - headers: Dict[str, Any] = {} if key.version_id: headers["x-amz-version-id"] = key.version_id diff --git a/moto/s3/utils.py b/moto/s3/utils.py index cae502ba8..5946dbc8f 100644 --- a/moto/s3/utils.py +++ b/moto/s3/utils.py @@ -24,6 +24,12 @@ user_settable_fields = { "expires", "content-disposition", "x-robots-tag", + "x-amz-checksum-algorithm", + "x-amz-content-sha256", + "x-amz-content-crc32", + "x-amz-content-crc32c", + "x-amz-content-sha1", + "x-amz-website-redirect-location", } ARCHIVE_STORAGE_CLASSES = [ "GLACIER", @@ -190,7 +196,7 @@ class _VersionedKeyStore(dict): # type: ignore values = itervalues = _itervalues # type: ignore -def compute_checksum(body: bytes, algorithm: str) -> bytes: +def compute_checksum(body: bytes, algorithm: str, encode_base64: bool = True) -> bytes: if algorithm == "SHA1": hashed_body = _hash(hashlib.sha1, (body,)) elif algorithm == "CRC32C": @@ -205,7 +211,10 @@ def compute_checksum(body: bytes, algorithm: str) -> bytes: hashed_body = binascii.crc32(body).to_bytes(4, "big") else: hashed_body = _hash(hashlib.sha256, (body,)) - return base64.b64encode(hashed_body) + if encode_base64: + return base64.b64encode(hashed_body) + else: + return hashed_body def _hash(fn: Any, args: Any) -> bytes: diff --git a/tests/test_s3/test_s3_copyobject.py b/tests/test_s3/test_s3_copyobject.py index a0fca248f..90bfde46f 100644 --- a/tests/test_s3/test_s3_copyobject.py +++ b/tests/test_s3/test_s3_copyobject.py @@ -7,6 +7,8 @@ from botocore.client import ClientError from moto import mock_kms, mock_s3 from moto.s3.responses import DEFAULT_REGION_NAME +from . import s3_aws_verified + @pytest.mark.parametrize( "key_name", @@ -35,25 +37,27 @@ def test_copy_key_boto3(key_name): assert resp["Body"].read() == b"some value" -@mock_s3 -def test_copy_key_boto3_with_sha256_checksum(): +@pytest.mark.aws_verified +@s3_aws_verified +def test_copy_key_boto3_with_args(bucket=None): # Setup s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) key_name = "key" new_key = "new_key" - bucket = "foobar" expected_hash = "qz0H8xacy9DtbEtF3iFRn5+TjHLSQSSZiquUnOg7tRs=" - s3_resource.create_bucket(Bucket=bucket) - key = s3_resource.Object("foobar", key_name) + key = s3_resource.Object(bucket, key_name) key.put(Body=b"some value") # Execute key2 = s3_resource.Object(bucket, new_key) key2.copy( CopySource={"Bucket": bucket, "Key": key_name}, - ExtraArgs={"ChecksumAlgorithm": "SHA256"}, + ExtraArgs={ + "ChecksumAlgorithm": "SHA256", + "WebsiteRedirectLocation": "http://getmoto.org/", + }, ) # Verify @@ -65,6 +69,9 @@ def test_copy_key_boto3_with_sha256_checksum(): assert "ChecksumSHA256" in resp["Checksum"] assert resp["Checksum"]["ChecksumSHA256"] == expected_hash + obj = client.get_object(Bucket=bucket, Key=new_key) + assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/" + # Verify in place copy_in_place = client.copy_object( Bucket=bucket, @@ -78,6 +85,43 @@ def test_copy_key_boto3_with_sha256_checksum(): assert copy_in_place["CopyObjectResult"]["ChecksumSHA256"] == expected_hash +@pytest.mark.aws_verified +@s3_aws_verified +def test_copy_key_boto3_with_args__using_multipart(bucket=None): + # Setup + s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) + client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + key_name = "key" + new_key = "new_key" + expected_hash = "DnKotDi4EtYGwNMDKmnR6SqH3bWVOlo2BC+tsz9rHqw=" + + key = s3_resource.Object(bucket, key_name) + key.put(Body=b"some value") + + # Execute + key2 = s3_resource.Object(bucket, new_key) + key2.copy( + CopySource={"Bucket": bucket, "Key": key_name}, + ExtraArgs={ + "ChecksumAlgorithm": "SHA256", + "WebsiteRedirectLocation": "http://getmoto.org/", + }, + Config=boto3.s3.transfer.TransferConfig(multipart_threshold=1), + ) + + # Verify + resp = client.get_object_attributes( + Bucket=bucket, Key=new_key, ObjectAttributes=["Checksum"] + ) + + assert "Checksum" in resp + assert "ChecksumSHA256" in resp["Checksum"] + assert resp["Checksum"]["ChecksumSHA256"] == expected_hash + + obj = client.get_object(Bucket=bucket, Key=new_key) + assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/" + + @mock_s3 def test_copy_key_with_version_boto3(): s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)