Techdebt: MyPy all the things (#6406)
This commit is contained in:
parent
2e3b06bbe5
commit
0247719554
@ -1,17 +1,27 @@
|
||||
import importlib
|
||||
import sys
|
||||
from contextlib import ContextDecorator
|
||||
from moto.core.models import BaseMockAWS
|
||||
from typing import Any, Callable, List, Optional, TypeVar
|
||||
|
||||
|
||||
def lazy_load(module_name, element, boto3_name=None, backend=None):
|
||||
def f(*args, **kwargs):
|
||||
TEST_METHOD = TypeVar("TEST_METHOD", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def lazy_load(
|
||||
module_name: str,
|
||||
element: str,
|
||||
boto3_name: Optional[str] = None,
|
||||
backend: Optional[str] = None,
|
||||
) -> Callable[..., BaseMockAWS]:
|
||||
def f(*args: Any, **kwargs: Any) -> Any:
|
||||
module = importlib.import_module(module_name, "moto")
|
||||
return getattr(module, element)(*args, **kwargs)
|
||||
|
||||
setattr(f, "name", module_name.replace(".", ""))
|
||||
setattr(f, "element", element)
|
||||
setattr(f, "boto3_name", boto3_name or f.name)
|
||||
setattr(f, "backend", backend or f"{f.name}_backends")
|
||||
setattr(f, "boto3_name", boto3_name or f.name) # type: ignore[attr-defined]
|
||||
setattr(f, "backend", backend or f"{f.name}_backends") # type: ignore[attr-defined]
|
||||
return f
|
||||
|
||||
|
||||
@ -176,17 +186,17 @@ mock_textract = lazy_load(".textract", "mock_textract")
|
||||
|
||||
|
||||
class MockAll(ContextDecorator):
|
||||
def __init__(self):
|
||||
self.mocks = []
|
||||
def __init__(self) -> None:
|
||||
self.mocks: List[Any] = []
|
||||
for mock in dir(sys.modules["moto"]):
|
||||
if mock.startswith("mock_") and not mock == ("mock_all"):
|
||||
if mock.startswith("mock_") and not mock == "mock_all":
|
||||
self.mocks.append(globals()[mock]())
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
for mock in self.mocks:
|
||||
mock.start()
|
||||
|
||||
def __exit__(self, *exc):
|
||||
def __exit__(self, *exc: Any) -> None:
|
||||
for mock in self.mocks:
|
||||
mock.stop()
|
||||
|
||||
|
@ -626,7 +626,7 @@ class Stage(BaseModel):
|
||||
self.tags = tags
|
||||
self.tracing_enabled = tracing_enabled
|
||||
self.access_log_settings: Optional[Dict[str, Any]] = None
|
||||
self.web_acl_arn = None
|
||||
self.web_acl_arn: Optional[str] = None
|
||||
|
||||
def to_json(self) -> Dict[str, Any]:
|
||||
dct: Dict[str, Any] = {
|
||||
|
@ -5,7 +5,8 @@ import os
|
||||
import re
|
||||
import unittest
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Dict, Optional, Set, TypeVar
|
||||
from typing import Any, Callable, Dict, Optional, Set, TypeVar, Union
|
||||
from typing import ContextManager
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
@ -29,7 +30,7 @@ DEFAULT_ACCOUNT_ID = "123456789012"
|
||||
CALLABLE_RETURN = TypeVar("CALLABLE_RETURN")
|
||||
|
||||
|
||||
class BaseMockAWS:
|
||||
class BaseMockAWS(ContextManager["BaseMockAWS"]):
|
||||
nested_count = 0
|
||||
mocks_active = False
|
||||
|
||||
@ -65,10 +66,13 @@ class BaseMockAWS:
|
||||
self.reset() # type: ignore[attr-defined]
|
||||
|
||||
def __call__(
|
||||
self, func: Callable[..., Any], reset: bool = True, remove_data: bool = True
|
||||
) -> Any:
|
||||
self,
|
||||
func: Callable[..., "BaseMockAWS"],
|
||||
reset: bool = True,
|
||||
remove_data: bool = True,
|
||||
) -> Callable[..., "BaseMockAWS"]:
|
||||
if inspect.isclass(func):
|
||||
return self.decorate_class(func)
|
||||
return self.decorate_class(func) # type: ignore
|
||||
return self.decorate_callable(func, reset, remove_data)
|
||||
|
||||
def __enter__(self) -> "BaseMockAWS":
|
||||
@ -116,9 +120,9 @@ class BaseMockAWS:
|
||||
self.disable_patching() # type: ignore[attr-defined]
|
||||
|
||||
def decorate_callable(
|
||||
self, func: Callable[..., CALLABLE_RETURN], reset: bool, remove_data: bool
|
||||
) -> Callable[..., CALLABLE_RETURN]:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> CALLABLE_RETURN:
|
||||
self, func: Callable[..., "BaseMockAWS"], reset: bool, remove_data: bool
|
||||
) -> Callable[..., "BaseMockAWS"]:
|
||||
def wrapper(*args: Any, **kwargs: Any) -> "BaseMockAWS":
|
||||
self.start(reset=reset)
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
@ -361,7 +365,6 @@ MockAWS = BotocoreEventMockAWS
|
||||
|
||||
|
||||
class ServerModeMockAWS(BaseMockAWS):
|
||||
|
||||
RESET_IN_PROGRESS = False
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
@ -427,7 +430,9 @@ class base_decorator:
|
||||
def __init__(self, backends: BackendDict):
|
||||
self.backends = backends
|
||||
|
||||
def __call__(self, func: Optional[Callable[..., Any]] = None) -> BaseMockAWS:
|
||||
def __call__(
|
||||
self, func: Optional[Callable[..., Any]] = None
|
||||
) -> Union[BaseMockAWS, Callable[..., BaseMockAWS]]:
|
||||
if settings.TEST_SERVER_MODE:
|
||||
mocked_backend: BaseMockAWS = ServerModeMockAWS(self.backends)
|
||||
else:
|
||||
|
@ -6,7 +6,7 @@ class WAFv2ClientError(JsonRESTError):
|
||||
|
||||
|
||||
class WAFV2DuplicateItemException(WAFv2ClientError):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"WafV2DuplicateItem",
|
||||
"AWS WAF could not perform the operation because some resource in your request is a duplicate of an existing one.",
|
||||
@ -14,7 +14,7 @@ class WAFV2DuplicateItemException(WAFv2ClientError):
|
||||
|
||||
|
||||
class WAFNonexistentItemException(WAFv2ClientError):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"WAFNonexistentItemException",
|
||||
"AWS WAF couldn’t perform the operation because your resource doesn’t exist.",
|
||||
|
@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
import re
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel
|
||||
|
||||
from .utils import make_arn_for_wacl
|
||||
@ -10,6 +10,9 @@ from moto.moto_api._internal import mock_random
|
||||
from moto.utilities.tagging_service import TaggingService
|
||||
from collections import OrderedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from moto.apigateway.models import Stage
|
||||
|
||||
|
||||
US_EAST_1_REGION = "us-east-1"
|
||||
GLOBAL_REGION = "global"
|
||||
@ -25,7 +28,14 @@ class FakeWebACL(BaseModel):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name, arn, wacl_id, visibility_config, default_action, description, rules
|
||||
self,
|
||||
name: str,
|
||||
arn: str,
|
||||
wacl_id: str,
|
||||
visibility_config: Dict[str, Any],
|
||||
default_action: Dict[str, Any],
|
||||
description: Optional[str],
|
||||
rules: List[Dict[str, Any]],
|
||||
):
|
||||
self.name = name
|
||||
self.created_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
|
||||
@ -38,7 +48,13 @@ class FakeWebACL(BaseModel):
|
||||
self.default_action = default_action
|
||||
self.lock_token = str(mock_random.uuid4())[0:6]
|
||||
|
||||
def update(self, default_action, rules, description, visibility_config):
|
||||
def update(
|
||||
self,
|
||||
default_action: Optional[Dict[str, Any]],
|
||||
rules: Optional[List[Dict[str, Any]]],
|
||||
description: Optional[str],
|
||||
visibility_config: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
if default_action is not None:
|
||||
self.default_action = default_action
|
||||
if rules is not None:
|
||||
@ -49,7 +65,7 @@ class FakeWebACL(BaseModel):
|
||||
self.visibility_config = visibility_config
|
||||
self.lock_token = str(mock_random.uuid4())[0:6]
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
# Format for summary https://docs.aws.amazon.com/waf/latest/APIReference/API_CreateWebACL.html (response syntax section)
|
||||
return {
|
||||
"ARN": self.arn,
|
||||
@ -67,13 +83,13 @@ class WAFV2Backend(BaseBackend):
|
||||
https://docs.aws.amazon.com/waf/latest/APIReference/API_Operations_AWS_WAFV2.html
|
||||
"""
|
||||
|
||||
def __init__(self, region_name, account_id):
|
||||
def __init__(self, region_name: str, account_id: str):
|
||||
super().__init__(region_name, account_id)
|
||||
self.wacls: Dict[str, FakeWebACL] = OrderedDict()
|
||||
self.tagging_service = TaggingService()
|
||||
# TODO: self.load_balancers = OrderedDict()
|
||||
|
||||
def associate_web_acl(self, web_acl_arn, resource_arn):
|
||||
def associate_web_acl(self, web_acl_arn: str, resource_arn: str) -> None:
|
||||
"""
|
||||
Only APIGateway Stages can be associated at the moment.
|
||||
"""
|
||||
@ -83,19 +99,19 @@ class WAFV2Backend(BaseBackend):
|
||||
if stage:
|
||||
stage.web_acl_arn = web_acl_arn
|
||||
|
||||
def disassociate_web_acl(self, resource_arn):
|
||||
def disassociate_web_acl(self, resource_arn: str) -> None:
|
||||
stage = self._find_apigw_stage(resource_arn)
|
||||
if stage:
|
||||
stage.web_acl_arn = None
|
||||
|
||||
def get_web_acl_for_resource(self, resource_arn):
|
||||
def get_web_acl_for_resource(self, resource_arn: str) -> Optional[FakeWebACL]:
|
||||
stage = self._find_apigw_stage(resource_arn)
|
||||
if stage and stage.web_acl_arn is not None:
|
||||
wacl_arn = stage.web_acl_arn
|
||||
return self.wacls.get(wacl_arn)
|
||||
return None
|
||||
|
||||
def _find_apigw_stage(self, resource_arn):
|
||||
def _find_apigw_stage(self, resource_arn: str) -> Optional["Stage"]: # type: ignore
|
||||
try:
|
||||
if re.search(APIGATEWAY_REGEX, resource_arn):
|
||||
region = resource_arn.split(":")[3]
|
||||
@ -110,8 +126,15 @@ class WAFV2Backend(BaseBackend):
|
||||
return None
|
||||
|
||||
def create_web_acl(
|
||||
self, name, visibility_config, default_action, scope, description, tags, rules
|
||||
):
|
||||
self,
|
||||
name: str,
|
||||
visibility_config: Dict[str, Any],
|
||||
default_action: Dict[str, Any],
|
||||
scope: str,
|
||||
description: str,
|
||||
tags: List[Dict[str, str]],
|
||||
rules: List[Dict[str, Any]],
|
||||
) -> FakeWebACL:
|
||||
"""
|
||||
The following parameters are not yet implemented: CustomResponseBodies, CaptchaConfig
|
||||
"""
|
||||
@ -132,7 +155,7 @@ class WAFV2Backend(BaseBackend):
|
||||
self.tag_resource(arn, tags)
|
||||
return new_wacl
|
||||
|
||||
def delete_web_acl(self, name, _id):
|
||||
def delete_web_acl(self, name: str, _id: str) -> None:
|
||||
"""
|
||||
The LockToken-parameter is not yet implemented
|
||||
"""
|
||||
@ -142,37 +165,43 @@ class WAFV2Backend(BaseBackend):
|
||||
if wacl.name != name and wacl.id != _id
|
||||
}
|
||||
|
||||
def get_web_acl(self, name, _id) -> FakeWebACL:
|
||||
def get_web_acl(self, name: str, _id: str) -> FakeWebACL:
|
||||
for wacl in self.wacls.values():
|
||||
if wacl.name == name and wacl.id == _id:
|
||||
return wacl
|
||||
raise WAFNonexistentItemException
|
||||
|
||||
def list_web_acls(self):
|
||||
def list_web_acls(self) -> List[Dict[str, Any]]:
|
||||
return [wacl.to_dict() for wacl in self.wacls.values()]
|
||||
|
||||
def _is_duplicate_name(self, name):
|
||||
def _is_duplicate_name(self, name: str) -> bool:
|
||||
allWaclNames = set(wacl.name for wacl in self.wacls.values())
|
||||
return name in allWaclNames
|
||||
|
||||
def list_rule_groups(self):
|
||||
def list_rule_groups(self) -> List[Any]:
|
||||
return []
|
||||
|
||||
def list_tags_for_resource(self, arn):
|
||||
def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Pagination is not yet implemented
|
||||
"""
|
||||
return self.tagging_service.list_tags_for_resource(arn)["Tags"]
|
||||
|
||||
def tag_resource(self, arn, tags):
|
||||
def tag_resource(self, arn: str, tags: List[Dict[str, str]]) -> None:
|
||||
self.tagging_service.tag_resource(arn, tags)
|
||||
|
||||
def untag_resource(self, arn, tag_keys):
|
||||
def untag_resource(self, arn: str, tag_keys: List[str]) -> None:
|
||||
self.tagging_service.untag_resource_using_names(arn, tag_keys)
|
||||
|
||||
def update_web_acl(
|
||||
self, name, _id, default_action, rules, description, visibility_config
|
||||
):
|
||||
self,
|
||||
name: str,
|
||||
_id: str,
|
||||
default_action: Optional[Dict[str, Any]],
|
||||
rules: Optional[List[Dict[str, Any]]],
|
||||
description: Optional[str],
|
||||
visibility_config: Optional[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""
|
||||
The following parameters are not yet implemented: LockToken, CustomResponseBodies, CaptchaConfig
|
||||
"""
|
||||
|
@ -1,19 +1,20 @@
|
||||
import json
|
||||
from moto.core.common_types import TYPE_RESPONSE
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.utilities.aws_headers import amzn_request_id
|
||||
from .models import GLOBAL_REGION, wafv2_backends
|
||||
from .models import GLOBAL_REGION, wafv2_backends, WAFV2Backend
|
||||
|
||||
|
||||
class WAFV2Response(BaseResponse):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(service_name="wafv2")
|
||||
|
||||
@property
|
||||
def wafv2_backend(self):
|
||||
def wafv2_backend(self) -> WAFV2Backend:
|
||||
return wafv2_backends[self.current_account][self.region]
|
||||
|
||||
@amzn_request_id
|
||||
def associate_web_acl(self):
|
||||
def associate_web_acl(self) -> TYPE_RESPONSE:
|
||||
body = json.loads(self.body)
|
||||
web_acl_arn = body["WebACLArn"]
|
||||
resource_arn = body["ResourceArn"]
|
||||
@ -21,14 +22,14 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, {}, "{}"
|
||||
|
||||
@amzn_request_id
|
||||
def disassociate_web_acl(self):
|
||||
def disassociate_web_acl(self) -> TYPE_RESPONSE:
|
||||
body = json.loads(self.body)
|
||||
resource_arn = body["ResourceArn"]
|
||||
self.wafv2_backend.disassociate_web_acl(resource_arn)
|
||||
return 200, {}, "{}"
|
||||
|
||||
@amzn_request_id
|
||||
def get_web_acl_for_resource(self):
|
||||
def get_web_acl_for_resource(self) -> TYPE_RESPONSE:
|
||||
body = json.loads(self.body)
|
||||
resource_arn = body["ResourceArn"]
|
||||
web_acl = self.wafv2_backend.get_web_acl_for_resource(resource_arn)
|
||||
@ -37,7 +38,7 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, response_headers, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_web_acl(self):
|
||||
def create_web_acl(self) -> TYPE_RESPONSE:
|
||||
"""https://docs.aws.amazon.com/waf/latest/APIReference/API_CreateWebACL.html (response syntax section)"""
|
||||
|
||||
scope = self._get_param("Scope")
|
||||
@ -62,7 +63,7 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, response_headers, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_web_acl(self):
|
||||
def delete_web_acl(self) -> TYPE_RESPONSE:
|
||||
scope = self._get_param("Scope")
|
||||
if scope == "CLOUDFRONT":
|
||||
self.region = GLOBAL_REGION
|
||||
@ -73,7 +74,7 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, response_headers, "{}"
|
||||
|
||||
@amzn_request_id
|
||||
def get_web_acl(self):
|
||||
def get_web_acl(self) -> TYPE_RESPONSE:
|
||||
scope = self._get_param("Scope")
|
||||
if scope == "CLOUDFRONT":
|
||||
self.region = GLOBAL_REGION
|
||||
@ -85,7 +86,7 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, response_headers, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_web_ac_ls(self):
|
||||
def list_web_ac_ls(self) -> TYPE_RESPONSE:
|
||||
"""https://docs.aws.amazon.com/waf/latest/APIReference/API_ListWebACLs.html (response syntax section)"""
|
||||
|
||||
scope = self._get_param("Scope")
|
||||
@ -97,7 +98,7 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, response_headers, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_rule_groups(self):
|
||||
def list_rule_groups(self) -> TYPE_RESPONSE:
|
||||
scope = self._get_param("Scope")
|
||||
if scope == "CLOUDFRONT":
|
||||
self.region = GLOBAL_REGION
|
||||
@ -107,7 +108,7 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, response_headers, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_tags_for_resource(self):
|
||||
def list_tags_for_resource(self) -> TYPE_RESPONSE:
|
||||
arn = self._get_param("ResourceARN")
|
||||
self.region = arn.split(":")[3]
|
||||
tags = self.wafv2_backend.list_tags_for_resource(arn)
|
||||
@ -116,7 +117,7 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, response_headers, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def tag_resource(self):
|
||||
def tag_resource(self) -> TYPE_RESPONSE:
|
||||
body = json.loads(self.body)
|
||||
arn = body.get("ResourceARN")
|
||||
self.region = arn.split(":")[3]
|
||||
@ -125,7 +126,7 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, {}, "{}"
|
||||
|
||||
@amzn_request_id
|
||||
def untag_resource(self):
|
||||
def untag_resource(self) -> TYPE_RESPONSE:
|
||||
body = json.loads(self.body)
|
||||
arn = body.get("ResourceARN")
|
||||
self.region = arn.split(":")[3]
|
||||
@ -134,7 +135,7 @@ class WAFV2Response(BaseResponse):
|
||||
return 200, {}, "{}"
|
||||
|
||||
@amzn_request_id
|
||||
def update_web_acl(self):
|
||||
def update_web_acl(self) -> TYPE_RESPONSE:
|
||||
body = json.loads(self.body)
|
||||
name = body.get("Name")
|
||||
_id = body.get("Id")
|
||||
|
@ -1,4 +1,6 @@
|
||||
def make_arn_for_wacl(name, account_id, region_name, wacl_id, scope):
|
||||
def make_arn_for_wacl(
|
||||
name: str, account_id: str, region_name: str, wacl_id: str, scope: str
|
||||
) -> str:
|
||||
"""https://docs.aws.amazon.com/waf/latest/developerguide/how-aws-waf-works.html - explains --scope (cloudfront vs regional)"""
|
||||
|
||||
if scope == "REGIONAL":
|
||||
|
@ -1,14 +1,22 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class BadSegmentException(Exception):
|
||||
def __init__(self, seg_id=None, code=None, message=None):
|
||||
def __init__(
|
||||
self,
|
||||
seg_id: Optional[str] = None,
|
||||
code: Optional[str] = None,
|
||||
message: Optional[str] = None,
|
||||
):
|
||||
self.id = seg_id
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<BadSegment {self.id}-{self.code}-{self.message}>"
|
||||
|
||||
def to_dict(self):
|
||||
result = {}
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result: Dict[str, Any] = {}
|
||||
if self.id is not None:
|
||||
result["Id"] = self.id
|
||||
if self.code is not None:
|
||||
|
@ -1,25 +1,26 @@
|
||||
import os
|
||||
from moto.xray import xray_backends
|
||||
from typing import Any
|
||||
from moto.xray.models import xray_backends, XRayBackend
|
||||
import aws_xray_sdk.core
|
||||
from aws_xray_sdk.core.context import Context as AWSContext
|
||||
from aws_xray_sdk.core.emitters.udp_emitter import UDPEmitter
|
||||
|
||||
|
||||
class MockEmitter(UDPEmitter):
|
||||
class MockEmitter(UDPEmitter): # type: ignore
|
||||
"""
|
||||
Replaces the code that sends UDP to local X-Ray daemon
|
||||
"""
|
||||
|
||||
def __init__(self, daemon_address="127.0.0.1:2000"):
|
||||
def __init__(self, daemon_address: str = "127.0.0.1:2000"):
|
||||
address = os.getenv(
|
||||
"AWS_XRAY_DAEMON_ADDRESS_YEAH_NOT_TODAY_MATE", daemon_address
|
||||
)
|
||||
self._ip, self._port = self._parse_address(address)
|
||||
|
||||
def _xray_backend(self, region):
|
||||
return xray_backends[region]
|
||||
def _xray_backend(self, account_id: str, region: str) -> XRayBackend:
|
||||
return xray_backends[account_id][region]
|
||||
|
||||
def send_entity(self, entity):
|
||||
def send_entity(self, entity: Any) -> None:
|
||||
# Hack to get region
|
||||
# region = entity.subsegments[0].aws['region']
|
||||
# xray = self._xray_backend(region)
|
||||
@ -27,7 +28,7 @@ class MockEmitter(UDPEmitter):
|
||||
# TODO store X-Ray data, pretty sure X-Ray needs refactor for this
|
||||
pass
|
||||
|
||||
def _send_data(self, data):
|
||||
def _send_data(self, data: Any) -> None:
|
||||
raise RuntimeError("Should not be running this")
|
||||
|
||||
|
||||
@ -42,11 +43,11 @@ class MockXrayClient:
|
||||
that itno the recorder instance.
|
||||
"""
|
||||
|
||||
def __call__(self, f=None):
|
||||
def __call__(self, f: Any = None) -> Any:
|
||||
if not f:
|
||||
return self
|
||||
|
||||
def wrapped_f(*args, **kwargs):
|
||||
def wrapped_f(*args: Any, **kwargs: Any) -> Any:
|
||||
self.start()
|
||||
try:
|
||||
f(*args, **kwargs)
|
||||
@ -55,7 +56,7 @@ class MockXrayClient:
|
||||
|
||||
return wrapped_f
|
||||
|
||||
def start(self):
|
||||
def start(self) -> None:
|
||||
print("Starting X-Ray Patch") # noqa
|
||||
self.old_xray_context_var = os.environ.get("AWS_XRAY_CONTEXT_MISSING")
|
||||
os.environ["AWS_XRAY_CONTEXT_MISSING"] = "LOG_ERROR"
|
||||
@ -64,7 +65,7 @@ class MockXrayClient:
|
||||
aws_xray_sdk.core.xray_recorder._context = AWSContext()
|
||||
aws_xray_sdk.core.xray_recorder._emitter = MockEmitter()
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
if self.old_xray_context_var is None:
|
||||
del os.environ["AWS_XRAY_CONTEXT_MISSING"]
|
||||
else:
|
||||
@ -73,11 +74,11 @@ class MockXrayClient:
|
||||
aws_xray_sdk.core.xray_recorder._emitter = self.old_xray_emitter
|
||||
aws_xray_sdk.core.xray_recorder._context = self.old_xray_context
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "MockXrayClient":
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.stop()
|
||||
|
||||
|
||||
@ -91,12 +92,12 @@ class XRaySegment(object):
|
||||
During testing we're going to have to control the start and end of a segment via context managers.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "XRaySegment":
|
||||
aws_xray_sdk.core.xray_recorder.begin_segment(
|
||||
name="moto_mock", traceid=None, parent_id=None, sampling=1
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
aws_xray_sdk.core.xray_recorder.end_segment()
|
||||
|
@ -1,6 +1,7 @@
|
||||
import bisect
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import json
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel
|
||||
from moto.core.exceptions import AWSError
|
||||
@ -8,54 +9,60 @@ from .exceptions import BadSegmentException
|
||||
|
||||
|
||||
class TelemetryRecords(BaseModel):
|
||||
def __init__(self, instance_id, hostname, resource_arn, records):
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
hostname: str,
|
||||
resource_arn: str,
|
||||
records: List[Dict[str, Any]],
|
||||
):
|
||||
self.instance_id = instance_id
|
||||
self.hostname = hostname
|
||||
self.resource_arn = resource_arn
|
||||
self.records = records
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, src):
|
||||
def from_json(cls, src: Dict[str, Any]) -> "TelemetryRecords": # type: ignore[misc]
|
||||
instance_id = src.get("EC2InstanceId", None)
|
||||
hostname = src.get("Hostname")
|
||||
resource_arn = src.get("ResourceARN")
|
||||
telemetry_records = src["TelemetryRecords"]
|
||||
|
||||
return cls(instance_id, hostname, resource_arn, telemetry_records)
|
||||
return cls(instance_id, hostname, resource_arn, telemetry_records) # type: ignore
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html
|
||||
class TraceSegment(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
segment_id,
|
||||
trace_id,
|
||||
start_time,
|
||||
raw,
|
||||
end_time=None,
|
||||
in_progress=False,
|
||||
service=None,
|
||||
user=None,
|
||||
origin=None,
|
||||
parent_id=None,
|
||||
http=None,
|
||||
aws=None,
|
||||
metadata=None,
|
||||
annotations=None,
|
||||
subsegments=None,
|
||||
**kwargs
|
||||
name: str,
|
||||
segment_id: str,
|
||||
trace_id: str,
|
||||
start_time: float,
|
||||
raw: Any,
|
||||
end_time: Optional[float] = None,
|
||||
in_progress: bool = False,
|
||||
service: Any = None,
|
||||
user: Any = None,
|
||||
origin: Any = None,
|
||||
parent_id: Any = None,
|
||||
http: Any = None,
|
||||
aws: Any = None,
|
||||
metadata: Any = None,
|
||||
annotations: Any = None,
|
||||
subsegments: Any = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
self.name = name
|
||||
self.id = segment_id
|
||||
self.trace_id = trace_id
|
||||
self._trace_version = None
|
||||
self._original_request_start_time = None
|
||||
self._trace_version: Optional[int] = None
|
||||
self._original_request_start_time: Optional[datetime.datetime] = None
|
||||
self._trace_identifier = None
|
||||
self.start_time = start_time
|
||||
self._start_date = None
|
||||
self._start_date: Optional[datetime.datetime] = None
|
||||
self.end_time = end_time
|
||||
self._end_date = None
|
||||
self._end_date: Optional[datetime.datetime] = None
|
||||
self.in_progress = in_progress
|
||||
self.service = service
|
||||
self.user = user
|
||||
@ -71,17 +78,17 @@ class TraceSegment(BaseModel):
|
||||
# Raw json string
|
||||
self.raw = raw
|
||||
|
||||
def __lt__(self, other):
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
return self.start_date < other.start_date
|
||||
|
||||
@property
|
||||
def trace_version(self):
|
||||
def trace_version(self) -> int:
|
||||
if self._trace_version is None:
|
||||
self._trace_version = int(self.trace_id.split("-", 1)[0])
|
||||
return self._trace_version
|
||||
|
||||
@property
|
||||
def request_start_date(self):
|
||||
def request_start_date(self) -> datetime.datetime:
|
||||
if self._original_request_start_time is None:
|
||||
start_time = int(self.trace_id.split("-")[1], 16)
|
||||
self._original_request_start_time = datetime.datetime.fromtimestamp(
|
||||
@ -90,19 +97,19 @@ class TraceSegment(BaseModel):
|
||||
return self._original_request_start_time
|
||||
|
||||
@property
|
||||
def start_date(self):
|
||||
def start_date(self) -> datetime.datetime:
|
||||
if self._start_date is None:
|
||||
self._start_date = datetime.datetime.fromtimestamp(self.start_time)
|
||||
return self._start_date
|
||||
|
||||
@property
|
||||
def end_date(self):
|
||||
def end_date(self) -> datetime.datetime:
|
||||
if self._end_date is None:
|
||||
self._end_date = datetime.datetime.fromtimestamp(self.end_time)
|
||||
self._end_date = datetime.datetime.fromtimestamp(self.end_time) # type: ignore
|
||||
return self._end_date
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data, raw):
|
||||
def from_dict(cls, data: Dict[str, Any], raw: Any) -> "TraceSegment": # type: ignore[misc]
|
||||
# Check manditory args
|
||||
if "id" not in data:
|
||||
raise BadSegmentException(code="MissingParam", message="Missing segment ID")
|
||||
@ -130,11 +137,11 @@ class TraceSegment(BaseModel):
|
||||
|
||||
|
||||
class SegmentCollection(object):
|
||||
def __init__(self):
|
||||
self._traces = defaultdict(self._new_trace_item)
|
||||
def __init__(self) -> None:
|
||||
self._traces: Dict[str, Dict[str, Any]] = defaultdict(self._new_trace_item)
|
||||
|
||||
@staticmethod
|
||||
def _new_trace_item():
|
||||
def _new_trace_item() -> Dict[str, Any]: # type: ignore[misc]
|
||||
return {
|
||||
"start_date": datetime.datetime(1970, 1, 1),
|
||||
"end_date": datetime.datetime(1970, 1, 1),
|
||||
@ -143,7 +150,7 @@ class SegmentCollection(object):
|
||||
"segments": [],
|
||||
}
|
||||
|
||||
def put_segment(self, segment):
|
||||
def put_segment(self, segment: Any) -> None:
|
||||
# insert into a sorted list
|
||||
bisect.insort_left(self._traces[segment.trace_id]["segments"], segment)
|
||||
|
||||
@ -160,11 +167,15 @@ class SegmentCollection(object):
|
||||
# Todo consolidate trace segments into a trace.
|
||||
# not enough working knowledge of xray to do this
|
||||
|
||||
def summary(self, start_time, end_time, filter_expression=None):
|
||||
def summary(
|
||||
self, start_time: str, end_time: str, filter_expression: Any = None
|
||||
) -> Dict[str, Any]:
|
||||
# This beast https://docs.aws.amazon.com/xray/latest/api/API_GetTraceSummaries.html#API_GetTraceSummaries_ResponseSyntax
|
||||
if filter_expression is not None:
|
||||
raise AWSError(
|
||||
"Not implemented yet - moto", code="InternalFailure", status=500
|
||||
"Not implemented yet - moto",
|
||||
exception_type="InternalFailure",
|
||||
status=500,
|
||||
)
|
||||
|
||||
summaries = []
|
||||
@ -213,7 +224,9 @@ class SegmentCollection(object):
|
||||
|
||||
return result
|
||||
|
||||
def get_trace_ids(self, trace_ids):
|
||||
def get_trace_ids(
|
||||
self, trace_ids: List[str]
|
||||
) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
traces = []
|
||||
unprocessed = []
|
||||
|
||||
@ -229,22 +242,24 @@ class SegmentCollection(object):
|
||||
|
||||
|
||||
class XRayBackend(BaseBackend):
|
||||
def __init__(self, region_name, account_id):
|
||||
def __init__(self, region_name: str, account_id: str):
|
||||
super().__init__(region_name, account_id)
|
||||
self._telemetry_records = []
|
||||
self._telemetry_records: List[TelemetryRecords] = []
|
||||
self._segment_collection = SegmentCollection()
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(service_region, zones):
|
||||
def default_vpc_endpoint_service(
|
||||
service_region: str, zones: List[str]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Default VPC endpoint service."""
|
||||
return BaseBackend.default_vpc_endpoint_service_factory(
|
||||
service_region, zones, "xray"
|
||||
)
|
||||
|
||||
def add_telemetry_records(self, src):
|
||||
def add_telemetry_records(self, src: Any) -> None:
|
||||
self._telemetry_records.append(TelemetryRecords.from_json(src))
|
||||
|
||||
def process_segment(self, doc):
|
||||
def process_segment(self, doc: Any) -> None:
|
||||
try:
|
||||
data = json.loads(doc)
|
||||
except ValueError:
|
||||
@ -264,13 +279,15 @@ class XRayBackend(BaseBackend):
|
||||
seg_id=segment.id, code="InternalFailure", message=str(err)
|
||||
)
|
||||
|
||||
def get_trace_summary(self, start_time, end_time, filter_expression):
|
||||
def get_trace_summary(
|
||||
self, start_time: str, end_time: str, filter_expression: Any
|
||||
) -> Dict[str, Any]:
|
||||
return self._segment_collection.summary(start_time, end_time, filter_expression)
|
||||
|
||||
def get_trace_ids(self, trace_ids):
|
||||
def get_trace_ids(self, trace_ids: List[str]) -> Dict[str, Any]:
|
||||
traces, unprocessed_ids = self._segment_collection.get_trace_ids(trace_ids)
|
||||
|
||||
result = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids}
|
||||
result: Dict[str, Any] = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids}
|
||||
|
||||
for trace in traces:
|
||||
segments = []
|
||||
|
@ -1,49 +1,50 @@
|
||||
import json
|
||||
import datetime
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.core.exceptions import AWSError
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from .models import xray_backends
|
||||
from .models import xray_backends, XRayBackend
|
||||
from .exceptions import BadSegmentException
|
||||
|
||||
|
||||
class XRayResponse(BaseResponse):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(service_name="xray")
|
||||
|
||||
def _error(self, code, message):
|
||||
def _error(self, code: str, message: str) -> Tuple[str, Dict[str, int]]:
|
||||
return json.dumps({"__type": code, "message": message}), dict(status=400)
|
||||
|
||||
@property
|
||||
def xray_backend(self):
|
||||
def xray_backend(self) -> XRayBackend:
|
||||
return xray_backends[self.current_account][self.region]
|
||||
|
||||
@property
|
||||
def request_params(self):
|
||||
def request_params(self) -> Any: # type: ignore[misc]
|
||||
try:
|
||||
return json.loads(self.body)
|
||||
except ValueError:
|
||||
return {}
|
||||
|
||||
def _get_param(self, param_name, if_none=None):
|
||||
def _get_param(self, param_name: str, if_none: Any = None) -> Any:
|
||||
return self.request_params.get(param_name, if_none)
|
||||
|
||||
def _get_action(self):
|
||||
def _get_action(self) -> str:
|
||||
# Amazon is just calling urls like /TelemetryRecords etc...
|
||||
# This uses the value after / as the camalcase action, which then
|
||||
# gets converted in call_action to find the following methods
|
||||
return urlsplit(self.uri).path.lstrip("/")
|
||||
|
||||
# PutTelemetryRecords
|
||||
def telemetry_records(self):
|
||||
def telemetry_records(self) -> str:
|
||||
self.xray_backend.add_telemetry_records(self.request_params)
|
||||
|
||||
return ""
|
||||
|
||||
# PutTraceSegments
|
||||
def trace_segments(self):
|
||||
def trace_segments(self) -> Union[str, Tuple[str, Dict[str, int]]]:
|
||||
docs = self._get_param("TraceSegmentDocuments")
|
||||
|
||||
if docs is None:
|
||||
@ -71,7 +72,7 @@ class XRayResponse(BaseResponse):
|
||||
return json.dumps(result)
|
||||
|
||||
# GetTraceSummaries
|
||||
def trace_summaries(self):
|
||||
def trace_summaries(self) -> Union[str, Tuple[str, Dict[str, int]]]:
|
||||
start_time = self._get_param("StartTime")
|
||||
end_time = self._get_param("EndTime")
|
||||
if start_time is None:
|
||||
@ -119,7 +120,7 @@ class XRayResponse(BaseResponse):
|
||||
return json.dumps(result)
|
||||
|
||||
# BatchGetTraces
|
||||
def traces(self):
|
||||
def traces(self) -> Union[str, Tuple[str, Dict[str, int]]]:
|
||||
trace_ids = self._get_param("TraceIds")
|
||||
|
||||
if trace_ids is None:
|
||||
@ -142,7 +143,7 @@ class XRayResponse(BaseResponse):
|
||||
return json.dumps(result)
|
||||
|
||||
# GetServiceGraph - just a dummy response for now
|
||||
def service_graph(self):
|
||||
def service_graph(self) -> Union[str, Tuple[str, Dict[str, int]]]:
|
||||
start_time = self._get_param("StartTime")
|
||||
end_time = self._get_param("EndTime")
|
||||
# next_token = self._get_param('NextToken') # not implemented yet
|
||||
@ -164,7 +165,7 @@ class XRayResponse(BaseResponse):
|
||||
return json.dumps(result)
|
||||
|
||||
# GetTraceGraph - just a dummy response for now
|
||||
def trace_graph(self):
|
||||
def trace_graph(self) -> Union[str, Tuple[str, Dict[str, int]]]:
|
||||
trace_ids = self._get_param("TraceIds")
|
||||
# next_token = self._get_param('NextToken') # not implemented yet
|
||||
|
||||
@ -175,5 +176,5 @@ class XRayResponse(BaseResponse):
|
||||
dict(status=400),
|
||||
)
|
||||
|
||||
result = {"Services": []}
|
||||
result: Dict[str, Any] = {"Services": []}
|
||||
return json.dumps(result)
|
||||
|
@ -239,7 +239,7 @@ disable = W,C,R,E
|
||||
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
|
||||
|
||||
[mypy]
|
||||
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u*,moto/t*
|
||||
files= moto, tests/test_core/test_mock_all.py, tests/test_core/test_decorator_calls.py
|
||||
show_column_numbers=True
|
||||
show_error_codes = True
|
||||
disable_error_code=abstract
|
||||
|
@ -1,10 +1,11 @@
|
||||
import boto3
|
||||
import pytest
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
import unittest
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
from typing import Any
|
||||
from moto import mock_ec2, mock_kinesis, mock_s3, settings
|
||||
from moto.core.models import BaseMockAWS
|
||||
from unittest import SkipTest
|
||||
|
||||
"""
|
||||
@ -13,13 +14,13 @@ Test the different ways that the decorator can be used
|
||||
|
||||
|
||||
@mock_ec2
|
||||
def test_basic_decorator():
|
||||
def test_basic_decorator() -> None:
|
||||
client = boto3.client("ec2", region_name="us-west-1")
|
||||
client.describe_addresses()["Addresses"].should.equal([])
|
||||
assert client.describe_addresses()["Addresses"] == []
|
||||
|
||||
|
||||
@pytest.fixture(name="aws_credentials")
|
||||
def fixture_aws_credentials(monkeypatch):
|
||||
def fixture_aws_credentials(monkeypatch: Any) -> None: # type: ignore[misc]
|
||||
"""Mocked AWS Credentials for moto."""
|
||||
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
|
||||
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
|
||||
@ -28,97 +29,95 @@ def fixture_aws_credentials(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.network
|
||||
def test_context_manager(aws_credentials): # pylint: disable=unused-argument
|
||||
def test_context_manager(aws_credentials: Any) -> None: # type: ignore[misc] # pylint: disable=unused-argument
|
||||
client = boto3.client("ec2", region_name="us-west-1")
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.describe_addresses()
|
||||
err = exc.value.response["Error"]
|
||||
err["Code"].should.equal("AuthFailure")
|
||||
err["Message"].should.equal(
|
||||
"AWS was not able to validate the provided access credentials"
|
||||
assert err["Code"] == "AuthFailure"
|
||||
assert (
|
||||
err["Message"] == "AWS was not able to validate the provided access credentials"
|
||||
)
|
||||
|
||||
with mock_ec2():
|
||||
client = boto3.client("ec2", region_name="us-west-1")
|
||||
client.describe_addresses()["Addresses"].should.equal([])
|
||||
assert client.describe_addresses()["Addresses"] == []
|
||||
|
||||
|
||||
@pytest.mark.network
|
||||
def test_decorator_start_and_stop():
|
||||
def test_decorator_start_and_stop() -> None:
|
||||
if settings.TEST_SERVER_MODE:
|
||||
raise SkipTest("Authentication always works in ServerMode")
|
||||
mock = mock_ec2()
|
||||
mock: BaseMockAWS = mock_ec2()
|
||||
mock.start()
|
||||
client = boto3.client("ec2", region_name="us-west-1")
|
||||
client.describe_addresses()["Addresses"].should.equal([])
|
||||
assert client.describe_addresses()["Addresses"] == []
|
||||
mock.stop()
|
||||
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.describe_addresses()
|
||||
err = exc.value.response["Error"]
|
||||
err["Code"].should.equal("AuthFailure")
|
||||
err["Message"].should.equal(
|
||||
"AWS was not able to validate the provided access credentials"
|
||||
assert err["Code"] == "AuthFailure"
|
||||
assert (
|
||||
err["Message"] == "AWS was not able to validate the provided access credentials"
|
||||
)
|
||||
|
||||
|
||||
@mock_ec2
|
||||
def test_decorater_wrapped_gets_set():
|
||||
def test_decorater_wrapped_gets_set() -> None:
|
||||
"""
|
||||
Moto decorator's __wrapped__ should get set to the tests function
|
||||
"""
|
||||
test_decorater_wrapped_gets_set.__wrapped__.__name__.should.equal(
|
||||
"test_decorater_wrapped_gets_set"
|
||||
)
|
||||
assert test_decorater_wrapped_gets_set.__wrapped__.__name__ == "test_decorater_wrapped_gets_set" # type: ignore
|
||||
|
||||
|
||||
@mock_ec2
|
||||
class Tester(object):
|
||||
def test_the_class(self):
|
||||
class Tester:
|
||||
def test_the_class(self) -> None:
|
||||
client = boto3.client("ec2", region_name="us-west-1")
|
||||
client.describe_addresses()["Addresses"].should.equal([])
|
||||
assert client.describe_addresses()["Addresses"] == []
|
||||
|
||||
def test_still_the_same(self):
|
||||
def test_still_the_same(self) -> None:
|
||||
client = boto3.client("ec2", region_name="us-west-1")
|
||||
client.describe_addresses()["Addresses"].should.equal([])
|
||||
assert client.describe_addresses()["Addresses"] == []
|
||||
|
||||
|
||||
@mock_s3
|
||||
class TesterWithSetup(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.client = boto3.client("s3")
|
||||
self.client.create_bucket(Bucket="mybucket")
|
||||
|
||||
def test_still_the_same(self):
|
||||
def test_still_the_same(self) -> None:
|
||||
buckets = self.client.list_buckets()["Buckets"]
|
||||
bucket_names = [b["Name"] for b in buckets]
|
||||
# There is a potential bug in the class-decorator, where the reset API is not called on start.
|
||||
# This leads to a situation where 'bucket_names' may contain buckets created by earlier tests
|
||||
bucket_names.should.contain("mybucket")
|
||||
assert "mybucket" in bucket_names
|
||||
|
||||
|
||||
@mock_s3
|
||||
class TesterWithStaticmethod(object):
|
||||
class TesterWithStaticmethod:
|
||||
@staticmethod
|
||||
def static(*args):
|
||||
def static(*args: Any) -> None: # type: ignore[misc]
|
||||
assert not args or not isinstance(args[0], TesterWithStaticmethod)
|
||||
|
||||
def test_no_instance_sent_to_staticmethod(self):
|
||||
def test_no_instance_sent_to_staticmethod(self) -> None:
|
||||
self.static()
|
||||
|
||||
|
||||
@mock_s3
|
||||
class TestWithSetup_UppercaseU(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
# This method will be executed automatically, provided we extend the TestCase-class
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.create_bucket(Bucket="mybucket")
|
||||
|
||||
def test_should_find_bucket(self):
|
||||
def test_should_find_bucket(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
self.assertIsNotNone(s3.head_bucket(Bucket="mybucket"))
|
||||
|
||||
def test_should_not_find_unknown_bucket(self):
|
||||
def test_should_not_find_unknown_bucket(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
with pytest.raises(ClientError):
|
||||
s3.head_bucket(Bucket="unknown_bucket")
|
||||
@ -126,16 +125,16 @@ class TestWithSetup_UppercaseU(unittest.TestCase):
|
||||
|
||||
@mock_s3
|
||||
class TestWithSetup_LowercaseU:
|
||||
def setup_method(self, *args): # pylint: disable=unused-argument
|
||||
def setup_method(self, *args: Any) -> None: # pylint: disable=unused-argument
|
||||
# This method will be executed automatically using pytest
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.create_bucket(Bucket="mybucket")
|
||||
|
||||
def test_should_find_bucket(self):
|
||||
def test_should_find_bucket(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
assert s3.head_bucket(Bucket="mybucket") is not None
|
||||
|
||||
def test_should_not_find_unknown_bucket(self):
|
||||
def test_should_not_find_unknown_bucket(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
with pytest.raises(ClientError):
|
||||
s3.head_bucket(Bucket="unknown_bucket")
|
||||
@ -143,16 +142,16 @@ class TestWithSetup_LowercaseU:
|
||||
|
||||
@mock_s3
|
||||
class TestWithSetupMethod:
|
||||
def setup_method(self, *args): # pylint: disable=unused-argument
|
||||
def setup_method(self, *args: Any) -> None: # pylint: disable=unused-argument
|
||||
# This method will be executed automatically using pytest
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.create_bucket(Bucket="mybucket")
|
||||
|
||||
def test_should_find_bucket(self):
|
||||
def test_should_find_bucket(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
assert s3.head_bucket(Bucket="mybucket") is not None
|
||||
|
||||
def test_should_not_find_unknown_bucket(self):
|
||||
def test_should_not_find_unknown_bucket(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
with pytest.raises(ClientError):
|
||||
s3.head_bucket(Bucket="unknown_bucket")
|
||||
@ -160,17 +159,17 @@ class TestWithSetupMethod:
|
||||
|
||||
@mock_kinesis
|
||||
class TestKinesisUsingSetupMethod:
|
||||
def setup_method(self, *args): # pylint: disable=unused-argument
|
||||
def setup_method(self, *args: Any) -> None: # pylint: disable=unused-argument
|
||||
self.stream_name = "test_stream"
|
||||
self.boto3_kinesis_client = boto3.client("kinesis", region_name="us-east-1")
|
||||
self.boto3_kinesis_client.create_stream(
|
||||
StreamName=self.stream_name, ShardCount=1
|
||||
)
|
||||
|
||||
def test_stream_creation(self):
|
||||
def test_stream_creation(self) -> None:
|
||||
pass
|
||||
|
||||
def test_stream_recreation(self):
|
||||
def test_stream_recreation(self) -> None:
|
||||
# The setup-method will run again for this test
|
||||
# The fact that it passes, means the state was reset
|
||||
# Otherwise it would complain about a stream already existing
|
||||
@ -179,11 +178,11 @@ class TestKinesisUsingSetupMethod:
|
||||
|
||||
@mock_s3
|
||||
class TestWithInvalidSetupMethod:
|
||||
def setupmethod(self):
|
||||
def setupmethod(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.create_bucket(Bucket="mybucket")
|
||||
|
||||
def test_should_not_find_bucket(self):
|
||||
def test_should_not_find_bucket(self) -> None:
|
||||
# Name of setupmethod is not recognized, so it will not be executed
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
with pytest.raises(ClientError):
|
||||
@ -192,17 +191,17 @@ class TestWithInvalidSetupMethod:
|
||||
|
||||
@mock_s3
|
||||
class TestWithPublicMethod(unittest.TestCase):
|
||||
def ensure_bucket_exists(self):
|
||||
def ensure_bucket_exists(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.create_bucket(Bucket="mybucket")
|
||||
|
||||
def test_should_find_bucket(self):
|
||||
def test_should_find_bucket(self) -> None:
|
||||
self.ensure_bucket_exists()
|
||||
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.head_bucket(Bucket="mybucket").shouldnt.equal(None)
|
||||
|
||||
def test_should_not_find_bucket(self):
|
||||
def test_should_not_find_bucket(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
with pytest.raises(ClientError):
|
||||
s3.head_bucket(Bucket="mybucket")
|
||||
@ -210,16 +209,16 @@ class TestWithPublicMethod(unittest.TestCase):
|
||||
|
||||
@mock_s3
|
||||
class TestWithPseudoPrivateMethod(unittest.TestCase):
|
||||
def _ensure_bucket_exists(self):
|
||||
def _ensure_bucket_exists(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.create_bucket(Bucket="mybucket")
|
||||
|
||||
def test_should_find_bucket(self):
|
||||
def test_should_find_bucket(self) -> None:
|
||||
self._ensure_bucket_exists()
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.head_bucket(Bucket="mybucket").shouldnt.equal(None)
|
||||
|
||||
def test_should_not_find_bucket(self):
|
||||
def test_should_not_find_bucket(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
with pytest.raises(ClientError):
|
||||
s3.head_bucket(Bucket="mybucket")
|
||||
@ -227,20 +226,20 @@ class TestWithPseudoPrivateMethod(unittest.TestCase):
|
||||
|
||||
@mock_s3
|
||||
class Baseclass(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.s3 = boto3.resource("s3", region_name="us-east-1")
|
||||
self.client = boto3.client("s3", region_name="us-east-1")
|
||||
self.test_bucket = self.s3.Bucket("testbucket")
|
||||
self.test_bucket.create()
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
# The bucket will still exist at this point
|
||||
self.test_bucket.delete()
|
||||
|
||||
|
||||
@mock_s3
|
||||
class TestSetUpInBaseClass(Baseclass):
|
||||
def test_a_thing(self):
|
||||
def test_a_thing(self) -> None:
|
||||
# Verify that we can 'see' the setUp-method in the parent class
|
||||
self.client.head_bucket(Bucket="testbucket").shouldnt.equal(None)
|
||||
|
||||
@ -248,42 +247,42 @@ class TestSetUpInBaseClass(Baseclass):
|
||||
@mock_s3
|
||||
class TestWithNestedClasses:
|
||||
class NestedClass(unittest.TestCase):
|
||||
def _ensure_bucket_exists(self):
|
||||
def _ensure_bucket_exists(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.create_bucket(Bucket="bucketclass1")
|
||||
|
||||
def test_should_find_bucket(self):
|
||||
def test_should_find_bucket(self) -> None:
|
||||
self._ensure_bucket_exists()
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.head_bucket(Bucket="bucketclass1")
|
||||
|
||||
class NestedClass2(unittest.TestCase):
|
||||
def _ensure_bucket_exists(self):
|
||||
def _ensure_bucket_exists(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.create_bucket(Bucket="bucketclass2")
|
||||
|
||||
def test_should_find_bucket(self):
|
||||
def test_should_find_bucket(self) -> None:
|
||||
self._ensure_bucket_exists()
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.head_bucket(Bucket="bucketclass2")
|
||||
|
||||
def test_should_not_find_bucket_from_different_class(self):
|
||||
def test_should_not_find_bucket_from_different_class(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
with pytest.raises(ClientError):
|
||||
s3.head_bucket(Bucket="bucketclass1")
|
||||
|
||||
class TestWithSetup(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.create_bucket(Bucket="mybucket")
|
||||
|
||||
def test_should_find_bucket(self):
|
||||
def test_should_find_bucket(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.head_bucket(Bucket="mybucket")
|
||||
|
||||
s3.create_bucket(Bucket="bucketinsidetest")
|
||||
|
||||
def test_should_not_find_bucket_from_test_method(self):
|
||||
def test_should_not_find_bucket_from_test_method(self) -> None:
|
||||
s3 = boto3.client("s3", region_name="us-east-1")
|
||||
s3.head_bucket(Bucket="mybucket")
|
||||
|
||||
|
@ -6,7 +6,7 @@ from moto import mock_all
|
||||
|
||||
|
||||
@mock_all()
|
||||
def test_decorator():
|
||||
def test_decorator() -> None:
|
||||
rgn = "us-east-1"
|
||||
sqs = boto3.client("sqs", region_name=rgn)
|
||||
r = sqs.list_queues()
|
||||
@ -21,7 +21,7 @@ def test_decorator():
|
||||
r["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
|
||||
|
||||
|
||||
def test_context_manager():
|
||||
def test_context_manager() -> None:
|
||||
rgn = "us-east-1"
|
||||
|
||||
with mock_all():
|
||||
|
Loading…
Reference in New Issue
Block a user