Techdebt: MyPy all the things (#6406)

This commit is contained in:
Bert Blommers 2023-06-15 11:03:58 +00:00 committed by GitHub
parent 2e3b06bbe5
commit 0247719554
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 275 additions and 202 deletions

View File

@ -1,17 +1,27 @@
import importlib
import sys
from contextlib import ContextDecorator
from moto.core.models import BaseMockAWS
from typing import Any, Callable, List, Optional, TypeVar
def lazy_load(module_name, element, boto3_name=None, backend=None):
def f(*args, **kwargs):
TEST_METHOD = TypeVar("TEST_METHOD", bound=Callable[..., Any])
def lazy_load(
module_name: str,
element: str,
boto3_name: Optional[str] = None,
backend: Optional[str] = None,
) -> Callable[..., BaseMockAWS]:
def f(*args: Any, **kwargs: Any) -> Any:
module = importlib.import_module(module_name, "moto")
return getattr(module, element)(*args, **kwargs)
setattr(f, "name", module_name.replace(".", ""))
setattr(f, "element", element)
setattr(f, "boto3_name", boto3_name or f.name)
setattr(f, "backend", backend or f"{f.name}_backends")
setattr(f, "boto3_name", boto3_name or f.name) # type: ignore[attr-defined]
setattr(f, "backend", backend or f"{f.name}_backends") # type: ignore[attr-defined]
return f
@ -176,17 +186,17 @@ mock_textract = lazy_load(".textract", "mock_textract")
class MockAll(ContextDecorator):
def __init__(self):
self.mocks = []
def __init__(self) -> None:
self.mocks: List[Any] = []
for mock in dir(sys.modules["moto"]):
if mock.startswith("mock_") and not mock == ("mock_all"):
if mock.startswith("mock_") and not mock == "mock_all":
self.mocks.append(globals()[mock]())
def __enter__(self):
def __enter__(self) -> None:
for mock in self.mocks:
mock.start()
def __exit__(self, *exc):
def __exit__(self, *exc: Any) -> None:
for mock in self.mocks:
mock.stop()

View File

@ -626,7 +626,7 @@ class Stage(BaseModel):
self.tags = tags
self.tracing_enabled = tracing_enabled
self.access_log_settings: Optional[Dict[str, Any]] = None
self.web_acl_arn = None
self.web_acl_arn: Optional[str] = None
def to_json(self) -> Dict[str, Any]:
dct: Dict[str, Any] = {

View File

@ -5,7 +5,8 @@ import os
import re
import unittest
from types import FunctionType
from typing import Any, Callable, Dict, Optional, Set, TypeVar
from typing import Any, Callable, Dict, Optional, Set, TypeVar, Union
from typing import ContextManager
from unittest.mock import patch
import boto3
@ -29,7 +30,7 @@ DEFAULT_ACCOUNT_ID = "123456789012"
CALLABLE_RETURN = TypeVar("CALLABLE_RETURN")
class BaseMockAWS:
class BaseMockAWS(ContextManager["BaseMockAWS"]):
nested_count = 0
mocks_active = False
@ -65,10 +66,13 @@ class BaseMockAWS:
self.reset() # type: ignore[attr-defined]
def __call__(
self, func: Callable[..., Any], reset: bool = True, remove_data: bool = True
) -> Any:
self,
func: Callable[..., "BaseMockAWS"],
reset: bool = True,
remove_data: bool = True,
) -> Callable[..., "BaseMockAWS"]:
if inspect.isclass(func):
return self.decorate_class(func)
return self.decorate_class(func) # type: ignore
return self.decorate_callable(func, reset, remove_data)
def __enter__(self) -> "BaseMockAWS":
@ -116,9 +120,9 @@ class BaseMockAWS:
self.disable_patching() # type: ignore[attr-defined]
def decorate_callable(
self, func: Callable[..., CALLABLE_RETURN], reset: bool, remove_data: bool
) -> Callable[..., CALLABLE_RETURN]:
def wrapper(*args: Any, **kwargs: Any) -> CALLABLE_RETURN:
self, func: Callable[..., "BaseMockAWS"], reset: bool, remove_data: bool
) -> Callable[..., "BaseMockAWS"]:
def wrapper(*args: Any, **kwargs: Any) -> "BaseMockAWS":
self.start(reset=reset)
try:
result = func(*args, **kwargs)
@ -361,7 +365,6 @@ MockAWS = BotocoreEventMockAWS
class ServerModeMockAWS(BaseMockAWS):
RESET_IN_PROGRESS = False
def __init__(self, *args: Any, **kwargs: Any):
@ -427,7 +430,9 @@ class base_decorator:
def __init__(self, backends: BackendDict):
self.backends = backends
def __call__(self, func: Optional[Callable[..., Any]] = None) -> BaseMockAWS:
def __call__(
self, func: Optional[Callable[..., Any]] = None
) -> Union[BaseMockAWS, Callable[..., BaseMockAWS]]:
if settings.TEST_SERVER_MODE:
mocked_backend: BaseMockAWS = ServerModeMockAWS(self.backends)
else:

View File

