TechDebt: MyPy Core (#5653)
This commit is contained in:
parent
8c9838cc8c
commit
96b2eff1bc
@ -2,15 +2,17 @@ from collections import defaultdict
|
||||
from io import BytesIO
|
||||
from botocore.awsrequest import AWSResponse
|
||||
from moto.core.exceptions import HTTPException
|
||||
from typing import Any, Dict, Callable, List, Tuple, Union, Pattern
|
||||
from .responses import TYPE_RESPONSE
|
||||
|
||||
|
||||
class MockRawResponse(BytesIO):
|
||||
def __init__(self, response_input):
|
||||
def __init__(self, response_input: Union[str, bytes]):
|
||||
if isinstance(response_input, str):
|
||||
response_input = response_input.encode("utf-8")
|
||||
super().__init__(response_input)
|
||||
|
||||
def stream(self, **kwargs): # pylint: disable=unused-argument
|
||||
def stream(self, **kwargs: Any) -> Any: # pylint: disable=unused-argument
|
||||
contents = self.read()
|
||||
while contents:
|
||||
yield contents
|
||||
@ -18,18 +20,22 @@ class MockRawResponse(BytesIO):
|
||||
|
||||
|
||||
class BotocoreStubber:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.enabled = False
|
||||
self.methods = defaultdict(list)
|
||||
self.methods: Dict[
|
||||
str, List[Tuple[Pattern[str], Callable[..., TYPE_RESPONSE]]]
|
||||
] = defaultdict(list)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.methods.clear()
|
||||
|
||||
def register_response(self, method, pattern, response):
|
||||
def register_response(
|
||||
self, method: str, pattern: Pattern[str], response: Callable[..., TYPE_RESPONSE]
|
||||
) -> None:
|
||||
matchers = self.methods[method]
|
||||
matchers.append((pattern, response))
|
||||
|
||||
def __call__(self, event_name, request, **kwargs):
|
||||
def __call__(self, event_name: str, request: Any, **kwargs: Any) -> AWSResponse:
|
||||
if not self.enabled:
|
||||
return None
|
||||
|
||||
@ -41,7 +47,7 @@ class BotocoreStubber:
|
||||
matchers = self.methods.get(request.method)
|
||||
|
||||
base_url = request.url.split("?", 1)[0]
|
||||
for i, (pattern, callback) in enumerate(matchers):
|
||||
for i, (pattern, callback) in enumerate(matchers): # type: ignore[arg-type]
|
||||
if pattern.match(base_url):
|
||||
if found_index is None:
|
||||
found_index = i
|
||||
@ -62,10 +68,10 @@ class BotocoreStubber:
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
status = e.code
|
||||
headers = e.get_headers()
|
||||
status = e.code # type: ignore[assignment]
|
||||
headers = e.get_headers() # type: ignore[assignment]
|
||||
body = e.get_body()
|
||||
body = MockRawResponse(body)
|
||||
response = AWSResponse(request.url, status, headers, body)
|
||||
raw_response = MockRawResponse(body)
|
||||
response = AWSResponse(request.url, status, headers, raw_response)
|
||||
|
||||
return response
|
||||
|
@ -1,12 +1,14 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from .base_backend import InstanceTrackerMeta
|
||||
|
||||
|
||||
class BaseModel(metaclass=InstanceTrackerMeta):
|
||||
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
def __new__(
|
||||
cls, *args: Any, **kwargs: Any # pylint: disable=unused-argument
|
||||
) -> "BaseModel":
|
||||
instance = super(BaseModel, cls).__new__(cls)
|
||||
cls.instances.append(instance)
|
||||
cls.instances.append(instance) # type: ignore[attr-defined]
|
||||
return instance
|
||||
|
||||
|
||||
@ -37,7 +39,7 @@ class CloudFormationModel(BaseModel):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_from_cloudformation_json(
|
||||
def create_from_cloudformation_json( # type: ignore[misc]
|
||||
cls,
|
||||
resource_name: str,
|
||||
cloudformation_json: Dict[str, Any],
|
||||
@ -53,7 +55,7 @@ class CloudFormationModel(BaseModel):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def update_from_cloudformation_json(
|
||||
def update_from_cloudformation_json( # type: ignore[misc]
|
||||
cls,
|
||||
original_resource: Any,
|
||||
new_resource_name: str,
|
||||
@ -70,7 +72,7 @@ class CloudFormationModel(BaseModel):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def delete_from_cloudformation_json(
|
||||
def delete_from_cloudformation_json( # type: ignore[misc]
|
||||
cls,
|
||||
resource_name: str,
|
||||
cloudformation_json: Dict[str, Any],
|
||||
@ -92,7 +94,7 @@ class CloudFormationModel(BaseModel):
|
||||
|
||||
|
||||
class ConfigQueryModel:
|
||||
def __init__(self, backends):
|
||||
def __init__(self, backends: Any):
|
||||
"""Inits based on the resource type's backends (1 for each region if applicable)"""
|
||||
self.backends = backends
|
||||
|
||||
@ -106,7 +108,7 @@ class ConfigQueryModel:
|
||||
backend_region: Optional[str] = None,
|
||||
resource_region: Optional[str] = None,
|
||||
aggregator: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
) -> Tuple[List[Dict[str, Any]], str]:
|
||||
"""For AWS Config. This will list all of the resources of the given type and optional resource name and region.
|
||||
|
||||
This supports both aggregated and non-aggregated listing. The following notes the difference:
|
||||
@ -195,5 +197,5 @@ class ConfigQueryModel:
|
||||
class CloudWatchMetricProvider(object):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_cloudwatch_metrics(account_id: str) -> Any:
|
||||
def get_cloudwatch_metrics(account_id: str) -> Any: # type: ignore[misc]
|
||||
pass
|
||||
|
5
moto/core/common_types.py
Normal file
5
moto/core/common_types.py
Normal file
@ -0,0 +1,5 @@
|
||||
from typing import Dict, Tuple, TypeVar
|
||||
|
||||
|
||||
TYPE_RESPONSE = Tuple[int, Dict[str, str], str]
|
||||
TYPE_IF_NONE = TypeVar("TYPE_IF_NONE")
|
@ -2,8 +2,10 @@ import responses
|
||||
import types
|
||||
from io import BytesIO
|
||||
from http.client import responses as http_responses
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from urllib.parse import urlparse
|
||||
from werkzeug.wrappers import Request
|
||||
from .responses import TYPE_RESPONSE
|
||||
|
||||
from moto.utilities.distutils_version import LooseVersion
|
||||
|
||||
@ -21,7 +23,7 @@ class CallbackResponse(responses.CallbackResponse):
|
||||
Need to subclass so we can change a couple things
|
||||
"""
|
||||
|
||||
def get_response(self, request):
|
||||
def get_response(self, request: Any) -> responses.HTTPResponse:
|
||||
"""
|
||||
Need to override this so we can pass decode_content=False
|
||||
"""
|
||||
@ -58,20 +60,22 @@ class CallbackResponse(responses.CallbackResponse):
|
||||
raise result
|
||||
|
||||
status, r_headers, body = result
|
||||
body = responses._handle_body(body)
|
||||
body_io = responses._handle_body(body)
|
||||
headers.update(r_headers)
|
||||
|
||||
return responses.HTTPResponse(
|
||||
status=status,
|
||||
reason=http_responses.get(status),
|
||||
body=body,
|
||||
body=body_io,
|
||||
headers=headers,
|
||||
preload_content=False,
|
||||
# Need to not decode_content to mimic requests
|
||||
decode_content=False,
|
||||
)
|
||||
|
||||
def _url_matches(self, url, other, match_querystring=False):
|
||||
def _url_matches(
|
||||
self, url: Any, other: Any, match_querystring: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Need to override this so we can fix querystrings breaking regex matching
|
||||
"""
|
||||
@ -83,16 +87,18 @@ class CallbackResponse(responses.CallbackResponse):
|
||||
url = responses._clean_unicode(url)
|
||||
if not isinstance(other, str):
|
||||
other = other.encode("ascii").decode("utf8")
|
||||
return self._url_matches_strict(url, other)
|
||||
return self._url_matches_strict(url, other) # type: ignore[attr-defined]
|
||||
elif isinstance(url, responses.Pattern) and url.match(other):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def not_implemented_callback(request): # pylint: disable=unused-argument
|
||||
def not_implemented_callback(
|
||||
request: Any, # pylint: disable=unused-argument
|
||||
) -> TYPE_RESPONSE:
|
||||
status = 400
|
||||
headers = {}
|
||||
headers: Dict[str, str] = {}
|
||||
response = "The method is not implemented"
|
||||
|
||||
return status, headers, response
|
||||
@ -106,10 +112,12 @@ def not_implemented_callback(request): # pylint: disable=unused-argument
|
||||
#
|
||||
# Note that, due to an outdated API we no longer support Responses <= 0.12.1
|
||||
# This method should be used for Responses 0.12.1 < .. < 0.17.0
|
||||
def _find_first_match(self, request):
|
||||
def _find_first_match(
|
||||
self: Any, request: Any
|
||||
) -> Tuple[Optional[responses.BaseResponse], List[str]]:
|
||||
matches = []
|
||||
match_failed_reasons = []
|
||||
all_possibles = self._matches + responses._default_mock._matches
|
||||
all_possibles = self._matches + responses._default_mock._matches # type: ignore[attr-defined]
|
||||
for match in all_possibles:
|
||||
match_result, reason = match.matches(request)
|
||||
if match_result:
|
||||
@ -132,7 +140,7 @@ def _find_first_match(self, request):
|
||||
return None, match_failed_reasons
|
||||
|
||||
|
||||
def get_response_mock():
|
||||
def get_response_mock() -> responses.RequestsMock:
|
||||
"""
|
||||
The responses-library is crucial in ensuring that requests to AWS are intercepted, and routed to the right backend.
|
||||
However, as our usecase is different from a 'typical' responses-user, Moto always needs some custom logic to ensure responses behaves in a way that works for us.
|
||||
@ -152,13 +160,13 @@ def get_response_mock():
|
||||
)
|
||||
else:
|
||||
responses_mock = responses.RequestsMock(assert_all_requests_are_fired=False)
|
||||
responses_mock._find_match = types.MethodType(_find_first_match, responses_mock)
|
||||
responses_mock._find_match = types.MethodType(_find_first_match, responses_mock) # type: ignore[assignment]
|
||||
|
||||
responses_mock.add_passthru("http")
|
||||
return responses_mock
|
||||
|
||||
|
||||
def reset_responses_mock(responses_mock):
|
||||
def reset_responses_mock(responses_mock: responses.RequestsMock) -> None:
|
||||
if LooseVersion(RESPONSES_VERSION) >= LooseVersion("0.17.0"):
|
||||
from .responses_custom_registry import CustomRegistry
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from jinja2 import DictLoader, Environment
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Tuple, Optional
|
||||
import json
|
||||
|
||||
# TODO: add "<Type>Sender</Type>" to error responses below?
|
||||
@ -67,14 +67,18 @@ class RESTError(HTTPException):
|
||||
)
|
||||
self.content_type = "application/xml"
|
||||
|
||||
def get_headers(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||
return {
|
||||
"X-Amzn-ErrorType": self.error_type or "UnknownError",
|
||||
"Content-Type": self.content_type,
|
||||
}
|
||||
def get_headers(
|
||||
self, *args: Any, **kwargs: Any # pylint: disable=unused-argument
|
||||
) -> List[Tuple[str, str]]:
|
||||
return [
|
||||
("X-Amzn-ErrorType", self.error_type or "UnknownError"),
|
||||
("Content-Type", self.content_type),
|
||||
]
|
||||
|
||||
def get_body(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||
return self.description
|
||||
def get_body(
|
||||
self, *args: Any, **kwargs: Any # pylint: disable=unused-argument
|
||||
) -> str:
|
||||
return self.description # type: ignore[return-value]
|
||||
|
||||
|
||||
class DryRunClientError(RESTError):
|
||||
@ -86,19 +90,19 @@ class JsonRESTError(RESTError):
|
||||
self, error_type: str, message: str, template: str = "error_json", **kwargs: Any
|
||||
):
|
||||
super().__init__(error_type, message, template, **kwargs)
|
||||
self.description = json.dumps(
|
||||
self.description: str = json.dumps(
|
||||
{"__type": self.error_type, "message": self.message}
|
||||
)
|
||||
self.content_type = "application/json"
|
||||
|
||||
def get_body(self, *args, **kwargs) -> str:
|
||||
def get_body(self, *args: Any, **kwargs: Any) -> str:
|
||||
return self.description
|
||||
|
||||
|
||||
class SignatureDoesNotMatchError(RESTError):
|
||||
code = 403
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"SignatureDoesNotMatch",
|
||||
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.",
|
||||
@ -108,7 +112,7 @@ class SignatureDoesNotMatchError(RESTError):
|
||||
class InvalidClientTokenIdError(RESTError):
|
||||
code = 403
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"InvalidClientTokenId",
|
||||
"The security token included in the request is invalid.",
|
||||
@ -118,7 +122,7 @@ class InvalidClientTokenIdError(RESTError):
|
||||
class AccessDeniedError(RESTError):
|
||||
code = 403
|
||||
|
||||
def __init__(self, user_arn, action):
|
||||
def __init__(self, user_arn: str, action: str):
|
||||
super().__init__(
|
||||
"AccessDenied",
|
||||
"User: {user_arn} is not authorized to perform: {operation}".format(
|
||||
@ -130,7 +134,7 @@ class AccessDeniedError(RESTError):
|
||||
class AuthFailureError(RESTError):
|
||||
code = 401
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"AuthFailure",
|
||||
"AWS was not able to validate the provided access credentials",
|
||||
@ -142,9 +146,12 @@ class AWSError(JsonRESTError):
|
||||
STATUS = 400
|
||||
|
||||
def __init__(
|
||||
self, message: str, exception_type: str = None, status: Optional[int] = None
|
||||
self,
|
||||
message: str,
|
||||
exception_type: Optional[str] = None,
|
||||
status: Optional[int] = None,
|
||||
):
|
||||
super().__init__(exception_type or self.TYPE, message)
|
||||
super().__init__(exception_type or self.TYPE, message) # type: ignore[arg-type]
|
||||
self.code = status or self.STATUS
|
||||
|
||||
|
||||
@ -153,7 +160,7 @@ class InvalidNextTokenException(JsonRESTError):
|
||||
|
||||
code = 400
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"InvalidNextTokenException", "The nextToken provided is invalid"
|
||||
)
|
||||
@ -162,5 +169,5 @@ class InvalidNextTokenException(JsonRESTError):
|
||||
class InvalidToken(AWSError):
|
||||
code = 400
|
||||
|
||||
def __init__(self, message="Invalid token"):
|
||||
def __init__(self, message: str = "Invalid token"):
|
||||
super().__init__("Invalid Token: {}".format(message), "InvalidToken")
|
||||
|
@ -5,6 +5,7 @@ import os
|
||||
import re
|
||||
import unittest
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Dict, Optional, Set, TypeVar
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
@ -14,7 +15,7 @@ from botocore.config import Config
|
||||
from botocore.handlers import BUILTIN_HANDLERS
|
||||
|
||||
from moto import settings
|
||||
from moto.core.base_backend import BackendDict
|
||||
from .base_backend import BackendDict
|
||||
from .botocore_stubber import BotocoreStubber
|
||||
from .custom_responses_mock import (
|
||||
get_response_mock,
|
||||
@ -24,13 +25,14 @@ from .custom_responses_mock import (
|
||||
)
|
||||
|
||||
DEFAULT_ACCOUNT_ID = "123456789012"
|
||||
CALLABLE_RETURN = TypeVar("CALLABLE_RETURN")
|
||||
|
||||
|
||||
class BaseMockAWS:
|
||||
nested_count = 0
|
||||
mocks_active = False
|
||||
|
||||
def __init__(self, backends):
|
||||
def __init__(self, backends: BackendDict):
|
||||
from moto.instance_metadata import instance_metadata_backends
|
||||
from moto.moto_api._internal.models import moto_api_backend
|
||||
|
||||
@ -55,25 +57,25 @@ class BaseMockAWS:
|
||||
"AWS_ACCESS_KEY_ID": "foobar_key",
|
||||
"AWS_SECRET_ACCESS_KEY": "foobar_secret",
|
||||
}
|
||||
self.ORIG_KEYS = {}
|
||||
self.ORIG_KEYS: Dict[str, Optional[str]] = {}
|
||||
self.default_session_mock = patch("boto3.DEFAULT_SESSION", None)
|
||||
|
||||
if self.__class__.nested_count == 0:
|
||||
self.reset()
|
||||
self.reset() # type: ignore[attr-defined]
|
||||
|
||||
def __call__(self, func, reset=True):
|
||||
def __call__(self, func: Callable[..., Any], reset: bool = True) -> Any:
|
||||
if inspect.isclass(func):
|
||||
return self.decorate_class(func)
|
||||
return self.decorate_callable(func, reset)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "BaseMockAWS":
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.stop()
|
||||
|
||||
def start(self, reset=True):
|
||||
def start(self, reset: bool = True) -> None:
|
||||
if not self.__class__.mocks_active:
|
||||
self.default_session_mock.start()
|
||||
self.mock_env_variables()
|
||||
@ -84,9 +86,9 @@ class BaseMockAWS:
|
||||
for backend in self.backends.values():
|
||||
backend.reset()
|
||||
|
||||
self.enable_patching(reset)
|
||||
self.enable_patching(reset) # type: ignore[attr-defined]
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
self.__class__.nested_count -= 1
|
||||
|
||||
if self.__class__.nested_count < 0:
|
||||
@ -102,10 +104,12 @@ class BaseMockAWS:
|
||||
pass
|
||||
self.unmock_env_variables()
|
||||
self.__class__.mocks_active = False
|
||||
self.disable_patching()
|
||||
self.disable_patching() # type: ignore[attr-defined]
|
||||
|
||||
def decorate_callable(self, func, reset):
|
||||
def wrapper(*args, **kwargs):
|
||||
def decorate_callable(
|
||||
self, func: Callable[..., CALLABLE_RETURN], reset: bool
|
||||
) -> Callable[..., CALLABLE_RETURN]:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> CALLABLE_RETURN:
|
||||
self.start(reset=reset)
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
@ -114,10 +118,10 @@ class BaseMockAWS:
|
||||
return result
|
||||
|
||||
functools.update_wrapper(wrapper, func)
|
||||
wrapper.__wrapped__ = func
|
||||
wrapper.__wrapped__ = func # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
||||
def decorate_class(self, klass):
|
||||
def decorate_class(self, klass: type) -> object:
|
||||
direct_methods = get_direct_methods_of(klass)
|
||||
defined_classes = set(
|
||||
x for x, y in klass.__dict__.items() if inspect.isclass(y)
|
||||
@ -181,7 +185,7 @@ class BaseMockAWS:
|
||||
continue
|
||||
return klass
|
||||
|
||||
def mock_env_variables(self):
|
||||
def mock_env_variables(self) -> None:
|
||||
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
|
||||
# self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
|
||||
# self.env_variables_mocks.start()
|
||||
@ -189,7 +193,7 @@ class BaseMockAWS:
|
||||
self.ORIG_KEYS[k] = os.environ.get(k, None)
|
||||
os.environ[k] = v
|
||||
|
||||
def unmock_env_variables(self):
|
||||
def unmock_env_variables(self) -> None:
|
||||
# This doesn't work in Python2 - for some reason, unmocking clears the entire os.environ dict
|
||||
# Obviously bad user experience, and also breaks pytest - as it uses PYTEST_CURRENT_TEST as an env var
|
||||
# self.env_variables_mocks.stop()
|
||||
@ -200,7 +204,7 @@ class BaseMockAWS:
|
||||
del os.environ[k]
|
||||
|
||||
|
||||
def get_direct_methods_of(klass):
|
||||
def get_direct_methods_of(klass: object) -> Set[str]:
|
||||
return set(
|
||||
x
|
||||
for x, y in klass.__dict__.items()
|
||||
@ -232,7 +236,7 @@ botocore_stubber = BotocoreStubber()
|
||||
BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
|
||||
|
||||
|
||||
def patch_client(client):
|
||||
def patch_client(client: botocore.client.BaseClient) -> None:
|
||||
"""
|
||||
Explicitly patch a boto3-client
|
||||
"""
|
||||
@ -254,7 +258,7 @@ def patch_client(client):
|
||||
raise Exception(f"Argument {client} should be of type boto3.client")
|
||||
|
||||
|
||||
def patch_resource(resource):
|
||||
def patch_resource(resource: Any) -> None:
|
||||
"""
|
||||
Explicitly patch a boto3-resource
|
||||
"""
|
||||
@ -267,11 +271,13 @@ def patch_resource(resource):
|
||||
|
||||
|
||||
class BotocoreEventMockAWS(BaseMockAWS):
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
botocore_stubber.reset()
|
||||
reset_responses_mock(responses_mock)
|
||||
|
||||
def enable_patching(self, reset=True): # pylint: disable=unused-argument
|
||||
def enable_patching(
|
||||
self, reset: bool = True # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
# Circumvent circular imports
|
||||
from .utils import convert_flask_to_responses_response
|
||||
|
||||
@ -313,7 +319,7 @@ class BotocoreEventMockAWS(BaseMockAWS):
|
||||
)
|
||||
)
|
||||
|
||||
def disable_patching(self):
|
||||
def disable_patching(self) -> None:
|
||||
botocore_stubber.enabled = False
|
||||
self.reset()
|
||||
|
||||
@ -330,14 +336,13 @@ class ServerModeMockAWS(BaseMockAWS):
|
||||
|
||||
RESET_IN_PROGRESS = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
self.test_server_mode_endpoint = settings.test_server_mode_endpoint()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
call_reset_api = os.environ.get("MOTO_CALL_RESET_API")
|
||||
call_reset_api = not call_reset_api or call_reset_api.lower() != "false"
|
||||
if call_reset_api:
|
||||
if not call_reset_api or call_reset_api.lower() != "false":
|
||||
if not ServerModeMockAWS.RESET_IN_PROGRESS:
|
||||
ServerModeMockAWS.RESET_IN_PROGRESS = True
|
||||
import requests
|
||||
@ -345,14 +350,14 @@ class ServerModeMockAWS(BaseMockAWS):
|
||||
requests.post(f"{self.test_server_mode_endpoint}/moto-api/reset")
|
||||
ServerModeMockAWS.RESET_IN_PROGRESS = False
|
||||
|
||||
def enable_patching(self, reset=True):
|
||||
def enable_patching(self, reset: bool = True) -> None:
|
||||
if self.__class__.nested_count == 1 and reset:
|
||||
# Just started
|
||||
self.reset()
|
||||
|
||||
from boto3 import client as real_boto3_client, resource as real_boto3_resource
|
||||
|
||||
def fake_boto3_client(*args, **kwargs):
|
||||
def fake_boto3_client(*args: Any, **kwargs: Any) -> botocore.client.BaseClient:
|
||||
region = self._get_region(*args, **kwargs)
|
||||
if region:
|
||||
if "config" in kwargs:
|
||||
@ -364,7 +369,7 @@ class ServerModeMockAWS(BaseMockAWS):
|
||||
kwargs["endpoint_url"] = self.test_server_mode_endpoint
|
||||
return real_boto3_client(*args, **kwargs)
|
||||
|
||||
def fake_boto3_resource(*args, **kwargs):
|
||||
def fake_boto3_resource(*args: Any, **kwargs: Any) -> Any:
|
||||
if "endpoint_url" not in kwargs:
|
||||
kwargs["endpoint_url"] = self.test_server_mode_endpoint
|
||||
return real_boto3_resource(*args, **kwargs)
|
||||
@ -374,7 +379,7 @@ class ServerModeMockAWS(BaseMockAWS):
|
||||
self._client_patcher.start()
|
||||
self._resource_patcher.start()
|
||||
|
||||
def _get_region(self, *args, **kwargs):
|
||||
def _get_region(self, *args: Any, **kwargs: Any) -> Optional[str]:
|
||||
if "region_name" in kwargs:
|
||||
return kwargs["region_name"]
|
||||
if type(args) == tuple and len(args) == 2:
|
||||
@ -382,7 +387,7 @@ class ServerModeMockAWS(BaseMockAWS):
|
||||
return region
|
||||
return None
|
||||
|
||||
def disable_patching(self):
|
||||
def disable_patching(self) -> None:
|
||||
if self._client_patcher:
|
||||
self._client_patcher.stop()
|
||||
self._resource_patcher.stop()
|
||||
@ -394,9 +399,9 @@ class base_decorator:
|
||||
def __init__(self, backends: BackendDict):
|
||||
self.backends = backends
|
||||
|
||||
def __call__(self, func=None):
|
||||
def __call__(self, func: Optional[Callable[..., Any]] = None) -> BaseMockAWS:
|
||||
if settings.TEST_SERVER_MODE:
|
||||
mocked_backend = ServerModeMockAWS(self.backends)
|
||||
mocked_backend: BaseMockAWS = ServerModeMockAWS(self.backends)
|
||||
else:
|
||||
mocked_backend = self.mock_backend(self.backends)
|
||||
|
||||
|
@ -11,11 +11,22 @@ import xmltodict
|
||||
|
||||
from collections import defaultdict, OrderedDict
|
||||
from moto import settings
|
||||
from moto.core.common_types import TYPE_RESPONSE, TYPE_IF_NONE
|
||||
from moto.core.exceptions import DryRunClientError
|
||||
from moto.core.utils import camelcase_to_underscores, method_names_from_class
|
||||
from moto.utilities.utils import load_resource
|
||||
from jinja2 import Environment, DictLoader, Template
|
||||
from typing import Dict, Union, Any, Tuple, TypeVar
|
||||
from typing import (
|
||||
Dict,
|
||||
Union,
|
||||
Any,
|
||||
Tuple,
|
||||
Optional,
|
||||
List,
|
||||
Set,
|
||||
ClassVar,
|
||||
Callable,
|
||||
)
|
||||
from urllib.parse import parse_qs, parse_qsl, urlparse
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from xml.dom.minidom import parseString as parseXML
|
||||
@ -23,29 +34,19 @@ from xml.dom.minidom import parseString as parseXML
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
JINJA_ENVS = {}
|
||||
|
||||
TYPE_RESPONSE = Tuple[int, Dict[str, str], str]
|
||||
TYPE_IF_NONE = TypeVar("TYPE_IF_NONE")
|
||||
JINJA_ENVS: Dict[type, Environment] = {}
|
||||
|
||||
|
||||
def _decode_dict(d):
|
||||
decoded = OrderedDict()
|
||||
def _decode_dict(d: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
decoded: Dict[str, Any] = OrderedDict()
|
||||
for key, value in d.items():
|
||||
if isinstance(key, bytes):
|
||||
newkey = key.decode("utf-8")
|
||||
elif isinstance(key, (list, tuple)):
|
||||
newkey = []
|
||||
for k in key:
|
||||
if isinstance(k, bytes):
|
||||
newkey.append(k.decode("utf-8"))
|
||||
else:
|
||||
newkey.append(k)
|
||||
else:
|
||||
newkey = key
|
||||
|
||||
if isinstance(value, bytes):
|
||||
newvalue = value.decode("utf-8")
|
||||
decoded[newkey] = value.decode("utf-8")
|
||||
elif isinstance(value, (list, tuple)):
|
||||
newvalue = []
|
||||
for v in value:
|
||||
@ -53,18 +54,18 @@ def _decode_dict(d):
|
||||
newvalue.append(v.decode("utf-8"))
|
||||
else:
|
||||
newvalue.append(v)
|
||||
decoded[newkey] = newvalue
|
||||
else:
|
||||
newvalue = value
|
||||
decoded[newkey] = value
|
||||
|
||||
decoded[newkey] = newvalue
|
||||
return decoded
|
||||
|
||||
|
||||
class DynamicDictLoader(DictLoader):
|
||||
def update(self, mapping):
|
||||
self.mapping.update(mapping)
|
||||
def update(self, mapping: Dict[str, str]) -> None:
|
||||
self.mapping.update(mapping) # type: ignore[attr-defined]
|
||||
|
||||
def contains(self, template):
|
||||
def contains(self, template: str) -> bool:
|
||||
return bool(template in self.mapping)
|
||||
|
||||
|
||||
@ -73,12 +74,12 @@ class _TemplateEnvironmentMixin(object):
|
||||
RIGHT_PATTERN = re.compile(r">[\s\n]+")
|
||||
|
||||
@property
|
||||
def should_autoescape(self):
|
||||
def should_autoescape(self) -> bool:
|
||||
# Allow for subclass to overwrite
|
||||
return False
|
||||
|
||||
@property
|
||||
def environment(self):
|
||||
def environment(self) -> Environment:
|
||||
key = type(self)
|
||||
try:
|
||||
environment = JINJA_ENVS[key]
|
||||
@ -94,11 +95,11 @@ class _TemplateEnvironmentMixin(object):
|
||||
|
||||
return environment
|
||||
|
||||
def contains_template(self, template_id):
|
||||
return self.environment.loader.contains(template_id)
|
||||
def contains_template(self, template_id: str) -> bool:
|
||||
return self.environment.loader.contains(template_id) # type: ignore[union-attr]
|
||||
|
||||
@classmethod
|
||||
def _make_template_id(cls, source):
|
||||
def _make_template_id(cls, source: str) -> str:
|
||||
"""
|
||||
Return a numeric string that's unique for the lifetime of the source.
|
||||
|
||||
@ -117,47 +118,49 @@ class _TemplateEnvironmentMixin(object):
|
||||
xml = re.sub(
|
||||
self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source)
|
||||
)
|
||||
self.environment.loader.update({template_id: xml})
|
||||
self.environment.loader.update({template_id: xml}) # type: ignore[union-attr]
|
||||
return self.environment.get_template(template_id)
|
||||
|
||||
|
||||
class ActionAuthenticatorMixin(object):
|
||||
|
||||
request_count = 0
|
||||
request_count: ClassVar[int] = 0
|
||||
|
||||
def _authenticate_and_authorize_action(self, iam_request_cls):
|
||||
def _authenticate_and_authorize_action(self, iam_request_cls: type) -> None:
|
||||
if (
|
||||
ActionAuthenticatorMixin.request_count
|
||||
>= settings.INITIAL_NO_AUTH_ACTION_COUNT
|
||||
):
|
||||
iam_request = iam_request_cls(
|
||||
account_id=self.current_account,
|
||||
method=self.method,
|
||||
path=self.path,
|
||||
data=self.data,
|
||||
headers=self.headers,
|
||||
account_id=self.current_account, # type: ignore[attr-defined]
|
||||
method=self.method, # type: ignore[attr-defined]
|
||||
path=self.path, # type: ignore[attr-defined]
|
||||
data=self.data, # type: ignore[attr-defined]
|
||||
headers=self.headers, # type: ignore[attr-defined]
|
||||
)
|
||||
iam_request.check_signature()
|
||||
iam_request.check_action_permitted()
|
||||
else:
|
||||
ActionAuthenticatorMixin.request_count += 1
|
||||
|
||||
def _authenticate_and_authorize_normal_action(self):
|
||||
def _authenticate_and_authorize_normal_action(self) -> None:
|
||||
from moto.iam.access_control import IAMRequest
|
||||
|
||||
self._authenticate_and_authorize_action(IAMRequest)
|
||||
|
||||
def _authenticate_and_authorize_s3_action(self):
|
||||
def _authenticate_and_authorize_s3_action(self) -> None:
|
||||
from moto.iam.access_control import S3IAMRequest
|
||||
|
||||
self._authenticate_and_authorize_action(S3IAMRequest)
|
||||
|
||||
@staticmethod
|
||||
def set_initial_no_auth_action_count(initial_no_auth_action_count):
|
||||
def set_initial_no_auth_action_count(initial_no_auth_action_count: int) -> Callable[..., Callable[..., TYPE_RESPONSE]]: # type: ignore[misc]
|
||||
_test_server_mode_endpoint = settings.test_server_mode_endpoint()
|
||||
|
||||
def decorator(function):
|
||||
def wrapper(*args, **kwargs):
|
||||
def decorator(
|
||||
function: Callable[..., TYPE_RESPONSE]
|
||||
) -> Callable[..., TYPE_RESPONSE]:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> TYPE_RESPONSE:
|
||||
if settings.TEST_SERVER_MODE:
|
||||
response = requests.post(
|
||||
f"{_test_server_mode_endpoint}/moto-api/reset-auth",
|
||||
@ -191,7 +194,7 @@ class ActionAuthenticatorMixin(object):
|
||||
return result
|
||||
|
||||
functools.update_wrapper(wrapper, function)
|
||||
wrapper.__wrapped__ = function
|
||||
wrapper.__wrapped__ = function # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@ -213,12 +216,12 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
)
|
||||
aws_service_spec = None
|
||||
|
||||
def __init__(self, service_name=None) -> None:
|
||||
def __init__(self, service_name: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.service_name = service_name
|
||||
|
||||
@classmethod
|
||||
def dispatch(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
def dispatch(cls, *args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
|
||||
return cls()._dispatch(*args, **kwargs)
|
||||
|
||||
def setup_class(
|
||||
@ -227,7 +230,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
"""
|
||||
use_raw_body: Use incoming bytes if True, encode to string otherwise
|
||||
"""
|
||||
querystring = OrderedDict()
|
||||
querystring: Dict[str, Any] = OrderedDict()
|
||||
if hasattr(request, "body"):
|
||||
# Boto
|
||||
self.body = request.body
|
||||
@ -292,7 +295,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
self.data = querystring
|
||||
self.method = request.method
|
||||
self.region = self.get_region_from_url(request, full_url)
|
||||
self.uri_match = None
|
||||
self.uri_match: Optional[re.Match[str]] = None
|
||||
|
||||
self.headers = request.headers
|
||||
if "host" not in self.headers:
|
||||
@ -307,11 +310,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
mark_account_as_visited(
|
||||
account_id=self.current_account,
|
||||
access_key=self.access_key,
|
||||
service=self.service_name,
|
||||
service=self.service_name, # type: ignore[arg-type]
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
def get_region_from_url(self, request, full_url):
|
||||
def get_region_from_url(self, request: Any, full_url: str) -> str:
|
||||
url_match = self.region_regex.search(full_url)
|
||||
user_agent_match = self.region_from_useragent_regex.search(
|
||||
request.headers.get("User-Agent", "")
|
||||
@ -329,7 +332,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
region = self.default_region
|
||||
return region
|
||||
|
||||
def get_access_key(self):
|
||||
def get_access_key(self) -> str:
|
||||
"""
|
||||
Returns the access key id used in this request as the current user id
|
||||
"""
|
||||
@ -339,11 +342,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
return match.group(1)
|
||||
|
||||
if self.querystring.get("AWSAccessKeyId"):
|
||||
return self.querystring.get("AWSAccessKeyId")[0]
|
||||
return self.querystring["AWSAccessKeyId"][0]
|
||||
else:
|
||||
return "AKIAEXAMPLE"
|
||||
|
||||
def get_current_account(self):
|
||||
def get_current_account(self) -> str:
|
||||
# PRIO 1: Check if we have a Environment Variable set
|
||||
if "MOTO_ACCOUNT_ID" in os.environ:
|
||||
return os.environ["MOTO_ACCOUNT_ID"]
|
||||
@ -358,11 +361,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
|
||||
return get_account_id_from(self.get_access_key())
|
||||
|
||||
def _dispatch(self, request, full_url, headers):
|
||||
def _dispatch(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
|
||||
self.setup_class(request, full_url, headers)
|
||||
return self.call_action()
|
||||
|
||||
def uri_to_regexp(self, uri):
|
||||
def uri_to_regexp(self, uri: str) -> str:
|
||||
"""converts uri w/ placeholder to regexp
|
||||
'/cars/{carName}/drivers/{DriverName}'
|
||||
-> '^/cars/.*/drivers/[^/]*$'
|
||||
@ -372,7 +375,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
|
||||
"""
|
||||
|
||||
def _convert(elem, is_last):
|
||||
def _convert(elem: str, is_last: bool) -> str:
|
||||
if not re.match("^{.*}$", elem):
|
||||
return elem
|
||||
name = (
|
||||
@ -394,7 +397,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
)
|
||||
return regexp
|
||||
|
||||
def _get_action_from_method_and_request_uri(self, method, request_uri):
|
||||
def _get_action_from_method_and_request_uri(
|
||||
self, method: str, request_uri: str
|
||||
) -> str:
|
||||
"""basically used for `rest-json` APIs
|
||||
You can refer to example from link below
|
||||
https://github.com/boto/botocore/blob/develop/botocore/data/iot/2015-05-28/service-2.json
|
||||
@ -406,7 +411,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
|
||||
# make cache if it does not exist yet
|
||||
if not hasattr(self, "method_urls"):
|
||||
self.method_urls = defaultdict(lambda: defaultdict(str))
|
||||
self.method_urls: Dict[str, Dict[str, str]] = defaultdict(
|
||||
lambda: defaultdict(str)
|
||||
)
|
||||
op_names = conn._service_model.operation_names
|
||||
for op_name in op_names:
|
||||
op_model = conn._service_model.operation_model(op_name)
|
||||
@ -419,9 +426,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
self.uri_match = match
|
||||
if match:
|
||||
return name
|
||||
return None
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
def _get_action(self):
|
||||
def _get_action(self) -> str:
|
||||
action = self.querystring.get("Action", [""])[0]
|
||||
if action:
|
||||
return action
|
||||
@ -433,7 +440,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
# get action from method and uri
|
||||
return self._get_action_from_method_and_request_uri(self.method, self.path)
|
||||
|
||||
def call_action(self):
|
||||
def call_action(self) -> TYPE_RESPONSE:
|
||||
headers = self.response_headers
|
||||
|
||||
try:
|
||||
@ -449,9 +456,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
try:
|
||||
response = method()
|
||||
except HTTPException as http_error:
|
||||
response_headers = dict(http_error.get_headers() or [])
|
||||
response_headers["status"] = http_error.code
|
||||
response = http_error.description, response_headers
|
||||
response_headers: Dict[str, Union[str, int]] = dict(
|
||||
http_error.get_headers() or []
|
||||
)
|
||||
response_headers["status"] = http_error.code # type: ignore[assignment]
|
||||
response = http_error.description, response_headers # type: ignore[assignment]
|
||||
|
||||
if isinstance(response, str):
|
||||
return 200, headers, response
|
||||
@ -466,7 +475,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _send_response(headers, response):
|
||||
def _send_response(headers: Dict[str, str], response: Any) -> Tuple[int, Dict[str, str], str]: # type: ignore[misc]
|
||||
if response is None:
|
||||
response = "", {}
|
||||
if len(response) == 2:
|
||||
@ -480,7 +489,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
headers["status"] = str(headers["status"])
|
||||
return status, headers, body
|
||||
|
||||
def _get_param(self, param_name, if_none=None) -> Any:
|
||||
def _get_param(self, param_name: str, if_none: Any = None) -> Any:
|
||||
val = self.querystring.get(param_name)
|
||||
if val is not None:
|
||||
return val[0]
|
||||
@ -503,7 +512,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
return if_none
|
||||
|
||||
def _get_int_param(
|
||||
self, param_name, if_none: TYPE_IF_NONE = None
|
||||
self, param_name: str, if_none: TYPE_IF_NONE = None # type: ignore[assignment]
|
||||
) -> Union[int, TYPE_IF_NONE]:
|
||||
val = self._get_param(param_name)
|
||||
if val is not None:
|
||||
@ -511,7 +520,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
return if_none
|
||||
|
||||
def _get_bool_param(
|
||||
self, param_name, if_none: TYPE_IF_NONE = None
|
||||
self, param_name: str, if_none: TYPE_IF_NONE = None # type: ignore[assignment]
|
||||
) -> Union[bool, TYPE_IF_NONE]:
|
||||
val = self._get_param(param_name)
|
||||
if val is not None:
|
||||
@ -522,12 +531,15 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
return False
|
||||
return if_none
|
||||
|
||||
def _get_multi_param_dict(self, param_prefix) -> Dict:
|
||||
def _get_multi_param_dict(self, param_prefix: str) -> Dict[str, Any]:
|
||||
return self._get_multi_param_helper(param_prefix, skip_result_conversion=True)
|
||||
|
||||
def _get_multi_param_helper(
|
||||
self, param_prefix, skip_result_conversion=False, tracked_prefixes=None
|
||||
):
|
||||
self,
|
||||
param_prefix: str,
|
||||
skip_result_conversion: bool = False,
|
||||
tracked_prefixes: Optional[Set[str]] = None,
|
||||
) -> Any:
|
||||
value_dict = dict()
|
||||
tracked_prefixes = (
|
||||
tracked_prefixes or set()
|
||||
@ -589,11 +601,13 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
if len(parts) != 2 or parts[1] != "member":
|
||||
value_dict[parts[0]] = value_dict.pop(k)
|
||||
else:
|
||||
value_dict = list(value_dict.values())[0]
|
||||
value_dict = list(value_dict.values())[0] # type: ignore[assignment]
|
||||
|
||||
return value_dict
|
||||
|
||||
def _get_multi_param(self, param_prefix, skip_result_conversion=False) -> Any:
|
||||
def _get_multi_param(
|
||||
self, param_prefix: str, skip_result_conversion: bool = False
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Given a querystring of ?LaunchConfigurationNames.member.1=my-test-1&LaunchConfigurationNames.member.2=my-test-2
|
||||
this will return ['my-test-1', 'my-test-2']
|
||||
@ -616,7 +630,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
|
||||
return values
|
||||
|
||||
def _get_dict_param(self, param_prefix) -> Dict:
|
||||
def _get_dict_param(self, param_prefix: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a parameter dict of
|
||||
{
|
||||
@ -630,7 +644,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
"instance_count": "1",
|
||||
}
|
||||
"""
|
||||
params = {}
|
||||
params: Dict[str, Any] = {}
|
||||
for key, value in self.querystring.items():
|
||||
if key.startswith(param_prefix):
|
||||
params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[
|
||||
@ -638,7 +652,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
]
|
||||
return params
|
||||
|
||||
def _get_params(self) -> Any:
|
||||
def _get_params(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a querystring of
|
||||
{
|
||||
@ -674,12 +688,12 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
]
|
||||
}
|
||||
"""
|
||||
params = {}
|
||||
params: Dict[str, Any] = {}
|
||||
for k, v in sorted(self.querystring.items()):
|
||||
self._parse_param(k, v[0], params)
|
||||
return params
|
||||
|
||||
def _parse_param(self, key, value, params):
|
||||
def _parse_param(self, key: str, value: str, params: Any) -> None:
|
||||
keylist = key.split(".")
|
||||
obj = params
|
||||
for i, key in enumerate(keylist[:-1]):
|
||||
@ -713,7 +727,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
else:
|
||||
obj[keylist[-1]] = value
|
||||
|
||||
def _get_list_prefix(self, param_prefix: str) -> Any:
|
||||
def _get_list_prefix(self, param_prefix: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Given a query dict like
|
||||
{
|
||||
@ -752,7 +766,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
param_index += 1
|
||||
return results
|
||||
|
||||
def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"):
|
||||
def _get_map_prefix(
|
||||
self, param_prefix: str, key_end: str = ".key", value_end: str = ".value"
|
||||
) -> Dict[str, Any]:
|
||||
results = {}
|
||||
param_index = 1
|
||||
while 1:
|
||||
@ -774,7 +790,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
|
||||
return results
|
||||
|
||||
def _get_object_map(self, prefix, name="Name", value="Value"):
|
||||
def _get_object_map(
|
||||
self, prefix: str, name: str = "Name", value: str = "Value"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a query dict like
|
||||
{
|
||||
@ -822,19 +840,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
return object_map
|
||||
|
||||
@property
|
||||
def request_json(self):
|
||||
def request_json(self) -> bool:
|
||||
return "JSON" in self.querystring.get("ContentType", [])
|
||||
|
||||
def error_on_dryrun(self):
|
||||
def error_on_dryrun(self) -> None:
|
||||
self.is_not_dryrun()
|
||||
|
||||
def is_not_dryrun(self, action=None):
|
||||
action = action or self._get_param("Action")
|
||||
def is_not_dryrun(self, action: Optional[str] = None) -> bool:
|
||||
if "true" in self.querystring.get("DryRun", ["false"]):
|
||||
message = (
|
||||
"An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set"
|
||||
% action
|
||||
)
|
||||
a = action or self._get_param("Action")
|
||||
message = f"An error occurred (DryRunOperation) when calling the {a} operation: Request would have succeeded, but DryRun flag is set"
|
||||
raise DryRunClientError(error_type="DryRunOperation", message=message)
|
||||
return True
|
||||
|
||||
@ -842,20 +857,20 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
class _RecursiveDictRef(object):
|
||||
"""Store a recursive reference to dict."""
|
||||
|
||||
def __init__(self):
|
||||
self.key = None
|
||||
self.dic = {}
|
||||
def __init__(self) -> None:
|
||||
self.key: Optional[str] = None
|
||||
self.dic: Dict[str, Any] = {}
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "{!r}".format(self.dic)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self.dic.__getattr__(key)
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
return self.dic.__getattr__(key) # type: ignore[attr-defined]
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.dic.__getitem__(key)
|
||||
|
||||
def set_reference(self, key, dic):
|
||||
def set_reference(self, key: str, dic: Dict[str, Any]) -> None:
|
||||
"""Set the RecursiveDictRef object to keep reference to dict object
|
||||
(dic) at the key.
|
||||
|
||||
@ -877,7 +892,7 @@ class AWSServiceSpec(object):
|
||||
self.operations = spec["operations"]
|
||||
self.shapes = spec["shapes"]
|
||||
|
||||
def input_spec(self, operation):
|
||||
def input_spec(self, operation: str) -> Dict[str, Any]:
|
||||
try:
|
||||
op = self.operations[operation]
|
||||
except KeyError:
|
||||
@ -887,7 +902,7 @@ class AWSServiceSpec(object):
|
||||
shape = self.shapes[op["input"]["shape"]]
|
||||
return self._expand(shape)
|
||||
|
||||
def output_spec(self, operation):
|
||||
def output_spec(self, operation: str) -> Dict[str, Any]:
|
||||
"""Produce a JSON with a valid API response syntax for operation, but
|
||||
with type information. Each node represented by a key has the
|
||||
value containing field type, e.g.,
|
||||
@ -904,11 +919,13 @@ class AWSServiceSpec(object):
|
||||
shape = self.shapes[op["output"]["shape"]]
|
||||
return self._expand(shape)
|
||||
|
||||
def _expand(self, shape):
|
||||
def expand(dic, seen=None):
|
||||
def _expand(self, shape: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def expand(
|
||||
dic: Dict[str, Any], seen: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
seen = seen or {}
|
||||
if dic["type"] == "structure":
|
||||
nodes = {}
|
||||
nodes: Dict[str, Any] = {}
|
||||
for k, v in dic["members"].items():
|
||||
seen_till_here = dict(seen)
|
||||
if k in seen_till_here:
|
||||
@ -932,7 +949,7 @@ class AWSServiceSpec(object):
|
||||
|
||||
elif dic["type"] == "map":
|
||||
seen_till_here = dict(seen)
|
||||
node = {"type": "map"}
|
||||
node: Dict[str, Any] = {"type": "map"}
|
||||
|
||||
if "shape" in dic["key"]:
|
||||
shape = dic["key"]["shape"]
|
||||
@ -958,12 +975,12 @@ class AWSServiceSpec(object):
|
||||
return expand(shape)
|
||||
|
||||
|
||||
def to_str(value, spec):
|
||||
def to_str(value: Any, spec: Dict[str, Any]) -> str:
|
||||
vtype = spec["type"]
|
||||
if vtype == "boolean":
|
||||
return "true" if value else "false"
|
||||
elif vtype == "long":
|
||||
return int(value)
|
||||
return int(value) # type: ignore[return-value]
|
||||
elif vtype == "integer":
|
||||
return str(value)
|
||||
elif vtype == "float":
|
||||
@ -984,7 +1001,7 @@ def to_str(value, spec):
|
||||
raise TypeError("Unknown type {}".format(vtype))
|
||||
|
||||
|
||||
def from_str(value, spec):
|
||||
def from_str(value: str, spec: Dict[str, Any]) -> Any:
|
||||
vtype = spec["type"]
|
||||
if vtype == "boolean":
|
||||
return True if value == "true" else False
|
||||
@ -1001,7 +1018,9 @@ def from_str(value, spec):
|
||||
raise TypeError("Unknown type {}".format(vtype))
|
||||
|
||||
|
||||
def flatten_json_request_body(prefix, dict_body, spec):
|
||||
def flatten_json_request_body(
|
||||
prefix: str, dict_body: Dict[str, Any], spec: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a JSON request body into query params."""
|
||||
if len(spec) == 1 and "type" in spec:
|
||||
return {prefix: to_str(dict_body, spec)}
|
||||
@ -1030,10 +1049,12 @@ def flatten_json_request_body(prefix, dict_body, spec):
|
||||
return dict((prefix + k, v) for k, v in flat.items())
|
||||
|
||||
|
||||
def xml_to_json_response(service_spec, operation, xml, result_node=None):
|
||||
def xml_to_json_response(
|
||||
service_spec: Any, operation: str, xml: str, result_node: Any = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert rendered XML response to JSON for use with boto3."""
|
||||
|
||||
def transform(value, spec):
|
||||
def transform(value: Any, spec: Dict[str, Any]) -> Any:
|
||||
"""Apply transformations to make the output JSON comply with the
|
||||
expected form. This function applies:
|
||||
|
||||
@ -1047,7 +1068,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
|
||||
if len(spec) == 1:
|
||||
return from_str(value, spec)
|
||||
|
||||
od = OrderedDict()
|
||||
od: Dict[str, Any] = OrderedDict()
|
||||
for k, v in value.items():
|
||||
if k.startswith("@"):
|
||||
continue
|
||||
@ -1099,7 +1120,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
|
||||
for k in result_node or (operation + "Response", operation + "Result"):
|
||||
dic = dic[k]
|
||||
except KeyError:
|
||||
return None
|
||||
return None # type: ignore[return-value]
|
||||
else:
|
||||
return transform(dic, output_spec)
|
||||
return None
|
||||
|
@ -1,5 +1,6 @@
|
||||
# This will only exist in responses >= 0.17
|
||||
import responses
|
||||
from typing import Any, List, Tuple, Optional
|
||||
from .custom_responses_mock import CallbackResponse, not_implemented_callback
|
||||
|
||||
|
||||
@ -10,11 +11,12 @@ class CustomRegistry(responses.registries.FirstMatchRegistry):
|
||||
- CallbackResponses are not discarded after first use - users can mock the same URL as often as they like
|
||||
"""
|
||||
|
||||
def add(self, response):
|
||||
def add(self, response: responses.BaseResponse) -> responses.BaseResponse:
|
||||
if response not in self.registered:
|
||||
super().add(response)
|
||||
return response
|
||||
|
||||
def find(self, request):
|
||||
def find(self, request: Any) -> Tuple[Optional[responses.BaseResponse], List[str]]:
|
||||
all_possibles = responses._default_mock._registry.registered + self.registered
|
||||
found = []
|
||||
match_failed_reasons = []
|
||||
|
@ -2,11 +2,12 @@ import datetime
|
||||
import inspect
|
||||
import re
|
||||
from botocore.exceptions import ClientError
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, List, Callable, Dict
|
||||
from urllib.parse import urlparse
|
||||
from .common_types import TYPE_RESPONSE
|
||||
|
||||
|
||||
def camelcase_to_underscores(argument: Optional[str]) -> str:
|
||||
def camelcase_to_underscores(argument: str) -> str:
|
||||
"""Converts a camelcase param like theNewAttribute to the equivalent
|
||||
python underscore variable like the_new_attribute"""
|
||||
result = ""
|
||||
@ -32,7 +33,7 @@ def camelcase_to_underscores(argument: Optional[str]) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def underscores_to_camelcase(argument):
|
||||
def underscores_to_camelcase(argument: str) -> str:
|
||||
"""Converts a camelcase param like the_new_attribute to the equivalent
|
||||
camelcase version like theNewAttribute. Note that the first letter is
|
||||
NOT capitalized by this function"""
|
||||
@ -48,17 +49,17 @@ def underscores_to_camelcase(argument):
|
||||
return result
|
||||
|
||||
|
||||
def pascal_to_camelcase(argument):
|
||||
def pascal_to_camelcase(argument: str) -> str:
|
||||
"""Converts a PascalCase param to the camelCase equivalent"""
|
||||
return argument[0].lower() + argument[1:]
|
||||
|
||||
|
||||
def camelcase_to_pascal(argument):
|
||||
def camelcase_to_pascal(argument: str) -> str:
|
||||
"""Converts a camelCase param to the PascalCase equivalent"""
|
||||
return argument[0].upper() + argument[1:]
|
||||
|
||||
|
||||
def method_names_from_class(clazz):
|
||||
def method_names_from_class(clazz: object) -> List[str]:
|
||||
predicate = inspect.isfunction
|
||||
return [x[0] for x in inspect.getmembers(clazz, predicate=predicate)]
|
||||
|
||||
@ -70,7 +71,7 @@ def convert_regex_to_flask_path(url_path: str) -> str:
|
||||
for token in ["$"]:
|
||||
url_path = url_path.replace(token, "")
|
||||
|
||||
def caller(reg):
|
||||
def caller(reg: Any) -> str:
|
||||
match_name, match_pattern = reg.groups()
|
||||
return '<regex("{0}"):{1}>'.format(match_pattern, match_name)
|
||||
|
||||
@ -83,11 +84,11 @@ def convert_regex_to_flask_path(url_path: str) -> str:
|
||||
|
||||
|
||||
class convert_to_flask_response(object):
|
||||
def __init__(self, callback):
|
||||
def __init__(self, callback: Callable[..., Any]):
|
||||
self.callback = callback
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
def __name__(self) -> str:
|
||||
# For instance methods, use class and method names. Otherwise
|
||||
# use module and method name
|
||||
if inspect.ismethod(self.callback):
|
||||
@ -96,7 +97,7 @@ class convert_to_flask_response(object):
|
||||
outer = self.callback.__module__
|
||||
return "{0}.{1}".format(outer, self.callback.__name__)
|
||||
|
||||
def __call__(self, args=None, **kwargs):
|
||||
def __call__(self, args: Any = None, **kwargs: Any) -> Any:
|
||||
from flask import request, Response
|
||||
from moto.moto_api import recorder
|
||||
|
||||
@ -118,11 +119,11 @@ class convert_to_flask_response(object):
|
||||
|
||||
|
||||
class convert_flask_to_responses_response(object):
|
||||
def __init__(self, callback):
|
||||
def __init__(self, callback: Callable[..., Any]):
|
||||
self.callback = callback
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
def __name__(self) -> str:
|
||||
# For instance methods, use class and method names. Otherwise
|
||||
# use module and method name
|
||||
if inspect.ismethod(self.callback):
|
||||
@ -131,7 +132,7 @@ class convert_flask_to_responses_response(object):
|
||||
outer = self.callback.__module__
|
||||
return "{0}.{1}".format(outer, self.callback.__name__)
|
||||
|
||||
def __call__(self, request, *args, **kwargs):
|
||||
def __call__(self, request: Any, *args: Any, **kwargs: Any) -> TYPE_RESPONSE:
|
||||
for key, val in request.headers.items():
|
||||
if isinstance(val, bytes):
|
||||
request.headers[key] = val.decode("utf-8")
|
||||
@ -141,7 +142,7 @@ class convert_flask_to_responses_response(object):
|
||||
return status, headers, response
|
||||
|
||||
|
||||
def iso_8601_datetime_with_milliseconds(value: datetime) -> str:
|
||||
def iso_8601_datetime_with_milliseconds(value: datetime.datetime) -> str:
|
||||
return value.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
|
||||
|
||||
|
||||
@ -163,22 +164,22 @@ def iso_8601_datetime_without_milliseconds_s3(
|
||||
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
|
||||
|
||||
|
||||
def rfc_1123_datetime(src):
|
||||
def rfc_1123_datetime(src: datetime.datetime) -> str:
|
||||
return src.strftime(RFC1123)
|
||||
|
||||
|
||||
def str_to_rfc_1123_datetime(value):
|
||||
def str_to_rfc_1123_datetime(value: str) -> datetime.datetime:
|
||||
return datetime.datetime.strptime(value, RFC1123)
|
||||
|
||||
|
||||
def unix_time(dt: datetime.datetime = None) -> int:
|
||||
def unix_time(dt: Optional[datetime.datetime] = None) -> float:
|
||||
dt = dt or datetime.datetime.utcnow()
|
||||
epoch = datetime.datetime.utcfromtimestamp(0)
|
||||
delta = dt - epoch
|
||||
return (delta.days * 86400) + (delta.seconds + (delta.microseconds / 1e6))
|
||||
|
||||
|
||||
def unix_time_millis(dt: datetime = None) -> int:
|
||||
def unix_time_millis(dt: Optional[datetime.datetime] = None) -> float:
|
||||
return unix_time(dt) * 1000.0
|
||||
|
||||
|
||||
@ -193,28 +194,33 @@ def path_url(url: str) -> str:
|
||||
|
||||
|
||||
def tags_from_query_string(
|
||||
querystring_dict, prefix="Tag", key_suffix="Key", value_suffix="Value"
|
||||
):
|
||||
querystring_dict: Dict[str, Any],
|
||||
prefix: str = "Tag",
|
||||
key_suffix: str = "Key",
|
||||
value_suffix: str = "Value",
|
||||
) -> Dict[str, str]:
|
||||
response_values = {}
|
||||
for key in querystring_dict.keys():
|
||||
if key.startswith(prefix) and key.endswith(key_suffix):
|
||||
tag_index = key.replace(prefix + ".", "").replace("." + key_suffix, "")
|
||||
tag_key = querystring_dict.get(
|
||||
tag_key = querystring_dict[
|
||||
"{prefix}.{index}.{key_suffix}".format(
|
||||
prefix=prefix, index=tag_index, key_suffix=key_suffix
|
||||
)
|
||||
)[0]
|
||||
][0]
|
||||
tag_value_key = "{prefix}.{index}.{value_suffix}".format(
|
||||
prefix=prefix, index=tag_index, value_suffix=value_suffix
|
||||
)
|
||||
if tag_value_key in querystring_dict:
|
||||
response_values[tag_key] = querystring_dict.get(tag_value_key)[0]
|
||||
response_values[tag_key] = querystring_dict[tag_value_key][0]
|
||||
else:
|
||||
response_values[tag_key] = None
|
||||
return response_values
|
||||
|
||||
|
||||
def tags_from_cloudformation_tags_list(tags_list):
|
||||
def tags_from_cloudformation_tags_list(
|
||||
tags_list: List[Dict[str, str]]
|
||||
) -> Dict[str, str]:
|
||||
"""Return tags in dict form from cloudformation resource tags form (list of dicts)"""
|
||||
tags = {}
|
||||
for entry in tags_list:
|
||||
@ -225,7 +231,7 @@ def tags_from_cloudformation_tags_list(tags_list):
|
||||
return tags
|
||||
|
||||
|
||||
def remap_nested_keys(root, key_transform):
|
||||
def remap_nested_keys(root: Any, key_transform: Callable[[str], str]) -> Any:
|
||||
"""This remap ("recursive map") function is used to traverse and
|
||||
transform the dictionary keys of arbitrarily nested structures.
|
||||
List comprehensions do not recurse, making it tedious to apply
|
||||
@ -252,7 +258,9 @@ def remap_nested_keys(root, key_transform):
|
||||
return root
|
||||
|
||||
|
||||
def merge_dicts(dict1, dict2, remove_nulls=False):
|
||||
def merge_dicts(
|
||||
dict1: Dict[str, Any], dict2: Dict[str, Any], remove_nulls: bool = False
|
||||
) -> None:
|
||||
"""Given two arbitrarily nested dictionaries, merge the second dict into the first.
|
||||
|
||||
:param dict dict1: the dictionary to be updated.
|
||||
@ -275,7 +283,7 @@ def merge_dicts(dict1, dict2, remove_nulls=False):
|
||||
dict1.pop(key)
|
||||
|
||||
|
||||
def aws_api_matches(pattern, string):
|
||||
def aws_api_matches(pattern: str, string: str) -> bool:
|
||||
"""
|
||||
AWS API can match a value based on a glob, or an exact match
|
||||
"""
|
||||
@ -296,7 +304,7 @@ def aws_api_matches(pattern, string):
|
||||
return False
|
||||
|
||||
|
||||
def extract_region_from_aws_authorization(string):
|
||||
def extract_region_from_aws_authorization(string: str) -> Optional[str]:
|
||||
auth = string or ""
|
||||
region = re.sub(r".*Credential=[^/]+/[^/]+/([^/]+)/.*", r"\1", auth)
|
||||
if region == auth:
|
||||
|
@ -18,7 +18,7 @@ disable = W,C,R,E
|
||||
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
|
||||
|
||||
[mypy]
|
||||
files= moto/a*,moto/b*,moto/ce,moto/cloudformation,moto/cloudfront,moto/cloudtrail,moto/codebuild,moto/cloudwatch,moto/codepipeline,moto/codecommit,moto/cognito*,moto/comprehend,moto/config,moto/core/base_backend.py
|
||||
files= moto/a*,moto/b*,moto/c*
|
||||
show_column_numbers=True
|
||||
show_error_codes = True
|
||||
disable_error_code=abstract
|
||||
|
Loading…
Reference in New Issue
Block a user