TechDebt: MyPy Core (#5653)

This commit is contained in:
Bert Blommers 2022-11-10 18:54:38 -01:00 committed by GitHub
parent 8c9838cc8c
commit 96b2eff1bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 286 additions and 222 deletions

View File

@ -2,15 +2,17 @@ from collections import defaultdict
from io import BytesIO from io import BytesIO
from botocore.awsrequest import AWSResponse from botocore.awsrequest import AWSResponse
from moto.core.exceptions import HTTPException from moto.core.exceptions import HTTPException
from typing import Any, Dict, Callable, List, Tuple, Union, Pattern
from .responses import TYPE_RESPONSE
class MockRawResponse(BytesIO): class MockRawResponse(BytesIO):
def __init__(self, response_input): def __init__(self, response_input: Union[str, bytes]):
if isinstance(response_input, str): if isinstance(response_input, str):
response_input = response_input.encode("utf-8") response_input = response_input.encode("utf-8")
super().__init__(response_input) super().__init__(response_input)
def stream(self, **kwargs): # pylint: disable=unused-argument def stream(self, **kwargs: Any) -> Any: # pylint: disable=unused-argument
contents = self.read() contents = self.read()
while contents: while contents:
yield contents yield contents
@ -18,18 +20,22 @@ class MockRawResponse(BytesIO):
class BotocoreStubber: class BotocoreStubber:
def __init__(self): def __init__(self) -> None:
self.enabled = False self.enabled = False
self.methods = defaultdict(list) self.methods: Dict[
str, List[Tuple[Pattern[str], Callable[..., TYPE_RESPONSE]]]
] = defaultdict(list)
def reset(self): def reset(self) -> None:
self.methods.clear() self.methods.clear()
def register_response(self, method, pattern, response): def register_response(
self, method: str, pattern: Pattern[str], response: Callable[..., TYPE_RESPONSE]
) -> None:
matchers = self.methods[method] matchers = self.methods[method]
matchers.append((pattern, response)) matchers.append((pattern, response))
def __call__(self, event_name, request, **kwargs): def __call__(self, event_name: str, request: Any, **kwargs: Any) -> AWSResponse:
if not self.enabled: if not self.enabled:
return None return None
@ -41,7 +47,7 @@ class BotocoreStubber:
matchers = self.methods.get(request.method) matchers = self.methods.get(request.method)
base_url = request.url.split("?", 1)[0] base_url = request.url.split("?", 1)[0]
for i, (pattern, callback) in enumerate(matchers): for i, (pattern, callback) in enumerate(matchers): # type: ignore[arg-type]
if pattern.match(base_url): if pattern.match(base_url):
if found_index is None: if found_index is None:
found_index = i found_index = i
@ -62,10 +68,10 @@ class BotocoreStubber:
) )
except HTTPException as e: except HTTPException as e:
status = e.code status = e.code # type: ignore[assignment]
headers = e.get_headers() headers = e.get_headers() # type: ignore[assignment]
body = e.get_body() body = e.get_body()
body = MockRawResponse(body) raw_response = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body) response = AWSResponse(request.url, status, headers, raw_response)
return response return response

View File

@ -1,12 +1,14 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
from .base_backend import InstanceTrackerMeta from .base_backend import InstanceTrackerMeta
class BaseModel(metaclass=InstanceTrackerMeta): class BaseModel(metaclass=InstanceTrackerMeta):
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument def __new__(
cls, *args: Any, **kwargs: Any # pylint: disable=unused-argument
) -> "BaseModel":
instance = super(BaseModel, cls).__new__(cls) instance = super(BaseModel, cls).__new__(cls)
cls.instances.append(instance) cls.instances.append(instance) # type: ignore[attr-defined]
return instance return instance
@ -37,7 +39,7 @@ class CloudFormationModel(BaseModel):
@classmethod @classmethod
@abstractmethod @abstractmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, cls,
resource_name: str, resource_name: str,
cloudformation_json: Dict[str, Any], cloudformation_json: Dict[str, Any],
@ -53,7 +55,7 @@ class CloudFormationModel(BaseModel):
@classmethod @classmethod
@abstractmethod @abstractmethod
def update_from_cloudformation_json( def update_from_cloudformation_json( # type: ignore[misc]
cls, cls,
original_resource: Any, original_resource: Any,
new_resource_name: str, new_resource_name: str,
@ -70,7 +72,7 @@ class CloudFormationModel(BaseModel):
@classmethod @classmethod
@abstractmethod @abstractmethod
def delete_from_cloudformation_json( def delete_from_cloudformation_json( # type: ignore[misc]
cls, cls,
resource_name: str, resource_name: str,
cloudformation_json: Dict[str, Any], cloudformation_json: Dict[str, Any],
@ -92,7 +94,7 @@ class CloudFormationModel(BaseModel):
class ConfigQueryModel: class ConfigQueryModel:
def __init__(self, backends): def __init__(self, backends: Any):
"""Inits based on the resource type's backends (1 for each region if applicable)""" """Inits based on the resource type's backends (1 for each region if applicable)"""
self.backends = backends self.backends = backends
@ -106,7 +108,7 @@ class ConfigQueryModel:
backend_region: Optional[str] = None, backend_region: Optional[str] = None,
resource_region: Optional[str] = None, resource_region: Optional[str] = None,
aggregator: Optional[Dict[str, Any]] = None, aggregator: Optional[Dict[str, Any]] = None,
): ) -> Tuple[List[Dict[str, Any]], str]:
"""For AWS Config. This will list all of the resources of the given type and optional resource name and region. """For AWS Config. This will list all of the resources of the given type and optional resource name and region.
This supports both aggregated and non-aggregated listing. The following notes the difference: This supports both aggregated and non-aggregated listing. The following notes the difference:
@ -195,5 +197,5 @@ class ConfigQueryModel:
class CloudWatchMetricProvider(object): class CloudWatchMetricProvider(object):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_cloudwatch_metrics(account_id: str) -> Any: def get_cloudwatch_metrics(account_id: str) -> Any: # type: ignore[misc]
pass pass

View File

@ -0,0 +1,5 @@
from typing import Dict, Tuple, TypeVar
TYPE_RESPONSE = Tuple[int, Dict[str, str], str]
TYPE_IF_NONE = TypeVar("TYPE_IF_NONE")

View File

@ -2,8 +2,10 @@ import responses
import types import types
from io import BytesIO from io import BytesIO
from http.client import responses as http_responses from http.client import responses as http_responses
from typing import Any, Dict, List, Tuple, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from werkzeug.wrappers import Request from werkzeug.wrappers import Request
from .responses import TYPE_RESPONSE
from moto.utilities.distutils_version import LooseVersion from moto.utilities.distutils_version import LooseVersion
@ -21,7 +23,7 @@ class CallbackResponse(responses.CallbackResponse):
Need to subclass so we can change a couple things Need to subclass so we can change a couple things
""" """
def get_response(self, request): def get_response(self, request: Any) -> responses.HTTPResponse:
""" """
Need to override this so we can pass decode_content=False Need to override this so we can pass decode_content=False
""" """
@ -58,20 +60,22 @@ class CallbackResponse(responses.CallbackResponse):
raise result raise result
status, r_headers, body = result status, r_headers, body = result
body = responses._handle_body(body) body_io = responses._handle_body(body)
headers.update(r_headers) headers.update(r_headers)
return responses.HTTPResponse( return responses.HTTPResponse(
status=status, status=status,
reason=http_responses.get(status), reason=http_responses.get(status),
body=body, body=body_io,
headers=headers, headers=headers,
preload_content=False, preload_content=False,
# Need to not decode_content to mimic requests # Need to not decode_content to mimic requests
decode_content=False, decode_content=False,
) )
def _url_matches(self, url, other, match_querystring=False): def _url_matches(
self, url: Any, other: Any, match_querystring: bool = False
) -> bool:
""" """
Need to override this so we can fix querystrings breaking regex matching Need to override this so we can fix querystrings breaking regex matching
""" """
@ -83,16 +87,18 @@ class CallbackResponse(responses.CallbackResponse):
url = responses._clean_unicode(url) url = responses._clean_unicode(url)
if not isinstance(other, str): if not isinstance(other, str):
other = other.encode("ascii").decode("utf8") other = other.encode("ascii").decode("utf8")
return self._url_matches_strict(url, other) return self._url_matches_strict(url, other) # type: ignore[attr-defined]
elif isinstance(url, responses.Pattern) and url.match(other): elif isinstance(url, responses.Pattern) and url.match(other):
return True return True
else: else:
return False return False
def not_implemented_callback(request): # pylint: disable=unused-argument def not_implemented_callback(
request: Any, # pylint: disable=unused-argument
) -> TYPE_RESPONSE:
status = 400 status = 400
headers = {} headers: Dict[str, str] = {}
response = "The method is not implemented" response = "The method is not implemented"
return status, headers, response return status, headers, response
@ -106,10 +112,12 @@ def not_implemented_callback(request): # pylint: disable=unused-argument
# #
# Note that, due to an outdated API we no longer support Responses <= 0.12.1 # Note that, due to an outdated API we no longer support Responses <= 0.12.1
# This method should be used for Responses 0.12.1 < .. < 0.17.0 # This method should be used for Responses 0.12.1 < .. < 0.17.0
def _find_first_match(self, request): def _find_first_match(
self: Any, request: Any
) -> Tuple[Optional[responses.BaseResponse], List[str]]:
matches = [] matches = []
match_failed_reasons = [] match_failed_reasons = []
all_possibles = self._matches + responses._default_mock._matches all_possibles = self._matches + responses._default_mock._matches # type: ignore[attr-defined]
for match in all_possibles: for match in all_possibles:
match_result, reason = match.matches(request) match_result, reason = match.matches(request)
if match_result: if match_result:
@ -132,7 +140,7 @@ def _find_first_match(self, request):
return None, match_failed_reasons return None, match_failed_reasons
def get_response_mock(): def get_response_mock() -> responses.RequestsMock:
""" """
The responses-library is crucial in ensuring that requests to AWS are intercepted, and routed to the right backend. The responses-library is crucial in ensuring that requests to AWS are intercepted, and routed to the right backend.
However, as our usecase is different from a 'typical' responses-user, Moto always needs some custom logic to ensure responses behaves in a way that works for us. However, as our usecase is different from a 'typical' responses-user, Moto always needs some custom logic to ensure responses behaves in a way that works for us.
@ -152,13 +160,13 @@ def get_response_mock():
) )
else: else:
responses_mock = responses.RequestsMock(assert_all_requests_are_fired=False) responses_mock = responses.RequestsMock(assert_all_requests_are_fired=False)
responses_mock._find_match = types.MethodType(_find_first_match, responses_mock) responses_mock._find_match = types.MethodType(_find_first_match, responses_mock) # type: ignore[assignment]
responses_mock.add_passthru("http") responses_mock.add_passthru("http")
return responses_mock return responses_mock
def reset_responses_mock(responses_mock): def reset_responses_mock(responses_mock: responses.RequestsMock) -> None:
if LooseVersion(RESPONSES_VERSION) >= LooseVersion("0.17.0"): if LooseVersion(RESPONSES_VERSION) >= LooseVersion("0.17.0"):
from .responses_custom_registry import CustomRegistry from .responses_custom_registry import CustomRegistry

View File

@ -1,6 +1,6 @@
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment from jinja2 import DictLoader, Environment
from typing import Any, Optional from typing import Any, List, Tuple, Optional
import json import json
# TODO: add "<Type>Sender</Type>" to error responses below? # TODO: add "<Type>Sender</Type>" to error responses below?
@ -67,14 +67,18 @@ class RESTError(HTTPException):
) )
self.content_type = "application/xml" self.content_type = "application/xml"
def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument def get_headers(
return { self, *args: Any, **kwargs: Any # pylint: disable=unused-argument
"X-Amzn-ErrorType": self.error_type or "UnknownError", ) -> List[Tuple[str, str]]:
"Content-Type": self.content_type, return [
} ("X-Amzn-ErrorType", self.error_type or "UnknownError"),
("Content-Type", self.content_type),
]
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument def get_body(
return self.description self, *args: Any, **kwargs: Any # pylint: disable=unused-argument
) -> str:
return self.description # type: ignore[return-value]
class DryRunClientError(RESTError): class DryRunClientError(RESTError):
@ -86,19 +90,19 @@ class JsonRESTError(RESTError):
self, error_type: str, message: str, template: str = "error_json", **kwargs: Any self, error_type: str, message: str, template: str = "error_json", **kwargs: Any
): ):
super().__init__(error_type, message, template, **kwargs) super().__init__(error_type, message, template, **kwargs)
self.description = json.dumps( self.description: str = json.dumps(
{"__type": self.error_type, "message": self.message} {"__type": self.error_type, "message": self.message}
) )
self.content_type = "application/json" self.content_type = "application/json"
def get_body(self, *args, **kwargs) -> str: def get_body(self, *args: Any, **kwargs: Any) -> str:
return self.description return self.description
class SignatureDoesNotMatchError(RESTError): class SignatureDoesNotMatchError(RESTError):
code = 403 code = 403
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"SignatureDoesNotMatch", "SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.", "The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.",
@ -108,7 +112,7 @@ class SignatureDoesNotMatchError(RESTError):
class InvalidClientTokenIdError(RESTError): class InvalidClientTokenIdError(RESTError):
code = 403 code = 403
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidClientTokenId", "InvalidClientTokenId",
"The security token included in the request is invalid.", "The security token included in the request is invalid.",
@ -118,7 +122,7 @@ class InvalidClientTokenIdError(RESTError):
class AccessDeniedError(RESTError): class AccessDeniedError(RESTError):
code = 403 code = 403
def __init__(self, user_arn, action): def __init__(self, user_arn: str, action: str):
super().__init__( super().__init__(
"AccessDenied", "AccessDenied",
"User: {user_arn} is not authorized to perform: {operation}".format( "User: {user_arn} is not authorized to perform: {operation}".format(
@ -130,7 +134,7 @@ class AccessDeniedError(RESTError):
class AuthFailureError(RESTError): class AuthFailureError(RESTError):
code = 401 code = 401
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"AuthFailure", "AuthFailure",
"AWS was not able to validate the provided access credentials", "AWS was not able to validate the provided access credentials",
@ -142,9 +146,12 @@ class AWSError(JsonRESTError):
STATUS = 400 STATUS = 400
def __init__( def __init__(
self, message: str, exception_type: str = None, status: Optional[int] = None self,
message: str,
exception_type: Optional[str] = None,
status: Optional[int] = None,
): ):
super().__init__(exception_type or self.TYPE, message) super().__init__(exception_type or self.TYPE, message) # type: ignore[arg-type]
self.code = status or self.STATUS self.code = status or self.STATUS
@ -153,7 +160,7 @@ class InvalidNextTokenException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidNextTokenException", "The nextToken provided is invalid" "InvalidNextTokenException", "The nextToken provided is invalid"
) )
@ -162,5 +169,5 @@ class InvalidNextTokenException(JsonRESTError):
class InvalidToken(AWSError): class InvalidToken(AWSError):
code = 400 code = 400
def __init__(self, message="Invalid token"): def __init__(self, message: str = "Invalid token"):
super().__init__("Invalid Token: {}".format(message), "InvalidToken") super().__init__("Invalid Token: {}".format(message), "InvalidToken")

View File

@ -5,6 +5,7 @@ import os
import re import re
import unittest import unittest
from types import FunctionType from types import FunctionType
from typing import Any, Callable, Dict, Optional, Set, TypeVar
from unittest.mock import patch from unittest.mock import patch
import boto3 import boto3
@ -14,7 +15,7 @@ from botocore.config import Config
from botocore.handlers import BUILTIN_HANDLERS from botocore.handlers import BUILTIN_HANDLERS
from moto import settings from moto import settings
from moto.core.base_backend import BackendDict from .base_backend import BackendDict
from .botocore_stubber import BotocoreStubber from .botocore_stubber import BotocoreStubber
from .custom_responses_mock import ( from .custom_responses_mock import (
get_response_mock, get_response_mock,
@ -24,13 +25,14 @@ from .custom_responses_mock import (
) )
DEFAULT_ACCOUNT_ID = "123456789012" DEFAULT_ACCOUNT_ID = "123456789012"
CALLABLE_RETURN = TypeVar("CALLABLE_RETURN")
class BaseMockAWS: class BaseMockAWS:
nested_count = 0 nested_count = 0
mocks_active = False mocks_active = False
def __init__(self, backends): def __init__(self, backends: BackendDict):
from moto.instance_metadata import instance_metadata_backends from moto.instance_metadata import instance_metadata_backends
from moto.moto_api._internal.models import moto_api_backend from moto.moto_api._internal.models import moto_api_backend
@ -55,25 +57,25 @@ class BaseMockAWS:
"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_ACCESS_KEY_ID": "foobar_key",
"AWS_SECRET_ACCESS_KEY": "foobar_secret", "AWS_SECRET_ACCESS_KEY": "foobar_secret",
} }
self.ORIG_KEYS = {} self.ORIG_KEYS: Dict[str, Optional[str]] = {}
self.default_session_mock = patch("boto3.DEFAULT_SESSION", None) self.default_session_mock = patch("boto3.DEFAULT_SESSION", None)
if self.__class__.nested_count == 0: if self.__class__.nested_count == 0:
self.reset() self.reset() # type: ignore[attr-defined]
def __call__(self, func, reset=True): def __call__(self, func: Callable[..., Any], reset: bool = True) -> Any:
if inspect.isclass(func): if inspect.isclass(func):
return self.decorate_class(func) return self.decorate_class(func)
return self.decorate_callable(func, reset) return self.decorate_callable(func, reset)
def __enter__(self): def __enter__(self) -> "BaseMockAWS":
self.start() self.start()
return self return self
def __exit__(self, *args): def __exit__(self, *args: Any) -> None:
self.stop() self.stop()
def start(self, reset=True): def start(self, reset: bool = True) -> None:
if not self.__class__.mocks_active: if not self.__class__.mocks_active:
self.default_session_mock.start() self.default_session_mock.start()
self.mock_env_variables() self.mock_env_variables()
@ -84,9 +86,9 @@ class BaseMockAWS:
for backend in self.backends.values(): for backend in self.backends.values():
backend.reset() backend.reset()
self.enable_patching(reset) self.enable_patching(reset) # type: ignore[attr-defined]
def stop(self): def stop(self) -> None:
self.__class__.nested_count -= 1 self.__class__.nested_count -= 1
if self.__class__.nested_count < 0: if self.__class__.nested_count < 0:
@ -102,10 +104,12 @@ class BaseMockAWS:
pass pass
self.unmock_env_variables() self.unmock_env_variables()
self.__class__.mocks_active = False self.__class__.mocks_active = False
self.disable_patching() self.disable_patching() # type: ignore[attr-defined]
def decorate_callable(self, func, reset): def decorate_callable(
def wrapper(*args, **kwargs): self, func: Callable[..., CALLABLE_RETURN], reset: bool
) -> Callable[..., CALLABLE_RETURN]:
def wrapper(*args: Any, **kwargs: Any) -> CALLABLE_RETURN:
self.start(reset=reset) self.start(reset=reset)
try: try:
result = func(*args, **kwargs) result = func(*args, **kwargs)
@ -114,10 +118,10 @@ class BaseMockAWS:
return result return result
functools.update_wrapper(wrapper, func) functools.update_wrapper(wrapper, func)
wrapper.__wrapped__ = func wrapper.__wrapped__ = func # type: ignore[attr-defined]
return wrapper return wrapper
def decorate_class(self, klass): def decorate_class(self, klass: type) -> object:
direct_methods = get_direct_methods_of(klass) direct_methods = get_direct_methods_of(klass)
defined_classes = set( defined_classes = set(
x for x, y in klass.__dict__.items() if inspect.isclass(y) x for x, y in klass.__dict__.items() if inspect.isclass(y)
@ -181,7 +185,7 @@ class BaseMockAWS:
continue continue
return klass return klass
def mock_env_variables(self): def mock_env_variables(self) -> None:
# "Mock" the AWS credentials as they can't be mocked in Botocore currently # "Mock" the AWS credentials as they can't be mocked in Botocore currently
# self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS) # self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
# self.env_variables_mocks.start() # self.env_variables_mocks.start()
@ -189,7 +193,7 @@ class BaseMockAWS:
self.ORIG_KEYS[k] = os.environ.get(k, None) self.ORIG_KEYS[k] = os.environ.get(k, None)
os.environ[k] = v os.environ[k] = v
def unmock_env_variables(self): def unmock_env_variables(self) -> None:
# This doesn't work in Python2 - for some reason, unmocking clears the entire os.environ dict # This doesn't work in Python2 - for some reason, unmocking clears the entire os.environ dict
# Obviously bad user experience, and also breaks pytest - as it uses PYTEST_CURRENT_TEST as an env var # Obviously bad user experience, and also breaks pytest - as it uses PYTEST_CURRENT_TEST as an env var
# self.env_variables_mocks.stop() # self.env_variables_mocks.stop()
@ -200,7 +204,7 @@ class BaseMockAWS:
del os.environ[k] del os.environ[k]
def get_direct_methods_of(klass): def get_direct_methods_of(klass: object) -> Set[str]:
return set( return set(
x x
for x, y in klass.__dict__.items() for x, y in klass.__dict__.items()
@ -232,7 +236,7 @@ botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(("before-send", botocore_stubber)) BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
def patch_client(client): def patch_client(client: botocore.client.BaseClient) -> None:
""" """
Explicitly patch a boto3-client Explicitly patch a boto3-client
""" """
@ -254,7 +258,7 @@ def patch_client(client):
raise Exception(f"Argument {client} should be of type boto3.client") raise Exception(f"Argument {client} should be of type boto3.client")
def patch_resource(resource): def patch_resource(resource: Any) -> None:
""" """
Explicitly patch a boto3-resource Explicitly patch a boto3-resource
""" """
@ -267,11 +271,13 @@ def patch_resource(resource):
class BotocoreEventMockAWS(BaseMockAWS): class BotocoreEventMockAWS(BaseMockAWS):
def reset(self): def reset(self) -> None:
botocore_stubber.reset() botocore_stubber.reset()
reset_responses_mock(responses_mock) reset_responses_mock(responses_mock)
def enable_patching(self, reset=True): # pylint: disable=unused-argument def enable_patching(
self, reset: bool = True # pylint: disable=unused-argument
) -> None:
# Circumvent circular imports # Circumvent circular imports
from .utils import convert_flask_to_responses_response from .utils import convert_flask_to_responses_response
@ -313,7 +319,7 @@ class BotocoreEventMockAWS(BaseMockAWS):
) )
) )
def disable_patching(self): def disable_patching(self) -> None:
botocore_stubber.enabled = False botocore_stubber.enabled = False
self.reset() self.reset()
@ -330,14 +336,13 @@ class ServerModeMockAWS(BaseMockAWS):
RESET_IN_PROGRESS = False RESET_IN_PROGRESS = False
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
self.test_server_mode_endpoint = settings.test_server_mode_endpoint() self.test_server_mode_endpoint = settings.test_server_mode_endpoint()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def reset(self): def reset(self) -> None:
call_reset_api = os.environ.get("MOTO_CALL_RESET_API") call_reset_api = os.environ.get("MOTO_CALL_RESET_API")
call_reset_api = not call_reset_api or call_reset_api.lower() != "false" if not call_reset_api or call_reset_api.lower() != "false":
if call_reset_api:
if not ServerModeMockAWS.RESET_IN_PROGRESS: if not ServerModeMockAWS.RESET_IN_PROGRESS:
ServerModeMockAWS.RESET_IN_PROGRESS = True ServerModeMockAWS.RESET_IN_PROGRESS = True
import requests import requests
@ -345,14 +350,14 @@ class ServerModeMockAWS(BaseMockAWS):
requests.post(f"{self.test_server_mode_endpoint}/moto-api/reset") requests.post(f"{self.test_server_mode_endpoint}/moto-api/reset")
ServerModeMockAWS.RESET_IN_PROGRESS = False ServerModeMockAWS.RESET_IN_PROGRESS = False
def enable_patching(self, reset=True): def enable_patching(self, reset: bool = True) -> None:
if self.__class__.nested_count == 1 and reset: if self.__class__.nested_count == 1 and reset:
# Just started # Just started
self.reset() self.reset()
from boto3 import client as real_boto3_client, resource as real_boto3_resource from boto3 import client as real_boto3_client, resource as real_boto3_resource
def fake_boto3_client(*args, **kwargs): def fake_boto3_client(*args: Any, **kwargs: Any) -> botocore.client.BaseClient:
region = self._get_region(*args, **kwargs) region = self._get_region(*args, **kwargs)
if region: if region:
if "config" in kwargs: if "config" in kwargs:
@ -364,7 +369,7 @@ class ServerModeMockAWS(BaseMockAWS):
kwargs["endpoint_url"] = self.test_server_mode_endpoint kwargs["endpoint_url"] = self.test_server_mode_endpoint
return real_boto3_client(*args, **kwargs) return real_boto3_client(*args, **kwargs)
def fake_boto3_resource(*args, **kwargs): def fake_boto3_resource(*args: Any, **kwargs: Any) -> Any:
if "endpoint_url" not in kwargs: if "endpoint_url" not in kwargs:
kwargs["endpoint_url"] = self.test_server_mode_endpoint kwargs["endpoint_url"] = self.test_server_mode_endpoint
return real_boto3_resource(*args, **kwargs) return real_boto3_resource(*args, **kwargs)
@ -374,7 +379,7 @@ class ServerModeMockAWS(BaseMockAWS):
self._client_patcher.start() self._client_patcher.start()
self._resource_patcher.start() self._resource_patcher.start()
def _get_region(self, *args, **kwargs): def _get_region(self, *args: Any, **kwargs: Any) -> Optional[str]:
if "region_name" in kwargs: if "region_name" in kwargs:
return kwargs["region_name"] return kwargs["region_name"]
if type(args) == tuple and len(args) == 2: if type(args) == tuple and len(args) == 2:
@ -382,7 +387,7 @@ class ServerModeMockAWS(BaseMockAWS):
return region return region
return None return None
def disable_patching(self): def disable_patching(self) -> None:
if self._client_patcher: if self._client_patcher:
self._client_patcher.stop() self._client_patcher.stop()
self._resource_patcher.stop() self._resource_patcher.stop()
@ -394,9 +399,9 @@ class base_decorator:
def __init__(self, backends: BackendDict): def __init__(self, backends: BackendDict):
self.backends = backends self.backends = backends
def __call__(self, func=None): def __call__(self, func: Optional[Callable[..., Any]] = None) -> BaseMockAWS:
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
mocked_backend = ServerModeMockAWS(self.backends) mocked_backend: BaseMockAWS = ServerModeMockAWS(self.backends)
else: else:
mocked_backend = self.mock_backend(self.backends) mocked_backend = self.mock_backend(self.backends)

View File

@ -11,11 +11,22 @@ import xmltodict
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from moto import settings from moto import settings
from moto.core.common_types import TYPE_RESPONSE, TYPE_IF_NONE
from moto.core.exceptions import DryRunClientError from moto.core.exceptions import DryRunClientError
from moto.core.utils import camelcase_to_underscores, method_names_from_class from moto.core.utils import camelcase_to_underscores, method_names_from_class
from moto.utilities.utils import load_resource from moto.utilities.utils import load_resource
from jinja2 import Environment, DictLoader, Template from jinja2 import Environment, DictLoader, Template
from typing import Dict, Union, Any, Tuple, TypeVar from typing import (
Dict,
Union,
Any,
Tuple,
Optional,
List,
Set,
ClassVar,
Callable,
)
from urllib.parse import parse_qs, parse_qsl, urlparse from urllib.parse import parse_qs, parse_qsl, urlparse
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException
from xml.dom.minidom import parseString as parseXML from xml.dom.minidom import parseString as parseXML
@ -23,29 +34,19 @@ from xml.dom.minidom import parseString as parseXML
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
JINJA_ENVS = {} JINJA_ENVS: Dict[type, Environment] = {}
TYPE_RESPONSE = Tuple[int, Dict[str, str], str]
TYPE_IF_NONE = TypeVar("TYPE_IF_NONE")
def _decode_dict(d): def _decode_dict(d: Dict[Any, Any]) -> Dict[str, Any]:
decoded = OrderedDict() decoded: Dict[str, Any] = OrderedDict()
for key, value in d.items(): for key, value in d.items():
if isinstance(key, bytes): if isinstance(key, bytes):
newkey = key.decode("utf-8") newkey = key.decode("utf-8")
elif isinstance(key, (list, tuple)):
newkey = []
for k in key:
if isinstance(k, bytes):
newkey.append(k.decode("utf-8"))
else:
newkey.append(k)
else: else:
newkey = key newkey = key
if isinstance(value, bytes): if isinstance(value, bytes):
newvalue = value.decode("utf-8") decoded[newkey] = value.decode("utf-8")
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
newvalue = [] newvalue = []
for v in value: for v in value:
@ -53,18 +54,18 @@ def _decode_dict(d):
newvalue.append(v.decode("utf-8")) newvalue.append(v.decode("utf-8"))
else: else:
newvalue.append(v) newvalue.append(v)
decoded[newkey] = newvalue
else: else:
newvalue = value decoded[newkey] = value
decoded[newkey] = newvalue
return decoded return decoded
class DynamicDictLoader(DictLoader): class DynamicDictLoader(DictLoader):
def update(self, mapping): def update(self, mapping: Dict[str, str]) -> None:
self.mapping.update(mapping) self.mapping.update(mapping) # type: ignore[attr-defined]
def contains(self, template): def contains(self, template: str) -> bool:
return bool(template in self.mapping) return bool(template in self.mapping)
@ -73,12 +74,12 @@ class _TemplateEnvironmentMixin(object):
RIGHT_PATTERN = re.compile(r">[\s\n]+") RIGHT_PATTERN = re.compile(r">[\s\n]+")
@property @property
def should_autoescape(self): def should_autoescape(self) -> bool:
# Allow for subclass to overwrite # Allow for subclass to overwrite
return False return False
@property @property
def environment(self): def environment(self) -> Environment:
key = type(self) key = type(self)
try: try:
environment = JINJA_ENVS[key] environment = JINJA_ENVS[key]
@ -94,11 +95,11 @@ class _TemplateEnvironmentMixin(object):
return environment return environment
def contains_template(self, template_id): def contains_template(self, template_id: str) -> bool:
return self.environment.loader.contains(template_id) return self.environment.loader.contains(template_id) # type: ignore[union-attr]
@classmethod @classmethod
def _make_template_id(cls, source): def _make_template_id(cls, source: str) -> str:
""" """
Return a numeric string that's unique for the lifetime of the source. Return a numeric string that's unique for the lifetime of the source.
@ -117,47 +118,49 @@ class _TemplateEnvironmentMixin(object):
xml = re.sub( xml = re.sub(
self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source) self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source)
) )
self.environment.loader.update({template_id: xml}) self.environment.loader.update({template_id: xml}) # type: ignore[union-attr]
return self.environment.get_template(template_id) return self.environment.get_template(template_id)
class ActionAuthenticatorMixin(object): class ActionAuthenticatorMixin(object):
request_count = 0 request_count: ClassVar[int] = 0
def _authenticate_and_authorize_action(self, iam_request_cls): def _authenticate_and_authorize_action(self, iam_request_cls: type) -> None:
if ( if (
ActionAuthenticatorMixin.request_count ActionAuthenticatorMixin.request_count
>= settings.INITIAL_NO_AUTH_ACTION_COUNT >= settings.INITIAL_NO_AUTH_ACTION_COUNT
): ):
iam_request = iam_request_cls( iam_request = iam_request_cls(
account_id=self.current_account, account_id=self.current_account, # type: ignore[attr-defined]
method=self.method, method=self.method, # type: ignore[attr-defined]
path=self.path, path=self.path, # type: ignore[attr-defined]
data=self.data, data=self.data, # type: ignore[attr-defined]
headers=self.headers, headers=self.headers, # type: ignore[attr-defined]
) )
iam_request.check_signature() iam_request.check_signature()
iam_request.check_action_permitted() iam_request.check_action_permitted()
else: else:
ActionAuthenticatorMixin.request_count += 1 ActionAuthenticatorMixin.request_count += 1
def _authenticate_and_authorize_normal_action(self): def _authenticate_and_authorize_normal_action(self) -> None:
from moto.iam.access_control import IAMRequest from moto.iam.access_control import IAMRequest
self._authenticate_and_authorize_action(IAMRequest) self._authenticate_and_authorize_action(IAMRequest)
def _authenticate_and_authorize_s3_action(self): def _authenticate_and_authorize_s3_action(self) -> None:
from moto.iam.access_control import S3IAMRequest from moto.iam.access_control import S3IAMRequest
self._authenticate_and_authorize_action(S3IAMRequest) self._authenticate_and_authorize_action(S3IAMRequest)
@staticmethod @staticmethod
def set_initial_no_auth_action_count(initial_no_auth_action_count): def set_initial_no_auth_action_count(initial_no_auth_action_count: int) -> Callable[..., Callable[..., TYPE_RESPONSE]]: # type: ignore[misc]
_test_server_mode_endpoint = settings.test_server_mode_endpoint() _test_server_mode_endpoint = settings.test_server_mode_endpoint()
def decorator(function): def decorator(
def wrapper(*args, **kwargs): function: Callable[..., TYPE_RESPONSE]
) -> Callable[..., TYPE_RESPONSE]:
def wrapper(*args: Any, **kwargs: Any) -> TYPE_RESPONSE:
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
response = requests.post( response = requests.post(
f"{_test_server_mode_endpoint}/moto-api/reset-auth", f"{_test_server_mode_endpoint}/moto-api/reset-auth",
@ -191,7 +194,7 @@ class ActionAuthenticatorMixin(object):
return result return result
functools.update_wrapper(wrapper, function) functools.update_wrapper(wrapper, function)
wrapper.__wrapped__ = function wrapper.__wrapped__ = function # type: ignore[attr-defined]
return wrapper return wrapper
return decorator return decorator
@ -213,12 +216,12 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
aws_service_spec = None aws_service_spec = None
def __init__(self, service_name=None) -> None: def __init__(self, service_name: Optional[str] = None):
super().__init__() super().__init__()
self.service_name = service_name self.service_name = service_name
@classmethod @classmethod
def dispatch(cls, *args: Any, **kwargs: Any) -> Any: def dispatch(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
return cls()._dispatch(*args, **kwargs) return cls()._dispatch(*args, **kwargs)
def setup_class( def setup_class(
@ -227,7 +230,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
""" """
use_raw_body: Use incoming bytes if True, encode to string otherwise use_raw_body: Use incoming bytes if True, encode to string otherwise
""" """
querystring = OrderedDict() querystring: Dict[str, Any] = OrderedDict()
if hasattr(request, "body"): if hasattr(request, "body"):
# Boto # Boto
self.body = request.body self.body = request.body
@ -292,7 +295,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.data = querystring self.data = querystring
self.method = request.method self.method = request.method
self.region = self.get_region_from_url(request, full_url) self.region = self.get_region_from_url(request, full_url)
self.uri_match = None self.uri_match: Optional[re.Match[str]] = None
self.headers = request.headers self.headers = request.headers
if "host" not in self.headers: if "host" not in self.headers:
@ -307,11 +310,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
mark_account_as_visited( mark_account_as_visited(
account_id=self.current_account, account_id=self.current_account,
access_key=self.access_key, access_key=self.access_key,
service=self.service_name, service=self.service_name, # type: ignore[arg-type]
region=self.region, region=self.region,
) )
def get_region_from_url(self, request, full_url): def get_region_from_url(self, request: Any, full_url: str) -> str:
url_match = self.region_regex.search(full_url) url_match = self.region_regex.search(full_url)
user_agent_match = self.region_from_useragent_regex.search( user_agent_match = self.region_from_useragent_regex.search(
request.headers.get("User-Agent", "") request.headers.get("User-Agent", "")
@ -329,7 +332,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
region = self.default_region region = self.default_region
return region return region
def get_access_key(self): def get_access_key(self) -> str:
""" """
Returns the access key id used in this request as the current user id Returns the access key id used in this request as the current user id
""" """
@ -339,11 +342,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return match.group(1) return match.group(1)
if self.querystring.get("AWSAccessKeyId"): if self.querystring.get("AWSAccessKeyId"):
return self.querystring.get("AWSAccessKeyId")[0] return self.querystring["AWSAccessKeyId"][0]
else: else:
return "AKIAEXAMPLE" return "AKIAEXAMPLE"
def get_current_account(self): def get_current_account(self) -> str:
# PRIO 1: Check if we have a Environment Variable set # PRIO 1: Check if we have a Environment Variable set
if "MOTO_ACCOUNT_ID" in os.environ: if "MOTO_ACCOUNT_ID" in os.environ:
return os.environ["MOTO_ACCOUNT_ID"] return os.environ["MOTO_ACCOUNT_ID"]
@ -358,11 +361,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return get_account_id_from(self.get_access_key()) return get_account_id_from(self.get_access_key())
def _dispatch(self, request, full_url, headers): def _dispatch(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
return self.call_action() return self.call_action()
def uri_to_regexp(self, uri): def uri_to_regexp(self, uri: str) -> str:
"""converts uri w/ placeholder to regexp """converts uri w/ placeholder to regexp
'/cars/{carName}/drivers/{DriverName}' '/cars/{carName}/drivers/{DriverName}'
-> '^/cars/.*/drivers/[^/]*$' -> '^/cars/.*/drivers/[^/]*$'
@ -372,7 +375,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
""" """
def _convert(elem, is_last): def _convert(elem: str, is_last: bool) -> str:
if not re.match("^{.*}$", elem): if not re.match("^{.*}$", elem):
return elem return elem
name = ( name = (
@ -394,7 +397,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
return regexp return regexp
def _get_action_from_method_and_request_uri(self, method, request_uri): def _get_action_from_method_and_request_uri(
self, method: str, request_uri: str
) -> str:
"""basically used for `rest-json` APIs """basically used for `rest-json` APIs
You can refer to example from link below You can refer to example from link below
https://github.com/boto/botocore/blob/develop/botocore/data/iot/2015-05-28/service-2.json https://github.com/boto/botocore/blob/develop/botocore/data/iot/2015-05-28/service-2.json
@ -406,7 +411,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# make cache if it does not exist yet # make cache if it does not exist yet
if not hasattr(self, "method_urls"): if not hasattr(self, "method_urls"):
self.method_urls = defaultdict(lambda: defaultdict(str)) self.method_urls: Dict[str, Dict[str, str]] = defaultdict(
lambda: defaultdict(str)
)
op_names = conn._service_model.operation_names op_names = conn._service_model.operation_names
for op_name in op_names: for op_name in op_names:
op_model = conn._service_model.operation_model(op_name) op_model = conn._service_model.operation_model(op_name)
@ -419,9 +426,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.uri_match = match self.uri_match = match
if match: if match:
return name return name
return None return None # type: ignore[return-value]
def _get_action(self): def _get_action(self) -> str:
action = self.querystring.get("Action", [""])[0] action = self.querystring.get("Action", [""])[0]
if action: if action:
return action return action
@ -433,7 +440,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# get action from method and uri # get action from method and uri
return self._get_action_from_method_and_request_uri(self.method, self.path) return self._get_action_from_method_and_request_uri(self.method, self.path)
def call_action(self): def call_action(self) -> TYPE_RESPONSE:
headers = self.response_headers headers = self.response_headers
try: try:
@ -449,9 +456,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
try: try:
response = method() response = method()
except HTTPException as http_error: except HTTPException as http_error:
response_headers = dict(http_error.get_headers() or []) response_headers: Dict[str, Union[str, int]] = dict(
response_headers["status"] = http_error.code http_error.get_headers() or []
response = http_error.description, response_headers )
response_headers["status"] = http_error.code # type: ignore[assignment]
response = http_error.description, response_headers # type: ignore[assignment]
if isinstance(response, str): if isinstance(response, str):
return 200, headers, response return 200, headers, response
@ -466,7 +475,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
@staticmethod @staticmethod
def _send_response(headers, response): def _send_response(headers: Dict[str, str], response: Any) -> Tuple[int, Dict[str, str], str]: # type: ignore[misc]
if response is None: if response is None:
response = "", {} response = "", {}
if len(response) == 2: if len(response) == 2:
@ -480,7 +489,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
headers["status"] = str(headers["status"]) headers["status"] = str(headers["status"])
return status, headers, body return status, headers, body
def _get_param(self, param_name, if_none=None) -> Any: def _get_param(self, param_name: str, if_none: Any = None) -> Any:
val = self.querystring.get(param_name) val = self.querystring.get(param_name)
if val is not None: if val is not None:
return val[0] return val[0]
@ -503,7 +512,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return if_none return if_none
def _get_int_param( def _get_int_param(
self, param_name, if_none: TYPE_IF_NONE = None self, param_name: str, if_none: TYPE_IF_NONE = None # type: ignore[assignment]
) -> Union[int, TYPE_IF_NONE]: ) -> Union[int, TYPE_IF_NONE]:
val = self._get_param(param_name) val = self._get_param(param_name)
if val is not None: if val is not None:
@ -511,7 +520,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return if_none return if_none
def _get_bool_param( def _get_bool_param(
self, param_name, if_none: TYPE_IF_NONE = None self, param_name: str, if_none: TYPE_IF_NONE = None # type: ignore[assignment]
) -> Union[bool, TYPE_IF_NONE]: ) -> Union[bool, TYPE_IF_NONE]:
val = self._get_param(param_name) val = self._get_param(param_name)
if val is not None: if val is not None:
@ -522,12 +531,15 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return False return False
return if_none return if_none
def _get_multi_param_dict(self, param_prefix) -> Dict: def _get_multi_param_dict(self, param_prefix: str) -> Dict[str, Any]:
return self._get_multi_param_helper(param_prefix, skip_result_conversion=True) return self._get_multi_param_helper(param_prefix, skip_result_conversion=True)
def _get_multi_param_helper( def _get_multi_param_helper(
self, param_prefix, skip_result_conversion=False, tracked_prefixes=None self,
): param_prefix: str,
skip_result_conversion: bool = False,
tracked_prefixes: Optional[Set[str]] = None,
) -> Any:
value_dict = dict() value_dict = dict()
tracked_prefixes = ( tracked_prefixes = (
tracked_prefixes or set() tracked_prefixes or set()
@ -589,11 +601,13 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if len(parts) != 2 or parts[1] != "member": if len(parts) != 2 or parts[1] != "member":
value_dict[parts[0]] = value_dict.pop(k) value_dict[parts[0]] = value_dict.pop(k)
else: else:
value_dict = list(value_dict.values())[0] value_dict = list(value_dict.values())[0] # type: ignore[assignment]
return value_dict return value_dict
def _get_multi_param(self, param_prefix, skip_result_conversion=False) -> Any: def _get_multi_param(
self, param_prefix: str, skip_result_conversion: bool = False
) -> List[Any]:
""" """
Given a querystring of ?LaunchConfigurationNames.member.1=my-test-1&LaunchConfigurationNames.member.2=my-test-2 Given a querystring of ?LaunchConfigurationNames.member.1=my-test-1&LaunchConfigurationNames.member.2=my-test-2
this will return ['my-test-1', 'my-test-2'] this will return ['my-test-1', 'my-test-2']
@ -616,7 +630,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return values return values
def _get_dict_param(self, param_prefix) -> Dict: def _get_dict_param(self, param_prefix: str) -> Dict[str, Any]:
""" """
Given a parameter dict of Given a parameter dict of
{ {
@ -630,7 +644,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
"instance_count": "1", "instance_count": "1",
} }
""" """
params = {} params: Dict[str, Any] = {}
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if key.startswith(param_prefix): if key.startswith(param_prefix):
params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[ params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[
@ -638,7 +652,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
] ]
return params return params
def _get_params(self) -> Any: def _get_params(self) -> Dict[str, Any]:
""" """
Given a querystring of Given a querystring of
{ {
@ -674,12 +688,12 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
] ]
} }
""" """
params = {} params: Dict[str, Any] = {}
for k, v in sorted(self.querystring.items()): for k, v in sorted(self.querystring.items()):
self._parse_param(k, v[0], params) self._parse_param(k, v[0], params)
return params return params
def _parse_param(self, key, value, params): def _parse_param(self, key: str, value: str, params: Any) -> None:
keylist = key.split(".") keylist = key.split(".")
obj = params obj = params
for i, key in enumerate(keylist[:-1]): for i, key in enumerate(keylist[:-1]):
@ -713,7 +727,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else: else:
obj[keylist[-1]] = value obj[keylist[-1]] = value
def _get_list_prefix(self, param_prefix: str) -> Any: def _get_list_prefix(self, param_prefix: str) -> List[Dict[str, Any]]:
""" """
Given a query dict like Given a query dict like
{ {
@ -752,7 +766,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
param_index += 1 param_index += 1
return results return results
def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"): def _get_map_prefix(
self, param_prefix: str, key_end: str = ".key", value_end: str = ".value"
) -> Dict[str, Any]:
results = {} results = {}
param_index = 1 param_index = 1
while 1: while 1:
@ -774,7 +790,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return results return results
def _get_object_map(self, prefix, name="Name", value="Value"): def _get_object_map(
self, prefix: str, name: str = "Name", value: str = "Value"
) -> Dict[str, Any]:
""" """
Given a query dict like Given a query dict like
{ {
@ -822,19 +840,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return object_map return object_map
@property @property
def request_json(self): def request_json(self) -> bool:
return "JSON" in self.querystring.get("ContentType", []) return "JSON" in self.querystring.get("ContentType", [])
def error_on_dryrun(self): def error_on_dryrun(self) -> None:
self.is_not_dryrun() self.is_not_dryrun()
def is_not_dryrun(self, action=None): def is_not_dryrun(self, action: Optional[str] = None) -> bool:
action = action or self._get_param("Action")
if "true" in self.querystring.get("DryRun", ["false"]): if "true" in self.querystring.get("DryRun", ["false"]):
message = ( a = action or self._get_param("Action")
"An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set" message = f"An error occurred (DryRunOperation) when calling the {a} operation: Request would have succeeded, but DryRun flag is set"
% action
)
raise DryRunClientError(error_type="DryRunOperation", message=message) raise DryRunClientError(error_type="DryRunOperation", message=message)
return True return True
@ -842,20 +857,20 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
class _RecursiveDictRef(object): class _RecursiveDictRef(object):
"""Store a recursive reference to dict.""" """Store a recursive reference to dict."""
def __init__(self): def __init__(self) -> None:
self.key = None self.key: Optional[str] = None
self.dic = {} self.dic: Dict[str, Any] = {}
def __repr__(self): def __repr__(self) -> str:
return "{!r}".format(self.dic) return "{!r}".format(self.dic)
def __getattr__(self, key): def __getattr__(self, key: str) -> Any:
return self.dic.__getattr__(key) return self.dic.__getattr__(key) # type: ignore[attr-defined]
def __getitem__(self, key): def __getitem__(self, key: str) -> Any:
return self.dic.__getitem__(key) return self.dic.__getitem__(key)
def set_reference(self, key, dic): def set_reference(self, key: str, dic: Dict[str, Any]) -> None:
"""Set the RecursiveDictRef object to keep reference to dict object """Set the RecursiveDictRef object to keep reference to dict object
(dic) at the key. (dic) at the key.
@ -877,7 +892,7 @@ class AWSServiceSpec(object):
self.operations = spec["operations"] self.operations = spec["operations"]
self.shapes = spec["shapes"] self.shapes = spec["shapes"]
def input_spec(self, operation): def input_spec(self, operation: str) -> Dict[str, Any]:
try: try:
op = self.operations[operation] op = self.operations[operation]
except KeyError: except KeyError:
@ -887,7 +902,7 @@ class AWSServiceSpec(object):
shape = self.shapes[op["input"]["shape"]] shape = self.shapes[op["input"]["shape"]]
return self._expand(shape) return self._expand(shape)
def output_spec(self, operation): def output_spec(self, operation: str) -> Dict[str, Any]:
"""Produce a JSON with a valid API response syntax for operation, but """Produce a JSON with a valid API response syntax for operation, but
with type information. Each node represented by a key has the with type information. Each node represented by a key has the
value containing field type, e.g., value containing field type, e.g.,
@ -904,11 +919,13 @@ class AWSServiceSpec(object):
shape = self.shapes[op["output"]["shape"]] shape = self.shapes[op["output"]["shape"]]
return self._expand(shape) return self._expand(shape)
def _expand(self, shape): def _expand(self, shape: Dict[str, Any]) -> Dict[str, Any]:
def expand(dic, seen=None): def expand(
dic: Dict[str, Any], seen: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
seen = seen or {} seen = seen or {}
if dic["type"] == "structure": if dic["type"] == "structure":
nodes = {} nodes: Dict[str, Any] = {}
for k, v in dic["members"].items(): for k, v in dic["members"].items():
seen_till_here = dict(seen) seen_till_here = dict(seen)
if k in seen_till_here: if k in seen_till_here:
@ -932,7 +949,7 @@ class AWSServiceSpec(object):
elif dic["type"] == "map": elif dic["type"] == "map":
seen_till_here = dict(seen) seen_till_here = dict(seen)
node = {"type": "map"} node: Dict[str, Any] = {"type": "map"}
if "shape" in dic["key"]: if "shape" in dic["key"]:
shape = dic["key"]["shape"] shape = dic["key"]["shape"]
@ -958,12 +975,12 @@ class AWSServiceSpec(object):
return expand(shape) return expand(shape)
def to_str(value, spec): def to_str(value: Any, spec: Dict[str, Any]) -> str:
vtype = spec["type"] vtype = spec["type"]
if vtype == "boolean": if vtype == "boolean":
return "true" if value else "false" return "true" if value else "false"
elif vtype == "long": elif vtype == "long":
return int(value) return int(value) # type: ignore[return-value]
elif vtype == "integer": elif vtype == "integer":
return str(value) return str(value)
elif vtype == "float": elif vtype == "float":
@ -984,7 +1001,7 @@ def to_str(value, spec):
raise TypeError("Unknown type {}".format(vtype)) raise TypeError("Unknown type {}".format(vtype))
def from_str(value, spec): def from_str(value: str, spec: Dict[str, Any]) -> Any:
vtype = spec["type"] vtype = spec["type"]
if vtype == "boolean": if vtype == "boolean":
return True if value == "true" else False return True if value == "true" else False
@ -1001,7 +1018,9 @@ def from_str(value, spec):
raise TypeError("Unknown type {}".format(vtype)) raise TypeError("Unknown type {}".format(vtype))
def flatten_json_request_body(prefix, dict_body, spec): def flatten_json_request_body(
prefix: str, dict_body: Dict[str, Any], spec: Dict[str, Any]
) -> Dict[str, Any]:
"""Convert a JSON request body into query params.""" """Convert a JSON request body into query params."""
if len(spec) == 1 and "type" in spec: if len(spec) == 1 and "type" in spec:
return {prefix: to_str(dict_body, spec)} return {prefix: to_str(dict_body, spec)}
@ -1030,10 +1049,12 @@ def flatten_json_request_body(prefix, dict_body, spec):
return dict((prefix + k, v) for k, v in flat.items()) return dict((prefix + k, v) for k, v in flat.items())
def xml_to_json_response(service_spec, operation, xml, result_node=None): def xml_to_json_response(
service_spec: Any, operation: str, xml: str, result_node: Any = None
) -> Dict[str, Any]:
"""Convert rendered XML response to JSON for use with boto3.""" """Convert rendered XML response to JSON for use with boto3."""
def transform(value, spec): def transform(value: Any, spec: Dict[str, Any]) -> Any:
"""Apply transformations to make the output JSON comply with the """Apply transformations to make the output JSON comply with the
expected form. This function applies: expected form. This function applies:
@ -1047,7 +1068,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
if len(spec) == 1: if len(spec) == 1:
return from_str(value, spec) return from_str(value, spec)
od = OrderedDict() od: Dict[str, Any] = OrderedDict()
for k, v in value.items(): for k, v in value.items():
if k.startswith("@"): if k.startswith("@"):
continue continue
@ -1099,7 +1120,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
for k in result_node or (operation + "Response", operation + "Result"): for k in result_node or (operation + "Response", operation + "Result"):
dic = dic[k] dic = dic[k]
except KeyError: except KeyError:
return None return None # type: ignore[return-value]
else: else:
return transform(dic, output_spec) return transform(dic, output_spec)
return None return None

View File

@ -1,5 +1,6 @@
# This will only exist in responses >= 0.17 # This will only exist in responses >= 0.17
import responses import responses
from typing import Any, List, Tuple, Optional
from .custom_responses_mock import CallbackResponse, not_implemented_callback from .custom_responses_mock import CallbackResponse, not_implemented_callback
@ -10,11 +11,12 @@ class CustomRegistry(responses.registries.FirstMatchRegistry):
- CallbackResponses are not discarded after first use - users can mock the same URL as often as they like - CallbackResponses are not discarded after first use - users can mock the same URL as often as they like
""" """
def add(self, response): def add(self, response: responses.BaseResponse) -> responses.BaseResponse:
if response not in self.registered: if response not in self.registered:
super().add(response) super().add(response)
return response
def find(self, request): def find(self, request: Any) -> Tuple[Optional[responses.BaseResponse], List[str]]:
all_possibles = responses._default_mock._registry.registered + self.registered all_possibles = responses._default_mock._registry.registered + self.registered
found = [] found = []
match_failed_reasons = [] match_failed_reasons = []

View File

@ -2,11 +2,12 @@ import datetime
import inspect import inspect
import re import re
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from typing import Optional from typing import Any, Optional, List, Callable, Dict
from urllib.parse import urlparse from urllib.parse import urlparse
from .common_types import TYPE_RESPONSE
def camelcase_to_underscores(argument: Optional[str]) -> str: def camelcase_to_underscores(argument: str) -> str:
"""Converts a camelcase param like theNewAttribute to the equivalent """Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute""" python underscore variable like the_new_attribute"""
result = "" result = ""
@ -32,7 +33,7 @@ def camelcase_to_underscores(argument: Optional[str]) -> str:
return result return result
def underscores_to_camelcase(argument): def underscores_to_camelcase(argument: str) -> str:
"""Converts a camelcase param like the_new_attribute to the equivalent """Converts a camelcase param like the_new_attribute to the equivalent
camelcase version like theNewAttribute. Note that the first letter is camelcase version like theNewAttribute. Note that the first letter is
NOT capitalized by this function""" NOT capitalized by this function"""
@ -48,17 +49,17 @@ def underscores_to_camelcase(argument):
return result return result
def pascal_to_camelcase(argument): def pascal_to_camelcase(argument: str) -> str:
"""Converts a PascalCase param to the camelCase equivalent""" """Converts a PascalCase param to the camelCase equivalent"""
return argument[0].lower() + argument[1:] return argument[0].lower() + argument[1:]
def camelcase_to_pascal(argument): def camelcase_to_pascal(argument: str) -> str:
"""Converts a camelCase param to the PascalCase equivalent""" """Converts a camelCase param to the PascalCase equivalent"""
return argument[0].upper() + argument[1:] return argument[0].upper() + argument[1:]
def method_names_from_class(clazz): def method_names_from_class(clazz: object) -> List[str]:
predicate = inspect.isfunction predicate = inspect.isfunction
return [x[0] for x in inspect.getmembers(clazz, predicate=predicate)] return [x[0] for x in inspect.getmembers(clazz, predicate=predicate)]
@ -70,7 +71,7 @@ def convert_regex_to_flask_path(url_path: str) -> str:
for token in ["$"]: for token in ["$"]:
url_path = url_path.replace(token, "") url_path = url_path.replace(token, "")
def caller(reg): def caller(reg: Any) -> str:
match_name, match_pattern = reg.groups() match_name, match_pattern = reg.groups()
return '<regex("{0}"):{1}>'.format(match_pattern, match_name) return '<regex("{0}"):{1}>'.format(match_pattern, match_name)
@ -83,11 +84,11 @@ def convert_regex_to_flask_path(url_path: str) -> str:
class convert_to_flask_response(object): class convert_to_flask_response(object):
def __init__(self, callback): def __init__(self, callback: Callable[..., Any]):
self.callback = callback self.callback = callback
@property @property
def __name__(self): def __name__(self) -> str:
# For instance methods, use class and method names. Otherwise # For instance methods, use class and method names. Otherwise
# use module and method name # use module and method name
if inspect.ismethod(self.callback): if inspect.ismethod(self.callback):
@ -96,7 +97,7 @@ class convert_to_flask_response(object):
outer = self.callback.__module__ outer = self.callback.__module__
return "{0}.{1}".format(outer, self.callback.__name__) return "{0}.{1}".format(outer, self.callback.__name__)
def __call__(self, args=None, **kwargs): def __call__(self, args: Any = None, **kwargs: Any) -> Any:
from flask import request, Response from flask import request, Response
from moto.moto_api import recorder from moto.moto_api import recorder
@ -118,11 +119,11 @@ class convert_to_flask_response(object):
class convert_flask_to_responses_response(object): class convert_flask_to_responses_response(object):
def __init__(self, callback): def __init__(self, callback: Callable[..., Any]):
self.callback = callback self.callback = callback
@property @property
def __name__(self): def __name__(self) -> str:
# For instance methods, use class and method names. Otherwise # For instance methods, use class and method names. Otherwise
# use module and method name # use module and method name
if inspect.ismethod(self.callback): if inspect.ismethod(self.callback):
@ -131,7 +132,7 @@ class convert_flask_to_responses_response(object):
outer = self.callback.__module__ outer = self.callback.__module__
return "{0}.{1}".format(outer, self.callback.__name__) return "{0}.{1}".format(outer, self.callback.__name__)
def __call__(self, request, *args, **kwargs): def __call__(self, request: Any, *args: Any, **kwargs: Any) -> TYPE_RESPONSE:
for key, val in request.headers.items(): for key, val in request.headers.items():
if isinstance(val, bytes): if isinstance(val, bytes):
request.headers[key] = val.decode("utf-8") request.headers[key] = val.decode("utf-8")
@ -141,7 +142,7 @@ class convert_flask_to_responses_response(object):
return status, headers, response return status, headers, response
def iso_8601_datetime_with_milliseconds(value: datetime) -> str: def iso_8601_datetime_with_milliseconds(value: datetime.datetime) -> str:
return value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" return value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
@ -163,22 +164,22 @@ def iso_8601_datetime_without_milliseconds_s3(
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT" RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(src): def rfc_1123_datetime(src: datetime.datetime) -> str:
return src.strftime(RFC1123) return src.strftime(RFC1123)
def str_to_rfc_1123_datetime(value): def str_to_rfc_1123_datetime(value: str) -> datetime.datetime:
return datetime.datetime.strptime(value, RFC1123) return datetime.datetime.strptime(value, RFC1123)
def unix_time(dt: datetime.datetime = None) -> int: def unix_time(dt: Optional[datetime.datetime] = None) -> float:
dt = dt or datetime.datetime.utcnow() dt = dt or datetime.datetime.utcnow()
epoch = datetime.datetime.utcfromtimestamp(0) epoch = datetime.datetime.utcfromtimestamp(0)
delta = dt - epoch delta = dt - epoch
return (delta.days * 86400) + (delta.seconds + (delta.microseconds / 1e6)) return (delta.days * 86400) + (delta.seconds + (delta.microseconds / 1e6))
def unix_time_millis(dt: datetime = None) -> int: def unix_time_millis(dt: Optional[datetime.datetime] = None) -> float:
return unix_time(dt) * 1000.0 return unix_time(dt) * 1000.0
@ -193,28 +194,33 @@ def path_url(url: str) -> str:
def tags_from_query_string( def tags_from_query_string(
querystring_dict, prefix="Tag", key_suffix="Key", value_suffix="Value" querystring_dict: Dict[str, Any],
): prefix: str = "Tag",
key_suffix: str = "Key",
value_suffix: str = "Value",
) -> Dict[str, str]:
response_values = {} response_values = {}
for key in querystring_dict.keys(): for key in querystring_dict.keys():
if key.startswith(prefix) and key.endswith(key_suffix): if key.startswith(prefix) and key.endswith(key_suffix):
tag_index = key.replace(prefix + ".", "").replace("." + key_suffix, "") tag_index = key.replace(prefix + ".", "").replace("." + key_suffix, "")
tag_key = querystring_dict.get( tag_key = querystring_dict[
"{prefix}.{index}.{key_suffix}".format( "{prefix}.{index}.{key_suffix}".format(
prefix=prefix, index=tag_index, key_suffix=key_suffix prefix=prefix, index=tag_index, key_suffix=key_suffix
) )
)[0] ][0]
tag_value_key = "{prefix}.{index}.{value_suffix}".format( tag_value_key = "{prefix}.{index}.{value_suffix}".format(
prefix=prefix, index=tag_index, value_suffix=value_suffix prefix=prefix, index=tag_index, value_suffix=value_suffix
) )
if tag_value_key in querystring_dict: if tag_value_key in querystring_dict:
response_values[tag_key] = querystring_dict.get(tag_value_key)[0] response_values[tag_key] = querystring_dict[tag_value_key][0]
else: else:
response_values[tag_key] = None response_values[tag_key] = None
return response_values return response_values
def tags_from_cloudformation_tags_list(tags_list): def tags_from_cloudformation_tags_list(
tags_list: List[Dict[str, str]]
) -> Dict[str, str]:
"""Return tags in dict form from cloudformation resource tags form (list of dicts)""" """Return tags in dict form from cloudformation resource tags form (list of dicts)"""
tags = {} tags = {}
for entry in tags_list: for entry in tags_list:
@ -225,7 +231,7 @@ def tags_from_cloudformation_tags_list(tags_list):
return tags return tags
def remap_nested_keys(root, key_transform): def remap_nested_keys(root: Any, key_transform: Callable[[str], str]) -> Any:
"""This remap ("recursive map") function is used to traverse and """This remap ("recursive map") function is used to traverse and
transform the dictionary keys of arbitrarily nested structures. transform the dictionary keys of arbitrarily nested structures.
List comprehensions do not recurse, making it tedious to apply List comprehensions do not recurse, making it tedious to apply
@ -252,7 +258,9 @@ def remap_nested_keys(root, key_transform):
return root return root
def merge_dicts(dict1, dict2, remove_nulls=False): def merge_dicts(
dict1: Dict[str, Any], dict2: Dict[str, Any], remove_nulls: bool = False
) -> None:
"""Given two arbitrarily nested dictionaries, merge the second dict into the first. """Given two arbitrarily nested dictionaries, merge the second dict into the first.
:param dict dict1: the dictionary to be updated. :param dict dict1: the dictionary to be updated.
@ -275,7 +283,7 @@ def merge_dicts(dict1, dict2, remove_nulls=False):
dict1.pop(key) dict1.pop(key)
def aws_api_matches(pattern, string): def aws_api_matches(pattern: str, string: str) -> bool:
""" """
AWS API can match a value based on a glob, or an exact match AWS API can match a value based on a glob, or an exact match
""" """
@ -296,7 +304,7 @@ def aws_api_matches(pattern, string):
return False return False
def extract_region_from_aws_authorization(string): def extract_region_from_aws_authorization(string: str) -> Optional[str]:
auth = string or "" auth = string or ""
region = re.sub(r".*Credential=[^/]+/[^/]+/([^/]+)/.*", r"\1", auth) region = re.sub(r".*Credential=[^/]+/[^/]+/([^/]+)/.*", r"\1", auth)
if region == auth: if region == auth:

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/ce,moto/cloudformation,moto/cloudfront,moto/cloudtrail,moto/codebuild,moto/cloudwatch,moto/codepipeline,moto/codecommit,moto/cognito*,moto/comprehend,moto/config,moto/core/base_backend.py files= moto/a*,moto/b*,moto/c*
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract