diff --git a/moto/s3/exceptions.py b/moto/s3/exceptions.py index dfea0dec6..7f12ae885 100644 --- a/moto/s3/exceptions.py +++ b/moto/s3/exceptions.py @@ -584,3 +584,23 @@ class HeadOnDeleteMarker(Exception): def __init__(self, marker: "FakeDeleteMarker"): self.marker = marker + + +class DaysMustNotProvidedForSelectRequest(S3ClientError): + code = 400 + + def __init__(self) -> None: + super().__init__( + "DaysMustNotProvidedForSelectRequest", + "`Days` must not be provided for select requests", + ) + + +class DaysMustProvidedExceptForSelectRequest(S3ClientError): + code = 400 + + def __init__(self) -> None: + super().__init__( + "DaysMustProvidedExceptForSelectRequest", + "`Days` must be provided except for select requests", + ) diff --git a/moto/s3/models.py b/moto/s3/models.py index b96a1581c..7f3722a6d 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -37,6 +37,8 @@ from moto.s3.exceptions import ( BucketNeedsToBeNew, CopyObjectMustChangeSomething, CrossLocationLoggingProhibitted, + DaysMustNotProvidedForSelectRequest, + DaysMustProvidedExceptForSelectRequest, EntityTooSmall, HeadOnDeleteMarker, InvalidBucketName, @@ -2882,14 +2884,24 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): for x in query_result ] - def restore_object(self, bucket_name: str, key_name: str, days: str) -> bool: + def restore_object( + self, bucket_name: str, key_name: str, days: Optional[str], type_: Optional[str] + ) -> bool: key = self.get_object(bucket_name, key_name) if not key: raise MissingKey + + if days is None and type_ is None: + raise DaysMustProvidedExceptForSelectRequest() + + if days and type_: + raise DaysMustNotProvidedForSelectRequest() + if key.storage_class not in ARCHIVE_STORAGE_CLASSES: raise InvalidObjectState(storage_class=key.storage_class) had_expiry_date = key.expiry_date is not None - key.restore(int(days)) + if days: + key.restore(int(days)) return had_expiry_date def upload_file(self) -> None: diff --git a/moto/s3/responses.py b/moto/s3/responses.py index a737f74df..7e72a1c15 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -2272,10 +2272,12 @@ class S3Response(BaseResponse): ) elif "restore" in query: - es = minidom.parseString(body).getElementsByTagName("Days") - days = es[0].childNodes[0].wholeText + params = xmltodict.parse(body)["RestoreRequest"] previously_restored = self.backend.restore_object( - bucket_name, key_name, days + bucket_name, + key_name, + params.get("Days", None), + params.get("Type", None), ) status_code = 200 if previously_restored else 202 return status_code, {}, "" diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index cbb5c1d16..7fb2664ab 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -627,6 +627,32 @@ def test_cannot_restore_standard_class_object(): ) +@mock_aws +def test_restore_object_invalid_request_params(): + if not settings.TEST_DECORATOR_MODE: + raise SkipTest("Can't set transition directly in ServerMode") + + s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) + bucket = s3_resource.Bucket("foobar") + bucket.create() + + key = bucket.put_object(Key="the-key", Body=b"somedata", StorageClass="GLACIER") + + # `Days` must be provided except for select requests + with pytest.raises(ClientError) as exc: + key.restore_object(RestoreRequest={}) + err = exc.value.response["Error"] + assert err["Code"] == "DaysMustProvidedExceptForSelectRequest" + assert err["Message"] == "`Days` must be provided except for select requests" + + # `Days` must not be provided for select requests + with pytest.raises(ClientError) as exc: + key.restore_object(RestoreRequest={"Days": 1, "Type": "SELECT"}) + err = exc.value.response["Error"] + assert err["Code"] == "DaysMustNotProvidedForSelectRequest" + assert err["Message"] == "`Days` must not be provided for select requests" + + @mock_aws def test_get_versioning_status(): s3_resource = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)