Techdebt: MyPy S3 (#6235)

This commit is contained in:
Bert Blommers 2023-04-20 16:47:39 +00:00 committed by GitHub
parent ff48188362
commit c2e3d90fc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 880 additions and 730 deletions

View File

@ -59,7 +59,7 @@ class RESTError(HTTPException):
if template in self.templates.keys(): if template in self.templates.keys():
env = Environment(loader=DictLoader(self.templates)) env = Environment(loader=DictLoader(self.templates))
self.description = env.get_template(template).render( self.description: str = env.get_template(template).render( # type: ignore
error_type=error_type, error_type=error_type,
message=message, message=message,
request_id_tag=self.request_id_tag_name, request_id_tag=self.request_id_tag_name,

View File

@ -1,6 +1,6 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from moto.moto_api import state_manager from moto.moto_api import state_manager
from typing import List, Tuple from typing import List, Tuple, Optional
class ManagedState: class ManagedState:
@ -8,7 +8,7 @@ class ManagedState:
Subclass this class to configure state-transitions Subclass this class to configure state-transitions
""" """
def __init__(self, model_name: str, transitions: List[Tuple[str, str]]): def __init__(self, model_name: str, transitions: List[Tuple[Optional[str], str]]):
# Indicate the possible transitions for this model # Indicate the possible transitions for this model
# Example: [(initializing,queued), (queued, starting), (starting, ready)] # Example: [(initializing,queued), (queued, starting), (starting, ready)]
self._transitions = transitions self._transitions = transitions
@ -28,7 +28,7 @@ class ManagedState:
self._tick += 1 self._tick += 1
@property @property
def status(self) -> str: def status(self) -> Optional[str]:
""" """
Transitions the status as appropriate before returning Transitions the status as appropriate before returning
""" """
@ -55,12 +55,12 @@ class ManagedState:
def status(self, value: str) -> None: def status(self, value: str) -> None:
self._status = value self._status = value
def _get_next_status(self, previous: str) -> str: def _get_next_status(self, previous: Optional[str]) -> Optional[str]:
return next( return next(
(nxt for prev, nxt in self._transitions if previous == prev), previous (nxt for prev, nxt in self._transitions if previous == prev), previous
) )
def _get_last_status(self, previous: str) -> str: def _get_last_status(self, previous: Optional[str]) -> Optional[str]:
next_state = self._get_next_status(previous) next_state = self._get_next_status(previous)
while next_state != previous: while next_state != previous:
previous = next_state previous = next_state

View File

@ -1,7 +1,10 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List
def cfn_to_api_encryption(bucket_encryption_properties): def cfn_to_api_encryption(
bucket_encryption_properties: Dict[str, Any]
) -> Dict[str, Any]:
sse_algorithm = bucket_encryption_properties["ServerSideEncryptionConfiguration"][ sse_algorithm = bucket_encryption_properties["ServerSideEncryptionConfiguration"][
0 0
@ -16,14 +19,12 @@ def cfn_to_api_encryption(bucket_encryption_properties):
rule = OrderedDict( rule = OrderedDict(
{"ApplyServerSideEncryptionByDefault": apply_server_side_encryption_by_default} {"ApplyServerSideEncryptionByDefault": apply_server_side_encryption_by_default}
) )
bucket_encryption = OrderedDict( return OrderedDict(
{"@xmlns": "http://s3.amazonaws.com/doc/2006-03-01/"} {"@xmlns": "http://s3.amazonaws.com/doc/2006-03-01/", "Rule": rule}
) )
bucket_encryption["Rule"] = rule
return bucket_encryption
def is_replacement_update(properties): def is_replacement_update(properties: List[str]) -> bool:
properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"] properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"]
return any( return any(
[ [

View File

@ -1,4 +1,5 @@
import json import json
from typing import Any, Dict, List, Optional, Tuple
from moto.core.exceptions import InvalidNextTokenException from moto.core.exceptions import InvalidNextTokenException
from moto.core.common_models import ConfigQueryModel from moto.core.common_models import ConfigQueryModel
@ -8,15 +9,15 @@ from moto.s3 import s3_backends
class S3ConfigQuery(ConfigQueryModel): class S3ConfigQuery(ConfigQueryModel):
def list_config_service_resources( def list_config_service_resources(
self, self,
account_id, account_id: str,
resource_ids, resource_ids: Optional[List[str]],
resource_name, resource_name: Optional[str],
limit, limit: int,
next_token, next_token: Optional[str],
backend_region=None, backend_region: Optional[str] = None,
resource_region=None, resource_region: Optional[str] = None,
aggregator=None, aggregator: Optional[Dict[str, Any]] = None,
): ) -> Tuple[List[Dict[str, Any]], Optional[str]]:
# The resource_region only matters for aggregated queries as you can filter on bucket regions for them. # The resource_region only matters for aggregated queries as you can filter on bucket regions for them.
# For other resource types, you would need to iterate appropriately for the backend_region. # For other resource types, you would need to iterate appropriately for the backend_region.
@ -37,7 +38,7 @@ class S3ConfigQuery(ConfigQueryModel):
filter_buckets = [resource_name] if resource_name else resource_ids filter_buckets = [resource_name] if resource_name else resource_ids
for bucket in self.backends[account_id]["global"].buckets.keys(): for bucket in self.backends[account_id]["global"].buckets.keys():
if bucket in filter_buckets: if bucket in filter_buckets: # type: ignore
bucket_list.append(bucket) bucket_list.append(bucket)
# Filter on the proper region if supplied: # Filter on the proper region if supplied:
@ -95,26 +96,26 @@ class S3ConfigQuery(ConfigQueryModel):
def get_config_resource( def get_config_resource(
self, self,
account_id, account_id: str,
resource_id, resource_id: str,
resource_name=None, resource_name: Optional[str] = None,
backend_region=None, backend_region: Optional[str] = None,
resource_region=None, resource_region: Optional[str] = None,
): ) -> Optional[Dict[str, Any]]:
# Get the bucket: # Get the bucket:
bucket = self.backends[account_id]["global"].buckets.get(resource_id, {}) bucket = self.backends[account_id]["global"].buckets.get(resource_id, {})
if not bucket: if not bucket:
return return None
# Are we filtering based on region? # Are we filtering based on region?
region_filter = backend_region or resource_region region_filter = backend_region or resource_region
if region_filter and bucket.region_name != region_filter: if region_filter and bucket.region_name != region_filter:
return return None
# Are we also filtering on bucket name? # Are we also filtering on bucket name?
if resource_name and bucket.name != resource_name: if resource_name and bucket.name != resource_name:
return return None
# Format the bucket to the AWS Config format: # Format the bucket to the AWS Config format:
config_data = bucket.to_config_dict() config_data = bucket.to_config_dict()

View File

@ -1,3 +1,4 @@
from typing import Any, Optional, Union
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
ERROR_WITH_BUCKET_NAME = """{% extends 'single_error' %} ERROR_WITH_BUCKET_NAME = """{% extends 'single_error' %}
@ -35,7 +36,7 @@ class S3ClientError(RESTError):
# S3 API uses <RequestID> as the XML tag in response messages # S3 API uses <RequestID> as the XML tag in response messages
request_id_tag_name = "RequestID" request_id_tag_name = "RequestID"
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "single_error") kwargs.setdefault("template", "single_error")
self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -44,7 +45,7 @@ class S3ClientError(RESTError):
class InvalidArgumentError(S3ClientError): class InvalidArgumentError(S3ClientError):
code = 400 code = 400
def __init__(self, message, name, value, *args, **kwargs): def __init__(self, message: str, name: str, value: str, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "argument_error") kwargs.setdefault("template", "argument_error")
kwargs["name"] = name kwargs["name"] = name
kwargs["value"] = value kwargs["value"] = value
@ -60,7 +61,7 @@ class AccessForbidden(S3ClientError):
class BucketError(S3ClientError): class BucketError(S3ClientError):
def __init__(self, *args: str, **kwargs: str): def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "bucket_error") kwargs.setdefault("template", "bucket_error")
self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -69,7 +70,7 @@ class BucketError(S3ClientError):
class BucketAlreadyExists(BucketError): class BucketAlreadyExists(BucketError):
code = 409 code = 409
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "bucket_error") kwargs.setdefault("template", "bucket_error")
self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME
super().__init__( super().__init__(
@ -87,16 +88,16 @@ class BucketAlreadyExists(BucketError):
class MissingBucket(BucketError): class MissingBucket(BucketError):
code = 404 code = 404
def __init__(self, *args, **kwargs): def __init__(self, bucket: str):
super().__init__( super().__init__(
"NoSuchBucket", "The specified bucket does not exist", *args, **kwargs "NoSuchBucket", "The specified bucket does not exist", bucket=bucket
) )
class MissingKey(S3ClientError): class MissingKey(S3ClientError):
code = 404 code = 404
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
kwargs.setdefault("template", "key_error") kwargs.setdefault("template", "key_error")
self.templates["key_error"] = ERROR_WITH_KEY_NAME self.templates["key_error"] = ERROR_WITH_KEY_NAME
super().__init__("NoSuchKey", "The specified key does not exist.", **kwargs) super().__init__("NoSuchKey", "The specified key does not exist.", **kwargs)
@ -105,16 +106,14 @@ class MissingKey(S3ClientError):
class MissingVersion(S3ClientError): class MissingVersion(S3ClientError):
code = 404 code = 404
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__("NoSuchVersion", "The specified version does not exist.")
"NoSuchVersion", "The specified version does not exist.", *args, **kwargs
)
class InvalidVersion(S3ClientError): class InvalidVersion(S3ClientError):
code = 400 code = 400
def __init__(self, version_id, *args, **kwargs): def __init__(self, version_id: str, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "argument_error") kwargs.setdefault("template", "argument_error")
kwargs["name"] = "versionId" kwargs["name"] = "versionId"
kwargs["value"] = version_id kwargs["value"] = version_id
@ -127,7 +126,7 @@ class InvalidVersion(S3ClientError):
class ObjectNotInActiveTierError(S3ClientError): class ObjectNotInActiveTierError(S3ClientError):
code = 403 code = 403
def __init__(self, key_name): def __init__(self, key_name: Any):
super().__init__( super().__init__(
"ObjectNotInActiveTierError", "ObjectNotInActiveTierError",
"The source object of the COPY operation is not in the active tier and is only stored in Amazon Glacier.", "The source object of the COPY operation is not in the active tier and is only stored in Amazon Glacier.",
@ -138,105 +137,84 @@ class ObjectNotInActiveTierError(S3ClientError):
class InvalidPartOrder(S3ClientError): class InvalidPartOrder(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidPartOrder", "InvalidPartOrder",
( "The list of parts was not in ascending order. The parts list must be specified in order by part number.",
"The list of parts was not in ascending order. The parts "
"list must be specified in order by part number."
),
*args,
**kwargs,
) )
class InvalidPart(S3ClientError): class InvalidPart(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidPart", "InvalidPart",
( "One or more of the specified parts could not be found. The part might not have been uploaded, or the specified entity tag might not have matched the part's entity tag.",
"One or more of the specified parts could not be found. "
"The part might not have been uploaded, or the specified "
"entity tag might not have matched the part's entity tag."
),
*args,
**kwargs,
) )
class EntityTooSmall(S3ClientError): class EntityTooSmall(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"EntityTooSmall", "EntityTooSmall",
"Your proposed upload is smaller than the minimum allowed object size.", "Your proposed upload is smaller than the minimum allowed object size.",
*args,
**kwargs,
) )
class InvalidRequest(S3ClientError): class InvalidRequest(S3ClientError):
code = 400 code = 400
def __init__(self, method, *args, **kwargs): def __init__(self, method: str):
super().__init__( super().__init__(
"InvalidRequest", "InvalidRequest",
f"Found unsupported HTTP method in CORS config. Unsupported method is {method}", f"Found unsupported HTTP method in CORS config. Unsupported method is {method}",
*args,
**kwargs,
) )
class IllegalLocationConstraintException(S3ClientError): class IllegalLocationConstraintException(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"IllegalLocationConstraintException", "IllegalLocationConstraintException",
"The unspecified location constraint is incompatible for the region specific endpoint this request was sent to.", "The unspecified location constraint is incompatible for the region specific endpoint this request was sent to.",
*args,
**kwargs,
) )
class MalformedXML(S3ClientError): class MalformedXML(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"MalformedXML", "MalformedXML",
"The XML you provided was not well-formed or did not validate against our published schema", "The XML you provided was not well-formed or did not validate against our published schema",
*args,
**kwargs,
) )
class MalformedACLError(S3ClientError): class MalformedACLError(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"MalformedACLError", "MalformedACLError",
"The XML you provided was not well-formed or did not validate against our published schema", "The XML you provided was not well-formed or did not validate against our published schema",
*args,
**kwargs,
) )
class InvalidTargetBucketForLogging(S3ClientError): class InvalidTargetBucketForLogging(S3ClientError):
code = 400 code = 400
def __init__(self, msg): def __init__(self, msg: str):
super().__init__("InvalidTargetBucketForLogging", msg) super().__init__("InvalidTargetBucketForLogging", msg)
class CrossLocationLoggingProhibitted(S3ClientError): class CrossLocationLoggingProhibitted(S3ClientError):
code = 403 code = 403
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"CrossLocationLoggingProhibitted", "Cross S3 location logging not allowed." "CrossLocationLoggingProhibitted", "Cross S3 location logging not allowed."
) )
@ -245,7 +223,7 @@ class CrossLocationLoggingProhibitted(S3ClientError):
class InvalidMaxPartArgument(S3ClientError): class InvalidMaxPartArgument(S3ClientError):
code = 400 code = 400
def __init__(self, arg, min_val, max_val): def __init__(self, arg: str, min_val: int, max_val: int):
error = f"Argument {arg} must be an integer between {min_val} and {max_val}" error = f"Argument {arg} must be an integer between {min_val} and {max_val}"
super().__init__("InvalidArgument", error) super().__init__("InvalidArgument", error)
@ -253,97 +231,83 @@ class InvalidMaxPartArgument(S3ClientError):
class InvalidMaxPartNumberArgument(InvalidArgumentError): class InvalidMaxPartNumberArgument(InvalidArgumentError):
code = 400 code = 400
def __init__(self, value, *args, **kwargs): def __init__(self, value: int):
error = "Part number must be an integer between 1 and 10000, inclusive" error = "Part number must be an integer between 1 and 10000, inclusive"
super().__init__(message=error, name="partNumber", value=value, *args, **kwargs) super().__init__(message=error, name="partNumber", value=value) # type: ignore
class NotAnIntegerException(InvalidArgumentError): class NotAnIntegerException(InvalidArgumentError):
code = 400 code = 400
def __init__(self, name, value, *args, **kwargs): def __init__(self, name: str, value: int):
error = f"Provided {name} not an integer or within integer range" error = f"Provided {name} not an integer or within integer range"
super().__init__(message=error, name=name, value=value, *args, **kwargs) super().__init__(message=error, name=name, value=value) # type: ignore
class InvalidNotificationARN(S3ClientError): class InvalidNotificationARN(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__("InvalidArgument", "The ARN is not well formed")
"InvalidArgument", "The ARN is not well formed", *args, **kwargs
)
class InvalidNotificationDestination(S3ClientError): class InvalidNotificationDestination(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidArgument", "InvalidArgument",
"The notification destination service region is not valid for the bucket location constraint", "The notification destination service region is not valid for the bucket location constraint",
*args,
**kwargs,
) )
class InvalidNotificationEvent(S3ClientError): class InvalidNotificationEvent(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidArgument", "InvalidArgument",
"The event is not supported for notifications", "The event is not supported for notifications",
*args,
**kwargs,
) )
class InvalidStorageClass(S3ClientError): class InvalidStorageClass(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self, storage: Optional[str]):
super().__init__( super().__init__(
"InvalidStorageClass", "InvalidStorageClass",
"The storage class you specified is not valid", "The storage class you specified is not valid",
*args, storage=storage,
**kwargs,
) )
class InvalidBucketName(S3ClientError): class InvalidBucketName(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__("InvalidBucketName", "The specified bucket is not valid.")
"InvalidBucketName", "The specified bucket is not valid.", *args, **kwargs
)
class DuplicateTagKeys(S3ClientError): class DuplicateTagKeys(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__("InvalidTag", "Cannot provide multiple Tags with the same key")
"InvalidTag",
"Cannot provide multiple Tags with the same key",
*args,
**kwargs,
)
class S3AccessDeniedError(S3ClientError): class S3AccessDeniedError(S3ClientError):
code = 403 code = 403
def __init__(self, *args: str, **kwargs: str): def __init__(self) -> None:
super().__init__("AccessDenied", "Access Denied", *args, **kwargs) super().__init__("AccessDenied", "Access Denied")
class BucketAccessDeniedError(BucketError): class BucketAccessDeniedError(BucketError):
code = 403 code = 403
def __init__(self, *args: str, **kwargs: str): def __init__(self, bucket: str):
super().__init__("AccessDenied", "Access Denied", *args, **kwargs) super().__init__("AccessDenied", "Access Denied", bucket=bucket)
class S3InvalidTokenError(S3ClientError): class S3InvalidTokenError(S3ClientError):
@ -368,12 +332,11 @@ class S3AclAndGrantError(S3ClientError):
class BucketInvalidTokenError(BucketError): class BucketInvalidTokenError(BucketError):
code = 400 code = 400
def __init__(self, *args: str, **kwargs: str): def __init__(self, bucket: str):
super().__init__( super().__init__(
"InvalidToken", "InvalidToken",
"The provided token is malformed or otherwise invalid.", "The provided token is malformed or otherwise invalid.",
*args, bucket=bucket,
**kwargs,
) )
@ -390,12 +353,11 @@ class S3InvalidAccessKeyIdError(S3ClientError):
class BucketInvalidAccessKeyIdError(S3ClientError): class BucketInvalidAccessKeyIdError(S3ClientError):
code = 403 code = 403
def __init__(self, *args: str, **kwargs: str): def __init__(self, bucket: str):
super().__init__( super().__init__(
"InvalidAccessKeyId", "InvalidAccessKeyId",
"The AWS Access Key Id you provided does not exist in our records.", "The AWS Access Key Id you provided does not exist in our records.",
*args, bucket=bucket,
**kwargs,
) )
@ -412,50 +374,45 @@ class S3SignatureDoesNotMatchError(S3ClientError):
class BucketSignatureDoesNotMatchError(S3ClientError): class BucketSignatureDoesNotMatchError(S3ClientError):
code = 403 code = 403
def __init__(self, *args: str, **kwargs: str): def __init__(self, bucket: str):
super().__init__( super().__init__(
"SignatureDoesNotMatch", "SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided. Check your key and signing method.", "The request signature we calculated does not match the signature you provided. Check your key and signing method.",
*args, bucket=bucket,
**kwargs,
) )
class NoSuchPublicAccessBlockConfiguration(S3ClientError): class NoSuchPublicAccessBlockConfiguration(S3ClientError):
code = 404 code = 404
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"NoSuchPublicAccessBlockConfiguration", "NoSuchPublicAccessBlockConfiguration",
"The public access block configuration was not found", "The public access block configuration was not found",
*args,
**kwargs,
) )
class InvalidPublicAccessBlockConfiguration(S3ClientError): class InvalidPublicAccessBlockConfiguration(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidRequest", "InvalidRequest",
"Must specify at least one configuration.", "Must specify at least one configuration.",
*args,
**kwargs,
) )
class WrongPublicAccessBlockAccountIdError(S3ClientError): class WrongPublicAccessBlockAccountIdError(S3ClientError):
code = 403 code = 403
def __init__(self): def __init__(self) -> None:
super().__init__("AccessDenied", "Access Denied") super().__init__("AccessDenied", "Access Denied")
class NoSystemTags(S3ClientError): class NoSystemTags(S3ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidTag", "System tags cannot be added/updated by requester" "InvalidTag", "System tags cannot be added/updated by requester"
) )
@ -464,7 +421,7 @@ class NoSystemTags(S3ClientError):
class NoSuchUpload(S3ClientError): class NoSuchUpload(S3ClientError):
code = 404 code = 404
def __init__(self, upload_id, *args, **kwargs): def __init__(self, upload_id: Union[int, str], *args: Any, **kwargs: Any):
kwargs.setdefault("template", "error_uploadid") kwargs.setdefault("template", "error_uploadid")
kwargs["upload_id"] = upload_id kwargs["upload_id"] = upload_id
self.templates["error_uploadid"] = ERROR_WITH_UPLOADID self.templates["error_uploadid"] = ERROR_WITH_UPLOADID
@ -479,7 +436,7 @@ class NoSuchUpload(S3ClientError):
class PreconditionFailed(S3ClientError): class PreconditionFailed(S3ClientError):
code = 412 code = 412
def __init__(self, failed_condition, **kwargs): def __init__(self, failed_condition: str, **kwargs: Any):
kwargs.setdefault("template", "condition_error") kwargs.setdefault("template", "condition_error")
self.templates["condition_error"] = ERROR_WITH_CONDITION_NAME self.templates["condition_error"] = ERROR_WITH_CONDITION_NAME
super().__init__( super().__init__(
@ -493,7 +450,7 @@ class PreconditionFailed(S3ClientError):
class InvalidRange(S3ClientError): class InvalidRange(S3ClientError):
code = 416 code = 416
def __init__(self, range_requested, actual_size, **kwargs): def __init__(self, range_requested: str, actual_size: str, **kwargs: Any):
kwargs.setdefault("template", "range_error") kwargs.setdefault("template", "range_error")
self.templates["range_error"] = ERROR_WITH_RANGE self.templates["range_error"] = ERROR_WITH_RANGE
super().__init__( super().__init__(
@ -508,19 +465,16 @@ class InvalidRange(S3ClientError):
class InvalidContinuationToken(S3ClientError): class InvalidContinuationToken(S3ClientError):
code = 400 code = 400
def __init__(self, *args, **kwargs): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidArgument", "InvalidArgument", "The continuation token provided is incorrect"
"The continuation token provided is incorrect",
*args,
**kwargs,
) )
class InvalidObjectState(BucketError): class InvalidObjectState(BucketError):
code = 403 code = 403
def __init__(self, storage_class, **kwargs): def __init__(self, storage_class: Optional[str], **kwargs: Any):
kwargs.setdefault("template", "storage_error") kwargs.setdefault("template", "storage_error")
self.templates["storage_error"] = ERROR_WITH_STORAGE_CLASS self.templates["storage_error"] = ERROR_WITH_STORAGE_CLASS
super().__init__( super().__init__(
@ -534,35 +488,35 @@ class InvalidObjectState(BucketError):
class LockNotEnabled(S3ClientError): class LockNotEnabled(S3ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__("InvalidRequest", "Bucket is missing ObjectLockConfiguration") super().__init__("InvalidRequest", "Bucket is missing ObjectLockConfiguration")
class AccessDeniedByLock(S3ClientError): class AccessDeniedByLock(S3ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__("AccessDenied", "Access Denied") super().__init__("AccessDenied", "Access Denied")
class InvalidContentMD5(S3ClientError): class InvalidContentMD5(S3ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__("InvalidContentMD5", "Content MD5 header is invalid") super().__init__("InvalidContentMD5", "Content MD5 header is invalid")
class BucketNeedsToBeNew(S3ClientError): class BucketNeedsToBeNew(S3ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__("InvalidBucket", "Bucket needs to be empty") super().__init__("InvalidBucket", "Bucket needs to be empty")
class BucketMustHaveLockeEnabled(S3ClientError): class BucketMustHaveLockeEnabled(S3ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidBucketState", "InvalidBucketState",
"Object Lock configuration cannot be enabled on existing buckets", "Object Lock configuration cannot be enabled on existing buckets",
@ -572,7 +526,7 @@ class BucketMustHaveLockeEnabled(S3ClientError):
class CopyObjectMustChangeSomething(S3ClientError): class CopyObjectMustChangeSomething(S3ClientError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidRequest", "InvalidRequest",
"This copy request is illegal because it is trying to copy an object to itself without changing the object's metadata, storage class, website redirect location or encryption attributes.", "This copy request is illegal because it is trying to copy an object to itself without changing the object's metadata, storage class, website redirect location or encryption attributes.",
@ -582,27 +536,25 @@ class CopyObjectMustChangeSomething(S3ClientError):
class InvalidFilterRuleName(InvalidArgumentError): class InvalidFilterRuleName(InvalidArgumentError):
code = 400 code = 400
def __init__(self, value, *args, **kwargs): def __init__(self, value: str):
super().__init__( super().__init__(
"filter rule name must be either prefix or suffix", "filter rule name must be either prefix or suffix",
"FilterRule.Name", "FilterRule.Name",
value, value,
*args,
**kwargs,
) )
class InvalidTagError(S3ClientError): class InvalidTagError(S3ClientError):
code = 400 code = 400
def __init__(self, value, *args, **kwargs): def __init__(self, value: str):
super().__init__("InvalidTag", value, *args, **kwargs) super().__init__("InvalidTag", value)
class ObjectLockConfigurationNotFoundError(S3ClientError): class ObjectLockConfigurationNotFoundError(S3ClientError):
code = 404 code = 404
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"ObjectLockConfigurationNotFoundError", "ObjectLockConfigurationNotFoundError",
"Object Lock configuration does not exist for this bucket", "Object Lock configuration does not exist for this bucket",

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
import json import json
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List
_EVENT_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f" _EVENT_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
@ -7,7 +8,9 @@ S3_OBJECT_CREATE_COPY = "s3:ObjectCreated:Copy"
S3_OBJECT_CREATE_PUT = "s3:ObjectCreated:Put" S3_OBJECT_CREATE_PUT = "s3:ObjectCreated:Put"
def _get_s3_event(event_name, bucket, key, notification_id): def _get_s3_event(
event_name: str, bucket: Any, key: Any, notification_id: str
) -> Dict[str, List[Dict[str, Any]]]:
etag = key.etag.replace('"', "") etag = key.etag.replace('"', "")
# s3:ObjectCreated:Put --> ObjectCreated:Put # s3:ObjectCreated:Put --> ObjectCreated:Put
event_name = event_name[3:] event_name = event_name[3:]
@ -34,11 +37,11 @@ def _get_s3_event(event_name, bucket, key, notification_id):
} }
def _get_region_from_arn(arn): def _get_region_from_arn(arn: str) -> str:
return arn.split(":")[3] return arn.split(":")[3]
def send_event(account_id, event_name, bucket, key): def send_event(account_id: str, event_name: str, bucket: Any, key: Any) -> None:
if bucket.notification_configuration is None: if bucket.notification_configuration is None:
return return
@ -58,7 +61,9 @@ def send_event(account_id, event_name, bucket, key):
_send_sqs_message(account_id, event_body, queue_name, region_name) _send_sqs_message(account_id, event_body, queue_name, region_name)
def _send_sqs_message(account_id, event_body, queue_name, region_name): def _send_sqs_message(
account_id: str, event_body: Any, queue_name: str, region_name: str
) -> None:
try: try:
from moto.sqs.models import sqs_backends from moto.sqs.models import sqs_backends
@ -74,7 +79,9 @@ def _send_sqs_message(account_id, event_body, queue_name, region_name):
pass pass
def _invoke_awslambda(account_id, event_body, fn_arn, region_name): def _invoke_awslambda(
account_id: str, event_body: Any, fn_arn: str, region_name: str
) -> None:
try: try:
from moto.awslambda.models import lambda_backends from moto.awslambda.models import lambda_backends
@ -89,7 +96,7 @@ def _invoke_awslambda(account_id, event_body, fn_arn, region_name):
pass pass
def _get_test_event(bucket_name): def _get_test_event(bucket_name: str) -> Dict[str, Any]:
event_time = datetime.now().strftime(_EVENT_TIME_FORMAT) event_time = datetime.now().strftime(_EVENT_TIME_FORMAT)
return { return {
"Service": "Amazon S3", "Service": "Amazon S3",
@ -99,7 +106,7 @@ def _get_test_event(bucket_name):
} }
def send_test_event(account_id, bucket): def send_test_event(account_id: str, bucket: Any) -> None:
arns = [n.arn for n in bucket.notification_configuration.queue] arns = [n.arn for n in bucket.notification_configuration.queue]
for arn in set(arns): for arn in set(arns):
region_name = _get_region_from_arn(arn) region_name = _get_region_from_arn(arn)

View File

@ -1,7 +1,7 @@
import io import io
import os import os
import re import re
from typing import List, Union from typing import Any, Dict, List, Iterator, Union, Tuple, Optional, Type
import urllib.parse import urllib.parse
@ -14,6 +14,7 @@ from urllib.parse import parse_qs, urlparse, unquote, urlencode, urlunparse
import xmltodict import xmltodict
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import path_url from moto.core.utils import path_url
@ -53,7 +54,7 @@ from .exceptions import (
AccessForbidden, AccessForbidden,
) )
from .models import s3_backends, S3Backend from .models import s3_backends, S3Backend
from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeBucket
from .select_object_content import serialize_select from .select_object_content import serialize_select
from .utils import ( from .utils import (
bucket_name_from_url, bucket_name_from_url,
@ -146,13 +147,13 @@ ACTION_MAP = {
} }
def parse_key_name(pth): def parse_key_name(pth: str) -> str:
# strip the first '/' left by urlparse # strip the first '/' left by urlparse
return pth[1:] if pth.startswith("/") else pth return pth[1:] if pth.startswith("/") else pth
class S3Response(BaseResponse): class S3Response(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="s3") super().__init__(service_name="s3")
@property @property
@ -160,10 +161,10 @@ class S3Response(BaseResponse):
return s3_backends[self.current_account]["global"] return s3_backends[self.current_account]["global"]
@property @property
def should_autoescape(self): def should_autoescape(self) -> bool:
return True return True
def all_buckets(self): def all_buckets(self) -> str:
self.data["Action"] = "ListAllMyBuckets" self.data["Action"] = "ListAllMyBuckets"
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
@ -172,7 +173,7 @@ class S3Response(BaseResponse):
template = self.response_template(S3_ALL_BUCKETS) template = self.response_template(S3_ALL_BUCKETS)
return template.render(buckets=all_buckets) return template.render(buckets=all_buckets)
def subdomain_based_buckets(self, request): def subdomain_based_buckets(self, request: Any) -> bool:
if settings.S3_IGNORE_SUBDOMAIN_BUCKETNAME: if settings.S3_IGNORE_SUBDOMAIN_BUCKETNAME:
return False return False
host = request.headers.get("host", request.headers.get("Host")) host = request.headers.get("host", request.headers.get("Host"))
@ -224,23 +225,25 @@ class S3Response(BaseResponse):
) )
return not path_based return not path_based
def is_delete_keys(self): def is_delete_keys(self) -> bool:
qs = parse_qs(urlparse(self.path).query, keep_blank_values=True) qs = parse_qs(urlparse(self.path).query, keep_blank_values=True)
return "delete" in qs return "delete" in qs
def parse_bucket_name_from_url(self, request, url): def parse_bucket_name_from_url(self, request: Any, url: str) -> str:
if self.subdomain_based_buckets(request): if self.subdomain_based_buckets(request):
return bucket_name_from_url(url) return bucket_name_from_url(url) # type: ignore
else: else:
return bucketpath_bucket_name_from_url(url) return bucketpath_bucket_name_from_url(url) # type: ignore
def parse_key_name(self, request, url): def parse_key_name(self, request: Any, url: str) -> str:
if self.subdomain_based_buckets(request): if self.subdomain_based_buckets(request):
return parse_key_name(url) return parse_key_name(url)
else: else:
return bucketpath_parse_key_name(url) return bucketpath_parse_key_name(url)
def ambiguous_response(self, request, full_url, headers): def ambiguous_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
# Depending on which calling format the client is using, we don't know # Depending on which calling format the client is using, we don't know
# if this is a bucket or key request so we have to check # if this is a bucket or key request so we have to check
if self.subdomain_based_buckets(request): if self.subdomain_based_buckets(request):
@ -250,7 +253,7 @@ class S3Response(BaseResponse):
return self.bucket_response(request, full_url, headers) return self.bucket_response(request, full_url, headers)
@amzn_request_id @amzn_request_id
def bucket_response(self, request, full_url, headers): def bucket_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore
self.setup_class(request, full_url, headers, use_raw_body=True) self.setup_class(request, full_url, headers, use_raw_body=True)
bucket_name = self.parse_bucket_name_from_url(request, full_url) bucket_name = self.parse_bucket_name_from_url(request, full_url)
self.backend.log_incoming_request(request, bucket_name) self.backend.log_incoming_request(request, bucket_name)
@ -262,7 +265,7 @@ class S3Response(BaseResponse):
return self._send_response(response) return self._send_response(response)
@staticmethod @staticmethod
def _send_response(response): def _send_response(response: Any) -> TYPE_RESPONSE: # type: ignore
if isinstance(response, str): if isinstance(response, str):
return 200, {}, response.encode("utf-8") return 200, {}, response.encode("utf-8")
else: else:
@ -272,7 +275,9 @@ class S3Response(BaseResponse):
return status_code, headers, response_content return status_code, headers, response_content
def _bucket_response(self, request, full_url): def _bucket_response(
self, request: Any, full_url: str
) -> Union[str, TYPE_RESPONSE]:
querystring = self._get_querystring(request, full_url) querystring = self._get_querystring(request, full_url)
method = request.method method = request.method
region_name = parse_region_from_url(full_url, use_default_region=False) region_name = parse_region_from_url(full_url, use_default_region=False)
@ -309,7 +314,7 @@ class S3Response(BaseResponse):
) )
@staticmethod @staticmethod
def _get_querystring(request, full_url): def _get_querystring(request: Any, full_url: str) -> Dict[str, Any]: # type: ignore[misc]
# Flask's Request has the querystring already parsed # Flask's Request has the querystring already parsed
# In ServerMode, we can use this, instead of manually parsing this # In ServerMode, we can use this, instead of manually parsing this
if hasattr(request, "args"): if hasattr(request, "args"):
@ -330,10 +335,11 @@ class S3Response(BaseResponse):
# Workaround - manually reverse the encoding. # Workaround - manually reverse the encoding.
# Keep the + encoded, ensuring that parse_qsl doesn't replace it, and parse_qsl will unquote it afterwards # Keep the + encoded, ensuring that parse_qsl doesn't replace it, and parse_qsl will unquote it afterwards
qs = (parsed_url.query or "").replace("+", "%2B") qs = (parsed_url.query or "").replace("+", "%2B")
querystring = parse_qs(qs, keep_blank_values=True) return parse_qs(qs, keep_blank_values=True)
return querystring
def _bucket_response_head(self, bucket_name, querystring): def _bucket_response_head(
self, bucket_name: str, querystring: Dict[str, Any]
) -> TYPE_RESPONSE:
self._set_action("BUCKET", "HEAD", querystring) self._set_action("BUCKET", "HEAD", querystring)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
@ -347,7 +353,7 @@ class S3Response(BaseResponse):
return 404, {}, "" return 404, {}, ""
return 200, {"x-amz-bucket-region": bucket.region_name}, "" return 200, {"x-amz-bucket-region": bucket.region_name}, ""
def _set_cors_headers(self, headers, bucket): def _set_cors_headers(self, headers: Dict[str, str], bucket: FakeBucket) -> None:
""" """
TODO: smarter way of matching the right CORS rule: TODO: smarter way of matching the right CORS rule:
See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html
@ -372,8 +378,8 @@ class S3Response(BaseResponse):
) )
if cors_rule.allowed_origins is not None: if cors_rule.allowed_origins is not None:
origin = headers.get("Origin") origin = headers.get("Origin")
if cors_matches_origin(origin, cors_rule.allowed_origins): if cors_matches_origin(origin, cors_rule.allowed_origins): # type: ignore
self.response_headers["Access-Control-Allow-Origin"] = origin self.response_headers["Access-Control-Allow-Origin"] = origin # type: ignore
else: else:
raise AccessForbidden( raise AccessForbidden(
"CORSResponse: This CORS request is not allowed. This is usually because the evalution of Origin, request method / Access-Control-Request-Method or Access-Control-Request-Headers are not whitelisted by the resource's CORS spec." "CORSResponse: This CORS request is not allowed. This is usually because the evalution of Origin, request method / Access-Control-Request-Method or Access-Control-Request-Headers are not whitelisted by the resource's CORS spec."
@ -391,23 +397,24 @@ class S3Response(BaseResponse):
cors_rule.max_age_seconds cors_rule.max_age_seconds
) )
def _response_options(self, headers, bucket_name): def _response_options(
self, headers: Dict[str, str], bucket_name: str
) -> TYPE_RESPONSE:
# Return 200 with the headers from the bucket CORS configuration # Return 200 with the headers from the bucket CORS configuration
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
try: try:
bucket = self.backend.head_bucket(bucket_name) bucket = self.backend.head_bucket(bucket_name)
except MissingBucket: except MissingBucket:
return ( # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD
403, return 403, {}, ""
{},
"",
) # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD
self._set_cors_headers(headers, bucket) self._set_cors_headers(headers, bucket)
return 200, self.response_headers, "" return 200, self.response_headers, ""
def _bucket_response_get(self, bucket_name, querystring): def _bucket_response_get(
self, bucket_name: str, querystring: Dict[str, Any]
) -> Union[str, TYPE_RESPONSE]:
self._set_action("BUCKET", "GET", querystring) self._set_action("BUCKET", "GET", querystring)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
@ -445,7 +452,7 @@ class S3Response(BaseResponse):
account_id=self.current_account, account_id=self.current_account,
) )
elif "location" in querystring: elif "location" in querystring:
location = self.backend.get_bucket_location(bucket_name) location: Optional[str] = self.backend.get_bucket_location(bucket_name)
template = self.response_template(S3_BUCKET_LOCATION) template = self.response_template(S3_BUCKET_LOCATION)
# us-east-1 is different - returns a None location # us-east-1 is different - returns a None location
@ -477,7 +484,7 @@ class S3Response(BaseResponse):
if not website_configuration: if not website_configuration:
template = self.response_template(S3_NO_BUCKET_WEBSITE_CONFIG) template = self.response_template(S3_NO_BUCKET_WEBSITE_CONFIG)
return 404, {}, template.render(bucket_name=bucket_name) return 404, {}, template.render(bucket_name=bucket_name)
return 200, {}, website_configuration return 200, {}, website_configuration # type: ignore
elif "acl" in querystring: elif "acl" in querystring:
acl = self.backend.get_bucket_acl(bucket_name) acl = self.backend.get_bucket_acl(bucket_name)
template = self.response_template(S3_OBJECT_ACL_RESPONSE) template = self.response_template(S3_OBJECT_ACL_RESPONSE)
@ -615,7 +622,9 @@ class S3Response(BaseResponse):
), ),
) )
def _set_action(self, action_resource_type, method, querystring): def _set_action(
self, action_resource_type: str, method: str, querystring: Dict[str, Any]
) -> None:
action_set = False action_set = False
for action_in_querystring, action in ACTION_MAP[action_resource_type][ for action_in_querystring, action in ACTION_MAP[action_resource_type][
method method
@ -626,7 +635,9 @@ class S3Response(BaseResponse):
if not action_set: if not action_set:
self.data["Action"] = ACTION_MAP[action_resource_type][method]["DEFAULT"] self.data["Action"] = ACTION_MAP[action_resource_type][method]["DEFAULT"]
def _handle_list_objects_v2(self, bucket_name, querystring): def _handle_list_objects_v2(
self, bucket_name: str, querystring: Dict[str, Any]
) -> str:
template = self.response_template(S3_BUCKET_GET_RESPONSE_V2) template = self.response_template(S3_BUCKET_GET_RESPONSE_V2)
bucket = self.backend.get_bucket(bucket_name) bucket = self.backend.get_bucket(bucket_name)
@ -678,7 +689,7 @@ class S3Response(BaseResponse):
) )
@staticmethod @staticmethod
def _split_truncated_keys(truncated_keys): def _split_truncated_keys(truncated_keys: Any) -> Any: # type: ignore[misc]
result_keys = [] result_keys = []
result_folders = [] result_folders = []
for key in truncated_keys: for key in truncated_keys:
@ -688,7 +699,7 @@ class S3Response(BaseResponse):
result_folders.append(key) result_folders.append(key)
return result_keys, result_folders return result_keys, result_folders
def _get_results_from_token(self, result_keys, token): def _get_results_from_token(self, result_keys: Any, token: Any) -> Any:
continuation_index = 0 continuation_index = 0
for key in result_keys: for key in result_keys:
if (key.name if isinstance(key, FakeKey) else key) > token: if (key.name if isinstance(key, FakeKey) else key) > token:
@ -696,22 +707,22 @@ class S3Response(BaseResponse):
continuation_index += 1 continuation_index += 1
return result_keys[continuation_index:] return result_keys[continuation_index:]
def _truncate_result(self, result_keys, max_keys): def _truncate_result(self, result_keys: Any, max_keys: int) -> Any:
if max_keys == 0: if max_keys == 0:
result_keys = [] result_keys = []
is_truncated = True is_truncated = True
next_continuation_token = None next_continuation_token = None
elif len(result_keys) > max_keys: elif len(result_keys) > max_keys:
is_truncated = "true" is_truncated = "true" # type: ignore
result_keys = result_keys[:max_keys] result_keys = result_keys[:max_keys]
item = result_keys[-1] item = result_keys[-1]
next_continuation_token = item.name if isinstance(item, FakeKey) else item next_continuation_token = item.name if isinstance(item, FakeKey) else item
else: else:
is_truncated = "false" is_truncated = "false" # type: ignore
next_continuation_token = None next_continuation_token = None
return result_keys, is_truncated, next_continuation_token return result_keys, is_truncated, next_continuation_token
def _body_contains_location_constraint(self, body): def _body_contains_location_constraint(self, body: bytes) -> bool:
if body: if body:
try: try:
xmltodict.parse(body)["CreateBucketConfiguration"]["LocationConstraint"] xmltodict.parse(body)["CreateBucketConfiguration"]["LocationConstraint"]
@ -720,7 +731,7 @@ class S3Response(BaseResponse):
pass pass
return False return False
def _create_bucket_configuration_is_empty(self, body): def _create_bucket_configuration_is_empty(self, body: bytes) -> bool:
if body: if body:
try: try:
create_bucket_configuration = xmltodict.parse(body)[ create_bucket_configuration = xmltodict.parse(body)[
@ -733,13 +744,19 @@ class S3Response(BaseResponse):
pass pass
return False return False
def _parse_pab_config(self): def _parse_pab_config(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None) parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
return parsed_xml return parsed_xml
def _bucket_response_put(self, request, region_name, bucket_name, querystring): def _bucket_response_put(
self,
request: Any,
region_name: str,
bucket_name: str,
querystring: Dict[str, Any],
) -> Union[str, TYPE_RESPONSE]:
if querystring and not request.headers.get("Content-Length"): if querystring and not request.headers.get("Content-Length"):
return 411, {}, "Content-Length required" return 411, {}, "Content-Length required"
@ -754,7 +771,7 @@ class S3Response(BaseResponse):
self.backend.put_object_lock_configuration( self.backend.put_object_lock_configuration(
bucket_name, bucket_name,
config.get("enabled"), config.get("enabled"), # type: ignore
config.get("mode"), config.get("mode"),
config.get("days"), config.get("days"),
config.get("years"), config.get("years"),
@ -765,7 +782,7 @@ class S3Response(BaseResponse):
body = self.body.decode("utf-8") body = self.body.decode("utf-8")
ver = re.search(r"<Status>([A-Za-z]+)</Status>", body) ver = re.search(r"<Status>([A-Za-z]+)</Status>", body)
if ver: if ver:
self.backend.put_bucket_versioning(bucket_name, ver.group(1)) self.backend.put_bucket_versioning(bucket_name, ver.group(1)) # type: ignore
template = self.response_template(S3_BUCKET_VERSIONING) template = self.response_template(S3_BUCKET_VERSIONING)
return template.render(bucket_versioning_status=ver.group(1)) return template.render(bucket_versioning_status=ver.group(1))
else: else:
@ -922,7 +939,9 @@ class S3Response(BaseResponse):
template = self.response_template(S3_BUCKET_CREATE_RESPONSE) template = self.response_template(S3_BUCKET_CREATE_RESPONSE)
return 200, {}, template.render(bucket=new_bucket) return 200, {}, template.render(bucket=new_bucket)
def _bucket_response_delete(self, bucket_name, querystring): def _bucket_response_delete(
self, bucket_name: str, querystring: Dict[str, Any]
) -> TYPE_RESPONSE:
self._set_action("BUCKET", "DELETE", querystring) self._set_action("BUCKET", "DELETE", querystring)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
@ -965,7 +984,7 @@ class S3Response(BaseResponse):
template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR) template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR)
return 409, {}, template.render(bucket=removed_bucket) return 409, {}, template.render(bucket=removed_bucket)
def _bucket_response_post(self, request, bucket_name): def _bucket_response_post(self, request: Any, bucket_name: str) -> TYPE_RESPONSE:
response_headers = {} response_headers = {}
if not request.headers.get("Content-Length"): if not request.headers.get("Content-Length"):
return 411, {}, "Content-Length required" return 411, {}, "Content-Length required"
@ -999,7 +1018,7 @@ class S3Response(BaseResponse):
if "success_action_redirect" in form: if "success_action_redirect" in form:
redirect = form["success_action_redirect"] redirect = form["success_action_redirect"]
parts = urlparse(redirect) parts = urlparse(redirect)
queryargs = parse_qs(parts.query) queryargs: Dict[str, Any] = parse_qs(parts.query)
queryargs["key"] = key queryargs["key"] = key
queryargs["bucket"] = bucket_name queryargs["bucket"] = bucket_name
redirect_queryargs = urlencode(queryargs, doseq=True) redirect_queryargs = urlencode(queryargs, doseq=True)
@ -1035,14 +1054,16 @@ class S3Response(BaseResponse):
return status_code, response_headers, "" return status_code, response_headers, ""
@staticmethod @staticmethod
def _get_path(request): def _get_path(request: Any) -> str: # type: ignore[misc]
return ( return (
request.full_path request.full_path
if hasattr(request, "full_path") if hasattr(request, "full_path")
else path_url(request.url) else path_url(request.url)
) )
def _bucket_response_delete_keys(self, bucket_name, authenticated=True): def _bucket_response_delete_keys(
self, bucket_name: str, authenticated: bool = True
) -> TYPE_RESPONSE:
template = self.response_template(S3_DELETE_KEYS_RESPONSE) template = self.response_template(S3_DELETE_KEYS_RESPONSE)
body_dict = xmltodict.parse(self.body, strip_whitespace=False) body_dict = xmltodict.parse(self.body, strip_whitespace=False)
@ -1068,14 +1089,16 @@ class S3Response(BaseResponse):
template.render(deleted=deleted_objects, delete_errors=errors), template.render(deleted=deleted_objects, delete_errors=errors),
) )
def _handle_range_header(self, request, response_headers, response_content): def _handle_range_header(
self, request: Any, response_headers: Dict[str, Any], response_content: Any
) -> TYPE_RESPONSE:
length = len(response_content) length = len(response_content)
last = length - 1 last = length - 1
_, rspec = request.headers.get("range").split("=") _, rspec = request.headers.get("range").split("=")
if "," in rspec: if "," in rspec:
raise NotImplementedError("Multiple range specifiers not supported") raise NotImplementedError("Multiple range specifiers not supported")
def toint(i): def toint(i: Any) -> Optional[int]:
return int(i) if i else None return int(i) if i else None
begin, end = map(toint, rspec.split("-")) begin, end = map(toint, rspec.split("-"))
@ -1095,7 +1118,7 @@ class S3Response(BaseResponse):
response_headers["content-length"] = len(content) response_headers["content-length"] = len(content)
return 206, response_headers, content return 206, response_headers, content
def _handle_v4_chunk_signatures(self, body, content_length): def _handle_v4_chunk_signatures(self, body: bytes, content_length: int) -> bytes:
body_io = io.BytesIO(body) body_io = io.BytesIO(body)
new_body = bytearray(content_length) new_body = bytearray(content_length)
pos = 0 pos = 0
@ -1110,7 +1133,7 @@ class S3Response(BaseResponse):
line = body_io.readline() line = body_io.readline()
return bytes(new_body) return bytes(new_body)
def _handle_encoded_body(self, body, content_length): def _handle_encoded_body(self, body: bytes, content_length: int) -> bytes:
body_io = io.BytesIO(body) body_io = io.BytesIO(body)
# first line should equal '{content_length}\r\n # first line should equal '{content_length}\r\n
body_io.readline() body_io.readline()
@ -1120,12 +1143,12 @@ class S3Response(BaseResponse):
# amz-checksum-sha256:<..>\r\n # amz-checksum-sha256:<..>\r\n
@amzn_request_id @amzn_request_id
def key_response(self, request, full_url, headers): def key_response(self, request: Any, full_url: str, headers: Dict[str, Any]) -> TYPE_RESPONSE: # type: ignore[misc]
# Key and Control are lumped in because splitting out the regex is too much of a pain :/ # Key and Control are lumped in because splitting out the regex is too much of a pain :/
self.setup_class(request, full_url, headers, use_raw_body=True) self.setup_class(request, full_url, headers, use_raw_body=True)
bucket_name = self.parse_bucket_name_from_url(request, full_url) bucket_name = self.parse_bucket_name_from_url(request, full_url)
self.backend.log_incoming_request(request, bucket_name) self.backend.log_incoming_request(request, bucket_name)
response_headers = {} response_headers: Dict[str, Any] = {}
try: try:
response = self._key_response(request, full_url, self.headers) response = self._key_response(request, full_url, self.headers)
@ -1151,7 +1174,9 @@ class S3Response(BaseResponse):
return s3error.code, {}, s3error.description return s3error.code, {}, s3error.description
return status_code, response_headers, response_content return status_code, response_headers, response_content
def _key_response(self, request, full_url, headers): def _key_response(
self, request: Any, full_url: str, headers: Dict[str, Any]
) -> TYPE_RESPONSE:
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
query = parse_qs(parsed_url.query, keep_blank_values=True) query = parse_qs(parsed_url.query, keep_blank_values=True)
method = request.method method = request.method
@ -1182,7 +1207,7 @@ class S3Response(BaseResponse):
from moto.iam.access_control import PermissionResult from moto.iam.access_control import PermissionResult
action = f"s3:{method.upper()[0]}{method.lower()[1:]}Object" action = f"s3:{method.upper()[0]}{method.lower()[1:]}Object"
bucket_permissions = bucket.get_permission(action, resource) bucket_permissions = bucket.get_permission(action, resource) # type: ignore
if bucket_permissions == PermissionResult.DENIED: if bucket_permissions == PermissionResult.DENIED:
return 403, {}, "" return 403, {}, ""
@ -1255,11 +1280,17 @@ class S3Response(BaseResponse):
f"Method {method} has not been implemented in the S3 backend yet" f"Method {method} has not been implemented in the S3 backend yet"
) )
def _key_response_get(self, bucket_name, query, key_name, headers): def _key_response_get(
self,
bucket_name: str,
query: Dict[str, Any],
key_name: str,
headers: Dict[str, Any],
) -> TYPE_RESPONSE:
self._set_action("KEY", "GET", query) self._set_action("KEY", "GET", query)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
response_headers = {} response_headers: Dict[str, Any] = {}
if query.get("uploadId"): if query.get("uploadId"):
upload_id = query["uploadId"][0] upload_id = query["uploadId"][0]
@ -1287,7 +1318,7 @@ class S3Response(BaseResponse):
) )
next_part_number_marker = parts[-1].name if parts else 0 next_part_number_marker = parts[-1].name if parts else 0
is_truncated = len(parts) != 0 and self.backend.is_truncated( is_truncated = len(parts) != 0 and self.backend.is_truncated(
bucket_name, upload_id, next_part_number_marker bucket_name, upload_id, next_part_number_marker # type: ignore
) )
template = self.response_template(S3_MULTIPART_LIST_RESPONSE) template = self.response_template(S3_MULTIPART_LIST_RESPONSE)
@ -1355,7 +1386,7 @@ class S3Response(BaseResponse):
attributes_to_get = headers.get("x-amz-object-attributes", "").split(",") attributes_to_get = headers.get("x-amz-object-attributes", "").split(",")
response_keys = self.backend.get_object_attributes(key, attributes_to_get) response_keys = self.backend.get_object_attributes(key, attributes_to_get)
if key.version_id == "null": if key.version_id == "null": # type: ignore
response_headers.pop("x-amz-version-id") response_headers.pop("x-amz-version-id")
response_headers["Last-Modified"] = key.last_modified_ISO8601 response_headers["Last-Modified"] = key.last_modified_ISO8601
@ -1367,11 +1398,18 @@ class S3Response(BaseResponse):
response_headers.update({"AcceptRanges": "bytes"}) response_headers.update({"AcceptRanges": "bytes"})
return 200, response_headers, key.value return 200, response_headers, key.value
def _key_response_put(self, request, body, bucket_name, query, key_name): def _key_response_put(
self,
request: Any,
body: bytes,
bucket_name: str,
query: Dict[str, Any],
key_name: str,
) -> TYPE_RESPONSE:
self._set_action("KEY", "PUT", query) self._set_action("KEY", "PUT", query)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
response_headers = {} response_headers: Dict[str, Any] = {}
if query.get("uploadId") and query.get("partNumber"): if query.get("uploadId") and query.get("partNumber"):
upload_id = query["uploadId"][0] upload_id = query["uploadId"][0]
part_number = int(query["partNumber"][0]) part_number = int(query["partNumber"][0])
@ -1382,7 +1420,7 @@ class S3Response(BaseResponse):
copy_source_parsed = urlparse(copy_source) copy_source_parsed = urlparse(copy_source)
src_bucket, src_key = copy_source_parsed.path.lstrip("/").split("/", 1) src_bucket, src_key = copy_source_parsed.path.lstrip("/").split("/", 1)
src_version_id = parse_qs(copy_source_parsed.query).get( src_version_id = parse_qs(copy_source_parsed.query).get(
"versionId", [None] "versionId", [None] # type: ignore
)[0] )[0]
src_range = request.headers.get("x-amz-copy-source-range", "").split( src_range = request.headers.get("x-amz-copy-source-range", "").split(
"bytes=" "bytes="
@ -1515,9 +1553,11 @@ class S3Response(BaseResponse):
version_id = query["versionId"][0] version_id = query["versionId"][0]
else: else:
version_id = None version_id = None
key = self.backend.get_object(bucket_name, key_name, version_id=version_id) key_to_tag = self.backend.get_object(
bucket_name, key_name, version_id=version_id
)
tagging = self._tagging_from_xml(body) tagging = self._tagging_from_xml(body)
self.backend.set_key_tags(key, tagging, key_name) self.backend.set_key_tags(key_to_tag, tagging, key_name)
return 200, response_headers, "" return 200, response_headers, ""
if "x-amz-copy-source" in request.headers: if "x-amz-copy-source" in request.headers:
@ -1532,21 +1572,21 @@ class S3Response(BaseResponse):
unquote(copy_source_parsed.path).lstrip("/").split("/", 1) unquote(copy_source_parsed.path).lstrip("/").split("/", 1)
) )
src_version_id = parse_qs(copy_source_parsed.query).get( src_version_id = parse_qs(copy_source_parsed.query).get(
"versionId", [None] "versionId", [None] # type: ignore
)[0] )[0]
key = self.backend.get_object( key_to_copy = self.backend.get_object(
src_bucket, src_key, version_id=src_version_id, key_is_clean=True src_bucket, src_key, version_id=src_version_id, key_is_clean=True
) )
if key is not None: if key_to_copy is not None:
if key.storage_class in ARCHIVE_STORAGE_CLASSES: if key_to_copy.storage_class in ARCHIVE_STORAGE_CLASSES:
if key.response_dict.get( if key_to_copy.response_dict.get(
"x-amz-restore" "x-amz-restore"
) is None or 'ongoing-request="true"' in key.response_dict.get( ) is None or 'ongoing-request="true"' in key_to_copy.response_dict.get( # type: ignore
"x-amz-restore" "x-amz-restore"
): ):
raise ObjectNotInActiveTierError(key) raise ObjectNotInActiveTierError(key_to_copy)
bucket_key_enabled = ( bucket_key_enabled = (
request.headers.get( request.headers.get(
@ -1558,7 +1598,7 @@ class S3Response(BaseResponse):
mdirective = request.headers.get("x-amz-metadata-directive") mdirective = request.headers.get("x-amz-metadata-directive")
self.backend.copy_object( self.backend.copy_object(
key, key_to_copy,
bucket_name, bucket_name,
key_name, key_name,
storage=storage_class, storage=storage_class,
@ -1571,7 +1611,7 @@ class S3Response(BaseResponse):
else: else:
raise MissingKey(key=src_key) raise MissingKey(key=src_key)
new_key = self.backend.get_object(bucket_name, key_name) new_key: FakeKey = self.backend.get_object(bucket_name, key_name) # type: ignore
if mdirective is not None and mdirective == "REPLACE": if mdirective is not None and mdirective == "REPLACE":
metadata = metadata_from_headers(request.headers) metadata = metadata_from_headers(request.headers)
new_key.set_metadata(metadata, replace=True) new_key.set_metadata(metadata, replace=True)
@ -1612,11 +1652,17 @@ class S3Response(BaseResponse):
response_headers.update(new_key.response_dict) response_headers.update(new_key.response_dict)
return 200, response_headers, "" return 200, response_headers, ""
def _key_response_head(self, bucket_name, query, key_name, headers): def _key_response_head(
self,
bucket_name: str,
query: Dict[str, Any],
key_name: str,
headers: Dict[str, Any],
) -> TYPE_RESPONSE:
self._set_action("KEY", "HEAD", query) self._set_action("KEY", "HEAD", query)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
response_headers = {} response_headers: Dict[str, Any] = {}
version_id = query.get("versionId", [None])[0] version_id = query.get("versionId", [None])[0]
if version_id and not self.backend.get_bucket(bucket_name).is_versioned: if version_id and not self.backend.get_bucket(bucket_name).is_versioned:
return 400, response_headers, "" return 400, response_headers, ""
@ -1654,16 +1700,21 @@ class S3Response(BaseResponse):
if part_number: if part_number:
full_key = self.backend.head_object(bucket_name, key_name, version_id) full_key = self.backend.head_object(bucket_name, key_name, version_id)
if full_key.multipart: if full_key.multipart: # type: ignore
mp_part_count = str(len(full_key.multipart.partlist)) mp_part_count = str(len(full_key.multipart.partlist)) # type: ignore
response_headers["x-amz-mp-parts-count"] = mp_part_count response_headers["x-amz-mp-parts-count"] = mp_part_count
return 200, response_headers, "" return 200, response_headers, ""
else: else:
return 404, response_headers, "" return 404, response_headers, ""
def _lock_config_from_body(self): def _lock_config_from_body(self) -> Dict[str, Any]:
response_dict = {"enabled": False, "mode": None, "days": None, "years": None} response_dict: Dict[str, Any] = {
"enabled": False,
"mode": None,
"days": None,
"years": None,
}
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
enabled = ( enabled = (
parsed_xml["ObjectLockConfiguration"]["ObjectLockEnabled"] == "Enabled" parsed_xml["ObjectLockConfiguration"]["ObjectLockEnabled"] == "Enabled"
@ -1685,7 +1736,7 @@ class S3Response(BaseResponse):
return response_dict return response_dict
def _acl_from_body(self): def _acl_from_body(self) -> Optional[FakeAcl]:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
if not parsed_xml.get("AccessControlPolicy"): if not parsed_xml.get("AccessControlPolicy"):
raise MalformedACLError() raise MalformedACLError()
@ -1697,7 +1748,7 @@ class S3Response(BaseResponse):
# If empty, then no ACLs: # If empty, then no ACLs:
if parsed_xml["AccessControlPolicy"].get("AccessControlList") is None: if parsed_xml["AccessControlPolicy"].get("AccessControlList") is None:
return [] return None
if not parsed_xml["AccessControlPolicy"]["AccessControlList"].get("Grant"): if not parsed_xml["AccessControlPolicy"]["AccessControlList"].get("Grant"):
raise MalformedACLError() raise MalformedACLError()
@ -1718,7 +1769,12 @@ class S3Response(BaseResponse):
) )
return FakeAcl(grants) return FakeAcl(grants)
def _get_grants_from_xml(self, grant_list, exception_type, permissions): def _get_grants_from_xml(
self,
grant_list: List[Dict[str, Any]],
exception_type: Type[S3ClientError],
permissions: List[str],
) -> List[FakeGrant]:
grants = [] grants = []
for grant in grant_list: for grant in grant_list:
if grant.get("Permission", "") not in permissions: if grant.get("Permission", "") not in permissions:
@ -1748,7 +1804,7 @@ class S3Response(BaseResponse):
return grants return grants
def _acl_from_headers(self, headers): def _acl_from_headers(self, headers: Dict[str, str]) -> Optional[FakeAcl]:
canned_acl = headers.get("x-amz-acl", "") canned_acl = headers.get("x-amz-acl", "")
grants = [] grants = []
@ -1767,7 +1823,7 @@ class S3Response(BaseResponse):
grantees = [] grantees = []
for key_and_value in value.split(","): for key_and_value in value.split(","):
key, value = re.match( key, value = re.match( # type: ignore
'([^=]+)="?([^"]+)"?', key_and_value.strip() '([^=]+)="?([^"]+)"?', key_and_value.strip()
).groups() ).groups()
if key.lower() == "id": if key.lower() == "id":
@ -1785,7 +1841,7 @@ class S3Response(BaseResponse):
else: else:
return None return None
def _tagging_from_headers(self, headers): def _tagging_from_headers(self, headers: Dict[str, Any]) -> Dict[str, str]:
tags = {} tags = {}
if headers.get("x-amz-tagging"): if headers.get("x-amz-tagging"):
parsed_header = parse_qs(headers["x-amz-tagging"], keep_blank_values=True) parsed_header = parse_qs(headers["x-amz-tagging"], keep_blank_values=True)
@ -1793,7 +1849,7 @@ class S3Response(BaseResponse):
tags[tag[0]] = tag[1][0] tags[tag[0]] = tag[1][0]
return tags return tags
def _tagging_from_xml(self, xml): def _tagging_from_xml(self, xml: bytes) -> Dict[str, str]:
parsed_xml = xmltodict.parse(xml, force_list={"Tag": True}) parsed_xml = xmltodict.parse(xml, force_list={"Tag": True})
tags = {} tags = {}
@ -1802,7 +1858,7 @@ class S3Response(BaseResponse):
return tags return tags
def _bucket_tagging_from_body(self): def _bucket_tagging_from_body(self) -> Dict[str, str]:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
tags = {} tags = {}
@ -1826,7 +1882,7 @@ class S3Response(BaseResponse):
return tags return tags
def _cors_from_body(self): def _cors_from_body(self) -> List[Dict[str, Any]]:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
if isinstance(parsed_xml["CORSConfiguration"]["CORSRule"], list): if isinstance(parsed_xml["CORSConfiguration"]["CORSRule"], list):
@ -1834,18 +1890,18 @@ class S3Response(BaseResponse):
return [parsed_xml["CORSConfiguration"]["CORSRule"]] return [parsed_xml["CORSConfiguration"]["CORSRule"]]
def _mode_until_from_body(self): def _mode_until_from_body(self) -> Tuple[Optional[str], Optional[str]]:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
return ( return (
parsed_xml.get("Retention", None).get("Mode", None), parsed_xml.get("Retention", None).get("Mode", None),
parsed_xml.get("Retention", None).get("RetainUntilDate", None), parsed_xml.get("Retention", None).get("RetainUntilDate", None),
) )
def _legal_hold_status_from_xml(self, xml): def _legal_hold_status_from_xml(self, xml: bytes) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(xml) parsed_xml = xmltodict.parse(xml)
return parsed_xml["LegalHold"]["Status"] return parsed_xml["LegalHold"]["Status"]
def _encryption_config_from_body(self): def _encryption_config_from_body(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
if ( if (
@ -1861,7 +1917,7 @@ class S3Response(BaseResponse):
return parsed_xml["ServerSideEncryptionConfiguration"] return parsed_xml["ServerSideEncryptionConfiguration"]
def _ownership_rule_from_body(self): def _ownership_rule_from_body(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
if not parsed_xml["OwnershipControls"]["Rule"].get("ObjectOwnership"): if not parsed_xml["OwnershipControls"]["Rule"].get("ObjectOwnership"):
@ -1869,7 +1925,7 @@ class S3Response(BaseResponse):
return parsed_xml["OwnershipControls"]["Rule"]["ObjectOwnership"] return parsed_xml["OwnershipControls"]["Rule"]["ObjectOwnership"]
def _logging_from_body(self): def _logging_from_body(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
if not parsed_xml["BucketLoggingStatus"].get("LoggingEnabled"): if not parsed_xml["BucketLoggingStatus"].get("LoggingEnabled"):
@ -1914,7 +1970,7 @@ class S3Response(BaseResponse):
return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"] return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]
def _notification_config_from_body(self): def _notification_config_from_body(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
if not len(parsed_xml["NotificationConfiguration"]): if not len(parsed_xml["NotificationConfiguration"]):
@ -1989,17 +2045,19 @@ class S3Response(BaseResponse):
return parsed_xml["NotificationConfiguration"] return parsed_xml["NotificationConfiguration"]
def _accelerate_config_from_body(self): def _accelerate_config_from_body(self) -> str:
parsed_xml = xmltodict.parse(self.body) parsed_xml = xmltodict.parse(self.body)
config = parsed_xml["AccelerateConfiguration"] config = parsed_xml["AccelerateConfiguration"]
return config["Status"] return config["Status"]
def _replication_config_from_xml(self, xml): def _replication_config_from_xml(self, xml: str) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(xml, dict_constructor=dict) parsed_xml = xmltodict.parse(xml, dict_constructor=dict)
config = parsed_xml["ReplicationConfiguration"] config = parsed_xml["ReplicationConfiguration"]
return config return config
def _key_response_delete(self, headers, bucket_name, query, key_name): def _key_response_delete(
self, headers: Any, bucket_name: str, query: Dict[str, Any], key_name: str
) -> TYPE_RESPONSE:
self._set_action("KEY", "DELETE", query) self._set_action("KEY", "DELETE", query)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
@ -2024,7 +2082,7 @@ class S3Response(BaseResponse):
response_headers[f"x-amz-{k}"] = response_meta[k] response_headers[f"x-amz-{k}"] = response_meta[k]
return 204, response_headers, "" return 204, response_headers, ""
def _complete_multipart_body(self, body): def _complete_multipart_body(self, body: bytes) -> Iterator[Tuple[int, str]]:
ps = minidom.parseString(body).getElementsByTagName("Part") ps = minidom.parseString(body).getElementsByTagName("Part")
prev = 0 prev = 0
for p in ps: for p in ps:
@ -2033,7 +2091,14 @@ class S3Response(BaseResponse):
raise InvalidPartOrder() raise InvalidPartOrder()
yield (pn, p.getElementsByTagName("ETag")[0].firstChild.wholeText) yield (pn, p.getElementsByTagName("ETag")[0].firstChild.wholeText)
def _key_response_post(self, request, body, bucket_name, query, key_name): def _key_response_post(
self,
request: Any,
body: bytes,
bucket_name: str,
query: Dict[str, Any],
key_name: str,
) -> TYPE_RESPONSE:
self._set_action("KEY", "POST", query) self._set_action("KEY", "POST", query)
self._authenticate_and_authorize_s3_action() self._authenticate_and_authorize_s3_action()
@ -2071,11 +2136,10 @@ class S3Response(BaseResponse):
return 200, response_headers, response return 200, response_headers, response
if query.get("uploadId"): if query.get("uploadId"):
body = self._complete_multipart_body(body) multipart_id = query["uploadId"][0] # type: ignore
multipart_id = query["uploadId"][0]
multipart, value, etag = self.backend.complete_multipart_upload( multipart, value, etag = self.backend.complete_multipart_upload(
bucket_name, multipart_id, body bucket_name, multipart_id, self._complete_multipart_body(body)
) )
if value is None: if value is None:
return 400, {}, "" return 400, {}, ""
@ -2095,7 +2159,7 @@ class S3Response(BaseResponse):
self.backend.put_object_acl(bucket_name, key.name, multipart.acl) self.backend.put_object_acl(bucket_name, key.name, multipart.acl)
template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE)
headers = {} headers: Dict[str, Any] = {}
if key.version_id: if key.version_id:
headers["x-amz-version-id"] = key.version_id headers["x-amz-version-id"] = key.version_id
@ -2116,7 +2180,7 @@ class S3Response(BaseResponse):
elif "restore" in query: elif "restore" in query:
es = minidom.parseString(body).getElementsByTagName("Days") es = minidom.parseString(body).getElementsByTagName("Days")
days = es[0].childNodes[0].wholeText days = es[0].childNodes[0].wholeText
key = self.backend.get_object(bucket_name, key_name) key = self.backend.get_object(bucket_name, key_name) # type: ignore
if key.storage_class not in ARCHIVE_STORAGE_CLASSES: if key.storage_class not in ARCHIVE_STORAGE_CLASSES:
raise InvalidObjectState(storage_class=key.storage_class) raise InvalidObjectState(storage_class=key.storage_class)
r = 202 r = 202
@ -2139,7 +2203,7 @@ class S3Response(BaseResponse):
"Method POST had only been implemented for multipart uploads and restore operations, so far" "Method POST had only been implemented for multipart uploads and restore operations, so far"
) )
def _invalid_headers(self, url, headers): def _invalid_headers(self, url: str, headers: Dict[str, str]) -> bool:
""" """
Verify whether the provided metadata in the URL is also present in the headers Verify whether the provided metadata in the URL is also present in the headers
:param url: .../file.txt&content-type=app%2Fjson&Signature=.. :param url: .../file.txt&content-type=app%2Fjson&Signature=..

View File

@ -1,19 +1,21 @@
import binascii import binascii
import struct import struct
from typing import List from typing import Any, Dict, List, Optional
def parse_query(text_input, query): def parse_query(text_input: str, query: str) -> List[Dict[str, Any]]:
from py_partiql_parser import S3SelectParser from py_partiql_parser import S3SelectParser
return S3SelectParser(source_data={"s3object": text_input}).parse(query) return S3SelectParser(source_data={"s3object": text_input}).parse(query)
def _create_header(key: bytes, value: bytes): def _create_header(key: bytes, value: bytes) -> bytes:
return struct.pack("b", len(key)) + key + struct.pack("!bh", 7, len(value)) + value return struct.pack("b", len(key)) + key + struct.pack("!bh", 7, len(value)) + value
def _create_message(content_type, event_type, payload): def _create_message(
content_type: Optional[bytes], event_type: bytes, payload: bytes
) -> bytes:
headers = _create_header(b":message-type", b"event") headers = _create_header(b":message-type", b"event")
headers += _create_header(b":event-type", event_type) headers += _create_header(b":event-type", event_type)
if content_type is not None: if content_type is not None:
@ -31,23 +33,23 @@ def _create_message(content_type, event_type, payload):
return prelude + prelude_crc + headers + payload + message_crc return prelude + prelude_crc + headers + payload + message_crc
def _create_stats_message(): def _create_stats_message() -> bytes:
stats = b"""<Stats><BytesScanned>24</BytesScanned><BytesProcessed>24</BytesProcessed><BytesReturned>22</BytesReturned></Stats>""" stats = b"""<Stats><BytesScanned>24</BytesScanned><BytesProcessed>24</BytesProcessed><BytesReturned>22</BytesReturned></Stats>"""
return _create_message(content_type=b"text/xml", event_type=b"Stats", payload=stats) return _create_message(content_type=b"text/xml", event_type=b"Stats", payload=stats)
def _create_data_message(payload: bytes): def _create_data_message(payload: bytes) -> bytes:
# https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html # https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html
return _create_message( return _create_message(
content_type=b"application/octet-stream", event_type=b"Records", payload=payload content_type=b"application/octet-stream", event_type=b"Records", payload=payload
) )
def _create_end_message(): def _create_end_message() -> bytes:
return _create_message(content_type=None, event_type=b"End", payload=b"") return _create_message(content_type=None, event_type=b"End", payload=b"")
def serialize_select(data_list: List[bytes]): def serialize_select(data_list: List[bytes]) -> bytes:
response = b"" response = b""
for data in data_list: for data in data_list:
response += _create_data_message(data + b",") response += _create_data_message(data + b",")

View File

@ -5,7 +5,7 @@ import re
import hashlib import hashlib
from urllib.parse import urlparse, unquote, quote from urllib.parse import urlparse, unquote, quote
from requests.structures import CaseInsensitiveDict from requests.structures import CaseInsensitiveDict
from typing import List, Union, Tuple from typing import Any, Dict, List, Iterator, Union, Tuple, Optional
import sys import sys
from moto.settings import S3_IGNORE_SUBDOMAIN_BUCKETNAME from moto.settings import S3_IGNORE_SUBDOMAIN_BUCKETNAME
@ -38,7 +38,7 @@ STORAGE_CLASS = [
] + ARCHIVE_STORAGE_CLASSES ] + ARCHIVE_STORAGE_CLASSES
def bucket_name_from_url(url): def bucket_name_from_url(url: str) -> Optional[str]: # type: ignore
if S3_IGNORE_SUBDOMAIN_BUCKETNAME: if S3_IGNORE_SUBDOMAIN_BUCKETNAME:
return None return None
domain = urlparse(url).netloc domain = urlparse(url).netloc
@ -75,7 +75,7 @@ REGION_URL_REGEX = re.compile(
) )
def parse_region_from_url(url, use_default_region=True): def parse_region_from_url(url: str, use_default_region: bool = True) -> str:
match = REGION_URL_REGEX.search(url) match = REGION_URL_REGEX.search(url)
if match: if match:
region = match.group("region1") or match.group("region2") region = match.group("region1") or match.group("region2")
@ -84,8 +84,8 @@ def parse_region_from_url(url, use_default_region=True):
return region return region
def metadata_from_headers(headers): def metadata_from_headers(headers: Dict[str, Any]) -> CaseInsensitiveDict: # type: ignore
metadata = CaseInsensitiveDict() metadata = CaseInsensitiveDict() # type: ignore
meta_regex = re.compile(r"^x-amz-meta-([a-zA-Z0-9\-_.]+)$", flags=re.IGNORECASE) meta_regex = re.compile(r"^x-amz-meta-([a-zA-Z0-9\-_.]+)$", flags=re.IGNORECASE)
for header in headers.keys(): for header in headers.keys():
if isinstance(header, str): if isinstance(header, str):
@ -106,32 +106,32 @@ def metadata_from_headers(headers):
return metadata return metadata
def clean_key_name(key_name): def clean_key_name(key_name: str) -> str:
return unquote(key_name) return unquote(key_name)
def undo_clean_key_name(key_name): def undo_clean_key_name(key_name: str) -> str:
return quote(key_name) return quote(key_name)
class _VersionedKeyStore(dict): class _VersionedKeyStore(dict): # type: ignore
"""A simplified/modified version of Django's `MultiValueDict` taken from: """A simplified/modified version of Django's `MultiValueDict` taken from:
https://github.com/django/django/blob/70576740b0bb5289873f5a9a9a4e1a26b2c330e5/django/utils/datastructures.py#L282 https://github.com/django/django/blob/70576740b0bb5289873f5a9a9a4e1a26b2c330e5/django/utils/datastructures.py#L282
""" """
def __sgetitem__(self, key): def __sgetitem__(self, key: str) -> List[Any]:
return super().__getitem__(key) return super().__getitem__(key)
def pop(self, key): def pop(self, key: str) -> None: # type: ignore
for version in self.getlist(key, []): for version in self.getlist(key, []):
version.dispose() version.dispose()
super().pop(key) super().pop(key)
def __getitem__(self, key): def __getitem__(self, key: str) -> Any:
return self.__sgetitem__(key)[-1] return self.__sgetitem__(key)[-1]
def __setitem__(self, key, value): def __setitem__(self, key: str, value: Any) -> Any:
try: try:
current = self.__sgetitem__(key) current = self.__sgetitem__(key)
current.append(value) current.append(value)
@ -140,21 +140,21 @@ class _VersionedKeyStore(dict):
super().__setitem__(key, current) super().__setitem__(key, current)
def get(self, key, default=None): def get(self, key: str, default: Any = None) -> Any:
try: try:
return self[key] return self[key]
except (KeyError, IndexError): except (KeyError, IndexError):
pass pass
return default return default
def getlist(self, key, default=None): def getlist(self, key: str, default: Any = None) -> Any:
try: try:
return self.__sgetitem__(key) return self.__sgetitem__(key)
except (KeyError, IndexError): except (KeyError, IndexError):
pass pass
return default return default
def setlist(self, key, list_): def setlist(self, key: Any, list_: Any) -> Any:
if isinstance(list_, tuple): if isinstance(list_, tuple):
list_ = list(list_) list_ = list(list_)
elif not isinstance(list_, list): elif not isinstance(list_, list):
@ -168,35 +168,35 @@ class _VersionedKeyStore(dict):
super().__setitem__(key, list_) super().__setitem__(key, list_)
def _iteritems(self): def _iteritems(self) -> Iterator[Tuple[str, Any]]:
for key in self._self_iterable(): for key in self._self_iterable():
yield key, self[key] yield key, self[key]
def _itervalues(self): def _itervalues(self) -> Iterator[Any]:
for key in self._self_iterable(): for key in self._self_iterable():
yield self[key] yield self[key]
def _iterlists(self): def _iterlists(self) -> Iterator[Tuple[str, List[Any]]]:
for key in self._self_iterable(): for key in self._self_iterable():
yield key, self.getlist(key) yield key, self.getlist(key)
def item_size(self): def item_size(self) -> int:
size = 0 size = 0
for val in self._self_iterable().values(): for val in self._self_iterable().values():
size += sys.getsizeof(val) size += sys.getsizeof(val)
return size return size
def _self_iterable(self): def _self_iterable(self) -> Dict[str, Any]:
# to enable concurrency, return a copy, to avoid "dictionary changed size during iteration" # to enable concurrency, return a copy, to avoid "dictionary changed size during iteration"
# TODO: look into replacing with a locking mechanism, potentially # TODO: look into replacing with a locking mechanism, potentially
return dict(self) return dict(self)
items = iteritems = _iteritems items = iteritems = _iteritems # type: ignore
lists = iterlists = _iterlists lists = iterlists = _iterlists
values = itervalues = _itervalues values = itervalues = _itervalues # type: ignore
def compute_checksum(body, algorithm): def compute_checksum(body: bytes, algorithm: str) -> bytes:
if algorithm == "SHA1": if algorithm == "SHA1":
hashed_body = _hash(hashlib.sha1, (body,)) hashed_body = _hash(hashlib.sha1, (body,))
elif algorithm == "CRC32" or algorithm == "CRC32C": elif algorithm == "CRC32" or algorithm == "CRC32C":
@ -206,7 +206,7 @@ def compute_checksum(body, algorithm):
return base64.b64encode(hashed_body) return base64.b64encode(hashed_body)
def _hash(fn, args) -> bytes: def _hash(fn: Any, args: Any) -> bytes:
try: try:
return fn(*args, usedforsecurity=False).hexdigest().encode("utf-8") return fn(*args, usedforsecurity=False).hexdigest().encode("utf-8")
except TypeError: except TypeError:

View File

@ -1,7 +1,8 @@
from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
def bucket_name_from_url(url): def bucket_name_from_url(url: str) -> Optional[str]:
path = urlparse(url).path.lstrip("/") path = urlparse(url).path.lstrip("/")
parts = path.lstrip("/").split("/") parts = path.lstrip("/").split("/")
@ -10,5 +11,5 @@ def bucket_name_from_url(url):
return parts[0] return parts[0]
def parse_key_name(path): def parse_key_name(path: str) -> str:
return "/".join(path.split("/")[2:]) return "/".join(path.split("/")[2:])

View File

@ -2,6 +2,7 @@ import datetime
import json import json
from boto3 import Session from boto3 import Session
from typing import Any, Dict, List, Optional, Tuple
from moto.core.exceptions import InvalidNextTokenException from moto.core.exceptions import InvalidNextTokenException
from moto.core.common_models import ConfigQueryModel from moto.core.common_models import ConfigQueryModel
@ -12,15 +13,15 @@ from moto.s3control import s3control_backends
class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel): class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel):
def list_config_service_resources( def list_config_service_resources(
self, self,
account_id, account_id: str,
resource_ids, resource_ids: Optional[List[str]],
resource_name, resource_name: Optional[str],
limit, limit: int,
next_token, next_token: Optional[str],
backend_region=None, backend_region: Optional[str] = None,
resource_region=None, resource_region: Optional[str] = None,
aggregator=None, aggregator: Any = None,
): ) -> Tuple[List[Dict[str, Any]], Optional[str]]:
# For the Account Public Access Block, they are the same for all regions. The resource ID is the AWS account ID # For the Account Public Access Block, they are the same for all regions. The resource ID is the AWS account ID
# There is no resource name -- it should be a blank string "" if provided. # There is no resource name -- it should be a blank string "" if provided.
@ -95,12 +96,12 @@ class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel):
def get_config_resource( def get_config_resource(
self, self,
account_id, account_id: str,
resource_id, resource_id: str,
resource_name=None, resource_name: Optional[str] = None,
backend_region=None, backend_region: Optional[str] = None,
resource_region=None, resource_region: Optional[str] = None,
): ) -> Optional[Dict[str, Any]]:
# Do we even have this defined? # Do we even have this defined?
backend = self.backends[account_id]["global"] backend = self.backends[account_id]["global"]
@ -116,7 +117,7 @@ class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel):
# Is the resource ID correct?: # Is the resource ID correct?:
if account_id == resource_id: if account_id == resource_id:
if backend_region: if backend_region:
pab_region = backend_region pab_region: Optional[str] = backend_region
# Invalid region? # Invalid region?
elif resource_region not in regions: elif resource_region not in regions:

View File

@ -1,4 +1,4 @@
"""Exceptions raised by the s3control service.""" from typing import Any
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
@ -13,7 +13,7 @@ ERROR_WITH_ACCESS_POINT_POLICY = """{% extends 'wrapped_single_error' %}
class S3ControlError(RESTError): class S3ControlError(RESTError):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "single_error") kwargs.setdefault("template", "single_error")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -21,7 +21,7 @@ class S3ControlError(RESTError):
class AccessPointNotFound(S3ControlError): class AccessPointNotFound(S3ControlError):
code = 404 code = 404
def __init__(self, name, **kwargs): def __init__(self, name: str, **kwargs: Any):
kwargs.setdefault("template", "ap_not_found") kwargs.setdefault("template", "ap_not_found")
kwargs["name"] = name kwargs["name"] = name
self.templates["ap_not_found"] = ERROR_WITH_ACCESS_POINT_NAME self.templates["ap_not_found"] = ERROR_WITH_ACCESS_POINT_NAME
@ -33,7 +33,7 @@ class AccessPointNotFound(S3ControlError):
class AccessPointPolicyNotFound(S3ControlError): class AccessPointPolicyNotFound(S3ControlError):
code = 404 code = 404
def __init__(self, name, **kwargs): def __init__(self, name: str, **kwargs: Any):
kwargs.setdefault("template", "apf_not_found") kwargs.setdefault("template", "apf_not_found")
kwargs["name"] = name kwargs["name"] = name
self.templates["apf_not_found"] = ERROR_WITH_ACCESS_POINT_POLICY self.templates["apf_not_found"] = ERROR_WITH_ACCESS_POINT_POLICY

View File

@ -1,5 +1,7 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
from moto.s3.exceptions import ( from moto.s3.exceptions import (
@ -15,18 +17,18 @@ from .exceptions import AccessPointNotFound, AccessPointPolicyNotFound
class AccessPoint(BaseModel): class AccessPoint(BaseModel):
def __init__( def __init__(
self, self,
account_id, account_id: str,
name, name: str,
bucket, bucket: str,
vpc_configuration, vpc_configuration: Dict[str, Any],
public_access_block_configuration, public_access_block_configuration: Dict[str, Any],
): ):
self.name = name self.name = name
self.alias = f"{name}-{mock_random.get_random_hex(34)}-s3alias" self.alias = f"{name}-{mock_random.get_random_hex(34)}-s3alias"
self.bucket = bucket self.bucket = bucket
self.created = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f") self.created = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
self.arn = f"arn:aws:s3:us-east-1:{account_id}:accesspoint/{name}" self.arn = f"arn:aws:s3:us-east-1:{account_id}:accesspoint/{name}"
self.policy = None self.policy: Optional[str] = None
self.network_origin = "VPC" if vpc_configuration else "Internet" self.network_origin = "VPC" if vpc_configuration else "Internet"
self.vpc_id = (vpc_configuration or {}).get("VpcId") self.vpc_id = (vpc_configuration or {}).get("VpcId")
pubc = public_access_block_configuration or {} pubc = public_access_block_configuration or {}
@ -37,23 +39,23 @@ class AccessPoint(BaseModel):
"RestrictPublicBuckets": pubc.get("RestrictPublicBuckets", "true"), "RestrictPublicBuckets": pubc.get("RestrictPublicBuckets", "true"),
} }
def delete_policy(self): def delete_policy(self) -> None:
self.policy = None self.policy = None
def set_policy(self, policy): def set_policy(self, policy: str) -> None:
self.policy = policy self.policy = policy
def has_policy(self): def has_policy(self) -> bool:
return self.policy is not None return self.policy is not None
class S3ControlBackend(BaseBackend): class S3ControlBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.public_access_block = None self.public_access_block: Optional[PublicAccessBlock] = None
self.access_points = defaultdict(dict) self.access_points: Dict[str, Dict[str, AccessPoint]] = defaultdict(dict)
def get_public_access_block(self, account_id): def get_public_access_block(self, account_id: str) -> PublicAccessBlock:
# The account ID should equal the account id that is set for Moto: # The account ID should equal the account id that is set for Moto:
if account_id != self.account_id: if account_id != self.account_id:
raise WrongPublicAccessBlockAccountIdError() raise WrongPublicAccessBlockAccountIdError()
@ -63,14 +65,16 @@ class S3ControlBackend(BaseBackend):
return self.public_access_block return self.public_access_block
def delete_public_access_block(self, account_id): def delete_public_access_block(self, account_id: str) -> None:
# The account ID should equal the account id that is set for Moto: # The account ID should equal the account id that is set for Moto:
if account_id != self.account_id: if account_id != self.account_id:
raise WrongPublicAccessBlockAccountIdError() raise WrongPublicAccessBlockAccountIdError()
self.public_access_block = None self.public_access_block = None
def put_public_access_block(self, account_id, pub_block_config): def put_public_access_block(
self, account_id: str, pub_block_config: Dict[str, Any]
) -> None:
# The account ID should equal the account id that is set for Moto: # The account ID should equal the account id that is set for Moto:
if account_id != self.account_id: if account_id != self.account_id:
raise WrongPublicAccessBlockAccountIdError() raise WrongPublicAccessBlockAccountIdError()
@ -87,12 +91,12 @@ class S3ControlBackend(BaseBackend):
def create_access_point( def create_access_point(
self, self,
account_id, account_id: str,
name, name: str,
bucket, bucket: str,
vpc_configuration, vpc_configuration: Dict[str, Any],
public_access_block_configuration, public_access_block_configuration: Dict[str, Any],
): ) -> AccessPoint:
access_point = AccessPoint( access_point = AccessPoint(
account_id, account_id,
name, name,
@ -103,29 +107,31 @@ class S3ControlBackend(BaseBackend):
self.access_points[account_id][name] = access_point self.access_points[account_id][name] = access_point
return access_point return access_point
def delete_access_point(self, account_id, name): def delete_access_point(self, account_id: str, name: str) -> None:
self.access_points[account_id].pop(name, None) self.access_points[account_id].pop(name, None)
def get_access_point(self, account_id, name): def get_access_point(self, account_id: str, name: str) -> AccessPoint:
if name not in self.access_points[account_id]: if name not in self.access_points[account_id]:
raise AccessPointNotFound(name) raise AccessPointNotFound(name)
return self.access_points[account_id][name] return self.access_points[account_id][name]
def create_access_point_policy(self, account_id, name, policy): def create_access_point_policy(
self, account_id: str, name: str, policy: str
) -> None:
access_point = self.get_access_point(account_id, name) access_point = self.get_access_point(account_id, name)
access_point.set_policy(policy) access_point.set_policy(policy)
def get_access_point_policy(self, account_id, name): def get_access_point_policy(self, account_id: str, name: str) -> str:
access_point = self.get_access_point(account_id, name) access_point = self.get_access_point(account_id, name)
if access_point.has_policy(): if access_point.has_policy():
return access_point.policy return access_point.policy # type: ignore[return-value]
raise AccessPointPolicyNotFound(name) raise AccessPointPolicyNotFound(name)
def delete_access_point_policy(self, account_id, name): def delete_access_point_policy(self, account_id: str, name: str) -> None:
access_point = self.get_access_point(account_id, name) access_point = self.get_access_point(account_id, name)
access_point.delete_policy() access_point.delete_policy()
def get_access_point_policy_status(self, account_id, name): def get_access_point_policy_status(self, account_id: str, name: str) -> bool:
""" """
We assume the policy status is always public We assume the policy status is always public
""" """

View File

@ -1,23 +1,25 @@
import json import json
import xmltodict import xmltodict
from typing import Any, Dict, Tuple
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.s3.exceptions import S3ClientError from moto.s3.exceptions import S3ClientError
from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION
from moto.utilities.aws_headers import amzn_request_id from moto.utilities.aws_headers import amzn_request_id
from .models import s3control_backends from .models import s3control_backends, S3ControlBackend
class S3ControlResponse(BaseResponse): class S3ControlResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="s3control") super().__init__(service_name="s3control")
@property @property
def backend(self): def backend(self) -> S3ControlBackend:
return s3control_backends[self.current_account]["global"] return s3control_backends[self.current_account]["global"]
@amzn_request_id @amzn_request_id
def public_access_block(self, request, full_url, headers): def public_access_block(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
try: try:
if request.method == "GET": if request.method == "GET":
@ -29,7 +31,7 @@ class S3ControlResponse(BaseResponse):
except S3ClientError as err: except S3ClientError as err:
return err.code, {}, err.description return err.code, {}, err.description
def get_public_access_block(self, request): def get_public_access_block(self, request: Any) -> TYPE_RESPONSE:
account_id = request.headers.get("x-amz-account-id") account_id = request.headers.get("x-amz-account-id")
public_block_config = self.backend.get_public_access_block( public_block_config = self.backend.get_public_access_block(
account_id=account_id account_id=account_id
@ -37,7 +39,7 @@ class S3ControlResponse(BaseResponse):
template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION) template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION)
return 200, {}, template.render(public_block_config=public_block_config) return 200, {}, template.render(public_block_config=public_block_config)
def put_public_access_block(self, request): def put_public_access_block(self, request: Any) -> TYPE_RESPONSE:
account_id = request.headers.get("x-amz-account-id") account_id = request.headers.get("x-amz-account-id")
data = request.body if hasattr(request, "body") else request.data data = request.body if hasattr(request, "body") else request.data
pab_config = self._parse_pab_config(data) pab_config = self._parse_pab_config(data)
@ -46,18 +48,18 @@ class S3ControlResponse(BaseResponse):
) )
return 201, {}, json.dumps({}) return 201, {}, json.dumps({})
def delete_public_access_block(self, request): def delete_public_access_block(self, request: Any) -> TYPE_RESPONSE:
account_id = request.headers.get("x-amz-account-id") account_id = request.headers.get("x-amz-account-id")
self.backend.delete_public_access_block(account_id=account_id) self.backend.delete_public_access_block(account_id=account_id)
return 204, {}, json.dumps({}) return 204, {}, json.dumps({})
def _parse_pab_config(self, body): def _parse_pab_config(self, body: str) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(body) parsed_xml = xmltodict.parse(body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None) parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
return parsed_xml return parsed_xml
def access_point(self, request, full_url, headers): def access_point(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "PUT": if request.method == "PUT":
return self.create_access_point(full_url) return self.create_access_point(full_url)
@ -66,7 +68,7 @@ class S3ControlResponse(BaseResponse):
if request.method == "DELETE": if request.method == "DELETE":
return self.delete_access_point(full_url) return self.delete_access_point(full_url)
def access_point_policy(self, request, full_url, headers): def access_point_policy(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "PUT": if request.method == "PUT":
return self.create_access_point_policy(full_url) return self.create_access_point_policy(full_url)
@ -75,14 +77,14 @@ class S3ControlResponse(BaseResponse):
if request.method == "DELETE": if request.method == "DELETE":
return self.delete_access_point_policy(full_url) return self.delete_access_point_policy(full_url)
def access_point_policy_status(self, request, full_url, headers): def access_point_policy_status(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "PUT": if request.method == "PUT":
return self.create_access_point(full_url) return self.create_access_point(full_url)
if request.method == "GET": if request.method == "GET":
return self.get_access_point_policy_status(full_url) return self.get_access_point_policy_status(full_url)
def create_access_point(self, full_url): def create_access_point(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
params = xmltodict.parse(self.body)["CreateAccessPointRequest"] params = xmltodict.parse(self.body)["CreateAccessPointRequest"]
bucket = params["Bucket"] bucket = params["Bucket"]
@ -98,43 +100,45 @@ class S3ControlResponse(BaseResponse):
template = self.response_template(CREATE_ACCESS_POINT_TEMPLATE) template = self.response_template(CREATE_ACCESS_POINT_TEMPLATE)
return 200, {}, template.render(access_point=access_point) return 200, {}, template.render(access_point=access_point)
def get_access_point(self, full_url): def get_access_point(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
access_point = self.backend.get_access_point(account_id=account_id, name=name) access_point = self.backend.get_access_point(account_id=account_id, name=name)
template = self.response_template(GET_ACCESS_POINT_TEMPLATE) template = self.response_template(GET_ACCESS_POINT_TEMPLATE)
return 200, {}, template.render(access_point=access_point) return 200, {}, template.render(access_point=access_point)
def delete_access_point(self, full_url): def delete_access_point(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
self.backend.delete_access_point(account_id=account_id, name=name) self.backend.delete_access_point(account_id=account_id, name=name)
return 204, {}, "" return 204, {}, ""
def create_access_point_policy(self, full_url): def create_access_point_policy(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_policy(full_url) account_id, name = self._get_accountid_and_name_from_policy(full_url)
params = xmltodict.parse(self.body) params = xmltodict.parse(self.body)
policy = params["PutAccessPointPolicyRequest"]["Policy"] policy = params["PutAccessPointPolicyRequest"]["Policy"]
self.backend.create_access_point_policy(account_id, name, policy) self.backend.create_access_point_policy(account_id, name, policy)
return 200, {}, "" return 200, {}, ""
def get_access_point_policy(self, full_url): def get_access_point_policy(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_policy(full_url) account_id, name = self._get_accountid_and_name_from_policy(full_url)
policy = self.backend.get_access_point_policy(account_id, name) policy = self.backend.get_access_point_policy(account_id, name)
template = self.response_template(GET_ACCESS_POINT_POLICY_TEMPLATE) template = self.response_template(GET_ACCESS_POINT_POLICY_TEMPLATE)
return 200, {}, template.render(policy=policy) return 200, {}, template.render(policy=policy)
def delete_access_point_policy(self, full_url): def delete_access_point_policy(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_policy(full_url) account_id, name = self._get_accountid_and_name_from_policy(full_url)
self.backend.delete_access_point_policy(account_id=account_id, name=name) self.backend.delete_access_point_policy(account_id=account_id, name=name)
return 204, {}, "" return 204, {}, ""
def get_access_point_policy_status(self, full_url): def get_access_point_policy_status(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_policy(full_url) account_id, name = self._get_accountid_and_name_from_policy(full_url)
self.backend.get_access_point_policy_status(account_id, name) self.backend.get_access_point_policy_status(account_id, name)
template = self.response_template(GET_ACCESS_POINT_POLICY_STATUS_TEMPLATE) template = self.response_template(GET_ACCESS_POINT_POLICY_STATUS_TEMPLATE)
return 200, {}, template.render() return 200, {}, template.render()
def _get_accountid_and_name_from_accesspoint(self, full_url): def _get_accountid_and_name_from_accesspoint(
self, full_url: str
) -> Tuple[str, str]:
url = full_url url = full_url
if full_url.startswith("http"): if full_url.startswith("http"):
url = full_url.split("://")[1] url = full_url.split("://")[1]
@ -142,7 +146,7 @@ class S3ControlResponse(BaseResponse):
name = url.split("v20180820/accesspoint/")[-1] name = url.split("v20180820/accesspoint/")[-1]
return account_id, name return account_id, name
def _get_accountid_and_name_from_policy(self, full_url): def _get_accountid_and_name_from_policy(self, full_url: str) -> Tuple[str, str]:
url = full_url url = full_url
if full_url.startswith("http"): if full_url.startswith("http"):
url = full_url.split("://")[1] url = full_url.split("://")[1]

View File

@ -3,7 +3,7 @@ import os
import pathlib import pathlib
from functools import lru_cache from functools import lru_cache
from typing import Optional from typing import List, Optional
TEST_SERVER_MODE = os.environ.get("TEST_SERVER_MODE", "0").lower() == "true" TEST_SERVER_MODE = os.environ.get("TEST_SERVER_MODE", "0").lower() == "true"
@ -47,7 +47,7 @@ def get_sf_execution_history_type():
return os.environ.get("SF_EXECUTION_HISTORY_TYPE", "SUCCESS") return os.environ.get("SF_EXECUTION_HISTORY_TYPE", "SUCCESS")
def get_s3_custom_endpoints(): def get_s3_custom_endpoints() -> List[str]:
endpoints = os.environ.get("MOTO_S3_CUSTOM_ENDPOINTS") endpoints = os.environ.get("MOTO_S3_CUSTOM_ENDPOINTS")
if endpoints: if endpoints:
return endpoints.split(",") return endpoints.split(",")
@ -57,7 +57,7 @@ def get_s3_custom_endpoints():
S3_UPLOAD_PART_MIN_SIZE = 5242880 S3_UPLOAD_PART_MIN_SIZE = 5242880
def get_s3_default_key_buffer_size(): def get_s3_default_key_buffer_size() -> int:
return int( return int(
os.environ.get( os.environ.get(
"MOTO_S3_DEFAULT_KEY_BUFFER_SIZE", S3_UPLOAD_PART_MIN_SIZE - 1024 "MOTO_S3_DEFAULT_KEY_BUFFER_SIZE", S3_UPLOAD_PART_MIN_SIZE - 1024

View File

@ -77,7 +77,7 @@ def md5_hash(data: Any = None) -> Any:
class LowercaseDict(MutableMapping): class LowercaseDict(MutableMapping):
"""A dictionary that lowercases all keys""" """A dictionary that lowercases all keys"""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
self.store = dict() self.store = dict()
self.update(dict(*args, **kwargs)) # use the free update to set keys self.update(dict(*args, **kwargs)) # use the free update to set keys

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/scheduler files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/scheduler
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract