TechDebt: MyPy Core (#5653)

This commit is contained in:
Bert Blommers 2022-11-10 18:54:38 -01:00 committed by GitHub
parent 8c9838cc8c
commit 96b2eff1bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 286 additions and 222 deletions

View File

@ -2,15 +2,17 @@ from collections import defaultdict
from io import BytesIO
from botocore.awsrequest import AWSResponse
from moto.core.exceptions import HTTPException
from typing import Any, Dict, Callable, List, Tuple, Union, Pattern
from .responses import TYPE_RESPONSE
class MockRawResponse(BytesIO):
def __init__(self, response_input):
def __init__(self, response_input: Union[str, bytes]):
if isinstance(response_input, str):
response_input = response_input.encode("utf-8")
super().__init__(response_input)
def stream(self, **kwargs): # pylint: disable=unused-argument
def stream(self, **kwargs: Any) -> Any: # pylint: disable=unused-argument
contents = self.read()
while contents:
yield contents
@ -18,18 +20,22 @@ class MockRawResponse(BytesIO):
class BotocoreStubber:
def __init__(self):
def __init__(self) -> None:
self.enabled = False
self.methods = defaultdict(list)
self.methods: Dict[
str, List[Tuple[Pattern[str], Callable[..., TYPE_RESPONSE]]]
] = defaultdict(list)
def reset(self):
def reset(self) -> None:
self.methods.clear()
def register_response(self, method, pattern, response):
def register_response(
self, method: str, pattern: Pattern[str], response: Callable[..., TYPE_RESPONSE]
) -> None:
matchers = self.methods[method]
matchers.append((pattern, response))
def __call__(self, event_name, request, **kwargs):
def __call__(self, event_name: str, request: Any, **kwargs: Any) -> AWSResponse:
if not self.enabled:
return None
@ -41,7 +47,7 @@ class BotocoreStubber:
matchers = self.methods.get(request.method)
base_url = request.url.split("?", 1)[0]
for i, (pattern, callback) in enumerate(matchers):
for i, (pattern, callback) in enumerate(matchers): # type: ignore[arg-type]
if pattern.match(base_url):
if found_index is None:
found_index = i
@ -62,10 +68,10 @@ class BotocoreStubber:
)
except HTTPException as e:
status = e.code
headers = e.get_headers()
status = e.code # type: ignore[assignment]
headers = e.get_headers() # type: ignore[assignment]
body = e.get_body()
body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body)
raw_response = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, raw_response)
return response

View File

