S3 - Fix corner cases multi upload (#4624)

This commit is contained in:
Bert Blommers 2021-11-23 18:47:48 -01:00 committed by GitHub
parent 74666c1271
commit 4be96719ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 33 deletions

View File

@ -253,6 +253,22 @@ class InvalidMaxPartArgument(S3ClientError):
super(InvalidMaxPartArgument, self).__init__("InvalidArgument", error) super(InvalidMaxPartArgument, self).__init__("InvalidArgument", error)
class InvalidMaxPartNumberArgument(InvalidArgumentError):
code = 400
def __init__(self, value, *args, **kwargs):
error = "Part number must be an integer between 1 and 10000, inclusive"
super().__init__(message=error, name="partNumber", value=value, *args, **kwargs)
class NotAnIntegerException(InvalidArgumentError):
code = 400
def __init__(self, name, value, *args, **kwargs):
error = f"Provided {name} not an integer or within integer range"
super().__init__(message=error, name=name, value=value, *args, **kwargs)
class InvalidNotificationARN(S3ClientError): class InvalidNotificationARN(S3ClientError):
code = 400 code = 400

View File

@ -386,10 +386,9 @@ class FakeMultipart(BaseModel):
return key return key
def list_parts(self, part_number_marker, max_parts): def list_parts(self, part_number_marker, max_parts):
for part_id in self.partlist: max_marker = part_number_marker + max_parts
part = self.parts[part_id] for part_id in self.partlist[part_number_marker:max_marker]:
if part_number_marker <= part.name < part_number_marker + max_parts: yield self.parts[part_id]
yield part
class FakeGrantee(BaseModel): class FakeGrantee(BaseModel):
@ -1874,7 +1873,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
def is_truncated(self, bucket_name, multipart_id, next_part_number_marker): def is_truncated(self, bucket_name, multipart_id, next_part_number_marker):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
return len(bucket.multiparts[multipart_id].parts) >= next_part_number_marker return len(bucket.multiparts[multipart_id].parts) > next_part_number_marker
def create_multipart_upload( def create_multipart_upload(
self, bucket_name, key_name, metadata, storage_type, tags self, bucket_name, key_name, metadata, storage_type, tags

View File

@ -40,6 +40,8 @@ from .exceptions import (
MissingKey, MissingKey,
MissingVersion, MissingVersion,
InvalidMaxPartArgument, InvalidMaxPartArgument,
InvalidMaxPartNumberArgument,
NotAnIntegerException,
InvalidPartOrder, InvalidPartOrder,
MalformedXML, MalformedXML,
MalformedACLError, MalformedACLError,
@ -1324,11 +1326,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# 0 <= PartNumberMarker <= 2,147,483,647 # 0 <= PartNumberMarker <= 2,147,483,647
part_number_marker = int(query.get("part-number-marker", [0])[0]) part_number_marker = int(query.get("part-number-marker", [0])[0])
if part_number_marker > 2147483647:
raise NotAnIntegerException(
name="part-number-marker", value=part_number_marker
)
if not (0 <= part_number_marker <= 2147483647): if not (0 <= part_number_marker <= 2147483647):
raise InvalidMaxPartArgument("part-number-marker", 0, 2147483647) raise InvalidMaxPartArgument("part-number-marker", 0, 2147483647)
# 0 <= MaxParts <= 2,147,483,647 (default is 1,000) # 0 <= MaxParts <= 2,147,483,647 (default is 1,000)
max_parts = int(query.get("max-parts", [1000])[0]) max_parts = int(query.get("max-parts", [1000])[0])
if max_parts > 2147483647:
raise NotAnIntegerException(name="max-parts", value=max_parts)
if not (0 <= max_parts <= 2147483647): if not (0 <= max_parts <= 2147483647):
raise InvalidMaxPartArgument("max-parts", 0, 2147483647) raise InvalidMaxPartArgument("max-parts", 0, 2147483647)
@ -1338,7 +1346,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
part_number_marker=part_number_marker, part_number_marker=part_number_marker,
max_parts=max_parts, max_parts=max_parts,
) )
next_part_number_marker = parts[-1].name + 1 if parts else 0 next_part_number_marker = parts[-1].name if parts else 0
is_truncated = parts and self.backend.is_truncated( is_truncated = parts and self.backend.is_truncated(
bucket_name, upload_id, next_part_number_marker bucket_name, upload_id, next_part_number_marker
) )
@ -1449,6 +1457,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
template = self.response_template(S3_MULTIPART_UPLOAD_RESPONSE) template = self.response_template(S3_MULTIPART_UPLOAD_RESPONSE)
response = template.render(part=key) response = template.render(part=key)
else: else:
if part_number > 10000:
raise InvalidMaxPartNumberArgument(part_number)
key = self.backend.upload_part( key = self.backend.upload_part(
bucket_name, upload_id, part_number, body bucket_name, upload_id, part_number, body
) )

View File

@ -3870,41 +3870,53 @@ def test_boto3_multipart_version():
@mock_s3 @mock_s3
def test_boto3_multipart_list_parts_invalid_argument(): @pytest.mark.parametrize(
"part_nr,msg,msg2",
[
(
-42,
"Argument max-parts must be an integer between 0 and 2147483647",
"Argument part-number-marker must be an integer between 0 and 2147483647",
),
(
2147483647 + 42,
"Provided max-parts not an integer or within integer range",
"Provided part-number-marker not an integer or within integer range",
),
],
)
def test_boto3_multipart_list_parts_invalid_argument(part_nr, msg, msg2):
s3 = boto3.client("s3", region_name="us-east-1") s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket") bucket_name = "mybucketasdfljoqwerasdfas"
s3.create_bucket(Bucket=bucket_name)
mpu = s3.create_multipart_upload(Bucket="mybucket", Key="the-key") mpu = s3.create_multipart_upload(Bucket=bucket_name, Key="the-key")
mpu_id = mpu["UploadId"] mpu_id = mpu["UploadId"]
def get_parts(**kwarg): def get_parts(**kwarg):
s3.list_parts(Bucket="mybucket", Key="the-key", UploadId=mpu_id, **kwarg) s3.list_parts(Bucket=bucket_name, Key="the-key", UploadId=mpu_id, **kwarg)
for value in [-42, 2147483647 + 42]: with pytest.raises(ClientError) as err:
with pytest.raises(ClientError) as err: get_parts(**{"MaxParts": part_nr})
get_parts(**{"MaxParts": value}) e = err.value.response["Error"]
e = err.value.response["Error"] e["Code"].should.equal("InvalidArgument")
e["Code"].should.equal("InvalidArgument") e["Message"].should.equal(msg)
e["Message"].should.equal(
"Argument max-parts must be an integer between 0 and 2147483647"
)
with pytest.raises(ClientError) as err: with pytest.raises(ClientError) as err:
get_parts(**{"PartNumberMarker": value}) get_parts(**{"PartNumberMarker": part_nr})
e = err.value.response["Error"] e = err.value.response["Error"]
e["Code"].should.equal("InvalidArgument") e["Code"].should.equal("InvalidArgument")
e["Message"].should.equal( e["Message"].should.equal(msg2)
"Argument part-number-marker must be an integer between 0 and 2147483647"
)
@mock_s3 @mock_s3
@reduced_min_part_size @reduced_min_part_size
def test_boto3_multipart_list_parts(): def test_boto3_multipart_list_parts():
s3 = boto3.client("s3", region_name="us-east-1") s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket") bucket_name = "mybucketasdfljoqwerasdfas"
s3.create_bucket(Bucket=bucket_name)
mpu = s3.create_multipart_upload(Bucket="mybucket", Key="the-key") mpu = s3.create_multipart_upload(Bucket=bucket_name, Key="the-key")
mpu_id = mpu["UploadId"] mpu_id = mpu["UploadId"]
parts = [] parts = []
@ -3914,7 +3926,7 @@ def test_boto3_multipart_list_parts():
# Get uploaded parts using default values # Get uploaded parts using default values
uploaded_parts = [] uploaded_parts = []
uploaded = s3.list_parts(Bucket="mybucket", Key="the-key", UploadId=mpu_id,) uploaded = s3.list_parts(Bucket=bucket_name, Key="the-key", UploadId=mpu_id,)
assert uploaded["PartNumberMarker"] == 0 assert uploaded["PartNumberMarker"] == 0
@ -3926,7 +3938,7 @@ def test_boto3_multipart_list_parts():
) )
assert uploaded_parts == parts assert uploaded_parts == parts
next_part_number_marker = uploaded["Parts"][-1]["PartNumber"] + 1 next_part_number_marker = uploaded["Parts"][-1]["PartNumber"]
else: else:
next_part_number_marker = 0 next_part_number_marker = 0
@ -3941,7 +3953,7 @@ def test_boto3_multipart_list_parts():
while "there are parts": while "there are parts":
uploaded = s3.list_parts( uploaded = s3.list_parts(
Bucket="mybucket", Bucket=bucket_name,
Key="the-key", Key="the-key",
UploadId=mpu_id, UploadId=mpu_id,
PartNumberMarker=part_number_marker, PartNumberMarker=part_number_marker,
@ -3979,7 +3991,7 @@ def test_boto3_multipart_list_parts():
part_size = REDUCED_PART_SIZE + i part_size = REDUCED_PART_SIZE + i
body = b"1" * part_size body = b"1" * part_size
part = s3.upload_part( part = s3.upload_part(
Bucket="mybucket", Bucket=bucket_name,
Key="the-key", Key="the-key",
PartNumber=i, PartNumber=i,
UploadId=mpu_id, UploadId=mpu_id,
@ -3997,7 +4009,7 @@ def test_boto3_multipart_list_parts():
get_parts_by_batch(11) get_parts_by_batch(11)
s3.complete_multipart_upload( s3.complete_multipart_upload(
Bucket="mybucket", Bucket=bucket_name,
Key="the-key", Key="the-key",
UploadId=mpu_id, UploadId=mpu_id,
MultipartUpload={"Parts": parts}, MultipartUpload={"Parts": parts},

View File

@ -3,7 +3,7 @@ from moto import mock_s3
import boto3 import boto3
import os import os
import pytest import pytest
import sure # pylint: disable=unused-import import sure # noqa # pylint: disable=unused-import
from .test_s3 import DEFAULT_REGION_NAME from .test_s3 import DEFAULT_REGION_NAME
@ -80,3 +80,53 @@ def test_multipart_upload_with_tags():
response = client.get_object_tagging(Bucket=bucket, Key=key) response = client.get_object_tagging(Bucket=bucket, Key=key)
actual = {t["Key"]: t["Value"] for t in response.get("TagSet", [])} actual = {t["Key"]: t["Value"] for t in response.get("TagSet", [])}
actual.should.equal({"a": "b"}) actual.should.equal({"a": "b"})
@mock_s3
def test_multipart_upload_should_return_part_10000():
bucket = "dummybucket"
s3_client = boto3.client("s3", "us-east-1")
key = "test_file"
s3_client.create_bucket(Bucket=bucket)
mpu = s3_client.create_multipart_upload(Bucket=bucket, Key=key)
mpu_id = mpu["UploadId"]
s3_client.upload_part(
Bucket=bucket, Key=key, PartNumber=1, UploadId=mpu_id, Body="data"
)
s3_client.upload_part(
Bucket=bucket, Key=key, PartNumber=2, UploadId=mpu_id, Body="data"
)
s3_client.upload_part(
Bucket=bucket, Key=key, PartNumber=10000, UploadId=mpu_id, Body="data"
)
all_parts = s3_client.list_parts(Bucket=bucket, Key=key, UploadId=mpu_id)["Parts"]
part_nrs = [part["PartNumber"] for part in all_parts]
part_nrs.should.equal([1, 2, 10000])
@mock_s3
@pytest.mark.parametrize("part_nr", [10001, 10002, 20000])
def test_s3_multipart_upload_cannot_upload_part_over_10000(part_nr):
bucket = "dummy"
s3_client = boto3.client("s3", "us-east-1")
key = "test_file"
s3_client.create_bucket(Bucket=bucket)
mpu = s3_client.create_multipart_upload(Bucket=bucket, Key=key)
mpu_id = mpu["UploadId"]
with pytest.raises(ClientError) as exc:
s3_client.upload_part(
Bucket=bucket, Key=key, PartNumber=part_nr, UploadId=mpu_id, Body="data"
)
err = exc.value.response["Error"]
err["Code"].should.equal("InvalidArgument")
err["Message"].should.equal(
"Part number must be an integer between 1 and 10000, inclusive"
)
err["ArgumentName"].should.equal("partNumber")
err["ArgumentValue"].should.equal(f"{part_nr}")