diff --git a/moto/s3/models.py b/moto/s3/models.py index c0f11809b..68e188fec 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -351,11 +351,12 @@ class FakeKey(BaseModel, ManagedState): class FakeMultipart(BaseModel): - def __init__(self, key_name, metadata, storage=None, tags=None): + def __init__(self, key_name, metadata, storage=None, tags=None, acl=None): self.key_name = key_name self.metadata = metadata self.storage = storage self.tags = tags + self.acl = acl self.parts = {} self.partlist = [] # ordered list of part ID's rand_b64 = base64.b64encode(os.urandom(UPLOAD_ID_BYTES)) @@ -1881,13 +1882,6 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): pub_block_config.get("RestrictPublicBuckets"), ) - def initiate_multipart(self, bucket_name, key_name, metadata): - bucket = self.get_bucket(bucket_name) - new_multipart = FakeMultipart(key_name, metadata) - bucket.multiparts[new_multipart.id] = new_multipart - - return new_multipart - def complete_multipart(self, bucket_name, multipart_id, body): bucket = self.get_bucket(bucket_name) multipart = bucket.multiparts[multipart_id] @@ -1924,9 +1918,11 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): return len(bucket.multiparts[multipart_id].parts) > next_part_number_marker def create_multipart_upload( - self, bucket_name, key_name, metadata, storage_type, tags + self, bucket_name, key_name, metadata, storage_type, tags, acl ): - multipart = FakeMultipart(key_name, metadata, storage=storage_type, tags=tags) + multipart = FakeMultipart( + key_name, metadata, storage=storage_type, tags=tags, acl=acl + ) bucket = self.get_bucket(bucket_name) bucket.multiparts[multipart.id] = multipart diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 64fc2e181..6d88bae44 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -1917,8 +1917,9 @@ class S3Response(BaseResponse): metadata = metadata_from_headers(request.headers) tagging = self._tagging_from_headers(request.headers) storage_type = request.headers.get("x-amz-storage-class", "STANDARD") + acl = self._acl_from_headers(request.headers) multipart_id = self.backend.create_multipart_upload( - bucket_name, key_name, metadata, storage_type, tagging + bucket_name, key_name, metadata, storage_type, tagging, acl ) template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE) @@ -1947,6 +1948,7 @@ class S3Response(BaseResponse): ) key.set_metadata(multipart.metadata) self.backend.set_key_tags(key, multipart.tags) + self.backend.put_object_acl(bucket_name, key.name, multipart.acl) template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) headers = {} diff --git a/tests/test_s3/test_s3_multipart.py b/tests/test_s3/test_s3_multipart.py index 3af3ecafb..d8a0fd82c 100644 --- a/tests/test_s3/test_s3_multipart.py +++ b/tests/test_s3/test_s3_multipart.py @@ -193,29 +193,48 @@ def test_multipart_upload_out_of_order(): def test_multipart_upload_with_headers(): s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) - s3.create_bucket(Bucket="foobar") + bucket_name = "fancymultiparttest" + key_name = "the-key" + s3.create_bucket(Bucket=bucket_name) part1 = b"0" * REDUCED_PART_SIZE mp = client.create_multipart_upload( - Bucket="foobar", Key="the-key", Metadata={"meta": "data"} + Bucket=bucket_name, + Key=key_name, + Metadata={"meta": "data"}, + StorageClass="STANDARD_IA", + ACL="authenticated-read", ) up1 = client.upload_part( Body=BytesIO(part1), PartNumber=1, - Bucket="foobar", - Key="the-key", + Bucket=bucket_name, + Key=key_name, UploadId=mp["UploadId"], ) client.complete_multipart_upload( - Bucket="foobar", - Key="the-key", + Bucket=bucket_name, + Key=key_name, MultipartUpload={"Parts": [{"ETag": up1["ETag"], "PartNumber": 1}]}, UploadId=mp["UploadId"], ) # we should get both parts as the key contents - response = client.get_object(Bucket="foobar", Key="the-key") + response = client.get_object(Bucket=bucket_name, Key=key_name) response["Metadata"].should.equal({"meta": "data"}) + response["StorageClass"].should.equal("STANDARD_IA") + + grants = client.get_object_acl(Bucket=bucket_name, Key=key_name)["Grants"] + grants.should.have.length_of(2) + grants.should.contain( + { + "Grantee": { + "Type": "Group", + "URI": "http://acs.amazonaws.com/groups/global/AuthenticatedUsers", + }, + "Permission": "READ", + } + ) @pytest.mark.parametrize(