From c2e3d90fc96d3b47130a136c5f20806043922257 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Thu, 20 Apr 2023 16:47:39 +0000 Subject: [PATCH] Techdebt: MyPy S3 (#6235) --- moto/core/exceptions.py | 2 +- .../moto_api/_internal/managed_state_model.py | 10 +- moto/s3/cloud_formation.py | 13 +- moto/s3/config.py | 39 +- moto/s3/exceptions.py | 184 ++-- moto/s3/models.py | 833 ++++++++++-------- moto/s3/notifications.py | 21 +- moto/s3/responses.py | 280 +++--- moto/s3/select_object_content.py | 18 +- moto/s3/utils.py | 48 +- moto/s3bucket_path/utils.py | 5 +- moto/s3control/config.py | 33 +- moto/s3control/exceptions.py | 8 +- moto/s3control/models.py | 62 +- moto/s3control/responses.py | 44 +- moto/settings.py | 6 +- moto/utilities/utils.py | 2 +- setup.cfg | 2 +- 18 files changed, 880 insertions(+), 730 deletions(-) diff --git a/moto/core/exceptions.py b/moto/core/exceptions.py index ee37479ce..547ac0377 100644 --- a/moto/core/exceptions.py +++ b/moto/core/exceptions.py @@ -59,7 +59,7 @@ class RESTError(HTTPException): if template in self.templates.keys(): env = Environment(loader=DictLoader(self.templates)) - self.description = env.get_template(template).render( + self.description: str = env.get_template(template).render( # type: ignore error_type=error_type, message=message, request_id_tag=self.request_id_tag_name, diff --git a/moto/moto_api/_internal/managed_state_model.py b/moto/moto_api/_internal/managed_state_model.py index 2d25be598..5e78f76f3 100644 --- a/moto/moto_api/_internal/managed_state_model.py +++ b/moto/moto_api/_internal/managed_state_model.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta from moto.moto_api import state_manager -from typing import List, Tuple +from typing import List, Tuple, Optional class ManagedState: @@ -8,7 +8,7 @@ class ManagedState: Subclass this class to configure state-transitions """ - def __init__(self, model_name: str, transitions: List[Tuple[str, str]]): + def __init__(self, model_name: str, transitions: List[Tuple[Optional[str], str]]): # Indicate the possible transitions for this model # Example: [(initializing,queued), (queued, starting), (starting, ready)] self._transitions = transitions @@ -28,7 +28,7 @@ class ManagedState: self._tick += 1 @property - def status(self) -> str: + def status(self) -> Optional[str]: """ Transitions the status as appropriate before returning """ @@ -55,12 +55,12 @@ class ManagedState: def status(self, value: str) -> None: self._status = value - def _get_next_status(self, previous: str) -> str: + def _get_next_status(self, previous: Optional[str]) -> Optional[str]: return next( (nxt for prev, nxt in self._transitions if previous == prev), previous ) - def _get_last_status(self, previous: str) -> str: + def _get_last_status(self, previous: Optional[str]) -> Optional[str]: next_state = self._get_next_status(previous) while next_state != previous: previous = next_state diff --git a/moto/s3/cloud_formation.py b/moto/s3/cloud_formation.py index 0bf6022ef..6331e3721 100644 --- a/moto/s3/cloud_formation.py +++ b/moto/s3/cloud_formation.py @@ -1,7 +1,10 @@ from collections import OrderedDict +from typing import Any, Dict, List -def cfn_to_api_encryption(bucket_encryption_properties): +def cfn_to_api_encryption( + bucket_encryption_properties: Dict[str, Any] +) -> Dict[str, Any]: sse_algorithm = bucket_encryption_properties["ServerSideEncryptionConfiguration"][ 0 @@ -16,14 +19,12 @@ def cfn_to_api_encryption(bucket_encryption_properties): rule = OrderedDict( {"ApplyServerSideEncryptionByDefault": apply_server_side_encryption_by_default} ) - bucket_encryption = OrderedDict( - {"@xmlns": "http://s3.amazonaws.com/doc/2006-03-01/"} + return OrderedDict( + {"@xmlns": "http://s3.amazonaws.com/doc/2006-03-01/", "Rule": rule} ) - bucket_encryption["Rule"] = rule - return bucket_encryption -def is_replacement_update(properties): +def is_replacement_update(properties: List[str]) -> bool: properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"] return any( [ diff --git a/moto/s3/config.py b/moto/s3/config.py index acb7d57e3..889c946cd 100644 --- a/moto/s3/config.py +++ b/moto/s3/config.py @@ -1,4 +1,5 @@ import json +from typing import Any, Dict, List, Optional, Tuple from moto.core.exceptions import InvalidNextTokenException from moto.core.common_models import ConfigQueryModel @@ -8,15 +9,15 @@ from moto.s3 import s3_backends class S3ConfigQuery(ConfigQueryModel): def list_config_service_resources( self, - account_id, - resource_ids, - resource_name, - limit, - next_token, - backend_region=None, - resource_region=None, - aggregator=None, - ): + account_id: str, + resource_ids: Optional[List[str]], + resource_name: Optional[str], + limit: int, + next_token: Optional[str], + backend_region: Optional[str] = None, + resource_region: Optional[str] = None, + aggregator: Optional[Dict[str, Any]] = None, + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: # The resource_region only matters for aggregated queries as you can filter on bucket regions for them. # For other resource types, you would need to iterate appropriately for the backend_region. @@ -37,7 +38,7 @@ class S3ConfigQuery(ConfigQueryModel): filter_buckets = [resource_name] if resource_name else resource_ids for bucket in self.backends[account_id]["global"].buckets.keys(): - if bucket in filter_buckets: + if bucket in filter_buckets: # type: ignore bucket_list.append(bucket) # Filter on the proper region if supplied: @@ -95,26 +96,26 @@ class S3ConfigQuery(ConfigQueryModel): def get_config_resource( self, - account_id, - resource_id, - resource_name=None, - backend_region=None, - resource_region=None, - ): + account_id: str, + resource_id: str, + resource_name: Optional[str] = None, + backend_region: Optional[str] = None, + resource_region: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: # Get the bucket: bucket = self.backends[account_id]["global"].buckets.get(resource_id, {}) if not bucket: - return + return None # Are we filtering based on region? region_filter = backend_region or resource_region if region_filter and bucket.region_name != region_filter: - return + return None # Are we also filtering on bucket name? if resource_name and bucket.name != resource_name: - return + return None # Format the bucket to the AWS Config format: config_data = bucket.to_config_dict() diff --git a/moto/s3/exceptions.py b/moto/s3/exceptions.py index 7345f929f..e6571f7c7 100644 --- a/moto/s3/exceptions.py +++ b/moto/s3/exceptions.py @@ -1,3 +1,4 @@ +from typing import Any, Optional, Union from moto.core.exceptions import RESTError ERROR_WITH_BUCKET_NAME = """{% extends 'single_error' %} @@ -35,7 +36,7 @@ class S3ClientError(RESTError): # S3 API uses as the XML tag in response messages request_id_tag_name = "RequestID" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): kwargs.setdefault("template", "single_error") self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME super().__init__(*args, **kwargs) @@ -44,7 +45,7 @@ class S3ClientError(RESTError): class InvalidArgumentError(S3ClientError): code = 400 - def __init__(self, message, name, value, *args, **kwargs): + def __init__(self, message: str, name: str, value: str, *args: Any, **kwargs: Any): kwargs.setdefault("template", "argument_error") kwargs["name"] = name kwargs["value"] = value @@ -60,7 +61,7 @@ class AccessForbidden(S3ClientError): class BucketError(S3ClientError): - def __init__(self, *args: str, **kwargs: str): + def __init__(self, *args: Any, **kwargs: Any): kwargs.setdefault("template", "bucket_error") self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME super().__init__(*args, **kwargs) @@ -69,7 +70,7 @@ class BucketError(S3ClientError): class BucketAlreadyExists(BucketError): code = 409 - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): kwargs.setdefault("template", "bucket_error") self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME super().__init__( @@ -87,16 +88,16 @@ class BucketAlreadyExists(BucketError): class MissingBucket(BucketError): code = 404 - def __init__(self, *args, **kwargs): + def __init__(self, bucket: str): super().__init__( - "NoSuchBucket", "The specified bucket does not exist", *args, **kwargs + "NoSuchBucket", "The specified bucket does not exist", bucket=bucket ) class MissingKey(S3ClientError): code = 404 - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): kwargs.setdefault("template", "key_error") self.templates["key_error"] = ERROR_WITH_KEY_NAME super().__init__("NoSuchKey", "The specified key does not exist.", **kwargs) @@ -105,16 +106,14 @@ class MissingKey(S3ClientError): class MissingVersion(S3ClientError): code = 404 - def __init__(self, *args, **kwargs): - super().__init__( - "NoSuchVersion", "The specified version does not exist.", *args, **kwargs - ) + def __init__(self) -> None: + super().__init__("NoSuchVersion", "The specified version does not exist.") class InvalidVersion(S3ClientError): code = 400 - def __init__(self, version_id, *args, **kwargs): + def __init__(self, version_id: str, *args: Any, **kwargs: Any): kwargs.setdefault("template", "argument_error") kwargs["name"] = "versionId" kwargs["value"] = version_id @@ -127,7 +126,7 @@ class InvalidVersion(S3ClientError): class ObjectNotInActiveTierError(S3ClientError): code = 403 - def __init__(self, key_name): + def __init__(self, key_name: Any): super().__init__( "ObjectNotInActiveTierError", "The source object of the COPY operation is not in the active tier and is only stored in Amazon Glacier.", @@ -138,105 +137,84 @@ class ObjectNotInActiveTierError(S3ClientError): class InvalidPartOrder(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidPartOrder", - ( - "The list of parts was not in ascending order. The parts " - "list must be specified in order by part number." - ), - *args, - **kwargs, + "The list of parts was not in ascending order. The parts list must be specified in order by part number.", ) class InvalidPart(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidPart", - ( - "One or more of the specified parts could not be found. " - "The part might not have been uploaded, or the specified " - "entity tag might not have matched the part's entity tag." - ), - *args, - **kwargs, + "One or more of the specified parts could not be found. The part might not have been uploaded, or the specified entity tag might not have matched the part's entity tag.", ) class EntityTooSmall(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "EntityTooSmall", "Your proposed upload is smaller than the minimum allowed object size.", - *args, - **kwargs, ) class InvalidRequest(S3ClientError): code = 400 - def __init__(self, method, *args, **kwargs): + def __init__(self, method: str): super().__init__( "InvalidRequest", f"Found unsupported HTTP method in CORS config. Unsupported method is {method}", - *args, - **kwargs, ) class IllegalLocationConstraintException(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "IllegalLocationConstraintException", "The unspecified location constraint is incompatible for the region specific endpoint this request was sent to.", - *args, - **kwargs, ) class MalformedXML(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "MalformedXML", "The XML you provided was not well-formed or did not validate against our published schema", - *args, - **kwargs, ) class MalformedACLError(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "MalformedACLError", "The XML you provided was not well-formed or did not validate against our published schema", - *args, - **kwargs, ) class InvalidTargetBucketForLogging(S3ClientError): code = 400 - def __init__(self, msg): + def __init__(self, msg: str): super().__init__("InvalidTargetBucketForLogging", msg) class CrossLocationLoggingProhibitted(S3ClientError): code = 403 - def __init__(self): + def __init__(self) -> None: super().__init__( "CrossLocationLoggingProhibitted", "Cross S3 location logging not allowed." ) @@ -245,7 +223,7 @@ class CrossLocationLoggingProhibitted(S3ClientError): class InvalidMaxPartArgument(S3ClientError): code = 400 - def __init__(self, arg, min_val, max_val): + def __init__(self, arg: str, min_val: int, max_val: int): error = f"Argument {arg} must be an integer between {min_val} and {max_val}" super().__init__("InvalidArgument", error) @@ -253,97 +231,83 @@ class InvalidMaxPartArgument(S3ClientError): class InvalidMaxPartNumberArgument(InvalidArgumentError): code = 400 - def __init__(self, value, *args, **kwargs): + def __init__(self, value: int): error = "Part number must be an integer between 1 and 10000, inclusive" - super().__init__(message=error, name="partNumber", value=value, *args, **kwargs) + super().__init__(message=error, name="partNumber", value=value) # type: ignore class NotAnIntegerException(InvalidArgumentError): code = 400 - def __init__(self, name, value, *args, **kwargs): + def __init__(self, name: str, value: int): error = f"Provided {name} not an integer or within integer range" - super().__init__(message=error, name=name, value=value, *args, **kwargs) + super().__init__(message=error, name=name, value=value) # type: ignore class InvalidNotificationARN(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): - super().__init__( - "InvalidArgument", "The ARN is not well formed", *args, **kwargs - ) + def __init__(self) -> None: + super().__init__("InvalidArgument", "The ARN is not well formed") class InvalidNotificationDestination(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidArgument", "The notification destination service region is not valid for the bucket location constraint", - *args, - **kwargs, ) class InvalidNotificationEvent(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidArgument", "The event is not supported for notifications", - *args, - **kwargs, ) class InvalidStorageClass(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self, storage: Optional[str]): super().__init__( "InvalidStorageClass", "The storage class you specified is not valid", - *args, - **kwargs, + storage=storage, ) class InvalidBucketName(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): - super().__init__( - "InvalidBucketName", "The specified bucket is not valid.", *args, **kwargs - ) + def __init__(self) -> None: + super().__init__("InvalidBucketName", "The specified bucket is not valid.") class DuplicateTagKeys(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): - super().__init__( - "InvalidTag", - "Cannot provide multiple Tags with the same key", - *args, - **kwargs, - ) + def __init__(self) -> None: + super().__init__("InvalidTag", "Cannot provide multiple Tags with the same key") class S3AccessDeniedError(S3ClientError): code = 403 - def __init__(self, *args: str, **kwargs: str): - super().__init__("AccessDenied", "Access Denied", *args, **kwargs) + def __init__(self) -> None: + super().__init__("AccessDenied", "Access Denied") class BucketAccessDeniedError(BucketError): code = 403 - def __init__(self, *args: str, **kwargs: str): - super().__init__("AccessDenied", "Access Denied", *args, **kwargs) + def __init__(self, bucket: str): + super().__init__("AccessDenied", "Access Denied", bucket=bucket) class S3InvalidTokenError(S3ClientError): @@ -368,12 +332,11 @@ class S3AclAndGrantError(S3ClientError): class BucketInvalidTokenError(BucketError): code = 400 - def __init__(self, *args: str, **kwargs: str): + def __init__(self, bucket: str): super().__init__( "InvalidToken", "The provided token is malformed or otherwise invalid.", - *args, - **kwargs, + bucket=bucket, ) @@ -390,12 +353,11 @@ class S3InvalidAccessKeyIdError(S3ClientError): class BucketInvalidAccessKeyIdError(S3ClientError): code = 403 - def __init__(self, *args: str, **kwargs: str): + def __init__(self, bucket: str): super().__init__( "InvalidAccessKeyId", "The AWS Access Key Id you provided does not exist in our records.", - *args, - **kwargs, + bucket=bucket, ) @@ -412,50 +374,45 @@ class S3SignatureDoesNotMatchError(S3ClientError): class BucketSignatureDoesNotMatchError(S3ClientError): code = 403 - def __init__(self, *args: str, **kwargs: str): + def __init__(self, bucket: str): super().__init__( "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided. Check your key and signing method.", - *args, - **kwargs, + bucket=bucket, ) class NoSuchPublicAccessBlockConfiguration(S3ClientError): code = 404 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "NoSuchPublicAccessBlockConfiguration", "The public access block configuration was not found", - *args, - **kwargs, ) class InvalidPublicAccessBlockConfiguration(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( "InvalidRequest", "Must specify at least one configuration.", - *args, - **kwargs, ) class WrongPublicAccessBlockAccountIdError(S3ClientError): code = 403 - def __init__(self): + def __init__(self) -> None: super().__init__("AccessDenied", "Access Denied") class NoSystemTags(S3ClientError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidTag", "System tags cannot be added/updated by requester" ) @@ -464,7 +421,7 @@ class NoSystemTags(S3ClientError): class NoSuchUpload(S3ClientError): code = 404 - def __init__(self, upload_id, *args, **kwargs): + def __init__(self, upload_id: Union[int, str], *args: Any, **kwargs: Any): kwargs.setdefault("template", "error_uploadid") kwargs["upload_id"] = upload_id self.templates["error_uploadid"] = ERROR_WITH_UPLOADID @@ -479,7 +436,7 @@ class NoSuchUpload(S3ClientError): class PreconditionFailed(S3ClientError): code = 412 - def __init__(self, failed_condition, **kwargs): + def __init__(self, failed_condition: str, **kwargs: Any): kwargs.setdefault("template", "condition_error") self.templates["condition_error"] = ERROR_WITH_CONDITION_NAME super().__init__( @@ -493,7 +450,7 @@ class PreconditionFailed(S3ClientError): class InvalidRange(S3ClientError): code = 416 - def __init__(self, range_requested, actual_size, **kwargs): + def __init__(self, range_requested: str, actual_size: str, **kwargs: Any): kwargs.setdefault("template", "range_error") self.templates["range_error"] = ERROR_WITH_RANGE super().__init__( @@ -508,19 +465,16 @@ class InvalidRange(S3ClientError): class InvalidContinuationToken(S3ClientError): code = 400 - def __init__(self, *args, **kwargs): + def __init__(self) -> None: super().__init__( - "InvalidArgument", - "The continuation token provided is incorrect", - *args, - **kwargs, + "InvalidArgument", "The continuation token provided is incorrect" ) class InvalidObjectState(BucketError): code = 403 - def __init__(self, storage_class, **kwargs): + def __init__(self, storage_class: Optional[str], **kwargs: Any): kwargs.setdefault("template", "storage_error") self.templates["storage_error"] = ERROR_WITH_STORAGE_CLASS super().__init__( @@ -534,35 +488,35 @@ class InvalidObjectState(BucketError): class LockNotEnabled(S3ClientError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("InvalidRequest", "Bucket is missing ObjectLockConfiguration") class AccessDeniedByLock(S3ClientError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("AccessDenied", "Access Denied") class InvalidContentMD5(S3ClientError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("InvalidContentMD5", "Content MD5 header is invalid") class BucketNeedsToBeNew(S3ClientError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("InvalidBucket", "Bucket needs to be empty") class BucketMustHaveLockeEnabled(S3ClientError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidBucketState", "Object Lock configuration cannot be enabled on existing buckets", @@ -572,7 +526,7 @@ class BucketMustHaveLockeEnabled(S3ClientError): class CopyObjectMustChangeSomething(S3ClientError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidRequest", "This copy request is illegal because it is trying to copy an object to itself without changing the object's metadata, storage class, website redirect location or encryption attributes.", @@ -582,27 +536,25 @@ class CopyObjectMustChangeSomething(S3ClientError): class InvalidFilterRuleName(InvalidArgumentError): code = 400 - def __init__(self, value, *args, **kwargs): + def __init__(self, value: str): super().__init__( "filter rule name must be either prefix or suffix", "FilterRule.Name", value, - *args, - **kwargs, ) class InvalidTagError(S3ClientError): code = 400 - def __init__(self, value, *args, **kwargs): - super().__init__("InvalidTag", value, *args, **kwargs) + def __init__(self, value: str): + super().__init__("InvalidTag", value) class ObjectLockConfigurationNotFoundError(S3ClientError): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__( "ObjectLockConfigurationNotFoundError", "Object Lock configuration does not exist for this bucket", diff --git a/moto/s3/models.py b/moto/s3/models.py index 35fbcbee3..f4f269eec 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -12,7 +12,7 @@ import sys import urllib.parse from bisect import insort -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set, Tuple, Iterator, Union from importlib import reload from moto.core import BaseBackend, BaseModel, BackendDict, CloudFormationModel from moto.core import CloudWatchMetricProvider @@ -53,7 +53,12 @@ from moto.s3.exceptions import ( from .cloud_formation import cfn_to_api_encryption, is_replacement_update from . import notifications from .select_object_content import parse_query -from .utils import clean_key_name, _VersionedKeyStore, undo_clean_key_name +from .utils import ( + clean_key_name, + _VersionedKeyStore, + undo_clean_key_name, + CaseInsensitiveDict, +) from .utils import ARCHIVE_STORAGE_CLASSES, STORAGE_CLASS from ..events.notifications import send_notification as events_send_notification from ..settings import get_s3_default_key_buffer_size, S3_UPLOAD_PART_MIN_SIZE @@ -66,41 +71,41 @@ OWNER = "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a" class FakeDeleteMarker(BaseModel): - def __init__(self, key): + def __init__(self, key: "FakeKey"): self.key = key self.name = key.name self.last_modified = datetime.datetime.utcnow() self._version_id = str(random.uuid4()) @property - def last_modified_ISO8601(self): - return iso_8601_datetime_without_milliseconds_s3(self.last_modified) + def last_modified_ISO8601(self) -> str: + return iso_8601_datetime_without_milliseconds_s3(self.last_modified) # type: ignore @property - def version_id(self): + def version_id(self) -> str: return self._version_id class FakeKey(BaseModel, ManagedState): def __init__( self, - name, - value, - account_id=None, - storage="STANDARD", - etag=None, - is_versioned=False, - version_id=0, - max_buffer_size=None, - multipart=None, - bucket_name=None, - encryption=None, - kms_key_id=None, - bucket_key_enabled=None, - lock_mode=None, - lock_legal_status=None, - lock_until=None, - checksum_value=None, + name: str, + value: bytes, + account_id: Optional[str] = None, + storage: Optional[str] = "STANDARD", + etag: Optional[str] = None, + is_versioned: bool = False, + version_id: int = 0, + max_buffer_size: Optional[int] = None, + multipart: Optional["FakeMultipart"] = None, + bucket_name: Optional[str] = None, + encryption: Optional[str] = None, + kms_key_id: Optional[str] = None, + bucket_key_enabled: Any = None, + lock_mode: Optional[str] = None, + lock_legal_status: Optional[str] = None, + lock_until: Optional[str] = None, + checksum_value: Optional[str] = None, ): ManagedState.__init__( self, @@ -113,12 +118,12 @@ class FakeKey(BaseModel, ManagedState): self.name = name self.account_id = account_id self.last_modified = datetime.datetime.utcnow() - self.acl = get_canned_acl("private") + self.acl: Optional[FakeAcl] = get_canned_acl("private") self.website_redirect_location = None self.checksum_algorithm = None - self._storage_class = storage if storage else "STANDARD" + self._storage_class: Optional[str] = storage if storage else "STANDARD" self._metadata = LowercaseDict() - self._expiry = None + self._expiry: Optional[datetime.datetime] = None self._etag = etag self._version_id = version_id self._is_versioned = is_versioned @@ -130,7 +135,7 @@ class FakeKey(BaseModel, ManagedState): ) self._value_buffer = tempfile.SpooledTemporaryFile(self._max_buffer_size) self.disposed = False - self.value = value + self.value = value # type: ignore self.lock = threading.Lock() self.encryption = encryption @@ -145,17 +150,17 @@ class FakeKey(BaseModel, ManagedState): # Default metadata values self._metadata["Content-Type"] = "binary/octet-stream" - def safe_name(self, encoding_type=None): + def safe_name(self, encoding_type: Optional[str] = None) -> str: if encoding_type == "url": return urllib.parse.quote(self.name) return self.name @property - def version_id(self): + def version_id(self) -> int: return self._version_id @property - def value(self): + def value(self) -> bytes: with self.lock: self._value_buffer.seek(0) r = self._value_buffer.read() @@ -163,12 +168,12 @@ class FakeKey(BaseModel, ManagedState): return r @property - def arn(self): + def arn(self) -> str: # S3 Objects don't have an ARN, but we do need something unique when creating tags against this resource return f"arn:aws:s3:::{self.bucket_name}/{self.name}/{self.version_id}" - @value.setter - def value(self, new_value): + @value.setter # type: ignore + def value(self, new_value: bytes) -> None: self._value_buffer.seek(0) self._value_buffer.truncate() @@ -179,27 +184,27 @@ class FakeKey(BaseModel, ManagedState): self._value_buffer.write(new_value) self.contentsize = len(new_value) - def set_metadata(self, metadata, replace=False): + def set_metadata(self, metadata: Any, replace: bool = False) -> None: if replace: - self._metadata = {} + self._metadata = {} # type: ignore self._metadata.update(metadata) - def set_storage_class(self, storage): + def set_storage_class(self, storage: Optional[str]) -> None: if storage is not None and storage not in STORAGE_CLASS: raise InvalidStorageClass(storage=storage) self._storage_class = storage - def set_expiry(self, expiry): + def set_expiry(self, expiry: Optional[datetime.datetime]) -> None: self._expiry = expiry - def set_acl(self, acl): + def set_acl(self, acl: Optional["FakeAcl"]) -> None: self.acl = acl - def restore(self, days): + def restore(self, days: int) -> None: self._expiry = datetime.datetime.utcnow() + datetime.timedelta(days) @property - def etag(self): + def etag(self) -> str: if self._etag is None: value_md5 = md5_hash() self._value_buffer.seek(0) @@ -213,22 +218,22 @@ class FakeKey(BaseModel, ManagedState): return f'"{self._etag}"' @property - def last_modified_ISO8601(self): - return iso_8601_datetime_without_milliseconds_s3(self.last_modified) + def last_modified_ISO8601(self) -> str: + return iso_8601_datetime_without_milliseconds_s3(self.last_modified) # type: ignore @property - def last_modified_RFC1123(self): + def last_modified_RFC1123(self) -> str: # Different datetime formats depending on how the key is obtained # https://github.com/boto/boto/issues/466 return rfc_1123_datetime(self.last_modified) @property - def metadata(self): + def metadata(self) -> LowercaseDict: return self._metadata @property - def response_dict(self): - res = { + def response_dict(self) -> Dict[str, Any]: # type: ignore[misc] + res: Dict[str, Any] = { "ETag": self.etag, "last-modified": self.last_modified_RFC1123, "content-length": str(self.size), @@ -279,23 +284,24 @@ class FakeKey(BaseModel, ManagedState): return res @property - def size(self): + def size(self) -> int: return self.contentsize @property - def storage_class(self): + def storage_class(self) -> Optional[str]: return self._storage_class @property - def expiry_date(self): + def expiry_date(self) -> Optional[str]: if self._expiry is not None: return self._expiry.strftime("%a, %d %b %Y %H:%M:%S GMT") + return None # Keys need to be pickleable due to some implementation details of boto3. # Since file objects aren't pickleable, we need to override the default # behavior. The following is adapted from the Python docs: # https://docs.python.org/3/library/pickle.html#handling-stateful-objects - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() try: state["value"] = self.value @@ -307,17 +313,17 @@ class FakeKey(BaseModel, ManagedState): del state["lock"] return state - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self.__dict__.update({k: v for k, v in state.items() if k != "value"}) self._value_buffer = tempfile.SpooledTemporaryFile( max_size=self._max_buffer_size ) - self.value = state["value"] + self.value = state["value"] # type: ignore self.lock = threading.Lock() @property - def is_locked(self): + def is_locked(self) -> bool: if self.lock_legal_status == "ON": return True @@ -325,11 +331,11 @@ class FakeKey(BaseModel, ManagedState): now = datetime.datetime.utcnow() try: until = datetime.datetime.strptime( - self.lock_until, "%Y-%m-%dT%H:%M:%SZ" + self.lock_until, "%Y-%m-%dT%H:%M:%SZ" # type: ignore ) except ValueError: until = datetime.datetime.strptime( - self.lock_until, "%Y-%m-%dT%H:%M:%S.%fZ" + self.lock_until, "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore ) if until > now: @@ -337,7 +343,7 @@ class FakeKey(BaseModel, ManagedState): return False - def dispose(self, garbage=False): + def dispose(self, garbage: bool = False) -> None: if garbage and not self.disposed: import warnings @@ -350,28 +356,28 @@ class FakeKey(BaseModel, ManagedState): pass self.disposed = True - def __del__(self): + def __del__(self) -> None: self.dispose(garbage=True) class FakeMultipart(BaseModel): def __init__( self, - key_name, - metadata, - storage=None, - tags=None, - acl=None, - sse_encryption=None, - kms_key_id=None, + key_name: str, + metadata: CaseInsensitiveDict, # type: ignore + storage: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + acl: Optional["FakeAcl"] = None, + sse_encryption: Optional[str] = None, + kms_key_id: Optional[str] = None, ): self.key_name = key_name self.metadata = metadata self.storage = storage self.tags = tags self.acl = acl - self.parts = {} - self.partlist = [] # ordered list of part ID's + self.parts: Dict[int, FakeKey] = {} + self.partlist: List[int] = [] # ordered list of part ID's rand_b64 = base64.b64encode(os.urandom(UPLOAD_ID_BYTES)) self.id = ( rand_b64.decode("utf-8").replace("=", "").replace("+", "").replace("/", "") @@ -379,7 +385,7 @@ class FakeMultipart(BaseModel): self.sse_encryption = sse_encryption self.kms_key_id = kms_key_id - def complete(self, body): + def complete(self, body: Iterator[Tuple[int, str]]) -> Tuple[bytes, str]: decode_hex = codecs.getdecoder("hex_codec") total = bytearray() md5s = bytearray() @@ -396,7 +402,7 @@ class FakeMultipart(BaseModel): raise InvalidPart() if last is not None and last.contentsize < S3_UPLOAD_PART_MIN_SIZE: raise EntityTooSmall() - md5s.extend(decode_hex(part_etag)[0]) + md5s.extend(decode_hex(part_etag)[0]) # type: ignore total.extend(part.value) last = part count += 1 @@ -404,16 +410,16 @@ class FakeMultipart(BaseModel): if count == 0: raise MalformedXML - etag = md5_hash() - etag.update(bytes(md5s)) - return total, f"{etag.hexdigest()}-{count}" + full_etag = md5_hash() + full_etag.update(bytes(md5s)) + return total, f"{full_etag.hexdigest()}-{count}" - def set_part(self, part_id, value): + def set_part(self, part_id: int, value: bytes) -> FakeKey: if part_id < 1: raise NoSuchUpload(upload_id=part_id) key = FakeKey( - part_id, value, encryption=self.sse_encryption, kms_key_id=self.kms_key_id + part_id, value, encryption=self.sse_encryption, kms_key_id=self.kms_key_id # type: ignore ) if part_id in self.parts: # We're overwriting the current part - dispose of it first @@ -423,23 +429,23 @@ class FakeMultipart(BaseModel): insort(self.partlist, part_id) return key - def list_parts(self, part_number_marker, max_parts): + def list_parts(self, part_number_marker: int, max_parts: int) -> Iterator[FakeKey]: max_marker = part_number_marker + max_parts for part_id in self.partlist[part_number_marker:max_marker]: yield self.parts[part_id] - def dispose(self): + def dispose(self) -> None: for part in self.parts.values(): part.dispose() class FakeGrantee(BaseModel): - def __init__(self, grantee_id="", uri="", display_name=""): + def __init__(self, grantee_id: str = "", uri: str = "", display_name: str = ""): self.id = grantee_id self.uri = uri self.display_name = display_name - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, FakeGrantee): return False return ( @@ -449,10 +455,10 @@ class FakeGrantee(BaseModel): ) @property - def type(self): + def type(self) -> str: return "Group" if self.uri else "CanonicalUser" - def __repr__(self): + def __repr__(self) -> str: return f"FakeGrantee(display_name: '{self.display_name}', id: '{self.id}', uri: '{self.uri}')" @@ -478,21 +484,20 @@ CAMEL_CASED_PERMISSIONS = { class FakeGrant(BaseModel): - def __init__(self, grantees, permissions): + def __init__(self, grantees: List[FakeGrantee], permissions: List[str]): self.grantees = grantees self.permissions = permissions - def __repr__(self): + def __repr__(self) -> str: return f"FakeGrant(grantees: {self.grantees}, permissions: {self.permissions})" class FakeAcl(BaseModel): - def __init__(self, grants=None): - grants = grants or [] - self.grants = grants + def __init__(self, grants: Optional[List[FakeGrant]] = None): + self.grants = grants or [] @property - def public_read(self): + def public_read(self) -> bool: for grant in self.grants: if ALL_USERS_GRANTEE in grant.grantees: if PERMISSION_READ in grant.permissions: @@ -501,12 +506,12 @@ class FakeAcl(BaseModel): return True return False - def __repr__(self): + def __repr__(self) -> str: return f"FakeAcl(grants: {self.grants})" - def to_config_dict(self): + def to_config_dict(self) -> Dict[str, Any]: """Returns the object into the format expected by AWS Config""" - data = { + data: Dict[str, Any] = { "grantSet": None, # Always setting this to None. Feel free to change. "owner": {"displayName": None, "id": OWNER}, } @@ -517,7 +522,7 @@ class FakeAcl(BaseModel): permissions = ( grant.permissions if isinstance(grant.permissions, list) - else [grant.permissions] + else [grant.permissions] # type: ignore ) for permission in permissions: for grantee in grant.grantees: @@ -533,7 +538,7 @@ class FakeAcl(BaseModel): else: grant_list.append( { - "grantee": { + "grantee": { # type: ignore "id": grantee.id, "displayName": None if not grantee.display_name @@ -549,7 +554,7 @@ class FakeAcl(BaseModel): return data -def get_canned_acl(acl): +def get_canned_acl(acl: str) -> FakeAcl: owner_grantee = FakeGrantee(grantee_id=OWNER) grants = [FakeGrant([owner_grantee], [PERMISSION_FULL_CONTROL])] if acl == "private": @@ -578,12 +583,17 @@ def get_canned_acl(acl): class LifecycleFilter(BaseModel): - def __init__(self, prefix=None, tag=None, and_filter=None): + def __init__( + self, + prefix: Optional[str] = None, + tag: Optional[Tuple[str, str]] = None, + and_filter: Optional["LifecycleAndFilter"] = None, + ): self.prefix = prefix (self.tag_key, self.tag_value) = tag if tag else (None, None) self.and_filter = and_filter - def to_config_dict(self): + def to_config_dict(self) -> Dict[str, Any]: if self.prefix is not None: return { "predicate": {"type": "LifecyclePrefixPredicate", "prefix": self.prefix} @@ -601,18 +611,20 @@ class LifecycleFilter(BaseModel): return { "predicate": { "type": "LifecycleAndOperator", - "operands": self.and_filter.to_config_dict(), + "operands": self.and_filter.to_config_dict(), # type: ignore } } class LifecycleAndFilter(BaseModel): - def __init__(self, prefix=None, tags=None): + def __init__( + self, prefix: Optional[str] = None, tags: Optional[Dict[str, str]] = None + ): self.prefix = prefix - self.tags = tags + self.tags = tags or {} - def to_config_dict(self): - data = [] + def to_config_dict(self) -> List[Dict[str, Any]]: + data: List[Dict[str, Any]] = [] if self.prefix is not None: data.append({"type": "LifecyclePrefixPredicate", "prefix": self.prefix}) @@ -628,20 +640,20 @@ class LifecycleAndFilter(BaseModel): class LifecycleRule(BaseModel): def __init__( self, - rule_id=None, - prefix=None, - lc_filter=None, - status=None, - expiration_days=None, - expiration_date=None, - transition_days=None, - transition_date=None, - storage_class=None, - expired_object_delete_marker=None, - nve_noncurrent_days=None, - nvt_noncurrent_days=None, - nvt_storage_class=None, - aimu_days=None, + rule_id: Optional[str] = None, + prefix: Optional[str] = None, + lc_filter: Optional[LifecycleFilter] = None, + status: Optional[str] = None, + expiration_days: Optional[str] = None, + expiration_date: Optional[str] = None, + transition_days: Optional[str] = None, + transition_date: Optional[str] = None, + storage_class: Optional[str] = None, + expired_object_delete_marker: Optional[str] = None, + nve_noncurrent_days: Optional[str] = None, + nvt_noncurrent_days: Optional[str] = None, + nvt_storage_class: Optional[str] = None, + aimu_days: Optional[str] = None, ): self.id = rule_id self.prefix = prefix @@ -658,7 +670,7 @@ class LifecycleRule(BaseModel): self.nvt_storage_class = nvt_storage_class self.aimu_days = aimu_days - def to_config_dict(self): + def to_config_dict(self) -> Dict[str, Any]: """Converts the object to the AWS Config data dict. Note: The following are missing that should be added in the future: @@ -669,7 +681,7 @@ class LifecycleRule(BaseModel): :return: """ - lifecycle_dict = { + lifecycle_dict: Dict[str, Any] = { "id": self.id, "prefix": self.prefix, "status": self.status, @@ -677,7 +689,7 @@ class LifecycleRule(BaseModel): if self.expiration_days else None, "expiredObjectDeleteMarker": self.expired_object_delete_marker, - "noncurrentVersionExpirationInDays": -1 or int(self.nve_noncurrent_days), + "noncurrentVersionExpirationInDays": -1 or int(self.nve_noncurrent_days), # type: ignore "expirationDate": self.expiration_date, "transitions": None, # Replace me with logic to fill in "noncurrentVersionTransitions": None, # Replace me with logic to fill in @@ -697,7 +709,7 @@ class LifecycleRule(BaseModel): elif self.prefix: lifecycle_dict["filter"] = None else: - lifecycle_dict["filter"] = self.filter.to_config_dict() + lifecycle_dict["filter"] = self.filter.to_config_dict() # type: ignore return lifecycle_dict @@ -705,11 +717,11 @@ class LifecycleRule(BaseModel): class CorsRule(BaseModel): def __init__( self, - allowed_methods, - allowed_origins, - allowed_headers=None, - expose_headers=None, - max_age_seconds=None, + allowed_methods: Any, + allowed_origins: Any, + allowed_headers: Any = None, + expose_headers: Any = None, + max_age_seconds: Any = None, ): self.allowed_methods = ( [allowed_methods] if isinstance(allowed_methods, str) else allowed_methods @@ -727,7 +739,13 @@ class CorsRule(BaseModel): class Notification(BaseModel): - def __init__(self, arn, events, filters=None, notification_id=None): + def __init__( + self, + arn: str, + events: List[str], + filters: Optional[Dict[str, Any]] = None, + notification_id: Optional[str] = None, + ): self.id = notification_id or "".join( random.choice(string.ascii_letters + string.digits) for _ in range(50) ) @@ -735,7 +753,7 @@ class Notification(BaseModel): self.events = events self.filters = filters if filters else {} - def _event_matches(self, event_name): + def _event_matches(self, event_name: str) -> bool: if event_name in self.events: return True # s3:ObjectCreated:Put --> s3:ObjectCreated:* @@ -744,7 +762,7 @@ class Notification(BaseModel): return True return False - def _key_matches(self, key_name): + def _key_matches(self, key_name: str) -> bool: if "S3Key" not in self.filters: return True _filters = {f["Name"]: f["Value"] for f in self.filters["S3Key"]["FilterRule"]} @@ -756,17 +774,15 @@ class Notification(BaseModel): ) return prefix_matches and suffix_matches - def matches(self, event_name, key_name): + def matches(self, event_name: str, key_name: str) -> bool: if self._event_matches(event_name): if self._key_matches(key_name): return True return False - def to_config_dict(self): - data = {} - + def to_config_dict(self) -> Dict[str, Any]: # Type and ARN will be filled in by NotificationConfiguration's to_config_dict: - data["events"] = [event for event in self.events] + data: Dict[str, Any] = {"events": [event for event in self.events]} if self.filters: data["filter"] = { @@ -787,7 +803,12 @@ class Notification(BaseModel): class NotificationConfiguration(BaseModel): - def __init__(self, topic=None, queue=None, cloud_function=None): + def __init__( + self, + topic: Optional[List[Dict[str, Any]]] = None, + queue: Optional[List[Dict[str, Any]]] = None, + cloud_function: Optional[List[Dict[str, Any]]] = None, + ): self.topic = ( [ Notification( @@ -828,8 +849,8 @@ class NotificationConfiguration(BaseModel): else [] ) - def to_config_dict(self): - data = {"configurations": {}} + def to_config_dict(self) -> Dict[str, Any]: + data: Dict[str, Any] = {"configurations": {}} for topic in self.topic: topic_config = topic.to_config_dict() @@ -852,7 +873,7 @@ class NotificationConfiguration(BaseModel): return data -def convert_str_to_bool(item): +def convert_str_to_bool(item: Any) -> bool: """Converts a boolean string to a boolean value""" if isinstance(item, str): return item.lower() == "true" @@ -863,10 +884,10 @@ def convert_str_to_bool(item): class PublicAccessBlock(BaseModel): def __init__( self, - block_public_acls, - ignore_public_acls, - block_public_policy, - restrict_public_buckets, + block_public_acls: Optional[str], + ignore_public_acls: Optional[str], + block_public_policy: Optional[str], + restrict_public_buckets: Optional[str], ): # The boto XML appears to expect these values to exist as lowercase strings... self.block_public_acls = block_public_acls or "false" @@ -874,7 +895,7 @@ class PublicAccessBlock(BaseModel): self.block_public_policy = block_public_policy or "false" self.restrict_public_buckets = restrict_public_buckets or "false" - def to_config_dict(self): + def to_config_dict(self) -> Dict[str, bool]: # Need to make the string values booleans for Config: return { "blockPublicAcls": convert_str_to_bool(self.block_public_acls), @@ -884,52 +905,52 @@ class PublicAccessBlock(BaseModel): } -class MultipartDict(dict): - def __delitem__(self, key): +class MultipartDict(Dict[str, FakeMultipart]): + def __delitem__(self, key: str) -> None: if key in self: self[key].dispose() super().__delitem__(key) class FakeBucket(CloudFormationModel): - def __init__(self, name, account_id, region_name): + def __init__(self, name: str, account_id: str, region_name: str): self.name = name self.account_id = account_id self.region_name = region_name self.keys = _VersionedKeyStore() self.multiparts = MultipartDict() - self.versioning_status = None - self.rules = [] - self.policy = None - self.website_configuration = None - self.acl = get_canned_acl("private") - self.cors = [] - self.logging = {} - self.notification_configuration = None - self.accelerate_configuration = None + self.versioning_status: Optional[str] = None + self.rules: List[LifecycleRule] = [] + self.policy: Optional[bytes] = None + self.website_configuration: Optional[Dict[str, Any]] = None + self.acl: Optional[FakeAcl] = get_canned_acl("private") + self.cors: List[CorsRule] = [] + self.logging: Dict[str, Any] = {} + self.notification_configuration: Optional[NotificationConfiguration] = None + self.accelerate_configuration: Optional[str] = None self.payer = "BucketOwner" self.creation_date = datetime.datetime.now(tz=datetime.timezone.utc) - self.public_access_block = None - self.encryption = None + self.public_access_block: Optional[PublicAccessBlock] = None + self.encryption: Optional[Dict[str, Any]] = None self.object_lock_enabled = False - self.default_lock_mode = "" - self.default_lock_days = 0 - self.default_lock_years = 0 - self.ownership_rule = None + self.default_lock_mode: Optional[str] = "" + self.default_lock_days: Optional[int] = 0 + self.default_lock_years: Optional[int] = 0 + self.ownership_rule: Optional[Dict[str, Any]] = None @property - def location(self): + def location(self) -> str: return self.region_name @property - def creation_date_ISO8601(self): - return iso_8601_datetime_without_milliseconds_s3(self.creation_date) + def creation_date_ISO8601(self) -> str: + return iso_8601_datetime_without_milliseconds_s3(self.creation_date) # type: ignore @property - def is_versioned(self): + def is_versioned(self) -> bool: return self.versioning_status == "Enabled" - def get_permission(self, action, resource): + def get_permission(self, action: str, resource: str) -> Any: from moto.iam.access_control import IAMPolicy, PermissionResult if self.policy is None: @@ -938,7 +959,7 @@ class FakeBucket(CloudFormationModel): iam_policy = IAMPolicy(self.policy.decode()) return iam_policy.is_action_permitted(action, resource) - def set_lifecycle(self, rules): + def set_lifecycle(self, rules: List[Dict[str, Any]]) -> None: self.rules = [] for rule in rules: # Extract and validate actions from Lifecycle rule @@ -1076,10 +1097,10 @@ class FakeBucket(CloudFormationModel): ) ) - def delete_lifecycle(self): + def delete_lifecycle(self) -> None: self.rules = [] - def set_cors(self, rules): + def set_cors(self, rules: List[Dict[str, Any]]) -> None: self.cors = [] if len(rules) > 100: @@ -1119,10 +1140,12 @@ class FakeBucket(CloudFormationModel): ) ) - def delete_cors(self): + def delete_cors(self) -> None: self.cors = [] - def set_logging(self, logging_config, bucket_backend): + def set_logging( + self, logging_config: Optional[Dict[str, Any]], bucket_backend: "S3Backend" + ) -> None: if not logging_config: self.logging = {} return @@ -1135,7 +1158,7 @@ class FakeBucket(CloudFormationModel): # Does the target bucket have the log-delivery WRITE and READ_ACP permissions? write = read_acp = False - for grant in bucket_backend.buckets[logging_config["TargetBucket"]].acl.grants: + for grant in bucket_backend.buckets[logging_config["TargetBucket"]].acl.grants: # type: ignore # Must be granted to: http://acs.amazonaws.com/groups/s3/LogDelivery for grantee in grant.grantees: if grantee.uri == "http://acs.amazonaws.com/groups/s3/LogDelivery": @@ -1169,7 +1192,9 @@ class FakeBucket(CloudFormationModel): # Checks pass -- set the logging config: self.logging = logging_config - def set_notification_configuration(self, notification_config): + def set_notification_configuration( + self, notification_config: Optional[Dict[str, Any]] + ) -> None: if not notification_config: self.notification_configuration = None return @@ -1190,7 +1215,7 @@ class FakeBucket(CloudFormationModel): # Send test events so the user can verify these notifications were set correctly notifications.send_test_event(account_id=self.account_id, bucket=self) - def set_accelerate_configuration(self, accelerate_config): + def set_accelerate_configuration(self, accelerate_config: str) -> None: if self.accelerate_configuration is None and accelerate_config == "Suspended": # Cannot "suspend" a not active acceleration. Leaves it undefined return @@ -1198,7 +1223,7 @@ class FakeBucket(CloudFormationModel): self.accelerate_configuration = accelerate_config @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in [ "Arn", "DomainName", @@ -1207,7 +1232,7 @@ class FakeBucket(CloudFormationModel): "WebsiteURL", ] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> Any: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": @@ -1222,46 +1247,51 @@ class FakeBucket(CloudFormationModel): return self.website_url raise UnformattedGetAttTemplateException() - def set_acl(self, acl): + def set_acl(self, acl: Optional[FakeAcl]) -> None: self.acl = acl @property - def arn(self): + def arn(self) -> str: return f"arn:aws:s3:::{self.name}" @property - def domain_name(self): + def domain_name(self) -> str: return f"{self.name}.s3.amazonaws.com" @property - def dual_stack_domain_name(self): + def dual_stack_domain_name(self) -> str: return f"{self.name}.s3.dualstack.{self.region_name}.amazonaws.com" @property - def regional_domain_name(self): + def regional_domain_name(self) -> str: return f"{self.name}.s3.{self.region_name}.amazonaws.com" @property - def website_url(self): + def website_url(self) -> str: return f"http://{self.name}.s3-website.{self.region_name}.amazonaws.com" @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.name @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "BucketName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-s3-bucket.html return "AWS::S3::Bucket" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "FakeBucket": bucket = s3_backends[account_id]["global"].create_bucket( resource_name, region_name ) @@ -1277,14 +1307,14 @@ class FakeBucket(CloudFormationModel): return bucket @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "FakeBucket": properties = cloudformation_json["Properties"] if is_replacement_update(properties): @@ -1314,18 +1344,22 @@ class FakeBucket(CloudFormationModel): return original_resource @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: s3_backends[account_id]["global"].delete_bucket(resource_name) - def to_config_dict(self): + def to_config_dict(self) -> Dict[str, Any]: """Return the AWS Config JSON format of this S3 bucket. Note: The following features are not implemented and will need to be if you care about them: - Bucket Accelerate Configuration """ - config_dict = { + config_dict: Dict[str, Any] = { "version": "1.3", "configurationItemCaptureTime": str(self.creation_date), "configurationItemStatus": "ResourceDiscovered", @@ -1352,8 +1386,8 @@ class FakeBucket(CloudFormationModel): # Make the supplementary configuration: # This is a dobule-wrapped JSON for some reason... - s_config = { - "AccessControlList": json.dumps(json.dumps(self.acl.to_config_dict())) + s_config: Dict[str, Any] = { + "AccessControlList": json.dumps(json.dumps(self.acl.to_config_dict())) # type: ignore } if self.public_access_block: @@ -1400,7 +1434,7 @@ class FakeBucket(CloudFormationModel): return config_dict @property - def has_default_lock(self): + def has_default_lock(self) -> bool: if not self.object_lock_enabled: return False @@ -1409,10 +1443,10 @@ class FakeBucket(CloudFormationModel): return False - def default_retention(self): + def default_retention(self) -> str: now = datetime.datetime.utcnow() - now += datetime.timedelta(self.default_lock_days) - now += datetime.timedelta(self.default_lock_years * 365) + now += datetime.timedelta(self.default_lock_days) # type: ignore + now += datetime.timedelta(self.default_lock_years * 365) # type: ignore return now.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -1434,22 +1468,22 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): Note that this only works if the environment variable is set **before** the mock is initialized. """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.buckets = {} + self.buckets: Dict[str, FakeBucket] = {} self.tagger = TaggingService() state_manager.register_default_transition( "s3::keyrestore", transition={"progression": "immediate"} ) - def reset(self): + def reset(self) -> None: # For every key and multipart, Moto opens a TemporaryFile to write the value of those keys # Ensure that these TemporaryFile-objects are closed, and leave no filehandles open # # First, check all known buckets/keys for bucket in self.buckets.values(): - for key in bucket.keys.values(): + for key in bucket.keys.values(): # type: ignore if isinstance(key, FakeKey): key.dispose() for part in bucket.multiparts.values(): @@ -1457,13 +1491,13 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): # # Second, go through the list of instances # It may contain FakeKeys created earlier, which are no longer tracked - for mp in FakeMultipart.instances: + for mp in FakeMultipart.instances: # type: ignore mp.dispose() - for key in FakeKey.instances: + for key in FakeKey.instances: # type: ignore key.dispose() super().reset() - def log_incoming_request(self, request, bucket_name): + def log_incoming_request(self, request: Any, bucket_name: str) -> None: """ Process incoming requests If the request is made to a bucket with logging enabled, logs will be persisted in the appropriate bucket @@ -1488,14 +1522,14 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): response = '200 - - 1 2 "-"' user_agent = f"{request.headers.get('User-Agent')} prompt/off command/s3api.put-object" content = f"{random.get_random_hex(64)} originbucket [{date}] {source_ip} {source_iam} {unknown_hex} {source} {key_name} {http_line} {response} {user_agent} - c29tZSB1bmtub3duIGRhdGE= SigV4 ECDHE-RSA-AES128-GCM-SHA256 AuthHeader {request.url.split('amazonaws.com')[0]}amazonaws.com TLSv1.2 - -" - self.put_object(target_bucket, prefix + file_name, value=content) + self.put_object(target_bucket, prefix + file_name, value=content) # type: ignore except: # noqa: E722 Do not use bare except # log delivery is not guaranteed in AWS, so if anything goes wrong, it's 'safe' to just ignore it # Realistically, we should only get here when the bucket does not exist, or logging is not enabled pass @property - def _url_module(self): + def _url_module(self) -> Any: # type: ignore # The urls-property can be different depending on env variables # Force a reload, to retrieve the correct set of URLs import moto.s3.urls as backend_urls_module @@ -1504,7 +1538,9 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): return backend_urls_module @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """List of dicts representing default VPC endpoints for this service.""" accesspoint = { "AcceptanceRequired": False, @@ -1535,14 +1571,8 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): + [accesspoint] ) - # TODO: This is broken! DO NOT IMPORT MUTABLE DATA TYPES FROM OTHER AREAS -- THIS BREAKS UNMOCKING! - # WRAP WITH A GETTER/SETTER FUNCTION - # Register this class as a CloudWatch Metric Provider - # Must provide a method 'get_cloudwatch_metrics' that will return a list of metrics, based on the data available - # metric_providers["S3"] = self - @classmethod - def get_cloudwatch_metrics(cls, account_id): + def get_cloudwatch_metrics(cls, account_id: str) -> List[MetricDatum]: metrics = [] for name, bucket in s3_backends[account_id]["global"].buckets.items(): metrics.append( @@ -1577,7 +1607,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): ) return metrics - def create_bucket(self, bucket_name, region_name): + def create_bucket(self, bucket_name: str, region_name: str) -> FakeBucket: if bucket_name in self.buckets: raise BucketAlreadyExists(bucket=bucket_name) if not MIN_BUCKET_NAME_LENGTH <= len(bucket_name) <= MAX_BUCKET_NAME_LENGTH: @@ -1606,10 +1636,10 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): return new_bucket - def list_buckets(self): - return self.buckets.values() + def list_buckets(self) -> List[FakeBucket]: + return list(self.buckets.values()) - def get_bucket(self, bucket_name) -> FakeBucket: + def get_bucket(self, bucket_name: str) -> FakeBucket: try: return self.buckets[bucket_name] except KeyError: @@ -1618,35 +1648,38 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): def head_bucket(self, bucket_name: str) -> FakeBucket: return self.get_bucket(bucket_name) - def delete_bucket(self, bucket_name): + def delete_bucket(self, bucket_name: str) -> Optional[FakeBucket]: bucket = self.get_bucket(bucket_name) if bucket.keys: # Can't delete a bucket with keys - return False + return None else: return self.buckets.pop(bucket_name) - def put_bucket_versioning(self, bucket_name, status): + def put_bucket_versioning(self, bucket_name: str, status: str) -> None: self.get_bucket(bucket_name).versioning_status = status - def get_bucket_versioning(self, bucket_name): + def get_bucket_versioning(self, bucket_name: str) -> Optional[str]: return self.get_bucket(bucket_name).versioning_status - def get_bucket_encryption(self, bucket_name): + def get_bucket_encryption(self, bucket_name: str) -> Optional[Dict[str, Any]]: return self.get_bucket(bucket_name).encryption def list_object_versions( - self, bucket_name, delimiter=None, key_marker=None, prefix="" - ): + self, + bucket_name: str, + delimiter: Optional[str] = None, + key_marker: Optional[str] = None, + prefix: str = "", + ) -> Tuple[List[FakeKey], List[str], List[FakeDeleteMarker]]: bucket = self.get_bucket(bucket_name) - common_prefixes = [] - requested_versions = [] - delete_markers = [] - all_versions = itertools.chain( - *(copy.deepcopy(l) for key, l in bucket.keys.iterlists()) + common_prefixes: List[str] = [] + requested_versions: List[FakeKey] = [] + delete_markers: List[FakeDeleteMarker] = [] + all_versions = list( + itertools.chain(*(copy.deepcopy(l) for key, l in bucket.keys.iterlists())) ) - all_versions = list(all_versions) # sort by name, revert last-modified-date all_versions.sort(key=lambda r: (r.name, -unix_time_millis(r.last_modified))) last_name = None @@ -1682,10 +1715,10 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): return requested_versions, common_prefixes, delete_markers - def get_bucket_policy(self, bucket_name): + def get_bucket_policy(self, bucket_name: str) -> Optional[bytes]: return self.get_bucket(bucket_name).policy - def put_bucket_policy(self, bucket_name, policy): + def put_bucket_policy(self, bucket_name: str, policy: bytes) -> None: """ Basic policy enforcement is in place. @@ -1695,30 +1728,38 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): """ self.get_bucket(bucket_name).policy = policy - def delete_bucket_policy(self, bucket_name): + def delete_bucket_policy(self, bucket_name: str) -> None: bucket = self.get_bucket(bucket_name) bucket.policy = None - def put_bucket_encryption(self, bucket_name, encryption): + def put_bucket_encryption( + self, bucket_name: str, encryption: Dict[str, Any] + ) -> None: self.get_bucket(bucket_name).encryption = encryption - def delete_bucket_encryption(self, bucket_name): + def delete_bucket_encryption(self, bucket_name: str) -> None: self.get_bucket(bucket_name).encryption = None - def get_bucket_ownership_controls(self, bucket_name): + def get_bucket_ownership_controls( + self, bucket_name: str + ) -> Optional[Dict[str, Any]]: return self.get_bucket(bucket_name).ownership_rule - def put_bucket_ownership_controls(self, bucket_name, ownership): + def put_bucket_ownership_controls( + self, bucket_name: str, ownership: Dict[str, Any] + ) -> None: self.get_bucket(bucket_name).ownership_rule = ownership - def delete_bucket_ownership_controls(self, bucket_name): + def delete_bucket_ownership_controls(self, bucket_name: str) -> None: self.get_bucket(bucket_name).ownership_rule = None - def get_bucket_replication(self, bucket_name): + def get_bucket_replication(self, bucket_name: str) -> Optional[Dict[str, Any]]: bucket = self.get_bucket(bucket_name) return getattr(bucket, "replication", None) - def put_bucket_replication(self, bucket_name, replication): + def put_bucket_replication( + self, bucket_name: str, replication: Dict[str, Any] + ) -> None: if isinstance(replication["Rule"], dict): replication["Rule"] = [replication["Rule"]] for rule in replication["Rule"]: @@ -1730,33 +1771,39 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): for _ in range(30) ) bucket = self.get_bucket(bucket_name) - bucket.replication = replication + bucket.replication = replication # type: ignore - def delete_bucket_replication(self, bucket_name): + def delete_bucket_replication(self, bucket_name: str) -> None: bucket = self.get_bucket(bucket_name) - bucket.replication = None + bucket.replication = None # type: ignore - def put_bucket_lifecycle(self, bucket_name, rules): + def put_bucket_lifecycle( + self, bucket_name: str, rules: List[Dict[str, Any]] + ) -> None: bucket = self.get_bucket(bucket_name) bucket.set_lifecycle(rules) - def delete_bucket_lifecycle(self, bucket_name): + def delete_bucket_lifecycle(self, bucket_name: str) -> None: bucket = self.get_bucket(bucket_name) bucket.delete_lifecycle() - def set_bucket_website_configuration(self, bucket_name, website_configuration): + def set_bucket_website_configuration( + self, bucket_name: str, website_configuration: Dict[str, Any] + ) -> None: bucket = self.get_bucket(bucket_name) bucket.website_configuration = website_configuration - def get_bucket_website_configuration(self, bucket_name): + def get_bucket_website_configuration( + self, bucket_name: str + ) -> Optional[Dict[str, Any]]: bucket = self.get_bucket(bucket_name) return bucket.website_configuration - def delete_bucket_website(self, bucket_name): + def delete_bucket_website(self, bucket_name: str) -> None: bucket = self.get_bucket(bucket_name) bucket.website_configuration = None - def get_public_access_block(self, bucket_name): + def get_public_access_block(self, bucket_name: str) -> PublicAccessBlock: bucket = self.get_bucket(bucket_name) if not bucket.public_access_block: @@ -1766,20 +1813,20 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): def put_object( self, - bucket_name, - key_name, - value, - storage=None, - etag=None, - multipart=None, - encryption=None, - kms_key_id=None, - bucket_key_enabled=None, - lock_mode=None, - lock_legal_status=None, - lock_until=None, - checksum_value=None, - ): + bucket_name: str, + key_name: str, + value: bytes, + storage: Optional[str] = None, + etag: Optional[str] = None, + multipart: Optional[FakeMultipart] = None, + encryption: Optional[str] = None, + kms_key_id: Optional[str] = None, + bucket_key_enabled: Any = None, + lock_mode: Optional[str] = None, + lock_legal_status: Optional[str] = None, + lock_until: Optional[str] = None, + checksum_value: Optional[str] = None, + ) -> FakeKey: key_name = clean_key_name(key_name) if storage is not None and storage not in STORAGE_CLASS: raise InvalidStorageClass(storage=storage) @@ -1809,7 +1856,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): storage=storage, etag=etag, is_versioned=bucket.is_versioned, - version_id=str(random.uuid4()) if bucket.is_versioned else "null", + version_id=str(random.uuid4()) if bucket.is_versioned else "null", # type: ignore multipart=multipart, encryption=encryption, kms_key_id=kms_key_id, @@ -1833,7 +1880,9 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): return new_key - def put_object_acl(self, bucket_name, key_name, acl): + def put_object_acl( + self, bucket_name: str, key_name: str, acl: Optional[FakeAcl] + ) -> None: key = self.get_object(bucket_name, key_name) # TODO: Support the XML-based ACL format if key is not None: @@ -1842,15 +1891,25 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): raise MissingKey(key=key_name) def put_object_legal_hold( - self, bucket_name, key_name, version_id, legal_hold_status - ): + self, + bucket_name: str, + key_name: str, + version_id: Optional[str], + legal_hold_status: Dict[str, Any], + ) -> None: key = self.get_object(bucket_name, key_name, version_id=version_id) - key.lock_legal_status = legal_hold_status + key.lock_legal_status = legal_hold_status # type: ignore - def put_object_retention(self, bucket_name, key_name, version_id, retention): + def put_object_retention( + self, + bucket_name: str, + key_name: str, + version_id: Optional[str], + retention: Tuple[Optional[str], Optional[str]], + ) -> None: key = self.get_object(bucket_name, key_name, version_id=version_id) - key.lock_mode = retention[0] - key.lock_until = retention[1] + key.lock_mode = retention[0] # type: ignore + key.lock_until = retention[1] # type: ignore def get_object_attributes( self, @@ -1860,7 +1919,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): """ The following attributes are not yet returned: DeleteMarker, RequestCharged, ObjectParts """ - response_keys = { + response_keys: Dict[str, Any] = { "etag": None, "checksum": None, "size": None, @@ -1878,11 +1937,11 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): def get_object( self, - bucket_name, - key_name, - version_id=None, - part_number=None, - key_is_clean=False, + bucket_name: str, + key_name: str, + version_id: Optional[str] = None, + part_number: Optional[str] = None, + key_is_clean: bool = False, ) -> Optional[FakeKey]: if not key_is_clean: key_name = clean_key_name(key_name) @@ -1908,16 +1967,24 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): else: return None - def head_object(self, bucket_name, key_name, version_id=None, part_number=None): + def head_object( + self, + bucket_name: str, + key_name: str, + version_id: Optional[str] = None, + part_number: Optional[str] = None, + ) -> Optional[FakeKey]: return self.get_object(bucket_name, key_name, version_id, part_number) - def get_object_acl(self, key): + def get_object_acl(self, key: FakeKey) -> Optional[FakeAcl]: return key.acl - def get_object_legal_hold(self, key): + def get_object_legal_hold(self, key: FakeKey) -> Optional[str]: return key.lock_legal_status - def get_object_lock_configuration(self, bucket_name): + def get_object_lock_configuration( + self, bucket_name: str + ) -> Tuple[bool, Optional[str], Optional[int], Optional[int]]: bucket = self.get_bucket(bucket_name) if not bucket.object_lock_enabled: raise ObjectLockConfigurationNotFoundError @@ -1928,10 +1995,15 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): bucket.default_lock_years, ) - def get_object_tagging(self, key): + def get_object_tagging(self, key: FakeKey) -> Dict[str, List[Dict[str, str]]]: return self.tagger.list_tags_for_resource(key.arn) - def set_key_tags(self, key, tags, key_name=None): + def set_key_tags( + self, + key: Optional[FakeKey], + tags: Optional[Dict[str, str]], + key_name: Optional[str] = None, + ) -> FakeKey: if key is None: raise MissingKey(key=key_name) boto_tags_dict = self.tagger.convert_dict_to_tags_input(tags) @@ -1942,11 +2014,11 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): self.tagger.tag_resource(key.arn, boto_tags_dict) return key - def get_bucket_tagging(self, bucket_name): + def get_bucket_tagging(self, bucket_name: str) -> Dict[str, List[Dict[str, str]]]: bucket = self.get_bucket(bucket_name) return self.tagger.list_tags_for_resource(bucket.arn) - def put_bucket_tagging(self, bucket_name, tags): + def put_bucket_tagging(self, bucket_name: str, tags: Dict[str, str]) -> None: bucket = self.get_bucket(bucket_name) self.tagger.delete_all_tags_for_resource(bucket.arn) self.tagger.tag_resource( @@ -1954,8 +2026,13 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): ) def put_object_lock_configuration( - self, bucket_name, lock_enabled, mode=None, days=None, years=None - ): + self, + bucket_name: str, + lock_enabled: bool, + mode: Optional[str] = None, + days: Optional[int] = None, + years: Optional[int] = None, + ) -> None: bucket = self.get_bucket(bucket_name) if bucket.keys.item_size() > 0: @@ -1969,27 +2046,33 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): bucket.default_lock_days = days bucket.default_lock_years = years - def delete_bucket_tagging(self, bucket_name): + def delete_bucket_tagging(self, bucket_name: str) -> None: bucket = self.get_bucket(bucket_name) self.tagger.delete_all_tags_for_resource(bucket.arn) - def put_bucket_cors(self, bucket_name, cors_rules): + def put_bucket_cors( + self, bucket_name: str, cors_rules: List[Dict[str, Any]] + ) -> None: bucket = self.get_bucket(bucket_name) bucket.set_cors(cors_rules) - def put_bucket_logging(self, bucket_name, logging_config): + def put_bucket_logging( + self, bucket_name: str, logging_config: Dict[str, Any] + ) -> None: bucket = self.get_bucket(bucket_name) bucket.set_logging(logging_config, self) - def delete_bucket_cors(self, bucket_name): + def delete_bucket_cors(self, bucket_name: str) -> None: bucket = self.get_bucket(bucket_name) bucket.delete_cors() - def delete_public_access_block(self, bucket_name): + def delete_public_access_block(self, bucket_name: str) -> None: bucket = self.get_bucket(bucket_name) bucket.public_access_block = None - def put_bucket_notification_configuration(self, bucket_name, notification_config): + def put_bucket_notification_configuration( + self, bucket_name: str, notification_config: Dict[str, Any] + ) -> None: """ The configuration can be persisted, but at the moment we only send notifications to the following targets: @@ -2005,8 +2088,8 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): bucket.set_notification_configuration(notification_config) def put_bucket_accelerate_configuration( - self, bucket_name, accelerate_configuration - ): + self, bucket_name: str, accelerate_configuration: str + ) -> None: if accelerate_configuration not in ["Enabled", "Suspended"]: raise MalformedXML() @@ -2015,7 +2098,9 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): raise InvalidRequest("PutBucketAccelerateConfiguration") bucket.set_accelerate_configuration(accelerate_configuration) - def put_bucket_public_access_block(self, bucket_name, pub_block_config): + def put_bucket_public_access_block( + self, bucket_name: str, pub_block_config: Optional[Dict[str, Any]] + ) -> None: bucket = self.get_bucket(bucket_name) if not pub_block_config: @@ -2028,7 +2113,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): pub_block_config.get("RestrictPublicBuckets"), ) - def abort_multipart_upload(self, bucket_name, multipart_id): + def abort_multipart_upload(self, bucket_name: str, multipart_id: str) -> None: bucket = self.get_bucket(bucket_name) multipart_data = bucket.multiparts.get(multipart_id, None) if not multipart_data: @@ -2036,8 +2121,12 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): del bucket.multiparts[multipart_id] def list_parts( - self, bucket_name, multipart_id, part_number_marker=0, max_parts=1000 - ): + self, + bucket_name: str, + multipart_id: str, + part_number_marker: int = 0, + max_parts: int = 1000, + ) -> List[FakeKey]: bucket = self.get_bucket(bucket_name) if multipart_id not in bucket.multiparts: raise NoSuchUpload(upload_id=multipart_id) @@ -2045,21 +2134,23 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): bucket.multiparts[multipart_id].list_parts(part_number_marker, max_parts) ) - def is_truncated(self, bucket_name, multipart_id, next_part_number_marker): + def is_truncated( + self, bucket_name: str, multipart_id: str, next_part_number_marker: int + ) -> bool: bucket = self.get_bucket(bucket_name) return len(bucket.multiparts[multipart_id].parts) > next_part_number_marker def create_multipart_upload( self, - bucket_name, - key_name, - metadata, - storage_type, - tags, - acl, - sse_encryption, - kms_key_id, - ): + bucket_name: str, + key_name: str, + metadata: CaseInsensitiveDict, # type: ignore + storage_type: str, + tags: Dict[str, str], + acl: Optional[FakeAcl], + sse_encryption: str, + kms_key_id: str, + ) -> str: multipart = FakeMultipart( key_name, metadata, @@ -2074,7 +2165,9 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): bucket.multiparts[multipart.id] = multipart return multipart.id - def complete_multipart_upload(self, bucket_name, multipart_id, body): + def complete_multipart_upload( + self, bucket_name: str, multipart_id: str, body: Iterator[Tuple[int, str]] + ) -> Tuple[FakeMultipart, bytes, str]: bucket = self.get_bucket(bucket_name) multipart = bucket.multiparts[multipart_id] value, etag = multipart.complete(body) @@ -2082,41 +2175,45 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): del bucket.multiparts[multipart_id] return multipart, value, etag - def get_all_multiparts(self, bucket_name): + def get_all_multiparts(self, bucket_name: str) -> Dict[str, FakeMultipart]: bucket = self.get_bucket(bucket_name) return bucket.multiparts - def upload_part(self, bucket_name, multipart_id, part_id, value): + def upload_part( + self, bucket_name: str, multipart_id: str, part_id: int, value: bytes + ) -> FakeKey: bucket = self.get_bucket(bucket_name) multipart = bucket.multiparts[multipart_id] return multipart.set_part(part_id, value) def copy_part( self, - dest_bucket_name, - multipart_id, - part_id, - src_bucket_name, - src_key_name, - src_version_id, - start_byte, - end_byte, - ): + dest_bucket_name: str, + multipart_id: str, + part_id: int, + src_bucket_name: str, + src_key_name: str, + src_version_id: str, + start_byte: int, + end_byte: int, + ) -> FakeKey: dest_bucket = self.get_bucket(dest_bucket_name) multipart = dest_bucket.multiparts[multipart_id] - src_value = self.get_object( + src_value = self.get_object( # type: ignore src_bucket_name, src_key_name, version_id=src_version_id ).value if start_byte is not None: src_value = src_value[start_byte : end_byte + 1] return multipart.set_part(part_id, src_value) - def list_objects(self, bucket, prefix, delimiter): + def list_objects( + self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str] + ) -> Tuple[Set[FakeKey], Set[str]]: key_results = set() folder_results = set() if prefix: - for key_name, key in bucket.keys.items(): + for key_name, key in bucket.keys.items(): # type: ignore if key_name.startswith(prefix): key_without_prefix = key_name.replace(prefix, "", 1) if delimiter and delimiter in key_without_prefix: @@ -2128,48 +2225,58 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): else: key_results.add(key) else: - for key_name, key in bucket.keys.items(): + for key_name, key in bucket.keys.items(): # type: ignore if delimiter and delimiter in key_name: # If delimiter, we need to split out folder_results folder_results.add(key_name.split(delimiter)[0] + delimiter) else: key_results.add(key) - key_results = filter( + key_results = filter( # type: ignore lambda key: not isinstance(key, FakeDeleteMarker), key_results ) - key_results = sorted(key_results, key=lambda key: key.name) - folder_results = [ + key_results = sorted(key_results, key=lambda key: key.name) # type: ignore + folder_results = [ # type: ignore folder_name for folder_name in sorted(folder_results, key=lambda key: key) ] return key_results, folder_results - def list_objects_v2(self, bucket, prefix, delimiter): + def list_objects_v2( + self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str] + ) -> Set[Union[FakeKey, str]]: result_keys, result_folders = self.list_objects(bucket, prefix, delimiter) # sort the combination of folders and keys into lexicographical order - all_keys = result_keys + result_folders + all_keys = result_keys + result_folders # type: ignore all_keys.sort(key=self._get_name) return all_keys @staticmethod - def _get_name(key): + def _get_name(key: Union[str, FakeKey]) -> str: if isinstance(key, FakeKey): return key.name else: return key - def _set_delete_marker(self, bucket_name, key_name): + def _set_delete_marker(self, bucket_name: str, key_name: str) -> FakeDeleteMarker: bucket = self.get_bucket(bucket_name) delete_marker = FakeDeleteMarker(key=bucket.keys[key_name]) bucket.keys[key_name] = delete_marker return delete_marker - def delete_object_tagging(self, bucket_name, key_name, version_id=None): + def delete_object_tagging( + self, bucket_name: str, key_name: str, version_id: Optional[str] = None + ) -> None: key = self.get_object(bucket_name, key_name, version_id=version_id) - self.tagger.delete_all_tags_for_resource(key.arn) + self.tagger.delete_all_tags_for_resource(key.arn) # type: ignore - def delete_object(self, bucket_name, key_name, version_id=None, bypass=False): + def delete_object( + self, + bucket_name: str, + key_name: str, + version_id: Optional[str] = None, + bypass: bool = False, + ) -> Tuple[bool, Optional[Dict[str, Any]]]: key_name = clean_key_name(key_name) bucket = self.get_bucket(bucket_name) @@ -2200,7 +2307,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): raise AccessDeniedByLock if type(key) is FakeDeleteMarker: - if type(key.key) is FakeDeleteMarker: + if type(key.key) is FakeDeleteMarker: # type: ignore # Our key is a DeleteMarker, that usually contains a link to the actual FakeKey # But: If we have deleted the FakeKey multiple times, # We have a DeleteMarker linking to a DeleteMarker (etc..) linking to a FakeKey @@ -2225,7 +2332,9 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): except KeyError: return False, None - def delete_objects(self, bucket_name, objects): + def delete_objects( + self, bucket_name: str, objects: List[Dict[str, Any]] + ) -> List[Tuple[str, Optional[str]]]: deleted_objects = [] for object_ in objects: key_name = object_["Key"] @@ -2239,16 +2348,16 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): def copy_object( self, - src_key, - dest_bucket_name, - dest_key_name, - storage=None, - acl=None, - encryption=None, - kms_key_id=None, - bucket_key_enabled=False, - mdirective=None, - ): + src_key: FakeKey, + dest_bucket_name: str, + dest_key_name: str, + storage: Optional[str] = None, + acl: Optional[FakeAcl] = None, + encryption: Optional[str] = None, + kms_key_id: Optional[str] = None, + bucket_key_enabled: bool = False, + mdirective: Optional[str] = None, + ) -> None: if ( src_key.name == dest_key_name and src_key.bucket_name == dest_bucket_name @@ -2289,32 +2398,34 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): self.account_id, notifications.S3_OBJECT_CREATE_COPY, bucket, new_key ) - def put_bucket_acl(self, bucket_name, acl): + def put_bucket_acl(self, bucket_name: str, acl: Optional[FakeAcl]) -> None: bucket = self.get_bucket(bucket_name) bucket.set_acl(acl) - def get_bucket_acl(self, bucket_name): + def get_bucket_acl(self, bucket_name: str) -> Optional[FakeAcl]: bucket = self.get_bucket(bucket_name) return bucket.acl - def get_bucket_cors(self, bucket_name): + def get_bucket_cors(self, bucket_name: str) -> List[CorsRule]: bucket = self.get_bucket(bucket_name) return bucket.cors - def get_bucket_lifecycle(self, bucket_name): + def get_bucket_lifecycle(self, bucket_name: str) -> List[LifecycleRule]: bucket = self.get_bucket(bucket_name) return bucket.rules - def get_bucket_location(self, bucket_name): + def get_bucket_location(self, bucket_name: str) -> str: bucket = self.get_bucket(bucket_name) return bucket.location - def get_bucket_logging(self, bucket_name): + def get_bucket_logging(self, bucket_name: str) -> Dict[str, Any]: bucket = self.get_bucket(bucket_name) return bucket.logging - def get_bucket_notification_configuration(self, bucket_name): + def get_bucket_notification_configuration( + self, bucket_name: str + ) -> Optional[NotificationConfiguration]: bucket = self.get_bucket(bucket_name) return bucket.notification_configuration @@ -2325,7 +2436,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): select_query: str, input_details: Dict[str, Any], output_details: Dict[str, Any], # pylint: disable=unused-argument - ): + ) -> List[bytes]: """ Highly experimental. Please raise an issue if you find any inconsistencies/bugs. @@ -2337,7 +2448,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): """ self.get_bucket(bucket_name) key = self.get_object(bucket_name, key_name) - query_input = key.value.decode("utf-8") + query_input = key.value.decode("utf-8") # type: ignore if "CSV" in input_details: # input is in CSV - we need to convert it to JSON before parsing from py_partiql_parser._internal.csv_converter import ( # noqa # pylint: disable=unused-import diff --git a/moto/s3/notifications.py b/moto/s3/notifications.py index 12be09203..5258b4f11 100644 --- a/moto/s3/notifications.py +++ b/moto/s3/notifications.py @@ -1,5 +1,6 @@ import json from datetime import datetime +from typing import Any, Dict, List _EVENT_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f" @@ -7,7 +8,9 @@ S3_OBJECT_CREATE_COPY = "s3:ObjectCreated:Copy" S3_OBJECT_CREATE_PUT = "s3:ObjectCreated:Put" -def _get_s3_event(event_name, bucket, key, notification_id): +def _get_s3_event( + event_name: str, bucket: Any, key: Any, notification_id: str +) -> Dict[str, List[Dict[str, Any]]]: etag = key.etag.replace('"', "") # s3:ObjectCreated:Put --> ObjectCreated:Put event_name = event_name[3:] @@ -34,11 +37,11 @@ def _get_s3_event(event_name, bucket, key, notification_id): } -def _get_region_from_arn(arn): +def _get_region_from_arn(arn: str) -> str: return arn.split(":")[3] -def send_event(account_id, event_name, bucket, key): +def send_event(account_id: str, event_name: str, bucket: Any, key: Any) -> None: if bucket.notification_configuration is None: return @@ -58,7 +61,9 @@ def send_event(account_id, event_name, bucket, key): _send_sqs_message(account_id, event_body, queue_name, region_name) -def _send_sqs_message(account_id, event_body, queue_name, region_name): +def _send_sqs_message( + account_id: str, event_body: Any, queue_name: str, region_name: str +) -> None: try: from moto.sqs.models import sqs_backends @@ -74,7 +79,9 @@ def _send_sqs_message(account_id, event_body, queue_name, region_name): pass -def _invoke_awslambda(account_id, event_body, fn_arn, region_name): +def _invoke_awslambda( + account_id: str, event_body: Any, fn_arn: str, region_name: str +) -> None: try: from moto.awslambda.models import lambda_backends @@ -89,7 +96,7 @@ def _invoke_awslambda(account_id, event_body, fn_arn, region_name): pass -def _get_test_event(bucket_name): +def _get_test_event(bucket_name: str) -> Dict[str, Any]: event_time = datetime.now().strftime(_EVENT_TIME_FORMAT) return { "Service": "Amazon S3", @@ -99,7 +106,7 @@ def _get_test_event(bucket_name): } -def send_test_event(account_id, bucket): +def send_test_event(account_id: str, bucket: Any) -> None: arns = [n.arn for n in bucket.notification_configuration.queue] for arn in set(arns): region_name = _get_region_from_arn(arn) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 9b72d0bb5..755c02b62 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -1,7 +1,7 @@ import io import os import re -from typing import List, Union +from typing import Any, Dict, List, Iterator, Union, Tuple, Optional, Type import urllib.parse @@ -14,6 +14,7 @@ from urllib.parse import parse_qs, urlparse, unquote, urlencode, urlunparse import xmltodict +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse from moto.core.utils import path_url @@ -53,7 +54,7 @@ from .exceptions import ( AccessForbidden, ) from .models import s3_backends, S3Backend -from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey +from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeBucket from .select_object_content import serialize_select from .utils import ( bucket_name_from_url, @@ -146,13 +147,13 @@ ACTION_MAP = { } -def parse_key_name(pth): +def parse_key_name(pth: str) -> str: # strip the first '/' left by urlparse return pth[1:] if pth.startswith("/") else pth class S3Response(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="s3") @property @@ -160,10 +161,10 @@ class S3Response(BaseResponse): return s3_backends[self.current_account]["global"] @property - def should_autoescape(self): + def should_autoescape(self) -> bool: return True - def all_buckets(self): + def all_buckets(self) -> str: self.data["Action"] = "ListAllMyBuckets" self._authenticate_and_authorize_s3_action() @@ -172,7 +173,7 @@ class S3Response(BaseResponse): template = self.response_template(S3_ALL_BUCKETS) return template.render(buckets=all_buckets) - def subdomain_based_buckets(self, request): + def subdomain_based_buckets(self, request: Any) -> bool: if settings.S3_IGNORE_SUBDOMAIN_BUCKETNAME: return False host = request.headers.get("host", request.headers.get("Host")) @@ -224,23 +225,25 @@ class S3Response(BaseResponse): ) return not path_based - def is_delete_keys(self): + def is_delete_keys(self) -> bool: qs = parse_qs(urlparse(self.path).query, keep_blank_values=True) return "delete" in qs - def parse_bucket_name_from_url(self, request, url): + def parse_bucket_name_from_url(self, request: Any, url: str) -> str: if self.subdomain_based_buckets(request): - return bucket_name_from_url(url) + return bucket_name_from_url(url) # type: ignore else: - return bucketpath_bucket_name_from_url(url) + return bucketpath_bucket_name_from_url(url) # type: ignore - def parse_key_name(self, request, url): + def parse_key_name(self, request: Any, url: str) -> str: if self.subdomain_based_buckets(request): return parse_key_name(url) else: return bucketpath_parse_key_name(url) - def ambiguous_response(self, request, full_url, headers): + def ambiguous_response( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: # Depending on which calling format the client is using, we don't know # if this is a bucket or key request so we have to check if self.subdomain_based_buckets(request): @@ -250,7 +253,7 @@ class S3Response(BaseResponse): return self.bucket_response(request, full_url, headers) @amzn_request_id - def bucket_response(self, request, full_url, headers): + def bucket_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore self.setup_class(request, full_url, headers, use_raw_body=True) bucket_name = self.parse_bucket_name_from_url(request, full_url) self.backend.log_incoming_request(request, bucket_name) @@ -262,7 +265,7 @@ class S3Response(BaseResponse): return self._send_response(response) @staticmethod - def _send_response(response): + def _send_response(response: Any) -> TYPE_RESPONSE: # type: ignore if isinstance(response, str): return 200, {}, response.encode("utf-8") else: @@ -272,7 +275,9 @@ class S3Response(BaseResponse): return status_code, headers, response_content - def _bucket_response(self, request, full_url): + def _bucket_response( + self, request: Any, full_url: str + ) -> Union[str, TYPE_RESPONSE]: querystring = self._get_querystring(request, full_url) method = request.method region_name = parse_region_from_url(full_url, use_default_region=False) @@ -309,7 +314,7 @@ class S3Response(BaseResponse): ) @staticmethod - def _get_querystring(request, full_url): + def _get_querystring(request: Any, full_url: str) -> Dict[str, Any]: # type: ignore[misc] # Flask's Request has the querystring already parsed # In ServerMode, we can use this, instead of manually parsing this if hasattr(request, "args"): @@ -330,10 +335,11 @@ class S3Response(BaseResponse): # Workaround - manually reverse the encoding. # Keep the + encoded, ensuring that parse_qsl doesn't replace it, and parse_qsl will unquote it afterwards qs = (parsed_url.query or "").replace("+", "%2B") - querystring = parse_qs(qs, keep_blank_values=True) - return querystring + return parse_qs(qs, keep_blank_values=True) - def _bucket_response_head(self, bucket_name, querystring): + def _bucket_response_head( + self, bucket_name: str, querystring: Dict[str, Any] + ) -> TYPE_RESPONSE: self._set_action("BUCKET", "HEAD", querystring) self._authenticate_and_authorize_s3_action() @@ -347,7 +353,7 @@ class S3Response(BaseResponse): return 404, {}, "" return 200, {"x-amz-bucket-region": bucket.region_name}, "" - def _set_cors_headers(self, headers, bucket): + def _set_cors_headers(self, headers: Dict[str, str], bucket: FakeBucket) -> None: """ TODO: smarter way of matching the right CORS rule: See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html @@ -372,8 +378,8 @@ class S3Response(BaseResponse): ) if cors_rule.allowed_origins is not None: origin = headers.get("Origin") - if cors_matches_origin(origin, cors_rule.allowed_origins): - self.response_headers["Access-Control-Allow-Origin"] = origin + if cors_matches_origin(origin, cors_rule.allowed_origins): # type: ignore + self.response_headers["Access-Control-Allow-Origin"] = origin # type: ignore else: raise AccessForbidden( "CORSResponse: This CORS request is not allowed. This is usually because the evalution of Origin, request method / Access-Control-Request-Method or Access-Control-Request-Headers are not whitelisted by the resource's CORS spec." @@ -391,23 +397,24 @@ class S3Response(BaseResponse): cors_rule.max_age_seconds ) - def _response_options(self, headers, bucket_name): + def _response_options( + self, headers: Dict[str, str], bucket_name: str + ) -> TYPE_RESPONSE: # Return 200 with the headers from the bucket CORS configuration self._authenticate_and_authorize_s3_action() try: bucket = self.backend.head_bucket(bucket_name) except MissingBucket: - return ( - 403, - {}, - "", - ) # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD + # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD + return 403, {}, "" self._set_cors_headers(headers, bucket) return 200, self.response_headers, "" - def _bucket_response_get(self, bucket_name, querystring): + def _bucket_response_get( + self, bucket_name: str, querystring: Dict[str, Any] + ) -> Union[str, TYPE_RESPONSE]: self._set_action("BUCKET", "GET", querystring) self._authenticate_and_authorize_s3_action() @@ -445,7 +452,7 @@ class S3Response(BaseResponse): account_id=self.current_account, ) elif "location" in querystring: - location = self.backend.get_bucket_location(bucket_name) + location: Optional[str] = self.backend.get_bucket_location(bucket_name) template = self.response_template(S3_BUCKET_LOCATION) # us-east-1 is different - returns a None location @@ -477,7 +484,7 @@ class S3Response(BaseResponse): if not website_configuration: template = self.response_template(S3_NO_BUCKET_WEBSITE_CONFIG) return 404, {}, template.render(bucket_name=bucket_name) - return 200, {}, website_configuration + return 200, {}, website_configuration # type: ignore elif "acl" in querystring: acl = self.backend.get_bucket_acl(bucket_name) template = self.response_template(S3_OBJECT_ACL_RESPONSE) @@ -615,7 +622,9 @@ class S3Response(BaseResponse): ), ) - def _set_action(self, action_resource_type, method, querystring): + def _set_action( + self, action_resource_type: str, method: str, querystring: Dict[str, Any] + ) -> None: action_set = False for action_in_querystring, action in ACTION_MAP[action_resource_type][ method @@ -626,7 +635,9 @@ class S3Response(BaseResponse): if not action_set: self.data["Action"] = ACTION_MAP[action_resource_type][method]["DEFAULT"] - def _handle_list_objects_v2(self, bucket_name, querystring): + def _handle_list_objects_v2( + self, bucket_name: str, querystring: Dict[str, Any] + ) -> str: template = self.response_template(S3_BUCKET_GET_RESPONSE_V2) bucket = self.backend.get_bucket(bucket_name) @@ -678,7 +689,7 @@ class S3Response(BaseResponse): ) @staticmethod - def _split_truncated_keys(truncated_keys): + def _split_truncated_keys(truncated_keys: Any) -> Any: # type: ignore[misc] result_keys = [] result_folders = [] for key in truncated_keys: @@ -688,7 +699,7 @@ class S3Response(BaseResponse): result_folders.append(key) return result_keys, result_folders - def _get_results_from_token(self, result_keys, token): + def _get_results_from_token(self, result_keys: Any, token: Any) -> Any: continuation_index = 0 for key in result_keys: if (key.name if isinstance(key, FakeKey) else key) > token: @@ -696,22 +707,22 @@ class S3Response(BaseResponse): continuation_index += 1 return result_keys[continuation_index:] - def _truncate_result(self, result_keys, max_keys): + def _truncate_result(self, result_keys: Any, max_keys: int) -> Any: if max_keys == 0: result_keys = [] is_truncated = True next_continuation_token = None elif len(result_keys) > max_keys: - is_truncated = "true" + is_truncated = "true" # type: ignore result_keys = result_keys[:max_keys] item = result_keys[-1] next_continuation_token = item.name if isinstance(item, FakeKey) else item else: - is_truncated = "false" + is_truncated = "false" # type: ignore next_continuation_token = None return result_keys, is_truncated, next_continuation_token - def _body_contains_location_constraint(self, body): + def _body_contains_location_constraint(self, body: bytes) -> bool: if body: try: xmltodict.parse(body)["CreateBucketConfiguration"]["LocationConstraint"] @@ -720,7 +731,7 @@ class S3Response(BaseResponse): pass return False - def _create_bucket_configuration_is_empty(self, body): + def _create_bucket_configuration_is_empty(self, body: bytes) -> bool: if body: try: create_bucket_configuration = xmltodict.parse(body)[ @@ -733,13 +744,19 @@ class S3Response(BaseResponse): pass return False - def _parse_pab_config(self): + def _parse_pab_config(self) -> Dict[str, Any]: parsed_xml = xmltodict.parse(self.body) parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None) return parsed_xml - def _bucket_response_put(self, request, region_name, bucket_name, querystring): + def _bucket_response_put( + self, + request: Any, + region_name: str, + bucket_name: str, + querystring: Dict[str, Any], + ) -> Union[str, TYPE_RESPONSE]: if querystring and not request.headers.get("Content-Length"): return 411, {}, "Content-Length required" @@ -754,7 +771,7 @@ class S3Response(BaseResponse): self.backend.put_object_lock_configuration( bucket_name, - config.get("enabled"), + config.get("enabled"), # type: ignore config.get("mode"), config.get("days"), config.get("years"), @@ -765,7 +782,7 @@ class S3Response(BaseResponse): body = self.body.decode("utf-8") ver = re.search(r"([A-Za-z]+)", body) if ver: - self.backend.put_bucket_versioning(bucket_name, ver.group(1)) + self.backend.put_bucket_versioning(bucket_name, ver.group(1)) # type: ignore template = self.response_template(S3_BUCKET_VERSIONING) return template.render(bucket_versioning_status=ver.group(1)) else: @@ -922,7 +939,9 @@ class S3Response(BaseResponse): template = self.response_template(S3_BUCKET_CREATE_RESPONSE) return 200, {}, template.render(bucket=new_bucket) - def _bucket_response_delete(self, bucket_name, querystring): + def _bucket_response_delete( + self, bucket_name: str, querystring: Dict[str, Any] + ) -> TYPE_RESPONSE: self._set_action("BUCKET", "DELETE", querystring) self._authenticate_and_authorize_s3_action() @@ -965,7 +984,7 @@ class S3Response(BaseResponse): template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR) return 409, {}, template.render(bucket=removed_bucket) - def _bucket_response_post(self, request, bucket_name): + def _bucket_response_post(self, request: Any, bucket_name: str) -> TYPE_RESPONSE: response_headers = {} if not request.headers.get("Content-Length"): return 411, {}, "Content-Length required" @@ -999,7 +1018,7 @@ class S3Response(BaseResponse): if "success_action_redirect" in form: redirect = form["success_action_redirect"] parts = urlparse(redirect) - queryargs = parse_qs(parts.query) + queryargs: Dict[str, Any] = parse_qs(parts.query) queryargs["key"] = key queryargs["bucket"] = bucket_name redirect_queryargs = urlencode(queryargs, doseq=True) @@ -1035,14 +1054,16 @@ class S3Response(BaseResponse): return status_code, response_headers, "" @staticmethod - def _get_path(request): + def _get_path(request: Any) -> str: # type: ignore[misc] return ( request.full_path if hasattr(request, "full_path") else path_url(request.url) ) - def _bucket_response_delete_keys(self, bucket_name, authenticated=True): + def _bucket_response_delete_keys( + self, bucket_name: str, authenticated: bool = True + ) -> TYPE_RESPONSE: template = self.response_template(S3_DELETE_KEYS_RESPONSE) body_dict = xmltodict.parse(self.body, strip_whitespace=False) @@ -1068,14 +1089,16 @@ class S3Response(BaseResponse): template.render(deleted=deleted_objects, delete_errors=errors), ) - def _handle_range_header(self, request, response_headers, response_content): + def _handle_range_header( + self, request: Any, response_headers: Dict[str, Any], response_content: Any + ) -> TYPE_RESPONSE: length = len(response_content) last = length - 1 _, rspec = request.headers.get("range").split("=") if "," in rspec: raise NotImplementedError("Multiple range specifiers not supported") - def toint(i): + def toint(i: Any) -> Optional[int]: return int(i) if i else None begin, end = map(toint, rspec.split("-")) @@ -1095,7 +1118,7 @@ class S3Response(BaseResponse): response_headers["content-length"] = len(content) return 206, response_headers, content - def _handle_v4_chunk_signatures(self, body, content_length): + def _handle_v4_chunk_signatures(self, body: bytes, content_length: int) -> bytes: body_io = io.BytesIO(body) new_body = bytearray(content_length) pos = 0 @@ -1110,7 +1133,7 @@ class S3Response(BaseResponse): line = body_io.readline() return bytes(new_body) - def _handle_encoded_body(self, body, content_length): + def _handle_encoded_body(self, body: bytes, content_length: int) -> bytes: body_io = io.BytesIO(body) # first line should equal '{content_length}\r\n body_io.readline() @@ -1120,12 +1143,12 @@ class S3Response(BaseResponse): # amz-checksum-sha256:<..>\r\n @amzn_request_id - def key_response(self, request, full_url, headers): + def key_response(self, request: Any, full_url: str, headers: Dict[str, Any]) -> TYPE_RESPONSE: # type: ignore[misc] # Key and Control are lumped in because splitting out the regex is too much of a pain :/ self.setup_class(request, full_url, headers, use_raw_body=True) bucket_name = self.parse_bucket_name_from_url(request, full_url) self.backend.log_incoming_request(request, bucket_name) - response_headers = {} + response_headers: Dict[str, Any] = {} try: response = self._key_response(request, full_url, self.headers) @@ -1151,7 +1174,9 @@ class S3Response(BaseResponse): return s3error.code, {}, s3error.description return status_code, response_headers, response_content - def _key_response(self, request, full_url, headers): + def _key_response( + self, request: Any, full_url: str, headers: Dict[str, Any] + ) -> TYPE_RESPONSE: parsed_url = urlparse(full_url) query = parse_qs(parsed_url.query, keep_blank_values=True) method = request.method @@ -1182,7 +1207,7 @@ class S3Response(BaseResponse): from moto.iam.access_control import PermissionResult action = f"s3:{method.upper()[0]}{method.lower()[1:]}Object" - bucket_permissions = bucket.get_permission(action, resource) + bucket_permissions = bucket.get_permission(action, resource) # type: ignore if bucket_permissions == PermissionResult.DENIED: return 403, {}, "" @@ -1255,11 +1280,17 @@ class S3Response(BaseResponse): f"Method {method} has not been implemented in the S3 backend yet" ) - def _key_response_get(self, bucket_name, query, key_name, headers): + def _key_response_get( + self, + bucket_name: str, + query: Dict[str, Any], + key_name: str, + headers: Dict[str, Any], + ) -> TYPE_RESPONSE: self._set_action("KEY", "GET", query) self._authenticate_and_authorize_s3_action() - response_headers = {} + response_headers: Dict[str, Any] = {} if query.get("uploadId"): upload_id = query["uploadId"][0] @@ -1287,7 +1318,7 @@ class S3Response(BaseResponse): ) next_part_number_marker = parts[-1].name if parts else 0 is_truncated = len(parts) != 0 and self.backend.is_truncated( - bucket_name, upload_id, next_part_number_marker + bucket_name, upload_id, next_part_number_marker # type: ignore ) template = self.response_template(S3_MULTIPART_LIST_RESPONSE) @@ -1355,7 +1386,7 @@ class S3Response(BaseResponse): attributes_to_get = headers.get("x-amz-object-attributes", "").split(",") response_keys = self.backend.get_object_attributes(key, attributes_to_get) - if key.version_id == "null": + if key.version_id == "null": # type: ignore response_headers.pop("x-amz-version-id") response_headers["Last-Modified"] = key.last_modified_ISO8601 @@ -1367,11 +1398,18 @@ class S3Response(BaseResponse): response_headers.update({"AcceptRanges": "bytes"}) return 200, response_headers, key.value - def _key_response_put(self, request, body, bucket_name, query, key_name): + def _key_response_put( + self, + request: Any, + body: bytes, + bucket_name: str, + query: Dict[str, Any], + key_name: str, + ) -> TYPE_RESPONSE: self._set_action("KEY", "PUT", query) self._authenticate_and_authorize_s3_action() - response_headers = {} + response_headers: Dict[str, Any] = {} if query.get("uploadId") and query.get("partNumber"): upload_id = query["uploadId"][0] part_number = int(query["partNumber"][0]) @@ -1382,7 +1420,7 @@ class S3Response(BaseResponse): copy_source_parsed = urlparse(copy_source) src_bucket, src_key = copy_source_parsed.path.lstrip("/").split("/", 1) src_version_id = parse_qs(copy_source_parsed.query).get( - "versionId", [None] + "versionId", [None] # type: ignore )[0] src_range = request.headers.get("x-amz-copy-source-range", "").split( "bytes=" @@ -1515,9 +1553,11 @@ class S3Response(BaseResponse): version_id = query["versionId"][0] else: version_id = None - key = self.backend.get_object(bucket_name, key_name, version_id=version_id) + key_to_tag = self.backend.get_object( + bucket_name, key_name, version_id=version_id + ) tagging = self._tagging_from_xml(body) - self.backend.set_key_tags(key, tagging, key_name) + self.backend.set_key_tags(key_to_tag, tagging, key_name) return 200, response_headers, "" if "x-amz-copy-source" in request.headers: @@ -1532,21 +1572,21 @@ class S3Response(BaseResponse): unquote(copy_source_parsed.path).lstrip("/").split("/", 1) ) src_version_id = parse_qs(copy_source_parsed.query).get( - "versionId", [None] + "versionId", [None] # type: ignore )[0] - key = self.backend.get_object( + key_to_copy = self.backend.get_object( src_bucket, src_key, version_id=src_version_id, key_is_clean=True ) - if key is not None: - if key.storage_class in ARCHIVE_STORAGE_CLASSES: - if key.response_dict.get( + if key_to_copy is not None: + if key_to_copy.storage_class in ARCHIVE_STORAGE_CLASSES: + if key_to_copy.response_dict.get( "x-amz-restore" - ) is None or 'ongoing-request="true"' in key.response_dict.get( + ) is None or 'ongoing-request="true"' in key_to_copy.response_dict.get( # type: ignore "x-amz-restore" ): - raise ObjectNotInActiveTierError(key) + raise ObjectNotInActiveTierError(key_to_copy) bucket_key_enabled = ( request.headers.get( @@ -1558,7 +1598,7 @@ class S3Response(BaseResponse): mdirective = request.headers.get("x-amz-metadata-directive") self.backend.copy_object( - key, + key_to_copy, bucket_name, key_name, storage=storage_class, @@ -1571,7 +1611,7 @@ class S3Response(BaseResponse): else: raise MissingKey(key=src_key) - new_key = self.backend.get_object(bucket_name, key_name) + new_key: FakeKey = self.backend.get_object(bucket_name, key_name) # type: ignore if mdirective is not None and mdirective == "REPLACE": metadata = metadata_from_headers(request.headers) new_key.set_metadata(metadata, replace=True) @@ -1612,11 +1652,17 @@ class S3Response(BaseResponse): response_headers.update(new_key.response_dict) return 200, response_headers, "" - def _key_response_head(self, bucket_name, query, key_name, headers): + def _key_response_head( + self, + bucket_name: str, + query: Dict[str, Any], + key_name: str, + headers: Dict[str, Any], + ) -> TYPE_RESPONSE: self._set_action("KEY", "HEAD", query) self._authenticate_and_authorize_s3_action() - response_headers = {} + response_headers: Dict[str, Any] = {} version_id = query.get("versionId", [None])[0] if version_id and not self.backend.get_bucket(bucket_name).is_versioned: return 400, response_headers, "" @@ -1654,16 +1700,21 @@ class S3Response(BaseResponse): if part_number: full_key = self.backend.head_object(bucket_name, key_name, version_id) - if full_key.multipart: - mp_part_count = str(len(full_key.multipart.partlist)) + if full_key.multipart: # type: ignore + mp_part_count = str(len(full_key.multipart.partlist)) # type: ignore response_headers["x-amz-mp-parts-count"] = mp_part_count return 200, response_headers, "" else: return 404, response_headers, "" - def _lock_config_from_body(self): - response_dict = {"enabled": False, "mode": None, "days": None, "years": None} + def _lock_config_from_body(self) -> Dict[str, Any]: + response_dict: Dict[str, Any] = { + "enabled": False, + "mode": None, + "days": None, + "years": None, + } parsed_xml = xmltodict.parse(self.body) enabled = ( parsed_xml["ObjectLockConfiguration"]["ObjectLockEnabled"] == "Enabled" @@ -1685,7 +1736,7 @@ class S3Response(BaseResponse): return response_dict - def _acl_from_body(self): + def _acl_from_body(self) -> Optional[FakeAcl]: parsed_xml = xmltodict.parse(self.body) if not parsed_xml.get("AccessControlPolicy"): raise MalformedACLError() @@ -1697,7 +1748,7 @@ class S3Response(BaseResponse): # If empty, then no ACLs: if parsed_xml["AccessControlPolicy"].get("AccessControlList") is None: - return [] + return None if not parsed_xml["AccessControlPolicy"]["AccessControlList"].get("Grant"): raise MalformedACLError() @@ -1718,7 +1769,12 @@ class S3Response(BaseResponse): ) return FakeAcl(grants) - def _get_grants_from_xml(self, grant_list, exception_type, permissions): + def _get_grants_from_xml( + self, + grant_list: List[Dict[str, Any]], + exception_type: Type[S3ClientError], + permissions: List[str], + ) -> List[FakeGrant]: grants = [] for grant in grant_list: if grant.get("Permission", "") not in permissions: @@ -1748,7 +1804,7 @@ class S3Response(BaseResponse): return grants - def _acl_from_headers(self, headers): + def _acl_from_headers(self, headers: Dict[str, str]) -> Optional[FakeAcl]: canned_acl = headers.get("x-amz-acl", "") grants = [] @@ -1767,7 +1823,7 @@ class S3Response(BaseResponse): grantees = [] for key_and_value in value.split(","): - key, value = re.match( + key, value = re.match( # type: ignore '([^=]+)="?([^"]+)"?', key_and_value.strip() ).groups() if key.lower() == "id": @@ -1785,7 +1841,7 @@ class S3Response(BaseResponse): else: return None - def _tagging_from_headers(self, headers): + def _tagging_from_headers(self, headers: Dict[str, Any]) -> Dict[str, str]: tags = {} if headers.get("x-amz-tagging"): parsed_header = parse_qs(headers["x-amz-tagging"], keep_blank_values=True) @@ -1793,7 +1849,7 @@ class S3Response(BaseResponse): tags[tag[0]] = tag[1][0] return tags - def _tagging_from_xml(self, xml): + def _tagging_from_xml(self, xml: bytes) -> Dict[str, str]: parsed_xml = xmltodict.parse(xml, force_list={"Tag": True}) tags = {} @@ -1802,7 +1858,7 @@ class S3Response(BaseResponse): return tags - def _bucket_tagging_from_body(self): + def _bucket_tagging_from_body(self) -> Dict[str, str]: parsed_xml = xmltodict.parse(self.body) tags = {} @@ -1826,7 +1882,7 @@ class S3Response(BaseResponse): return tags - def _cors_from_body(self): + def _cors_from_body(self) -> List[Dict[str, Any]]: parsed_xml = xmltodict.parse(self.body) if isinstance(parsed_xml["CORSConfiguration"]["CORSRule"], list): @@ -1834,18 +1890,18 @@ class S3Response(BaseResponse): return [parsed_xml["CORSConfiguration"]["CORSRule"]] - def _mode_until_from_body(self): + def _mode_until_from_body(self) -> Tuple[Optional[str], Optional[str]]: parsed_xml = xmltodict.parse(self.body) return ( parsed_xml.get("Retention", None).get("Mode", None), parsed_xml.get("Retention", None).get("RetainUntilDate", None), ) - def _legal_hold_status_from_xml(self, xml): + def _legal_hold_status_from_xml(self, xml: bytes) -> Dict[str, Any]: parsed_xml = xmltodict.parse(xml) return parsed_xml["LegalHold"]["Status"] - def _encryption_config_from_body(self): + def _encryption_config_from_body(self) -> Dict[str, Any]: parsed_xml = xmltodict.parse(self.body) if ( @@ -1861,7 +1917,7 @@ class S3Response(BaseResponse): return parsed_xml["ServerSideEncryptionConfiguration"] - def _ownership_rule_from_body(self): + def _ownership_rule_from_body(self) -> Dict[str, Any]: parsed_xml = xmltodict.parse(self.body) if not parsed_xml["OwnershipControls"]["Rule"].get("ObjectOwnership"): @@ -1869,7 +1925,7 @@ class S3Response(BaseResponse): return parsed_xml["OwnershipControls"]["Rule"]["ObjectOwnership"] - def _logging_from_body(self): + def _logging_from_body(self) -> Dict[str, Any]: parsed_xml = xmltodict.parse(self.body) if not parsed_xml["BucketLoggingStatus"].get("LoggingEnabled"): @@ -1914,7 +1970,7 @@ class S3Response(BaseResponse): return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"] - def _notification_config_from_body(self): + def _notification_config_from_body(self) -> Dict[str, Any]: parsed_xml = xmltodict.parse(self.body) if not len(parsed_xml["NotificationConfiguration"]): @@ -1989,17 +2045,19 @@ class S3Response(BaseResponse): return parsed_xml["NotificationConfiguration"] - def _accelerate_config_from_body(self): + def _accelerate_config_from_body(self) -> str: parsed_xml = xmltodict.parse(self.body) config = parsed_xml["AccelerateConfiguration"] return config["Status"] - def _replication_config_from_xml(self, xml): + def _replication_config_from_xml(self, xml: str) -> Dict[str, Any]: parsed_xml = xmltodict.parse(xml, dict_constructor=dict) config = parsed_xml["ReplicationConfiguration"] return config - def _key_response_delete(self, headers, bucket_name, query, key_name): + def _key_response_delete( + self, headers: Any, bucket_name: str, query: Dict[str, Any], key_name: str + ) -> TYPE_RESPONSE: self._set_action("KEY", "DELETE", query) self._authenticate_and_authorize_s3_action() @@ -2024,7 +2082,7 @@ class S3Response(BaseResponse): response_headers[f"x-amz-{k}"] = response_meta[k] return 204, response_headers, "" - def _complete_multipart_body(self, body): + def _complete_multipart_body(self, body: bytes) -> Iterator[Tuple[int, str]]: ps = minidom.parseString(body).getElementsByTagName("Part") prev = 0 for p in ps: @@ -2033,7 +2091,14 @@ class S3Response(BaseResponse): raise InvalidPartOrder() yield (pn, p.getElementsByTagName("ETag")[0].firstChild.wholeText) - def _key_response_post(self, request, body, bucket_name, query, key_name): + def _key_response_post( + self, + request: Any, + body: bytes, + bucket_name: str, + query: Dict[str, Any], + key_name: str, + ) -> TYPE_RESPONSE: self._set_action("KEY", "POST", query) self._authenticate_and_authorize_s3_action() @@ -2071,11 +2136,10 @@ class S3Response(BaseResponse): return 200, response_headers, response if query.get("uploadId"): - body = self._complete_multipart_body(body) - multipart_id = query["uploadId"][0] + multipart_id = query["uploadId"][0] # type: ignore multipart, value, etag = self.backend.complete_multipart_upload( - bucket_name, multipart_id, body + bucket_name, multipart_id, self._complete_multipart_body(body) ) if value is None: return 400, {}, "" @@ -2095,7 +2159,7 @@ class S3Response(BaseResponse): self.backend.put_object_acl(bucket_name, key.name, multipart.acl) template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) - headers = {} + headers: Dict[str, Any] = {} if key.version_id: headers["x-amz-version-id"] = key.version_id @@ -2116,7 +2180,7 @@ class S3Response(BaseResponse): elif "restore" in query: es = minidom.parseString(body).getElementsByTagName("Days") days = es[0].childNodes[0].wholeText - key = self.backend.get_object(bucket_name, key_name) + key = self.backend.get_object(bucket_name, key_name) # type: ignore if key.storage_class not in ARCHIVE_STORAGE_CLASSES: raise InvalidObjectState(storage_class=key.storage_class) r = 202 @@ -2139,7 +2203,7 @@ class S3Response(BaseResponse): "Method POST had only been implemented for multipart uploads and restore operations, so far" ) - def _invalid_headers(self, url, headers): + def _invalid_headers(self, url: str, headers: Dict[str, str]) -> bool: """ Verify whether the provided metadata in the URL is also present in the headers :param url: .../file.txt&content-type=app%2Fjson&Signature=.. diff --git a/moto/s3/select_object_content.py b/moto/s3/select_object_content.py index 58a02a8b2..e718df235 100644 --- a/moto/s3/select_object_content.py +++ b/moto/s3/select_object_content.py @@ -1,19 +1,21 @@ import binascii import struct -from typing import List +from typing import Any, Dict, List, Optional -def parse_query(text_input, query): +def parse_query(text_input: str, query: str) -> List[Dict[str, Any]]: from py_partiql_parser import S3SelectParser return S3SelectParser(source_data={"s3object": text_input}).parse(query) -def _create_header(key: bytes, value: bytes): +def _create_header(key: bytes, value: bytes) -> bytes: return struct.pack("b", len(key)) + key + struct.pack("!bh", 7, len(value)) + value -def _create_message(content_type, event_type, payload): +def _create_message( + content_type: Optional[bytes], event_type: bytes, payload: bytes +) -> bytes: headers = _create_header(b":message-type", b"event") headers += _create_header(b":event-type", event_type) if content_type is not None: @@ -31,23 +33,23 @@ def _create_message(content_type, event_type, payload): return prelude + prelude_crc + headers + payload + message_crc -def _create_stats_message(): +def _create_stats_message() -> bytes: stats = b"""242422""" return _create_message(content_type=b"text/xml", event_type=b"Stats", payload=stats) -def _create_data_message(payload: bytes): +def _create_data_message(payload: bytes) -> bytes: # https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html return _create_message( content_type=b"application/octet-stream", event_type=b"Records", payload=payload ) -def _create_end_message(): +def _create_end_message() -> bytes: return _create_message(content_type=None, event_type=b"End", payload=b"") -def serialize_select(data_list: List[bytes]): +def serialize_select(data_list: List[bytes]) -> bytes: response = b"" for data in data_list: response += _create_data_message(data + b",") diff --git a/moto/s3/utils.py b/moto/s3/utils.py index b7e7f41ba..855ce4a4f 100644 --- a/moto/s3/utils.py +++ b/moto/s3/utils.py @@ -5,7 +5,7 @@ import re import hashlib from urllib.parse import urlparse, unquote, quote from requests.structures import CaseInsensitiveDict -from typing import List, Union, Tuple +from typing import Any, Dict, List, Iterator, Union, Tuple, Optional import sys from moto.settings import S3_IGNORE_SUBDOMAIN_BUCKETNAME @@ -38,7 +38,7 @@ STORAGE_CLASS = [ ] + ARCHIVE_STORAGE_CLASSES -def bucket_name_from_url(url): +def bucket_name_from_url(url: str) -> Optional[str]: # type: ignore if S3_IGNORE_SUBDOMAIN_BUCKETNAME: return None domain = urlparse(url).netloc @@ -75,7 +75,7 @@ REGION_URL_REGEX = re.compile( ) -def parse_region_from_url(url, use_default_region=True): +def parse_region_from_url(url: str, use_default_region: bool = True) -> str: match = REGION_URL_REGEX.search(url) if match: region = match.group("region1") or match.group("region2") @@ -84,8 +84,8 @@ def parse_region_from_url(url, use_default_region=True): return region -def metadata_from_headers(headers): - metadata = CaseInsensitiveDict() +def metadata_from_headers(headers: Dict[str, Any]) -> CaseInsensitiveDict: # type: ignore + metadata = CaseInsensitiveDict() # type: ignore meta_regex = re.compile(r"^x-amz-meta-([a-zA-Z0-9\-_.]+)$", flags=re.IGNORECASE) for header in headers.keys(): if isinstance(header, str): @@ -106,32 +106,32 @@ def metadata_from_headers(headers): return metadata -def clean_key_name(key_name): +def clean_key_name(key_name: str) -> str: return unquote(key_name) -def undo_clean_key_name(key_name): +def undo_clean_key_name(key_name: str) -> str: return quote(key_name) -class _VersionedKeyStore(dict): +class _VersionedKeyStore(dict): # type: ignore """A simplified/modified version of Django's `MultiValueDict` taken from: https://github.com/django/django/blob/70576740b0bb5289873f5a9a9a4e1a26b2c330e5/django/utils/datastructures.py#L282 """ - def __sgetitem__(self, key): + def __sgetitem__(self, key: str) -> List[Any]: return super().__getitem__(key) - def pop(self, key): + def pop(self, key: str) -> None: # type: ignore for version in self.getlist(key, []): version.dispose() super().pop(key) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self.__sgetitem__(key)[-1] - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> Any: try: current = self.__sgetitem__(key) current.append(value) @@ -140,21 +140,21 @@ class _VersionedKeyStore(dict): super().__setitem__(key, current) - def get(self, key, default=None): + def get(self, key: str, default: Any = None) -> Any: try: return self[key] except (KeyError, IndexError): pass return default - def getlist(self, key, default=None): + def getlist(self, key: str, default: Any = None) -> Any: try: return self.__sgetitem__(key) except (KeyError, IndexError): pass return default - def setlist(self, key, list_): + def setlist(self, key: Any, list_: Any) -> Any: if isinstance(list_, tuple): list_ = list(list_) elif not isinstance(list_, list): @@ -168,35 +168,35 @@ class _VersionedKeyStore(dict): super().__setitem__(key, list_) - def _iteritems(self): + def _iteritems(self) -> Iterator[Tuple[str, Any]]: for key in self._self_iterable(): yield key, self[key] - def _itervalues(self): + def _itervalues(self) -> Iterator[Any]: for key in self._self_iterable(): yield self[key] - def _iterlists(self): + def _iterlists(self) -> Iterator[Tuple[str, List[Any]]]: for key in self._self_iterable(): yield key, self.getlist(key) - def item_size(self): + def item_size(self) -> int: size = 0 for val in self._self_iterable().values(): size += sys.getsizeof(val) return size - def _self_iterable(self): + def _self_iterable(self) -> Dict[str, Any]: # to enable concurrency, return a copy, to avoid "dictionary changed size during iteration" # TODO: look into replacing with a locking mechanism, potentially return dict(self) - items = iteritems = _iteritems + items = iteritems = _iteritems # type: ignore lists = iterlists = _iterlists - values = itervalues = _itervalues + values = itervalues = _itervalues # type: ignore -def compute_checksum(body, algorithm): +def compute_checksum(body: bytes, algorithm: str) -> bytes: if algorithm == "SHA1": hashed_body = _hash(hashlib.sha1, (body,)) elif algorithm == "CRC32" or algorithm == "CRC32C": @@ -206,7 +206,7 @@ def compute_checksum(body, algorithm): return base64.b64encode(hashed_body) -def _hash(fn, args) -> bytes: +def _hash(fn: Any, args: Any) -> bytes: try: return fn(*args, usedforsecurity=False).hexdigest().encode("utf-8") except TypeError: diff --git a/moto/s3bucket_path/utils.py b/moto/s3bucket_path/utils.py index f2b86c68e..c2f833757 100644 --- a/moto/s3bucket_path/utils.py +++ b/moto/s3bucket_path/utils.py @@ -1,7 +1,8 @@ +from typing import Optional from urllib.parse import urlparse -def bucket_name_from_url(url): +def bucket_name_from_url(url: str) -> Optional[str]: path = urlparse(url).path.lstrip("/") parts = path.lstrip("/").split("/") @@ -10,5 +11,5 @@ def bucket_name_from_url(url): return parts[0] -def parse_key_name(path): +def parse_key_name(path: str) -> str: return "/".join(path.split("/")[2:]) diff --git a/moto/s3control/config.py b/moto/s3control/config.py index 26329cdf7..2fdaa2ad8 100644 --- a/moto/s3control/config.py +++ b/moto/s3control/config.py @@ -2,6 +2,7 @@ import datetime import json from boto3 import Session +from typing import Any, Dict, List, Optional, Tuple from moto.core.exceptions import InvalidNextTokenException from moto.core.common_models import ConfigQueryModel @@ -12,15 +13,15 @@ from moto.s3control import s3control_backends class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel): def list_config_service_resources( self, - account_id, - resource_ids, - resource_name, - limit, - next_token, - backend_region=None, - resource_region=None, - aggregator=None, - ): + account_id: str, + resource_ids: Optional[List[str]], + resource_name: Optional[str], + limit: int, + next_token: Optional[str], + backend_region: Optional[str] = None, + resource_region: Optional[str] = None, + aggregator: Any = None, + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: # For the Account Public Access Block, they are the same for all regions. The resource ID is the AWS account ID # There is no resource name -- it should be a blank string "" if provided. @@ -95,12 +96,12 @@ class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel): def get_config_resource( self, - account_id, - resource_id, - resource_name=None, - backend_region=None, - resource_region=None, - ): + account_id: str, + resource_id: str, + resource_name: Optional[str] = None, + backend_region: Optional[str] = None, + resource_region: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: # Do we even have this defined? backend = self.backends[account_id]["global"] @@ -116,7 +117,7 @@ class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel): # Is the resource ID correct?: if account_id == resource_id: if backend_region: - pab_region = backend_region + pab_region: Optional[str] = backend_region # Invalid region? elif resource_region not in regions: diff --git a/moto/s3control/exceptions.py b/moto/s3control/exceptions.py index 8e051b300..9572ace0d 100644 --- a/moto/s3control/exceptions.py +++ b/moto/s3control/exceptions.py @@ -1,4 +1,4 @@ -"""Exceptions raised by the s3control service.""" +from typing import Any from moto.core.exceptions import RESTError @@ -13,7 +13,7 @@ ERROR_WITH_ACCESS_POINT_POLICY = """{% extends 'wrapped_single_error' %} class S3ControlError(RESTError): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): kwargs.setdefault("template", "single_error") super().__init__(*args, **kwargs) @@ -21,7 +21,7 @@ class S3ControlError(RESTError): class AccessPointNotFound(S3ControlError): code = 404 - def __init__(self, name, **kwargs): + def __init__(self, name: str, **kwargs: Any): kwargs.setdefault("template", "ap_not_found") kwargs["name"] = name self.templates["ap_not_found"] = ERROR_WITH_ACCESS_POINT_NAME @@ -33,7 +33,7 @@ class AccessPointNotFound(S3ControlError): class AccessPointPolicyNotFound(S3ControlError): code = 404 - def __init__(self, name, **kwargs): + def __init__(self, name: str, **kwargs: Any): kwargs.setdefault("template", "apf_not_found") kwargs["name"] = name self.templates["apf_not_found"] = ERROR_WITH_ACCESS_POINT_POLICY diff --git a/moto/s3control/models.py b/moto/s3control/models.py index 9f1b398fd..cb4bfd786 100644 --- a/moto/s3control/models.py +++ b/moto/s3control/models.py @@ -1,5 +1,7 @@ from collections import defaultdict from datetime import datetime +from typing import Any, Dict, Optional + from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api._internal import mock_random from moto.s3.exceptions import ( @@ -15,18 +17,18 @@ from .exceptions import AccessPointNotFound, AccessPointPolicyNotFound class AccessPoint(BaseModel): def __init__( self, - account_id, - name, - bucket, - vpc_configuration, - public_access_block_configuration, + account_id: str, + name: str, + bucket: str, + vpc_configuration: Dict[str, Any], + public_access_block_configuration: Dict[str, Any], ): self.name = name self.alias = f"{name}-{mock_random.get_random_hex(34)}-s3alias" self.bucket = bucket self.created = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f") self.arn = f"arn:aws:s3:us-east-1:{account_id}:accesspoint/{name}" - self.policy = None + self.policy: Optional[str] = None self.network_origin = "VPC" if vpc_configuration else "Internet" self.vpc_id = (vpc_configuration or {}).get("VpcId") pubc = public_access_block_configuration or {} @@ -37,23 +39,23 @@ class AccessPoint(BaseModel): "RestrictPublicBuckets": pubc.get("RestrictPublicBuckets", "true"), } - def delete_policy(self): + def delete_policy(self) -> None: self.policy = None - def set_policy(self, policy): + def set_policy(self, policy: str) -> None: self.policy = policy - def has_policy(self): + def has_policy(self) -> bool: return self.policy is not None class S3ControlBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.public_access_block = None - self.access_points = defaultdict(dict) + self.public_access_block: Optional[PublicAccessBlock] = None + self.access_points: Dict[str, Dict[str, AccessPoint]] = defaultdict(dict) - def get_public_access_block(self, account_id): + def get_public_access_block(self, account_id: str) -> PublicAccessBlock: # The account ID should equal the account id that is set for Moto: if account_id != self.account_id: raise WrongPublicAccessBlockAccountIdError() @@ -63,14 +65,16 @@ class S3ControlBackend(BaseBackend): return self.public_access_block - def delete_public_access_block(self, account_id): + def delete_public_access_block(self, account_id: str) -> None: # The account ID should equal the account id that is set for Moto: if account_id != self.account_id: raise WrongPublicAccessBlockAccountIdError() self.public_access_block = None - def put_public_access_block(self, account_id, pub_block_config): + def put_public_access_block( + self, account_id: str, pub_block_config: Dict[str, Any] + ) -> None: # The account ID should equal the account id that is set for Moto: if account_id != self.account_id: raise WrongPublicAccessBlockAccountIdError() @@ -87,12 +91,12 @@ class S3ControlBackend(BaseBackend): def create_access_point( self, - account_id, - name, - bucket, - vpc_configuration, - public_access_block_configuration, - ): + account_id: str, + name: str, + bucket: str, + vpc_configuration: Dict[str, Any], + public_access_block_configuration: Dict[str, Any], + ) -> AccessPoint: access_point = AccessPoint( account_id, name, @@ -103,29 +107,31 @@ class S3ControlBackend(BaseBackend): self.access_points[account_id][name] = access_point return access_point - def delete_access_point(self, account_id, name): + def delete_access_point(self, account_id: str, name: str) -> None: self.access_points[account_id].pop(name, None) - def get_access_point(self, account_id, name): + def get_access_point(self, account_id: str, name: str) -> AccessPoint: if name not in self.access_points[account_id]: raise AccessPointNotFound(name) return self.access_points[account_id][name] - def create_access_point_policy(self, account_id, name, policy): + def create_access_point_policy( + self, account_id: str, name: str, policy: str + ) -> None: access_point = self.get_access_point(account_id, name) access_point.set_policy(policy) - def get_access_point_policy(self, account_id, name): + def get_access_point_policy(self, account_id: str, name: str) -> str: access_point = self.get_access_point(account_id, name) if access_point.has_policy(): - return access_point.policy + return access_point.policy # type: ignore[return-value] raise AccessPointPolicyNotFound(name) - def delete_access_point_policy(self, account_id, name): + def delete_access_point_policy(self, account_id: str, name: str) -> None: access_point = self.get_access_point(account_id, name) access_point.delete_policy() - def get_access_point_policy_status(self, account_id, name): + def get_access_point_policy_status(self, account_id: str, name: str) -> bool: """ We assume the policy status is always public """ diff --git a/moto/s3control/responses.py b/moto/s3control/responses.py index 18c7231d0..57c314e64 100644 --- a/moto/s3control/responses.py +++ b/moto/s3control/responses.py @@ -1,23 +1,25 @@ import json import xmltodict +from typing import Any, Dict, Tuple +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse from moto.s3.exceptions import S3ClientError from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION from moto.utilities.aws_headers import amzn_request_id -from .models import s3control_backends +from .models import s3control_backends, S3ControlBackend class S3ControlResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="s3control") @property - def backend(self): + def backend(self) -> S3ControlBackend: return s3control_backends[self.current_account]["global"] @amzn_request_id - def public_access_block(self, request, full_url, headers): + def public_access_block(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore self.setup_class(request, full_url, headers) try: if request.method == "GET": @@ -29,7 +31,7 @@ class S3ControlResponse(BaseResponse): except S3ClientError as err: return err.code, {}, err.description - def get_public_access_block(self, request): + def get_public_access_block(self, request: Any) -> TYPE_RESPONSE: account_id = request.headers.get("x-amz-account-id") public_block_config = self.backend.get_public_access_block( account_id=account_id @@ -37,7 +39,7 @@ class S3ControlResponse(BaseResponse): template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION) return 200, {}, template.render(public_block_config=public_block_config) - def put_public_access_block(self, request): + def put_public_access_block(self, request: Any) -> TYPE_RESPONSE: account_id = request.headers.get("x-amz-account-id") data = request.body if hasattr(request, "body") else request.data pab_config = self._parse_pab_config(data) @@ -46,18 +48,18 @@ class S3ControlResponse(BaseResponse): ) return 201, {}, json.dumps({}) - def delete_public_access_block(self, request): + def delete_public_access_block(self, request: Any) -> TYPE_RESPONSE: account_id = request.headers.get("x-amz-account-id") self.backend.delete_public_access_block(account_id=account_id) return 204, {}, json.dumps({}) - def _parse_pab_config(self, body): + def _parse_pab_config(self, body: str) -> Dict[str, Any]: parsed_xml = xmltodict.parse(body) parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None) return parsed_xml - def access_point(self, request, full_url, headers): + def access_point(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "PUT": return self.create_access_point(full_url) @@ -66,7 +68,7 @@ class S3ControlResponse(BaseResponse): if request.method == "DELETE": return self.delete_access_point(full_url) - def access_point_policy(self, request, full_url, headers): + def access_point_policy(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "PUT": return self.create_access_point_policy(full_url) @@ -75,14 +77,14 @@ class S3ControlResponse(BaseResponse): if request.method == "DELETE": return self.delete_access_point_policy(full_url) - def access_point_policy_status(self, request, full_url, headers): + def access_point_policy_status(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "PUT": return self.create_access_point(full_url) if request.method == "GET": return self.get_access_point_policy_status(full_url) - def create_access_point(self, full_url): + def create_access_point(self, full_url: str) -> TYPE_RESPONSE: account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) params = xmltodict.parse(self.body)["CreateAccessPointRequest"] bucket = params["Bucket"] @@ -98,43 +100,45 @@ class S3ControlResponse(BaseResponse): template = self.response_template(CREATE_ACCESS_POINT_TEMPLATE) return 200, {}, template.render(access_point=access_point) - def get_access_point(self, full_url): + def get_access_point(self, full_url: str) -> TYPE_RESPONSE: account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) access_point = self.backend.get_access_point(account_id=account_id, name=name) template = self.response_template(GET_ACCESS_POINT_TEMPLATE) return 200, {}, template.render(access_point=access_point) - def delete_access_point(self, full_url): + def delete_access_point(self, full_url: str) -> TYPE_RESPONSE: account_id, name = self._get_accountid_and_name_from_accesspoint(full_url) self.backend.delete_access_point(account_id=account_id, name=name) return 204, {}, "" - def create_access_point_policy(self, full_url): + def create_access_point_policy(self, full_url: str) -> TYPE_RESPONSE: account_id, name = self._get_accountid_and_name_from_policy(full_url) params = xmltodict.parse(self.body) policy = params["PutAccessPointPolicyRequest"]["Policy"] self.backend.create_access_point_policy(account_id, name, policy) return 200, {}, "" - def get_access_point_policy(self, full_url): + def get_access_point_policy(self, full_url: str) -> TYPE_RESPONSE: account_id, name = self._get_accountid_and_name_from_policy(full_url) policy = self.backend.get_access_point_policy(account_id, name) template = self.response_template(GET_ACCESS_POINT_POLICY_TEMPLATE) return 200, {}, template.render(policy=policy) - def delete_access_point_policy(self, full_url): + def delete_access_point_policy(self, full_url: str) -> TYPE_RESPONSE: account_id, name = self._get_accountid_and_name_from_policy(full_url) self.backend.delete_access_point_policy(account_id=account_id, name=name) return 204, {}, "" - def get_access_point_policy_status(self, full_url): + def get_access_point_policy_status(self, full_url: str) -> TYPE_RESPONSE: account_id, name = self._get_accountid_and_name_from_policy(full_url) self.backend.get_access_point_policy_status(account_id, name) template = self.response_template(GET_ACCESS_POINT_POLICY_STATUS_TEMPLATE) return 200, {}, template.render() - def _get_accountid_and_name_from_accesspoint(self, full_url): + def _get_accountid_and_name_from_accesspoint( + self, full_url: str + ) -> Tuple[str, str]: url = full_url if full_url.startswith("http"): url = full_url.split("://")[1] @@ -142,7 +146,7 @@ class S3ControlResponse(BaseResponse): name = url.split("v20180820/accesspoint/")[-1] return account_id, name - def _get_accountid_and_name_from_policy(self, full_url): + def _get_accountid_and_name_from_policy(self, full_url: str) -> Tuple[str, str]: url = full_url if full_url.startswith("http"): url = full_url.split("://")[1] diff --git a/moto/settings.py b/moto/settings.py index ce0dace60..2dbfe4bd8 100644 --- a/moto/settings.py +++ b/moto/settings.py @@ -3,7 +3,7 @@ import os import pathlib from functools import lru_cache -from typing import Optional +from typing import List, Optional TEST_SERVER_MODE = os.environ.get("TEST_SERVER_MODE", "0").lower() == "true" @@ -47,7 +47,7 @@ def get_sf_execution_history_type(): return os.environ.get("SF_EXECUTION_HISTORY_TYPE", "SUCCESS") -def get_s3_custom_endpoints(): +def get_s3_custom_endpoints() -> List[str]: endpoints = os.environ.get("MOTO_S3_CUSTOM_ENDPOINTS") if endpoints: return endpoints.split(",") @@ -57,7 +57,7 @@ def get_s3_custom_endpoints(): S3_UPLOAD_PART_MIN_SIZE = 5242880 -def get_s3_default_key_buffer_size(): +def get_s3_default_key_buffer_size() -> int: return int( os.environ.get( "MOTO_S3_DEFAULT_KEY_BUFFER_SIZE", S3_UPLOAD_PART_MIN_SIZE - 1024 diff --git a/moto/utilities/utils.py b/moto/utilities/utils.py index 971346c1d..ec0188493 100644 --- a/moto/utilities/utils.py +++ b/moto/utilities/utils.py @@ -77,7 +77,7 @@ def md5_hash(data: Any = None) -> Any: class LowercaseDict(MutableMapping): """A dictionary that lowercases all keys""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): self.store = dict() self.update(dict(*args, **kwargs)) # use the free update to set keys diff --git a/setup.cfg b/setup.cfg index c5dfd6cf3..89748a342 100644 --- a/setup.cfg +++ b/setup.cfg @@ -239,7 +239,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/scheduler +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/scheduler show_column_numbers=True show_error_codes = True disable_error_code=abstract