@ -6,7 +6,7 @@ class WAFv2ClientError(JsonRESTError):
class WAFV2DuplicateItemException(WAFv2ClientError):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"WafV2DuplicateItem",
"AWS WAF could not perform the operation because some resource in your request is a duplicate of an existing one.",
@ -14,7 +14,7 @@ class WAFV2DuplicateItemException(WAFv2ClientError):
class WAFNonexistentItemException(WAFv2ClientError):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"WAFNonexistentItemException",
"AWS WAF couldnt perform the operation because your resource doesnt exist.",

View File

@ -1,6 +1,6 @@
import datetime
import re
from typing import Dict
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from moto.core import BaseBackend, BackendDict, BaseModel
from .utils import make_arn_for_wacl
@ -10,6 +10,9 @@ from moto.moto_api._internal import mock_random
from moto.utilities.tagging_service import TaggingService
from collections import OrderedDict
if TYPE_CHECKING:
from moto.apigateway.models import Stage
US_EAST_1_REGION = "us-east-1"
GLOBAL_REGION = "global"
@ -25,7 +28,14 @@ class FakeWebACL(BaseModel):
"""
def __init__(
self, name, arn, wacl_id, visibility_config, default_action, description, rules
self,
name: str,
arn: str,
wacl_id: str,
visibility_config: Dict[str, Any],
default_action: Dict[str, Any],
description: Optional[str],
rules: List[Dict[str, Any]],
):
self.name = name
self.created_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
@ -38,7 +48,13 @@ class FakeWebACL(BaseModel):
self.default_action = default_action
self.lock_token = str(mock_random.uuid4())[0:6]
def update(self, default_action, rules, description, visibility_config):
def update(
self,
default_action: Optional[Dict[str, Any]],
rules: Optional[List[Dict[str, Any]]],
description: Optional[str],
visibility_config: Optional[Dict[str, Any]],
) -> None:
if default_action is not None:
self.default_action = default_action
if rules is not None:
@ -49,7 +65,7 @@ class FakeWebACL(BaseModel):
self.visibility_config = visibility_config
self.lock_token = str(mock_random.uuid4())[0:6]
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
# Format for summary https://docs.aws.amazon.com/waf/latest/APIReference/API_CreateWebACL.html (response syntax section)
return {
"ARN": self.arn,
@ -67,13 +83,13 @@ class WAFV2Backend(BaseBackend):
https://docs.aws.amazon.com/waf/latest/APIReference/API_Operations_AWS_WAFV2.html
"""
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.wacls: Dict[str, FakeWebACL] = OrderedDict()
self.tagging_service = TaggingService()
# TODO: self.load_balancers = OrderedDict()
def associate_web_acl(self, web_acl_arn, resource_arn):
def associate_web_acl(self, web_acl_arn: str, resource_arn: str) -> None:
"""
Only APIGateway Stages can be associated at the moment.
"""
@ -83,19 +99,19 @@ class WAFV2Backend(BaseBackend):
if stage:
stage.web_acl_arn = web_acl_arn
def disassociate_web_acl(self, resource_arn):
def disassociate_web_acl(self, resource_arn: str) -> None:
stage = self._find_apigw_stage(resource_arn)
if stage:
stage.web_acl_arn = None
def get_web_acl_for_resource(self, resource_arn):
def get_web_acl_for_resource(self, resource_arn: str) -> Optional[FakeWebACL]:
stage = self._find_apigw_stage(resource_arn)
if stage and stage.web_acl_arn is not None:
wacl_arn = stage.web_acl_arn
return self.wacls.get(wacl_arn)
return None
def _find_apigw_stage(self, resource_arn):
def _find_apigw_stage(self, resource_arn: str) -> Optional["Stage"]: # type: ignore
try:
if re.search(APIGATEWAY_REGEX, resource_arn):
region = resource_arn.split(":")[3]
@ -110,8 +126,15 @@ class WAFV2Backend(BaseBackend):
return None
def create_web_acl(
self, name, visibility_config, default_action, scope, description, tags, rules
):
self,
name: str,
visibility_config: Dict[str, Any],
default_action: Dict[str, Any],
scope: str,
description: str,
tags: List[Dict[str, str]],
rules: List[Dict[str, Any]],
) -> FakeWebACL:
"""
The following parameters are not yet implemented: CustomResponseBodies, CaptchaConfig
"""
@ -132,7 +155,7 @@ class WAFV2Backend(BaseBackend):
self.tag_resource(arn, tags)
return new_wacl
def delete_web_acl(self, name, _id):
def delete_web_acl(self, name: str, _id: str) -> None:
"""
The LockToken-parameter is not yet implemented
"""
@ -142,37 +165,43 @@ class WAFV2Backend(BaseBackend):
if wacl.name != name and wacl.id != _id
}
def get_web_acl(self, name, _id) -> FakeWebACL:
def get_web_acl(self, name: str, _id: str) -> FakeWebACL:
for wacl in self.wacls.values():
if wacl.name == name and wacl.id == _id:
return wacl
raise WAFNonexistentItemException
def list_web_acls(self):
def list_web_acls(self) -> List[Dict[str, Any]]:
return [wacl.to_dict() for wacl in self.wacls.values()]
def _is_duplicate_name(self, name):
def _is_duplicate_name(self, name: str) -> bool:
allWaclNames = set(wacl.name for wacl in self.wacls.values())
return name in allWaclNames
def list_rule_groups(self):
def list_rule_groups(self) -> List[Any]:
return []
def list_tags_for_resource(self, arn):
def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]:
"""
Pagination is not yet implemented
"""
return self.tagging_service.list_tags_for_resource(arn)["Tags"]
def tag_resource(self, arn, tags):
def tag_resource(self, arn: str, tags: List[Dict[str, str]]) -> None:
self.tagging_service.tag_resource(arn, tags)
def untag_resource(self, arn, tag_keys):
def untag_resource(self, arn: str, tag_keys: List[str]) -> None:
self.tagging_service.untag_resource_using_names(arn, tag_keys)
def update_web_acl(
self, name, _id, default_action, rules, description, visibility_config
):
self,
name: str,
_id: str,
default_action: Optional[Dict[str, Any]],
rules: Optional[List[Dict[str, Any]]],
description: Optional[str],
visibility_config: Optional[Dict[str, Any]],
) -> str:
"""
The following parameters are not yet implemented: LockToken, CustomResponseBodies, CaptchaConfig
"""

