Techdebt: MyPy all the things (#6406)
This commit is contained in:
		
							parent
							
								
									2e3b06bbe5
								
							
						
					
					
						commit
						0247719554
					
				@ -1,17 +1,27 @@
 | 
			
		||||
import importlib
 | 
			
		||||
import sys
 | 
			
		||||
from contextlib import ContextDecorator
 | 
			
		||||
from moto.core.models import BaseMockAWS
 | 
			
		||||
from typing import Any, Callable, List, Optional, TypeVar
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def lazy_load(module_name, element, boto3_name=None, backend=None):
 | 
			
		||||
    def f(*args, **kwargs):
 | 
			
		||||
TEST_METHOD = TypeVar("TEST_METHOD", bound=Callable[..., Any])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def lazy_load(
 | 
			
		||||
    module_name: str,
 | 
			
		||||
    element: str,
 | 
			
		||||
    boto3_name: Optional[str] = None,
 | 
			
		||||
    backend: Optional[str] = None,
 | 
			
		||||
) -> Callable[..., BaseMockAWS]:
 | 
			
		||||
    def f(*args: Any, **kwargs: Any) -> Any:
 | 
			
		||||
        module = importlib.import_module(module_name, "moto")
 | 
			
		||||
        return getattr(module, element)(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    setattr(f, "name", module_name.replace(".", ""))
 | 
			
		||||
    setattr(f, "element", element)
 | 
			
		||||
    setattr(f, "boto3_name", boto3_name or f.name)
 | 
			
		||||
    setattr(f, "backend", backend or f"{f.name}_backends")
 | 
			
		||||
    setattr(f, "boto3_name", boto3_name or f.name)  # type: ignore[attr-defined]
 | 
			
		||||
    setattr(f, "backend", backend or f"{f.name}_backends")  # type: ignore[attr-defined]
 | 
			
		||||
    return f
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -176,17 +186,17 @@ mock_textract = lazy_load(".textract", "mock_textract")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockAll(ContextDecorator):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.mocks = []
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        self.mocks: List[Any] = []
 | 
			
		||||
        for mock in dir(sys.modules["moto"]):
 | 
			
		||||
            if mock.startswith("mock_") and not mock == ("mock_all"):
 | 
			
		||||
            if mock.startswith("mock_") and not mock == "mock_all":
 | 
			
		||||
                self.mocks.append(globals()[mock]())
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
    def __enter__(self) -> None:
 | 
			
		||||
        for mock in self.mocks:
 | 
			
		||||
            mock.start()
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, *exc):
 | 
			
		||||
    def __exit__(self, *exc: Any) -> None:
 | 
			
		||||
        for mock in self.mocks:
 | 
			
		||||
            mock.stop()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -626,7 +626,7 @@ class Stage(BaseModel):
 | 
			
		||||
        self.tags = tags
 | 
			
		||||
        self.tracing_enabled = tracing_enabled
 | 
			
		||||
        self.access_log_settings: Optional[Dict[str, Any]] = None
 | 
			
		||||
        self.web_acl_arn = None
 | 
			
		||||
        self.web_acl_arn: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    def to_json(self) -> Dict[str, Any]:
 | 
			
		||||
        dct: Dict[str, Any] = {
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,8 @@ import os
 | 
			
		||||
import re
 | 
			
		||||
import unittest
 | 
			
		||||
from types import FunctionType
 | 
			
		||||
from typing import Any, Callable, Dict, Optional, Set, TypeVar
 | 
			
		||||
from typing import Any, Callable, Dict, Optional, Set, TypeVar, Union
 | 
			
		||||
from typing import ContextManager
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import boto3
 | 
			
		||||
@ -29,7 +30,7 @@ DEFAULT_ACCOUNT_ID = "123456789012"
 | 
			
		||||
CALLABLE_RETURN = TypeVar("CALLABLE_RETURN")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseMockAWS:
 | 
			
		||||
class BaseMockAWS(ContextManager["BaseMockAWS"]):
 | 
			
		||||
    nested_count = 0
 | 
			
		||||
    mocks_active = False
 | 
			
		||||
 | 
			
		||||
@ -65,10 +66,13 @@ class BaseMockAWS:
 | 
			
		||||
            self.reset()  # type: ignore[attr-defined]
 | 
			
		||||
 | 
			
		||||
    def __call__(
 | 
			
		||||
        self, func: Callable[..., Any], reset: bool = True, remove_data: bool = True
 | 
			
		||||
    ) -> Any:
 | 
			
		||||
        self,
 | 
			
		||||
        func: Callable[..., "BaseMockAWS"],
 | 
			
		||||
        reset: bool = True,
 | 
			
		||||
        remove_data: bool = True,
 | 
			
		||||
    ) -> Callable[..., "BaseMockAWS"]:
 | 
			
		||||
        if inspect.isclass(func):
 | 
			
		||||
            return self.decorate_class(func)
 | 
			
		||||
            return self.decorate_class(func)  # type: ignore
 | 
			
		||||
        return self.decorate_callable(func, reset, remove_data)
 | 
			
		||||
 | 
			
		||||
    def __enter__(self) -> "BaseMockAWS":
 | 
			
		||||
@ -116,9 +120,9 @@ class BaseMockAWS:
 | 
			
		||||
            self.disable_patching()  # type: ignore[attr-defined]
 | 
			
		||||
 | 
			
		||||
    def decorate_callable(
 | 
			
		||||
        self, func: Callable[..., CALLABLE_RETURN], reset: bool, remove_data: bool
 | 
			
		||||
    ) -> Callable[..., CALLABLE_RETURN]:
 | 
			
		||||
        def wrapper(*args: Any, **kwargs: Any) -> CALLABLE_RETURN:
 | 
			
		||||
        self, func: Callable[..., "BaseMockAWS"], reset: bool, remove_data: bool
 | 
			
		||||
    ) -> Callable[..., "BaseMockAWS"]:
 | 
			
		||||
        def wrapper(*args: Any, **kwargs: Any) -> "BaseMockAWS":
 | 
			
		||||
            self.start(reset=reset)
 | 
			
		||||
            try:
 | 
			
		||||
                result = func(*args, **kwargs)
 | 
			
		||||
@ -361,7 +365,6 @@ MockAWS = BotocoreEventMockAWS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServerModeMockAWS(BaseMockAWS):
 | 
			
		||||
 | 
			
		||||
    RESET_IN_PROGRESS = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args: Any, **kwargs: Any):
 | 
			
		||||
@ -427,7 +430,9 @@ class base_decorator:
 | 
			
		||||
    def __init__(self, backends: BackendDict):
 | 
			
		||||
        self.backends = backends
 | 
			
		||||
 | 
			
		||||
    def __call__(self, func: Optional[Callable[..., Any]] = None) -> BaseMockAWS:
 | 
			
		||||
    def __call__(
 | 
			
		||||
        self, func: Optional[Callable[..., Any]] = None
 | 
			
		||||
    ) -> Union[BaseMockAWS, Callable[..., BaseMockAWS]]:
 | 
			
		||||
        if settings.TEST_SERVER_MODE:
 | 
			
		||||
            mocked_backend: BaseMockAWS = ServerModeMockAWS(self.backends)
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ class WAFv2ClientError(JsonRESTError):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WAFV2DuplicateItemException(WAFv2ClientError):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            "WafV2DuplicateItem",
 | 
			
		||||
            "AWS WAF could not perform the operation because some resource in your request is a duplicate of an existing one.",
 | 
			
		||||
@ -14,7 +14,7 @@ class WAFV2DuplicateItemException(WAFv2ClientError):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WAFNonexistentItemException(WAFv2ClientError):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            "WAFNonexistentItemException",
 | 
			
		||||
            "AWS WAF couldn’t perform the operation because your resource doesn’t exist.",
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
import datetime
 | 
			
		||||
import re
 | 
			
		||||
from typing import Dict
 | 
			
		||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
 | 
			
		||||
