S3: copy() using multiparts should respect ExtraArgs (#7110)

This commit is contained in:
Bert Blommers 2023-12-10 15:26:26 -01:00 committed by GitHub
parent bd93d87134
commit 90850bc573
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 20 deletions

View File

@ -6257,7 +6257,7 @@
## s3 ## s3
<details> <details>
<summary>64% implemented</summary> <summary>65% implemented</summary>
- [X] abort_multipart_upload - [X] abort_multipart_upload
- [X] complete_multipart_upload - [X] complete_multipart_upload
@ -6356,7 +6356,7 @@
- [ ] restore_object - [ ] restore_object
- [X] select_object_content - [X] select_object_content
- [X] upload_part - [X] upload_part
- [ ] upload_part_copy - [X] upload_part_copy
- [ ] write_get_object_response - [ ] write_get_object_response
</details> </details>

View File

@ -159,6 +159,6 @@ s3
- [X] upload_part - [X] upload_part
- [ ] upload_part_copy - [X] upload_part_copy
- [ ] write_get_object_response - [ ] write_get_object_response

View File

@ -73,6 +73,7 @@ from .utils import (
STORAGE_CLASS, STORAGE_CLASS,
CaseInsensitiveDict, CaseInsensitiveDict,
_VersionedKeyStore, _VersionedKeyStore,
compute_checksum,
) )
MAX_BUCKET_NAME_LENGTH = 63 MAX_BUCKET_NAME_LENGTH = 63
@ -399,10 +400,14 @@ class FakeMultipart(BaseModel):
self.sse_encryption = sse_encryption self.sse_encryption = sse_encryption
self.kms_key_id = kms_key_id self.kms_key_id = kms_key_id
def complete(self, body: Iterator[Tuple[int, str]]) -> Tuple[bytes, str]: def complete(
self, body: Iterator[Tuple[int, str]]
) -> Tuple[bytes, str, Optional[str]]:
checksum_algo = self.metadata.get("x-amz-checksum-algorithm")
decode_hex = codecs.getdecoder("hex_codec") decode_hex = codecs.getdecoder("hex_codec")
total = bytearray() total = bytearray()
md5s = bytearray() md5s = bytearray()
checksum = bytearray()
last = None last = None
count = 0 count = 0
@ -418,6 +423,10 @@ class FakeMultipart(BaseModel):
raise EntityTooSmall() raise EntityTooSmall()
md5s.extend(decode_hex(part_etag)[0]) # type: ignore md5s.extend(decode_hex(part_etag)[0]) # type: ignore
total.extend(part.value) total.extend(part.value)
if checksum_algo:
checksum.extend(
compute_checksum(part.value, checksum_algo, encode_base64=False)
)
last = part last = part
count += 1 count += 1
@ -426,7 +435,11 @@ class FakeMultipart(BaseModel):
full_etag = md5_hash() full_etag = md5_hash()
full_etag.update(bytes(md5s)) full_etag.update(bytes(md5s))
return total, f"{full_etag.hexdigest()}-{count}" if checksum_algo:
encoded_checksum = compute_checksum(checksum, checksum_algo).decode("utf-8")
else:
encoded_checksum = None
return total, f"{full_etag.hexdigest()}-{count}", encoded_checksum
def set_part(self, part_id: int, value: bytes) -> FakeKey: def set_part(self, part_id: int, value: bytes) -> FakeKey:
if part_id < 1: if part_id < 1:
@ -2319,13 +2332,13 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
def complete_multipart_upload( def complete_multipart_upload(
self, bucket_name: str, multipart_id: str, body: Iterator[Tuple[int, str]] self, bucket_name: str, multipart_id: str, body: Iterator[Tuple[int, str]]
) -> Tuple[FakeMultipart, bytes, str]: ) -> Tuple[FakeMultipart, bytes, str, Optional[str]]:
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
multipart = bucket.multiparts[multipart_id] multipart = bucket.multiparts[multipart_id]
value, etag = multipart.complete(body) value, etag, checksum = multipart.complete(body)
if value is not None: if value is not None:
del bucket.multiparts[multipart_id] del bucket.multiparts[multipart_id]
return multipart, value, etag return multipart, value, etag, checksum
def get_all_multiparts(self, bucket_name: str) -> Dict[str, FakeMultipart]: def get_all_multiparts(self, bucket_name: str) -> Dict[str, FakeMultipart]:
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
@ -2338,7 +2351,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
multipart = bucket.multiparts[multipart_id] multipart = bucket.multiparts[multipart_id]
return multipart.set_part(part_id, value) return multipart.set_part(part_id, value)
def copy_part( def upload_part_copy(
self, self,
dest_bucket_name: str, dest_bucket_name: str,
multipart_id: str, multipart_id: str,

View File

@ -1507,7 +1507,7 @@ class S3Response(BaseResponse):
if self.backend.get_object( if self.backend.get_object(
src_bucket, src_key, version_id=src_version_id src_bucket, src_key, version_id=src_version_id
): ):
key = self.backend.copy_part( key = self.backend.upload_part_copy(
bucket_name, bucket_name,
upload_id, upload_id,
part_number, part_number,
@ -1707,6 +1707,8 @@ class S3Response(BaseResponse):
response_headers.update( response_headers.update(
{"Checksum": {f"Checksum{checksum_algorithm}": checksum_value}} {"Checksum": {f"Checksum{checksum_algorithm}": checksum_value}}
) )
# By default, the checksum-details for the copy will be the same as the original
# But if another algorithm is provided during the copy-operation, we override the values
new_key.checksum_algorithm = checksum_algorithm new_key.checksum_algorithm = checksum_algorithm
new_key.checksum_value = checksum_value new_key.checksum_value = checksum_value
@ -2252,12 +2254,13 @@ class S3Response(BaseResponse):
if query.get("uploadId"): if query.get("uploadId"):
multipart_id = query["uploadId"][0] multipart_id = query["uploadId"][0]
multipart, value, etag = self.backend.complete_multipart_upload( multipart, value, etag, checksum = self.backend.complete_multipart_upload(
bucket_name, multipart_id, self._complete_multipart_body(body) bucket_name, multipart_id, self._complete_multipart_body(body)
) )
if value is None: if value is None:
return 400, {}, "" return 400, {}, ""
headers: Dict[str, Any] = {}
key = self.backend.put_object( key = self.backend.put_object(
bucket_name, bucket_name,
multipart.key_name, multipart.key_name,
@ -2269,6 +2272,13 @@ class S3Response(BaseResponse):
kms_key_id=multipart.kms_key_id, kms_key_id=multipart.kms_key_id,
) )
key.set_metadata(multipart.metadata) key.set_metadata(multipart.metadata)
if checksum:
key.checksum_algorithm = multipart.metadata.get(
"x-amz-checksum-algorithm"
)
key.checksum_value = checksum
self.backend.set_key_tags(key, multipart.tags) self.backend.set_key_tags(key, multipart.tags)
self.backend.put_object_acl( self.backend.put_object_acl(
bucket_name=bucket_name, bucket_name=bucket_name,
@ -2277,7 +2287,6 @@ class S3Response(BaseResponse):
) )
template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE)
headers: Dict[str, Any] = {}
if key.version_id: if key.version_id:
headers["x-amz-version-id"] = key.version_id headers["x-amz-version-id"] = key.version_id

View File

@ -24,6 +24,12 @@ user_settable_fields = {
"expires", "expires",
"content-disposition", "content-disposition",
"x-robots-tag", "x-robots-tag",
"x-amz-checksum-algorithm",
"x-amz-content-sha256",
"x-amz-content-crc32",
"x-amz-content-crc32c",
"x-amz-content-sha1",
"x-amz-website-redirect-location",
} }
ARCHIVE_STORAGE_CLASSES = [ ARCHIVE_STORAGE_CLASSES = [
"GLACIER", "GLACIER",
@ -190,7 +196,7 @@ class _VersionedKeyStore(dict): # type: ignore
values = itervalues = _itervalues # type: ignore values = itervalues = _itervalues # type: ignore
def compute_checksum(body: bytes, algorithm: str) -> bytes: def compute_checksum(body: bytes, algorithm: str, encode_base64: bool = True) -> bytes:
if algorithm == "SHA1": if algorithm == "SHA1":
hashed_body = _hash(hashlib.sha1, (body,)) hashed_body = _hash(hashlib.sha1, (body,))
elif algorithm == "CRC32C": elif algorithm == "CRC32C":
@ -205,7 +211,10 @@ def compute_checksum(body: bytes, algorithm: str) -> bytes:
hashed_body = binascii.crc32(body).to_bytes(4, "big") hashed_body = binascii.crc32(body).to_bytes(4, "big")
else: else:
hashed_body = _hash(hashlib.sha256, (body,)) hashed_body = _hash(hashlib.sha256, (body,))
return base64.b64encode(hashed_body) if encode_base64:
return base64.b64encode(hashed_body)
else:
return hashed_body
def _hash(fn: Any, args: Any) -> bytes: def _hash(fn: Any, args: Any) -> bytes:

View File

@ -7,6 +7,8 @@ from botocore.client import ClientError
from moto import mock_kms, mock_s3 from moto import mock_kms, mock_s3
from moto.s3.responses import DEFAULT_REGION_NAME from moto.s3.responses import DEFAULT_REGION_NAME
from . import s3_aws_verified
@pytest.mark.parametrize( @pytest.mark.parametrize(
"key_name", "key_name",
@ -35,25 +37,27 @@ def test_copy_key_boto3(key_name):
assert resp["Body"].read() == b"some value" assert resp["Body"].read() == b"some value"
@mock_s3 @pytest.mark.aws_verified
def test_copy_key_boto3_with_sha256_checksum(): @s3_aws_verified
def test_copy_key_boto3_with_args(bucket=None):
# Setup # Setup
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) client = boto3.client("s3", region_name=DEFAULT_REGION_NAME)
key_name = "key" key_name = "key"
new_key = "new_key" new_key = "new_key"
bucket = "foobar"
expected_hash = "qz0H8xacy9DtbEtF3iFRn5+TjHLSQSSZiquUnOg7tRs=" expected_hash = "qz0H8xacy9DtbEtF3iFRn5+TjHLSQSSZiquUnOg7tRs="
s3_resource.create_bucket(Bucket=bucket) key = s3_resource.Object(bucket, key_name)
key = s3_resource.Object("foobar", key_name)
key.put(Body=b"some value") key.put(Body=b"some value")
# Execute # Execute
key2 = s3_resource.Object(bucket, new_key) key2 = s3_resource.Object(bucket, new_key)
key2.copy( key2.copy(
CopySource={"Bucket": bucket, "Key": key_name}, CopySource={"Bucket": bucket, "Key": key_name},
ExtraArgs={"ChecksumAlgorithm": "SHA256"}, ExtraArgs={
"ChecksumAlgorithm": "SHA256",
"WebsiteRedirectLocation": "http://getmoto.org/",
},
) )
# Verify # Verify
@ -65,6 +69,9 @@ def test_copy_key_boto3_with_sha256_checksum():
assert "ChecksumSHA256" in resp["Checksum"] assert "ChecksumSHA256" in resp["Checksum"]
assert resp["Checksum"]["ChecksumSHA256"] == expected_hash assert resp["Checksum"]["ChecksumSHA256"] == expected_hash
obj = client.get_object(Bucket=bucket, Key=new_key)
assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/"
# Verify in place # Verify in place
copy_in_place = client.copy_object( copy_in_place = client.copy_object(
Bucket=bucket, Bucket=bucket,
@ -78,6 +85,43 @@ def test_copy_key_boto3_with_sha256_checksum():
assert copy_in_place["CopyObjectResult"]["ChecksumSHA256"] == expected_hash assert copy_in_place["CopyObjectResult"]["ChecksumSHA256"] == expected_hash
@pytest.mark.aws_verified
@s3_aws_verified
def test_copy_key_boto3_with_args__using_multipart(bucket=None):
# Setup
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
client = boto3.client("s3", region_name=DEFAULT_REGION_NAME)
key_name = "key"
new_key = "new_key"
expected_hash = "DnKotDi4EtYGwNMDKmnR6SqH3bWVOlo2BC+tsz9rHqw="
key = s3_resource.Object(bucket, key_name)
key.put(Body=b"some value")
# Execute
key2 = s3_resource.Object(bucket, new_key)
key2.copy(
CopySource={"Bucket": bucket, "Key": key_name},
ExtraArgs={
"ChecksumAlgorithm": "SHA256",
"WebsiteRedirectLocation": "http://getmoto.org/",
},
Config=boto3.s3.transfer.TransferConfig(multipart_threshold=1),
)
# Verify
resp = client.get_object_attributes(
Bucket=bucket, Key=new_key, ObjectAttributes=["Checksum"]
)
assert "Checksum" in resp
assert "ChecksumSHA256" in resp["Checksum"]
assert resp["Checksum"]["ChecksumSHA256"] == expected_hash
obj = client.get_object(Bucket=bucket, Key=new_key)
assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/"
@mock_s3 @mock_s3
def test_copy_key_with_version_boto3(): def test_copy_key_with_version_boto3():
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)