diff --git a/moto/core/botocore_stubber.py b/moto/core/botocore_stubber.py index 387fd41f7..1338f8a9c 100644 --- a/moto/core/botocore_stubber.py +++ b/moto/core/botocore_stubber.py @@ -2,15 +2,17 @@ from collections import defaultdict from io import BytesIO from botocore.awsrequest import AWSResponse from moto.core.exceptions import HTTPException +from typing import Any, Dict, Callable, List, Tuple, Union, Pattern +from .responses import TYPE_RESPONSE class MockRawResponse(BytesIO): - def __init__(self, response_input): + def __init__(self, response_input: Union[str, bytes]): if isinstance(response_input, str): response_input = response_input.encode("utf-8") super().__init__(response_input) - def stream(self, **kwargs): # pylint: disable=unused-argument + def stream(self, **kwargs: Any) -> Any: # pylint: disable=unused-argument contents = self.read() while contents: yield contents @@ -18,18 +20,22 @@ class MockRawResponse(BytesIO): class BotocoreStubber: - def __init__(self): + def __init__(self) -> None: self.enabled = False - self.methods = defaultdict(list) + self.methods: Dict[ + str, List[Tuple[Pattern[str], Callable[..., TYPE_RESPONSE]]] + ] = defaultdict(list) - def reset(self): + def reset(self) -> None: self.methods.clear() - def register_response(self, method, pattern, response): + def register_response( + self, method: str, pattern: Pattern[str], response: Callable[..., TYPE_RESPONSE] + ) -> None: matchers = self.methods[method] matchers.append((pattern, response)) - def __call__(self, event_name, request, **kwargs): + def __call__(self, event_name: str, request: Any, **kwargs: Any) -> AWSResponse: if not self.enabled: return None @@ -41,7 +47,7 @@ class BotocoreStubber: matchers = self.methods.get(request.method) base_url = request.url.split("?", 1)[0] - for i, (pattern, callback) in enumerate(matchers): + for i, (pattern, callback) in enumerate(matchers): # type: ignore[arg-type] if pattern.match(base_url): if found_index is None: found_index = i @@ -62,10 +68,10 @@ class BotocoreStubber: ) except HTTPException as e: - status = e.code - headers = e.get_headers() + status = e.code # type: ignore[assignment] + headers = e.get_headers() # type: ignore[assignment] body = e.get_body() - body = MockRawResponse(body) - response = AWSResponse(request.url, status, headers, body) + raw_response = MockRawResponse(body) + response = AWSResponse(request.url, status, headers, raw_response) return response diff --git a/moto/core/common_models.py b/moto/core/common_models.py index 8c875aad7..613c9f3ad 100644 --- a/moto/core/common_models.py +++ b/moto/core/common_models.py @@ -1,12 +1,14 @@ from abc import abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from .base_backend import InstanceTrackerMeta class BaseModel(metaclass=InstanceTrackerMeta): - def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument + def __new__( + cls, *args: Any, **kwargs: Any # pylint: disable=unused-argument + ) -> "BaseModel": instance = super(BaseModel, cls).__new__(cls) - cls.instances.append(instance) + cls.instances.append(instance) # type: ignore[attr-defined] return instance @@ -37,7 +39,7 @@ class CloudFormationModel(BaseModel): @classmethod @abstractmethod - def create_from_cloudformation_json( + def create_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Dict[str, Any], @@ -53,7 +55,7 @@ class CloudFormationModel(BaseModel): @classmethod @abstractmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, original_resource: Any, new_resource_name: str, @@ -70,7 +72,7 @@ class CloudFormationModel(BaseModel): @classmethod @abstractmethod - def delete_from_cloudformation_json( + def delete_from_cloudformation_json( # type: ignore[misc] cls, resource_name: str, cloudformation_json: Dict[str, Any], @@ -92,7 +94,7 @@ class CloudFormationModel(BaseModel): class ConfigQueryModel: - def __init__(self, backends): + def __init__(self, backends: Any): """Inits based on the resource type's backends (1 for each region if applicable)""" self.backends = backends @@ -106,7 +108,7 @@ class ConfigQueryModel: backend_region: Optional[str] = None, resource_region: Optional[str] = None, aggregator: Optional[Dict[str, Any]] = None, - ): + ) -> Tuple[List[Dict[str, Any]], str]: """For AWS Config. This will list all of the resources of the given type and optional resource name and region. This supports both aggregated and non-aggregated listing. The following notes the difference: @@ -195,5 +197,5 @@ class ConfigQueryModel: class CloudWatchMetricProvider(object): @staticmethod @abstractmethod - def get_cloudwatch_metrics(account_id: str) -> Any: + def get_cloudwatch_metrics(account_id: str) -> Any: # type: ignore[misc] pass diff --git a/moto/core/common_types.py b/moto/core/common_types.py new file mode 100644 index 000000000..7a27c8404 --- /dev/null +++ b/moto/core/common_types.py @@ -0,0 +1,5 @@ +from typing import Dict, Tuple, TypeVar + + +TYPE_RESPONSE = Tuple[int, Dict[str, str], str] +TYPE_IF_NONE = TypeVar("TYPE_IF_NONE") diff --git a/moto/core/custom_responses_mock.py b/moto/core/custom_responses_mock.py index 3fff7ddf7..7402303e2 100644 --- a/moto/core/custom_responses_mock.py +++ b/moto/core/custom_responses_mock.py @@ -2,8 +2,10 @@ import responses import types from io import BytesIO from http.client import responses as http_responses +from typing import Any, Dict, List, Tuple, Optional from urllib.parse import urlparse from werkzeug.wrappers import Request +from .responses import TYPE_RESPONSE from moto.utilities.distutils_version import LooseVersion @@ -21,7 +23,7 @@ class CallbackResponse(responses.CallbackResponse): Need to subclass so we can change a couple things """ - def get_response(self, request): + def get_response(self, request: Any) -> responses.HTTPResponse: """ Need to override this so we can pass decode_content=False """ @@ -58,20 +60,22 @@ class CallbackResponse(responses.CallbackResponse): raise result status, r_headers, body = result - body = responses._handle_body(body) + body_io = responses._handle_body(body) headers.update(r_headers) return responses.HTTPResponse( status=status, reason=http_responses.get(status), - body=body, + body=body_io, headers=headers, preload_content=False, # Need to not decode_content to mimic requests decode_content=False, ) - def _url_matches(self, url, other, match_querystring=False): + def _url_matches( + self, url: Any, other: Any, match_querystring: bool = False + ) -> bool: """ Need to override this so we can fix querystrings breaking regex matching """ @@ -83,16 +87,18 @@ class CallbackResponse(responses.CallbackResponse): url = responses._clean_unicode(url) if not isinstance(other, str): other = other.encode("ascii").decode("utf8") - return self._url_matches_strict(url, other) + return self._url_matches_strict(url, other) # type: ignore[attr-defined] elif isinstance(url, responses.Pattern) and url.match(other): return True else: return False -def not_implemented_callback(request): # pylint: disable=unused-argument +def not_implemented_callback( + request: Any, # pylint: disable=unused-argument +) -> TYPE_RESPONSE: status = 400 - headers = {} + headers: Dict[str, str] = {} response = "The method is not implemented" return status, headers, response @@ -106,10 +112,12 @@ def not_implemented_callback(request): # pylint: disable=unused-argument # # Note that, due to an outdated API we no longer support Responses <= 0.12.1 # This method should be used for Responses 0.12.1 < .. < 0.17.0 -def _find_first_match(self, request): +def _find_first_match( + self: Any, request: Any +) -> Tuple[Optional[responses.BaseResponse], List[str]]: matches = [] match_failed_reasons = [] - all_possibles = self._matches + responses._default_mock._matches + all_possibles = self._matches + responses._default_mock._matches # type: ignore[attr-defined] for match in all_possibles: match_result, reason = match.matches(request) if match_result: @@ -132,7 +140,7 @@ def _find_first_match(self, request): return None, match_failed_reasons -def get_response_mock(): +def get_response_mock() -> responses.RequestsMock: """ The responses-library is crucial in ensuring that requests to AWS are intercepted, and routed to the right backend. However, as our usecase is different from a 'typical' responses-user, Moto always needs some custom logic to ensure responses behaves in a way that works for us. @@ -152,13 +160,13 @@ def get_response_mock(): ) else: responses_mock = responses.RequestsMock(assert_all_requests_are_fired=False) - responses_mock._find_match = types.MethodType(_find_first_match, responses_mock) + responses_mock._find_match = types.MethodType(_find_first_match, responses_mock) # type: ignore[assignment] responses_mock.add_passthru("http") return responses_mock -def reset_responses_mock(responses_mock): +def reset_responses_mock(responses_mock: responses.RequestsMock) -> None: if LooseVersion(RESPONSES_VERSION) >= LooseVersion("0.17.0"): from .responses_custom_registry import CustomRegistry diff --git a/moto/core/exceptions.py b/moto/core/exceptions.py index a5b0630f3..afc7e98fa 100644 --- a/moto/core/exceptions.py +++ b/moto/core/exceptions.py @@ -1,6 +1,6 @@ from werkzeug.exceptions import HTTPException from jinja2 import DictLoader, Environment -from typing import Any, Optional +from typing import Any, List, Tuple, Optional import json # TODO: add "Sender" to error responses below? @@ -67,14 +67,18 @@ class RESTError(HTTPException): ) self.content_type = "application/xml" - def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument - return { - "X-Amzn-ErrorType": self.error_type or "UnknownError", - "Content-Type": self.content_type, - } + def get_headers( + self, *args: Any, **kwargs: Any # pylint: disable=unused-argument + ) -> List[Tuple[str, str]]: + return [ + ("X-Amzn-ErrorType", self.error_type or "UnknownError"), + ("Content-Type", self.content_type), + ] - def get_body(self, *args, **kwargs): # pylint: disable=unused-argument - return self.description + def get_body( + self, *args: Any, **kwargs: Any # pylint: disable=unused-argument + ) -> str: + return self.description # type: ignore[return-value] class DryRunClientError(RESTError): @@ -86,19 +90,19 @@ class JsonRESTError(RESTError): self, error_type: str, message: str, template: str = "error_json", **kwargs: Any ): super().__init__(error_type, message, template, **kwargs) - self.description = json.dumps( + self.description: str = json.dumps( {"__type": self.error_type, "message": self.message} ) self.content_type = "application/json" - def get_body(self, *args, **kwargs) -> str: + def get_body(self, *args: Any, **kwargs: Any) -> str: return self.description class SignatureDoesNotMatchError(RESTError): code = 403 - def __init__(self): + def __init__(self) -> None: super().__init__( "SignatureDoesNotMatch", "The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.", @@ -108,7 +112,7 @@ class SignatureDoesNotMatchError(RESTError): class InvalidClientTokenIdError(RESTError): code = 403 - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidClientTokenId", "The security token included in the request is invalid.", @@ -118,7 +122,7 @@ class InvalidClientTokenIdError(RESTError): class AccessDeniedError(RESTError): code = 403 - def __init__(self, user_arn, action): + def __init__(self, user_arn: str, action: str): super().__init__( "AccessDenied", "User: {user_arn} is not authorized to perform: {operation}".format( @@ -130,7 +134,7 @@ class AccessDeniedError(RESTError): class AuthFailureError(RESTError): code = 401 - def __init__(self): + def __init__(self) -> None: super().__init__( "AuthFailure", "AWS was not able to validate the provided access credentials", @@ -142,9 +146,12 @@ class AWSError(JsonRESTError): STATUS = 400 def __init__( - self, message: str, exception_type: str = None, status: Optional[int] = None + self, + message: str, + exception_type: Optional[str] = None, + status: Optional[int] = None, ): - super().__init__(exception_type or self.TYPE, message) + super().__init__(exception_type or self.TYPE, message) # type: ignore[arg-type] self.code = status or self.STATUS @@ -153,7 +160,7 @@ class InvalidNextTokenException(JsonRESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidNextTokenException", "The nextToken provided is invalid" ) @@ -162,5 +169,5 @@ class InvalidNextTokenException(JsonRESTError): class InvalidToken(AWSError): code = 400 - def __init__(self, message="Invalid token"): + def __init__(self, message: str = "Invalid token"): super().__init__("Invalid Token: {}".format(message), "InvalidToken") diff --git a/moto/core/models.py b/moto/core/models.py index 862164b1e..a0789bb5a 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -5,6 +5,7 @@ import os import re import unittest from types import FunctionType +from typing import Any, Callable, Dict, Optional, Set, TypeVar from unittest.mock import patch import boto3 @@ -14,7 +15,7 @@ from botocore.config import Config from botocore.handlers import BUILTIN_HANDLERS from moto import settings -from moto.core.base_backend import BackendDict +from .base_backend import BackendDict from .botocore_stubber import BotocoreStubber from .custom_responses_mock import ( get_response_mock, @@ -24,13 +25,14 @@ from .custom_responses_mock import ( ) DEFAULT_ACCOUNT_ID = "123456789012" +CALLABLE_RETURN = TypeVar("CALLABLE_RETURN") class BaseMockAWS: nested_count = 0 mocks_active = False - def __init__(self, backends): + def __init__(self, backends: BackendDict): from moto.instance_metadata import instance_metadata_backends from moto.moto_api._internal.models import moto_api_backend @@ -55,25 +57,25 @@ class BaseMockAWS: "AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret", } - self.ORIG_KEYS = {} + self.ORIG_KEYS: Dict[str, Optional[str]] = {} self.default_session_mock = patch("boto3.DEFAULT_SESSION", None) if self.__class__.nested_count == 0: - self.reset() + self.reset() # type: ignore[attr-defined] - def __call__(self, func, reset=True): + def __call__(self, func: Callable[..., Any], reset: bool = True) -> Any: if inspect.isclass(func): return self.decorate_class(func) return self.decorate_callable(func, reset) - def __enter__(self): + def __enter__(self) -> "BaseMockAWS": self.start() return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self.stop() - def start(self, reset=True): + def start(self, reset: bool = True) -> None: if not self.__class__.mocks_active: self.default_session_mock.start() self.mock_env_variables() @@ -84,9 +86,9 @@ class BaseMockAWS: for backend in self.backends.values(): backend.reset() - self.enable_patching(reset) + self.enable_patching(reset) # type: ignore[attr-defined] - def stop(self): + def stop(self) -> None: self.__class__.nested_count -= 1 if self.__class__.nested_count < 0: @@ -102,10 +104,12 @@ class BaseMockAWS: pass self.unmock_env_variables() self.__class__.mocks_active = False - self.disable_patching() + self.disable_patching() # type: ignore[attr-defined] - def decorate_callable(self, func, reset): - def wrapper(*args, **kwargs): + def decorate_callable( + self, func: Callable[..., CALLABLE_RETURN], reset: bool + ) -> Callable[..., CALLABLE_RETURN]: + def wrapper(*args: Any, **kwargs: Any) -> CALLABLE_RETURN: self.start(reset=reset) try: result = func(*args, **kwargs) @@ -114,10 +118,10 @@ class BaseMockAWS: return result functools.update_wrapper(wrapper, func) - wrapper.__wrapped__ = func + wrapper.__wrapped__ = func # type: ignore[attr-defined] return wrapper - def decorate_class(self, klass): + def decorate_class(self, klass: type) -> object: direct_methods = get_direct_methods_of(klass) defined_classes = set( x for x, y in klass.__dict__.items() if inspect.isclass(y) @@ -181,7 +185,7 @@ class BaseMockAWS: continue return klass - def mock_env_variables(self): + def mock_env_variables(self) -> None: # "Mock" the AWS credentials as they can't be mocked in Botocore currently # self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS) # self.env_variables_mocks.start() @@ -189,7 +193,7 @@ class BaseMockAWS: self.ORIG_KEYS[k] = os.environ.get(k, None) os.environ[k] = v - def unmock_env_variables(self): + def unmock_env_variables(self) -> None: # This doesn't work in Python2 - for some reason, unmocking clears the entire os.environ dict # Obviously bad user experience, and also breaks pytest - as it uses PYTEST_CURRENT_TEST as an env var # self.env_variables_mocks.stop() @@ -200,7 +204,7 @@ class BaseMockAWS: del os.environ[k] -def get_direct_methods_of(klass): +def get_direct_methods_of(klass: object) -> Set[str]: return set( x for x, y in klass.__dict__.items() @@ -232,7 +236,7 @@ botocore_stubber = BotocoreStubber() BUILTIN_HANDLERS.append(("before-send", botocore_stubber)) -def patch_client(client): +def patch_client(client: botocore.client.BaseClient) -> None: """ Explicitly patch a boto3-client """ @@ -254,7 +258,7 @@ def patch_client(client): raise Exception(f"Argument {client} should be of type boto3.client") -def patch_resource(resource): +def patch_resource(resource: Any) -> None: """ Explicitly patch a boto3-resource """ @@ -267,11 +271,13 @@ def patch_resource(resource): class BotocoreEventMockAWS(BaseMockAWS): - def reset(self): + def reset(self) -> None: botocore_stubber.reset() reset_responses_mock(responses_mock) - def enable_patching(self, reset=True): # pylint: disable=unused-argument + def enable_patching( + self, reset: bool = True # pylint: disable=unused-argument + ) -> None: # Circumvent circular imports from .utils import convert_flask_to_responses_response @@ -313,7 +319,7 @@ class BotocoreEventMockAWS(BaseMockAWS): ) ) - def disable_patching(self): + def disable_patching(self) -> None: botocore_stubber.enabled = False self.reset() @@ -330,14 +336,13 @@ class ServerModeMockAWS(BaseMockAWS): RESET_IN_PROGRESS = False - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): self.test_server_mode_endpoint = settings.test_server_mode_endpoint() super().__init__(*args, **kwargs) - def reset(self): + def reset(self) -> None: call_reset_api = os.environ.get("MOTO_CALL_RESET_API") - call_reset_api = not call_reset_api or call_reset_api.lower() != "false" - if call_reset_api: + if not call_reset_api or call_reset_api.lower() != "false": if not ServerModeMockAWS.RESET_IN_PROGRESS: ServerModeMockAWS.RESET_IN_PROGRESS = True import requests @@ -345,14 +350,14 @@ class ServerModeMockAWS(BaseMockAWS): requests.post(f"{self.test_server_mode_endpoint}/moto-api/reset") ServerModeMockAWS.RESET_IN_PROGRESS = False - def enable_patching(self, reset=True): + def enable_patching(self, reset: bool = True) -> None: if self.__class__.nested_count == 1 and reset: # Just started self.reset() from boto3 import client as real_boto3_client, resource as real_boto3_resource - def fake_boto3_client(*args, **kwargs): + def fake_boto3_client(*args: Any, **kwargs: Any) -> botocore.client.BaseClient: region = self._get_region(*args, **kwargs) if region: if "config" in kwargs: @@ -364,7 +369,7 @@ class ServerModeMockAWS(BaseMockAWS): kwargs["endpoint_url"] = self.test_server_mode_endpoint return real_boto3_client(*args, **kwargs) - def fake_boto3_resource(*args, **kwargs): + def fake_boto3_resource(*args: Any, **kwargs: Any) -> Any: if "endpoint_url" not in kwargs: kwargs["endpoint_url"] = self.test_server_mode_endpoint return real_boto3_resource(*args, **kwargs) @@ -374,7 +379,7 @@ class ServerModeMockAWS(BaseMockAWS): self._client_patcher.start() self._resource_patcher.start() - def _get_region(self, *args, **kwargs): + def _get_region(self, *args: Any, **kwargs: Any) -> Optional[str]: if "region_name" in kwargs: return kwargs["region_name"] if type(args) == tuple and len(args) == 2: @@ -382,7 +387,7 @@ class ServerModeMockAWS(BaseMockAWS): return region return None - def disable_patching(self): + def disable_patching(self) -> None: if self._client_patcher: self._client_patcher.stop() self._resource_patcher.stop() @@ -394,9 +399,9 @@ class base_decorator: def __init__(self, backends: BackendDict): self.backends = backends - def __call__(self, func=None): + def __call__(self, func: Optional[Callable[..., Any]] = None) -> BaseMockAWS: if settings.TEST_SERVER_MODE: - mocked_backend = ServerModeMockAWS(self.backends) + mocked_backend: BaseMockAWS = ServerModeMockAWS(self.backends) else: mocked_backend = self.mock_backend(self.backends) diff --git a/moto/core/responses.py b/moto/core/responses.py index 76c42908e..5776a1e33 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -11,11 +11,22 @@ import xmltodict from collections import defaultdict, OrderedDict from moto import settings +from moto.core.common_types import TYPE_RESPONSE, TYPE_IF_NONE from moto.core.exceptions import DryRunClientError from moto.core.utils import camelcase_to_underscores, method_names_from_class from moto.utilities.utils import load_resource from jinja2 import Environment, DictLoader, Template -from typing import Dict, Union, Any, Tuple, TypeVar +from typing import ( + Dict, + Union, + Any, + Tuple, + Optional, + List, + Set, + ClassVar, + Callable, +) from urllib.parse import parse_qs, parse_qsl, urlparse from werkzeug.exceptions import HTTPException from xml.dom.minidom import parseString as parseXML @@ -23,29 +34,19 @@ from xml.dom.minidom import parseString as parseXML log = logging.getLogger(__name__) -JINJA_ENVS = {} - -TYPE_RESPONSE = Tuple[int, Dict[str, str], str] -TYPE_IF_NONE = TypeVar("TYPE_IF_NONE") +JINJA_ENVS: Dict[type, Environment] = {} -def _decode_dict(d): - decoded = OrderedDict() +def _decode_dict(d: Dict[Any, Any]) -> Dict[str, Any]: + decoded: Dict[str, Any] = OrderedDict() for key, value in d.items(): if isinstance(key, bytes): newkey = key.decode("utf-8") - elif isinstance(key, (list, tuple)): - newkey = [] - for k in key: - if isinstance(k, bytes): - newkey.append(k.decode("utf-8")) - else: - newkey.append(k) else: newkey = key if isinstance(value, bytes): - newvalue = value.decode("utf-8") + decoded[newkey] = value.decode("utf-8") elif isinstance(value, (list, tuple)): newvalue = [] for v in value: @@ -53,18 +54,18 @@ def _decode_dict(d): newvalue.append(v.decode("utf-8")) else: newvalue.append(v) + decoded[newkey] = newvalue else: - newvalue = value + decoded[newkey] = value - decoded[newkey] = newvalue return decoded class DynamicDictLoader(DictLoader): - def update(self, mapping): - self.mapping.update(mapping) + def update(self, mapping: Dict[str, str]) -> None: + self.mapping.update(mapping) # type: ignore[attr-defined] - def contains(self, template): + def contains(self, template: str) -> bool: return bool(template in self.mapping) @@ -73,12 +74,12 @@ class _TemplateEnvironmentMixin(object): RIGHT_PATTERN = re.compile(r">[\s\n]+") @property - def should_autoescape(self): + def should_autoescape(self) -> bool: # Allow for subclass to overwrite return False @property - def environment(self): + def environment(self) -> Environment: key = type(self) try: environment = JINJA_ENVS[key] @@ -94,11 +95,11 @@ class _TemplateEnvironmentMixin(object): return environment - def contains_template(self, template_id): - return self.environment.loader.contains(template_id) + def contains_template(self, template_id: str) -> bool: + return self.environment.loader.contains(template_id) # type: ignore[union-attr] @classmethod - def _make_template_id(cls, source): + def _make_template_id(cls, source: str) -> str: """ Return a numeric string that's unique for the lifetime of the source. @@ -117,47 +118,49 @@ class _TemplateEnvironmentMixin(object): xml = re.sub( self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source) ) - self.environment.loader.update({template_id: xml}) + self.environment.loader.update({template_id: xml}) # type: ignore[union-attr] return self.environment.get_template(template_id) class ActionAuthenticatorMixin(object): - request_count = 0 + request_count: ClassVar[int] = 0 - def _authenticate_and_authorize_action(self, iam_request_cls): + def _authenticate_and_authorize_action(self, iam_request_cls: type) -> None: if ( ActionAuthenticatorMixin.request_count >= settings.INITIAL_NO_AUTH_ACTION_COUNT ): iam_request = iam_request_cls( - account_id=self.current_account, - method=self.method, - path=self.path, - data=self.data, - headers=self.headers, + account_id=self.current_account, # type: ignore[attr-defined] + method=self.method, # type: ignore[attr-defined] + path=self.path, # type: ignore[attr-defined] + data=self.data, # type: ignore[attr-defined] + headers=self.headers, # type: ignore[attr-defined] ) iam_request.check_signature() iam_request.check_action_permitted() else: ActionAuthenticatorMixin.request_count += 1 - def _authenticate_and_authorize_normal_action(self): + def _authenticate_and_authorize_normal_action(self) -> None: from moto.iam.access_control import IAMRequest self._authenticate_and_authorize_action(IAMRequest) - def _authenticate_and_authorize_s3_action(self): + def _authenticate_and_authorize_s3_action(self) -> None: from moto.iam.access_control import S3IAMRequest self._authenticate_and_authorize_action(S3IAMRequest) @staticmethod - def set_initial_no_auth_action_count(initial_no_auth_action_count): + def set_initial_no_auth_action_count(initial_no_auth_action_count: int) -> Callable[..., Callable[..., TYPE_RESPONSE]]: # type: ignore[misc] _test_server_mode_endpoint = settings.test_server_mode_endpoint() - def decorator(function): - def wrapper(*args, **kwargs): + def decorator( + function: Callable[..., TYPE_RESPONSE] + ) -> Callable[..., TYPE_RESPONSE]: + def wrapper(*args: Any, **kwargs: Any) -> TYPE_RESPONSE: if settings.TEST_SERVER_MODE: response = requests.post( f"{_test_server_mode_endpoint}/moto-api/reset-auth", @@ -191,7 +194,7 @@ class ActionAuthenticatorMixin(object): return result functools.update_wrapper(wrapper, function) - wrapper.__wrapped__ = function + wrapper.__wrapped__ = function # type: ignore[attr-defined] return wrapper return decorator @@ -213,12 +216,12 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): ) aws_service_spec = None - def __init__(self, service_name=None) -> None: + def __init__(self, service_name: Optional[str] = None): super().__init__() self.service_name = service_name @classmethod - def dispatch(cls, *args: Any, **kwargs: Any) -> Any: + def dispatch(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc] return cls()._dispatch(*args, **kwargs) def setup_class( @@ -227,7 +230,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): """ use_raw_body: Use incoming bytes if True, encode to string otherwise """ - querystring = OrderedDict() + querystring: Dict[str, Any] = OrderedDict() if hasattr(request, "body"): # Boto self.body = request.body @@ -292,7 +295,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self.data = querystring self.method = request.method self.region = self.get_region_from_url(request, full_url) - self.uri_match = None + self.uri_match: Optional[re.Match[str]] = None self.headers = request.headers if "host" not in self.headers: @@ -307,11 +310,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): mark_account_as_visited( account_id=self.current_account, access_key=self.access_key, - service=self.service_name, + service=self.service_name, # type: ignore[arg-type] region=self.region, ) - def get_region_from_url(self, request, full_url): + def get_region_from_url(self, request: Any, full_url: str) -> str: url_match = self.region_regex.search(full_url) user_agent_match = self.region_from_useragent_regex.search( request.headers.get("User-Agent", "") @@ -329,7 +332,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): region = self.default_region return region - def get_access_key(self): + def get_access_key(self) -> str: """ Returns the access key id used in this request as the current user id """ @@ -339,11 +342,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return match.group(1) if self.querystring.get("AWSAccessKeyId"): - return self.querystring.get("AWSAccessKeyId")[0] + return self.querystring["AWSAccessKeyId"][0] else: return "AKIAEXAMPLE" - def get_current_account(self): + def get_current_account(self) -> str: # PRIO 1: Check if we have a Environment Variable set if "MOTO_ACCOUNT_ID" in os.environ: return os.environ["MOTO_ACCOUNT_ID"] @@ -358,11 +361,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return get_account_id_from(self.get_access_key()) - def _dispatch(self, request, full_url, headers): + def _dispatch(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) return self.call_action() - def uri_to_regexp(self, uri): + def uri_to_regexp(self, uri: str) -> str: """converts uri w/ placeholder to regexp '/cars/{carName}/drivers/{DriverName}' -> '^/cars/.*/drivers/[^/]*$' @@ -372,7 +375,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): """ - def _convert(elem, is_last): + def _convert(elem: str, is_last: bool) -> str: if not re.match("^{.*}$", elem): return elem name = ( @@ -394,7 +397,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): ) return regexp - def _get_action_from_method_and_request_uri(self, method, request_uri): + def _get_action_from_method_and_request_uri( + self, method: str, request_uri: str + ) -> str: """basically used for `rest-json` APIs You can refer to example from link below https://github.com/boto/botocore/blob/develop/botocore/data/iot/2015-05-28/service-2.json @@ -406,7 +411,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # make cache if it does not exist yet if not hasattr(self, "method_urls"): - self.method_urls = defaultdict(lambda: defaultdict(str)) + self.method_urls: Dict[str, Dict[str, str]] = defaultdict( + lambda: defaultdict(str) + ) op_names = conn._service_model.operation_names for op_name in op_names: op_model = conn._service_model.operation_model(op_name) @@ -419,9 +426,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self.uri_match = match if match: return name - return None + return None # type: ignore[return-value] - def _get_action(self): + def _get_action(self) -> str: action = self.querystring.get("Action", [""])[0] if action: return action @@ -433,7 +440,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # get action from method and uri return self._get_action_from_method_and_request_uri(self.method, self.path) - def call_action(self): + def call_action(self) -> TYPE_RESPONSE: headers = self.response_headers try: @@ -449,9 +456,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): try: response = method() except HTTPException as http_error: - response_headers = dict(http_error.get_headers() or []) - response_headers["status"] = http_error.code - response = http_error.description, response_headers + response_headers: Dict[str, Union[str, int]] = dict( + http_error.get_headers() or [] + ) + response_headers["status"] = http_error.code # type: ignore[assignment] + response = http_error.description, response_headers # type: ignore[assignment] if isinstance(response, str): return 200, headers, response @@ -466,7 +475,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): ) @staticmethod - def _send_response(headers, response): + def _send_response(headers: Dict[str, str], response: Any) -> Tuple[int, Dict[str, str], str]: # type: ignore[misc] if response is None: response = "", {} if len(response) == 2: @@ -480,7 +489,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): headers["status"] = str(headers["status"]) return status, headers, body - def _get_param(self, param_name, if_none=None) -> Any: + def _get_param(self, param_name: str, if_none: Any = None) -> Any: val = self.querystring.get(param_name) if val is not None: return val[0] @@ -503,7 +512,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return if_none def _get_int_param( - self, param_name, if_none: TYPE_IF_NONE = None + self, param_name: str, if_none: TYPE_IF_NONE = None # type: ignore[assignment] ) -> Union[int, TYPE_IF_NONE]: val = self._get_param(param_name) if val is not None: @@ -511,7 +520,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return if_none def _get_bool_param( - self, param_name, if_none: TYPE_IF_NONE = None + self, param_name: str, if_none: TYPE_IF_NONE = None # type: ignore[assignment] ) -> Union[bool, TYPE_IF_NONE]: val = self._get_param(param_name) if val is not None: @@ -522,12 +531,15 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return False return if_none - def _get_multi_param_dict(self, param_prefix) -> Dict: + def _get_multi_param_dict(self, param_prefix: str) -> Dict[str, Any]: return self._get_multi_param_helper(param_prefix, skip_result_conversion=True) def _get_multi_param_helper( - self, param_prefix, skip_result_conversion=False, tracked_prefixes=None - ): + self, + param_prefix: str, + skip_result_conversion: bool = False, + tracked_prefixes: Optional[Set[str]] = None, + ) -> Any: value_dict = dict() tracked_prefixes = ( tracked_prefixes or set() @@ -589,11 +601,13 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if len(parts) != 2 or parts[1] != "member": value_dict[parts[0]] = value_dict.pop(k) else: - value_dict = list(value_dict.values())[0] + value_dict = list(value_dict.values())[0] # type: ignore[assignment] return value_dict - def _get_multi_param(self, param_prefix, skip_result_conversion=False) -> Any: + def _get_multi_param( + self, param_prefix: str, skip_result_conversion: bool = False + ) -> List[Any]: """ Given a querystring of ?LaunchConfigurationNames.member.1=my-test-1&LaunchConfigurationNames.member.2=my-test-2 this will return ['my-test-1', 'my-test-2'] @@ -616,7 +630,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return values - def _get_dict_param(self, param_prefix) -> Dict: + def _get_dict_param(self, param_prefix: str) -> Dict[str, Any]: """ Given a parameter dict of { @@ -630,7 +644,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): "instance_count": "1", } """ - params = {} + params: Dict[str, Any] = {} for key, value in self.querystring.items(): if key.startswith(param_prefix): params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[ @@ -638,7 +652,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): ] return params - def _get_params(self) -> Any: + def _get_params(self) -> Dict[str, Any]: """ Given a querystring of { @@ -674,12 +688,12 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): ] } """ - params = {} + params: Dict[str, Any] = {} for k, v in sorted(self.querystring.items()): self._parse_param(k, v[0], params) return params - def _parse_param(self, key, value, params): + def _parse_param(self, key: str, value: str, params: Any) -> None: keylist = key.split(".") obj = params for i, key in enumerate(keylist[:-1]): @@ -713,7 +727,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: obj[keylist[-1]] = value - def _get_list_prefix(self, param_prefix: str) -> Any: + def _get_list_prefix(self, param_prefix: str) -> List[Dict[str, Any]]: """ Given a query dict like { @@ -752,7 +766,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): param_index += 1 return results - def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"): + def _get_map_prefix( + self, param_prefix: str, key_end: str = ".key", value_end: str = ".value" + ) -> Dict[str, Any]: results = {} param_index = 1 while 1: @@ -774,7 +790,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return results - def _get_object_map(self, prefix, name="Name", value="Value"): + def _get_object_map( + self, prefix: str, name: str = "Name", value: str = "Value" + ) -> Dict[str, Any]: """ Given a query dict like { @@ -822,19 +840,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return object_map @property - def request_json(self): + def request_json(self) -> bool: return "JSON" in self.querystring.get("ContentType", []) - def error_on_dryrun(self): + def error_on_dryrun(self) -> None: self.is_not_dryrun() - def is_not_dryrun(self, action=None): - action = action or self._get_param("Action") + def is_not_dryrun(self, action: Optional[str] = None) -> bool: if "true" in self.querystring.get("DryRun", ["false"]): - message = ( - "An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set" - % action - ) + a = action or self._get_param("Action") + message = f"An error occurred (DryRunOperation) when calling the {a} operation: Request would have succeeded, but DryRun flag is set" raise DryRunClientError(error_type="DryRunOperation", message=message) return True @@ -842,20 +857,20 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): class _RecursiveDictRef(object): """Store a recursive reference to dict.""" - def __init__(self): - self.key = None - self.dic = {} + def __init__(self) -> None: + self.key: Optional[str] = None + self.dic: Dict[str, Any] = {} - def __repr__(self): + def __repr__(self) -> str: return "{!r}".format(self.dic) - def __getattr__(self, key): - return self.dic.__getattr__(key) + def __getattr__(self, key: str) -> Any: + return self.dic.__getattr__(key) # type: ignore[attr-defined] - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self.dic.__getitem__(key) - def set_reference(self, key, dic): + def set_reference(self, key: str, dic: Dict[str, Any]) -> None: """Set the RecursiveDictRef object to keep reference to dict object (dic) at the key. @@ -877,7 +892,7 @@ class AWSServiceSpec(object): self.operations = spec["operations"] self.shapes = spec["shapes"] - def input_spec(self, operation): + def input_spec(self, operation: str) -> Dict[str, Any]: try: op = self.operations[operation] except KeyError: @@ -887,7 +902,7 @@ class AWSServiceSpec(object): shape = self.shapes[op["input"]["shape"]] return self._expand(shape) - def output_spec(self, operation): + def output_spec(self, operation: str) -> Dict[str, Any]: """Produce a JSON with a valid API response syntax for operation, but with type information. Each node represented by a key has the value containing field type, e.g., @@ -904,11 +919,13 @@ class AWSServiceSpec(object): shape = self.shapes[op["output"]["shape"]] return self._expand(shape) - def _expand(self, shape): - def expand(dic, seen=None): + def _expand(self, shape: Dict[str, Any]) -> Dict[str, Any]: + def expand( + dic: Dict[str, Any], seen: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: seen = seen or {} if dic["type"] == "structure": - nodes = {} + nodes: Dict[str, Any] = {} for k, v in dic["members"].items(): seen_till_here = dict(seen) if k in seen_till_here: @@ -932,7 +949,7 @@ class AWSServiceSpec(object): elif dic["type"] == "map": seen_till_here = dict(seen) - node = {"type": "map"} + node: Dict[str, Any] = {"type": "map"} if "shape" in dic["key"]: shape = dic["key"]["shape"] @@ -958,12 +975,12 @@ class AWSServiceSpec(object): return expand(shape) -def to_str(value, spec): +def to_str(value: Any, spec: Dict[str, Any]) -> str: vtype = spec["type"] if vtype == "boolean": return "true" if value else "false" elif vtype == "long": - return int(value) + return int(value) # type: ignore[return-value] elif vtype == "integer": return str(value) elif vtype == "float": @@ -984,7 +1001,7 @@ def to_str(value, spec): raise TypeError("Unknown type {}".format(vtype)) -def from_str(value, spec): +def from_str(value: str, spec: Dict[str, Any]) -> Any: vtype = spec["type"] if vtype == "boolean": return True if value == "true" else False @@ -1001,7 +1018,9 @@ def from_str(value, spec): raise TypeError("Unknown type {}".format(vtype)) -def flatten_json_request_body(prefix, dict_body, spec): +def flatten_json_request_body( + prefix: str, dict_body: Dict[str, Any], spec: Dict[str, Any] +) -> Dict[str, Any]: """Convert a JSON request body into query params.""" if len(spec) == 1 and "type" in spec: return {prefix: to_str(dict_body, spec)} @@ -1030,10 +1049,12 @@ def flatten_json_request_body(prefix, dict_body, spec): return dict((prefix + k, v) for k, v in flat.items()) -def xml_to_json_response(service_spec, operation, xml, result_node=None): +def xml_to_json_response( + service_spec: Any, operation: str, xml: str, result_node: Any = None +) -> Dict[str, Any]: """Convert rendered XML response to JSON for use with boto3.""" - def transform(value, spec): + def transform(value: Any, spec: Dict[str, Any]) -> Any: """Apply transformations to make the output JSON comply with the expected form. This function applies: @@ -1047,7 +1068,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None): if len(spec) == 1: return from_str(value, spec) - od = OrderedDict() + od: Dict[str, Any] = OrderedDict() for k, v in value.items(): if k.startswith("@"): continue @@ -1099,7 +1120,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None): for k in result_node or (operation + "Response", operation + "Result"): dic = dic[k] except KeyError: - return None + return None # type: ignore[return-value] else: return transform(dic, output_spec) return None diff --git a/moto/core/responses_custom_registry.py b/moto/core/responses_custom_registry.py index 2e87dd6ab..30f800b51 100644 --- a/moto/core/responses_custom_registry.py +++ b/moto/core/responses_custom_registry.py @@ -1,5 +1,6 @@ # This will only exist in responses >= 0.17 import responses +from typing import Any, List, Tuple, Optional from .custom_responses_mock import CallbackResponse, not_implemented_callback @@ -10,11 +11,12 @@ class CustomRegistry(responses.registries.FirstMatchRegistry): - CallbackResponses are not discarded after first use - users can mock the same URL as often as they like """ - def add(self, response): + def add(self, response: responses.BaseResponse) -> responses.BaseResponse: if response not in self.registered: super().add(response) + return response - def find(self, request): + def find(self, request: Any) -> Tuple[Optional[responses.BaseResponse], List[str]]: all_possibles = responses._default_mock._registry.registered + self.registered found = [] match_failed_reasons = [] diff --git a/moto/core/utils.py b/moto/core/utils.py index d803fb6d1..29d8fabb8 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -2,11 +2,12 @@ import datetime import inspect import re from botocore.exceptions import ClientError -from typing import Optional +from typing import Any, Optional, List, Callable, Dict from urllib.parse import urlparse +from .common_types import TYPE_RESPONSE -def camelcase_to_underscores(argument: Optional[str]) -> str: +def camelcase_to_underscores(argument: str) -> str: """Converts a camelcase param like theNewAttribute to the equivalent python underscore variable like the_new_attribute""" result = "" @@ -32,7 +33,7 @@ def camelcase_to_underscores(argument: Optional[str]) -> str: return result -def underscores_to_camelcase(argument): +def underscores_to_camelcase(argument: str) -> str: """Converts a camelcase param like the_new_attribute to the equivalent camelcase version like theNewAttribute. Note that the first letter is NOT capitalized by this function""" @@ -48,17 +49,17 @@ def underscores_to_camelcase(argument): return result -def pascal_to_camelcase(argument): +def pascal_to_camelcase(argument: str) -> str: """Converts a PascalCase param to the camelCase equivalent""" return argument[0].lower() + argument[1:] -def camelcase_to_pascal(argument): +def camelcase_to_pascal(argument: str) -> str: """Converts a camelCase param to the PascalCase equivalent""" return argument[0].upper() + argument[1:] -def method_names_from_class(clazz): +def method_names_from_class(clazz: object) -> List[str]: predicate = inspect.isfunction return [x[0] for x in inspect.getmembers(clazz, predicate=predicate)] @@ -70,7 +71,7 @@ def convert_regex_to_flask_path(url_path: str) -> str: for token in ["$"]: url_path = url_path.replace(token, "") - def caller(reg): + def caller(reg: Any) -> str: match_name, match_pattern = reg.groups() return ''.format(match_pattern, match_name) @@ -83,11 +84,11 @@ def convert_regex_to_flask_path(url_path: str) -> str: class convert_to_flask_response(object): - def __init__(self, callback): + def __init__(self, callback: Callable[..., Any]): self.callback = callback @property - def __name__(self): + def __name__(self) -> str: # For instance methods, use class and method names. Otherwise # use module and method name if inspect.ismethod(self.callback): @@ -96,7 +97,7 @@ class convert_to_flask_response(object): outer = self.callback.__module__ return "{0}.{1}".format(outer, self.callback.__name__) - def __call__(self, args=None, **kwargs): + def __call__(self, args: Any = None, **kwargs: Any) -> Any: from flask import request, Response from moto.moto_api import recorder @@ -118,11 +119,11 @@ class convert_to_flask_response(object): class convert_flask_to_responses_response(object): - def __init__(self, callback): + def __init__(self, callback: Callable[..., Any]): self.callback = callback @property - def __name__(self): + def __name__(self) -> str: # For instance methods, use class and method names. Otherwise # use module and method name if inspect.ismethod(self.callback): @@ -131,7 +132,7 @@ class convert_flask_to_responses_response(object): outer = self.callback.__module__ return "{0}.{1}".format(outer, self.callback.__name__) - def __call__(self, request, *args, **kwargs): + def __call__(self, request: Any, *args: Any, **kwargs: Any) -> TYPE_RESPONSE: for key, val in request.headers.items(): if isinstance(val, bytes): request.headers[key] = val.decode("utf-8") @@ -141,7 +142,7 @@ class convert_flask_to_responses_response(object): return status, headers, response -def iso_8601_datetime_with_milliseconds(value: datetime) -> str: +def iso_8601_datetime_with_milliseconds(value: datetime.datetime) -> str: return value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" @@ -163,22 +164,22 @@ def iso_8601_datetime_without_milliseconds_s3( RFC1123 = "%a, %d %b %Y %H:%M:%S GMT" -def rfc_1123_datetime(src): +def rfc_1123_datetime(src: datetime.datetime) -> str: return src.strftime(RFC1123) -def str_to_rfc_1123_datetime(value): +def str_to_rfc_1123_datetime(value: str) -> datetime.datetime: return datetime.datetime.strptime(value, RFC1123) -def unix_time(dt: datetime.datetime = None) -> int: +def unix_time(dt: Optional[datetime.datetime] = None) -> float: dt = dt or datetime.datetime.utcnow() epoch = datetime.datetime.utcfromtimestamp(0) delta = dt - epoch return (delta.days * 86400) + (delta.seconds + (delta.microseconds / 1e6)) -def unix_time_millis(dt: datetime = None) -> int: +def unix_time_millis(dt: Optional[datetime.datetime] = None) -> float: return unix_time(dt) * 1000.0 @@ -193,28 +194,33 @@ def path_url(url: str) -> str: def tags_from_query_string( - querystring_dict, prefix="Tag", key_suffix="Key", value_suffix="Value" -): + querystring_dict: Dict[str, Any], + prefix: str = "Tag", + key_suffix: str = "Key", + value_suffix: str = "Value", +) -> Dict[str, str]: response_values = {} for key in querystring_dict.keys(): if key.startswith(prefix) and key.endswith(key_suffix): tag_index = key.replace(prefix + ".", "").replace("." + key_suffix, "") - tag_key = querystring_dict.get( + tag_key = querystring_dict[ "{prefix}.{index}.{key_suffix}".format( prefix=prefix, index=tag_index, key_suffix=key_suffix ) - )[0] + ][0] tag_value_key = "{prefix}.{index}.{value_suffix}".format( prefix=prefix, index=tag_index, value_suffix=value_suffix ) if tag_value_key in querystring_dict: - response_values[tag_key] = querystring_dict.get(tag_value_key)[0] + response_values[tag_key] = querystring_dict[tag_value_key][0] else: response_values[tag_key] = None return response_values -def tags_from_cloudformation_tags_list(tags_list): +def tags_from_cloudformation_tags_list( + tags_list: List[Dict[str, str]] +) -> Dict[str, str]: """Return tags in dict form from cloudformation resource tags form (list of dicts)""" tags = {} for entry in tags_list: @@ -225,7 +231,7 @@ def tags_from_cloudformation_tags_list(tags_list): return tags -def remap_nested_keys(root, key_transform): +def remap_nested_keys(root: Any, key_transform: Callable[[str], str]) -> Any: """This remap ("recursive map") function is used to traverse and transform the dictionary keys of arbitrarily nested structures. List comprehensions do not recurse, making it tedious to apply @@ -252,7 +258,9 @@ def remap_nested_keys(root, key_transform): return root -def merge_dicts(dict1, dict2, remove_nulls=False): +def merge_dicts( + dict1: Dict[str, Any], dict2: Dict[str, Any], remove_nulls: bool = False +) -> None: """Given two arbitrarily nested dictionaries, merge the second dict into the first. :param dict dict1: the dictionary to be updated. @@ -275,7 +283,7 @@ def merge_dicts(dict1, dict2, remove_nulls=False): dict1.pop(key) -def aws_api_matches(pattern, string): +def aws_api_matches(pattern: str, string: str) -> bool: """ AWS API can match a value based on a glob, or an exact match """ @@ -296,7 +304,7 @@ def aws_api_matches(pattern, string): return False -def extract_region_from_aws_authorization(string): +def extract_region_from_aws_authorization(string: str) -> Optional[str]: auth = string or "" region = re.sub(r".*Credential=[^/]+/[^/]+/([^/]+)/.*", r"\1", auth) if region == auth: diff --git a/setup.cfg b/setup.cfg index f7cf76cbe..65755c267 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/ce,moto/cloudformation,moto/cloudfront,moto/cloudtrail,moto/codebuild,moto/cloudwatch,moto/codepipeline,moto/codecommit,moto/cognito*,moto/comprehend,moto/config,moto/core/base_backend.py +files= moto/a*,moto/b*,moto/c* show_column_numbers=True show_error_codes = True disable_error_code=abstract