@ -1,12 +1,14 @@
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from .base_backend import InstanceTrackerMeta
class BaseModel(metaclass=InstanceTrackerMeta):
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
def __new__(
cls, *args: Any, **kwargs: Any # pylint: disable=unused-argument
) -> "BaseModel":
instance = super(BaseModel, cls).__new__(cls)
cls.instances.append(instance)
cls.instances.append(instance) # type: ignore[attr-defined]
return instance
@ -37,7 +39,7 @@ class CloudFormationModel(BaseModel):
@classmethod
@abstractmethod
def create_from_cloudformation_json(
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Dict[str, Any],
@ -53,7 +55,7 @@ class CloudFormationModel(BaseModel):
@classmethod
@abstractmethod
def update_from_cloudformation_json(
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource: Any,
new_resource_name: str,
@ -70,7 +72,7 @@ class CloudFormationModel(BaseModel):
@classmethod
@abstractmethod
def delete_from_cloudformation_json(
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Dict[str, Any],
@ -92,7 +94,7 @@ class CloudFormationModel(BaseModel):
class ConfigQueryModel:
def __init__(self, backends):
def __init__(self, backends: Any):
"""Inits based on the resource type's backends (1 for each region if applicable)"""
self.backends = backends
@ -106,7 +108,7 @@ class ConfigQueryModel:
backend_region: Optional[str] = None,
resource_region: Optional[str] = None,
aggregator: Optional[Dict[str, Any]] = None,
):
) -> Tuple[List[Dict[str, Any]], str]:
"""For AWS Config. This will list all of the resources of the given type and optional resource name and region.
This supports both aggregated and non-aggregated listing. The following notes the difference:
@ -195,5 +197,5 @@ class ConfigQueryModel:
class CloudWatchMetricProvider(object):
@staticmethod
@abstractmethod
def get_cloudwatch_metrics(account_id: str) -> Any:
def get_cloudwatch_metrics(account_id: str) -> Any: # type: ignore[misc]
pass

View File

@ -0,0 +1,5 @@
from typing import Dict, Tuple, TypeVar
TYPE_RESPONSE = Tuple[int, Dict[str, str], str]
TYPE_IF_NONE = TypeVar("TYPE_IF_NONE")

View File

@ -2,8 +2,10 @@ import responses
import types
from io import BytesIO
from http.client import responses as http_responses
from typing import Any, Dict, List, Tuple, Optional
from urllib.parse import urlparse
from werkzeug.wrappers import Request
from .responses import TYPE_RESPONSE
from moto.utilities.distutils_version import LooseVersion
@ -21,7 +23,7 @@ class CallbackResponse(responses.CallbackResponse):
Need to subclass so we can change a couple things
"""
def get_response(self, request):
def get_response(self, request: Any) -> responses.HTTPResponse:
"""
Need to override this so we can pass decode_content=False
"""
@ -58,20 +60,22 @@ class CallbackResponse(responses.CallbackResponse):
raise result
status, r_headers, body = result
body = responses._handle_body(body)
body_io = responses._handle_body(body)
headers.update(r_headers)
return responses.HTTPResponse(
status=status,
reason=http_responses.get(status),
body=body,
body=body_io,
headers=headers,
preload_content=False,
# Need to not decode_content to mimic requests
decode_content=False,
)
def _url_matches(self, url, other, match_querystring=False):
def _url_matches(
self, url: Any, other: Any, match_querystring: bool = False
) -> bool:
"""
Need to override this so we can fix querystrings breaking regex matching
"""
@ -83,16 +87,18 @@ class CallbackResponse(responses.CallbackResponse):
url = responses._clean_unicode(url)
if not isinstance(other, str):
other = other.encode("ascii").decode("utf8")
return self._url_matches_strict(url, other)
return self._url_matches_strict(url, other) # type: ignore[attr-defined]
elif isinstance(url, responses.Pattern) and url.match(other):
return True
else:
return False
def not_implemented_callback(request): # pylint: disable=unused-argument
def not_implemented_callback(
request: Any, # pylint: disable=unused-argument
) -> TYPE_RESPONSE:
status = 400
headers = {}
headers: Dict[str, str] = {}
response = "The method is not implemented"
return status, headers, response
@ -106,10 +112,12 @@ def not_implemented_callback(request): # pylint: disable=unused-argument
#
# Note that, due to an outdated API we no longer support Responses <= 0.12.1
# This method should be used for Responses 0.12.1 < .. < 0.17.0
def _find_first_match(self, request):
def _find_first_match(
self: Any, request: Any
) -> Tuple[Optional[responses.BaseResponse], List[str]]:
matches = []
match_failed_reasons = []
all_possibles = self._matches + responses._default_mock._matches
all_possibles = self._matches + responses._default_mock._matches # type: ignore[attr-defined]
for match in all_possibles:
match_result, reason = match.matches(request)
if match_result:
@ -132,7 +140,7 @@ def _find_first_match(self, request):
return None, match_failed_reasons
def get_response_mock():
def get_response_mock() -> responses.RequestsMock:
"""
The responses-library is crucial in ensuring that requests to AWS are intercepted, and routed to the right backend.
However, as our usecase is different from a 'typical' responses-user, Moto always needs some custom logic to ensure responses behaves in a way that works for us.
@ -152,13 +160,13 @@ def get_response_mock():
)
else:
responses_mock = responses.RequestsMock(assert_all_requests_are_fired=False)
responses_mock._find_match = types.MethodType(_find_first_match, responses_mock)
responses_mock._find_match = types.MethodType(_find_first_match, responses_mock) # type: ignore[assignment]
responses_mock.add_passthru("http")
return responses_mock
def reset_responses_mock(responses_mock):
def reset_responses_mock(responses_mock: responses.RequestsMock) -> None:
if LooseVersion(RESPONSES_VERSION) >= LooseVersion("0.17.0"):
from .responses_custom_registry import CustomRegistry

View File

@ -1,6 +1,6 @@
from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment
from typing import Any, Optional
from typing import Any, List, Tuple, Optional
import json
# TODO: add "<Type>Sender</Type>" to error responses below?
@ -67,14 +67,18 @@ class RESTError(HTTPException):
)
self.content_type = "application/xml"
def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument
return {
"X-Amzn-ErrorType": self.error_type or "UnknownError",
"Content-Type": self.content_type,
}
def get_headers(
self, *args: Any, **kwargs: Any # pylint: disable=unused-argument
) -> List[Tuple[str, str]]:
return [
("X-Amzn-ErrorType", self.error_type or "UnknownError"),
("Content-Type", self.content_type),
]
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
return self.description
def get_body(
self, *args: Any, **kwargs: Any # pylint: disable=unused-argument
) -> str:
return self.description # type: ignore[return-value]
class DryRunClientError(RESTError):
@ -86,19 +90,19 @@ class JsonRESTError(RESTError):
self, error_type: str, message: str, template: str = "error_json", **kwargs: Any
):
super().__init__(error_type, message, template, **kwargs)
self.description = json.dumps(
self.description: str = json.dumps(
{"__type": self.error_type, "message": self.message}
)
self.content_type = "application/json"
def get_body(self, *args, **kwargs) -> str:
def get_body(self, *args: Any, **kwargs: Any) -> str:
return self.description
class SignatureDoesNotMatchError(RESTError):
code = 403
def __init__(self):
def __init__(self) -> None:
super().__init__(
"SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.",
@ -108,7 +112,7 @@ class SignatureDoesNotMatchError(RESTError):
class InvalidClientTokenIdError(RESTError):
code = 403
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidClientTokenId",
"The security token included in the request is invalid.",
@ -118,7 +122,7 @@ class InvalidClientTokenIdError(RESTError):
class AccessDeniedError(RESTError):
code = 403
def __init__(self, user_arn, action):
def __init__(self, user_arn: str, action: str):
super().__init__(
"AccessDenied",
"User: {user_arn} is not authorized to perform: {operation}".format(
@ -130,7 +134,7 @@ class AccessDeniedError(RESTError):
class AuthFailureError(RESTError):
code = 401
def __init__(self):
def __init__(self) -> None:
super().__init__(
"AuthFailure",
"AWS was not able to validate the provided access credentials",
@ -142,9 +146,12 @@ class AWSError(JsonRESTError):
STATUS = 400
def __init__(
self, message: str, exception_type: str = None, status: Optional[int] = None
self,
message: str,
exception_type: Optional[str] = None,
status: Optional[int] = None,
):
super().__init__(exception_type or self.TYPE, message)
super().__init__(exception_type or self.TYPE, message) # type: ignore[arg-type]
self.code = status or self.STATUS
@ -153,7 +160,7 @@ class InvalidNextTokenException(JsonRESTError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidNextTokenException", "The nextToken provided is invalid"
)
@ -162,5 +169,5 @@ class InvalidNextTokenException(JsonRESTError):
class InvalidToken(AWSError):
code = 400
def __init__(self, message="Invalid token"):
def __init__(self, message: str = "Invalid token"):
super().__init__("Invalid Token: {}".format(message), "InvalidToken")

View File

@ -5,6 +5,7 @@ import os
import re
import unittest
from types import FunctionType
from typing import Any, Callable, Dict, Optional, Set, TypeVar
from unittest.mock import patch
import boto3
@ -14,7 +15,7 @@ from botocore.config import Config
from botocore.handlers import BUILTIN_HANDLERS
from moto import settings
from moto.core.base_backend import BackendDict
from .base_backend import BackendDict
from .botocore_stubber import BotocoreStubber
from .custom_responses_mock import (
get_response_mock,
@ -24,13 +25,14 @@ from .custom_responses_mock import (
)
DEFAULT_ACCOUNT_ID = "123456789012"
CALLABLE_RETURN = TypeVar("CALLABLE_RETURN")
class BaseMockAWS:
nested_count = 0
mocks_active = False
def __init__(self, backends):
def __init__(self, backends: BackendDict):
from moto.instance_metadata import instance_metadata_backends
from moto.moto_api._internal.models import moto_api_backend
@ -55,25 +57,25 @@ class BaseMockAWS:
"AWS_ACCESS_KEY_ID": "foobar_key",
"AWS_SECRET_ACCESS_KEY": "foobar_secret",
}
self.ORIG_KEYS = {}
self.ORIG_KEYS: Dict[str, Optional[str]] = {}
self.default_session_mock = patch("boto3.DEFAULT_SESSION", None)
if self.__class__.nested_count == 0:
self.reset()
self.reset() # type: ignore[attr-defined]
def __call__(self, func, reset=True):
def __call__(self, func: Callable[..., Any], reset: bool = True) -> Any:
if inspect.isclass(func):
return self.decorate_class(func)
return self.decorate_callable(func, reset)
def __enter__(self):
def __enter__(self) -> "BaseMockAWS":
self.start()
return self
def __exit__(self, *args):
def __exit__(self, *args: Any) -> None:
self.stop()
def start(self, reset=True):
def start(self, reset: bool = True) -> None:
if not self.__class__.mocks_active:
self.default_session_mock.start()
self.mock_env_variables()
@ -84,9 +86,9 @@ class BaseMockAWS:
for backend in self.backends.values():
backend.reset()
self.enable_patching(reset)
self.enable_patching(reset) # type: ignore[attr-defined]
def stop(self):
def stop(self) -> None:
self.__class__.nested_count -= 1
if self.__class__.nested_count < 0:
@ -102,10 +104,12 @@ class BaseMockAWS:
pass
self.unmock_env_variables()
self.__class__.mocks_active = False
self.disable_patching()
self.disable_patching() # type: ignore[attr-defined]
def decorate_callable(self, func, reset):
def wrapper(*args, **kwargs):
def decorate_callable(
self, func: Callable[..., CALLABLE_RETURN], reset: bool
) -> Callable[..., CALLABLE_RETURN]:
def wrapper(*args: Any, **kwargs: Any) -> CALLABLE_RETURN:
self.start(reset=reset)
try:
result = func(*args, **kwargs)
@ -114,10 +118,10 @@ class BaseMockAWS:
return result
functools.update_wrapper(wrapper, func)
wrapper.__wrapped__ = func
wrapper.__wrapped__ = func # type: ignore[attr-defined]
return wrapper
def decorate_class(self, klass):
def decorate_class(self, klass: type) -> object:
direct_methods = get_direct_methods_of(klass)
defined_classes = set(
x for x, y in klass.__dict__.items() if inspect.isclass(y)
@ -181,7 +185,7 @@ class BaseMockAWS:
continue
return klass
def mock_env_variables(self):
def mock_env_variables(self) -> None:
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
# self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
# self.env_variables_mocks.start()
@ -189,7 +193,7 @@ class BaseMockAWS:
self.ORIG_KEYS[k] = os.environ.get(k, None)
os.environ[k] = v
def unmock_env_variables(self):
def unmock_env_variables(self) -> None:
# This doesn't work in Python2 - for some reason, unmocking clears the entire os.environ dict
# Obviously bad user experience, and also breaks pytest - as it uses PYTEST_CURRENT_TEST as an env var
# self.env_variables_mocks.stop()
@ -200,7 +204,7 @@ class BaseMockAWS:
del os.environ[k]
def get_direct_methods_of(klass):
def get_direct_methods_of(klass: object) -> Set[str]:
return set(
x
for x, y in klass.__dict__.items()
@ -232,7 +236,7 @@ botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
def patch_client(client):
def patch_client(client: botocore.client.BaseClient) -> None:
"""
Explicitly patch a boto3-client
"""
@ -254,7 +258,7 @@ def patch_client(client):
raise Exception(f"Argument {client} should be of type boto3.client")
def patch_resource(resource):
def patch_resource(resource: Any) -> None:
"""
Explicitly patch a boto3-resource
"""
@ -267,11 +271,13 @@ def patch_resource(resource):
class BotocoreEventMockAWS(BaseMockAWS):
def reset(self):
def reset(self) -> None:
botocore_stubber.reset()
reset_responses_mock(responses_mock)
def enable_patching(self, reset=True): # pylint: disable=unused-argument
def enable_patching(
self, reset: bool = True # pylint: disable=unused-argument
) -> None:
# Circumvent circular imports
from .utils import convert_flask_to_responses_response
@ -313,7 +319,7 @@ class BotocoreEventMockAWS(BaseMockAWS):
)
)
def disable_patching(self):
def disable_patching(self) -> None:
botocore_stubber.enabled = False
self.reset()
@ -330,14 +336,13 @@ class ServerModeMockAWS(BaseMockAWS):
RESET_IN_PROGRESS = False
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
self.test_server_mode_endpoint = settings.test_server_mode_endpoint()
super().__init__(*args, **kwargs)
def reset(self):
def reset(self) -> None:
call_reset_api = os.environ.get("MOTO_CALL_RESET_API")
call_reset_api = not call_reset_api or call_reset_api.lower() != "false"
if call_reset_api:
if not call_reset_api or call_reset_api.lower() != "false":
if not ServerModeMockAWS.RESET_IN_PROGRESS:
ServerModeMockAWS.RESET_IN_PROGRESS = True
import requests
@ -345,14 +350,14 @@ class ServerModeMockAWS(BaseMockAWS):
requests.post(f"{self.test_server_mode_endpoint}/moto-api/reset")
ServerModeMockAWS.RESET_IN_PROGRESS = False
def enable_patching(self, reset=True):
def enable_patching(self, reset: bool = True) -> None:
if self.__class__.nested_count == 1 and reset:
# Just started
self.reset()
from boto3 import client as real_boto3_client, resource as real_boto3_resource
def fake_boto3_client(*args, **kwargs):
def fake_boto3_client(*args: Any, **kwargs: Any) -> botocore.client.BaseClient:
region = self._get_region(*args, **kwargs)
if region:
if "config" in kwargs:
@ -364,7 +369,7 @@ class ServerModeMockAWS(BaseMockAWS):
kwargs["endpoint_url"] = self.test_server_mode_endpoint
return real_boto3_client(*args, **kwargs)
def fake_boto3_resource(*args, **kwargs):
def fake_boto3_resource(*args: Any, **kwargs: Any) -> Any:
if "endpoint_url" not in kwargs:
kwargs["endpoint_url"] = self.test_server_mode_endpoint
return real_boto3_resource(*args, **kwargs)
@ -374,7 +379,7 @@ class ServerModeMockAWS(BaseMockAWS):
self._client_patcher.start()
self._resource_patcher.start()
def _get_region(self, *args, **kwargs):
def _get_region(self, *args: Any, **kwargs: Any) -> Optional[str]:
if "region_name" in kwargs:
return kwargs["region_name"]
if type(args) == tuple and len(args) == 2:
@ -382,7 +387,7 @@ class ServerModeMockAWS(BaseMockAWS):
return region
return None
def disable_patching(self):
def disable_patching(self) -> None:
if self._client_patcher:
self._client_patcher.stop()
self._resource_patcher.stop()
@ -394,9 +399,9 @@ class base_decorator:
def __init__(self, backends: BackendDict):
self.backends = backends
def __call__(self, func=None):
def __call__(self, func: Optional[Callable[..., Any]] = None) -> BaseMockAWS:
if settings.TEST_SERVER_MODE:
mocked_backend = ServerModeMockAWS(self.backends)
mocked_backend: BaseMockAWS = ServerModeMockAWS(self.backends)
else:
mocked_backend = self.mock_backend(self.backends)

View File

@ -11,11 +11,22 @@ import xmltodict
from collections import defaultdict, OrderedDict
from moto import settings
from moto.core.common_types import TYPE_RESPONSE, TYPE_IF_NONE
from moto.core.exceptions import DryRunClientError
from moto.core.utils import camelcase_to_underscores, method_names_from_class
from moto.utilities.utils import load_resource
from jinja2 import Environment, DictLoader, Template
from typing import Dict, Union, Any, Tuple, TypeVar
from typing import (
Dict,
Union,
Any,
Tuple,
Optional,
List,
Set,
ClassVar,
Callable,
)
from urllib.parse import parse_qs, parse_qsl, urlparse
from werkzeug.exceptions import HTTPException
from xml.dom.minidom import parseString as parseXML
@ -23,29 +34,19 @@ from xml.dom.minidom import parseString as parseXML
log = logging.getLogger(__name__)
JINJA_ENVS = {}
TYPE_RESPONSE = Tuple[int, Dict[str, str], str]
TYPE_IF_NONE = TypeVar("TYPE_IF_NONE")
JINJA_ENVS: Dict[type, Environment] = {}
def _decode_dict(d):
decoded = OrderedDict()
def _decode_dict(d: Dict[Any, Any]) -> Dict[str, Any]:
decoded: Dict[str, Any] = OrderedDict()
for key, value in d.items():
if isinstance(key, bytes):
newkey = key.decode("utf-8")
elif isinstance(key, (list, tuple)):
newkey = []
for k in key:
if isinstance(k, bytes):
newkey.append(k.decode("utf-8"))
else:
newkey.append(k)
else:
newkey = key
if isinstance(value, bytes):
newvalue = value.decode("utf-8")
decoded[newkey] = value.decode("utf-8")
elif isinstance(value, (list, tuple)):
newvalue = []
for v in value:
@ -53,18 +54,18 @@ def _decode_dict(d):
newvalue.append(v.decode("utf-8"))
else:
newvalue.append(v)
decoded[newkey] = newvalue
else:
newvalue = value
decoded[newkey] = value
decoded[newkey] = newvalue
return decoded
class DynamicDictLoader(DictLoader):
def update(self, mapping):
self.mapping.update(mapping)
def update(self, mapping: Dict[str, str]) -> None:
self.mapping.update(mapping) # type: ignore[attr-defined]
def contains(self, template):
def contains(self, template: str) -> bool:
return bool(template in self.mapping)
@ -73,12 +74,12 @@ class _TemplateEnvironmentMixin(object):
RIGHT_PATTERN = re.compile(r">[\s\n]+")
@property
def should_autoescape(self):
def should_autoescape(self) -> bool:
# Allow for subclass to overwrite
return False
@property
def environment(self):
def environment(self) -> Environment:
key = type(self)
try:
environment = JINJA_ENVS[key]
@ -94,11 +95,11 @@ class _TemplateEnvironmentMixin(object):
return environment
def contains_template(self, template_id):
return self.environment.loader.contains(template_id)
def contains_template(self, template_id: str) -> bool:
return self.environment.loader.contains(template_id) # type: ignore[union-attr]
@classmethod
def _make_template_id(cls, source):
def _make_template_id(cls, source: str) -> str:
"""
Return a numeric string that's unique for the lifetime of the source.
@ -117,47 +118,49 @@ class _TemplateEnvironmentMixin(object):
xml = re.sub(
self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source)
)
self.environment.loader.update({template_id: xml})
self.environment.loader.update({template_id: xml}) # type: ignore[union-attr]
return self.environment.get_template(template_id)
class ActionAuthenticatorMixin(object):
request_count = 0
request_count: ClassVar[int] = 0
def _authenticate_and_authorize_action(self, iam_request_cls):
def _authenticate_and_authorize_action(self, iam_request_cls: type) -> None:
if (
ActionAuthenticatorMixin.request_count
>= settings.INITIAL_NO_AUTH_ACTION_COUNT
):
iam_request = iam_request_cls(
account_id=self.current_account,
method=self.method,
path=self.path,
data=self.data,
headers=self.headers,
account_id=self.current_account, # type: ignore[attr-defined]
method=self.method, # type: ignore[attr-defined]
path=self.path, # type: ignore[attr-defined]
data=self.data, # type: ignore[attr-defined]
headers=self.headers, # type: ignore[attr-defined]
)
iam_request.check_signature()
iam_request.check_action_permitted()
else:
ActionAuthenticatorMixin.request_count += 1
def _authenticate_and_authorize_normal_action(self):
def _authenticate_and_authorize_normal_action(self) -> None:
from moto.iam.access_control import IAMRequest
self._authenticate_and_authorize_action(IAMRequest)
def _authenticate_and_authorize_s3_action(self):
def _authenticate_and_authorize_s3_action(self) -> None:
from moto.iam.access_control import S3IAMRequest
self._authenticate_and_authorize_action(S3IAMRequest)
@staticmethod
def set_initial_no_auth_action_count(initial_no_auth_action_count):
def set_initial_no_auth_action_count(initial_no_auth_action_count: int) -> Callable[..., Callable[..., TYPE_RESPONSE]]: # type: ignore[misc]
_test_server_mode_endpoint = settings.test_server_mode_endpoint()
def decorator(function):
def wrapper(*args, **kwargs):
def decorator(
function: Callable[..., TYPE_RESPONSE]
) -> Callable[..., TYPE_RESPONSE]:
def wrapper(*args: Any, **kwargs: Any) -> TYPE_RESPONSE:
if settings.TEST_SERVER_MODE:
response = requests.post(
f"{_test_server_mode_endpoint}/moto-api/reset-auth",
@ -191,7 +194,7 @@ class ActionAuthenticatorMixin(object):
return result
functools.update_wrapper(wrapper, function)
wrapper.__wrapped__ = function
wrapper.__wrapped__ = function # type: ignore[attr-defined]
return wrapper
return decorator
@ -213,12 +216,12 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
)
aws_service_spec = None
def __init__(self, service_name=None) -> None:
def __init__(self, service_name: Optional[str] = None):
super().__init__()
self.service_name = service_name
@classmethod
def dispatch(cls, *args: Any, **kwargs: Any) -> Any:
def dispatch(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
return cls()._dispatch(*args, **kwargs)
def setup_class(
@ -227,7 +230,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
"""
use_raw_body: Use incoming bytes if True, encode to string otherwise
"""
querystring = OrderedDict()
querystring: Dict[str, Any] = OrderedDict()
if hasattr(request, "body"):
# Boto
self.body = request.body
@ -292,7 +295,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.data = querystring
self.method = request.method
self.region = self.get_region_from_url(request, full_url)
self.uri_match = None
self.uri_match: Optional[re.Match[str]] = None
self.headers = request.headers
if "host" not in self.headers:
@ -307,11 +310,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
mark_account_as_visited(
account_id=self.current_account,
access_key=self.access_key,
service=self.service_name,
service=self.service_name, # type: ignore[arg-type]
region=self.region,
)
def get_region_from_url(self, request, full_url):
def get_region_from_url(self, request: Any, full_url: str) -> str:
url_match = self.region_regex.search(full_url)
user_agent_match = self.region_from_useragent_regex.search(
request.headers.get("User-Agent", "")
@ -329,7 +332,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
region = self.default_region
return region
def get_access_key(self):
def get_access_key(self) -> str:
"""
Returns the access key id used in this request as the current user id
"""
@ -339,11 +342,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return match.group(1)
if self.querystring.get("AWSAccessKeyId"):
return self.querystring.get("AWSAccessKeyId")[0]
return self.querystring["AWSAccessKeyId"][0]
else:
return "AKIAEXAMPLE"
def get_current_account(self):
def get_current_account(self) -> str:
# PRIO 1: Check if we have a Environment Variable set
if "MOTO_ACCOUNT_ID" in os.environ:
return os.environ["MOTO_ACCOUNT_ID"]
@ -358,11 +361,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return get_account_id_from(self.get_access_key())
def _dispatch(self, request, full_url, headers):
def _dispatch(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
return self.call_action()
def uri_to_regexp(self, uri):
def uri_to_regexp(self, uri: str) -> str:
"""converts uri w/ placeholder to regexp
'/cars/{carName}/drivers/{DriverName}'
-> '^/cars/.*/drivers/[^/]*$'
@ -372,7 +375,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
"""
def _convert(elem, is_last):
def _convert(elem: str, is_last: bool) -> str:
if not re.match("^{.*}$", elem):
return elem
name = (
@ -394,7 +397,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
)
return regexp
def _get_action_from_method_and_request_uri(self, method, request_uri):
def _get_action_from_method_and_request_uri(
self, method: str, request_uri: str
) -> str:
"""basically used for `rest-json` APIs
You can refer to example from link below
https://github.com/boto/botocore/blob/develop/botocore/data/iot/2015-05-28/service-2.json
@ -406,7 +411,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# make cache if it does not exist yet
if not hasattr(self, "method_urls"):
self.method_urls = defaultdict(lambda: defaultdict(str))
self.method_urls: Dict[str, Dict[str, str]] = defaultdict(
lambda: defaultdict(str)
)
op_names = conn._service_model.operation_names
for op_name in op_names:
op_model = conn._service_model.operation_model(op_name)
@ -419,9 +426,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.uri_match = match
if match:
return name
return None
return None # type: ignore[return-value]
def _get_action(self):
def _get_action(self) -> str:
action = self.querystring.get("Action", [""])[0]
if action:
return action
@ -433,7 +440,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# get action from method and uri
return self._get_action_from_method_and_request_uri(self.method, self.path)
def call_action(self):
def call_action(self) -> TYPE_RESPONSE:
headers = self.response_headers
try:
@ -449,9 +456,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
try:
response = method()
except HTTPException as http_error:
response_headers = dict(http_error.get_headers() or [])
response_headers["status"] = http_error.code
response = http_error.description, response_headers
response_headers: Dict[str, Union[str, int]] = dict(
http_error.get_headers() or []
)
response_headers["status"] = http_error.code # type: ignore[assignment]
response = http_error.description, response_headers # type: ignore[assignment]
if isinstance(response, str):
return 200, headers, response
@ -466,7 +475,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
)
@staticmethod
def _send_response(headers, response):
def _send_response(headers: Dict[str, str], response: Any) -> Tuple[int, Dict[str, str], str]: # type: ignore[misc]
if response is None:
response = "", {}
if len(response) == 2:
@ -480,7 +489,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
headers["status"] = str(headers["status"])
return status, headers, body
def _get_param(self, param_name, if_none=None) -> Any:
def _get_param(self, param_name: str, if_none: Any = None) -> Any:
val = self.querystring.get(param_name)
if val is not None:
return val[0]
@ -503,7 +512,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return if_none
def _get_int_param(
self, param_name, if_none: TYPE_IF_NONE = None
self, param_name: str, if_none: TYPE_IF_NONE = None # type: ignore[assignment]
) -> Union[int, TYPE_IF_NONE]:
val = self._get_param(param_name)
if val is not None:
@ -511,7 +520,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return if_none
def _get_bool_param(
self, param_name, if_none: TYPE_IF_NONE = None
self, param_name: str, if_none: TYPE_IF_NONE = None # type: ignore[assignment]
) -> Union[bool, TYPE_IF_NONE]:
val = self._get_param(param_name)
if val is not None:
@ -522,12 +531,15 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return False
return if_none
def _get_multi_param_dict(self, param_prefix) -> Dict:
def _get_multi_param_dict(self, param_prefix: str) -> Dict[str, Any]:
return self._get_multi_param_helper(param_prefix, skip_result_conversion=True)
def _get_multi_param_helper(
self, param_prefix, skip_result_conversion=False, tracked_prefixes=None
):
self,
param_prefix: str,
skip_result_conversion: bool = False,
tracked_prefixes: Optional[Set[str]] = None,
) -> Any:
value_dict = dict()
tracked_prefixes = (
tracked_prefixes or set()
@ -589,11 +601,13 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if len(parts) != 2 or parts[1] != "member":
value_dict[parts[0]] = value_dict.pop(k)
else:
value_dict = list(value_dict.values())[0]
value_dict = list(value_dict.values())[0] # type: ignore[assignment]
return value_dict
def _get_multi_param(self, param_prefix, skip_result_conversion=False) -> Any:
def _get_multi_param(
self, param_prefix: str, skip_result_conversion: bool = False
) -> List[Any]:
"""
Given a querystring of ?LaunchConfigurationNames.member.1=my-test-1&LaunchConfigurationNames.member.2=my-test-2
this will return ['my-test-1', 'my-test-2']
@ -616,7 +630,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return values
def _get_dict_param(self, param_prefix) -> Dict:
def _get_dict_param(self, param_prefix: str) -> Dict[str, Any]:
"""
Given a parameter dict of
{
@ -630,7 +644,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
"instance_count": "1",
}
"""
params = {}
params: Dict[str, Any] = {}
for key, value in self.querystring.items():
if key.startswith(param_prefix):
params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[
@ -638,7 +652,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
]
return params
def _get_params(self) -> Any:
def _get_params(self) -> Dict[str, Any]:
"""
Given a querystring of
{
@ -674,12 +688,12 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
]
}
"""
params = {}
params: Dict[str, Any] = {}
for k, v in sorted(self.querystring.items()):
self._parse_param(k, v[0], params)
return params
def _parse_param(self, key, value, params):
def _parse_param(self, key: str, value: str, params: Any) -> None:
keylist = key.split(".")
obj = params
for i, key in enumerate(keylist[:-1]):
@ -713,7 +727,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else:
obj[keylist[-1]] = value
def _get_list_prefix(self, param_prefix: str) -> Any:
def _get_list_prefix(self, param_prefix: str) -> List[Dict[str, Any]]:
"""
Given a query dict like
{
@ -752,7 +766,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
param_index += 1
return results
def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"):
def _get_map_prefix(
self, param_prefix: str, key_end: str = ".key", value_end: str = ".value"
) -> Dict[str, Any]:
results = {}
param_index = 1
while 1:
@ -774,7 +790,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return results
def _get_object_map(self, prefix, name="Name", value="Value"):
def _get_object_map(
self, prefix: str, name: str = "Name", value: str = "Value"
) -> Dict[str, Any]:
"""
Given a query dict like
{
@ -822,19 +840,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return object_map
@property
def request_json(self):
def request_json(self) -> bool:
return "JSON" in self.querystring.get("ContentType", [])
def error_on_dryrun(self):
def error_on_dryrun(self) -> None:
self.is_not_dryrun()
def is_not_dryrun(self, action=None):
action = action or self._get_param("Action")
def is_not_dryrun(self, action: Optional[str] = None) -> bool:
if "true" in self.querystring.get("DryRun", ["false"]):
message = (
"An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set"
% action
)
a = action or self._get_param("Action")
message = f"An error occurred (DryRunOperation) when calling the {a} operation: Request would have succeeded, but DryRun flag is set"
raise DryRunClientError(error_type="DryRunOperation", message=message)
return True
@ -842,20 +857,20 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
class _RecursiveDictRef(object):
"""Store a recursive reference to dict."""
def __init__(self):
self.key = None
self.dic = {}
def __init__(self) -> None:
self.key: Optional[str] = None
self.dic: Dict[str, Any] = {}
def __repr__(self):
def __repr__(self) -> str:
return "{!r}".format(self.dic)
def __getattr__(self, key):
return self.dic.__getattr__(key)
def __getattr__(self, key: str) -> Any:
return self.dic.__getattr__(key) # type: ignore[attr-defined]
def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
return self.dic.__getitem__(key)
def set_reference(self, key, dic):
def set_reference(self, key: str, dic: Dict[str, Any]) -> None:
"""Set the RecursiveDictRef object to keep reference to dict object
(dic) at the key.
@ -877,7 +892,7 @@ class AWSServiceSpec(object):
self.operations = spec["operations"]
self.shapes = spec["shapes"]
def input_spec(self, operation):
def input_spec(self, operation: str) -> Dict[str, Any]:
try:
op = self.operations[operation]
except KeyError:
@ -887,7 +902,7 @@ class AWSServiceSpec(object):
shape = self.shapes[op["input"]["shape"]]
return self._expand(shape)
def output_spec(self, operation):
def output_spec(self, operation: str) -> Dict[str, Any]:
"""Produce a JSON with a valid API response syntax for operation, but
with type information. Each node represented by a key has the
value containing field type, e.g.,
@ -904,11 +919,13 @@ class AWSServiceSpec(object):
shape = self.shapes[op["output"]["shape"]]
return self._expand(shape)
def _expand(self, shape):
def expand(dic, seen=None):
def _expand(self, shape: Dict[str, Any]) -> Dict[str, Any]:
def expand(
dic: Dict[str, Any], seen: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
seen = seen or {}
if dic["type"] == "structure":
nodes = {}
nodes: Dict[str, Any] = {}
for k, v in dic["members"].items():
seen_till_here = dict(seen)
if k in seen_till_here:
@ -932,7 +949,7 @@ class AWSServiceSpec(object):
elif dic["type"] == "map":
seen_till_here = dict(seen)
node = {"type": "map"}
node: Dict[str, Any] = {"type": "map"}
if "shape" in dic["key"]:
shape = dic["key"]["shape"]
@ -958,12 +975,12 @@ class AWSServiceSpec(object):
return expand(shape)
def to_str(value, spec):
def to_str(value: Any, spec: Dict[str, Any]) -> str:
vtype = spec["type"]
if vtype == "boolean":
return "true" if value else "false"
elif vtype == "long":
return int(value)
return int(value) # type: ignore[return-value]
elif vtype == "integer":
return str(value)
elif vtype == "float":
@ -984,7 +1001,7 @@ def to_str(value, spec):
raise TypeError("Unknown type {}".format(vtype))
def from_str(value, spec):
def from_str(value: str, spec: Dict[str, Any]) -> Any:
vtype = spec["type"]
if vtype == "boolean":
return True if value == "true" else False
@ -1001,7 +1018,9 @@ def from_str(value, spec):
raise TypeError("Unknown type {}".format(vtype))
def flatten_json_request_body(prefix, dict_body, spec):
def flatten_json_request_body(
prefix: str, dict_body: Dict[str, Any], spec: Dict[str, Any]
) -> Dict[str, Any]:
"""Convert a JSON request body into query params."""
if len(spec) == 1 and "type" in spec:
return {prefix: to_str(dict_body, spec)}
@ -1030,10 +1049,12 @@ def flatten_json_request_body(prefix, dict_body, spec):
return dict((prefix + k, v) for k, v in flat.items())
def xml_to_json_response(service_spec, operation, xml, result_node=None):
def xml_to_json_response(
service_spec: Any, operation: str, xml: str, result_node: Any = None
) -> Dict[str, Any]:
"""Convert rendered XML response to JSON for use with boto3."""
def transform(value, spec):
def transform(value: Any, spec: Dict[str, Any]) -> Any:
"""Apply transformations to make the output JSON comply with the
expected form. This function applies:
@ -1047,7 +1068,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
if len(spec) == 1:
return from_str(value, spec)
od = OrderedDict()
od: Dict[str, Any] = OrderedDict()
for k, v in value.items():
if k.startswith("@"):
continue
@ -1099,7 +1120,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
for k in result_node or (operation + "Response", operation + "Result"):
dic = dic[k]
except KeyError:
return None
return None # type: ignore[return-value]
else:
return transform(dic, output_spec)
return None

View File

@ -1,5 +1,6 @@
# This will only exist in responses >= 0.17
import responses
from typing import Any, List, Tuple, Optional
from .custom_responses_mock import CallbackResponse, not_implemented_callback
@ -10,11 +11,12 @@ class CustomRegistry(responses.registries.FirstMatchRegistry):
- CallbackResponses are not discarded after first use - users can mock the same URL as often as they like
"""
def add(self, response):
def add(self, response: responses.BaseResponse) -> responses.BaseResponse:
if response not in self.registered:
super().add(response)
return response
def find(self, request):
def find(self, request: Any) -> Tuple[Optional[responses.BaseResponse], List[str]]:
all_possibles = responses._default_mock._registry.registered + self.registered
found = []
match_failed_reasons = []

View File

@ -2,11 +2,12 @@ import datetime
import inspect
import re
from botocore.exceptions import ClientError
from typing import Optional
from typing import Any, Optional, List, Callable, Dict
from urllib.parse import urlparse
from .common_types import TYPE_RESPONSE
def camelcase_to_underscores(argument: Optional[str]) -> str:
def camelcase_to_underscores(argument: str) -> str:
"""Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute"""
result = ""
@ -32,7 +33,7 @@ def camelcase_to_underscores(argument: Optional[str]) -> str:
return result
def underscores_to_camelcase(argument):
def underscores_to_camelcase(argument: str) -> str:
"""Converts a camelcase param like the_new_attribute to the equivalent
camelcase version like theNewAttribute. Note that the first letter is
NOT capitalized by this function"""
@ -48,17 +49,17 @@ def underscores_to_camelcase(argument):
return result
def pascal_to_camelcase(argument):
def pascal_to_camelcase(argument: str) -> str:
"""Converts a PascalCase param to the camelCase equivalent"""
return argument[0].lower() + argument[1:]
def camelcase_to_pascal(argument):
def camelcase_to_pascal(argument: str) -> str:
"""Converts a camelCase param to the PascalCase equivalent"""
return argument[0].upper() + argument[1:]
def method_names_from_class(clazz):
def method_names_from_class(clazz: object) -> List[str]:
predicate = inspect.isfunction
return [x[0] for x in inspect.getmembers(clazz, predicate=predicate)]
@ -70,7 +71,7 @@ def convert_regex_to_flask_path(url_path: str) -> str:
for token in ["$"]:
url_path = url_path.replace(token, "")
def caller(reg):
def caller(reg: Any) -> str:
match_name, match_pattern = reg.groups()
return '<regex("{0}"):{1}>'.format(match_pattern, match_name)
@ -83,11 +84,11 @@ def convert_regex_to_flask_path(url_path: str) -> str:
class convert_to_flask_response(object):
def __init__(self, callback):
def __init__(self, callback: Callable[..., Any]):
self.callback = callback
@property
def __name__(self):
def __name__(self) -> str:
# For instance methods, use class and method names. Otherwise
# use module and method name
if inspect.ismethod(self.callback):
@ -96,7 +97,7 @@ class convert_to_flask_response(object):
outer = self.callback.__module__
return "{0}.{1}".format(outer, self.callback.__name__)
def __call__(self, args=None, **kwargs):
def __call__(self, args: Any = None, **kwargs: Any) -> Any:
from flask import request, Response
from moto.moto_api import recorder
@ -118,11 +119,11 @@ class convert_to_flask_response(object):
class convert_flask_to_responses_response(object):
def __init__(self, callback):
def __init__(self, callback: Callable[..., Any]):
self.callback = callback
@property
def __name__(self):
def __name__(self) -> str:
# For instance methods, use class and method names. Otherwise
# use module and method name
if inspect.ismethod(self.callback):
@ -131,7 +132,7 @@ class convert_flask_to_responses_response(object):
outer = self.callback.__module__
return "{0}.{1}".format(outer, self.callback.__name__)
def __call__(self, request, *args, **kwargs):
def __call__(self, request: Any, *args: Any, **kwargs: Any) -> TYPE_RESPONSE:
for key, val in request.headers.items():
if isinstance(val, bytes):
request.headers[key] = val.decode("utf-8")
@ -141,7 +142,7 @@ class convert_flask_to_responses_response(object):
return status, headers, response
def iso_8601_datetime_with_milliseconds(value: datetime) -> str:
def iso_8601_datetime_with_milliseconds(value: datetime.datetime) -> str:
return value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
@ -163,22 +164,22 @@ def iso_8601_datetime_without_milliseconds_s3(
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(src):
def rfc_1123_datetime(src: datetime.datetime) -> str:
return src.strftime(RFC1123)
def str_to_rfc_1123_datetime(value):
def str_to_rfc_1123_datetime(value: str) -> datetime.datetime:
return datetime.datetime.strptime(value, RFC1123)
def unix_time(dt: datetime.datetime = None) -> int:
def unix_time(dt: Optional[datetime.datetime] = None) -> float:
dt = dt or datetime.datetime.utcnow()
epoch = datetime.datetime.utcfromtimestamp(0)
delta = dt - epoch
return (delta.days * 86400) + (delta.seconds + (delta.microseconds / 1e6))
def unix_time_millis(dt: datetime = None) -> int:
def unix_time_millis(dt: Optional[datetime.datetime] = None) -> float:
return unix_time(dt) * 1000.0
@ -193,28 +194,33 @@ def path_url(url: str) -> str:
def tags_from_query_string(
querystring_dict, prefix="Tag", key_suffix="Key", value_suffix="Value"
):
querystring_dict: Dict[str, Any],
prefix: str = "Tag",
key_suffix: str = "Key",
value_suffix: str = "Value",
) -> Dict[str, str]:
response_values = {}
for key in querystring_dict.keys():
if key.startswith(prefix) and key.endswith(key_suffix):
tag_index = key.replace(prefix + ".", "").replace("." + key_suffix, "")
tag_key = querystring_dict.get(
tag_key = querystring_dict[
"{prefix}.{index}.{key_suffix}".format(
prefix=prefix, index=tag_index, key_suffix=key_suffix
)
)[0]
][0]
tag_value_key = "{prefix}.{index}.{value_suffix}".format(
prefix=prefix, index=tag_index, value_suffix=value_suffix
)
if tag_value_key in querystring_dict:
response_values[tag_key] = querystring_dict.get(tag_value_key)[0]
response_values[tag_key] = querystring_dict[tag_value_key][0]
else:
response_values[tag_key] = None
return response_values
def tags_from_cloudformation_tags_list(tags_list):
def tags_from_cloudformation_tags_list(
tags_list: List[Dict[str, str]]
) -> Dict[str, str]:
"""Return tags in dict form from cloudformation resource tags form (list of dicts)"""
tags = {}
for entry in tags_list:
@ -225,7 +231,7 @@ def tags_from_cloudformation_tags_list(tags_list):
return tags
def remap_nested_keys(root, key_transform):
def remap_nested_keys(root: Any, key_transform: Callable[[str], str]) -> Any:
"""This remap ("recursive map") function is used to traverse and
transform the dictionary keys of arbitrarily nested structures.
List comprehensions do not recurse, making it tedious to apply
@ -252,7 +258,9 @@ def remap_nested_keys(root, key_transform):
return root
def merge_dicts(dict1, dict2, remove_nulls=False):
def merge_dicts(
dict1: Dict[str, Any], dict2: Dict[str, Any], remove_nulls: bool = False
) -> None:
"""Given two arbitrarily nested dictionaries, merge the second dict into the first.
:param dict dict1: the dictionary to be updated.
@ -275,7 +283,7 @@ def merge_dicts(dict1, dict2, remove_nulls=False):
dict1.pop(key)
def aws_api_matches(pattern, string):
def aws_api_matches(pattern: str, string: str) -> bool:
"""
AWS API can match a value based on a glob, or an exact match
"""
@ -296,7 +304,7 @@ def aws_api_matches(pattern, string):
return False
def extract_region_from_aws_authorization(string):
def extract_region_from_aws_authorization(string: str) -> Optional[str]:
auth = string or ""
region = re.sub(r".*Credential=[^/]+/[^/]+/([^/]+)/.*", r"\1", auth)
if region == auth:

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/ce,moto/cloudformation,moto/cloudfront,moto/cloudtrail,moto/codebuild,moto/cloudwatch,moto/codepipeline,moto/codecommit,moto/cognito*,moto/comprehend,moto/config,moto/core/base_backend.py
files= moto/a*,moto/b*,moto/c*
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract