S3: copy() using multiparts should respect ExtraArgs (#7110)
This commit is contained in:
parent
bd93d87134
commit
90850bc573
@ -6257,7 +6257,7 @@
|
|||||||
|
|
||||||
## s3
|
## s3
|
||||||
<details>
|
<details>
|
||||||
<summary>64% implemented</summary>
|
<summary>65% implemented</summary>
|
||||||
|
|
||||||
- [X] abort_multipart_upload
|
- [X] abort_multipart_upload
|
||||||
- [X] complete_multipart_upload
|
- [X] complete_multipart_upload
|
||||||
@ -6356,7 +6356,7 @@
|
|||||||
- [ ] restore_object
|
- [ ] restore_object
|
||||||
- [X] select_object_content
|
- [X] select_object_content
|
||||||
- [X] upload_part
|
- [X] upload_part
|
||||||
- [ ] upload_part_copy
|
- [X] upload_part_copy
|
||||||
- [ ] write_get_object_response
|
- [ ] write_get_object_response
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
@ -159,6 +159,6 @@ s3
|
|||||||
|
|
||||||
|
|
||||||
- [X] upload_part
|
- [X] upload_part
|
||||||
- [ ] upload_part_copy
|
- [X] upload_part_copy
|
||||||
- [ ] write_get_object_response
|
- [ ] write_get_object_response
|
||||||
|
|
||||||
|
@ -73,6 +73,7 @@ from .utils import (
|
|||||||
STORAGE_CLASS,
|
STORAGE_CLASS,
|
||||||
CaseInsensitiveDict,
|
CaseInsensitiveDict,
|
||||||
_VersionedKeyStore,
|
_VersionedKeyStore,
|
||||||
|
compute_checksum,
|
||||||
)
|
)
|
||||||
|
|
||||||
MAX_BUCKET_NAME_LENGTH = 63
|
MAX_BUCKET_NAME_LENGTH = 63
|
||||||
@ -399,10 +400,14 @@ class FakeMultipart(BaseModel):
|
|||||||
self.sse_encryption = sse_encryption
|
self.sse_encryption = sse_encryption
|
||||||
self.kms_key_id = kms_key_id
|
self.kms_key_id = kms_key_id
|
||||||
|
|
||||||
def complete(self, body: Iterator[Tuple[int, str]]) -> Tuple[bytes, str]:
|
def complete(
|
||||||
|
self, body: Iterator[Tuple[int, str]]
|
||||||
|
) -> Tuple[bytes, str, Optional[str]]:
|
||||||
|
checksum_algo = self.metadata.get("x-amz-checksum-algorithm")
|
||||||
decode_hex = codecs.getdecoder("hex_codec")
|
decode_hex = codecs.getdecoder("hex_codec")
|
||||||
total = bytearray()
|
total = bytearray()
|
||||||
md5s = bytearray()
|
md5s = bytearray()
|
||||||
|
checksum = bytearray()
|
||||||
|
|
||||||
last = None
|
last = None
|
||||||
count = 0
|
count = 0
|
||||||
@ -418,6 +423,10 @@ class FakeMultipart(BaseModel):
|
|||||||
raise EntityTooSmall()
|
raise EntityTooSmall()
|
||||||
md5s.extend(decode_hex(part_etag)[0]) # type: ignore
|
md5s.extend(decode_hex(part_etag)[0]) # type: ignore
|
||||||
total.extend(part.value)
|
total.extend(part.value)
|
||||||
|
if checksum_algo:
|
||||||
|
checksum.extend(
|
||||||
|
compute_checksum(part.value, checksum_algo, encode_base64=False)
|
||||||
|
)
|
||||||
last = part
|
last = part
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
@ -426,7 +435,11 @@ class FakeMultipart(BaseModel):
|
|||||||
|
|
||||||
full_etag = md5_hash()
|
full_etag = md5_hash()
|
||||||
full_etag.update(bytes(md5s))
|
full_etag.update(bytes(md5s))
|
||||||
return total, f"{full_etag.hexdigest()}-{count}"
|
if checksum_algo:
|
||||||
|
encoded_checksum = compute_checksum(checksum, checksum_algo).decode("utf-8")
|
||||||
|
else:
|
||||||
|
encoded_checksum = None
|
||||||
|
return total, f"{full_etag.hexdigest()}-{count}", encoded_checksum
|
||||||
|
|
||||||
def set_part(self, part_id: int, value: bytes) -> FakeKey:
|
def set_part(self, part_id: int, value: bytes) -> FakeKey:
|
||||||
if part_id < 1:
|
if part_id < 1:
|
||||||
@ -2319,13 +2332,13 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
|
|||||||
|
|
||||||
def complete_multipart_upload(
|
def complete_multipart_upload(
|
||||||
self, bucket_name: str, multipart_id: str, body: Iterator[Tuple[int, str]]
|
self, bucket_name: str, multipart_id: str, body: Iterator[Tuple[int, str]]
|
||||||
) -> Tuple[FakeMultipart, bytes, str]:
|
) -> Tuple[FakeMultipart, bytes, str, Optional[str]]:
|
||||||
bucket = self.get_bucket(bucket_name)
|
bucket = self.get_bucket(bucket_name)
|
||||||
multipart = bucket.multiparts[multipart_id]
|
multipart = bucket.multiparts[multipart_id]
|
||||||
value, etag = multipart.complete(body)
|
value, etag, checksum = multipart.complete(body)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
del bucket.multiparts[multipart_id]
|
del bucket.multiparts[multipart_id]
|
||||||
return multipart, value, etag
|
return multipart, value, etag, checksum
|
||||||
|
|
||||||
def get_all_multiparts(self, bucket_name: str) -> Dict[str, FakeMultipart]:
|
def get_all_multiparts(self, bucket_name: str) -> Dict[str, FakeMultipart]:
|
||||||
bucket = self.get_bucket(bucket_name)
|
bucket = self.get_bucket(bucket_name)
|
||||||
@ -2338,7 +2351,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
|
|||||||
multipart = bucket.multiparts[multipart_id]
|
multipart = bucket.multiparts[multipart_id]
|
||||||
return multipart.set_part(part_id, value)
|
return multipart.set_part(part_id, value)
|
||||||
|
|
||||||
def copy_part(
|
def upload_part_copy(
|
||||||
self,
|
self,
|
||||||
dest_bucket_name: str,
|
dest_bucket_name: str,
|
||||||
multipart_id: str,
|
multipart_id: str,
|
||||||
|
@ -1507,7 +1507,7 @@ class S3Response(BaseResponse):
|
|||||||
if self.backend.get_object(
|
if self.backend.get_object(
|
||||||
src_bucket, src_key, version_id=src_version_id
|
src_bucket, src_key, version_id=src_version_id
|
||||||
):
|
):
|
||||||
key = self.backend.copy_part(
|
key = self.backend.upload_part_copy(
|
||||||
bucket_name,
|
bucket_name,
|
||||||
upload_id,
|
upload_id,
|
||||||
part_number,
|
part_number,
|
||||||
@ -1707,6 +1707,8 @@ class S3Response(BaseResponse):
|
|||||||
response_headers.update(
|
response_headers.update(
|
||||||
{"Checksum": {f"Checksum{checksum_algorithm}": checksum_value}}
|
{"Checksum": {f"Checksum{checksum_algorithm}": checksum_value}}
|
||||||
)
|
)
|
||||||
|
# By default, the checksum-details for the copy will be the same as the original
|
||||||
|
# But if another algorithm is provided during the copy-operation, we override the values
|
||||||
new_key.checksum_algorithm = checksum_algorithm
|
new_key.checksum_algorithm = checksum_algorithm
|
||||||
new_key.checksum_value = checksum_value
|
new_key.checksum_value = checksum_value
|
||||||
|
|
||||||
@ -2252,12 +2254,13 @@ class S3Response(BaseResponse):
|
|||||||
if query.get("uploadId"):
|
if query.get("uploadId"):
|
||||||
multipart_id = query["uploadId"][0]
|
multipart_id = query["uploadId"][0]
|
||||||
|
|
||||||
multipart, value, etag = self.backend.complete_multipart_upload(
|
multipart, value, etag, checksum = self.backend.complete_multipart_upload(
|
||||||
bucket_name, multipart_id, self._complete_multipart_body(body)
|
bucket_name, multipart_id, self._complete_multipart_body(body)
|
||||||
)
|
)
|
||||||
if value is None:
|
if value is None:
|
||||||
return 400, {}, ""
|
return 400, {}, ""
|
||||||
|
|
||||||
|
headers: Dict[str, Any] = {}
|
||||||
key = self.backend.put_object(
|
key = self.backend.put_object(
|
||||||
bucket_name,
|
bucket_name,
|
||||||
multipart.key_name,
|
multipart.key_name,
|
||||||
@ -2269,6 +2272,13 @@ class S3Response(BaseResponse):
|
|||||||
kms_key_id=multipart.kms_key_id,
|
kms_key_id=multipart.kms_key_id,
|
||||||
)
|
)
|
||||||
key.set_metadata(multipart.metadata)
|
key.set_metadata(multipart.metadata)
|
||||||
|
|
||||||
|
if checksum:
|
||||||
|
key.checksum_algorithm = multipart.metadata.get(
|
||||||
|
"x-amz-checksum-algorithm"
|
||||||
|
)
|
||||||
|
key.checksum_value = checksum
|
||||||
|
|
||||||
self.backend.set_key_tags(key, multipart.tags)
|
self.backend.set_key_tags(key, multipart.tags)
|
||||||
self.backend.put_object_acl(
|
self.backend.put_object_acl(
|
||||||
bucket_name=bucket_name,
|
bucket_name=bucket_name,
|
||||||
@ -2277,7 +2287,6 @@ class S3Response(BaseResponse):
|
|||||||
)
|
)
|
||||||
|
|
||||||
template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE)
|
template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE)
|
||||||
headers: Dict[str, Any] = {}
|
|
||||||
if key.version_id:
|
if key.version_id:
|
||||||
headers["x-amz-version-id"] = key.version_id
|
headers["x-amz-version-id"] = key.version_id
|
||||||
|
|
||||||
|
@ -24,6 +24,12 @@ user_settable_fields = {
|
|||||||
"expires",
|
"expires",
|
||||||
"content-disposition",
|
"content-disposition",
|
||||||
"x-robots-tag",
|
"x-robots-tag",
|
||||||
|
"x-amz-checksum-algorithm",
|
||||||
|
"x-amz-content-sha256",
|
||||||
|
"x-amz-content-crc32",
|
||||||
|
"x-amz-content-crc32c",
|
||||||
|
"x-amz-content-sha1",
|
||||||
|
"x-amz-website-redirect-location",
|
||||||
}
|
}
|
||||||
ARCHIVE_STORAGE_CLASSES = [
|
ARCHIVE_STORAGE_CLASSES = [
|
||||||
"GLACIER",
|
"GLACIER",
|
||||||
@ -190,7 +196,7 @@ class _VersionedKeyStore(dict): # type: ignore
|
|||||||
values = itervalues = _itervalues # type: ignore
|
values = itervalues = _itervalues # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def compute_checksum(body: bytes, algorithm: str) -> bytes:
|
def compute_checksum(body: bytes, algorithm: str, encode_base64: bool = True) -> bytes:
|
||||||
if algorithm == "SHA1":
|
if algorithm == "SHA1":
|
||||||
hashed_body = _hash(hashlib.sha1, (body,))
|
hashed_body = _hash(hashlib.sha1, (body,))
|
||||||
elif algorithm == "CRC32C":
|
elif algorithm == "CRC32C":
|
||||||
@ -205,7 +211,10 @@ def compute_checksum(body: bytes, algorithm: str) -> bytes:
|
|||||||
hashed_body = binascii.crc32(body).to_bytes(4, "big")
|
hashed_body = binascii.crc32(body).to_bytes(4, "big")
|
||||||
else:
|
else:
|
||||||
hashed_body = _hash(hashlib.sha256, (body,))
|
hashed_body = _hash(hashlib.sha256, (body,))
|
||||||
return base64.b64encode(hashed_body)
|
if encode_base64:
|
||||||
|
return base64.b64encode(hashed_body)
|
||||||
|
else:
|
||||||
|
return hashed_body
|
||||||
|
|
||||||
|
|
||||||
def _hash(fn: Any, args: Any) -> bytes:
|
def _hash(fn: Any, args: Any) -> bytes:
|
||||||
|
@ -7,6 +7,8 @@ from botocore.client import ClientError
|
|||||||
from moto import mock_kms, mock_s3
|
from moto import mock_kms, mock_s3
|
||||||
from moto.s3.responses import DEFAULT_REGION_NAME
|
from moto.s3.responses import DEFAULT_REGION_NAME
|
||||||
|
|
||||||
|
from . import s3_aws_verified
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"key_name",
|
"key_name",
|
||||||
@ -35,25 +37,27 @@ def test_copy_key_boto3(key_name):
|
|||||||
assert resp["Body"].read() == b"some value"
|
assert resp["Body"].read() == b"some value"
|
||||||
|
|
||||||
|
|
||||||
@mock_s3
|
@pytest.mark.aws_verified
|
||||||
def test_copy_key_boto3_with_sha256_checksum():
|
@s3_aws_verified
|
||||||
|
def test_copy_key_boto3_with_args(bucket=None):
|
||||||
# Setup
|
# Setup
|
||||||
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
|
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
|
||||||
client = boto3.client("s3", region_name=DEFAULT_REGION_NAME)
|
client = boto3.client("s3", region_name=DEFAULT_REGION_NAME)
|
||||||
key_name = "key"
|
key_name = "key"
|
||||||
new_key = "new_key"
|
new_key = "new_key"
|
||||||
bucket = "foobar"
|
|
||||||
expected_hash = "qz0H8xacy9DtbEtF3iFRn5+TjHLSQSSZiquUnOg7tRs="
|
expected_hash = "qz0H8xacy9DtbEtF3iFRn5+TjHLSQSSZiquUnOg7tRs="
|
||||||
|
|
||||||
s3_resource.create_bucket(Bucket=bucket)
|
key = s3_resource.Object(bucket, key_name)
|
||||||
key = s3_resource.Object("foobar", key_name)
|
|
||||||
key.put(Body=b"some value")
|
key.put(Body=b"some value")
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
key2 = s3_resource.Object(bucket, new_key)
|
key2 = s3_resource.Object(bucket, new_key)
|
||||||
key2.copy(
|
key2.copy(
|
||||||
CopySource={"Bucket": bucket, "Key": key_name},
|
CopySource={"Bucket": bucket, "Key": key_name},
|
||||||
ExtraArgs={"ChecksumAlgorithm": "SHA256"},
|
ExtraArgs={
|
||||||
|
"ChecksumAlgorithm": "SHA256",
|
||||||
|
"WebsiteRedirectLocation": "http://getmoto.org/",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
@ -65,6 +69,9 @@ def test_copy_key_boto3_with_sha256_checksum():
|
|||||||
assert "ChecksumSHA256" in resp["Checksum"]
|
assert "ChecksumSHA256" in resp["Checksum"]
|
||||||
assert resp["Checksum"]["ChecksumSHA256"] == expected_hash
|
assert resp["Checksum"]["ChecksumSHA256"] == expected_hash
|
||||||
|
|
||||||
|
obj = client.get_object(Bucket=bucket, Key=new_key)
|
||||||
|
assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/"
|
||||||
|
|
||||||
# Verify in place
|
# Verify in place
|
||||||
copy_in_place = client.copy_object(
|
copy_in_place = client.copy_object(
|
||||||
Bucket=bucket,
|
Bucket=bucket,
|
||||||
@ -78,6 +85,43 @@ def test_copy_key_boto3_with_sha256_checksum():
|
|||||||
assert copy_in_place["CopyObjectResult"]["ChecksumSHA256"] == expected_hash
|
assert copy_in_place["CopyObjectResult"]["ChecksumSHA256"] == expected_hash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.aws_verified
|
||||||
|
@s3_aws_verified
|
||||||
|
def test_copy_key_boto3_with_args__using_multipart(bucket=None):
|
||||||
|
# Setup
|
||||||
|
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
|
||||||
|
client = boto3.client("s3", region_name=DEFAULT_REGION_NAME)
|
||||||
|
key_name = "key"
|
||||||
|
new_key = "new_key"
|
||||||
|
expected_hash = "DnKotDi4EtYGwNMDKmnR6SqH3bWVOlo2BC+tsz9rHqw="
|
||||||
|
|
||||||
|
key = s3_resource.Object(bucket, key_name)
|
||||||
|
key.put(Body=b"some value")
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
key2 = s3_resource.Object(bucket, new_key)
|
||||||
|
key2.copy(
|
||||||
|
CopySource={"Bucket": bucket, "Key": key_name},
|
||||||
|
ExtraArgs={
|
||||||
|
"ChecksumAlgorithm": "SHA256",
|
||||||
|
"WebsiteRedirectLocation": "http://getmoto.org/",
|
||||||
|
},
|
||||||
|
Config=boto3.s3.transfer.TransferConfig(multipart_threshold=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
resp = client.get_object_attributes(
|
||||||
|
Bucket=bucket, Key=new_key, ObjectAttributes=["Checksum"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Checksum" in resp
|
||||||
|
assert "ChecksumSHA256" in resp["Checksum"]
|
||||||
|
assert resp["Checksum"]["ChecksumSHA256"] == expected_hash
|
||||||
|
|
||||||
|
obj = client.get_object(Bucket=bucket, Key=new_key)
|
||||||
|
assert obj["WebsiteRedirectLocation"] == "http://getmoto.org/"
|
||||||
|
|
||||||
|
|
||||||
@mock_s3
|
@mock_s3
|
||||||
def test_copy_key_with_version_boto3():
|
def test_copy_key_with_version_boto3():
|
||||||
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
|
s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
|
||||||
|
Loading…
Reference in New Issue
Block a user