diff --git a/.coveragerc b/.coveragerc index 2258101db..629809fe6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,6 +5,8 @@ exclude_lines = raise NotImplemented. return NotImplemented def __repr__ + if TYPE_CHECKING: + ^\s*\.\.\.$ [run] include = moto/* diff --git a/moto/__init__.py b/moto/__init__.py index 0cdaf4e60..32b721550 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -1,11 +1,19 @@ import importlib import sys from contextlib import ContextDecorator -from typing import Any, Callable, List, Optional, TypeVar -from moto.core.models import BaseMockAWS +from moto.core.models import BaseMockAWS, base_decorator, BaseDecorator +from typing import Any, Callable, List, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING -TEST_METHOD = TypeVar("TEST_METHOD", bound=Callable[..., Any]) +if TYPE_CHECKING: + from moto.xray import XRaySegment as xray_segment_type + from typing_extensions import ParamSpec + + P = ParamSpec("P") + + +T = TypeVar("T") def lazy_load( @@ -13,10 +21,21 @@ def lazy_load( element: str, boto3_name: Optional[str] = None, backend: Optional[str] = None, -) -> Callable[..., BaseMockAWS]: - def f(*args: Any, **kwargs: Any) -> Any: +) -> BaseDecorator: + @overload + def f(func: None = None) -> BaseMockAWS: + ... + + @overload + def f(func: "Callable[P, T]") -> "Callable[P, T]": + ... + + def f( + func: "Optional[Callable[P, T]]" = None, + ) -> "Union[BaseMockAWS, Callable[P, T]]": module = importlib.import_module(module_name, "moto") - return getattr(module, element)(*args, **kwargs) + decorator: base_decorator = getattr(module, element) + return decorator(func) setattr(f, "name", module_name.replace(".", "")) setattr(f, "element", element) @@ -25,6 +44,18 @@ def lazy_load( return f +def load_xray_segment() -> Callable[[], "xray_segment_type"]: + def f() -> "xray_segment_type": + # We can't use `lazy_load` here + # XRaySegment will always be run as a context manager + # I.e.: no function is passed directly: `with XRaySegment()` + from moto.xray import XRaySegment as xray_segment + + return xray_segment() + + return f + + mock_acm = lazy_load(".acm", "mock_acm") mock_acmpca = lazy_load(".acmpca", "mock_acmpca", boto3_name="acm-pca") mock_amp = lazy_load(".amp", "mock_amp") @@ -190,7 +221,7 @@ mock_timestreamwrite = lazy_load( ".timestreamwrite", "mock_timestreamwrite", boto3_name="timestream-write" ) mock_transcribe = lazy_load(".transcribe", "mock_transcribe") -XRaySegment = lazy_load(".xray", "XRaySegment") +XRaySegment = load_xray_segment() mock_xray = lazy_load(".xray", "mock_xray") mock_xray_client = lazy_load(".xray", "mock_xray_client") mock_wafv2 = lazy_load(".wafv2", "mock_wafv2") diff --git a/moto/core/models.py b/moto/core/models.py index 9957c6f35..5cae8286b 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -5,8 +5,8 @@ import os import re import unittest from types import FunctionType -from typing import Any, Callable, Dict, Optional, Set, TypeVar, Union -from typing import ContextManager +from typing import Any, Callable, Dict, Optional, Set, TypeVar, Union, overload +from typing import ContextManager, TYPE_CHECKING from unittest.mock import patch import boto3 @@ -26,8 +26,16 @@ from .custom_responses_mock import ( ) from .model_instances import reset_model_data +if TYPE_CHECKING: + from typing_extensions import ParamSpec, Protocol + + P = ParamSpec("P") +else: + Protocol = object + + DEFAULT_ACCOUNT_ID = "123456789012" -CALLABLE_RETURN = TypeVar("CALLABLE_RETURN") +T = TypeVar("T") class BaseMockAWS(ContextManager["BaseMockAWS"]): @@ -67,10 +75,10 @@ class BaseMockAWS(ContextManager["BaseMockAWS"]): def __call__( self, - func: Callable[..., "BaseMockAWS"], + func: "Callable[P, T]", reset: bool = True, remove_data: bool = True, - ) -> Callable[..., "BaseMockAWS"]: + ) -> "Callable[P, T]": if inspect.isclass(func): return self.decorate_class(func) # type: ignore return self.decorate_callable(func, reset, remove_data) @@ -120,9 +128,12 @@ class BaseMockAWS(ContextManager["BaseMockAWS"]): self.disable_patching() # type: ignore[attr-defined] def decorate_callable( - self, func: Callable[..., "BaseMockAWS"], reset: bool, remove_data: bool - ) -> Callable[..., "BaseMockAWS"]: - def wrapper(*args: Any, **kwargs: Any) -> "BaseMockAWS": + self, + func: "Callable[P, T]", + reset: bool, + remove_data: bool, + ) -> "Callable[P, T]": + def wrapper(*args: "P.args", **kwargs: "P.kwargs") -> T: self.start(reset=reset) try: result = func(*args, **kwargs) @@ -433,7 +444,7 @@ class ServerModeMockAWS(BaseMockAWS): def _get_region(self, *args: Any, **kwargs: Any) -> Optional[str]: if "region_name" in kwargs: return kwargs["region_name"] - if type(args) == tuple and len(args) == 2: + if type(args) is tuple and len(args) == 2: _, region = args return region return None @@ -513,13 +524,21 @@ class base_decorator: def __init__(self, backends: BackendDict): self.backends = backends + @overload + def __call__(self, func: None = None) -> BaseMockAWS: + ... + + @overload + def __call__(self, func: "Callable[P, T]") -> "Callable[P, T]": + ... + def __call__( - self, func: Optional[Callable[..., Any]] = None - ) -> Union[BaseMockAWS, Callable[..., BaseMockAWS]]: + self, func: "Optional[Callable[P, T]]" = None + ) -> "Union[BaseMockAWS, Callable[P, T]]": if settings.test_proxy_mode(): mocked_backend: BaseMockAWS = ProxyModeMockAWS(self.backends) elif settings.TEST_SERVER_MODE: - mocked_backend: BaseMockAWS = ServerModeMockAWS(self.backends) # type: ignore + mocked_backend = ServerModeMockAWS(self.backends) else: mocked_backend = self.mock_backend(self.backends) @@ -527,3 +546,18 @@ class base_decorator: return mocked_backend(func) else: return mocked_backend + + +class BaseDecorator(Protocol): + """A protocol for base_decorator's signature. + + This enables typing of callables with the same behavior as base_decorator. + """ + + @overload + def __call__(self, func: None = None) -> BaseMockAWS: + ... + + @overload + def __call__(self, func: "Callable[P, T]") -> "Callable[P, T]": + ... diff --git a/moto/identitystore/models.py b/moto/identitystore/models.py index f0f5443dd..d5d38d6e7 100644 --- a/moto/identitystore/models.py +++ b/moto/identitystore/models.py @@ -1,5 +1,4 @@ -from typing import Dict, Tuple, List, Any, NamedTuple, Optional -from typing_extensions import Self +from typing import Dict, Tuple, List, Any, NamedTuple, Optional, TYPE_CHECKING from moto.utilities.paginator import paginate from botocore.exceptions import ParamValidationError @@ -13,6 +12,9 @@ from .exceptions import ( ) import warnings +if TYPE_CHECKING: + from typing_extensions import Self + class Group(NamedTuple): GroupId: str @@ -31,7 +33,7 @@ class Name(NamedTuple): HonorificSuffix: Optional[str] @classmethod - def from_dict(cls, name_dict: Dict[str, str]) -> Optional[Self]: + def from_dict(cls, name_dict: Dict[str, str]) -> "Optional[Self]": if not name_dict: return None return cls( diff --git a/requirements-dev.txt b/requirements-dev.txt index 437420f60..ec4768281 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,6 +7,12 @@ inflection lxml mypy typing-extensions<=4.5.0; python_version < '3.8' +typing-extensions; python_version >= '3.8' packaging build prompt_toolkit + +# typing_extensions is currently used for: +# Protocol (3.8+) +# ParamSpec (3.10+) +# Self (3.11+) diff --git a/setup.cfg b/setup.cfg index ae691bb29..7d837a4aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -274,7 +274,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto, tests/test_core/test_mock_all.py, tests/test_core/test_decorator_calls.py, tests/test_core/test_responses_module.py +files= moto, tests/test_core/test_mock_all.py, tests/test_core/test_decorator_calls.py, tests/test_core/test_responses_module.py, tests/test_core/test_mypy.py show_column_numbers=True show_error_codes = True disable_error_code=abstract diff --git a/tests/test_core/test_decorator_calls.py b/tests/test_core/test_decorator_calls.py index 0ec4eab0e..b04256d30 100644 --- a/tests/test_core/test_decorator_calls.py +++ b/tests/test_core/test_decorator_calls.py @@ -5,7 +5,6 @@ import unittest from botocore.exceptions import ClientError from typing import Any from moto import mock_ec2, mock_kinesis, mock_s3, settings -from moto.core.models import BaseMockAWS from unittest import SkipTest """ @@ -48,7 +47,7 @@ def test_context_manager(aws_credentials: Any) -> None: # type: ignore[misc] # def test_decorator_start_and_stop() -> None: if settings.TEST_SERVER_MODE: raise SkipTest("Authentication always works in ServerMode") - mock: BaseMockAWS = mock_ec2() + mock = mock_ec2() mock.start() client = boto3.client("ec2", region_name="us-west-1") assert client.describe_addresses()["Addresses"] == [] diff --git a/tests/test_core/test_mypy.py b/tests/test_core/test_mypy.py new file mode 100644 index 000000000..593e1a8da --- /dev/null +++ b/tests/test_core/test_mypy.py @@ -0,0 +1,42 @@ +import boto3 + +from moto import mock_s3 +from moto.core.models import BaseMockAWS + + +@mock_s3 +def test_without_parentheses() -> int: + assert boto3.client("s3").list_buckets()["Buckets"] == [] + return 123 + + +@mock_s3() +def test_with_parentheses() -> int: + assert boto3.client("s3").list_buckets()["Buckets"] == [] + return 456 + + +@mock_s3 +def test_no_return() -> None: + assert boto3.client("s3").list_buckets()["Buckets"] == [] + + +def test_with_context_manager() -> None: + with mock_s3(): + assert boto3.client("s3").list_buckets()["Buckets"] == [] + + +def test_manual() -> None: + # this has the explicit type not because it's necessary but so that mypy will + # complain if it's wrong + m: BaseMockAWS = mock_s3() + m.start() + assert boto3.client("s3").list_buckets()["Buckets"] == [] + m.stop() + + +x: int = test_with_parentheses() +assert x == 456 + +y: int = test_without_parentheses() +assert y == 123 diff --git a/tests/test_core/test_responses_module.py b/tests/test_core/test_responses_module.py index c6f7992cf..fc22ee79a 100644 --- a/tests/test_core/test_responses_module.py +++ b/tests/test_core/test_responses_module.py @@ -19,7 +19,7 @@ class TestResponsesModule(TestCase): @mock_s3 @responses.activate - def test_moto_first(self) -> None: + def test_moto_first(self) -> None: # type: ignore """ Verify we can activate a user-defined `responses` on top of our Moto mocks """