View File

@ -1,19 +1,20 @@
import json
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from moto.utilities.aws_headers import amzn_request_id
from .models import GLOBAL_REGION, wafv2_backends
from .models import GLOBAL_REGION, wafv2_backends, WAFV2Backend
class WAFV2Response(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="wafv2")
@property
def wafv2_backend(self):
def wafv2_backend(self) -> WAFV2Backend:
return wafv2_backends[self.current_account][self.region]
@amzn_request_id
def associate_web_acl(self):
def associate_web_acl(self) -> TYPE_RESPONSE:
body = json.loads(self.body)
web_acl_arn = body["WebACLArn"]
resource_arn = body["ResourceArn"]
@ -21,14 +22,14 @@ class WAFV2Response(BaseResponse):
return 200, {}, "{}"
@amzn_request_id
def disassociate_web_acl(self):
def disassociate_web_acl(self) -> TYPE_RESPONSE:
body = json.loads(self.body)
resource_arn = body["ResourceArn"]
self.wafv2_backend.disassociate_web_acl(resource_arn)
return 200, {}, "{}"
@amzn_request_id
def get_web_acl_for_resource(self):
def get_web_acl_for_resource(self) -> TYPE_RESPONSE:
body = json.loads(self.body)
resource_arn = body["ResourceArn"]
web_acl = self.wafv2_backend.get_web_acl_for_resource(resource_arn)
@ -37,7 +38,7 @@ class WAFV2Response(BaseResponse):
return 200, response_headers, json.dumps(response)
@amzn_request_id
def create_web_acl(self):
def create_web_acl(self) -> TYPE_RESPONSE:
"""https://docs.aws.amazon.com/waf/latest/APIReference/API_CreateWebACL.html (response syntax section)"""
scope = self._get_param("Scope")
@ -62,7 +63,7 @@ class WAFV2Response(BaseResponse):
return 200, response_headers, json.dumps(response)
@amzn_request_id
def delete_web_acl(self):
def delete_web_acl(self) -> TYPE_RESPONSE:
scope = self._get_param("Scope")
if scope == "CLOUDFRONT":
self.region = GLOBAL_REGION
@ -73,7 +74,7 @@ class WAFV2Response(BaseResponse):
return 200, response_headers, "{}"
@amzn_request_id
def get_web_acl(self):
def get_web_acl(self) -> TYPE_RESPONSE:
scope = self._get_param("Scope")
if scope == "CLOUDFRONT":
self.region = GLOBAL_REGION
@ -85,7 +86,7 @@ class WAFV2Response(BaseResponse):
return 200, response_headers, json.dumps(response)
@amzn_request_id
def list_web_ac_ls(self):
def list_web_ac_ls(self) -> TYPE_RESPONSE:
"""https://docs.aws.amazon.com/waf/latest/APIReference/API_ListWebACLs.html (response syntax section)"""
scope = self._get_param("Scope")
@ -97,7 +98,7 @@ class WAFV2Response(BaseResponse):
return 200, response_headers, json.dumps(response)
@amzn_request_id
def list_rule_groups(self):
def list_rule_groups(self) -> TYPE_RESPONSE:
scope = self._get_param("Scope")
if scope == "CLOUDFRONT":
self.region = GLOBAL_REGION
@ -107,7 +108,7 @@ class WAFV2Response(BaseResponse):
return 200, response_headers, json.dumps(response)
@amzn_request_id
def list_tags_for_resource(self):
def list_tags_for_resource(self) -> TYPE_RESPONSE:
arn = self._get_param("ResourceARN")
self.region = arn.split(":")[3]
tags = self.wafv2_backend.list_tags_for_resource(arn)
@ -116,7 +117,7 @@ class WAFV2Response(BaseResponse):
return 200, response_headers, json.dumps(response)
@amzn_request_id
def tag_resource(self):
def tag_resource(self) -> TYPE_RESPONSE:
body = json.loads(self.body)
arn = body.get("ResourceARN")
self.region = arn.split(":")[3]
@ -125,7 +126,7 @@ class WAFV2Response(BaseResponse):
return 200, {}, "{}"
@amzn_request_id
def untag_resource(self):
def untag_resource(self) -> TYPE_RESPONSE:
body = json.loads(self.body)
arn = body.get("ResourceARN")
self.region = arn.split(":")[3]
@ -134,7 +135,7 @@ class WAFV2Response(BaseResponse):
return 200, {}, "{}"
@amzn_request_id
def update_web_acl(self):
def update_web_acl(self) -> TYPE_RESPONSE:
body = json.loads(self.body)
name = body.get("Name")
_id = body.get("Id")