from moto.core import BaseBackend, BackendDict, BaseModel
 | 
			
		||||
 | 
			
		||||
from .utils import make_arn_for_wacl
 | 
			
		||||
@ -10,6 +10,9 @@ from moto.moto_api._internal import mock_random
 | 
			
		||||
from moto.utilities.tagging_service import TaggingService
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from moto.apigateway.models import Stage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
US_EAST_1_REGION = "us-east-1"
 | 
			
		||||
GLOBAL_REGION = "global"
 | 
			
		||||
@ -25,7 +28,14 @@ class FakeWebACL(BaseModel):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, name, arn, wacl_id, visibility_config, default_action, description, rules
 | 
			
		||||
        self,
 | 
			
		||||
        name: str,
 | 
			
		||||
        arn: str,
 | 
			
		||||
        wacl_id: str,
 | 
			
		||||
        visibility_config: Dict[str, Any],
 | 
			
		||||
        default_action: Dict[str, Any],
 | 
			
		||||
        description: Optional[str],
 | 
			
		||||
        rules: List[Dict[str, Any]],
 | 
			
		||||
    ):
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.created_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
 | 
			
		||||
@ -38,7 +48,13 @@ class FakeWebACL(BaseModel):
 | 
			
		||||
        self.default_action = default_action
 | 
			
		||||
        self.lock_token = str(mock_random.uuid4())[0:6]
 | 
			
		||||
 | 
			
		||||
    def update(self, default_action, rules, description, visibility_config):
 | 
			
		||||
    def update(
 | 
			
		||||
        self,
 | 
			
		||||
        default_action: Optional[Dict[str, Any]],
 | 
			
		||||
        rules: Optional[List[Dict[str, Any]]],
 | 
			
		||||
        description: Optional[str],
 | 
			
		||||
        visibility_config: Optional[Dict[str, Any]],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        if default_action is not None:
 | 
			
		||||
            self.default_action = default_action
 | 
			
		||||
        if rules is not None:
 | 
			
		||||
@ -49,7 +65,7 @@ class FakeWebACL(BaseModel):
 | 
			
		||||
            self.visibility_config = visibility_config
 | 
			
		||||
        self.lock_token = str(mock_random.uuid4())[0:6]
 | 
			
		||||
 | 
			
		||||
    def to_dict(self):
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        # Format for summary https://docs.aws.amazon.com/waf/latest/APIReference/API_CreateWebACL.html (response syntax section)
 | 
			
		||||
        return {
 | 
			
		||||
            "ARN": self.arn,
 | 
			
		||||
@ -67,13 +83,13 @@ class WAFV2Backend(BaseBackend):
 | 
			
		||||
    https://docs.aws.amazon.com/waf/latest/APIReference/API_Operations_AWS_WAFV2.html
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, region_name, account_id):
 | 
			
		||||
    def __init__(self, region_name: str, account_id: str):
 | 
			
		||||
        super().__init__(region_name, account_id)
 | 
			
		||||
        self.wacls: Dict[str, FakeWebACL] = OrderedDict()
 | 
			
		||||
        self.tagging_service = TaggingService()
 | 
			
		||||
        # TODO: self.load_balancers = OrderedDict()
 | 
			
		||||
 | 
			
		||||
    def associate_web_acl(self, web_acl_arn, resource_arn):
 | 
			
		||||
    def associate_web_acl(self, web_acl_arn: str, resource_arn: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Only APIGateway Stages can be associated at the moment.
 | 
			
		||||
        """
 | 
			
		||||
@ -83,19 +99,19 @@ class WAFV2Backend(BaseBackend):
 | 
			
		||||
        if stage:
 | 
			
		||||
            stage.web_acl_arn = web_acl_arn
 | 
			
		||||
 | 
			
		||||
    def disassociate_web_acl(self, resource_arn):
 | 
			
		||||
    def disassociate_web_acl(self, resource_arn: str) -> None:
 | 
			
		||||
        stage = self._find_apigw_stage(resource_arn)
 | 
			
		||||
        if stage:
 | 
			
		||||
            stage.web_acl_arn = None
 | 
			
		||||
 | 
			
		||||
    def get_web_acl_for_resource(self, resource_arn):
 | 
			
		||||
    def get_web_acl_for_resource(self, resource_arn: str) -> Optional[FakeWebACL]:
 | 
			
		||||
        stage = self._find_apigw_stage(resource_arn)
 | 
			
		||||
        if stage and stage.web_acl_arn is not None:
 | 
			
		||||
            wacl_arn = stage.web_acl_arn
 | 
			
		||||
            return self.wacls.get(wacl_arn)
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def _find_apigw_stage(self, resource_arn):
 | 
			
		||||
    def _find_apigw_stage(self, resource_arn: str) -> Optional["Stage"]:  # type: ignore
 | 
			
		||||
        try:
 | 
			
		||||
            if re.search(APIGATEWAY_REGEX, resource_arn):
 | 
			
		||||
                region = resource_arn.split(":")[3]
 | 
			
		||||
@ -110,8 +126,15 @@ class WAFV2Backend(BaseBackend):
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def create_web_acl(
 | 
			
		||||
        self, name, visibility_config, default_action, scope, description, tags, rules
 | 
			
		||||
    ):
 | 
			
		||||
        self,
 | 
			
		||||
        name: str,
 | 
			
		||||
        visibility_config: Dict[str, Any],
 | 
			
		||||
        default_action: Dict[str, Any],
 | 
			
		||||
        scope: str,
 | 
			
		||||
        description: str,
 | 
			
		||||
        tags: List[Dict[str, str]],
 | 
			
		||||
        rules: List[Dict[str, Any]],
 | 
			
		||||
    ) -> FakeWebACL:
 | 
			
		||||
        """
 | 
			
		||||
        The following parameters are not yet implemented: CustomResponseBodies, CaptchaConfig
 | 
			
		||||
        """
 | 
			
		||||
@ -132,7 +155,7 @@ class WAFV2Backend(BaseBackend):
 | 
			
		||||
        self.tag_resource(arn, tags)
 | 
			
		||||
        return new_wacl
 | 
			
		||||
 | 
			
		||||
    def delete_web_acl(self, name, _id):
 | 
			
		||||
    def delete_web_acl(self, name: str, _id: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        The LockToken-parameter is not yet implemented
 | 
			
		||||
        """
 | 
			
		||||
@ -142,37 +165,43 @@ class WAFV2Backend(BaseBackend):
 | 
			
		||||
            if wacl.name != name and wacl.id != _id
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def get_web_acl(self, name, _id) -> FakeWebACL:
 | 
			
		||||
    def get_web_acl(self, name: str, _id: str) -> FakeWebACL:
 | 
			
		||||
        for wacl in self.wacls.values():
 | 
			
		||||
            if wacl.name == name and wacl.id == _id:
 | 
			
		||||
                return wacl
 | 
			
		||||
        raise WAFNonexistentItemException
 | 
			
		||||
 | 
			
		||||
    def list_web_acls(self):
 | 
			
		||||
    def list_web_acls(self) -> List[Dict[str, Any]]:
 | 
			
		||||
        return [wacl.to_dict() for wacl in self.wacls.values()]
 | 
			
		||||
 | 
			
		||||
    def _is_duplicate_name(self, name):
 | 
			
		||||
    def _is_duplicate_name(self, name: str) -> bool:
 | 
			
		||||
        allWaclNames = set(wacl.name for wacl in self.wacls.values())
 | 
			
		||||
        return name in allWaclNames
 | 
			
		||||
 | 
			
		||||
    def list_rule_groups(self):
 | 
			
		||||
    def list_rule_groups(self) -> List[Any]:
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    def list_tags_for_resource(self, arn):
 | 
			
		||||
    def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]:
 | 
			
		||||
        """
 | 
			
		||||
        Pagination is not yet implemented
 | 
			
		||||
        """
 | 
			
		||||
        return self.tagging_service.list_tags_for_resource(arn)["Tags"]
 | 
			
		||||
 | 
			
		||||
    def tag_resource(self, arn, tags):
 | 
			
		||||
    def tag_resource(self, arn: str, tags: List[Dict[str, str]]) -> None:
 | 
			
		||||
        self.tagging_service.tag_resource(arn, tags)
 | 
			
		||||
 | 
			
		||||
    def untag_resource(self, arn, tag_keys):
 | 
			
		||||
    def untag_resource(self, arn: str, tag_keys: List[str]) -> None:
 | 
			
		||||
        self.tagging_service.untag_resource_using_names(arn, tag_keys)
 | 
			
		||||
 | 
			
		||||
    def update_web_acl(
 | 
			
		||||
        self, name, _id, default_action, rules, description, visibility_config
 | 
			
		||||
    ):
 | 
			
		||||
        self,
 | 
			
		||||
        name: str,
 | 
			
		||||
        _id: str,
 | 
			
		||||
        default_action: Optional[Dict[str, Any]],
 | 
			
		||||
        rules: Optional[List[Dict[str, Any]]],
 | 
			
		||||
        description: Optional[str],
 | 
			
		||||
        visibility_config: Optional[Dict[str, Any]],
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        The following parameters are not yet implemented: LockToken, CustomResponseBodies, CaptchaConfig
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -1,19 +1,20 @@
 | 
			
		||||
import json
 | 
			
		||||
from moto.core.common_types import TYPE_RESPONSE
 | 
			
		||||
from moto.core.responses import BaseResponse
 | 
			
		||||
from moto.utilities.aws_headers import amzn_request_id
 | 
			
		||||
from .models import GLOBAL_REGION, wafv2_backends
 | 
			
		||||
from .models import GLOBAL_REGION, wafv2_backends, WAFV2Backend
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WAFV2Response(BaseResponse):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__(service_name="wafv2")
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def wafv2_backend(self):
 | 
			
		||||
    def wafv2_backend(self) -> WAFV2Backend:
 | 
			
		||||
        return wafv2_backends[self.current_account][self.region]
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def associate_web_acl(self):
 | 
			
		||||
    def associate_web_acl(self) -> TYPE_RESPONSE:
 | 
			
		||||
        body = json.loads(self.body)
 | 
			
		||||
        web_acl_arn = body["WebACLArn"]
 | 
			
		||||
        resource_arn = body["ResourceArn"]
 | 
			
		||||
@ -21,14 +22,14 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, {}, "{}"
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def disassociate_web_acl(self):
 | 
			
		||||
    def disassociate_web_acl(self) -> TYPE_RESPONSE:
 | 
			
		||||
        body = json.loads(self.body)
 | 
			
		||||
        resource_arn = body["ResourceArn"]
 | 
			
		||||
        self.wafv2_backend.disassociate_web_acl(resource_arn)
 | 
			
		||||
        return 200, {}, "{}"
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def get_web_acl_for_resource(self):
 | 
			
		||||
    def get_web_acl_for_resource(self) -> TYPE_RESPONSE:
 | 
			
		||||
        body = json.loads(self.body)
 | 
			
		||||
        resource_arn = body["ResourceArn"]
 | 
			
		||||
        web_acl = self.wafv2_backend.get_web_acl_for_resource(resource_arn)
 | 
			
		||||
@ -37,7 +38,7 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, response_headers, json.dumps(response)
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def create_web_acl(self):
 | 
			
		||||
    def create_web_acl(self) -> TYPE_RESPONSE:
 | 
			
		||||
        """https://docs.aws.amazon.com/waf/latest/APIReference/API_CreateWebACL.html (response syntax section)"""
 | 
			
		||||
 | 
			
		||||
        scope = self._get_param("Scope")
 | 
			
		||||
@ -62,7 +63,7 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, response_headers, json.dumps(response)
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def delete_web_acl(self):
 | 
			
		||||
    def delete_web_acl(self) -> TYPE_RESPONSE:
 | 
			
		||||
        scope = self._get_param("Scope")
 | 
			
		||||
        if scope == "CLOUDFRONT":
 | 
			
		||||
            self.region = GLOBAL_REGION
 | 
			
		||||
@ -73,7 +74,7 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, response_headers, "{}"
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def get_web_acl(self):
 | 
			
		||||
    def get_web_acl(self) -> TYPE_RESPONSE:
 | 
			
		||||
        scope = self._get_param("Scope")
 | 
			
		||||
        if scope == "CLOUDFRONT":
 | 
			
		||||
            self.region = GLOBAL_REGION
 | 
			
		||||
@ -85,7 +86,7 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, response_headers, json.dumps(response)
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def list_web_ac_ls(self):
 | 
			
		||||
    def list_web_ac_ls(self) -> TYPE_RESPONSE:
 | 
			
		||||
        """https://docs.aws.amazon.com/waf/latest/APIReference/API_ListWebACLs.html (response syntax section)"""
 | 
			
		||||
 | 
			
		||||
        scope = self._get_param("Scope")
 | 
			
		||||
@ -97,7 +98,7 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, response_headers, json.dumps(response)
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def list_rule_groups(self):
 | 
			
		||||
    def list_rule_groups(self) -> TYPE_RESPONSE:
 | 
			
		||||
        scope = self._get_param("Scope")
 | 
			
		||||
        if scope == "CLOUDFRONT":
 | 
			
		||||
            self.region = GLOBAL_REGION
 | 
			
		||||
@ -107,7 +108,7 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, response_headers, json.dumps(response)
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def list_tags_for_resource(self):
 | 
			
		||||
    def list_tags_for_resource(self) -> TYPE_RESPONSE:
 | 
			
		||||
        arn = self._get_param("ResourceARN")
 | 
			
		||||
        self.region = arn.split(":")[3]
 | 
			
		||||
        tags = self.wafv2_backend.list_tags_for_resource(arn)
 | 
			
		||||
@ -116,7 +117,7 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, response_headers, json.dumps(response)
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def tag_resource(self):
 | 
			
		||||
    def tag_resource(self) -> TYPE_RESPONSE:
 | 
			
		||||
        body = json.loads(self.body)
 | 
			
		||||
        arn = body.get("ResourceARN")
 | 
			
		||||
        self.region = arn.split(":")[3]
 | 
			
		||||
@ -125,7 +126,7 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, {}, "{}"
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def untag_resource(self):
 | 
			
		||||
    def untag_resource(self) -> TYPE_RESPONSE:
 | 
			
		||||
        body = json.loads(self.body)
 | 
			
		||||
        arn = body.get("ResourceARN")
 | 
			
		||||
        self.region = arn.split(":")[3]
 | 
			
		||||
@ -134,7 +135,7 @@ class WAFV2Response(BaseResponse):
 | 
			
		||||
        return 200, {}, "{}"
 | 
			
		||||
 | 
			
		||||
    @amzn_request_id
 | 
			
		||||
    def update_web_acl(self):
 | 
			
		||||
    def update_web_acl(self) -> TYPE_RESPONSE:
 | 
			
		||||
        body = json.loads(self.body)
 | 
			
		||||
        name = body.get("Name")
 | 
			
		||||
        _id = body.get("Id")
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,6 @@
 | 
			
		||||
def make_arn_for_wacl(name, account_id, region_name, wacl_id, scope):
 | 
			
		||||
def make_arn_for_wacl(
 | 
			
		||||
    name: str, account_id: str, region_name: str, wacl_id: str, scope: str
 | 
			
		||||
) -> str:
 | 
			
		||||
    """https://docs.aws.amazon.com/waf/latest/developerguide/how-aws-waf-works.html - explains --scope (cloudfront vs regional)"""
 | 
			
		||||
 | 
			
		||||
    if scope == "REGIONAL":
 | 
			
		||||
 | 
			
		||||
@ -1,14 +1,22 @@
 | 
			
		||||
from typing import Any, Dict, Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BadSegmentException(Exception):
 | 
			
		||||
    def __init__(self, seg_id=None, code=None, message=None):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        seg_id: Optional[str] = None,
 | 
			
		||||
        code: Optional[str] = None,
 | 
			
		||||
        message: Optional[str] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.id = seg_id
 | 
			
		||||
        self.code = code
 | 
			
		||||
        self.message = message
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return f"<BadSegment {self.id}-{self.code}-{self.message}>"
 | 
			
		||||
 | 
			
		||||
    def to_dict(self):
 | 
			
		||||
        result = {}
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        result: Dict[str, Any] = {}
 | 
			
		||||
        if self.id is not None:
 | 
			
		||||
            result["Id"] = self.id
 | 
			
		||||
        if self.code is not None:
 | 
			
		||||
 | 
			
		||||
@ -1,25 +1,26 @@
 | 
			
		||||
import os
 | 
			
		||||
from moto.xray import xray_backends
 | 
			
		||||
from typing import Any
 | 
			
		||||
from moto.xray.models import xray_backends, XRayBackend
 | 
			
		||||
import aws_xray_sdk.core
 | 
			
		||||
from aws_xray_sdk.core.context import Context as AWSContext
 | 
			
		||||
from aws_xray_sdk.core.emitters.udp_emitter import UDPEmitter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockEmitter(UDPEmitter):
 | 
			
		||||
class MockEmitter(UDPEmitter):  # type: ignore
 | 
			
		||||
    """
 | 
			
		||||
    Replaces the code that sends UDP to local X-Ray daemon
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, daemon_address="127.0.0.1:2000"):
 | 
			
		||||
    def __init__(self, daemon_address: str = "127.0.0.1:2000"):
 | 
			
		||||
        address = os.getenv(
 | 
			
		||||
            "AWS_XRAY_DAEMON_ADDRESS_YEAH_NOT_TODAY_MATE", daemon_address
 | 
			
		||||
        )
 | 
			
		||||
        self._ip, self._port = self._parse_address(address)
 | 
			
		||||
 | 
			
		||||
    def _xray_backend(self, region):
 | 
			
		||||
        return xray_backends[region]
 | 
			
		||||
    def _xray_backend(self, account_id: str, region: str) -> XRayBackend:
 | 
			
		||||
        return xray_backends[account_id][region]
 | 
			
		||||
 | 
			
		||||
    def send_entity(self, entity):
 | 
			
		||||
    def send_entity(self, entity: Any) -> None:
 | 
			
		||||
        # Hack to get region
 | 
			
		||||
        # region = entity.subsegments[0].aws['region']
 | 
			
		||||
        # xray = self._xray_backend(region)
 | 
			
		||||
@ -27,7 +28,7 @@ class MockEmitter(UDPEmitter):
 | 
			
		||||
        # TODO store X-Ray data, pretty sure X-Ray needs refactor for this
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def _send_data(self, data):
 | 
			
		||||
    def _send_data(self, data: Any) -> None:
 | 
			
		||||
        raise RuntimeError("Should not be running this")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -42,11 +43,11 @@ class MockXrayClient:
 | 
			
		||||
    that itno the recorder instance.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __call__(self, f=None):
 | 
			
		||||
    def __call__(self, f: Any = None) -> Any:
 | 
			
		||||
        if not f:
 | 
			
		||||
            return self
 | 
			
		||||
 | 
			
		||||
        def wrapped_f(*args, **kwargs):
 | 
			
		||||
        def wrapped_f(*args: Any, **kwargs: Any) -> Any:
 | 
			
		||||
            self.start()
 | 
			
		||||
            try:
 | 
			
		||||
                f(*args, **kwargs)
 | 
			
		||||
@ -55,7 +56,7 @@ class MockXrayClient:
 | 
			
		||||
 | 
			
		||||
        return wrapped_f
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
    def start(self) -> None:
 | 
			
		||||
        print("Starting X-Ray Patch")  # noqa
 | 
			
		||||
        self.old_xray_context_var = os.environ.get("AWS_XRAY_CONTEXT_MISSING")
 | 
			
		||||
        os.environ["AWS_XRAY_CONTEXT_MISSING"] = "LOG_ERROR"
 | 
			
		||||
@ -64,7 +65,7 @@ class MockXrayClient:
 | 
			
		||||
        aws_xray_sdk.core.xray_recorder._context = AWSContext()
 | 
			
		||||
        aws_xray_sdk.core.xray_recorder._emitter = MockEmitter()
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
    def stop(self) -> None:
 | 
			
		||||
        if self.old_xray_context_var is None:
 | 
			
		||||
            del os.environ["AWS_XRAY_CONTEXT_MISSING"]
 | 
			
		||||
        else:
 | 
			
		||||
@ -73,11 +74,11 @@ class MockXrayClient:
 | 
			
		||||
        aws_xray_sdk.core.xray_recorder._emitter = self.old_xray_emitter
 | 
			
		||||
        aws_xray_sdk.core.xray_recorder._context = self.old_xray_context
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
    def __enter__(self) -> "MockXrayClient":
 | 
			
		||||
        self.start()
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, *args):
 | 
			
		||||
    def __exit__(self, *args: Any) -> None:
 | 
			
		||||
        self.stop()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -91,12 +92,12 @@ class XRaySegment(object):
 | 
			
		||||
    During testing we're going to have to control the start and end of a segment via context managers.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
    def __enter__(self) -> "XRaySegment":
 | 
			
		||||
        aws_xray_sdk.core.xray_recorder.begin_segment(
 | 
			
		||||
            name="moto_mock", traceid=None, parent_id=None, sampling=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, exc_type, exc_val, exc_tb):
 | 
			
		||||
    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
 | 
			
		||||
        aws_xray_sdk.core.xray_recorder.end_segment()
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
import bisect
 | 
			
		||||
import datetime
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple
 | 
			
		||||
import json
 | 
			
		||||
from moto.core import BaseBackend, BackendDict, BaseModel
 | 
			
		||||
from moto.core.exceptions import AWSError
 | 
			
		||||
@ -8,54 +9,60 @@ from .exceptions import BadSegmentException
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TelemetryRecords(BaseModel):
 | 
			
		||||
    def __init__(self, instance_id, hostname, resource_arn, records):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        instance_id: str,
 | 
			
		||||
        hostname: str,
 | 
			
		||||
        resource_arn: str,
 | 
			
		||||
        records: List[Dict[str, Any]],
 | 
			
		||||
    ):
 | 
			
		||||
        self.instance_id = instance_id
 | 
			
		||||
        self.hostname = hostname
 | 
			
		||||
        self.resource_arn = resource_arn
 | 
			
		||||
        self.records = records
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_json(cls, src):
 | 
			
		||||
    def from_json(cls, src: Dict[str, Any]) -> "TelemetryRecords":  # type: ignore[misc]
 | 
			
		||||
        instance_id = src.get("EC2InstanceId", None)
 | 
			
		||||
        hostname = src.get("Hostname")
 | 
			
		||||
        resource_arn = src.get("ResourceARN")
 | 
			
		||||
        telemetry_records = src["TelemetryRecords"]
 | 
			
		||||
 | 
			
		||||
        return cls(instance_id, hostname, resource_arn, telemetry_records)
 | 
			
		||||
        return cls(instance_id, hostname, resource_arn, telemetry_records)  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# https://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html
 | 
			
		||||
class TraceSegment(BaseModel):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        name,
 | 
			
		||||
        segment_id,
 | 
			
		||||
        trace_id,
 | 
			
		||||
        start_time,
 | 
			
		||||
        raw,
 | 
			
		||||
        end_time=None,
 | 
			
		||||
        in_progress=False,
 | 
			
		||||
        service=None,
 | 
			
		||||
        user=None,
 | 
			
		||||
        origin=None,
 | 
			
		||||
        parent_id=None,
 | 
			
		||||
        http=None,
 | 
			
		||||
        aws=None,
 | 
			
		||||
        metadata=None,
 | 
			
		||||
        annotations=None,
 | 
			
		||||
        subsegments=None,
 | 
			
		||||
        **kwargs
 | 
			
		||||
        name: str,
 | 
			
		||||
        segment_id: str,
 | 
			
		||||
        trace_id: str,
 | 
			
		||||
        start_time: float,
 | 
			
		||||
        raw: Any,
 | 
			
		||||
        end_time: Optional[float] = None,
 | 
			
		||||
        in_progress: bool = False,
 | 
			
		||||
        service: Any = None,
 | 
			
		||||
        user: Any = None,
 | 
			
		||||
        origin: Any = None,
 | 
			
		||||
        parent_id: Any = None,
 | 
			
		||||
        http: Any = None,
 | 
			
		||||
        aws: Any = None,
 | 
			
		||||
        metadata: Any = None,
 | 
			
		||||
        annotations: Any = None,
 | 
			
		||||
        subsegments: Any = None,
 | 
			
		||||
        **kwargs: Any
 | 
			
		||||
    ):
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.id = segment_id
 | 
			
		||||
        self.trace_id = trace_id
 | 
			
		||||
        self._trace_version = None
 | 
			
		||||
        self._original_request_start_time = None
 | 
			
		||||
        self._trace_version: Optional[int] = None
 | 
			
		||||
        self._original_request_start_time: Optional[datetime.datetime] = None
 | 
			
		||||
        self._trace_identifier = None
 | 
			
		||||
        self.start_time = start_time
 | 
			
		||||
        self._start_date = None
 | 
			
		||||
        self._start_date: Optional[datetime.datetime] = None
 | 
			
		||||
        self.end_time = end_time
 | 
			
		||||
        self._end_date = None
 | 
			
		||||
        self._end_date: Optional[datetime.datetime] = None
 | 
			
		||||
        self.in_progress = in_progress
 | 
			
		||||
        self.service = service
 | 
			
		||||
        self.user = user
 | 
			
		||||
@ -71,17 +78,17 @@ class TraceSegment(BaseModel):
 | 
			
		||||
        # Raw json string
 | 
			
		||||
        self.raw = raw
 | 
			
		||||
 | 
			
		||||
    def __lt__(self, other):
 | 
			
		||||
    def __lt__(self, other: Any) -> bool:
 | 
			
		||||
        return self.start_date < other.start_date
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def trace_version(self):
 | 
			
		||||
    def trace_version(self) -> int:
 | 
			
		||||
        if self._trace_version is None:
 | 
			
		||||
            self._trace_version = int(self.trace_id.split("-", 1)[0])
 | 
			
		||||
        return self._trace_version
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def request_start_date(self):
 | 
			
		||||
    def request_start_date(self) -> datetime.datetime:
 | 
			
		||||
        if self._original_request_start_time is None:
 | 
			
		||||
            start_time = int(self.trace_id.split("-")[1], 16)
 | 
			
		||||
            self._original_request_start_time = datetime.datetime.fromtimestamp(
 | 
			
		||||
@ -90,19 +97,19 @@ class TraceSegment(BaseModel):
 | 
			
		||||
        return self._original_request_start_time
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def start_date(self):
 | 
			
		||||
    def start_date(self) -> datetime.datetime:
 | 
			
		||||
        if self._start_date is None:
 | 
			
		||||
            self._start_date = datetime.datetime.fromtimestamp(self.start_time)
 | 
			
		||||
        return self._start_date
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def end_date(self):
 | 
			
		||||
    def end_date(self) -> datetime.datetime:
 | 
			
		||||
        if self._end_date is None:
 | 
			
		||||
            self._end_date = datetime.datetime.fromtimestamp(self.end_time)
 | 
			
		||||
            self._end_date = datetime.datetime.fromtimestamp(self.end_time)  # type: ignore
 | 
			
		||||
        return self._end_date
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_dict(cls, data, raw):
 | 
			
		||||
    def from_dict(cls, data: Dict[str, Any], raw: Any) -> "TraceSegment":  # type: ignore[misc]
 | 
			
		||||
        # Check manditory args
 | 
			
		||||
        if "id" not in data:
 | 
			
		||||
            raise BadSegmentException(code="MissingParam", message="Missing segment ID")
 | 
			
		||||
@ -130,11 +137,11 @@ class TraceSegment(BaseModel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SegmentCollection(object):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._traces = defaultdict(self._new_trace_item)
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        self._traces: Dict[str, Dict[str, Any]] = defaultdict(self._new_trace_item)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _new_trace_item():
 | 
			
		||||
    def _new_trace_item() -> Dict[str, Any]:  # type: ignore[misc]
 | 
			
		||||
        return {
 | 
			
		||||
            "start_date": datetime.datetime(1970, 1, 1),
 | 
			
		||||
            "end_date": datetime.datetime(1970, 1, 1),
 | 
			
		||||
@ -143,7 +150,7 @@ class SegmentCollection(object):
 | 
			
		||||
            "segments": [],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def put_segment(self, segment):
 | 
			
		||||
    def put_segment(self, segment: Any) -> None:
 | 
			
		||||
        # insert into a sorted list
 | 
			
		||||
        bisect.insort_left(self._traces[segment.trace_id]["segments"], segment)
 | 
			
		||||
 | 
			
		||||
@ -160,11 +167,15 @@ class SegmentCollection(object):
 | 
			
		||||
            # Todo consolidate trace segments into a trace.
 | 
			
		||||
            # not enough working knowledge of xray to do this
 | 
			
		||||
 | 
			
		||||
    def summary(self, start_time, end_time, filter_expression=None):
 | 
			
		||||
    def summary(
 | 
			
		||||
        self, start_time: str, end_time: str, filter_expression: Any = None
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        # This beast https://docs.aws.amazon.com/xray/latest/api/API_GetTraceSummaries.html#API_GetTraceSummaries_ResponseSyntax
 | 
			
		||||
        if filter_expression is not None:
 | 
			
		||||
            raise AWSError(
 | 
			
		||||
                "Not implemented yet - moto", code="InternalFailure", status=500
 | 
			
		||||
                "Not implemented yet - moto",
 | 
			
		||||
                exception_type="InternalFailure",
 | 
			
		||||
                status=500,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        summaries = []
 | 
			
		||||
@ -213,7 +224,9 @@ class SegmentCollection(object):
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def get_trace_ids(self, trace_ids):
 | 
			
		||||
    def get_trace_ids(
 | 
			
		||||
        self, trace_ids: List[str]
 | 
			
		||||
    ) -> Tuple[List[Dict[str, Any]], List[str]]:
 | 
			
		||||
        traces = []
 | 
			
		||||
        unprocessed = []
 | 
			
		||||
 | 
			
		||||
@ -229,22 +242,24 @@ class SegmentCollection(object):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class XRayBackend(BaseBackend):
 | 
			
		||||
    def __init__(self, region_name, account_id):
 | 
			
		||||
    def __init__(self, region_name: str, account_id: str):
 | 
			
		||||
        super().__init__(region_name, account_id)
 | 
			
		||||
        self._telemetry_records = []
 | 
			
		||||
        self._telemetry_records: List[TelemetryRecords] = []
 | 
			
		||||
        self._segment_collection = SegmentCollection()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def default_vpc_endpoint_service(service_region, zones):
 | 
			
		||||
    def default_vpc_endpoint_service(
 | 
			
		||||
        service_region: str, zones: List[str]
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        """Default VPC endpoint service."""
 | 
			
		||||
        return BaseBackend.default_vpc_endpoint_service_factory(
 | 
			
		||||
            service_region, zones, "xray"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def add_telemetry_records(self, src):
 | 
			
		||||
    def add_telemetry_records(self, src: Any) -> None:
 | 
			
		||||
        self._telemetry_records.append(TelemetryRecords.from_json(src))
 | 
			
		||||
 | 
			
		||||
    def process_segment(self, doc):
 | 
			
		||||
    def process_segment(self, doc: Any) -> None:
 | 
			
		||||
        try:
 | 
			
		||||
            data = json.loads(doc)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
@ -264,13 +279,15 @@ class XRayBackend(BaseBackend):
 | 
			
		||||
                seg_id=segment.id, code="InternalFailure", message=str(err)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def get_trace_summary(self, start_time, end_time, filter_expression):
 | 
			
		||||
    def get_trace_summary(
 | 
			
		||||
        self, start_time: str, end_time: str, filter_expression: Any
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        return self._segment_collection.summary(start_time, end_time, filter_expression)
 | 
			
		||||
 | 
			
		||||
    def get_trace_ids(self, trace_ids):
 | 
			
		||||
    def get_trace_ids(self, trace_ids: List[str]) -> Dict[str, Any]:
 | 
			
		||||
        traces, unprocessed_ids = self._segment_collection.get_trace_ids(trace_ids)
 | 
			
		||||
 | 
			
		||||
        result = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids}
 | 
			
		||||
        result: Dict[str, Any] = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids}
 | 
			
		||||
 | 
			
		||||
        for trace in traces:
 | 
			
		||||
            segments = []
 | 
			
		||||
 | 
			
		||||
@ -1,49 +1,50 @@
 | 
			
		||||
import json
 | 
			
		||||
import datetime
 | 
			
		||||
from typing import Any, Dict, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from moto.core.responses import BaseResponse
 | 
			
		||||
from moto.core.exceptions import AWSError
 | 
			
		||||
from urllib.parse import urlsplit
 | 
			
		||||
 | 
			
		||||
from .models import xray_backends
 | 
			
		||||
from .models import xray_backends, XRayBackend
 | 
			
		||||
from .exceptions import BadSegmentException
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class XRayResponse(BaseResponse):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__(service_name="xray")
 | 
			
		||||
 | 
			
		||||
    def _error(self, code, message):
 | 
			
		||||
    def _error(self, code: str, message: str) -> Tuple[str, Dict[str, int]]:
 | 
			
		||||
        return json.dumps({"__type": code, "message": message}), dict(status=400)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def xray_backend(self):
 | 
			
		||||
    def xray_backend(self) -> XRayBackend:
 | 
			
		||||
        return xray_backends[self.current_account][self.region]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def request_params(self):
 | 
			
		||||
    def request_params(self) -> Any:  # type: ignore[misc]
 | 
			
		||||
        try:
 | 
			
		||||
            return json.loads(self.body)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            return {}
 | 
			
		||||
 | 
			
		||||
    def _get_param(self, param_name, if_none=None):
 | 
			
		||||
    def _get_param(self, param_name: str, if_none: Any = None) -> Any:
 | 
			
		||||
        return self.request_params.get(param_name, if_none)
 | 
			
		||||
 | 
			
		||||
    def _get_action(self):
 | 
			
		||||
    def _get_action(self) -> str:
 | 
			
		||||
        # Amazon is just calling urls like /TelemetryRecords etc...
 | 
			
		||||
        # This uses the value after / as the camalcase action, which then
 | 
			
		||||
        # gets converted in call_action to find the following methods
 | 
			
		||||
        return urlsplit(self.uri).path.lstrip("/")
 | 
			
		||||
 | 
			
		||||
    # PutTelemetryRecords
 | 
			
		||||
    def telemetry_records(self):
 | 
			
		||||
    def telemetry_records(self) -> str:
 | 
			
		||||
        self.xray_backend.add_telemetry_records(self.request_params)
 | 
			
		||||
 | 
			
		||||
        return ""
 | 
			
		||||
 | 
			
		||||
    # PutTraceSegments
 | 
			
		||||
    def trace_segments(self):
 | 
			
		||||
    def trace_segments(self) -> Union[str, Tuple[str, Dict[str, int]]]:
 | 
			
		||||
        docs = self._get_param("TraceSegmentDocuments")
 | 
			
		||||
 | 
			
		||||
        if docs is None:
 | 
			
		||||
@ -71,7 +72,7 @@ class XRayResponse(BaseResponse):
 | 
			
		||||
        return json.dumps(result)
 | 
			
		||||
 | 
			
		||||
    # GetTraceSummaries
 | 
			
		||||
    def trace_summaries(self):
 | 
			
		||||
    def trace_summaries(self) -> Union[str, Tuple[str, Dict[str, int]]]:
 | 
			
		||||
        start_time = self._get_param("StartTime")
 | 
			
		||||
        end_time = self._get_param("EndTime")
 | 
			
		||||
        if start_time is None:
 | 
			
		||||
@ -119,7 +120,7 @@ class XRayResponse(BaseResponse):
 | 
			
		||||
        return json.dumps(result)
 | 
			
		||||
 | 
			
		||||
    # BatchGetTraces
 | 
			
		||||
    def traces(self):
 | 
			
		||||
    def traces(self) -> Union[str, Tuple[str, Dict[str, int]]]:
 | 
			
		||||
        trace_ids = self._get_param("TraceIds")
 | 
			
		||||
 | 
			
		||||
        if trace_ids is None:
 | 
			
		||||
@ -142,7 +143,7 @@ class XRayResponse(BaseResponse):
 | 
			
		||||
        return json.dumps(result)
 | 
			
		||||
 | 
			
		||||
    # GetServiceGraph - just a dummy response for now
 | 
			
		||||
    def service_graph(self):
 | 
			
		||||
    def service_graph(self) -> Union[str, Tuple[str, Dict[str, int]]]:
 | 
			
		||||
        start_time = self._get_param("StartTime")
 | 
			
		||||
        end_time = self._get_param("EndTime")
 | 
			
		||||
        # next_token = self._get_param('NextToken')  # not implemented yet
 | 
			
		||||
@ -164,7 +165,7 @@ class XRayResponse(BaseResponse):
 | 
			
		||||
        return json.dumps(result)
 | 
			
		||||
 | 
			
		||||
    # GetTraceGraph - just a dummy response for now
 | 
			
		||||
    def trace_graph(self):
 | 
			
		||||
    def trace_graph(self) -> Union[str, Tuple[str, Dict[str, int]]]:
 | 
			
		||||
        trace_ids = self._get_param("TraceIds")
 | 
			
		||||
        # next_token = self._get_param('NextToken')  # not implemented yet
 | 
			
		||||
 | 
			
		||||
@ -175,5 +176,5 @@ class XRayResponse(BaseResponse):
 | 
			
		||||
                dict(status=400),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        result = {"Services": []}
 | 
			
		||||
        result: Dict[str, Any] = {"Services": []}
 | 
			
		||||
        return json.dumps(result)
 | 
			
		||||
 | 
			
		||||
@ -239,7 +239,7 @@ disable = W,C,R,E
 | 
			
		||||
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
 | 
			
		||||
 | 
			
		||||
[mypy]
 | 
			
		||||
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*,moto/u*,moto/t*
 | 
			
		||||
files= moto, tests/test_core/test_mock_all.py, tests/test_core/test_decorator_calls.py
 | 
			
		||||
show_column_numbers=True
 | 
			
		||||
show_error_codes = True
 | 
			
		||||
disable_error_code=abstract
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,11 @@
 | 
			
		||||
import boto3
 | 
			
		||||
import pytest
 | 
			
		||||
import sure  # noqa # pylint: disable=unused-import
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
from botocore.exceptions import ClientError
 | 
			
		||||
from typing import Any
 | 
			
		||||
from moto import mock_ec2, mock_kinesis, mock_s3, settings
 | 
			
		||||
from moto.core.models import BaseMockAWS
 | 
			
		||||
from unittest import SkipTest
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
@ -13,13 +14,13 @@ Test the different ways that the decorator can be used
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_ec2
 | 
			
		||||
def test_basic_decorator():
 | 
			
		||||
def test_basic_decorator() -> None:
 | 
			
		||||
    client = boto3.client("ec2", region_name="us-west-1")
 | 
			
		||||
    client.describe_addresses()["Addresses"].should.equal([])
 | 
			
		||||
    assert client.describe_addresses()["Addresses"] == []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(name="aws_credentials")
 | 
			
		||||
def fixture_aws_credentials(monkeypatch):
 | 
			
		||||
def fixture_aws_credentials(monkeypatch: Any) -> None:  # type: ignore[misc]
 | 
			
		||||
    """Mocked AWS Credentials for moto."""
 | 
			
		||||
    monkeypatch.setenv("AWS_ACCESS_KEY_ID", "testing")
 | 
			
		||||
    monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "testing")
 | 
			
		||||
@ -28,97 +29,95 @@ def fixture_aws_credentials(monkeypatch):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.network
 | 
			
		||||
def test_context_manager(aws_credentials):  # pylint: disable=unused-argument
 | 
			
		||||
def test_context_manager(aws_credentials: Any) -> None:  # type: ignore[misc]  # pylint: disable=unused-argument
 | 
			
		||||
    client = boto3.client("ec2", region_name="us-west-1")
 | 
			
		||||
    with pytest.raises(ClientError) as exc:
 | 
			
		||||
        client.describe_addresses()
 | 
			
		||||
    err = exc.value.response["Error"]
 | 
			
		||||
    err["Code"].should.equal("AuthFailure")
 | 
			
		||||
    err["Message"].should.equal(
 | 
			
		||||
        "AWS was not able to validate the provided access credentials"
 | 
			
		||||
    assert err["Code"] == "AuthFailure"
 | 
			
		||||
    assert (
 | 
			
		||||
        err["Message"] == "AWS was not able to validate the provided access credentials"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with mock_ec2():
 | 
			
		||||
        client = boto3.client("ec2", region_name="us-west-1")
 | 
			
		||||
        client.describe_addresses()["Addresses"].should.equal([])
 | 
			
		||||
        assert client.describe_addresses()["Addresses"] == []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.network
 | 
			
		||||
def test_decorator_start_and_stop():
 | 
			
		||||
def test_decorator_start_and_stop() -> None:
 | 
			
		||||
    if settings.TEST_SERVER_MODE:
 | 
			
		||||
        raise SkipTest("Authentication always works in ServerMode")
 | 
			
		||||
    mock = mock_ec2()
 | 
			
		||||
    mock: BaseMockAWS = mock_ec2()
 | 
			
		||||
    mock.start()
 | 
			
		||||
    client = boto3.client("ec2", region_name="us-west-1")
 | 
			
		||||
    client.describe_addresses()["Addresses"].should.equal([])
 | 
			
		||||
    assert client.describe_addresses()["Addresses"] == []
 | 
			
		||||
    mock.stop()
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(ClientError) as exc:
 | 
			
		||||
        client.describe_addresses()
 | 
			
		||||
    err = exc.value.response["Error"]
 | 
			
		||||
    err["Code"].should.equal("AuthFailure")
 | 
			
		||||
    err["Message"].should.equal(
 | 
			
		||||
        "AWS was not able to validate the provided access credentials"
 | 
			
		||||
    assert err["Code"] == "AuthFailure"
 | 
			
		||||
    assert (
 | 
			
		||||
        err["Message"] == "AWS was not able to validate the provided access credentials"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_ec2
 | 
			
		||||
def test_decorater_wrapped_gets_set():
 | 
			
		||||
def test_decorater_wrapped_gets_set() -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Moto decorator's __wrapped__ should get set to the tests function
 | 
			
		||||
    """
 | 
			
		||||
    test_decorater_wrapped_gets_set.__wrapped__.__name__.should.equal(
 | 
			
		||||
        "test_decorater_wrapped_gets_set"
 | 
			
		||||
    )
 | 
			
		||||
    assert test_decorater_wrapped_gets_set.__wrapped__.__name__ == "test_decorater_wrapped_gets_set"  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_ec2
 | 
			
		||||
class Tester(object):
 | 
			
		||||
    def test_the_class(self):
 | 
			
		||||
class Tester:
 | 
			
		||||
    def test_the_class(self) -> None:
 | 
			
		||||
        client = boto3.client("ec2", region_name="us-west-1")
 | 
			
		||||
        client.describe_addresses()["Addresses"].should.equal([])
 | 
			
		||||
        assert client.describe_addresses()["Addresses"] == []
 | 
			
		||||
 | 
			
		||||
    def test_still_the_same(self):
 | 
			
		||||
    def test_still_the_same(self) -> None:
 | 
			
		||||
        client = boto3.client("ec2", region_name="us-west-1")
 | 
			
		||||
        client.describe_addresses()["Addresses"].should.equal([])
 | 
			
		||||
        assert client.describe_addresses()["Addresses"] == []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TesterWithSetup(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        self.client = boto3.client("s3")
 | 
			
		||||
        self.client.create_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
    def test_still_the_same(self):
 | 
			
		||||
    def test_still_the_same(self) -> None:
 | 
			
		||||
        buckets = self.client.list_buckets()["Buckets"]
 | 
			
		||||
        bucket_names = [b["Name"] for b in buckets]
 | 
			
		||||
        # There is a potential bug in the class-decorator, where the reset API is not called on start.
 | 
			
		||||
        # This leads to a situation where 'bucket_names' may contain buckets created by earlier tests
 | 
			
		||||
        bucket_names.should.contain("mybucket")
 | 
			
		||||
        assert "mybucket" in bucket_names
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TesterWithStaticmethod(object):
 | 
			
		||||
class TesterWithStaticmethod:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def static(*args):
 | 
			
		||||
    def static(*args: Any) -> None:  # type: ignore[misc]
 | 
			
		||||
        assert not args or not isinstance(args[0], TesterWithStaticmethod)
 | 
			
		||||
 | 
			
		||||
    def test_no_instance_sent_to_staticmethod(self):
 | 
			
		||||
    def test_no_instance_sent_to_staticmethod(self) -> None:
 | 
			
		||||
        self.static()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TestWithSetup_UppercaseU(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        # This method will be executed automatically, provided we extend the TestCase-class
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        s3.create_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
    def test_should_find_bucket(self):
 | 
			
		||||
    def test_should_find_bucket(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        self.assertIsNotNone(s3.head_bucket(Bucket="mybucket"))
 | 
			
		||||
 | 
			
		||||
    def test_should_not_find_unknown_bucket(self):
 | 
			
		||||
    def test_should_not_find_unknown_bucket(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        with pytest.raises(ClientError):
 | 
			
		||||
            s3.head_bucket(Bucket="unknown_bucket")
 | 
			
		||||
@ -126,16 +125,16 @@ class TestWithSetup_UppercaseU(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TestWithSetup_LowercaseU:
 | 
			
		||||
    def setup_method(self, *args):  # pylint: disable=unused-argument
 | 
			
		||||
    def setup_method(self, *args: Any) -> None:  # pylint: disable=unused-argument
 | 
			
		||||
        # This method will be executed automatically using pytest
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        s3.create_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
    def test_should_find_bucket(self):
 | 
			
		||||
    def test_should_find_bucket(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        assert s3.head_bucket(Bucket="mybucket") is not None
 | 
			
		||||
 | 
			
		||||
    def test_should_not_find_unknown_bucket(self):
 | 
			
		||||
    def test_should_not_find_unknown_bucket(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        with pytest.raises(ClientError):
 | 
			
		||||
            s3.head_bucket(Bucket="unknown_bucket")
 | 
			
		||||
@ -143,16 +142,16 @@ class TestWithSetup_LowercaseU:
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TestWithSetupMethod:
 | 
			
		||||
    def setup_method(self, *args):  # pylint: disable=unused-argument
 | 
			
		||||
    def setup_method(self, *args: Any) -> None:  # pylint: disable=unused-argument
 | 
			
		||||
        # This method will be executed automatically using pytest
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        s3.create_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
    def test_should_find_bucket(self):
 | 
			
		||||
    def test_should_find_bucket(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        assert s3.head_bucket(Bucket="mybucket") is not None
 | 
			
		||||
 | 
			
		||||
    def test_should_not_find_unknown_bucket(self):
 | 
			
		||||
    def test_should_not_find_unknown_bucket(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        with pytest.raises(ClientError):
 | 
			
		||||
            s3.head_bucket(Bucket="unknown_bucket")
 | 
			
		||||
@ -160,17 +159,17 @@ class TestWithSetupMethod:
 | 
			
		||||
 | 
			
		||||
@mock_kinesis
 | 
			
		||||
class TestKinesisUsingSetupMethod:
 | 
			
		||||
    def setup_method(self, *args):  # pylint: disable=unused-argument
 | 
			
		||||
    def setup_method(self, *args: Any) -> None:  # pylint: disable=unused-argument
 | 
			
		||||
        self.stream_name = "test_stream"
 | 
			
		||||
        self.boto3_kinesis_client = boto3.client("kinesis", region_name="us-east-1")
 | 
			
		||||
        self.boto3_kinesis_client.create_stream(
 | 
			
		||||
            StreamName=self.stream_name, ShardCount=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_stream_creation(self):
 | 
			
		||||
    def test_stream_creation(self) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def test_stream_recreation(self):
 | 
			
		||||
    def test_stream_recreation(self) -> None:
 | 
			
		||||
        # The setup-method will run again for this test
 | 
			
		||||
        # The fact that it passes, means the state was reset
 | 
			
		||||
        # Otherwise it would complain about a stream already existing
 | 
			
		||||
@ -179,11 +178,11 @@ class TestKinesisUsingSetupMethod:
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TestWithInvalidSetupMethod:
 | 
			
		||||
    def setupmethod(self):
 | 
			
		||||
    def setupmethod(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        s3.create_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
    def test_should_not_find_bucket(self):
 | 
			
		||||
    def test_should_not_find_bucket(self) -> None:
 | 
			
		||||
        # Name of setupmethod is not recognized, so it will not be executed
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        with pytest.raises(ClientError):
 | 
			
		||||
@ -192,17 +191,17 @@ class TestWithInvalidSetupMethod:
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TestWithPublicMethod(unittest.TestCase):
 | 
			
		||||
    def ensure_bucket_exists(self):
 | 
			
		||||
    def ensure_bucket_exists(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        s3.create_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
    def test_should_find_bucket(self):
 | 
			
		||||
    def test_should_find_bucket(self) -> None:
 | 
			
		||||
        self.ensure_bucket_exists()
 | 
			
		||||
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        s3.head_bucket(Bucket="mybucket").shouldnt.equal(None)
 | 
			
		||||
 | 
			
		||||
    def test_should_not_find_bucket(self):
 | 
			
		||||
    def test_should_not_find_bucket(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        with pytest.raises(ClientError):
 | 
			
		||||
            s3.head_bucket(Bucket="mybucket")
 | 
			
		||||
@ -210,16 +209,16 @@ class TestWithPublicMethod(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TestWithPseudoPrivateMethod(unittest.TestCase):
 | 
			
		||||
    def _ensure_bucket_exists(self):
 | 
			
		||||
    def _ensure_bucket_exists(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        s3.create_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
    def test_should_find_bucket(self):
 | 
			
		||||
    def test_should_find_bucket(self) -> None:
 | 
			
		||||
        self._ensure_bucket_exists()
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        s3.head_bucket(Bucket="mybucket").shouldnt.equal(None)
 | 
			
		||||
 | 
			
		||||
    def test_should_not_find_bucket(self):
 | 
			
		||||
    def test_should_not_find_bucket(self) -> None:
 | 
			
		||||
        s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        with pytest.raises(ClientError):
 | 
			
		||||
            s3.head_bucket(Bucket="mybucket")
 | 
			
		||||
@ -227,20 +226,20 @@ class TestWithPseudoPrivateMethod(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class Baseclass(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        self.s3 = boto3.resource("s3", region_name="us-east-1")
 | 
			
		||||
        self.client = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
        self.test_bucket = self.s3.Bucket("testbucket")
 | 
			
		||||
        self.test_bucket.create()
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
    def tearDown(self) -> None:
 | 
			
		||||
        # The bucket will still exist at this point
 | 
			
		||||
        self.test_bucket.delete()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TestSetUpInBaseClass(Baseclass):
 | 
			
		||||
    def test_a_thing(self):
 | 
			
		||||
    def test_a_thing(self) -> None:
 | 
			
		||||
        # Verify that we can 'see' the setUp-method in the parent class
 | 
			
		||||
        self.client.head_bucket(Bucket="testbucket").shouldnt.equal(None)
 | 
			
		||||
 | 
			
		||||
@ -248,42 +247,42 @@ class TestSetUpInBaseClass(Baseclass):
 | 
			
		||||
@mock_s3
 | 
			
		||||
class TestWithNestedClasses:
 | 
			
		||||
    class NestedClass(unittest.TestCase):
 | 
			
		||||
        def _ensure_bucket_exists(self):
 | 
			
		||||
        def _ensure_bucket_exists(self) -> None:
 | 
			
		||||
            s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
            s3.create_bucket(Bucket="bucketclass1")
 | 
			
		||||
 | 
			
		||||
        def test_should_find_bucket(self):
 | 
			
		||||
        def test_should_find_bucket(self) -> None:
 | 
			
		||||
            self._ensure_bucket_exists()
 | 
			
		||||
            s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
            s3.head_bucket(Bucket="bucketclass1")
 | 
			
		||||
 | 
			
		||||
    class NestedClass2(unittest.TestCase):
 | 
			
		||||
        def _ensure_bucket_exists(self):
 | 
			
		||||
        def _ensure_bucket_exists(self) -> None:
 | 
			
		||||
            s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
            s3.create_bucket(Bucket="bucketclass2")
 | 
			
		||||
 | 
			
		||||
        def test_should_find_bucket(self):
 | 
			
		||||
        def test_should_find_bucket(self) -> None:
 | 
			
		||||
            self._ensure_bucket_exists()
 | 
			
		||||
            s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
            s3.head_bucket(Bucket="bucketclass2")
 | 
			
		||||
 | 
			
		||||
        def test_should_not_find_bucket_from_different_class(self):
 | 
			
		||||
        def test_should_not_find_bucket_from_different_class(self) -> None:
 | 
			
		||||
            s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
            with pytest.raises(ClientError):
 | 
			
		||||
                s3.head_bucket(Bucket="bucketclass1")
 | 
			
		||||
 | 
			
		||||
    class TestWithSetup(unittest.TestCase):
 | 
			
		||||
        def setUp(self):
 | 
			
		||||
        def setUp(self) -> None:
 | 
			
		||||
            s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
            s3.create_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
        def test_should_find_bucket(self):
 | 
			
		||||
        def test_should_find_bucket(self) -> None:
 | 
			
		||||
            s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
            s3.head_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
            s3.create_bucket(Bucket="bucketinsidetest")
 | 
			
		||||
 | 
			
		||||
        def test_should_not_find_bucket_from_test_method(self):
 | 
			
		||||
        def test_should_not_find_bucket_from_test_method(self) -> None:
 | 
			
		||||
            s3 = boto3.client("s3", region_name="us-east-1")
 | 
			
		||||
            s3.head_bucket(Bucket="mybucket")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ from moto import mock_all
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_all()
 | 
			
		||||
def test_decorator():
 | 
			
		||||
def test_decorator() -> None:
 | 
			
		||||
    rgn = "us-east-1"
 | 
			
		||||
    sqs = boto3.client("sqs", region_name=rgn)
 | 
			
		||||
    r = sqs.list_queues()
 | 
			
		||||
@ -21,7 +21,7 @@ def test_decorator():
 | 
			
		||||
    r["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_context_manager():
 | 
			
		||||
def test_context_manager() -> None:
 | 
			
		||||
    rgn = "us-east-1"
 | 
			
		||||
 | 
			
		||||
    with mock_all():
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user