S3: copy() using multiparts should respect ExtraArgs (#7110)
This commit is contained in:
		
							parent
							
								
									bd93d87134
								
							
						
					
					
						commit
						90850bc573
					
				| @ -6257,7 +6257,7 @@ | |||||||
| 
 | 
 | ||||||
| ## s3 | ## s3 | ||||||
| <details> | <details> | ||||||
| <summary>64% implemented</summary> | <summary>65% implemented</summary> | ||||||
| 
 | 
 | ||||||
| - [X] abort_multipart_upload | - [X] abort_multipart_upload | ||||||
| - [X] complete_multipart_upload | - [X] complete_multipart_upload | ||||||
| @ -6356,7 +6356,7 @@ | |||||||
| - [ ] restore_object | - [ ] restore_object | ||||||
| - [X] select_object_content | - [X] select_object_content | ||||||
| - [X] upload_part | - [X] upload_part | ||||||
| - [ ] upload_part_copy | - [X] upload_part_copy | ||||||
| - [ ] write_get_object_response | - [ ] write_get_object_response | ||||||
| </details> | </details> | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -159,6 +159,6 @@ s3 | |||||||
|          |          | ||||||
| 
 | 
 | ||||||
| - [X] upload_part | - [X] upload_part | ||||||
| - [ ] upload_part_copy | - [X] upload_part_copy | ||||||
| - [ ] write_get_object_response | - [ ] write_get_object_response | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -73,6 +73,7 @@ from .utils import ( | |||||||
|     STORAGE_CLASS, |     STORAGE_CLASS, | ||||||
|     CaseInsensitiveDict, |     CaseInsensitiveDict, | ||||||
|     _VersionedKeyStore, |     _VersionedKeyStore, | ||||||
|  |     compute_checksum, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| MAX_BUCKET_NAME_LENGTH = 63 | MAX_BUCKET_NAME_LENGTH = 63 | ||||||
| @ -399,10 +400,14 @@ class FakeMultipart(BaseModel): | |||||||
|         self.sse_encryption = sse_encryption |         self.sse_encryption = sse_encryption | ||||||
|         self.kms_key_id = kms_key_id |         self.kms_key_id = kms_key_id | ||||||
| 
 | 
 | ||||||
|     def complete(self, body: Iterator[Tuple[int, str]]) -> Tuple[bytes, str]: |     def complete( | ||||||
|  |         self, body: Iterator[Tuple[int, str]] | ||||||
|  |     ) -> Tuple[bytes, str, Optional[str]]: | ||||||
|  |         checksum_algo = self.metadata.get("x-amz-checksum-algorithm") | ||||||
|         decode_hex = codecs.getdecoder("hex_codec") |         decode_hex = codecs.getdecoder("hex_codec") | ||||||
|         total = bytearray() |         total = bytearray() | ||||||
|         md5s = bytearray() |         md5s = bytearray() | ||||||
|  |         checksum = bytearray() | ||||||
| 
 | 
 | ||||||
|         last = None |         last = None | ||||||
|         count = 0 |         count = 0 | ||||||
| @ -418,6 +423,10 @@ class FakeMultipart(BaseModel): | |||||||
|                 raise EntityTooSmall() |                 raise EntityTooSmall() | ||||||
|             md5s.extend(decode_hex(part_etag)[0])  # type: ignore |             md5s.extend(decode_hex(part_etag)[0])  # type: ignore | ||||||
|             total.extend(part.value) |             total.extend(part.value) | ||||||
|  |             if checksum_algo: | ||||||
|  |                 checksum.extend( | ||||||
|  |                     compute_checksum(part.value, checksum_algo, encode_base64=False) | ||||||
|  |                 ) | ||||||
|             last = part |             last = part | ||||||
|             count += 1 |             count += 1 | ||||||
| 
 | 
 | ||||||
| @ -426,7 +435,11 @@ class FakeMultipart(BaseModel): | |||||||
| 
 | 
 | ||||||
|         full_etag = md5_hash() |         full_etag = md5_hash() | ||||||
|         full_etag.update(bytes(md5s)) |         full_etag.update(bytes(md5s)) | ||||||
|         return total, f"{full_etag.hexdigest()}-{count}" |         if checksum_algo: | ||||||
|  |             encoded_checksum = compute_checksum(checksum, checksum_algo).decode("utf-8") | ||||||
|  |         else: | ||||||
|  |             encoded_checksum = None | ||||||
|  |         return total, f"{full_etag.hexdigest()}-{count}", encoded_checksum | ||||||
| 
 | 
 | ||||||
|     def set_part(self, part_id: int, value: bytes) -> FakeKey: |     def set_part(self, part_id: int, value: bytes) -> FakeKey: | ||||||
|         if part_id < 1: |         if part_id < 1: | ||||||
| @ -2319,13 +2332,13 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): | |||||||
| 
 | 
 | ||||||
|     def complete_multipart_upload( |     def complete_multipart_upload( | ||||||
|         self, bucket_name: str, multipart_id: str, body: Iterator[Tuple[int, str]] |         self, bucket_name: str, multipart_id: str, body: Iterator[Tuple[int, str]] | ||||||
|     ) -> Tuple[FakeMultipart, bytes, str]: |     ) -> Tuple[FakeMultipart, bytes, str, Optional[str]]: | ||||||
|         bucket = self.get_bucket(bucket_name) |         bucket = self.get_bucket(bucket_name) | ||||||
|         multipart = bucket.multiparts[multipart_id] |         multipart = bucket.multiparts[multipart_id] | ||||||
|         value, etag = multipart.complete(body) |         value, etag, checksum = multipart.complete(body) | ||||||
|         if value is not None: |         if value is not None: | ||||||
|             del bucket.multiparts[multipart_id] |             del bucket.multiparts[multipart_id] | ||||||
|         return multipart, value, etag |         return multipart, value, etag, checksum | ||||||
| 
 | 
 | ||||||
|     def get_all_multiparts(self, bucket_name: str) -> Dict[str, FakeMultipart]: |     def get_all_multiparts(self, bucket_name: str) -> Dict[str, FakeMultipart]: | ||||||
|         bucket = self.get_bucket(bucket_name) |         bucket = self.get_bucket(bucket_name) | ||||||
| @ -2338,7 +2351,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): | |||||||
|         multipart = bucket.multiparts[multipart_id] |         multipart = bucket.multiparts[multipart_id] | ||||||
|         return multipart.set_part(part_id, value) |         return multipart.set_part(part_id, value) | ||||||
| 
 | 
 | ||||||