View File

@ -1,4 +1,6 @@
def make_arn_for_wacl(name, account_id, region_name, wacl_id, scope):
def make_arn_for_wacl(
name: str, account_id: str, region_name: str, wacl_id: str, scope: str
) -> str:
"""https://docs.aws.amazon.com/waf/latest/developerguide/how-aws-waf-works.html - explains --scope (cloudfront vs regional)"""
if scope == "REGIONAL":

View File

@ -1,14 +1,22 @@
from typing import Any, Dict, Optional
class BadSegmentException(Exception):
def __init__(self, seg_id=None, code=None, message=None):
def __init__(
self,
seg_id: Optional[str] = None,
code: Optional[str] = None,
message: Optional[str] = None,
):
self.id = seg_id
self.code = code
self.message = message
def __repr__(self):
def __repr__(self) -> str:
return f"<BadSegment {self.id}-{self.code}-{self.message}>"
def to_dict(self):
result = {}
def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {}
if self.id is not None:
result["Id"] = self.id
if self.code is not None:

View File

@ -1,25 +1,26 @@
import os
from moto.xray import xray_backends
from typing import Any
from moto.xray.models import xray_backends, XRayBackend
import aws_xray_sdk.core
from aws_xray_sdk.core.context import Context as AWSContext
from aws_xray_sdk.core.emitters.udp_emitter import UDPEmitter
class MockEmitter(UDPEmitter):
class MockEmitter(UDPEmitter): # type: ignore
"""
Replaces the code that sends UDP to local X-Ray daemon
"""
def __init__(self, daemon_address="127.0.0.1:2000"):
def __init__(self, daemon_address: str = "127.0.0.1:2000"):
address = os.getenv(
"AWS_XRAY_DAEMON_ADDRESS_YEAH_NOT_TODAY_MATE", daemon_address
)
self._ip, self._port = self._parse_address(address)
def _xray_backend(self, region):
return xray_backends[region]
def _xray_backend(self, account_id: str, region: str) -> XRayBackend:
return xray_backends[account_id][region]
def send_entity(self, entity):
def send_entity(self, entity: Any) -> None:
# Hack to get region
# region = entity.subsegments[0].aws['region']
# xray = self._xray_backend(region)
@ -27,7 +28,7 @@ class MockEmitter(UDPEmitter):
# TODO store X-Ray data, pretty sure X-Ray needs refactor for this
pass
def _send_data(self, data):
def _send_data(self, data: Any) -> None:
raise RuntimeError("Should not be running this")
@ -42,11 +43,11 @@ class MockXrayClient:
that itno the recorder instance.
"""
def __call__(self, f=None):
def __call__(self, f: Any = None) -> Any:
if not f:
return self
def wrapped_f(*args, **kwargs):
def wrapped_f(*args: Any, **kwargs: Any) -> Any:
self.start()
try:
f(*args, **kwargs)
@ -55,7 +56,7 @@ class MockXrayClient:
return wrapped_f
def start(self):
def start(self) -> None:
print("Starting X-Ray Patch") # noqa
self.old_xray_context_var = os.environ.get("AWS_XRAY_CONTEXT_MISSING")
os.environ["AWS_XRAY_CONTEXT_MISSING"] = "LOG_ERROR"
@ -64,7 +65,7 @@ class MockXrayClient:
aws_xray_sdk.core.xray_recorder._context = AWSContext()
aws_xray_sdk.core.xray_recorder._emitter = MockEmitter()
def stop(self):
def stop(self) -> None:
if self.old_xray_context_var is None:
del os.environ["AWS_XRAY_CONTEXT_MISSING"]
else:
@ -73,11 +74,11 @@ class MockXrayClient:
aws_xray_sdk.core.xray_recorder._emitter = self.old_xray_emitter
aws_xray_sdk.core.xray_recorder._context = self.old_xray_context
def __enter__(self):
def __enter__(self) -> "MockXrayClient":
self.start()
return self
def __exit__(self, *args):
def __exit__(self, *args: Any) -> None:
self.stop()
@ -91,12 +92,12 @@ class XRaySegment(object):
During testing we're going to have to control the start and end of a segment via context managers.
"""
def __enter__(self):
def __enter__(self) -> "XRaySegment":
aws_xray_sdk.core.xray_recorder.begin_segment(
name="moto_mock", traceid=None, parent_id=None, sampling=1
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
aws_xray_sdk.core.xray_recorder.end_segment()

View File

@ -1,6 +1,7 @@
import bisect
import datetime
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
import json
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.core.exceptions import AWSError
@ -8,54 +9,60 @@ from .exceptions import BadSegmentException
class TelemetryRecords(BaseModel):
def __init__(self, instance_id, hostname, resource_arn, records):
def __init__(
self,
instance_id: str,
hostname: str,
resource_arn: str,
records: List[Dict[str, Any]],
):
self.instance_id = instance_id
self.hostname = hostname
self.resource_arn = resource_arn
self.records = records
@classmethod
def from_json(cls, src):
def from_json(cls, src: Dict[str, Any]) -> "TelemetryRecords": # type: ignore[misc]
instance_id = src.get("EC2InstanceId", None)
hostname = src.get("Hostname")
resource_arn = src.get("ResourceARN")
telemetry_records = src["TelemetryRecords"]
return cls(instance_id, hostname, resource_arn, telemetry_records)
return cls(instance_id, hostname, resource_arn, telemetry_records) # type: ignore
# https://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html
class TraceSegment(BaseModel):
def __init__(
self,
name,
segment_id,
trace_id,
start_time,
raw,
end_time=None,
in_progress=False,
service=None,
user=None,
origin=None,
parent_id=None,
http=None,
aws=None,
metadata=None,
annotations=None,
subsegments=None,
**kwargs
name: str,
segment_id: str,
trace_id: str,
start_time: float,
raw: Any,
end_time: Optional[float] = None,
in_progress: bool = False,
service: Any = None,
user: Any = None,
origin: Any = None,
parent_id: Any = None,
http: Any = None,
aws: Any = None,
metadata: Any = None,
annotations: Any = None,
subsegments: Any = None,
**kwargs: Any
):
self.name = name
self.id = segment_id
self.trace_id = trace_id
self._trace_version = None
self._original_request_start_time = None
self._trace_version: Optional[int] = None
self._original_request_start_time: Optional[datetime.datetime] = None
self._trace_identifier = None
self.start_time = start_time
self._start_date = None
self._start_date: Optional[datetime.datetime] = None
self.end_time = end_time
self._end_date = None
self._end_date: Optional[datetime.datetime] = None
self.in_progress = in_progress
self.service = service
self.user = user
@ -71,17 +78,17 @@ class TraceSegment(BaseModel):
# Raw json string
self.raw = raw
def __lt__(self, other):
def __lt__(self, other: Any) -> bool:
return self.start_date < other.start_date
@property
def trace_version(self):
def trace_version(self) -> int:
if self._trace_version is None:
self._trace_version = int(self.trace_id.split("-", 1)[0])
return self._trace_version
@property
def request_start_date(self):
def request_start_date(self) -> datetime.datetime:
if self._original_request_start_time is None:
start_time = int(self.trace_id.split("-")[1], 16)
self._original_request_start_time = datetime.datetime.fromtimestamp(
@ -90,19 +97,19 @@ class TraceSegment(BaseModel):
return self._original_request_start_time
@property
def start_date(self):
def start_date(self) -> datetime.datetime:
if self._start_date is None:
self._start_date = datetime.datetime.fromtimestamp(self.start_time)
return self._start_date
@property
def end_date(self):
def end_date(self) -> datetime.datetime:
if self._end_date is None:
self._end_date = datetime.datetime.fromtimestamp(self.end_time)
self._end_date = datetime.datetime.fromtimestamp(self.end_time) # type: ignore
return self._end_date
@classmethod
def from_dict(cls, data, raw):
def from_dict(cls, data: Dict[str, Any], raw: Any) -> "TraceSegment": # type: ignore[misc]
# Check manditory args
if "id" not in data:
raise BadSegmentException(code="MissingParam", message="Missing segment ID")
@ -130,11 +137,11 @@ class TraceSegment(BaseModel):
class SegmentCollection(object):
def __init__(self):
self._traces = defaultdict(self._new_trace_item)
def __init__(self) -> None:
self._traces: Dict[str, Dict[str, Any]] = defaultdict(self._new_trace_item)
@staticmethod
def _new_trace_item():
def _new_trace_item() -> Dict[str, Any]: # type: ignore[misc]
return {
"start_date": datetime.datetime(1970, 1, 1),
"end_date": datetime.datetime(1970, 1, 1),
@ -143,7 +150,7 @@ class SegmentCollection(object):
"segments": [],
}
def put_segment(self, segment):
def put_segment(self, segment: Any) -> None:
# insert into a sorted list
bisect.insort_left(self._traces[segment.trace_id]["segments"], segment)
@ -160,11 +167,15 @@ class SegmentCollection(object):
# Todo consolidate trace segments into a trace.
# not enough working knowledge of xray to do this
def summary(self, start_time, end_time, filter_expression=None):
def summary(
self, start_time: str, end_time: str, filter_expression: Any = None
) -> Dict[str, Any]:
# This beast https://docs.aws.amazon.com/xray/latest/api/API_GetTraceSummaries.html#API_GetTraceSummaries_ResponseSyntax
if filter_expression is not None:
raise AWSError(
"Not implemented yet - moto", code="InternalFailure", status=500
"Not implemented yet - moto",
exception_type="InternalFailure",
status=500,
)
summaries = []
@ -213,7 +224,9 @@ class SegmentCollection(object):
return result
def get_trace_ids(self, trace_ids):
def get_trace_ids(
self, trace_ids: List[str]
) -> Tuple[List[Dict[str, Any]], List[str]]:
traces = []
unprocessed = []
@ -229,22 +242,24 @@ class SegmentCollection(object):
class XRayBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self._telemetry_records = []
self._telemetry_records: List[TelemetryRecords] = []
self._segment_collection = SegmentCollection()
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "xray"
)
def add_telemetry_records(self, src):
def add_telemetry_records(self, src: Any) -> None:
self._telemetry_records.append(TelemetryRecords.from_json(src))
def process_segment(self, doc):
def process_segment(self, doc: Any) -> None:
try:
data = json.loads(doc)
except ValueError:
@ -264,13 +279,15 @@ class XRayBackend(BaseBackend):
seg_id=segment.id, code="InternalFailure", message=str(err)
)
def get_trace_summary(self, start_time, end_time, filter_expression):
def get_trace_summary(
self, start_time: str, end_time: str, filter_expression: Any
) -> Dict[str, Any]:
return self._segment_collection.summary(start_time, end_time, filter_expression)
def get_trace_ids(self, trace_ids):
def get_trace_ids(self, trace_ids: List[str]) -> Dict[str, Any]:
traces, unprocessed_ids = self._segment_collection.get_trace_ids(trace_ids)
result = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids}
result: Dict[str, Any] = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids}
for trace in traces:
segments = []

