S3 - Fix corner cases multi upload (#4624)

This commit is contained in:
Bert Blommers 2021-11-23 18:47:48 -01:00 committed by GitHub
parent 74666c1271
commit 4be96719ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 33 deletions

View File

@ -253,6 +253,22 @@ class InvalidMaxPartArgument(S3ClientError):
super(InvalidMaxPartArgument, self).__init__("InvalidArgument", error)
class InvalidMaxPartNumberArgument(InvalidArgumentError):
code = 400
def __init__(self, value, *args, **kwargs):
error = "Part number must be an integer between 1 and 10000, inclusive"
super().__init__(message=error, name="partNumber", value=value, *args, **kwargs)
class NotAnIntegerException(InvalidArgumentError):
code = 400
def __init__(self, name, value, *args, **kwargs):
error = f"Provided {name} not an integer or within integer range"
super().__init__(message=error, name=name, value=value, *args, **kwargs)
class InvalidNotificationARN(S3ClientError):
code = 400

View File

@ -386,10 +386,9 @@ class FakeMultipart(BaseModel):
return key
def list_parts(self, part_number_marker, max_parts):
for part_id in self.partlist:
part = self.parts[part_id]
if part_number_marker <= part.name < part_number_marker + max_parts:
yield part
max_marker = part_number_marker + max_parts
for part_id in self.partlist[part_number_marker:max_marker]:
yield self.parts[part_id]
class FakeGrantee(BaseModel):
@ -1874,7 +1873,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
def is_truncated(self, bucket_name, multipart_id, next_part_number_marker):
bucket = self.get_bucket(bucket_name)
return len(bucket.multiparts[multipart_id].parts) >= next_part_number_marker
return len(bucket.multiparts[multipart_id].parts) > next_part_number_marker
def create_multipart_upload(
self, bucket_name, key_name, metadata, storage_type, tags

View File

@ -40,6 +40,8 @@ from .exceptions import (
MissingKey,
MissingVersion,
InvalidMaxPartArgument,
InvalidMaxPartNumberArgument,
NotAnIntegerException,
InvalidPartOrder,
MalformedXML,
MalformedACLError,
@ -1324,11 +1326,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# 0 <= PartNumberMarker <= 2,147,483,647
part_number_marker = int(query.get("part-number-marker", [0])[0])
if part_number_marker > 2147483647:
raise NotAnIntegerException(
name="part-number-marker", value=part_number_marker
)
if not (0 <= part_number_marker <= 2147483647):
raise InvalidMaxPartArgument("part-number-marker", 0, 2147483647)
# 0 <= MaxParts <= 2,147,483,647 (default is 1,000)
max_parts = int(query.get("max-parts", [1000])[0])
if max_parts > 2147483647:
raise NotAnIntegerException(name="max-parts", value=max_parts)
if not (0 <= max_parts <= 2147483647):
raise InvalidMaxPartArgument("max-parts", 0, 2147483647)
@ -1338,7 +1346,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
part_number_marker=part_number_marker,
max_parts=max_parts,
)
next_part_number_marker = parts[-1].name + 1 if parts else 0
next_part_number_marker = parts[-1].name if parts else 0
is_truncated = parts and self.backend.is_truncated(
bucket_name, upload_id, next_part_number_marker
)
@ -1449,6 +1457,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
template = self.response_template(S3_MULTIPART_UPLOAD_RESPONSE)
response = template.render(part=key)
else:
if part_number > 10000:
raise InvalidMaxPartNumberArgument(part_number)
key = self.backend.upload_part(
bucket_name, upload_id, part_number, body
)

View File

@ -3870,41 +3870,53 @@ def test_boto3_multipart_version():
@mock_s3
def test_boto3_multipart_list_parts_invalid_argument():
@pytest.mark.parametrize(
"part_nr,msg,msg2",
[
(
-42,
"Argument max-parts must be an integer between 0 and 2147483647",
"Argument part-number-marker must be an integer between 0 and 2147483647",
),
(
2147483647 + 42,
"Provided max-parts not an integer or within integer range",
"Provided part-number-marker not an integer or within integer range",
),
],
)
def test_boto3_multipart_list_parts_invalid_argument(part_nr, msg, msg2):
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
bucket_name = "mybucketasdfljoqwerasdfas"
s3.create_bucket(Bucket=bucket_name)
mpu = s3.create_multipart_upload(Bucket="mybucket", Key="the-key")
mpu = s3.create_multipart_upload(Bucket=bucket_name, Key="the-key")
mpu_id = mpu["UploadId"]
def get_parts(**kwarg):
s3.list_parts(Bucket="mybucket", Key="the-key", UploadId=mpu_id, **kwarg)
s3.list_parts(Bucket=bucket_name, Key="the-key", UploadId=mpu_id, **kwarg)
for value in [-42, 2147483647 + 42]:
with pytest.raises(ClientError) as err:
get_parts(**{"MaxParts": value})
e = err.value.response["Error"]
e["Code"].should.equal("InvalidArgument")
e["Message"].should.equal(
"Argument max-parts must be an integer between 0 and 2147483647"
)
with pytest.raises(ClientError) as err:
get_parts(**{"MaxParts": part_nr})
e = err.value.response["Error"]
e["Code"].should.equal("InvalidArgument")
e["Message"].should.equal(msg)
with pytest.raises(ClientError) as err:
get_parts(**{"PartNumberMarker": value})
e = err.value.response["Error"]
e["Code"].should.equal("InvalidArgument")
e["Message"].should.equal(
"Argument part-number-marker must be an integer between 0 and 2147483647"
)
with pytest.raises(ClientError) as err:
get_parts(**{"PartNumberMarker": part_nr})
e = err.value.response["Error"]
e["Code"].should.equal("InvalidArgument")
e["Message"].should.equal(msg2)
@mock_s3
@reduced_min_part_size
def test_boto3_multipart_list_parts():
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
bucket_name = "mybucketasdfljoqwerasdfas"
s3.create_bucket(Bucket=bucket_name)
mpu = s3.create_multipart_upload(Bucket="mybucket", Key="the-key")
mpu = s3.create_multipart_upload(Bucket=bucket_name, Key="the-key")
mpu_id = mpu["UploadId"]
parts = []
@ -3914,7 +3926,7 @@ def test_boto3_multipart_list_parts():
# Get uploaded parts using default values
uploaded_parts = []
uploaded = s3.list_parts(Bucket="mybucket", Key="the-key", UploadId=mpu_id,)
uploaded = s3.list_parts(Bucket=bucket_name, Key="the-key", UploadId=mpu_id,)
assert uploaded["PartNumberMarker"] == 0
@ -3926,7 +3938,7 @@ def test_boto3_multipart_list_parts():
)
assert uploaded_parts == parts
next_part_number_marker = uploaded["Parts"][-1]["PartNumber"] + 1
next_part_number_marker = uploaded["Parts"][-1]["PartNumber"]
else:
next_part_number_marker = 0
@ -3941,7 +3953,7 @@ def test_boto3_multipart_list_parts():
while "there are parts":
uploaded = s3.list_parts(
Bucket="mybucket",
Bucket=bucket_name,
Key="the-key",
UploadId=mpu_id,
PartNumberMarker=part_number_marker,
@ -3979,7 +3991,7 @@ def test_boto3_multipart_list_parts():
part_size = REDUCED_PART_SIZE + i
body = b"1" * part_size
part = s3.upload_part(
Bucket="mybucket",
Bucket=bucket_name,
Key="the-key",
PartNumber=i,
UploadId=mpu_id,
@ -3997,7 +4009,7 @@ def test_boto3_multipart_list_parts():
get_parts_by_batch(11)
s3.complete_multipart_upload(
Bucket="mybucket",
Bucket=bucket_name,
Key="the-key",
UploadId=mpu_id,
MultipartUpload={"Parts": parts},

View File

@ -3,7 +3,7 @@ from moto import mock_s3
import boto3
import os
import pytest
import sure # pylint: disable=unused-import
import sure # noqa # pylint: disable=unused-import
from .test_s3 import DEFAULT_REGION_NAME
@ -80,3 +80,53 @@ def test_multipart_upload_with_tags():
response = client.get_object_tagging(Bucket=bucket, Key=key)
actual = {t["Key"]: t["Value"] for t in response.get("TagSet", [])}
actual.should.equal({"a": "b"})
@mock_s3
def test_multipart_upload_should_return_part_10000():
bucket = "dummybucket"
s3_client = boto3.client("s3", "us-east-1")
key = "test_file"
s3_client.create_bucket(Bucket=bucket)
mpu = s3_client.create_multipart_upload(Bucket=bucket, Key=key)
mpu_id = mpu["UploadId"]
s3_client.upload_part(
Bucket=bucket, Key=key, PartNumber=1, UploadId=mpu_id, Body="data"
)
s3_client.upload_part(
Bucket=bucket, Key=key, PartNumber=2, UploadId=mpu_id, Body="data"
)
s3_client.upload_part(
Bucket=bucket, Key=key, PartNumber=10000, UploadId=mpu_id, Body="data"
)
all_parts = s3_client.list_parts(Bucket=bucket, Key=key, UploadId=mpu_id)["Parts"]
part_nrs = [part["PartNumber"] for part in all_parts]
part_nrs.should.equal([1, 2, 10000])
@mock_s3
@pytest.mark.parametrize("part_nr", [10001, 10002, 20000])
def test_s3_multipart_upload_cannot_upload_part_over_10000(part_nr):
bucket = "dummy"
s3_client = boto3.client("s3", "us-east-1")
key = "test_file"
s3_client.create_bucket(Bucket=bucket)
mpu = s3_client.create_multipart_upload(Bucket=bucket, Key=key)
mpu_id = mpu["UploadId"]
with pytest.raises(ClientError) as exc:
s3_client.upload_part(
Bucket=bucket, Key=key, PartNumber=part_nr, UploadId=mpu_id, Body="data"
)
err = exc.value.response["Error"]
err["Code"].should.equal("InvalidArgument")
err["Message"].should.equal(
"Part number must be an integer between 1 and 10000, inclusive"
)
err["ArgumentName"].should.equal("partNumber")
err["ArgumentValue"].should.equal(f"{part_nr}")