|     def copy_part( |     def upload_part_copy( | ||||||
|         self, |         self, | ||||||
|         dest_bucket_name: str, |         dest_bucket_name: str, | ||||||
|         multipart_id: str, |         multipart_id: str, | ||||||
|  | |||||||
| @ -1507,7 +1507,7 @@ class S3Response(BaseResponse): | |||||||
|                 if self.backend.get_object( |                 if self.backend.get_object( | ||||||
|                     src_bucket, src_key, version_id=src_version_id |                     src_bucket, src_key, version_id=src_version_id | ||||||
|                 ): |                 ): | ||||||
|                     key = self.backend.copy_part( |                     key = self.backend.upload_part_copy( | ||||||
|                         bucket_name, |                         bucket_name, | ||||||
|                         upload_id, |                         upload_id, | ||||||
|                         part_number, |                         part_number, | ||||||
| @ -1707,6 +1707,8 @@ class S3Response(BaseResponse): | |||||||
|                 response_headers.update( |                 response_headers.update( | ||||||
|                     {"Checksum": {f"Checksum{checksum_algorithm}": checksum_value}} |                     {"Checksum": {f"Checksum{checksum_algorithm}": checksum_value}} | ||||||
|                 ) |                 ) | ||||||
|  |                 # By default, the checksum-details for the copy will be the same as the original | ||||||
|  |                 # But if another algorithm is provided during the copy-operation, we override the values | ||||||
|                 new_key.checksum_algorithm = checksum_algorithm |                 new_key.checksum_algorithm = checksum_algorithm | ||||||
|                 new_key.checksum_value = checksum_value |                 new_key.checksum_value = checksum_value | ||||||
| 
 | 
 | ||||||
| @ -2252,12 +2254,13 @@ class S3Response(BaseResponse): | |||||||
|         if query.get("uploadId"): |         if query.get("uploadId"): | ||||||
|             multipart_id = query["uploadId"][0] |             multipart_id = query["uploadId"][0] | ||||||
| 
 | 
 | ||||||
|             multipart, value, etag = self.backend.complete_multipart_upload( |             multipart, value, etag, checksum = self.backend.complete_multipart_upload( | ||||||
|                 bucket_name, multipart_id, self._complete_multipart_body(body) |                 bucket_name, multipart_id, self._complete_multipart_body(body) | ||||||
|             ) |             ) | ||||||
|             if value is None: |             if value is None: | ||||||
|                 return 400, {}, "" |                 return 400, {}, "" | ||||||
| 
 | 
 | ||||||
|  |             headers: Dict[str, Any] = {} | ||||||
|             key = self.backend.put_object( |             key = self.backend.put_object( | ||||||
|                 bucket_name, |                 bucket_name, | ||||||
|                 multipart.key_name, |                 multipart.key_name, | ||||||
| @ -2269,6 +2272,13 @@ class S3Response(BaseResponse): | |||||||
|                 kms_key_id=multipart.kms_key_id, |                 kms_key_id=multipart.kms_key_id, | ||||||
|             ) |             ) | ||||||
|             key.set_metadata(multipart.metadata) |             key.set_metadata(multipart.metadata) | ||||||
|  | 
 | ||||||
|  |             if checksum: | ||||||
|  |                 key.checksum_algorithm = multipart.metadata.get( | ||||||
|  |                     "x-amz-checksum-algorithm" | ||||||
|  |                 ) | ||||||
|  |                 key.checksum_value = checksum | ||||||
|  | 
 | ||||||
|             self.backend.set_key_tags(key, multipart.tags) |             self.backend.set_key_tags(key, multipart.tags) | ||||||
|             self.backend.put_object_acl( |             self.backend.put_object_acl( | ||||||
|                 bucket_name=bucket_name, |                 bucket_name=bucket_name, | ||||||
| @ -2277,7 +2287,6 @@ class S3Response(BaseResponse): | |||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|             template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) |             template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) | ||||||
|             headers: Dict[str, Any] = {} |  | ||||||
|             if key.version_id: |             if key.version_id: | ||||||
|                 headers["x-amz-version-id"] = key.version_id |                 headers["x-amz-version-id"] = key.version_id | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -24,6 +24,12 @@ user_settable_fields = { | |||||||
|     "expires", |     "expires", | ||||||
|     "content-disposition", |     "content-disposition", | ||||||
|     "x-robots-tag", |     "x-robots-tag", | ||||||
|  |     "x-amz-checksum-algorithm", | ||||||
|  |     "x-amz-content-sha256", | ||||||
|  |     "x-amz-content-crc32", | ||||||
|  |     "x-amz-content-crc32c", | ||||||
|  |     "x-amz-content-sha1", | ||||||
|  |     "x-amz-website-redirect-location", | ||||||
| } | } | ||||||
| ARCHIVE_STORAGE_CLASSES = [ | ARCHIVE_STORAGE_CLASSES = [ | ||||||
|     "GLACIER", |     "GLACIER", | ||||||
| @ -190,7 +196,7 @@ class _VersionedKeyStore(dict):  # type: ignore | |||||||
|     values = itervalues = _itervalues  # type: ignore |     values = itervalues = _itervalues  # type: ignore | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def compute_checksum(body: bytes, algorithm: str) -> bytes: | def compute_checksum(body: bytes, algorithm: str, encode_base64: bool = True) -> bytes: | ||||||
|     if algorithm == "SHA1": |     if algorithm == "SHA1": | ||||||
|         hashed_body = _hash(hashlib.sha1, (body,)) |         hashed_body = _hash(hashlib.sha1, (body,)) | ||||||
|     elif algorithm == "CRC32C": |     elif algorithm == "CRC32C": | ||||||
| @ -205,7 +211,10 @@ def compute_checksum(body: bytes, algorithm: str) -> bytes: | |||||||
|         hashed_body = binascii.crc32(body).to_bytes(4, "big") |         hashed_body = binascii.crc32(body).to_bytes(4, "big") | ||||||
|     else: |     else: | ||||||
|         hashed_body = _hash(hashlib.sha256, (body,)) |         hashed_body = _hash(hashlib.sha256, (body,)) | ||||||
|     return base64.b64encode(hashed_body) |     if encode_base64: | ||||||
|  |         return base64.b64encode(hashed_body) | ||||||
|  |     else: | ||||||
|  |         return hashed_body | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _hash(fn: Any, args: Any) -> bytes: | def _hash(fn: Any, args: Any) -> bytes: | ||||||
|  | |||||||
| @ -7,6 +7,8 @@ from botocore.client import ClientError | |||||||
| from moto import mock_kms, mock_s3 | from moto import mock_kms, mock_s3 | ||||||
| from moto.s3.responses import DEFAULT_REGION_NAME | from moto.s3.responses import DEFAULT_REGION_NAME | ||||||
| 
 | 
 | ||||||
|  | from . import s3_aws_verified | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "key_name", |     "key_name", | ||||||
| @ -35,25 +37,27 @@ def test_copy_key_boto3(key_name): | |||||||
|     assert resp["Body"].read() == b"some value" |     assert resp["Body"].read() == b"some value" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @mock_s3 | @pytest.mark.aws_verified | ||||||
| def test_copy_key_boto3_with_sha256_checksum(): | @s3_aws_verified | ||||||
|  | def test_copy_key_boto3_with_args(bucket=None): | ||||||
|     # Setup |     # Setup | ||||||
|     s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) |     s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) | ||||||
|     client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) |     client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) | ||||||
|     key_name = "key" |     key_name = "key" | ||||||
|     new_key = "new_key" |     new_key = "new_key" | ||||||
|     bucket = "foobar" |  | ||||||
|     expected_hash = "qz0H8xacy9DtbEtF3iFRn5+TjHLSQSSZiquUnOg7tRs=" |     expected_hash = "qz0H8xacy9DtbEtF3iFRn5+TjHLSQSSZiquUnOg7tRs=" | ||||||
| 
 | 
 | ||||||
|     s3_resource.create_bucket(Bucket=bucket) |     key = s3_resource.Object(bucket, key_name) | ||||||
|     key = s3_resource.Object("foobar", key_name) |  | ||||||
|     key.put(Body=b"some value") |     key.put(Body=b"some value") | ||||||
| 
 | 
 | ||||||
|     # Execute |     # Execute | ||||||
|     key2 = s3_resource.Object(bucket, new_key) |     key2 = s3_resource.Object(bucket, new_key) | ||||||
|     key2.copy( |     key2.copy( | ||||||
|         CopySource={"Bucket": bucket, "Key": key_name}, |         CopySource={"Bucket": bucket, "Key": key_name}, | ||||||
|         ExtraArgs={"ChecksumAlgorithm": "SHA256"}, |         ExtraArgs={ | ||||||
|  |             "ChecksumAlgorithm": "SHA256", | ||||||
|  |             "WebsiteRedirectLocation": "http://getmoto.org/", | ||||||
|  |         }, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     # Verify |     # Verify | ||||||
| @ -65,6 +69,9 @@ def test_copy_key_boto3_with_sha256_checksum(): | |||||||
|     assert "ChecksumSHA256" in resp["Checksum"] |     assert "ChecksumSHA256" in resp["Checksum"] | ||||||
|     assert resp["Checksum"]["ChecksumSHA256"] == expected_hash |     assert resp["Checksum"]["ChecksumSHA256"] == expected_hash | ||||||
| 
 | 
 | ||||||
|  |     obj = client.get_object(Bucket=bucket, Key=new_key) | ||||||
|  |     assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/" | ||||||
|  | 
 | ||||||
|     # Verify in place |     # Verify in place | ||||||
|     copy_in_place = client.copy_object( |     copy_in_place = client.copy_object( | ||||||
|         Bucket=bucket, |         Bucket=bucket, | ||||||
| @ -78,6 +85,43 @@ def test_copy_key_boto3_with_sha256_checksum(): | |||||||
|     assert copy_in_place["CopyObjectResult"]["ChecksumSHA256"] == expected_hash |     assert copy_in_place["CopyObjectResult"]["ChecksumSHA256"] == expected_hash | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @pytest.mark.aws_verified | ||||||
|  | @s3_aws_verified | ||||||
|  | def test_copy_key_boto3_with_args__using_multipart(bucket=None): | ||||||
|  |     # Setup | ||||||
|  |     s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) | ||||||
|  |     client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) | ||||||
|  |     key_name = "key" | ||||||
|  |     new_key = "new_key" | ||||||
|  |     expected_hash = "DnKotDi4EtYGwNMDKmnR6SqH3bWVOlo2BC+tsz9rHqw=" | ||||||
|  | 
 | ||||||
|  |     key = s3_resource.Object(bucket, key_name) | ||||||
|  |     key.put(Body=b"some value") | ||||||
|  | 
 | ||||||
|  |     # Execute | ||||||
|  |     key2 = s3_resource.Object(bucket, new_key) | ||||||
|  |     key2.copy( | ||||||
|  |         CopySource={"Bucket": bucket, "Key": key_name}, | ||||||
|  |         ExtraArgs={ | ||||||
|  |             "ChecksumAlgorithm": "SHA256", | ||||||
|  |             "WebsiteRedirectLocation": "http://getmoto.org/", | ||||||
|  |         }, | ||||||
|  |         Config=boto3.s3.transfer.TransferConfig(multipart_threshold=1), | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # Verify | ||||||
|  |     resp = client.get_object_attributes( | ||||||
|  |         Bucket=bucket, Key=new_key, ObjectAttributes=["Checksum"] | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     assert "Checksum" in resp | ||||||
|  |     assert "ChecksumSHA256" in resp["Checksum"] | ||||||
|  |     assert resp["Checksum"]["ChecksumSHA256"] == expected_hash | ||||||
|  | 
 | ||||||
|  |     obj = client.get_object(Bucket=bucket, Key=new_key) | ||||||
|  |     assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @mock_s3 | @mock_s3 | ||||||
| def test_copy_key_with_version_boto3(): | def test_copy_key_with_version_boto3(): | ||||||
|     s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) |     s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user