Techdebt: MyPy S3 (#6235)

This commit is contained in:
Bert Blommers 2023-04-20 16:47:39 +00:00 committed by GitHub
parent ff48188362
commit c2e3d90fc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 880 additions and 730 deletions

View File

@ -59,7 +59,7 @@ class RESTError(HTTPException):
if template in self.templates.keys():
env = Environment(loader=DictLoader(self.templates))
self.description = env.get_template(template).render(
self.description: str = env.get_template(template).render( # type: ignore
error_type=error_type,
message=message,
request_id_tag=self.request_id_tag_name,

View File

@ -1,6 +1,6 @@
from datetime import datetime, timedelta
from moto.moto_api import state_manager
from typing import List, Tuple
from typing import List, Tuple, Optional
class ManagedState:
@ -8,7 +8,7 @@ class ManagedState:
Subclass this class to configure state-transitions
"""
def __init__(self, model_name: str, transitions: List[Tuple[str, str]]):
def __init__(self, model_name: str, transitions: List[Tuple[Optional[str], str]]):
# Indicate the possible transitions for this model
# Example: [(initializing,queued), (queued, starting), (starting, ready)]
self._transitions = transitions
@ -28,7 +28,7 @@ class ManagedState:
self._tick += 1
@property
def status(self) -> str:
def status(self) -> Optional[str]:
"""
Transitions the status as appropriate before returning
"""
@ -55,12 +55,12 @@ class ManagedState:
def status(self, value: str) -> None:
self._status = value
def _get_next_status(self, previous: str) -> str:
def _get_next_status(self, previous: Optional[str]) -> Optional[str]:
return next(
(nxt for prev, nxt in self._transitions if previous == prev), previous
)
def _get_last_status(self, previous: str) -> str:
def _get_last_status(self, previous: Optional[str]) -> Optional[str]:
next_state = self._get_next_status(previous)
while next_state != previous:
previous = next_state

View File

@ -1,7 +1,10 @@
from collections import OrderedDict
from typing import Any, Dict, List
def cfn_to_api_encryption(bucket_encryption_properties):
def cfn_to_api_encryption(
bucket_encryption_properties: Dict[str, Any]
) -> Dict[str, Any]:
sse_algorithm = bucket_encryption_properties["ServerSideEncryptionConfiguration"][
0
@ -16,14 +19,12 @@ def cfn_to_api_encryption(bucket_encryption_properties):
rule = OrderedDict(
{"ApplyServerSideEncryptionByDefault": apply_server_side_encryption_by_default}
)
bucket_encryption = OrderedDict(
{"@xmlns": "http://s3.amazonaws.com/doc/2006-03-01/"}
return OrderedDict(
{"@xmlns": "http://s3.amazonaws.com/doc/2006-03-01/", "Rule": rule}
)
bucket_encryption["Rule"] = rule
return bucket_encryption
def is_replacement_update(properties):
def is_replacement_update(properties: List[str]) -> bool:
properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"]
return any(
[

View File

@ -1,4 +1,5 @@
import json
from typing import Any, Dict, List, Optional, Tuple
from moto.core.exceptions import InvalidNextTokenException
from moto.core.common_models import ConfigQueryModel
@ -8,15 +9,15 @@ from moto.s3 import s3_backends
class S3ConfigQuery(ConfigQueryModel):
def list_config_service_resources(
self,
account_id,
resource_ids,
resource_name,
limit,
next_token,
backend_region=None,
resource_region=None,
aggregator=None,
):
account_id: str,
resource_ids: Optional[List[str]],
resource_name: Optional[str],
limit: int,
next_token: Optional[str],
backend_region: Optional[str] = None,
resource_region: Optional[str] = None,
aggregator: Optional[Dict[str, Any]] = None,
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
# The resource_region only matters for aggregated queries as you can filter on bucket regions for them.
# For other resource types, you would need to iterate appropriately for the backend_region.
@ -37,7 +38,7 @@ class S3ConfigQuery(ConfigQueryModel):
filter_buckets = [resource_name] if resource_name else resource_ids
for bucket in self.backends[account_id]["global"].buckets.keys():
if bucket in filter_buckets:
if bucket in filter_buckets: # type: ignore
bucket_list.append(bucket)
# Filter on the proper region if supplied:
@ -95,26 +96,26 @@ class S3ConfigQuery(ConfigQueryModel):
def get_config_resource(
self,
account_id,
resource_id,
resource_name=None,
backend_region=None,
resource_region=None,
):
account_id: str,
resource_id: str,
resource_name: Optional[str] = None,
backend_region: Optional[str] = None,
resource_region: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
# Get the bucket:
bucket = self.backends[account_id]["global"].buckets.get(resource_id, {})
if not bucket:
return
return None
# Are we filtering based on region?
region_filter = backend_region or resource_region
if region_filter and bucket.region_name != region_filter:
return
return None
# Are we also filtering on bucket name?
if resource_name and bucket.name != resource_name:
return
return None
# Format the bucket to the AWS Config format:
config_data = bucket.to_config_dict()

View File

@ -1,3 +1,4 @@
from typing import Any, Optional, Union
from moto.core.exceptions import RESTError
ERROR_WITH_BUCKET_NAME = """{% extends 'single_error' %}
@ -35,7 +36,7 @@ class S3ClientError(RESTError):
# S3 API uses <RequestID> as the XML tag in response messages
request_id_tag_name = "RequestID"
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "single_error")
self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME
super().__init__(*args, **kwargs)
@ -44,7 +45,7 @@ class S3ClientError(RESTError):
class InvalidArgumentError(S3ClientError):
code = 400
def __init__(self, message, name, value, *args, **kwargs):
def __init__(self, message: str, name: str, value: str, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "argument_error")
kwargs["name"] = name
kwargs["value"] = value
@ -60,7 +61,7 @@ class AccessForbidden(S3ClientError):
class BucketError(S3ClientError):
def __init__(self, *args: str, **kwargs: str):
def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "bucket_error")
self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME
super().__init__(*args, **kwargs)
@ -69,7 +70,7 @@ class BucketError(S3ClientError):
class BucketAlreadyExists(BucketError):
code = 409
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "bucket_error")
self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME
super().__init__(
@ -87,16 +88,16 @@ class BucketAlreadyExists(BucketError):
class MissingBucket(BucketError):
code = 404
def __init__(self, *args, **kwargs):
def __init__(self, bucket: str):
super().__init__(
"NoSuchBucket", "The specified bucket does not exist", *args, **kwargs
"NoSuchBucket", "The specified bucket does not exist", bucket=bucket
)
class MissingKey(S3ClientError):
code = 404
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
kwargs.setdefault("template", "key_error")
self.templates["key_error"] = ERROR_WITH_KEY_NAME
super().__init__("NoSuchKey", "The specified key does not exist.", **kwargs)
@ -105,16 +106,14 @@ class MissingKey(S3ClientError):
class MissingVersion(S3ClientError):
code = 404
def __init__(self, *args, **kwargs):
super().__init__(
"NoSuchVersion", "The specified version does not exist.", *args, **kwargs
)
def __init__(self) -> None:
super().__init__("NoSuchVersion", "The specified version does not exist.")
class InvalidVersion(S3ClientError):
code = 400
def __init__(self, version_id, *args, **kwargs):
def __init__(self, version_id: str, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "argument_error")
kwargs["name"] = "versionId"
kwargs["value"] = version_id
@ -127,7 +126,7 @@ class InvalidVersion(S3ClientError):
class ObjectNotInActiveTierError(S3ClientError):
code = 403
def __init__(self, key_name):
def __init__(self, key_name: Any):
super().__init__(
"ObjectNotInActiveTierError",
"The source object of the COPY operation is not in the active tier and is only stored in Amazon Glacier.",
@ -138,105 +137,84 @@ class ObjectNotInActiveTierError(S3ClientError):
class InvalidPartOrder(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"InvalidPartOrder",
(
"The list of parts was not in ascending order. The parts "
"list must be specified in order by part number."
),
*args,
**kwargs,
"The list of parts was not in ascending order. The parts list must be specified in order by part number.",
)
class InvalidPart(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"InvalidPart",
(
"One or more of the specified parts could not be found. "
"The part might not have been uploaded, or the specified "
"entity tag might not have matched the part's entity tag."
),
*args,
**kwargs,
"One or more of the specified parts could not be found. The part might not have been uploaded, or the specified entity tag might not have matched the part's entity tag.",
)
class EntityTooSmall(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"EntityTooSmall",
"Your proposed upload is smaller than the minimum allowed object size.",
*args,
**kwargs,
)
class InvalidRequest(S3ClientError):
code = 400
def __init__(self, method, *args, **kwargs):
def __init__(self, method: str):
super().__init__(
"InvalidRequest",
f"Found unsupported HTTP method in CORS config. Unsupported method is {method}",
*args,
**kwargs,
)
class IllegalLocationConstraintException(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"IllegalLocationConstraintException",
"The unspecified location constraint is incompatible for the region specific endpoint this request was sent to.",
*args,
**kwargs,
)
class MalformedXML(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"MalformedXML",
"The XML you provided was not well-formed or did not validate against our published schema",
*args,
**kwargs,
)
class MalformedACLError(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"MalformedACLError",
"The XML you provided was not well-formed or did not validate against our published schema",
*args,
**kwargs,
)
class InvalidTargetBucketForLogging(S3ClientError):
code = 400
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__("InvalidTargetBucketForLogging", msg)
class CrossLocationLoggingProhibitted(S3ClientError):
code = 403
def __init__(self):
def __init__(self) -> None:
super().__init__(
"CrossLocationLoggingProhibitted", "Cross S3 location logging not allowed."
)
@ -245,7 +223,7 @@ class CrossLocationLoggingProhibitted(S3ClientError):
class InvalidMaxPartArgument(S3ClientError):
code = 400
def __init__(self, arg, min_val, max_val):
def __init__(self, arg: str, min_val: int, max_val: int):
error = f"Argument {arg} must be an integer between {min_val} and {max_val}"
super().__init__("InvalidArgument", error)
@ -253,97 +231,83 @@ class InvalidMaxPartArgument(S3ClientError):
class InvalidMaxPartNumberArgument(InvalidArgumentError):
code = 400
def __init__(self, value, *args, **kwargs):
def __init__(self, value: int):
error = "Part number must be an integer between 1 and 10000, inclusive"
super().__init__(message=error, name="partNumber", value=value, *args, **kwargs)
super().__init__(message=error, name="partNumber", value=value) # type: ignore
class NotAnIntegerException(InvalidArgumentError):
code = 400
def __init__(self, name, value, *args, **kwargs):
def __init__(self, name: str, value: int):
error = f"Provided {name} not an integer or within integer range"
super().__init__(message=error, name=name, value=value, *args, **kwargs)
super().__init__(message=error, name=name, value=value) # type: ignore
class InvalidNotificationARN(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
super().__init__(
"InvalidArgument", "The ARN is not well formed", *args, **kwargs
)
def __init__(self) -> None:
super().__init__("InvalidArgument", "The ARN is not well formed")
class InvalidNotificationDestination(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"InvalidArgument",
"The notification destination service region is not valid for the bucket location constraint",
*args,
**kwargs,
)
class InvalidNotificationEvent(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"InvalidArgument",
"The event is not supported for notifications",
*args,
**kwargs,
)
class InvalidStorageClass(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self, storage: Optional[str]):
super().__init__(
"InvalidStorageClass",
"The storage class you specified is not valid",
*args,
**kwargs,
storage=storage,
)
class InvalidBucketName(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
super().__init__(
"InvalidBucketName", "The specified bucket is not valid.", *args, **kwargs
)
def __init__(self) -> None:
super().__init__("InvalidBucketName", "The specified bucket is not valid.")
class DuplicateTagKeys(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
super().__init__(
"InvalidTag",
"Cannot provide multiple Tags with the same key",
*args,
**kwargs,
)
def __init__(self) -> None:
super().__init__("InvalidTag", "Cannot provide multiple Tags with the same key")
class S3AccessDeniedError(S3ClientError):
code = 403
def __init__(self, *args: str, **kwargs: str):
super().__init__("AccessDenied", "Access Denied", *args, **kwargs)
def __init__(self) -> None:
super().__init__("AccessDenied", "Access Denied")
class BucketAccessDeniedError(BucketError):
code = 403
def __init__(self, *args: str, **kwargs: str):
super().__init__("AccessDenied", "Access Denied", *args, **kwargs)
def __init__(self, bucket: str):
super().__init__("AccessDenied", "Access Denied", bucket=bucket)
class S3InvalidTokenError(S3ClientError):
@ -368,12 +332,11 @@ class S3AclAndGrantError(S3ClientError):
class BucketInvalidTokenError(BucketError):
code = 400
def __init__(self, *args: str, **kwargs: str):
def __init__(self, bucket: str):
super().__init__(
"InvalidToken",
"The provided token is malformed or otherwise invalid.",
*args,
**kwargs,
bucket=bucket,
)
@ -390,12 +353,11 @@ class S3InvalidAccessKeyIdError(S3ClientError):
class BucketInvalidAccessKeyIdError(S3ClientError):
code = 403
def __init__(self, *args: str, **kwargs: str):
def __init__(self, bucket: str):
super().__init__(
"InvalidAccessKeyId",
"The AWS Access Key Id you provided does not exist in our records.",
*args,
**kwargs,
bucket=bucket,
)
@ -412,50 +374,45 @@ class S3SignatureDoesNotMatchError(S3ClientError):
class BucketSignatureDoesNotMatchError(S3ClientError):
code = 403
def __init__(self, *args: str, **kwargs: str):
def __init__(self, bucket: str):
super().__init__(
"SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided. Check your key and signing method.",
*args,
**kwargs,
bucket=bucket,
)
class NoSuchPublicAccessBlockConfiguration(S3ClientError):
code = 404
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"NoSuchPublicAccessBlockConfiguration",
"The public access block configuration was not found",
*args,
**kwargs,
)
class InvalidPublicAccessBlockConfiguration(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"InvalidRequest",
"Must specify at least one configuration.",
*args,
**kwargs,
)
class WrongPublicAccessBlockAccountIdError(S3ClientError):
code = 403
def __init__(self):
def __init__(self) -> None:
super().__init__("AccessDenied", "Access Denied")
class NoSystemTags(S3ClientError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidTag", "System tags cannot be added/updated by requester"
)
@ -464,7 +421,7 @@ class NoSystemTags(S3ClientError):
class NoSuchUpload(S3ClientError):
code = 404
def __init__(self, upload_id, *args, **kwargs):
def __init__(self, upload_id: Union[int, str], *args: Any, **kwargs: Any):
kwargs.setdefault("template", "error_uploadid")
kwargs["upload_id"] = upload_id
self.templates["error_uploadid"] = ERROR_WITH_UPLOADID
@ -479,7 +436,7 @@ class NoSuchUpload(S3ClientError):
class PreconditionFailed(S3ClientError):
code = 412
def __init__(self, failed_condition, **kwargs):
def __init__(self, failed_condition: str, **kwargs: Any):
kwargs.setdefault("template", "condition_error")
self.templates["condition_error"] = ERROR_WITH_CONDITION_NAME
super().__init__(
@ -493,7 +450,7 @@ class PreconditionFailed(S3ClientError):
class InvalidRange(S3ClientError):
code = 416
def __init__(self, range_requested, actual_size, **kwargs):
def __init__(self, range_requested: str, actual_size: str, **kwargs: Any):
kwargs.setdefault("template", "range_error")
self.templates["range_error"] = ERROR_WITH_RANGE
super().__init__(
@ -508,19 +465,16 @@ class InvalidRange(S3ClientError):
class InvalidContinuationToken(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
def __init__(self) -> None:
super().__init__(
"InvalidArgument",
"The continuation token provided is incorrect",
*args,
**kwargs,
"InvalidArgument", "The continuation token provided is incorrect"
)
class InvalidObjectState(BucketError):
code = 403
def __init__(self, storage_class, **kwargs):
def __init__(self, storage_class: Optional[str], **kwargs: Any):
kwargs.setdefault("template", "storage_error")
self.templates["storage_error"] = ERROR_WITH_STORAGE_CLASS
super().__init__(
@ -534,35 +488,35 @@ class InvalidObjectState(BucketError):
class LockNotEnabled(S3ClientError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__("InvalidRequest", "Bucket is missing ObjectLockConfiguration")
class AccessDeniedByLock(S3ClientError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__("AccessDenied", "Access Denied")
class InvalidContentMD5(S3ClientError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__("InvalidContentMD5", "Content MD5 header is invalid")
class BucketNeedsToBeNew(S3ClientError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__("InvalidBucket", "Bucket needs to be empty")
class BucketMustHaveLockeEnabled(S3ClientError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidBucketState",
"Object Lock configuration cannot be enabled on existing buckets",
@ -572,7 +526,7 @@ class BucketMustHaveLockeEnabled(S3ClientError):
class CopyObjectMustChangeSomething(S3ClientError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidRequest",
"This copy request is illegal because it is trying to copy an object to itself without changing the object's metadata, storage class, website redirect location or encryption attributes.",
@ -582,27 +536,25 @@ class CopyObjectMustChangeSomething(S3ClientError):
class InvalidFilterRuleName(InvalidArgumentError):
code = 400
def __init__(self, value, *args, **kwargs):
def __init__(self, value: str):
super().__init__(
"filter rule name must be either prefix or suffix",
"FilterRule.Name",
value,
*args,
**kwargs,
)
class InvalidTagError(S3ClientError):
code = 400
def __init__(self, value, *args, **kwargs):
super().__init__("InvalidTag", value, *args, **kwargs)
def __init__(self, value: str):
super().__init__("InvalidTag", value)
class ObjectLockConfigurationNotFoundError(S3ClientError):
code = 404
def __init__(self):
def __init__(self) -> None:
super().__init__(
"ObjectLockConfigurationNotFoundError",
"Object Lock configuration does not exist for this bucket",

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
import json
from datetime import datetime
from typing import Any, Dict, List
_EVENT_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
@ -7,7 +8,9 @@ S3_OBJECT_CREATE_COPY = "s3:ObjectCreated:Copy"
S3_OBJECT_CREATE_PUT = "s3:ObjectCreated:Put"
def _get_s3_event(event_name, bucket, key, notification_id):
def _get_s3_event(
event_name: str, bucket: Any, key: Any, notification_id: str
) -> Dict[str, List[Dict[str, Any]]]:
etag = key.etag.replace('"', "")
# s3:ObjectCreated:Put --> ObjectCreated:Put
event_name = event_name[3:]
@ -34,11 +37,11 @@ def _get_s3_event(event_name, bucket, key, notification_id):
}
def _get_region_from_arn(arn):
def _get_region_from_arn(arn: str) -> str:
return arn.split(":")[3]
def send_event(account_id, event_name, bucket, key):
def send_event(account_id: str, event_name: str, bucket: Any, key: Any) -> None:
if bucket.notification_configuration is None:
return
@ -58,7 +61,9 @@ def send_event(account_id, event_name, bucket, key):
_send_sqs_message(account_id, event_body, queue_name, region_name)
def _send_sqs_message(account_id, event_body, queue_name, region_name):
def _send_sqs_message(
account_id: str, event_body: Any, queue_name: str, region_name: str
) -> None:
try:
from moto.sqs.models import sqs_backends
@ -74,7 +79,9 @@ def _send_sqs_message(account_id, event_body, queue_name, region_name):
pass
def _invoke_awslambda(account_id, event_body, fn_arn, region_name):
def _invoke_awslambda(
account_id: str, event_body: Any, fn_arn: str, region_name: str
) -> None:
try:
from moto.awslambda.models import lambda_backends
@ -89,7 +96,7 @@ def _invoke_awslambda(account_id, event_body, fn_arn, region_name):
pass
def _get_test_event(bucket_name):
def _get_test_event(bucket_name: str) -> Dict[str, Any]:
event_time = datetime.now().strftime(_EVENT_TIME_FORMAT)
return {
"Service": "Amazon S3",
@ -99,7 +106,7 @@ def _get_test_event(bucket_name):
}
def send_test_event(account_id, bucket):
def send_test_event(account_id: str, bucket: Any) -> None:
arns = [n.arn for n in bucket.notification_configuration.queue]
for arn in set(arns):
region_name = _get_region_from_arn(arn)

View File

@ -1,7 +1,7 @@
import io
import os
import re
from typing import List, Union
from typing import Any, Dict, List, Iterator, Union, Tuple, Optional, Type
import urllib.parse
@ -14,6 +14,7 @@ from urllib.parse import parse_qs, urlparse, unquote, urlencode, urlunparse
import xmltodict
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from moto.core.utils import path_url
@ -53,7 +54,7 @@ from .exceptions import (
AccessForbidden,
)
from .models import s3_backends, S3Backend
from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey
from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeBucket
from .select_object_content import serialize_select
from .utils import (
bucket_name_from_url,
@ -146,13 +147,13 @@ ACTION_MAP = {
}
def parse_key_name(pth):
def parse_key_name(pth: str) -> str:
# strip the first '/' left by urlparse
return pth[1:] if pth.startswith("/") else pth
class S3Response(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="s3")
@property
@ -160,10 +161,10 @@ class S3Response(BaseResponse):
return s3_backends[self.current_account]["global"]
@property
def should_autoescape(self):
def should_autoescape(self) -> bool:
return True
def all_buckets(self):
def all_buckets(self) -> str:
self.data["Action"] = "ListAllMyBuckets"
self._authenticate_and_authorize_s3_action()
@ -172,7 +173,7 @@ class S3Response(BaseResponse):
template = self.response_template(S3_ALL_BUCKETS)
return template.render(buckets=all_buckets)
def subdomain_based_buckets(self, request):
def subdomain_based_buckets(self, request: Any) -> bool:
if settings.S3_IGNORE_SUBDOMAIN_BUCKETNAME:
return False
host = request.headers.get("host", request.headers.get("Host"))
@ -224,23 +225,25 @@ class S3Response(BaseResponse):
)
return not path_based
def is_delete_keys(self):
def is_delete_keys(self) -> bool:
qs = parse_qs(urlparse(self.path).query, keep_blank_values=True)
return "delete" in qs
def parse_bucket_name_from_url(self, request, url):
def parse_bucket_name_from_url(self, request: Any, url: str) -> str:
if self.subdomain_based_buckets(request):
return bucket_name_from_url(url)
return bucket_name_from_url(url) # type: ignore
else:
return bucketpath_bucket_name_from_url(url)
return bucketpath_bucket_name_from_url(url) # type: ignore
def parse_key_name(self, request, url):
def parse_key_name(self, request: Any, url: str) -> str:
if self.subdomain_based_buckets(request):
return parse_key_name(url)
else:
return bucketpath_parse_key_name(url)
def ambiguous_response(self, request, full_url, headers):
def ambiguous_response(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
# Depending on which calling format the client is using, we don't know
# if this is a bucket or key request so we have to check
if self.subdomain_based_buckets(request):
@ -250,7 +253,7 @@ class S3Response(BaseResponse):
return self.bucket_response(request, full_url, headers)
@amzn_request_id
def bucket_response(self, request, full_url, headers):
def bucket_response(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore
self.setup_class(request, full_url, headers, use_raw_body=True)
bucket_name = self.parse_bucket_name_from_url(request, full_url)
self.backend.log_incoming_request(request, bucket_name)
@ -262,7 +265,7 @@ class S3Response(BaseResponse):
return self._send_response(response)
@staticmethod
def _send_response(response):
def _send_response(response: Any) -> TYPE_RESPONSE: # type: ignore
if isinstance(response, str):
return 200, {}, response.encode("utf-8")
else:
@ -272,7 +275,9 @@ class S3Response(BaseResponse):
return status_code, headers, response_content
def _bucket_response(self, request, full_url):
def _bucket_response(
self, request: Any, full_url: str
) -> Union[str, TYPE_RESPONSE]:
querystring = self._get_querystring(request, full_url)
method = request.method
region_name = parse_region_from_url(full_url, use_default_region=False)
@ -309,7 +314,7 @@ class S3Response(BaseResponse):
)
@staticmethod
def _get_querystring(request, full_url):
def _get_querystring(request: Any, full_url: str) -> Dict[str, Any]: # type: ignore[misc]
# Flask's Request has the querystring already parsed
# In ServerMode, we can use this, instead of manually parsing this
if hasattr(request, "args"):
@ -330,10 +335,11 @@ class S3Response(BaseResponse):
# Workaround - manually reverse the encoding.
# Keep the + encoded, ensuring that parse_qsl doesn't replace it, and parse_qsl will unquote it afterwards
qs = (parsed_url.query or "").replace("+", "%2B")
querystring = parse_qs(qs, keep_blank_values=True)
return querystring
return parse_qs(qs, keep_blank_values=True)
def _bucket_response_head(self, bucket_name, querystring):
def _bucket_response_head(
self, bucket_name: str, querystring: Dict[str, Any]
) -> TYPE_RESPONSE:
self._set_action("BUCKET", "HEAD", querystring)
self._authenticate_and_authorize_s3_action()
@ -347,7 +353,7 @@ class S3Response(BaseResponse):
return 404, {}, ""
return 200, {"x-amz-bucket-region": bucket.region_name}, ""
def _set_cors_headers(self, headers, bucket):
def _set_cors_headers(self, headers: Dict[str, str], bucket: FakeBucket) -> None:
"""
TODO: smarter way of matching the right CORS rule:
See https://docs.aws.amazon.com/AmazonS3/latest/userguide/cors.html
@ -372,8 +378,8 @@ class S3Response(BaseResponse):
)
if cors_rule.allowed_origins is not None:
origin = headers.get("Origin")
if cors_matches_origin(origin, cors_rule.allowed_origins):
self.response_headers["Access-Control-Allow-Origin"] = origin
if cors_matches_origin(origin, cors_rule.allowed_origins): # type: ignore
self.response_headers["Access-Control-Allow-Origin"] = origin # type: ignore
else:
raise AccessForbidden(
"CORSResponse: This CORS request is not allowed. This is usually because the evalution of Origin, request method / Access-Control-Request-Method or Access-Control-Request-Headers are not whitelisted by the resource's CORS spec."
@ -391,23 +397,24 @@ class S3Response(BaseResponse):
cors_rule.max_age_seconds
)
def _response_options(self, headers, bucket_name):
def _response_options(
self, headers: Dict[str, str], bucket_name: str
) -> TYPE_RESPONSE:
# Return 200 with the headers from the bucket CORS configuration
self._authenticate_and_authorize_s3_action()
try:
bucket = self.backend.head_bucket(bucket_name)
except MissingBucket:
return (
403,
{},
"",
) # AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD
# AWS S3 seems to return 403 on OPTIONS and 404 on GET/HEAD
return 403, {}, ""
self._set_cors_headers(headers, bucket)
return 200, self.response_headers, ""
def _bucket_response_get(self, bucket_name, querystring):
def _bucket_response_get(
self, bucket_name: str, querystring: Dict[str, Any]
) -> Union[str, TYPE_RESPONSE]:
self._set_action("BUCKET", "GET", querystring)
self._authenticate_and_authorize_s3_action()
@ -445,7 +452,7 @@ class S3Response(BaseResponse):
account_id=self.current_account,
)
elif "location" in querystring:
location = self.backend.get_bucket_location(bucket_name)
location: Optional[str] = self.backend.get_bucket_location(bucket_name)
template = self.response_template(S3_BUCKET_LOCATION)
# us-east-1 is different - returns a None location
@ -477,7 +484,7 @@ class S3Response(BaseResponse):
if not website_configuration:
template = self.response_template(S3_NO_BUCKET_WEBSITE_CONFIG)
return 404, {}, template.render(bucket_name=bucket_name)
return 200, {}, website_configuration
return 200, {}, website_configuration # type: ignore
elif "acl" in querystring:
acl = self.backend.get_bucket_acl(bucket_name)
template = self.response_template(S3_OBJECT_ACL_RESPONSE)
@ -615,7 +622,9 @@ class S3Response(BaseResponse):
),
)
def _set_action(self, action_resource_type, method, querystring):
def _set_action(
self, action_resource_type: str, method: str, querystring: Dict[str, Any]
) -> None:
action_set = False
for action_in_querystring, action in ACTION_MAP[action_resource_type][
method
@ -626,7 +635,9 @@ class S3Response(BaseResponse):
if not action_set:
self.data["Action"] = ACTION_MAP[action_resource_type][method]["DEFAULT"]
def _handle_list_objects_v2(self, bucket_name, querystring):
def _handle_list_objects_v2(
self, bucket_name: str, querystring: Dict[str, Any]
) -> str:
template = self.response_template(S3_BUCKET_GET_RESPONSE_V2)
bucket = self.backend.get_bucket(bucket_name)
@ -678,7 +689,7 @@ class S3Response(BaseResponse):
)
@staticmethod
def _split_truncated_keys(truncated_keys):
def _split_truncated_keys(truncated_keys: Any) -> Any: # type: ignore[misc]
result_keys = []
result_folders = []
for key in truncated_keys:
@ -688,7 +699,7 @@ class S3Response(BaseResponse):
result_folders.append(key)
return result_keys, result_folders
def _get_results_from_token(self, result_keys, token):
def _get_results_from_token(self, result_keys: Any, token: Any) -> Any:
continuation_index = 0
for key in result_keys:
if (key.name if isinstance(key, FakeKey) else key) > token:
@ -696,22 +707,22 @@ class S3Response(BaseResponse):
continuation_index += 1
return result_keys[continuation_index:]
def _truncate_result(self, result_keys, max_keys):
def _truncate_result(self, result_keys: Any, max_keys: int) -> Any:
if max_keys == 0:
result_keys = []
is_truncated = True
next_continuation_token = None
elif len(result_keys) > max_keys:
is_truncated = "true"
is_truncated = "true" # type: ignore
result_keys = result_keys[:max_keys]
item = result_keys[-1]
next_continuation_token = item.name if isinstance(item, FakeKey) else item
else:
is_truncated = "false"
is_truncated = "false" # type: ignore
next_continuation_token = None
return result_keys, is_truncated, next_continuation_token
def _body_contains_location_constraint(self, body):
def _body_contains_location_constraint(self, body: bytes) -> bool:
if body:
try:
xmltodict.parse(body)["CreateBucketConfiguration"]["LocationConstraint"]
@ -720,7 +731,7 @@ class S3Response(BaseResponse):
pass
return False
def _create_bucket_configuration_is_empty(self, body):
def _create_bucket_configuration_is_empty(self, body: bytes) -> bool:
if body:
try:
create_bucket_configuration = xmltodict.parse(body)[
@ -733,13 +744,19 @@ class S3Response(BaseResponse):
pass
return False
def _parse_pab_config(self):
def _parse_pab_config(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
return parsed_xml
def _bucket_response_put(self, request, region_name, bucket_name, querystring):
def _bucket_response_put(
self,
request: Any,
region_name: str,
bucket_name: str,
querystring: Dict[str, Any],
) -> Union[str, TYPE_RESPONSE]:
if querystring and not request.headers.get("Content-Length"):
return 411, {}, "Content-Length required"
@ -754,7 +771,7 @@ class S3Response(BaseResponse):
self.backend.put_object_lock_configuration(
bucket_name,
config.get("enabled"),
config.get("enabled"), # type: ignore
config.get("mode"),
config.get("days"),
config.get("years"),
@ -765,7 +782,7 @@ class S3Response(BaseResponse):
body = self.body.decode("utf-8")
ver = re.search(r"<Status>([A-Za-z]+)</Status>", body)
if ver:
self.backend.put_bucket_versioning(bucket_name, ver.group(1))
self.backend.put_bucket_versioning(bucket_name, ver.group(1)) # type: ignore
template = self.response_template(S3_BUCKET_VERSIONING)
return template.render(bucket_versioning_status=ver.group(1))
else:
@ -922,7 +939,9 @@ class S3Response(BaseResponse):
template = self.response_template(S3_BUCKET_CREATE_RESPONSE)
return 200, {}, template.render(bucket=new_bucket)
def _bucket_response_delete(self, bucket_name, querystring):
def _bucket_response_delete(
self, bucket_name: str, querystring: Dict[str, Any]
) -> TYPE_RESPONSE:
self._set_action("BUCKET", "DELETE", querystring)
self._authenticate_and_authorize_s3_action()
@ -965,7 +984,7 @@ class S3Response(BaseResponse):
template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR)
return 409, {}, template.render(bucket=removed_bucket)
def _bucket_response_post(self, request, bucket_name):
def _bucket_response_post(self, request: Any, bucket_name: str) -> TYPE_RESPONSE:
response_headers = {}
if not request.headers.get("Content-Length"):
return 411, {}, "Content-Length required"
@ -999,7 +1018,7 @@ class S3Response(BaseResponse):
if "success_action_redirect" in form:
redirect = form["success_action_redirect"]
parts = urlparse(redirect)
queryargs = parse_qs(parts.query)
queryargs: Dict[str, Any] = parse_qs(parts.query)
queryargs["key"] = key
queryargs["bucket"] = bucket_name
redirect_queryargs = urlencode(queryargs, doseq=True)
@ -1035,14 +1054,16 @@ class S3Response(BaseResponse):
return status_code, response_headers, ""
@staticmethod
def _get_path(request):
def _get_path(request: Any) -> str: # type: ignore[misc]
return (
request.full_path
if hasattr(request, "full_path")
else path_url(request.url)
)
def _bucket_response_delete_keys(self, bucket_name, authenticated=True):
def _bucket_response_delete_keys(
self, bucket_name: str, authenticated: bool = True
) -> TYPE_RESPONSE:
template = self.response_template(S3_DELETE_KEYS_RESPONSE)
body_dict = xmltodict.parse(self.body, strip_whitespace=False)
@ -1068,14 +1089,16 @@ class S3Response(BaseResponse):
template.render(deleted=deleted_objects, delete_errors=errors),
)
def _handle_range_header(self, request, response_headers, response_content):
def _handle_range_header(
self, request: Any, response_headers: Dict[str, Any], response_content: Any
) -> TYPE_RESPONSE:
length = len(response_content)
last = length - 1
_, rspec = request.headers.get("range").split("=")
if "," in rspec:
raise NotImplementedError("Multiple range specifiers not supported")
def toint(i):
def toint(i: Any) -> Optional[int]:
return int(i) if i else None
begin, end = map(toint, rspec.split("-"))
@ -1095,7 +1118,7 @@ class S3Response(BaseResponse):
response_headers["content-length"] = len(content)
return 206, response_headers, content
def _handle_v4_chunk_signatures(self, body, content_length):
def _handle_v4_chunk_signatures(self, body: bytes, content_length: int) -> bytes:
body_io = io.BytesIO(body)
new_body = bytearray(content_length)
pos = 0
@ -1110,7 +1133,7 @@ class S3Response(BaseResponse):
line = body_io.readline()
return bytes(new_body)
def _handle_encoded_body(self, body, content_length):
def _handle_encoded_body(self, body: bytes, content_length: int) -> bytes:
body_io = io.BytesIO(body)
# first line should equal '{content_length}\r\n
body_io.readline()
@ -1120,12 +1143,12 @@ class S3Response(BaseResponse):
# amz-checksum-sha256:<..>\r\n
@amzn_request_id
def key_response(self, request, full_url, headers):
def key_response(self, request: Any, full_url: str, headers: Dict[str, Any]) -> TYPE_RESPONSE: # type: ignore[misc]
# Key and Control are lumped in because splitting out the regex is too much of a pain :/
self.setup_class(request, full_url, headers, use_raw_body=True)
bucket_name = self.parse_bucket_name_from_url(request, full_url)
self.backend.log_incoming_request(request, bucket_name)
response_headers = {}
response_headers: Dict[str, Any] = {}
try:
response = self._key_response(request, full_url, self.headers)
@ -1151,7 +1174,9 @@ class S3Response(BaseResponse):
return s3error.code, {}, s3error.description
return status_code, response_headers, response_content
def _key_response(self, request, full_url, headers):
def _key_response(
self, request: Any, full_url: str, headers: Dict[str, Any]
) -> TYPE_RESPONSE:
parsed_url = urlparse(full_url)
query = parse_qs(parsed_url.query, keep_blank_values=True)
method = request.method
@ -1182,7 +1207,7 @@ class S3Response(BaseResponse):
from moto.iam.access_control import PermissionResult
action = f"s3:{method.upper()[0]}{method.lower()[1:]}Object"
bucket_permissions = bucket.get_permission(action, resource)
bucket_permissions = bucket.get_permission(action, resource) # type: ignore
if bucket_permissions == PermissionResult.DENIED:
return 403, {}, ""
@ -1255,11 +1280,17 @@ class S3Response(BaseResponse):
f"Method {method} has not been implemented in the S3 backend yet"
)
def _key_response_get(self, bucket_name, query, key_name, headers):
def _key_response_get(
self,
bucket_name: str,
query: Dict[str, Any],
key_name: str,
headers: Dict[str, Any],
) -> TYPE_RESPONSE:
self._set_action("KEY", "GET", query)
self._authenticate_and_authorize_s3_action()
response_headers = {}
response_headers: Dict[str, Any] = {}
if query.get("uploadId"):
upload_id = query["uploadId"][0]
@ -1287,7 +1318,7 @@ class S3Response(BaseResponse):
)
next_part_number_marker = parts[-1].name if parts else 0
is_truncated = len(parts) != 0 and self.backend.is_truncated(
bucket_name, upload_id, next_part_number_marker
bucket_name, upload_id, next_part_number_marker # type: ignore
)
template = self.response_template(S3_MULTIPART_LIST_RESPONSE)
@ -1355,7 +1386,7 @@ class S3Response(BaseResponse):
attributes_to_get = headers.get("x-amz-object-attributes", "").split(",")
response_keys = self.backend.get_object_attributes(key, attributes_to_get)
if key.version_id == "null":
if key.version_id == "null": # type: ignore
response_headers.pop("x-amz-version-id")
response_headers["Last-Modified"] = key.last_modified_ISO8601
@ -1367,11 +1398,18 @@ class S3Response(BaseResponse):
response_headers.update({"AcceptRanges": "bytes"})
return 200, response_headers, key.value
def _key_response_put(self, request, body, bucket_name, query, key_name):
def _key_response_put(
self,
request: Any,
body: bytes,
bucket_name: str,
query: Dict[str, Any],
key_name: str,
) -> TYPE_RESPONSE:
self._set_action("KEY", "PUT", query)
self._authenticate_and_authorize_s3_action()
response_headers = {}
response_headers: Dict[str, Any] = {}
if query.get("uploadId") and query.get("partNumber"):
upload_id = query["uploadId"][0]
part_number = int(query["partNumber"][0])
@ -1382,7 +1420,7 @@ class S3Response(BaseResponse):
copy_source_parsed = urlparse(copy_source)
src_bucket, src_key = copy_source_parsed.path.lstrip("/").split("/", 1)
src_version_id = parse_qs(copy_source_parsed.query).get(
"versionId", [None]
"versionId", [None] # type: ignore
)[0]
src_range = request.headers.get("x-amz-copy-source-range", "").split(
"bytes="
@ -1515,9 +1553,11 @@ class S3Response(BaseResponse):
version_id = query["versionId"][0]
else:
version_id = None
key = self.backend.get_object(bucket_name, key_name, version_id=version_id)
key_to_tag = self.backend.get_object(
bucket_name, key_name, version_id=version_id
)
tagging = self._tagging_from_xml(body)
self.backend.set_key_tags(key, tagging, key_name)
self.backend.set_key_tags(key_to_tag, tagging, key_name)
return 200, response_headers, ""
if "x-amz-copy-source" in request.headers:
@ -1532,21 +1572,21 @@ class S3Response(BaseResponse):
unquote(copy_source_parsed.path).lstrip("/").split("/", 1)
)
src_version_id = parse_qs(copy_source_parsed.query).get(
"versionId", [None]
"versionId", [None] # type: ignore
)[0]
key = self.backend.get_object(
key_to_copy = self.backend.get_object(
src_bucket, src_key, version_id=src_version_id, key_is_clean=True
)
if key is not None:
if key.storage_class in ARCHIVE_STORAGE_CLASSES:
if key.response_dict.get(
if key_to_copy is not None:
if key_to_copy.storage_class in ARCHIVE_STORAGE_CLASSES:
if key_to_copy.response_dict.get(
"x-amz-restore"
) is None or 'ongoing-request="true"' in key.response_dict.get(
) is None or 'ongoing-request="true"' in key_to_copy.response_dict.get( # type: ignore
"x-amz-restore"
):
raise ObjectNotInActiveTierError(key)
raise ObjectNotInActiveTierError(key_to_copy)
bucket_key_enabled = (
request.headers.get(
@ -1558,7 +1598,7 @@ class S3Response(BaseResponse):
mdirective = request.headers.get("x-amz-metadata-directive")
self.backend.copy_object(
key,
key_to_copy,
bucket_name,
key_name,
storage=storage_class,
@ -1571,7 +1611,7 @@ class S3Response(BaseResponse):
else:
raise MissingKey(key=src_key)
new_key = self.backend.get_object(bucket_name, key_name)
new_key: FakeKey = self.backend.get_object(bucket_name, key_name) # type: ignore
if mdirective is not None and mdirective == "REPLACE":
metadata = metadata_from_headers(request.headers)
new_key.set_metadata(metadata, replace=True)
@ -1612,11 +1652,17 @@ class S3Response(BaseResponse):
response_headers.update(new_key.response_dict)
return 200, response_headers, ""
def _key_response_head(self, bucket_name, query, key_name, headers):
def _key_response_head(
self,
bucket_name: str,
query: Dict[str, Any],
key_name: str,
headers: Dict[str, Any],
) -> TYPE_RESPONSE:
self._set_action("KEY", "HEAD", query)
self._authenticate_and_authorize_s3_action()
response_headers = {}
response_headers: Dict[str, Any] = {}
version_id = query.get("versionId", [None])[0]
if version_id and not self.backend.get_bucket(bucket_name).is_versioned:
return 400, response_headers, ""
@ -1654,16 +1700,21 @@ class S3Response(BaseResponse):
if part_number:
full_key = self.backend.head_object(bucket_name, key_name, version_id)
if full_key.multipart:
mp_part_count = str(len(full_key.multipart.partlist))
if full_key.multipart: # type: ignore
mp_part_count = str(len(full_key.multipart.partlist)) # type: ignore
response_headers["x-amz-mp-parts-count"] = mp_part_count
return 200, response_headers, ""
else:
return 404, response_headers, ""
def _lock_config_from_body(self):
response_dict = {"enabled": False, "mode": None, "days": None, "years": None}
def _lock_config_from_body(self) -> Dict[str, Any]:
response_dict: Dict[str, Any] = {
"enabled": False,
"mode": None,
"days": None,
"years": None,
}
parsed_xml = xmltodict.parse(self.body)
enabled = (
parsed_xml["ObjectLockConfiguration"]["ObjectLockEnabled"] == "Enabled"
@ -1685,7 +1736,7 @@ class S3Response(BaseResponse):
return response_dict
def _acl_from_body(self):
def _acl_from_body(self) -> Optional[FakeAcl]:
parsed_xml = xmltodict.parse(self.body)
if not parsed_xml.get("AccessControlPolicy"):
raise MalformedACLError()
@ -1697,7 +1748,7 @@ class S3Response(BaseResponse):
# If empty, then no ACLs:
if parsed_xml["AccessControlPolicy"].get("AccessControlList") is None:
return []
return None
if not parsed_xml["AccessControlPolicy"]["AccessControlList"].get("Grant"):
raise MalformedACLError()
@ -1718,7 +1769,12 @@ class S3Response(BaseResponse):
)
return FakeAcl(grants)
def _get_grants_from_xml(self, grant_list, exception_type, permissions):
def _get_grants_from_xml(
self,
grant_list: List[Dict[str, Any]],
exception_type: Type[S3ClientError],
permissions: List[str],
) -> List[FakeGrant]:
grants = []
for grant in grant_list:
if grant.get("Permission", "") not in permissions:
@ -1748,7 +1804,7 @@ class S3Response(BaseResponse):
return grants
def _acl_from_headers(self, headers):
def _acl_from_headers(self, headers: Dict[str, str]) -> Optional[FakeAcl]:
canned_acl = headers.get("x-amz-acl", "")
grants = []
@ -1767,7 +1823,7 @@ class S3Response(BaseResponse):
grantees = []
for key_and_value in value.split(","):
key, value = re.match(
key, value = re.match( # type: ignore
'([^=]+)="?([^"]+)"?', key_and_value.strip()
).groups()
if key.lower() == "id":
@ -1785,7 +1841,7 @@ class S3Response(BaseResponse):
else:
return None
def _tagging_from_headers(self, headers):
def _tagging_from_headers(self, headers: Dict[str, Any]) -> Dict[str, str]:
tags = {}
if headers.get("x-amz-tagging"):
parsed_header = parse_qs(headers["x-amz-tagging"], keep_blank_values=True)
@ -1793,7 +1849,7 @@ class S3Response(BaseResponse):
tags[tag[0]] = tag[1][0]
return tags
def _tagging_from_xml(self, xml):
def _tagging_from_xml(self, xml: bytes) -> Dict[str, str]:
parsed_xml = xmltodict.parse(xml, force_list={"Tag": True})
tags = {}
@ -1802,7 +1858,7 @@ class S3Response(BaseResponse):
return tags
def _bucket_tagging_from_body(self):
def _bucket_tagging_from_body(self) -> Dict[str, str]:
parsed_xml = xmltodict.parse(self.body)
tags = {}
@ -1826,7 +1882,7 @@ class S3Response(BaseResponse):
return tags
def _cors_from_body(self):
def _cors_from_body(self) -> List[Dict[str, Any]]:
parsed_xml = xmltodict.parse(self.body)
if isinstance(parsed_xml["CORSConfiguration"]["CORSRule"], list):
@ -1834,18 +1890,18 @@ class S3Response(BaseResponse):
return [parsed_xml["CORSConfiguration"]["CORSRule"]]
def _mode_until_from_body(self):
def _mode_until_from_body(self) -> Tuple[Optional[str], Optional[str]]:
parsed_xml = xmltodict.parse(self.body)
return (
parsed_xml.get("Retention", None).get("Mode", None),
parsed_xml.get("Retention", None).get("RetainUntilDate", None),
)
def _legal_hold_status_from_xml(self, xml):
def _legal_hold_status_from_xml(self, xml: bytes) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(xml)
return parsed_xml["LegalHold"]["Status"]
def _encryption_config_from_body(self):
def _encryption_config_from_body(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body)
if (
@ -1861,7 +1917,7 @@ class S3Response(BaseResponse):
return parsed_xml["ServerSideEncryptionConfiguration"]
def _ownership_rule_from_body(self):
def _ownership_rule_from_body(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body)
if not parsed_xml["OwnershipControls"]["Rule"].get("ObjectOwnership"):
@ -1869,7 +1925,7 @@ class S3Response(BaseResponse):
return parsed_xml["OwnershipControls"]["Rule"]["ObjectOwnership"]
def _logging_from_body(self):
def _logging_from_body(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body)
if not parsed_xml["BucketLoggingStatus"].get("LoggingEnabled"):
@ -1914,7 +1970,7 @@ class S3Response(BaseResponse):
return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]
def _notification_config_from_body(self):
def _notification_config_from_body(self) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(self.body)
if not len(parsed_xml["NotificationConfiguration"]):
@ -1989,17 +2045,19 @@ class S3Response(BaseResponse):
return parsed_xml["NotificationConfiguration"]
def _accelerate_config_from_body(self):
def _accelerate_config_from_body(self) -> str:
parsed_xml = xmltodict.parse(self.body)
config = parsed_xml["AccelerateConfiguration"]
return config["Status"]
def _replication_config_from_xml(self, xml):
def _replication_config_from_xml(self, xml: str) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(xml, dict_constructor=dict)
config = parsed_xml["ReplicationConfiguration"]
return config
def _key_response_delete(self, headers, bucket_name, query, key_name):
def _key_response_delete(
self, headers: Any, bucket_name: str, query: Dict[str, Any], key_name: str
) -> TYPE_RESPONSE:
self._set_action("KEY", "DELETE", query)
self._authenticate_and_authorize_s3_action()
@ -2024,7 +2082,7 @@ class S3Response(BaseResponse):
response_headers[f"x-amz-{k}"] = response_meta[k]
return 204, response_headers, ""
def _complete_multipart_body(self, body):
def _complete_multipart_body(self, body: bytes) -> Iterator[Tuple[int, str]]:
ps = minidom.parseString(body).getElementsByTagName("Part")
prev = 0
for p in ps:
@ -2033,7 +2091,14 @@ class S3Response(BaseResponse):
raise InvalidPartOrder()
yield (pn, p.getElementsByTagName("ETag")[0].firstChild.wholeText)
def _key_response_post(self, request, body, bucket_name, query, key_name):
def _key_response_post(
self,
request: Any,
body: bytes,
bucket_name: str,
query: Dict[str, Any],
key_name: str,
) -> TYPE_RESPONSE:
self._set_action("KEY", "POST", query)
self._authenticate_and_authorize_s3_action()
@ -2071,11 +2136,10 @@ class S3Response(BaseResponse):
return 200, response_headers, response
if query.get("uploadId"):
body = self._complete_multipart_body(body)
multipart_id = query["uploadId"][0]
multipart_id = query["uploadId"][0] # type: ignore
multipart, value, etag = self.backend.complete_multipart_upload(
bucket_name, multipart_id, body
bucket_name, multipart_id, self._complete_multipart_body(body)
)
if value is None:
return 400, {}, ""
@ -2095,7 +2159,7 @@ class S3Response(BaseResponse):
self.backend.put_object_acl(bucket_name, key.name, multipart.acl)
template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE)
headers = {}
headers: Dict[str, Any] = {}
if key.version_id:
headers["x-amz-version-id"] = key.version_id
@ -2116,7 +2180,7 @@ class S3Response(BaseResponse):
elif "restore" in query:
es = minidom.parseString(body).getElementsByTagName("Days")
days = es[0].childNodes[0].wholeText
key = self.backend.get_object(bucket_name, key_name)
key = self.backend.get_object(bucket_name, key_name) # type: ignore
if key.storage_class not in ARCHIVE_STORAGE_CLASSES:
raise InvalidObjectState(storage_class=key.storage_class)
r = 202
@ -2139,7 +2203,7 @@ class S3Response(BaseResponse):
"Method POST had only been implemented for multipart uploads and restore operations, so far"
)
def _invalid_headers(self, url, headers):
def _invalid_headers(self, url: str, headers: Dict[str, str]) -> bool:
"""
Verify whether the provided metadata in the URL is also present in the headers
:param url: .../file.txt&content-type=app%2Fjson&Signature=..

View File

@ -1,19 +1,21 @@
import binascii
import struct
from typing import List
from typing import Any, Dict, List, Optional
def parse_query(text_input, query):
def parse_query(text_input: str, query: str) -> List[Dict[str, Any]]:
from py_partiql_parser import S3SelectParser
return S3SelectParser(source_data={"s3object": text_input}).parse(query)
def _create_header(key: bytes, value: bytes):
def _create_header(key: bytes, value: bytes) -> bytes:
return struct.pack("b", len(key)) + key + struct.pack("!bh", 7, len(value)) + value
def _create_message(content_type, event_type, payload):
def _create_message(
content_type: Optional[bytes], event_type: bytes, payload: bytes
) -> bytes:
headers = _create_header(b":message-type", b"event")
headers += _create_header(b":event-type", event_type)
if content_type is not None:
@ -31,23 +33,23 @@ def _create_message(content_type, event_type, payload):
return prelude + prelude_crc + headers + payload + message_crc
def _create_stats_message():
def _create_stats_message() -> bytes:
stats = b"""<Stats><BytesScanned>24</BytesScanned><BytesProcessed>24</BytesProcessed><BytesReturned>22</BytesReturned></Stats>"""
return _create_message(content_type=b"text/xml", event_type=b"Stats", payload=stats)
def _create_data_message(payload: bytes):
def _create_data_message(payload: bytes) -> bytes:
# https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html
return _create_message(
content_type=b"application/octet-stream", event_type=b"Records", payload=payload
)
def _create_end_message():
def _create_end_message() -> bytes:
return _create_message(content_type=None, event_type=b"End", payload=b"")
def serialize_select(data_list: List[bytes]):
def serialize_select(data_list: List[bytes]) -> bytes:
response = b""
for data in data_list:
response += _create_data_message(data + b",")

View File

@ -5,7 +5,7 @@ import re
import hashlib
from urllib.parse import urlparse, unquote, quote
from requests.structures import CaseInsensitiveDict
from typing import List, Union, Tuple
from typing import Any, Dict, List, Iterator, Union, Tuple, Optional
import sys
from moto.settings import S3_IGNORE_SUBDOMAIN_BUCKETNAME
@ -38,7 +38,7 @@ STORAGE_CLASS = [
] + ARCHIVE_STORAGE_CLASSES
def bucket_name_from_url(url):
def bucket_name_from_url(url: str) -> Optional[str]: # type: ignore
if S3_IGNORE_SUBDOMAIN_BUCKETNAME:
return None
domain = urlparse(url).netloc
@ -75,7 +75,7 @@ REGION_URL_REGEX = re.compile(
)
def parse_region_from_url(url, use_default_region=True):
def parse_region_from_url(url: str, use_default_region: bool = True) -> str:
match = REGION_URL_REGEX.search(url)
if match:
region = match.group("region1") or match.group("region2")
@ -84,8 +84,8 @@ def parse_region_from_url(url, use_default_region=True):
return region
def metadata_from_headers(headers):
metadata = CaseInsensitiveDict()
def metadata_from_headers(headers: Dict[str, Any]) -> CaseInsensitiveDict: # type: ignore
metadata = CaseInsensitiveDict() # type: ignore
meta_regex = re.compile(r"^x-amz-meta-([a-zA-Z0-9\-_.]+)$", flags=re.IGNORECASE)
for header in headers.keys():
if isinstance(header, str):
@ -106,32 +106,32 @@ def metadata_from_headers(headers):
return metadata
def clean_key_name(key_name):
def clean_key_name(key_name: str) -> str:
return unquote(key_name)
def undo_clean_key_name(key_name):
def undo_clean_key_name(key_name: str) -> str:
return quote(key_name)
class _VersionedKeyStore(dict):
class _VersionedKeyStore(dict): # type: ignore
"""A simplified/modified version of Django's `MultiValueDict` taken from:
https://github.com/django/django/blob/70576740b0bb5289873f5a9a9a4e1a26b2c330e5/django/utils/datastructures.py#L282
"""
def __sgetitem__(self, key):
def __sgetitem__(self, key: str) -> List[Any]:
return super().__getitem__(key)
def pop(self, key):
def pop(self, key: str) -> None: # type: ignore
for version in self.getlist(key, []):
version.dispose()
super().pop(key)
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
return self.__sgetitem__(key)[-1]
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> Any:
try:
current = self.__sgetitem__(key)
current.append(value)
@ -140,21 +140,21 @@ class _VersionedKeyStore(dict):
super().__setitem__(key, current)
def get(self, key, default=None):
def get(self, key: str, default: Any = None) -> Any:
try:
return self[key]
except (KeyError, IndexError):
pass
return default
def getlist(self, key, default=None):
def getlist(self, key: str, default: Any = None) -> Any:
try:
return self.__sgetitem__(key)
except (KeyError, IndexError):
pass
return default
def setlist(self, key, list_):
def setlist(self, key: Any, list_: Any) -> Any:
if isinstance(list_, tuple):
list_ = list(list_)
elif not isinstance(list_, list):
@ -168,35 +168,35 @@ class _VersionedKeyStore(dict):
super().__setitem__(key, list_)
def _iteritems(self):
def _iteritems(self) -> Iterator[Tuple[str, Any]]:
for key in self._self_iterable():
yield key, self[key]
def _itervalues(self):
def _itervalues(self) -> Iterator[Any]:
for key in self._self_iterable():
yield self[key]
def _iterlists(self):
def _iterlists(self) -> Iterator[Tuple[str, List[Any]]]:
for key in self._self_iterable():
yield key, self.getlist(key)
def item_size(self):
def item_size(self) -> int:
size = 0
for val in self._self_iterable().values():
size += sys.getsizeof(val)
return size
def _self_iterable(self):
def _self_iterable(self) -> Dict[str, Any]:
# to enable concurrency, return a copy, to avoid "dictionary changed size during iteration"
# TODO: look into replacing with a locking mechanism, potentially
return dict(self)
items = iteritems = _iteritems
items = iteritems = _iteritems # type: ignore
lists = iterlists = _iterlists
values = itervalues = _itervalues
values = itervalues = _itervalues # type: ignore
def compute_checksum(body, algorithm):
def compute_checksum(body: bytes, algorithm: str) -> bytes:
if algorithm == "SHA1":
hashed_body = _hash(hashlib.sha1, (body,))
elif algorithm == "CRC32" or algorithm == "CRC32C":
@ -206,7 +206,7 @@ def compute_checksum(body, algorithm):
return base64.b64encode(hashed_body)
def _hash(fn, args) -> bytes:
def _hash(fn: Any, args: Any) -> bytes:
try:
return fn(*args, usedforsecurity=False).hexdigest().encode("utf-8")
except TypeError:

View File

@ -1,7 +1,8 @@
from typing import Optional
from urllib.parse import urlparse
def bucket_name_from_url(url):
def bucket_name_from_url(url: str) -> Optional[str]:
path = urlparse(url).path.lstrip("/")
parts = path.lstrip("/").split("/")
@ -10,5 +11,5 @@ def bucket_name_from_url(url):
return parts[0]
def parse_key_name(path):
def parse_key_name(path: str) -> str:
return "/".join(path.split("/")[2:])

View File

@ -2,6 +2,7 @@ import datetime
import json
from boto3 import Session
from typing import Any, Dict, List, Optional, Tuple
from moto.core.exceptions import InvalidNextTokenException
from moto.core.common_models import ConfigQueryModel
@ -12,15 +13,15 @@ from moto.s3control import s3control_backends
class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel):
def list_config_service_resources(
self,
account_id,
resource_ids,
resource_name,
limit,
next_token,
backend_region=None,
resource_region=None,
aggregator=None,
):
account_id: str,
resource_ids: Optional[List[str]],
resource_name: Optional[str],
limit: int,
next_token: Optional[str],
backend_region: Optional[str] = None,
resource_region: Optional[str] = None,
aggregator: Any = None,
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
# For the Account Public Access Block, they are the same for all regions. The resource ID is the AWS account ID
# There is no resource name -- it should be a blank string "" if provided.
@ -95,12 +96,12 @@ class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel):
def get_config_resource(
self,
account_id,
resource_id,
resource_name=None,
backend_region=None,
resource_region=None,
):
account_id: str,
resource_id: str,
resource_name: Optional[str] = None,
backend_region: Optional[str] = None,
resource_region: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
# Do we even have this defined?
backend = self.backends[account_id]["global"]
@ -116,7 +117,7 @@ class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel):
# Is the resource ID correct?:
if account_id == resource_id:
if backend_region:
pab_region = backend_region
pab_region: Optional[str] = backend_region
# Invalid region?
elif resource_region not in regions:

View File

@ -1,4 +1,4 @@
"""Exceptions raised by the s3control service."""
from typing import Any
from moto.core.exceptions import RESTError
@ -13,7 +13,7 @@ ERROR_WITH_ACCESS_POINT_POLICY = """{% extends 'wrapped_single_error' %}
class S3ControlError(RESTError):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault("template", "single_error")
super().__init__(*args, **kwargs)
@ -21,7 +21,7 @@ class S3ControlError(RESTError):
class AccessPointNotFound(S3ControlError):
code = 404
def __init__(self, name, **kwargs):
def __init__(self, name: str, **kwargs: Any):
kwargs.setdefault("template", "ap_not_found")
kwargs["name"] = name
self.templates["ap_not_found"] = ERROR_WITH_ACCESS_POINT_NAME
@ -33,7 +33,7 @@ class AccessPointNotFound(S3ControlError):
class AccessPointPolicyNotFound(S3ControlError):
code = 404
def __init__(self, name, **kwargs):
def __init__(self, name: str, **kwargs: Any):
kwargs.setdefault("template", "apf_not_found")
kwargs["name"] = name
self.templates["apf_not_found"] = ERROR_WITH_ACCESS_POINT_POLICY

View File

@ -1,5 +1,7 @@
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random
from moto.s3.exceptions import (
@ -15,18 +17,18 @@ from .exceptions import AccessPointNotFound, AccessPointPolicyNotFound
class AccessPoint(BaseModel):
def __init__(
self,
account_id,
name,
bucket,
vpc_configuration,
public_access_block_configuration,
account_id: str,
name: str,
bucket: str,
vpc_configuration: Dict[str, Any],
public_access_block_configuration: Dict[str, Any],
):
self.name = name
self.alias = f"{name}-{mock_random.get_random_hex(34)}-s3alias"
self.bucket = bucket
self.created = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
self.arn = f"arn:aws:s3:us-east-1:{account_id}:accesspoint/{name}"
self.policy = None
self.policy: Optional[str] = None
self.network_origin = "VPC" if vpc_configuration else "Internet"
self.vpc_id = (vpc_configuration or {}).get("VpcId")
pubc = public_access_block_configuration or {}
@ -37,23 +39,23 @@ class AccessPoint(BaseModel):
"RestrictPublicBuckets": pubc.get("RestrictPublicBuckets", "true"),
}
def delete_policy(self):
def delete_policy(self) -> None:
self.policy = None
def set_policy(self, policy):
def set_policy(self, policy: str) -> None:
self.policy = policy
def has_policy(self):
def has_policy(self) -> bool:
return self.policy is not None
class S3ControlBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.public_access_block = None
self.access_points = defaultdict(dict)
self.public_access_block: Optional[PublicAccessBlock] = None
self.access_points: Dict[str, Dict[str, AccessPoint]] = defaultdict(dict)
def get_public_access_block(self, account_id):
def get_public_access_block(self, account_id: str) -> PublicAccessBlock:
# The account ID should equal the account id that is set for Moto:
if account_id != self.account_id:
raise WrongPublicAccessBlockAccountIdError()
@ -63,14 +65,16 @@ class S3ControlBackend(BaseBackend):
return self.public_access_block
def delete_public_access_block(self, account_id):
def delete_public_access_block(self, account_id: str) -> None:
# The account ID should equal the account id that is set for Moto:
if account_id != self.account_id:
raise WrongPublicAccessBlockAccountIdError()
self.public_access_block = None
def put_public_access_block(self, account_id, pub_block_config):
def put_public_access_block(
self, account_id: str, pub_block_config: Dict[str, Any]
) -> None:
# The account ID should equal the account id that is set for Moto:
if account_id != self.account_id:
raise WrongPublicAccessBlockAccountIdError()
@ -87,12 +91,12 @@ class S3ControlBackend(BaseBackend):
def create_access_point(
self,
account_id,
name,
bucket,
vpc_configuration,
public_access_block_configuration,
):
account_id: str,
name: str,
bucket: str,
vpc_configuration: Dict[str, Any],
public_access_block_configuration: Dict[str, Any],
) -> AccessPoint:
access_point = AccessPoint(
account_id,
name,
@ -103,29 +107,31 @@ class S3ControlBackend(BaseBackend):
self.access_points[account_id][name] = access_point
return access_point
def delete_access_point(self, account_id, name):
def delete_access_point(self, account_id: str, name: str) -> None:
self.access_points[account_id].pop(name, None)
def get_access_point(self, account_id, name):
def get_access_point(self, account_id: str, name: str) -> AccessPoint:
if name not in self.access_points[account_id]:
raise AccessPointNotFound(name)
return self.access_points[account_id][name]
def create_access_point_policy(self, account_id, name, policy):
def create_access_point_policy(
self, account_id: str, name: str, policy: str
) -> None:
access_point = self.get_access_point(account_id, name)
access_point.set_policy(policy)
def get_access_point_policy(self, account_id, name):
def get_access_point_policy(self, account_id: str, name: str) -> str:
access_point = self.get_access_point(account_id, name)
if access_point.has_policy():
return access_point.policy
return access_point.policy # type: ignore[return-value]
raise AccessPointPolicyNotFound(name)
def delete_access_point_policy(self, account_id, name):
def delete_access_point_policy(self, account_id: str, name: str) -> None:
access_point = self.get_access_point(account_id, name)
access_point.delete_policy()
def get_access_point_policy_status(self, account_id, name):
def get_access_point_policy_status(self, account_id: str, name: str) -> bool:
"""
We assume the policy status is always public
"""

View File

@ -1,23 +1,25 @@
import json
import xmltodict
from typing import Any, Dict, Tuple
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from moto.s3.exceptions import S3ClientError
from moto.s3.responses import S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION
from moto.utilities.aws_headers import amzn_request_id
from .models import s3control_backends
from .models import s3control_backends, S3ControlBackend
class S3ControlResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="s3control")
@property
def backend(self):
def backend(self) -> S3ControlBackend:
return s3control_backends[self.current_account]["global"]
@amzn_request_id
def public_access_block(self, request, full_url, headers):
def public_access_block(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore
self.setup_class(request, full_url, headers)
try:
if request.method == "GET":
@ -29,7 +31,7 @@ class S3ControlResponse(BaseResponse):
except S3ClientError as err:
return err.code, {}, err.description
def get_public_access_block(self, request):
def get_public_access_block(self, request: Any) -> TYPE_RESPONSE:
account_id = request.headers.get("x-amz-account-id")
public_block_config = self.backend.get_public_access_block(
account_id=account_id
@ -37,7 +39,7 @@ class S3ControlResponse(BaseResponse):
template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION)
return 200, {}, template.render(public_block_config=public_block_config)
def put_public_access_block(self, request):
def put_public_access_block(self, request: Any) -> TYPE_RESPONSE:
account_id = request.headers.get("x-amz-account-id")
data = request.body if hasattr(request, "body") else request.data
pab_config = self._parse_pab_config(data)
@ -46,18 +48,18 @@ class S3ControlResponse(BaseResponse):
)
return 201, {}, json.dumps({})
def delete_public_access_block(self, request):
def delete_public_access_block(self, request: Any) -> TYPE_RESPONSE:
account_id = request.headers.get("x-amz-account-id")
self.backend.delete_public_access_block(account_id=account_id)
return 204, {}, json.dumps({})
def _parse_pab_config(self, body):
def _parse_pab_config(self, body: str) -> Dict[str, Any]:
parsed_xml = xmltodict.parse(body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
return parsed_xml
def access_point(self, request, full_url, headers):
def access_point(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self.create_access_point(full_url)
@ -66,7 +68,7 @@ class S3ControlResponse(BaseResponse):
if request.method == "DELETE":
return self.delete_access_point(full_url)
def access_point_policy(self, request, full_url, headers):
def access_point_policy(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self.create_access_point_policy(full_url)
@ -75,14 +77,14 @@ class S3ControlResponse(BaseResponse):
if request.method == "DELETE":
return self.delete_access_point_policy(full_url)
def access_point_policy_status(self, request, full_url, headers):
def access_point_policy_status(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self.create_access_point(full_url)
if request.method == "GET":
return self.get_access_point_policy_status(full_url)
def create_access_point(self, full_url):
def create_access_point(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
params = xmltodict.parse(self.body)["CreateAccessPointRequest"]
bucket = params["Bucket"]
@ -98,43 +100,45 @@ class S3ControlResponse(BaseResponse):
template = self.response_template(CREATE_ACCESS_POINT_TEMPLATE)
return 200, {}, template.render(access_point=access_point)
def get_access_point(self, full_url):
def get_access_point(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
access_point = self.backend.get_access_point(account_id=account_id, name=name)
template = self.response_template(GET_ACCESS_POINT_TEMPLATE)
return 200, {}, template.render(access_point=access_point)
def delete_access_point(self, full_url):
def delete_access_point(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_accesspoint(full_url)
self.backend.delete_access_point(account_id=account_id, name=name)
return 204, {}, ""
def create_access_point_policy(self, full_url):
def create_access_point_policy(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_policy(full_url)
params = xmltodict.parse(self.body)
policy = params["PutAccessPointPolicyRequest"]["Policy"]
self.backend.create_access_point_policy(account_id, name, policy)
return 200, {}, ""
def get_access_point_policy(self, full_url):
def get_access_point_policy(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_policy(full_url)
policy = self.backend.get_access_point_policy(account_id, name)
template = self.response_template(GET_ACCESS_POINT_POLICY_TEMPLATE)
return 200, {}, template.render(policy=policy)
def delete_access_point_policy(self, full_url):
def delete_access_point_policy(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_policy(full_url)
self.backend.delete_access_point_policy(account_id=account_id, name=name)
return 204, {}, ""
def get_access_point_policy_status(self, full_url):
def get_access_point_policy_status(self, full_url: str) -> TYPE_RESPONSE:
account_id, name = self._get_accountid_and_name_from_policy(full_url)
self.backend.get_access_point_policy_status(account_id, name)
template = self.response_template(GET_ACCESS_POINT_POLICY_STATUS_TEMPLATE)
return 200, {}, template.render()
def _get_accountid_and_name_from_accesspoint(self, full_url):
def _get_accountid_and_name_from_accesspoint(
self, full_url: str
) -> Tuple[str, str]:
url = full_url
if full_url.startswith("http"):
url = full_url.split("://")[1]
@ -142,7 +146,7 @@ class S3ControlResponse(BaseResponse):
name = url.split("v20180820/accesspoint/")[-1]
return account_id, name
def _get_accountid_and_name_from_policy(self, full_url):
def _get_accountid_and_name_from_policy(self, full_url: str) -> Tuple[str, str]:
url = full_url
if full_url.startswith("http"):
url = full_url.split("://")[1]

View File

@ -3,7 +3,7 @@ import os
import pathlib
from functools import lru_cache
from typing import Optional
from typing import List, Optional
TEST_SERVER_MODE = os.environ.get("TEST_SERVER_MODE", "0").lower() == "true"
@ -47,7 +47,7 @@ def get_sf_execution_history_type():
return os.environ.get("SF_EXECUTION_HISTORY_TYPE", "SUCCESS")
def get_s3_custom_endpoints():
def get_s3_custom_endpoints() -> List[str]:
endpoints = os.environ.get("MOTO_S3_CUSTOM_ENDPOINTS")
if endpoints:
return endpoints.split(",")
@ -57,7 +57,7 @@ def get_s3_custom_endpoints():
S3_UPLOAD_PART_MIN_SIZE = 5242880
def get_s3_default_key_buffer_size():
def get_s3_default_key_buffer_size() -> int:
return int(
os.environ.get(
"MOTO_S3_DEFAULT_KEY_BUFFER_SIZE", S3_UPLOAD_PART_MIN_SIZE - 1024

View File

@ -77,7 +77,7 @@ def md5_hash(data: Any = None) -> Any:
class LowercaseDict(MutableMapping):
"""A dictionary that lowercases all keys"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
self.store = dict()
self.update(dict(*args, **kwargs)) # use the free update to set keys

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/scheduler
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/scheduler
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract