From 02477195543e9789362d93265c2426784d61034d Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Thu, 15 Jun 2023 11:03:58 +0000 Subject: [PATCH] Techdebt: MyPy all the things (#6406) --- moto/__init__.py | 28 ++++-- moto/apigateway/models.py | 2 +- moto/core/models.py | 25 +++-- moto/wafv2/exceptions.py | 4 +- moto/wafv2/models.py | 71 ++++++++++---- moto/wafv2/responses.py | 31 +++--- moto/wafv2/utils.py | 4 +- moto/xray/exceptions.py | 16 +++- moto/xray/mock_client.py | 31 +++--- moto/xray/models.py | 109 ++++++++++++--------- moto/xray/responses.py | 29 +++--- setup.cfg | 2 +- tests/test_core/test_decorator_calls.py | 121 ++++++++++++------------ tests/test_core/test_mock_all.py | 4 +- 14 files changed, 275 insertions(+), 202 deletions(-) diff --git a/moto/__init__.py b/moto/__init__.py index f1baa7596..7cd7e70ac 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -1,17 +1,27 @@ import importlib import sys from contextlib import ContextDecorator +from moto.core.models import BaseMockAWS +from typing import Any, Callable, List, Optional, TypeVar -def lazy_load(module_name, element, boto3_name=None, backend=None): - def f(*args, **kwargs): +TEST_METHOD = TypeVar("TEST_METHOD", bound=Callable[..., Any]) + + +def lazy_load( + module_name: str, + element: str, + boto3_name: Optional[str] = None, + backend: Optional[str] = None, +) -> Callable[..., BaseMockAWS]: + def f(*args: Any, **kwargs: Any) -> Any: module = importlib.import_module(module_name, "moto") return getattr(module, element)(*args, **kwargs) setattr(f, "name", module_name.replace(".", "")) setattr(f, "element", element) - setattr(f, "boto3_name", boto3_name or f.name) - setattr(f, "backend", backend or f"{f.name}_backends") + setattr(f, "boto3_name", boto3_name or f.name) # type: ignore[attr-defined] + setattr(f, "backend", backend or f"{f.name}_backends") # type: ignore[attr-defined] return f @@ -176,17 +186,17 @@ mock_textract = lazy_load(".textract", "mock_textract") class MockAll(ContextDecorator): - def __init__(self): - self.mocks = [] + def __init__(self) -> None: + self.mocks: List[Any] = [] for mock in dir(sys.modules["moto"]): - if mock.startswith("mock_") and not mock == ("mock_all"): + if mock.startswith("mock_") and not mock == "mock_all": self.mocks.append(globals()[mock]()) - def __enter__(self): + def __enter__(self) -> None: for mock in self.mocks: mock.start() - def __exit__(self, *exc): + def __exit__(self, *exc: Any) -> None: for mock in self.mocks: mock.stop() diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index c70b26f9b..87ba2d9ac 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -626,7 +626,7 @@ class Stage(BaseModel): self.tags = tags self.tracing_enabled = tracing_enabled self.access_log_settings: Optional[Dict[str, Any]] = None - self.web_acl_arn = None + self.web_acl_arn: Optional[str] = None def to_json(self) -> Dict[str, Any]: dct: Dict[str, Any] = { diff --git a/moto/core/models.py b/moto/core/models.py index 25db72e5e..7431e1348 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -5,7 +5,8 @@ import os import re import unittest from types import FunctionType -from typing import Any, Callable, Dict, Optional, Set, TypeVar +from typing import Any, Callable, Dict, Optional, Set, TypeVar, Union +from typing import ContextManager from unittest.mock import patch import boto3 @@ -29,7 +30,7 @@ DEFAULT_ACCOUNT_ID = "123456789012" CALLABLE_RETURN = TypeVar("CALLABLE_RETURN") -class BaseMockAWS: +class BaseMockAWS(ContextManager["BaseMockAWS"]): nested_count = 0 mocks_active = False @@ -65,10 +66,13 @@ class BaseMockAWS: self.reset() # type: ignore[attr-defined] def __call__( - self, func: Callable[..., Any], reset: bool = True, remove_data: bool = True - ) -> Any: + self, + func: Callable[..., "BaseMockAWS"], + reset: bool = True, + remove_data: bool = True, + ) -> Callable[..., "BaseMockAWS"]: if inspect.isclass(func): - return self.decorate_class(func) + return self.decorate_class(func) # type: ignore return self.decorate_callable(func, reset, remove_data) def __enter__(self) -> "BaseMockAWS": @@ -116,9 +120,9 @@ class BaseMockAWS: self.disable_patching() # type: ignore[attr-defined] def decorate_callable( - self, func: Callable[..., CALLABLE_RETURN], reset: bool, remove_data: bool - ) -> Callable[..., CALLABLE_RETURN]: - def wrapper(*args: Any, **kwargs: Any) -> CALLABLE_RETURN: + self, func: Callable[..., "BaseMockAWS"], reset: bool, remove_data: bool + ) -> Callable[..., "BaseMockAWS"]: + def wrapper(*args: Any, **kwargs: Any) -> "BaseMockAWS": self.start(reset=reset) try: result = func(*args, **kwargs) @@ -361,7 +365,6 @@ MockAWS = BotocoreEventMockAWS class ServerModeMockAWS(BaseMockAWS): - RESET_IN_PROGRESS = False def __init__(self, *args: Any, **kwargs: Any): @@ -427,7 +430,9 @@ class base_decorator: def __init__(self, backends: BackendDict): self.backends = backends - def __call__(self, func: Optional[Callable[..., Any]] = None) -> BaseMockAWS: + def __call__( + self, func: Optional[Callable[..., Any]] = None + ) -> Union[BaseMockAWS, Callable[..., BaseMockAWS]]: if settings.TEST_SERVER_MODE: mocked_backend: BaseMockAWS = ServerModeMockAWS(self.backends) else: diff --git a/moto/wafv2/exceptions.py b/moto/wafv2/exceptions.py index 30b5d5789..218816d94 100644 --- a/moto/wafv2/exceptions.py +++ b/moto/wafv2/exceptions.py @@ -6,7 +6,7 @@ class WAFv2ClientError(JsonRESTError): class WAFV2DuplicateItemException(WAFv2ClientError): - def __init__(self): + def __init__(self) -> None: super().__init__( "WafV2DuplicateItem", "AWS WAF could not perform the operation because some resource in your request is a duplicate of an existing one.", @@ -14,7 +14,7 @@ class WAFV2DuplicateItemException(WAFv2ClientError): class WAFNonexistentItemException(WAFv2ClientError): - def __init__(self): + def __init__(self) -> None: super().__init__( "WAFNonexistentItemException", "AWS WAF couldn’t perform the operation because your resource doesn’t exist.", diff --git a/moto/wafv2/models.py b/moto/wafv2/models.py index 788fc98cc..3ff346f57 100644 --- a/moto/wafv2/models.py +++ b/moto/wafv2/models.py @@ -1,6 +1,6 @@ import datetime import re -from typing import Dict +from typing import Any, Dict, List, Optional, TYPE_CHECKING from moto.core import BaseBackend, BackendDict, BaseModel from .utils import make_arn_for_wacl @@ -10,6 +10,9 @@ from moto.moto_api._internal import mock_random from moto.utilities.tagging_service import TaggingService from collections import OrderedDict +if TYPE_CHECKING: + from moto.apigateway.models import Stage + US_EAST_1_REGION = "us-east-1" GLOBAL_REGION = "global" @@ -25,7 +28,14 @@ class FakeWebACL(BaseModel): """ def __init__( - self, name, arn, wacl_id, visibility_config, default_action, description, rules + self, + name: str, + arn: str, + wacl_id: str, + visibility_config: Dict[str, Any], + default_action: Dict[str, Any], + description: Optional[str], + rules: List[Dict[str, Any]], ): self.name = name self.created_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) @@ -38,7 +48,13 @@ class FakeWebACL(BaseModel): self.default_action = default_action self.lock_token = str(mock_random.uuid4())[0:6] - def update(self, default_action, rules, description, visibility_config): + def update( + self, + default_action: Optional[Dict[str, Any]], + rules: Optional[List[Dict[str, Any]]], + description: Optional[str], + visibility_config: Optional[Dict[str, Any]], + ) -> None: if default_action is not None: self.default_action = default_action if rules is not None: @@ -49,7 +65,7 @@ class FakeWebACL(BaseModel): self.visibility_config = visibility_config self.lock_token = str(mock_random.uuid4())[0:6] - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: # Format for summary https://docs.aws.amazon.com/waf/latest/APIReference/API_CreateWebACL.html (response syntax section) return { "ARN": self.arn, @@ -67,13 +83,13 @@ class WAFV2Backend(BaseBackend): https://docs.aws.amazon.com/waf/latest/APIReference/API_Operations_AWS_WAFV2.html """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.wacls: Dict[str, FakeWebACL] = OrderedDict() self.tagging_service = TaggingService() # TODO: self.load_balancers = OrderedDict() - def associate_web_acl(self, web_acl_arn, resource_arn): + def associate_web_acl(self, web_acl_arn: str, resource_arn: str) -> None: """ Only APIGateway Stages can be associated at the moment. """ @@ -83,19 +99,19 @@ class WAFV2Backend(BaseBackend): if stage: stage.web_acl_arn = web_acl_arn - def disassociate_web_acl(self, resource_arn): + def disassociate_web_acl(self, resource_arn: str) -> None: stage = self._find_apigw_stage(resource_arn) if stage: stage.web_acl_arn = None - def get_web_acl_for_resource(self, resource_arn): + def get_web_acl_for_resource(self, resource_arn: str) -> Optional[FakeWebACL]: stage = self._find_apigw_stage(resource_arn) if stage and stage.web_acl_arn is not None: wacl_arn = stage.web_acl_arn return self.wacls.get(wacl_arn) return None - def _find_apigw_stage(self, resource_arn): + def _find_apigw_stage(self, resource_arn: str) -> Optional["Stage"]: # type: ignore try: if re.search(APIGATEWAY_REGEX, resource_arn): region = resource_arn.split(":")[3] @@ -110,8 +126,15 @@ class WAFV2Backend(BaseBackend): return None def create_web_acl( - self, name, visibility_config, default_action, scope, description, tags, rules - ): + self, + name: str, + visibility_config: Dict[str, Any], + default_action: Dict[str, Any], + scope: str, + description: str, + tags: List[Dict[str, str]], + rules: List[Dict[str, Any]], + ) -> FakeWebACL: """ The following parameters are not yet implemented: CustomResponseBodies, CaptchaConfig """ @@ -132,7 +155,7 @@ class WAFV2Backend(BaseBackend): self.tag_resource(arn, tags) return new_wacl - def delete_web_acl(self, name, _id): + def delete_web_acl(self, name: str, _id: str) -> None: """ The LockToken-parameter is not yet implemented """ @@ -142,37 +165,43 @@ class WAFV2Backend(BaseBackend): if wacl.name != name and wacl.id != _id } - def get_web_acl(self, name, _id) -> FakeWebACL: + def get_web_acl(self, name: str, _id: str) -> FakeWebACL: for wacl in self.wacls.values(): if wacl.name == name and wacl.id == _id: return wacl raise WAFNonexistentItemException - def list_web_acls(self): + def list_web_acls(self) -> List[Dict[str, Any]]: return [wacl.to_dict() for wacl in self.wacls.values()] - def _is_duplicate_name(self, name): + def _is_duplicate_name(self, name: str) -> bool: allWaclNames = set(wacl.name for wacl in self.wacls.values()) return name in allWaclNames - def list_rule_groups(self): + def list_rule_groups(self) -> List[Any]: return [] - def list_tags_for_resource(self, arn): + def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]: """ Pagination is not yet implemented """ return self.tagging_service.list_tags_for_resource(arn)["Tags"] - def tag_resource(self, arn, tags): + def tag_resource(self, arn: str, tags: List[Dict[str, str]]) -> None: self.tagging_service.tag_resource(arn, tags) - def untag_resource(self, arn, tag_keys): + def untag_resource(self, arn: str, tag_keys: List[str]) -> None: self.tagging_service.untag_resource_using_names(arn, tag_keys) def update_web_acl( - self, name, _id, default_action, rules, description, visibility_config - ): + self, + name: str, + _id: str, + default_action: Optional[Dict[str, Any]], + rules: Optional[List[Dict[str, Any]]], + description: Optional[str], + visibility_config: Optional[Dict[str, Any]], + ) -> str: """ The following parameters are not yet implemented: LockToken, CustomResponseBodies, CaptchaConfig """ diff --git a/moto/wafv2/responses.py b/moto/wafv2/responses.py index be327ea94..bcaaf92dd 100644 --- a/moto/wafv2/responses.py +++ b/moto/wafv2/responses.py @@ -1,19 +1,20 @@ import json +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse from moto.utilities.aws_headers import amzn_request_id -from .models import GLOBAL_REGION, wafv2_backends +from .models import GLOBAL_REGION, wafv2_backends, WAFV2Backend class WAFV2Response(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="wafv2") @property - def wafv2_backend(self): + def wafv2_backend(self) -> WAFV2Backend: return wafv2_backends[self.current_account][self.region] @amzn_request_id - def associate_web_acl(self): + def associate_web_acl(self) -> TYPE_RESPONSE: body = json.loads(self.body) web_acl_arn = body["WebACLArn"] resource_arn = body["ResourceArn"] @@ -21,14 +22,14 @@ class WAFV2Response(BaseResponse): return 200, {}, "{}" @amzn_request_id - def disassociate_web_acl(self): + def disassociate_web_acl(self) -> TYPE_RESPONSE: body = json.loads(self.body) resource_arn = body["ResourceArn"] self.wafv2_backend.disassociate_web_acl(resource_arn) return 200, {}, "{}" @amzn_request_id - def get_web_acl_for_resource(self): + def get_web_acl_for_resource(self) -> TYPE_RESPONSE: body = json.loads(self.body) resource_arn = body["ResourceArn"] web_acl = self.wafv2_backend.get_web_acl_for_resource(resource_arn) @@ -37,7 +38,7 @@ class WAFV2Response(BaseResponse): return 200, response_headers, json.dumps(response) @amzn_request_id - def create_web_acl(self): + def create_web_acl(self) -> TYPE_RESPONSE: """https://docs.aws.amazon.com/waf/latest/APIReference/API_CreateWebACL.html (response syntax section)""" scope = self._get_param("Scope") @@ -62,7 +63,7 @@ class WAFV2Response(BaseResponse): return 200, response_headers, json.dumps(response) @amzn_request_id - def delete_web_acl(self): + def delete_web_acl(self) -> TYPE_RESPONSE: scope = self._get_param("Scope") if scope == "CLOUDFRONT": self.region = GLOBAL_REGION @@ -73,7 +74,7 @@ class WAFV2Response(BaseResponse): return 200, response_headers, "{}" @amzn_request_id - def get_web_acl(self): + def get_web_acl(self) -> TYPE_RESPONSE: scope = self._get_param("Scope") if scope == "CLOUDFRONT": self.region = GLOBAL_REGION @@ -85,7 +86,7 @@ class WAFV2Response(BaseResponse): return 200, response_headers, json.dumps(response) @amzn_request_id - def list_web_ac_ls(self): + def list_web_ac_ls(self) -> TYPE_RESPONSE: """https://docs.aws.amazon.com/waf/latest/APIReference/API_ListWebACLs.html (response syntax section)""" scope = self._get_param("Scope") @@ -97,7 +98,7 @@ class WAFV2Response(BaseResponse): return 200, response_headers, json.dumps(response) @amzn_request_id - def list_rule_groups(self): + def list_rule_groups(self) -> TYPE_RESPONSE: scope = self._get_param("Scope") if scope == "CLOUDFRONT": self.region = GLOBAL_REGION @@ -107,7 +108,7 @@ class WAFV2Response(BaseResponse): return 200, response_headers, json.dumps(response) @amzn_request_id - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> TYPE_RESPONSE: arn = self._get_param("ResourceARN") self.region = arn.split(":")[3] tags = self.wafv2_backend.list_tags_for_resource(arn) @@ -116,7 +117,7 @@ class WAFV2Response(BaseResponse): return 200, response_headers, json.dumps(response) @amzn_request_id - def tag_resource(self): + def tag_resource(self) -> TYPE_RESPONSE: body = json.loads(self.body) arn = body.get("ResourceARN") self.region = arn.split(":")[3] @@ -125,7 +126,7 @@ class WAFV2Response(BaseResponse): return 200, {}, "{}" @amzn_request_id - def untag_resource(self): + def untag_resource(self) -> TYPE_RESPONSE: body = json.loads(self.body) arn = body.get("ResourceARN") self.region = arn.split(":")[3] @@ -134,7 +135,7 @@ class WAFV2Response(BaseResponse): return 200, {}, "{}" @amzn_request_id - def update_web_acl(self): + def update_web_acl(self) -> TYPE_RESPONSE: body = json.loads(self.body) name = body.get("Name") _id = body.get("Id") diff --git a/moto/wafv2/utils.py b/moto/wafv2/utils.py index 9be18951f..c746ff614 100644 --- a/moto/wafv2/utils.py +++ b/moto/wafv2/utils.py @@ -1,4 +1,6 @@ -def make_arn_for_wacl(name, account_id, region_name, wacl_id, scope): +def make_arn_for_wacl( + name: str, account_id: str, region_name: str, wacl_id: str, scope: str +) -> str: """https://docs.aws.amazon.com/waf/latest/developerguide/how-aws-waf-works.html - explains --scope (cloudfront vs regional)""" if scope == "REGIONAL": diff --git a/moto/xray/exceptions.py b/moto/xray/exceptions.py index 5462bcf33..f56295f77 100644 --- a/moto/xray/exceptions.py +++ b/moto/xray/exceptions.py @@ -1,14 +1,22 @@ +from typing import Any, Dict, Optional + + class BadSegmentException(Exception): - def __init__(self, seg_id=None, code=None, message=None): + def __init__( + self, + seg_id: Optional[str] = None, + code: Optional[str] = None, + message: Optional[str] = None, + ): self.id = seg_id self.code = code self.message = message - def __repr__(self): + def __repr__(self) -> str: return f"" - def to_dict(self): - result = {} + def to_dict(self) -> Dict[str, Any]: + result: Dict[str, Any] = {} if self.id is not None: result["Id"] = self.id if self.code is not None: diff --git a/moto/xray/mock_client.py b/moto/xray/mock_client.py index eca1ffe5b..ffad067ea 100644 --- a/moto/xray/mock_client.py +++ b/moto/xray/mock_client.py @@ -1,25 +1,26 @@ import os -from moto.xray import xray_backends +from typing import Any +from moto.xray.models import xray_backends, XRayBackend import aws_xray_sdk.core from aws_xray_sdk.core.context import Context as AWSContext from aws_xray_sdk.core.emitters.udp_emitter import UDPEmitter -class MockEmitter(UDPEmitter): +class MockEmitter(UDPEmitter): # type: ignore """ Replaces the code that sends UDP to local X-Ray daemon """ - def __init__(self, daemon_address="127.0.0.1:2000"): + def __init__(self, daemon_address: str = "127.0.0.1:2000"): address = os.getenv( "AWS_XRAY_DAEMON_ADDRESS_YEAH_NOT_TODAY_MATE", daemon_address ) self._ip, self._port = self._parse_address(address) - def _xray_backend(self, region): - return xray_backends[region] + def _xray_backend(self, account_id: str, region: str) -> XRayBackend: + return xray_backends[account_id][region] - def send_entity(self, entity): + def send_entity(self, entity: Any) -> None: # Hack to get region # region = entity.subsegments[0].aws['region'] # xray = self._xray_backend(region) @@ -27,7 +28,7 @@ class MockEmitter(UDPEmitter): # TODO store X-Ray data, pretty sure X-Ray needs refactor for this pass - def _send_data(self, data): + def _send_data(self, data: Any) -> None: raise RuntimeError("Should not be running this") @@ -42,11 +43,11 @@ class MockXrayClient: that itno the recorder instance. """ - def __call__(self, f=None): + def __call__(self, f: Any = None) -> Any: if not f: return self - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: Any, **kwargs: Any) -> Any: self.start() try: f(*args, **kwargs) @@ -55,7 +56,7 @@ class MockXrayClient: return wrapped_f - def start(self): + def start(self) -> None: print("Starting X-Ray Patch") # noqa self.old_xray_context_var = os.environ.get("AWS_XRAY_CONTEXT_MISSING") os.environ["AWS_XRAY_CONTEXT_MISSING"] = "LOG_ERROR" @@ -64,7 +65,7 @@ class MockXrayClient: aws_xray_sdk.core.xray_recorder._context = AWSContext() aws_xray_sdk.core.xray_recorder._emitter = MockEmitter() - def stop(self): + def stop(self) -> None: if self.old_xray_context_var is None: del os.environ["AWS_XRAY_CONTEXT_MISSING"] else: @@ -73,11 +74,11 @@ class MockXrayClient: aws_xray_sdk.core.xray_recorder._emitter = self.old_xray_emitter aws_xray_sdk.core.xray_recorder._context = self.old_xray_context - def __enter__(self): + def __enter__(self) -> "MockXrayClient": self.start() return self - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: self.stop() @@ -91,12 +92,12 @@ class XRaySegment(object): During testing we're going to have to control the start and end of a segment via context managers. """ - def __enter__(self): + def __enter__(self) -> "XRaySegment": aws_xray_sdk.core.xray_recorder.begin_segment( name="moto_mock", traceid=None, parent_id=None, sampling=1 ) return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: aws_xray_sdk.core.xray_recorder.end_segment() diff --git a/moto/xray/models.py b/moto/xray/models.py index a1eb40e49..76bee1dd9 100644 --- a/moto/xray/models.py +++ b/moto/xray/models.py @@ -1,6 +1,7 @@ import bisect import datetime from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple import json from moto.core import BaseBackend, BackendDict, BaseModel from moto.core.exceptions import AWSError @@ -8,54 +9,60 @@ from .exceptions import BadSegmentException class TelemetryRecords(BaseModel): - def __init__(self, instance_id, hostname, resource_arn, records): + def __init__( + self, + instance_id: str, + hostname: str, + resource_arn: str, + records: List[Dict[str, Any]], + ): self.instance_id = instance_id self.hostname = hostname self.resource_arn = resource_arn self.records = records @classmethod - def from_json(cls, src): + def from_json(cls, src: Dict[str, Any]) -> "TelemetryRecords": # type: ignore[misc] instance_id = src.get("EC2InstanceId", None) hostname = src.get("Hostname") resource_arn = src.get("ResourceARN") telemetry_records = src["TelemetryRecords"] - return cls(instance_id, hostname, resource_arn, telemetry_records) + return cls(instance_id, hostname, resource_arn, telemetry_records) # type: ignore # https://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html class TraceSegment(BaseModel): def __init__( self, - name, - segment_id, - trace_id, - start_time, - raw, - end_time=None, - in_progress=False, - service=None, - user=None, - origin=None, - parent_id=None, - http=None, - aws=None, - metadata=None, - annotations=None, - subsegments=None, - **kwargs + name: str, + segment_id: str, + trace_id: str, + start_time: float, + raw: Any, + end_time: Optional[float] = None, + in_progress: bool = False, + service: Any = None, + user: Any = None, + origin: Any = None, + parent_id: Any = None, + http: Any = None, + aws: Any = None, + metadata: Any = None, + annotations: Any = None, + subsegments: Any = None, + **kwargs: Any ): self.name = name self.id = segment_id self.trace_id = trace_id - self._trace_version = None - self._original_request_start_time = None + self._trace_version: Optional[int] = None + self._original_request_start_time: Optional[datetime.datetime] = None self._trace_identifier = None self.start_time = start_time - self._start_date = None + self._start_date: Optional[datetime.datetime] = None self.end_time = end_time - self._end_date = None + self._end_date: Optional[datetime.datetime] = None self.in_progress = in_progress self.service = service self.user = user @@ -71,17 +78,17 @@ class TraceSegment(BaseModel): # Raw json string self.raw = raw - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: return self.start_date < other.start_date @property - def trace_version(self): + def trace_version(self) -> int: if self._trace_version is None: self._trace_version = int(self.trace_id.split("-", 1)[0]) return self._trace_version @property - def request_start_date(self): + def request_start_date(self) -> datetime.datetime: if self._original_request_start_time is None: start_time = int(self.trace_id.split("-")[1], 16) self._original_request_start_time = datetime.datetime.fromtimestamp( @@ -90,19 +97,19 @@ class TraceSegment(BaseModel): return self._original_request_start_time @property - def start_date(self): + def start_date(self) -> datetime.datetime: if self._start_date is None: self._start_date = datetime.datetime.fromtimestamp(self.start_time) return self._start_date @property - def end_date(self): + def end_date(self) -> datetime.datetime: if self._end_date is None: - self._end_date = datetime.datetime.fromtimestamp(self.end_time) + self._end_date = datetime.datetime.fromtimestamp(self.end_time) # type: ignore return self._end_date @classmethod - def from_dict(cls, data, raw): + def from_dict(cls, data: Dict[str, Any], raw: Any) -> "TraceSegment": # type: ignore[misc] # Check manditory args if "id" not in data: raise BadSegmentException(code="MissingParam", message="Missing segment ID") @@ -130,11 +137,11 @@ class TraceSegment(BaseModel): class SegmentCollection(object): - def __init__(self): - self._traces = defaultdict(self._new_trace_item) + def __init__(self) -> None: + self._traces: Dict[str, Dict[str, Any]] = defaultdict(self._new_trace_item) @staticmethod - def _new_trace_item(): + def _new_trace_item() -> Dict[str, Any]: # type: ignore[misc] return { "start_date": datetime.datetime(1970, 1, 1), "end_date": datetime.datetime(1970, 1, 1), @@ -143,7 +150,7 @@ class SegmentCollection(object): "segments": [], } - def put_segment(self, segment): + def put_segment(self, segment: Any) -> None: # insert into a sorted list bisect.insort_left(self._traces[segment.trace_id]["segments"], segment) @@ -160,11 +167,15 @@ class SegmentCollection(object): # Todo consolidate trace segments into a trace. # not enough working knowledge of xray to do this - def summary(self, start_time, end_time, filter_expression=None): + def summary( + self, start_time: str, end_time: str, filter_expression: Any = None + ) -> Dict[str, Any]: # This beast https://docs.aws.amazon.com/xray/latest/api/API_GetTraceSummaries.html#API_GetTraceSummaries_ResponseSyntax if filter_expression is not None: raise AWSError( - "Not implemented yet - moto", code="InternalFailure", status=500 + "Not implemented yet - moto", + exception_type="InternalFailure", + status=500, ) summaries = [] @@ -213,7 +224,9 @@ class SegmentCollection(object): return result - def get_trace_ids(self, trace_ids): + def get_trace_ids( + self, trace_ids: List[str] + ) -> Tuple[List[Dict[str, Any]], List[str]]: traces = [] unprocessed = [] @@ -229,22 +242,24 @@ class SegmentCollection(object): class XRayBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._telemetry_records = [] + self._telemetry_records: List[TelemetryRecords] = [] self._segment_collection = SegmentCollection() @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "xray" ) - def add_telemetry_records(self, src): + def add_telemetry_records(self, src: Any) -> None: self._telemetry_records.append(TelemetryRecords.from_json(src)) - def process_segment(self, doc): + def process_segment(self, doc: Any) -> None: try: data = json.loads(doc) except ValueError: @@ -264,13 +279,15 @@ class XRayBackend(BaseBackend): seg_id=segment.id, code="InternalFailure", message=str(err) ) - def get_trace_summary(self, start_time, end_time, filter_expression): + def get_trace_summary( + self, start_time: str, end_time: str, filter_expression: Any + ) -> Dict[str, Any]: return self._segment_collection.summary(start_time, end_time, filter_expression) - def get_trace_ids(self, trace_ids): + def get_trace_ids(self, trace_ids: List[str]) -> Dict[str, Any]: traces, unprocessed_ids = self._segment_collection.get_trace_ids(trace_ids) - result = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids} + result: Dict[str, Any] = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids} for trace in traces: segments = [] diff --git a/moto/xray/responses.py b/moto/xray/responses.py index 4b1d1c548..660e9559e 100644 --- a/moto/xray/responses.py +++ b/moto/xray/responses.py @@ -1,49 +1,50 @@ import json import datetime +from typing import Any, Dict, Tuple, Union from moto.core.responses import BaseResponse from moto.core.exceptions import AWSError from urllib.parse import urlsplit -from .models import xray_backends +from .models import xray_backends, XRayBackend from .exceptions import BadSegmentException class XRayResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="xray") - def _error(self, code, message): + def _error(self, code: str, message: str) -> Tuple[str, Dict[str, int]]: return json.dumps({"__type": code, "message": message}), dict(status=400) @property - def xray_backend(self): + def xray_backend(self) -> XRayBackend: return xray_backends[self.current_account][self.region] @property - def request_params(self): + def request_params(self) -> Any: # type: ignore[misc] try: return json.loads(self.body) except ValueError: return {} - def _get_param(self, param_name, if_none=None): + def _get_param(self, param_name: str, if_none: Any = None) -> Any: return self.request_params.get(param_name, if_none) - def _get_action(self): + def _get_action(self) -> str: # Amazon is just calling urls like /TelemetryRecords etc... # This uses the value after / as the camalcase action, which then # gets converted in call_action to find the following methods return urlsplit(self.uri).path.lstrip("/") # PutTelemetryRecords - def telemetry_records(self): + def telemetry_records(self) -> str: self.xray_backend.add_telemetry_records(self.request_params) return "" # PutTraceSegments - def trace_segments(self): + def trace_segments(self) -> Union[str, Tuple[str, Dict[str, int]]]: docs = self._get_param("TraceSegmentDocuments") if docs is None: @@ -71,7 +72,7 @@ class XRayResponse(BaseResponse): return json.dumps(result) # GetTraceSummaries - def trace_summaries(self): + def trace_summaries(self) -> Union[str, Tuple[str, Dict[str, int]]]: start_time = self._get_param("StartTime") end_time = self._get_param("EndTime") if start_time is None: @@ -119,7 +120,7 @@ class XRayResponse(BaseResponse): return json.dumps(result) # BatchGetTraces - def traces(self): + def traces(self) -> Union[str, Tuple[str, Dict[str, int]]]: trace_ids = self._get_param("TraceIds") if trace_ids is None: @@ -142,7 +143,7 @@ class XRayResponse(BaseResponse): return json.dumps(result) # GetServiceGraph - just a dummy response for now - def service_graph(self): + def service_graph(self) -> Union[str, Tuple[str, Dict[str, int]]]: start_time = self._get_param("StartTime") end_time = self._get_param("EndTime") # next_token = self._get_param('NextToken') # not implemented yet @@ -164,7 +165,7 @@ class XRayResponse(BaseResponse): return json.dumps(result) # GetTraceGraph - just a dummy response for now - def trace_graph(self): + def trace_graph(self) -> Union[str, Tuple[str, Dict[str, int]]]: trace_ids = self._get_param("TraceIds") # next_token = self._get_param('NextToken') # not implemented yet @@ -175,5 +176,5 @@ class XRayResponse(BaseResponse): dict(status=400), ) - result = {"Services": []} + result: Dict[str, Any] = {"Services": []} return json.dumps(result) diff --git a/setup.cfg b/setup.cfg index 103fb04d5..b90875a0c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -239,7 +239,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u*,moto/t* +files= moto, tests/test_core/test_mock_all.py, tests/test_core/test_decorator_calls.py show_column_numbers=True show_error_codes = True disable_error_code=abstract diff --git a/tests/test_core/test_decorator_calls.py b/tests/test_core/test_decorator_calls.py index 844f11b5c..303ee94ea 100644 --- a/tests/test_core/test_decorator_calls.py +++ b/tests/test_core/test_decorator_calls.py @@ -1,10 +1,11 @@ import boto3 import pytest -import sure # noqa # pylint: disable=unused-import import unittest from botocore.exceptions import ClientError +from typing import Any from moto import mock_ec2, mock_kinesis, mock_s3, settings +from moto.core.models import BaseMockAWS from unittest import SkipTest """ @@ -13,13 +14,13 @@ Test the different ways that the decorator can be used @mock_ec2 -def test_basic_decorator(): +def test_basic_decorator() -> None: client = boto3.client("ec2", region_name="us-west-1") - client.describe_addresses()["Addresses"].should.equal([]) + assert client.describe_addresses()["Addresses"] == [] @pytest.fixture(name="aws_credentials") -def fixture_aws_credentials(monkeypatch): +def fixture_aws_credentials(monkeypatch: Any) -> None: # type: ignore[misc] """Mocked AWS Credentials for moto.""" monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing") @@ -28,97 +29,95 @@ def fixture_aws_credentials(monkeypatch): @pytest.mark.network -def test_context_manager(aws_credentials): # pylint: disable=unused-argument +def test_context_manager(aws_credentials: Any) -> None: # type: ignore[misc] # pylint: disable=unused-argument client = boto3.client("ec2", region_name="us-west-1") with pytest.raises(ClientError) as exc: client.describe_addresses() err = exc.value.response["Error"] - err["Code"].should.equal("AuthFailure") - err["Message"].should.equal( - "AWS was not able to validate the provided access credentials" + assert err["Code"] == "AuthFailure" + assert ( + err["Message"] == "AWS was not able to validate the provided access credentials" ) with mock_ec2(): client = boto3.client("ec2", region_name="us-west-1") - client.describe_addresses()["Addresses"].should.equal([]) + assert client.describe_addresses()["Addresses"] == [] @pytest.mark.network -def test_decorator_start_and_stop(): +def test_decorator_start_and_stop() -> None: if settings.TEST_SERVER_MODE: raise SkipTest("Authentication always works in ServerMode") - mock = mock_ec2() + mock: BaseMockAWS = mock_ec2() mock.start() client = boto3.client("ec2", region_name="us-west-1") - client.describe_addresses()["Addresses"].should.equal([]) + assert client.describe_addresses()["Addresses"] == [] mock.stop() with pytest.raises(ClientError) as exc: client.describe_addresses() err = exc.value.response["Error"] - err["Code"].should.equal("AuthFailure") - err["Message"].should.equal( - "AWS was not able to validate the provided access credentials" + assert err["Code"] == "AuthFailure" + assert ( + err["Message"] == "AWS was not able to validate the provided access credentials" ) @mock_ec2 -def test_decorater_wrapped_gets_set(): +def test_decorater_wrapped_gets_set() -> None: """ Moto decorator's __wrapped__ should get set to the tests function """ - test_decorater_wrapped_gets_set.__wrapped__.__name__.should.equal( - "test_decorater_wrapped_gets_set" - ) + assert test_decorater_wrapped_gets_set.__wrapped__.__name__ == "test_decorater_wrapped_gets_set" # type: ignore @mock_ec2 -class Tester(object): - def test_the_class(self): +class Tester: + def test_the_class(self) -> None: client = boto3.client("ec2", region_name="us-west-1") - client.describe_addresses()["Addresses"].should.equal([]) + assert client.describe_addresses()["Addresses"] == [] - def test_still_the_same(self): + def test_still_the_same(self) -> None: client = boto3.client("ec2", region_name="us-west-1") - client.describe_addresses()["Addresses"].should.equal([]) + assert client.describe_addresses()["Addresses"] == [] @mock_s3 class TesterWithSetup(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.client = boto3.client("s3") self.client.create_bucket(Bucket="mybucket") - def test_still_the_same(self): + def test_still_the_same(self) -> None: buckets = self.client.list_buckets()["Buckets"] bucket_names = [b["Name"] for b in buckets] # There is a potential bug in the class-decorator, where the reset API is not called on start. # This leads to a situation where 'bucket_names' may contain buckets created by earlier tests - bucket_names.should.contain("mybucket") + assert "mybucket" in bucket_names @mock_s3 -class TesterWithStaticmethod(object): +class TesterWithStaticmethod: @staticmethod - def static(*args): + def static(*args: Any) -> None: # type: ignore[misc] assert not args or not isinstance(args[0], TesterWithStaticmethod) - def test_no_instance_sent_to_staticmethod(self): + def test_no_instance_sent_to_staticmethod(self) -> None: self.static() @mock_s3 class TestWithSetup_UppercaseU(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # This method will be executed automatically, provided we extend the TestCase-class s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="mybucket") - def test_should_find_bucket(self): + def test_should_find_bucket(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") self.assertIsNotNone(s3.head_bucket(Bucket="mybucket")) - def test_should_not_find_unknown_bucket(self): + def test_should_not_find_unknown_bucket(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") with pytest.raises(ClientError): s3.head_bucket(Bucket="unknown_bucket") @@ -126,16 +125,16 @@ class TestWithSetup_UppercaseU(unittest.TestCase): @mock_s3 class TestWithSetup_LowercaseU: - def setup_method(self, *args): # pylint: disable=unused-argument + def setup_method(self, *args: Any) -> None: # pylint: disable=unused-argument # This method will be executed automatically using pytest s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="mybucket") - def test_should_find_bucket(self): + def test_should_find_bucket(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") assert s3.head_bucket(Bucket="mybucket") is not None - def test_should_not_find_unknown_bucket(self): + def test_should_not_find_unknown_bucket(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") with pytest.raises(ClientError): s3.head_bucket(Bucket="unknown_bucket") @@ -143,16 +142,16 @@ class TestWithSetup_LowercaseU: @mock_s3 class TestWithSetupMethod: - def setup_method(self, *args): # pylint: disable=unused-argument + def setup_method(self, *args: Any) -> None: # pylint: disable=unused-argument # This method will be executed automatically using pytest s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="mybucket") - def test_should_find_bucket(self): + def test_should_find_bucket(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") assert s3.head_bucket(Bucket="mybucket") is not None - def test_should_not_find_unknown_bucket(self): + def test_should_not_find_unknown_bucket(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") with pytest.raises(ClientError): s3.head_bucket(Bucket="unknown_bucket") @@ -160,17 +159,17 @@ class TestWithSetupMethod: @mock_kinesis class TestKinesisUsingSetupMethod: - def setup_method(self, *args): # pylint: disable=unused-argument + def setup_method(self, *args: Any) -> None: # pylint: disable=unused-argument self.stream_name = "test_stream" self.boto3_kinesis_client = boto3.client("kinesis", region_name="us-east-1") self.boto3_kinesis_client.create_stream( StreamName=self.stream_name, ShardCount=1 ) - def test_stream_creation(self): + def test_stream_creation(self) -> None: pass - def test_stream_recreation(self): + def test_stream_recreation(self) -> None: # The setup-method will run again for this test # The fact that it passes, means the state was reset # Otherwise it would complain about a stream already existing @@ -179,11 +178,11 @@ class TestKinesisUsingSetupMethod: @mock_s3 class TestWithInvalidSetupMethod: - def setupmethod(self): + def setupmethod(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="mybucket") - def test_should_not_find_bucket(self): + def test_should_not_find_bucket(self) -> None: # Name of setupmethod is not recognized, so it will not be executed s3 = boto3.client("s3", region_name="us-east-1") with pytest.raises(ClientError): @@ -192,17 +191,17 @@ class TestWithInvalidSetupMethod: @mock_s3 class TestWithPublicMethod(unittest.TestCase): - def ensure_bucket_exists(self): + def ensure_bucket_exists(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="mybucket") - def test_should_find_bucket(self): + def test_should_find_bucket(self) -> None: self.ensure_bucket_exists() s3 = boto3.client("s3", region_name="us-east-1") s3.head_bucket(Bucket="mybucket").shouldnt.equal(None) - def test_should_not_find_bucket(self): + def test_should_not_find_bucket(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") with pytest.raises(ClientError): s3.head_bucket(Bucket="mybucket") @@ -210,16 +209,16 @@ class TestWithPublicMethod(unittest.TestCase): @mock_s3 class TestWithPseudoPrivateMethod(unittest.TestCase): - def _ensure_bucket_exists(self): + def _ensure_bucket_exists(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="mybucket") - def test_should_find_bucket(self): + def test_should_find_bucket(self) -> None: self._ensure_bucket_exists() s3 = boto3.client("s3", region_name="us-east-1") s3.head_bucket(Bucket="mybucket").shouldnt.equal(None) - def test_should_not_find_bucket(self): + def test_should_not_find_bucket(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") with pytest.raises(ClientError): s3.head_bucket(Bucket="mybucket") @@ -227,20 +226,20 @@ class TestWithPseudoPrivateMethod(unittest.TestCase): @mock_s3 class Baseclass(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.s3 = boto3.resource("s3", region_name="us-east-1") self.client = boto3.client("s3", region_name="us-east-1") self.test_bucket = self.s3.Bucket("testbucket") self.test_bucket.create() - def tearDown(self): + def tearDown(self) -> None: # The bucket will still exist at this point self.test_bucket.delete() @mock_s3 class TestSetUpInBaseClass(Baseclass): - def test_a_thing(self): + def test_a_thing(self) -> None: # Verify that we can 'see' the setUp-method in the parent class self.client.head_bucket(Bucket="testbucket").shouldnt.equal(None) @@ -248,42 +247,42 @@ class TestSetUpInBaseClass(Baseclass): @mock_s3 class TestWithNestedClasses: class NestedClass(unittest.TestCase): - def _ensure_bucket_exists(self): + def _ensure_bucket_exists(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="bucketclass1") - def test_should_find_bucket(self): + def test_should_find_bucket(self) -> None: self._ensure_bucket_exists() s3 = boto3.client("s3", region_name="us-east-1") s3.head_bucket(Bucket="bucketclass1") class NestedClass2(unittest.TestCase): - def _ensure_bucket_exists(self): + def _ensure_bucket_exists(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="bucketclass2") - def test_should_find_bucket(self): + def test_should_find_bucket(self) -> None: self._ensure_bucket_exists() s3 = boto3.client("s3", region_name="us-east-1") s3.head_bucket(Bucket="bucketclass2") - def test_should_not_find_bucket_from_different_class(self): + def test_should_not_find_bucket_from_different_class(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") with pytest.raises(ClientError): s3.head_bucket(Bucket="bucketclass1") class TestWithSetup(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="mybucket") - def test_should_find_bucket(self): + def test_should_find_bucket(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") s3.head_bucket(Bucket="mybucket") s3.create_bucket(Bucket="bucketinsidetest") - def test_should_not_find_bucket_from_test_method(self): + def test_should_not_find_bucket_from_test_method(self) -> None: s3 = boto3.client("s3", region_name="us-east-1") s3.head_bucket(Bucket="mybucket") diff --git a/tests/test_core/test_mock_all.py b/tests/test_core/test_mock_all.py index 1a7e07f83..6e7f227f8 100644 --- a/tests/test_core/test_mock_all.py +++ b/tests/test_core/test_mock_all.py @@ -6,7 +6,7 @@ from moto import mock_all @mock_all() -def test_decorator(): +def test_decorator() -> None: rgn = "us-east-1" sqs = boto3.client("sqs", region_name=rgn) r = sqs.list_queues() @@ -21,7 +21,7 @@ def test_decorator(): r["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) -def test_context_manager(): +def test_context_manager() -> None: rgn = "us-east-1" with mock_all():