View File

@ -1,49 +1,50 @@
import json
import datetime
from typing import Any, Dict, Tuple, Union
from moto.core.responses import BaseResponse
from moto.core.exceptions import AWSError
from urllib.parse import urlsplit
from .models import xray_backends
from .models import xray_backends, XRayBackend
from .exceptions import BadSegmentException
class XRayResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="xray")
def _error(self, code, message):
def _error(self, code: str, message: str) -> Tuple[str, Dict[str, int]]:
return json.dumps({"__type": code, "message": message}), dict(status=400)
@property
def xray_backend(self):
def xray_backend(self) -> XRayBackend:
return xray_backends[self.current_account][self.region]
@property
def request_params(self):
def request_params(self) -> Any: # type: ignore[misc]
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param_name, if_none=None):
def _get_param(self, param_name: str, if_none: Any = None) -> Any:
return self.request_params.get(param_name, if_none)
def _get_action(self):
def _get_action(self) -> str:
# Amazon is just calling urls like /TelemetryRecords etc...
# This uses the value after / as the camalcase action, which then
# gets converted in call_action to find the following methods
return urlsplit(self.uri).path.lstrip("/")
# PutTelemetryRecords
def telemetry_records(self):
def telemetry_records(self) -> str:
self.xray_backend.add_telemetry_records(self.request_params)
return ""
# PutTraceSegments
def trace_segments(self):
def trace_segments(self) -> Union[str, Tuple[str, Dict[str, int]]]:
docs = self._get_param("TraceSegmentDocuments")
if docs is None:
@ -71,7 +72,7 @@ class XRayResponse(BaseResponse):
return json.dumps(result)
# GetTraceSummaries
def trace_summaries(self):
def trace_summaries(self) -> Union[str, Tuple[str, Dict[str, int]]]:
start_time = self._get_param("StartTime")
end_time = self._get_param("EndTime")
if start_time is None:
@ -119,7 +120,7 @@ class XRayResponse(BaseResponse):
return json.dumps(result)
# BatchGetTraces
def traces(self):
def traces(self) -> Union[str, Tuple[str, Dict[str, int]]]:
trace_ids = self._get_param("TraceIds")
if trace_ids is None:
@ -142,7 +143,7 @@ class XRayResponse(BaseResponse):
return json.dumps(result)
# GetServiceGraph - just a dummy response for now
def service_graph(self):
def service_graph(self) -> Union[str, Tuple[str, Dict[str, int]]]:
start_time = self._get_param("StartTime")
end_time = self._get_param("EndTime")
# next_token = self._get_param('NextToken') # not implemented yet
@ -164,7 +165,7 @@ class XRayResponse(BaseResponse):
return json.dumps(result)
# GetTraceGraph - just a dummy response for now
def trace_graph(self):
def trace_graph(self) -> Union[str, Tuple[str, Dict[str, int]]]:
trace_ids = self._get_param("TraceIds")
# next_token = self._get_param('NextToken') # not implemented yet
@ -175,5 +176,5 @@ class XRayResponse(BaseResponse):
dict(status=400),
)
result = {"Services": []}
result: Dict[str, Any] = {"Services": []}
return json.dumps(result)

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u*,moto/t*
files= moto, tests/test_core/test_mock_all.py, tests/test_core/test_decorator_calls.py
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract

