S3: copy() using multiparts should respect ExtraArgs (#7110)
This commit is contained in:
parent
bd93d87134
commit
90850bc573
@ -6257,7 +6257,7 @@
|
||||
|
||||
## s3
|
||||
<details>
|
||||
<summary>64% implemented</summary>
|
||||
<summary>65% implemented</summary>
|
||||
|
||||
- [X] abort_multipart_upload
|
||||
- [X] complete_multipart_upload
|
||||
@ -6356,7 +6356,7 @@
|
||||
- [ ] restore_object
|
||||
- [X] select_object_content
|
||||
- [X] upload_part
|
||||
- [ ] upload_part_copy
|
||||
- [X] upload_part_copy
|
||||
- [ ] write_get_object_response
|
||||
</details>
|
||||
|
||||
|
@ -159,6 +159,6 @@ s3
|
||||
|
||||
|
||||
- [X] upload_part
|
||||
- [ ] upload_part_copy
|
||||
- [X] upload_part_copy
|
||||
- [ ] write_get_object_response
|
||||
|
||||
|
@ -73,6 +73,7 @@ from .utils import (
|
||||
STORAGE_CLASS,
|
||||
CaseInsensitiveDict,
|
||||
_VersionedKeyStore,
|
||||
compute_checksum,
|
||||
)
|
||||
|
||||
MAX_BUCKET_NAME_LENGTH = 63
|
||||
@ -399,10 +400,14 @@ class FakeMultipart(BaseModel):
|
||||
self.sse_encryption = sse_encryption
|
||||
self.kms_key_id = kms_key_id
|
||||
|
||||
def complete(self, body: Iterator[Tuple[int, str]]) -> Tuple[bytes, str]:
|
||||
def complete(
|
||||
self, body: Iterator[Tuple[int, str]]
|
||||
) -> Tuple[bytes, str, Optional[str]]:
|
||||
checksum_algo = self.metadata.get("x-amz-checksum-algorithm")
|
||||
decode_hex = codecs.getdecoder("hex_codec")
|
||||
total = bytearray()
|
||||
md5s = bytearray()
|
||||
checksum = bytearray()
|
||||
|
||||
last = None
|
||||
count = 0
|
||||
@ -418,6 +423,10 @@ class FakeMultipart(BaseModel):
|
||||
raise EntityTooSmall()
|
||||
md5s.extend(decode_hex(part_etag)[0]) # type: ignore
|
||||
total.extend(part.value)
|
||||
if checksum_algo:
|
||||
checksum.extend(
|
||||
compute_checksum(part.value, checksum_algo, encode_base64=False)
|
||||
)
|
||||
last = part
|
||||
count += 1
|
||||
|
||||
@ -426,7 +435,11 @@ class FakeMultipart(BaseModel):
|
||||
|
||||
full_etag = md5_hash()
|
||||
full_etag.update(bytes(md5s))
|
||||
return total, f"{full_etag.hexdigest()}-{count}"
|
||||
if checksum_algo:
|
||||
encoded_checksum = compute_checksum(checksum, checksum_algo).decode("utf-8")
|
||||
else:
|
||||
encoded_checksum = None
|
||||
return total, f"{full_etag.hexdigest()}-{count}", encoded_checksum
|
||||
|
||||
def set_part(self, part_id: int, value: bytes) -> FakeKey:
|
||||
if part_id < 1:
|
||||
@ -2319,13 +2332,13 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
|
||||
|
||||
def complete_multipart_upload(
|
||||
self, bucket_name: str, multipart_id: str, body: Iterator[Tuple[int, str]]
|
||||
) -> Tuple[FakeMultipart, bytes, str]:
|
||||
) -> Tuple[FakeMultipart, bytes, str, Optional[str]]:
|
||||
bucket = self.get_bucket(bucket_name)
|
||||
multipart = bucket.multiparts[multipart_id]
|
||||
value, etag = multipart.complete(body)
|
||||
value, etag, checksum = multipart.complete(body)
|
||||
if value is not None:
|
||||
del bucket.multiparts[multipart_id]
|
||||
return multipart, value, etag
|
||||
return multipart, value, etag, checksum
|
||||
|
||||
def get_all_multiparts(self, bucket_name: str) -> Dict[str, FakeMultipart]:
|
||||
bucket = self.get_bucket(bucket_name)
|
||||
@ -2338,7 +2351,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
|
||||
multipart = bucket.multiparts[multipart_id]
|
||||
return multipart.set_part(part_id, value)
|
||||
|
||||
def copy_part(
|
||||
def upload_part_copy(
|
||||
self,
|
||||
dest_bucket_name: str,
|
||||
multipart_id: str,
|
||||
|
@ -1507,7 +1507,7 @@ class S3Response(BaseResponse):
|
||||
if self.backend.get_object(
|
||||
src_bucket, src_key, version_id=src_version_id
|
||||
):
|
||||
key = self.backend.copy_part(
|
||||
key = self.backend.upload_part_copy(
|
||||
bucket_name,
|
||||
upload_id,
|
||||
part_number,
|
||||
@ -1707,6 +1707,8 @@ class S3Response(BaseResponse):
|
||||
response_headers.update(
|
||||
{"Checksum": {f"Checksum{checksum_algorithm}": checksum_value}}
|
||||
)
|
||||
# By default, the checksum-details for the copy will be the same as the original
|
||||
# But if another algorithm is provided during the copy-operation, we override the values
|
||||
new_key.checksum_algorithm = checksum_algorithm
|
||||
new_key.checksum_value = checksum_value
|
||||
|
||||
@ -2252,12 +2254,13 @@ class S3Response(BaseResponse):
|
||||
if query.get("uploadId"):
|
||||
multipart_id = query["uploadId"][0]
|
||||
|
||||
multipart, value, etag = self.backend.complete_multipart_upload(
|
||||
multipart, value, etag, checksum = self.backend.complete_multipart_upload(
|
||||
bucket_name, multipart_id, self._complete_multipart_body(body)
|
||||
)
|
||||
if value is None:
|
||||
return 400, {}, ""
|
||||
|
||||
headers: Dict[str, Any] = {}
|
||||
key = self.backend.put_object(
|
||||
bucket_name,
|
||||
multipart.key_name,
|
||||
@ -2269,6 +2272,13 @@ class S3Response(BaseResponse):
|
||||
kms_key_id=multipart.kms_key_id,
|
||||
)
|
||||
key.set_metadata(multipart.metadata)
|
||||
|
||||
if checksum:
|
||||
key.checksum_algorithm = multipart.metadata.get(
|
||||
"x-amz-checksum-algorithm"
|
||||
)
|
||||
key.checksum_value = checksum
|
||||
|
||||
self.backend.set_key_tags(key, multipart.tags)
|
||||
self.backend.put_object_acl(
|
||||
bucket_name=bucket_name,
|
||||
@ -2277,7 +2287,6 @@ class S3Response(BaseResponse):
|
||||
)
|
||||
|
||||
template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE)
|
||||
headers: Dict[str, Any] = {}
|
||||
if key.version_id:
|
||||
headers["x-amz-version-id"] = key.version_id
|
||||
|
||||
|
@ -24,6 +24,12 @@ user_settable_fields = {
|
||||
"expires",
|
||||
"content-disposition",
|
||||
"x-robots-tag",
|
||||
"x-amz-checksum-algorithm",
|
||||
"x-amz-content-sha256",
|
||||
"x-amz-content-crc32",
|
||||
"x-amz-content-crc32c",
|
||||
"x-amz-content-sha1",
|
||||
"x-amz-website-redirect-location",
|
||||
}
|
||||
ARCHIVE_STORAGE_CLASSES = [
|
||||
"GLACIER",
|
||||
@ -190,7 +196,7 @@ class _VersionedKeyStore(dict): # type: ignore
|
||||
values = itervalues = _itervalues # type: ignore
|
||||
|
||||
|
||||
def compute_checksum(body: bytes, algorithm: str) -> bytes:
|
||||
def compute_checksum(body: bytes, algorithm: str, encode_base64: bool = True) -> bytes:
|
||||
if algorithm == "SHA1":
|
||||
hashed_body = _hash(hashlib.sha1, (body,))
|
||||
elif algorithm == "CRC32C":
|
||||
@ -205,7 +211,10 @@ def compute_checksum(body: bytes, algorithm: str) -> bytes:
|
||||
hashed_body = binascii.crc32(body).to_bytes(4, "big")
|
||||
else:
|
||||
hashed_body = _hash(hashlib.sha256, (body,))
|
||||
return base64.b64encode(hashed_body)
|
||||
if encode_base64:
|
||||
return base64.b64encode(hashed_body)
|
||||
else:
|
||||
return hashed_body
|
||||
|
||||
|
||||
def _hash(fn: Any, args: Any) -> bytes:
|
||||
|
@ -7,6 +7,8 @@ from botocore.client import ClientError
|
||||
from moto import mock_kms, mock_s3
|
||||
from moto.s3.responses import DEFAULT_REGION_NAME
|
||||
|
||||
from . import s3_aws_verified
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key_name",
|
||||
@ -35,25 +37,27 @@ def test_copy_key_boto3(key_name):
|
||||
assert resp["Body"].read() == b"some value"
|
||||
|
||||
|
||||
@mock_s3
|
||||
def test_copy_key_boto3_with_sha256_checksum():
|
||||
@pytest.mark.aws_verified
|
||||
@s3_aws_verified
|
||||
def test_copy_key_boto3_with_args(bucket=None):
|
||||
# Setup
|
||||
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
|
||||
client = boto3.client("s3", region_name=DEFAULT_REGION_NAME)
|
||||
key_name = "key"
|
||||
new_key = "new_key"
|
||||
bucket = "foobar"
|
||||
expected_hash = "qz0H8xacy9DtbEtF3iFRn5+TjHLSQSSZiquUnOg7tRs="
|
||||
|
||||
s3_resource.create_bucket(Bucket=bucket)
|
||||
key = s3_resource.Object("foobar", key_name)
|
||||
key = s3_resource.Object(bucket, key_name)
|
||||
key.put(Body=b"some value")
|
||||
|
||||
# Execute
|
||||
key2 = s3_resource.Object(bucket, new_key)
|
||||
key2.copy(
|
||||
CopySource={"Bucket": bucket, "Key": key_name},
|
||||
ExtraArgs={"ChecksumAlgorithm": "SHA256"},
|
||||
ExtraArgs={
|
||||
"ChecksumAlgorithm": "SHA256",
|
||||
"WebsiteRedirectLocation": "http://getmoto.org/",
|
||||
},
|
||||
)
|
||||
|
||||
# Verify
|
||||
@ -65,6 +69,9 @@ def test_copy_key_boto3_with_sha256_checksum():
|
||||
assert "ChecksumSHA256" in resp["Checksum"]
|
||||
assert resp["Checksum"]["ChecksumSHA256"] == expected_hash
|
||||
|
||||
obj = client.get_object(Bucket=bucket, Key=new_key)
|
||||
assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/"
|
||||
|
||||
# Verify in place
|
||||
copy_in_place = client.copy_object(
|
||||
Bucket=bucket,
|
||||
@ -78,6 +85,43 @@ def test_copy_key_boto3_with_sha256_checksum():
|
||||
assert copy_in_place["CopyObjectResult"]["ChecksumSHA256"] == expected_hash
|
||||
|
||||
|
||||
@pytest.mark.aws_verified
|
||||
@s3_aws_verified
|
||||
def test_copy_key_boto3_with_args__using_multipart(bucket=None):
|
||||
# Setup
|
||||
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
|
||||
client = boto3.client("s3", region_name=DEFAULT_REGION_NAME)
|
||||
key_name = "key"
|
||||
new_key = "new_key"
|
||||
expected_hash = "DnKotDi4EtYGwNMDKmnR6SqH3bWVOlo2BC+tsz9rHqw="
|
||||
|
||||
key = s3_resource.Object(bucket, key_name)
|
||||
key.put(Body=b"some value")
|
||||
|
||||
# Execute
|
||||
key2 = s3_resource.Object(bucket, new_key)
|
||||
key2.copy(
|
||||
CopySource={"Bucket": bucket, "Key": key_name},
|
||||
ExtraArgs={
|
||||
"ChecksumAlgorithm": "SHA256",
|
||||
"WebsiteRedirectLocation": "http://getmoto.org/",
|
||||
},
|
||||
Config=boto3.s3.transfer.TransferConfig(multipart_threshold=1),
|
||||
)
|
||||
|
||||
# Verify
|
||||
resp = client.get_object_attributes(
|
||||
Bucket=bucket, Key=new_key, ObjectAttributes=["Checksum"]
|
||||
)
|
||||
|
||||
assert "Checksum" in resp
|
||||
assert "ChecksumSHA256" in resp["Checksum"]
|
||||
assert resp["Checksum"]["ChecksumSHA256"] == expected_hash
|
||||
|
||||
obj = client.get_object(Bucket=bucket, Key=new_key)
|
||||
assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/"
|
||||
|
||||
|
||||
@mock_s3
|
||||
def test_copy_key_with_version_boto3():
|
||||
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
|
||||
|
Loading…
Reference in New Issue
Block a user