View File

@ -1,10 +1,11 @@
import boto3
import pytest
import sure # noqa # pylint: disable=unused-import
import unittest
from botocore.exceptions import ClientError
from typing import Any
from moto import mock_ec2, mock_kinesis, mock_s3, settings
from moto.core.models import BaseMockAWS
from unittest import SkipTest
"""
@ -13,13 +14,13 @@ Test the different ways that the decorator can be used
@mock_ec2
def test_basic_decorator():
def test_basic_decorator() -> None:
client = boto3.client("ec2", region_name="us-west-1")
client.describe_addresses()["Addresses"].should.equal([])
assert client.describe_addresses()["Addresses"] == []
@pytest.fixture(name="aws_credentials")
def fixture_aws_credentials(monkeypatch):
def fixture_aws_credentials(monkeypatch: Any) -> None: # type: ignore[misc]
"""Mocked AWS Credentials for moto."""
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
@ -28,97 +29,95 @@ def fixture_aws_credentials(monkeypatch):
@pytest.mark.network
def test_context_manager(aws_credentials): # pylint: disable=unused-argument
def test_context_manager(aws_credentials: Any) -> None: # type: ignore[misc] # pylint: disable=unused-argument
client = boto3.client("ec2", region_name="us-west-1")
with pytest.raises(ClientError) as exc:
client.describe_addresses()
err = exc.value.response["Error"]
err["Code"].should.equal("AuthFailure")
err["Message"].should.equal(
"AWS was not able to validate the provided access credentials"
assert err["Code"] == "AuthFailure"
assert (
err["Message"] == "AWS was not able to validate the provided access credentials"
)
with mock_ec2():
client = boto3.client("ec2", region_name="us-west-1")
client.describe_addresses()["Addresses"].should.equal([])
assert client.describe_addresses()["Addresses"] == []
@pytest.mark.network
def test_decorator_start_and_stop():
def test_decorator_start_and_stop() -> None:
if settings.TEST_SERVER_MODE:
raise SkipTest("Authentication always works in ServerMode")
mock = mock_ec2()
mock: BaseMockAWS = mock_ec2()
mock.start()
client = boto3.client("ec2", region_name="us-west-1")
client.describe_addresses()["Addresses"].should.equal([])
assert client.describe_addresses()["Addresses"] == []
mock.stop()
with pytest.raises(ClientError) as exc:
client.describe_addresses()
err = exc.value.response["Error"]
err["Code"].should.equal("AuthFailure")
err["Message"].should.equal(
"AWS was not able to validate the provided access credentials"
assert err["Code"] == "AuthFailure"
assert (
err["Message"] == "AWS was not able to validate the provided access credentials"
)
@mock_ec2
def test_decorater_wrapped_gets_set():
def test_decorater_wrapped_gets_set() -> None:
"""
Moto decorator's __wrapped__ should get set to the tests function
"""
test_decorater_wrapped_gets_set.__wrapped__.__name__.should.equal(
"test_decorater_wrapped_gets_set"
)
assert test_decorater_wrapped_gets_set.__wrapped__.__name__ == "test_decorater_wrapped_gets_set" # type: ignore
@mock_ec2
class Tester(object):
def test_the_class(self):
class Tester:
def test_the_class(self) -> None:
client = boto3.client("ec2", region_name="us-west-1")
client.describe_addresses()["Addresses"].should.equal([])
assert client.describe_addresses()["Addresses"] == []
def test_still_the_same(self):
def test_still_the_same(self) -> None:
client = boto3.client("ec2", region_name="us-west-1")
client.describe_addresses()["Addresses"].should.equal([])
assert client.describe_addresses()["Addresses"] == []
@mock_s3
class TesterWithSetup(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.client = boto3.client("s3")
self.client.create_bucket(Bucket="mybucket")
def test_still_the_same(self):
def test_still_the_same(self) -> None:
buckets = self.client.list_buckets()["Buckets"]
bucket_names = [b["Name"] for b in buckets]
# There is a potential bug in the class-decorator, where the reset API is not called on start.
# This leads to a situation where 'bucket_names' may contain buckets created by earlier tests
bucket_names.should.contain("mybucket")
assert "mybucket" in bucket_names
@mock_s3
class TesterWithStaticmethod(object):
class TesterWithStaticmethod:
@staticmethod
def static(*args):
def static(*args: Any) -> None: # type: ignore[misc]
assert not args or not isinstance(args[0], TesterWithStaticmethod)
def test_no_instance_sent_to_staticmethod(self):
def test_no_instance_sent_to_staticmethod(self) -> None:
self.static()
@mock_s3
class TestWithSetup_UppercaseU(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
# This method will be executed automatically, provided we extend the TestCase-class
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
def test_should_find_bucket(self):
def test_should_find_bucket(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
self.assertIsNotNone(s3.head_bucket(Bucket="mybucket"))
def test_should_not_find_unknown_bucket(self):
def test_should_not_find_unknown_bucket(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
with pytest.raises(ClientError):
s3.head_bucket(Bucket="unknown_bucket")
@ -126,16 +125,16 @@ class TestWithSetup_UppercaseU(unittest.TestCase):
@mock_s3
class TestWithSetup_LowercaseU:
def setup_method(self, *args): # pylint: disable=unused-argument
def setup_method(self, *args: Any) -> None: # pylint: disable=unused-argument
# This method will be executed automatically using pytest
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
def test_should_find_bucket(self):
def test_should_find_bucket(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
assert s3.head_bucket(Bucket="mybucket") is not None
def test_should_not_find_unknown_bucket(self):
def test_should_not_find_unknown_bucket(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
with pytest.raises(ClientError):
s3.head_bucket(Bucket="unknown_bucket")
@ -143,16 +142,16 @@ class TestWithSetup_LowercaseU:
@mock_s3
class TestWithSetupMethod:
def setup_method(self, *args): # pylint: disable=unused-argument
def setup_method(self, *args: Any) -> None: # pylint: disable=unused-argument
# This method will be executed automatically using pytest
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
def test_should_find_bucket(self):
def test_should_find_bucket(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
assert s3.head_bucket(Bucket="mybucket") is not None
def test_should_not_find_unknown_bucket(self):
def test_should_not_find_unknown_bucket(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
with pytest.raises(ClientError):
s3.head_bucket(Bucket="unknown_bucket")
@ -160,17 +159,17 @@ class TestWithSetupMethod:
@mock_kinesis
class TestKinesisUsingSetupMethod:
def setup_method(self, *args): # pylint: disable=unused-argument
def setup_method(self, *args: Any) -> None: # pylint: disable=unused-argument
self.stream_name = "test_stream"
self.boto3_kinesis_client = boto3.client("kinesis", region_name="us-east-1")
self.boto3_kinesis_client.create_stream(
StreamName=self.stream_name, ShardCount=1
)
def test_stream_creation(self):
def test_stream_creation(self) -> None:
pass
def test_stream_recreation(self):
def test_stream_recreation(self) -> None:
# The setup-method will run again for this test
# The fact that it passes, means the state was reset
# Otherwise it would complain about a stream already existing
@ -179,11 +178,11 @@ class TestKinesisUsingSetupMethod:
@mock_s3
class TestWithInvalidSetupMethod:
def setupmethod(self):
def setupmethod(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
def test_should_not_find_bucket(self):
def test_should_not_find_bucket(self) -> None:
# Name of setupmethod is not recognized, so it will not be executed
s3 = boto3.client("s3", region_name="us-east-1")
with pytest.raises(ClientError):
@ -192,17 +191,17 @@ class TestWithInvalidSetupMethod:
@mock_s3
class TestWithPublicMethod(unittest.TestCase):
def ensure_bucket_exists(self):
def ensure_bucket_exists(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
def test_should_find_bucket(self):
def test_should_find_bucket(self) -> None:
self.ensure_bucket_exists()
s3 = boto3.client("s3", region_name="us-east-1")
s3.head_bucket(Bucket="mybucket").shouldnt.equal(None)
def test_should_not_find_bucket(self):
def test_should_not_find_bucket(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
with pytest.raises(ClientError):
s3.head_bucket(Bucket="mybucket")
@ -210,16 +209,16 @@ class TestWithPublicMethod(unittest.TestCase):
@mock_s3
class TestWithPseudoPrivateMethod(unittest.TestCase):
def _ensure_bucket_exists(self):
def _ensure_bucket_exists(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
def test_should_find_bucket(self):
def test_should_find_bucket(self) -> None:
self._ensure_bucket_exists()
s3 = boto3.client("s3", region_name="us-east-1")
s3.head_bucket(Bucket="mybucket").shouldnt.equal(None)
def test_should_not_find_bucket(self):
def test_should_not_find_bucket(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
with pytest.raises(ClientError):
s3.head_bucket(Bucket="mybucket")
@ -227,20 +226,20 @@ class TestWithPseudoPrivateMethod(unittest.TestCase):
@mock_s3
class Baseclass(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.s3 = boto3.resource("s3", region_name="us-east-1")
self.client = boto3.client("s3", region_name="us-east-1")
self.test_bucket = self.s3.Bucket("testbucket")
self.test_bucket.create()
def tearDown(self):
def tearDown(self) -> None:
# The bucket will still exist at this point
self.test_bucket.delete()
@mock_s3
class TestSetUpInBaseClass(Baseclass):
def test_a_thing(self):
def test_a_thing(self) -> None:
# Verify that we can 'see' the setUp-method in the parent class
self.client.head_bucket(Bucket="testbucket").shouldnt.equal(None)
@ -248,42 +247,42 @@ class TestSetUpInBaseClass(Baseclass):
@mock_s3
class TestWithNestedClasses:
class NestedClass(unittest.TestCase):
def _ensure_bucket_exists(self):
def _ensure_bucket_exists(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="bucketclass1")
def test_should_find_bucket(self):
def test_should_find_bucket(self) -> None:
self._ensure_bucket_exists()
s3 = boto3.client("s3", region_name="us-east-1")
s3.head_bucket(Bucket="bucketclass1")
class NestedClass2(unittest.TestCase):
def _ensure_bucket_exists(self):
def _ensure_bucket_exists(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="bucketclass2")
def test_should_find_bucket(self):
def test_should_find_bucket(self) -> None:
self._ensure_bucket_exists()
s3 = boto3.client("s3", region_name="us-east-1")
s3.head_bucket(Bucket="bucketclass2")
def test_should_not_find_bucket_from_different_class(self):
def test_should_not_find_bucket_from_different_class(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
with pytest.raises(ClientError):
s3.head_bucket(Bucket="bucketclass1")
class TestWithSetup(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="mybucket")
def test_should_find_bucket(self):
def test_should_find_bucket(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
s3.head_bucket(Bucket="mybucket")
s3.create_bucket(Bucket="bucketinsidetest")
def test_should_not_find_bucket_from_test_method(self):
def test_should_not_find_bucket_from_test_method(self) -> None:
s3 = boto3.client("s3", region_name="us-east-1")
s3.head_bucket(Bucket="mybucket")

View File

@ -6,7 +6,7 @@ from moto import mock_all
@mock_all()
def test_decorator():
def test_decorator() -> None:
rgn = "us-east-1"
sqs = boto3.client("sqs", region_name=rgn)
r = sqs.list_queues()
@ -21,7 +21,7 @@ def test_decorator():
r["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
def test_context_manager():
def test_context_manager() -> None:
rgn = "us-east-1"
with mock_all():