From f38babb026809f560d1b2dec915693cd6c230081 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 26 Apr 2023 22:20:28 +0000 Subject: [PATCH] Techdebt: MyPy S (#6261) --- moto/sdb/exceptions.py | 7 +- moto/sdb/models.py | 43 ++++---- moto/sdb/responses.py | 16 +-- moto/server.py | 9 +- moto/servicediscovery/exceptions.py | 8 +- moto/servicediscovery/models.py | 141 ++++++++++++++----------- moto/servicediscovery/responses.py | 45 ++++---- moto/settings.py | 6 +- moto/signer/models.py | 36 ++++--- moto/signer/responses.py | 10 +- moto/ssoadmin/exceptions.py | 2 +- moto/ssoadmin/models.py | 135 ++++++++++++------------ moto/ssoadmin/responses.py | 24 +++-- moto/stepfunctions/exceptions.py | 4 +- moto/stepfunctions/models.py | 158 +++++++++++++++++----------- moto/stepfunctions/responses.py | 37 +++---- moto/stepfunctions/utils.py | 13 +-- moto/sts/exceptions.py | 3 +- moto/sts/models.py | 69 +++++++----- moto/sts/responses.py | 34 +++--- moto/sts/utils.py | 4 +- moto/support/models.py | 64 +++++------ moto/support/responses.py | 16 +-- setup.cfg | 2 +- 24 files changed, 480 insertions(+), 406 deletions(-) diff --git a/moto/sdb/exceptions.py b/moto/sdb/exceptions.py index ec588d2f3..6ab6f9a3a 100644 --- a/moto/sdb/exceptions.py +++ b/moto/sdb/exceptions.py @@ -1,4 +1,5 @@ """Exceptions raised by the sdb service.""" +from typing import Any from moto.core.exceptions import RESTError @@ -18,7 +19,7 @@ SDB_ERROR = """ class InvalidParameterError(RESTError): code = 400 - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): kwargs.setdefault("template", "sdb_error") self.templates["sdb_error"] = SDB_ERROR kwargs["error_type"] = "InvalidParameterValue" @@ -28,7 +29,7 @@ class InvalidParameterError(RESTError): class InvalidDomainName(InvalidParameterError): code = 400 - def __init__(self, domain_name): + def __init__(self, domain_name: str): super().__init__( message=f"Value ({domain_name}) for parameter DomainName is invalid. " ) @@ -37,7 +38,7 @@ class InvalidDomainName(InvalidParameterError): class UnknownDomainName(RESTError): code = 400 - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): kwargs.setdefault("template", "sdb_error") self.templates["sdb_error"] = SDB_ERROR kwargs["error_type"] = "NoSuchDomain" diff --git a/moto/sdb/models.py b/moto/sdb/models.py index 1831add4b..f093e30a8 100644 --- a/moto/sdb/models.py +++ b/moto/sdb/models.py @@ -1,23 +1,24 @@ """SimpleDBBackend class with methods for supported APIs.""" import re from collections import defaultdict -from moto.core import BaseBackend, BackendDict, BaseModel from threading import Lock +from typing import Any, Dict, List, Iterable, Optional +from moto.core import BaseBackend, BackendDict, BaseModel from .exceptions import InvalidDomainName, UnknownDomainName class FakeItem(BaseModel): - def __init__(self): - self.attributes = [] + def __init__(self) -> None: + self.attributes: List[Dict[str, Any]] = [] self.lock = Lock() - def get_attributes(self, names): + def get_attributes(self, names: Optional[List[str]]) -> List[Dict[str, Any]]: if not names: return self.attributes return [attr for attr in self.attributes if attr["name"] in names] - def put_attributes(self, attributes): + def put_attributes(self, attributes: List[Dict[str, Any]]) -> None: # Replacing attributes involves quite a few loops # Lock this, so we know noone else touches this list while we're operating on it with self.lock: @@ -26,56 +27,58 @@ class FakeItem(BaseModel): self._remove_attributes(attr["name"]) self.attributes.append(attr) - def _remove_attributes(self, name): + def _remove_attributes(self, name: str) -> None: self.attributes = [attr for attr in self.attributes if attr["name"] != name] class FakeDomain(BaseModel): - def __init__(self, name): + def __init__(self, name: str): self.name = name - self.items = defaultdict(FakeItem) + self.items: Dict[str, FakeItem] = defaultdict(FakeItem) - def get(self, item_name, attribute_names): + def get(self, item_name: str, attribute_names: List[str]) -> List[Dict[str, Any]]: item = self.items[item_name] return item.get_attributes(attribute_names) - def put(self, item_name, attributes): + def put(self, item_name: str, attributes: List[Dict[str, Any]]) -> None: item = self.items[item_name] item.put_attributes(attributes) class SimpleDBBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.domains = dict() + self.domains: Dict[str, FakeDomain] = dict() - def create_domain(self, domain_name): + def create_domain(self, domain_name: str) -> None: self._validate_domain_name(domain_name) self.domains[domain_name] = FakeDomain(name=domain_name) - def list_domains(self): + def list_domains(self) -> Iterable[str]: """ The `max_number_of_domains` and `next_token` parameter have not been implemented yet - we simply return all domains. """ return self.domains.keys() - def delete_domain(self, domain_name): + def delete_domain(self, domain_name: str) -> None: self._validate_domain_name(domain_name) # Ignore unknown domains - AWS does the same self.domains.pop(domain_name, None) - def _validate_domain_name(self, domain_name): + def _validate_domain_name(self, domain_name: str) -> None: # Domain Name needs to have at least 3 chars # Can only contain characters: a-z, A-Z, 0-9, '_', '-', and '.' if not re.match("^[a-zA-Z0-9-_.]{3,}$", domain_name): raise InvalidDomainName(domain_name) - def _get_domain(self, domain_name): + def _get_domain(self, domain_name: str) -> FakeDomain: if domain_name not in self.domains: raise UnknownDomainName() return self.domains[domain_name] - def get_attributes(self, domain_name, item_name, attribute_names): + def get_attributes( + self, domain_name: str, item_name: str, attribute_names: List[str] + ) -> List[Dict[str, Any]]: """ Behaviour for the consistent_read-attribute is not yet implemented """ @@ -83,7 +86,9 @@ class SimpleDBBackend(BaseBackend): domain = self._get_domain(domain_name) return domain.get(item_name, attribute_names) - def put_attributes(self, domain_name, item_name, attributes): + def put_attributes( + self, domain_name: str, item_name: str, attributes: List[Dict[str, Any]] + ) -> None: """ Behaviour for the expected-attribute is not yet implemented. """ diff --git a/moto/sdb/responses.py b/moto/sdb/responses.py index 1e03a7d61..11b197428 100644 --- a/moto/sdb/responses.py +++ b/moto/sdb/responses.py @@ -1,33 +1,33 @@ from moto.core.responses import BaseResponse -from .models import sdb_backends +from .models import sdb_backends, SimpleDBBackend class SimpleDBResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="sdb") @property - def sdb_backend(self): + def sdb_backend(self) -> SimpleDBBackend: return sdb_backends[self.current_account][self.region] - def create_domain(self): + def create_domain(self) -> str: domain_name = self._get_param("DomainName") self.sdb_backend.create_domain(domain_name=domain_name) template = self.response_template(CREATE_DOMAIN_TEMPLATE) return template.render() - def delete_domain(self): + def delete_domain(self) -> str: domain_name = self._get_param("DomainName") self.sdb_backend.delete_domain(domain_name=domain_name) template = self.response_template(DELETE_DOMAIN_TEMPLATE) return template.render() - def list_domains(self): + def list_domains(self) -> str: domain_names = self.sdb_backend.list_domains() template = self.response_template(LIST_DOMAINS_TEMPLATE) return template.render(domain_names=domain_names, next_token=None) - def get_attributes(self): + def get_attributes(self) -> str: domain_name = self._get_param("DomainName") item_name = self._get_param("ItemName") attribute_names = self._get_multi_param("AttributeName.") @@ -39,7 +39,7 @@ class SimpleDBResponse(BaseResponse): template = self.response_template(GET_ATTRIBUTES_TEMPLATE) return template.render(attributes=attributes) - def put_attributes(self): + def put_attributes(self) -> str: domain_name = self._get_param("DomainName") item_name = self._get_param("ItemName") attributes = self._get_list_prefix("Attribute") diff --git a/moto/server.py b/moto/server.py index b451943a1..0e727d611 100644 --- a/moto/server.py +++ b/moto/server.py @@ -3,6 +3,7 @@ import os import signal import sys import warnings +from typing import Any, List, Optional from werkzeug.serving import run_simple @@ -15,11 +16,11 @@ from moto.moto_server.threaded_moto_server import ( # noqa # pylint: disable=un ) -def signal_handler(signum, frame): # pylint: disable=unused-argument +def signal_handler(signum: Any, frame: Any) -> None: # pylint: disable=unused-argument sys.exit(0) -def main(argv=None): +def main(argv: Optional[List[str]] = None) -> None: argv = argv or sys.argv[1:] parser = argparse.ArgumentParser() @@ -79,9 +80,9 @@ def main(argv=None): # Wrap the main application main_app = DomainDispatcherApplication(create_backend_app, service=args.service) - main_app.debug = True + main_app.debug = True # type: ignore - ssl_context = None + ssl_context: Any = None if args.ssl_key and args.ssl_cert: ssl_context = (args.ssl_cert, args.ssl_key) elif args.ssl: diff --git a/moto/servicediscovery/exceptions.py b/moto/servicediscovery/exceptions.py index f816fcc2c..4615306f3 100644 --- a/moto/servicediscovery/exceptions.py +++ b/moto/servicediscovery/exceptions.py @@ -3,20 +3,20 @@ from moto.core.exceptions import JsonRESTError class OperationNotFound(JsonRESTError): - def __init__(self): + def __init__(self) -> None: super().__init__("OperationNotFound", "") class NamespaceNotFound(JsonRESTError): - def __init__(self, ns_id): + def __init__(self, ns_id: str): super().__init__("NamespaceNotFound", f"{ns_id}") class ServiceNotFound(JsonRESTError): - def __init__(self, ns_id): + def __init__(self, ns_id: str): super().__init__("ServiceNotFound", f"{ns_id}") class ConflictingDomainExists(JsonRESTError): - def __init__(self, vpc_id): + def __init__(self, vpc_id: str): super().__init__("ConflictingDomainExists", f"{vpc_id}") diff --git a/moto/servicediscovery/models.py b/moto/servicediscovery/models.py index a97ab294a..2841011ec 100644 --- a/moto/servicediscovery/models.py +++ b/moto/servicediscovery/models.py @@ -1,4 +1,5 @@ import string +from typing import Any, Dict, Iterable, List, Optional from moto.core import BaseBackend, BackendDict, BaseModel from moto.core.utils import unix_time @@ -13,7 +14,7 @@ from .exceptions import ( ) -def random_id(size): +def random_id(size: int) -> str: return "".join( [random.choice(string.ascii_lowercase + string.digits) for _ in range(size)] ) @@ -22,17 +23,16 @@ def random_id(size): class Namespace(BaseModel): def __init__( self, - account_id, - region, - name, - ns_type, - creator_request_id, - description, - dns_properties, - http_properties, - vpc=None, + account_id: str, + region: str, + name: str, + ns_type: str, + creator_request_id: str, + description: str, + dns_properties: Dict[str, Any], + http_properties: Dict[str, Any], + vpc: Optional[str] = None, ): - super().__init__() self.id = f"ns-{random_id(20)}" self.arn = f"arn:aws:servicediscovery:{region}:{account_id}:namespace/{self.id}" self.name = name @@ -45,7 +45,7 @@ class Namespace(BaseModel): self.created = unix_time() self.updated = unix_time() - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "Arn": self.arn, "Id": self.id, @@ -65,31 +65,30 @@ class Namespace(BaseModel): class Service(BaseModel): def __init__( self, - account_id, - region, - name, - namespace_id, - description, - creator_request_id, - dns_config, - health_check_config, - health_check_custom_config, - service_type, + account_id: str, + region: str, + name: str, + namespace_id: str, + description: str, + creator_request_id: str, + dns_config: Dict[str, Any], + health_check_config: Dict[str, Any], + health_check_custom_config: Dict[str, int], + service_type: str, ): - super().__init__() self.id = f"srv-{random_id(8)}" self.arn = f"arn:aws:servicediscovery:{region}:{account_id}:service/{self.id}" self.name = name self.namespace_id = namespace_id self.description = description self.creator_request_id = creator_request_id - self.dns_config = dns_config + self.dns_config: Optional[Dict[str, Any]] = dns_config self.health_check_config = health_check_config self.health_check_custom_config = health_check_custom_config self.service_type = service_type self.created = unix_time() - def update(self, details): + def update(self, details: Dict[str, Any]) -> None: if "Description" in details: self.description = details["Description"] if "DnsConfig" in details: @@ -104,7 +103,7 @@ class Service(BaseModel): if "HealthCheckConfig" in details: self.health_check_config = details["HealthCheckConfig"] - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "Arn": self.arn, "Id": self.id, @@ -121,7 +120,7 @@ class Service(BaseModel): class Operation(BaseModel): - def __init__(self, operation_type, targets): + def __init__(self, operation_type: str, targets: Dict[str, str]): super().__init__() self.id = f"{random_id(32)}-{random_id(8)}" self.status = "SUCCESS" @@ -130,7 +129,7 @@ class Operation(BaseModel): self.updated = unix_time() self.targets = targets - def to_json(self, short=False): + def to_json(self, short: bool = False) -> Dict[str, Any]: if short: return {"Id": self.id, "Status": self.status} else: @@ -147,20 +146,26 @@ class Operation(BaseModel): class ServiceDiscoveryBackend(BaseBackend): """Implementation of ServiceDiscovery APIs.""" - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.operations = dict() - self.namespaces = dict() - self.services = dict() + self.operations: Dict[str, Operation] = dict() + self.namespaces: Dict[str, Namespace] = dict() + self.services: Dict[str, Service] = dict() self.tagger = TaggingService() - def list_namespaces(self): + def list_namespaces(self) -> Iterable[Namespace]: """ Pagination or the Filters-parameter is not yet implemented """ return self.namespaces.values() - def create_http_namespace(self, name, creator_request_id, description, tags): + def create_http_namespace( + self, + name: str, + creator_request_id: str, + description: str, + tags: List[Dict[str, str]], + ) -> str: namespace = Namespace( account_id=self.account_id, region=self.region_name, @@ -179,13 +184,12 @@ class ServiceDiscoveryBackend(BaseBackend): ) return operation_id - def _create_operation(self, op_type, targets): + def _create_operation(self, op_type: str, targets: Dict[str, str]) -> str: operation = Operation(operation_type=op_type, targets=targets) self.operations[operation.id] = operation - operation_id = operation.id - return operation_id + return operation.id - def delete_namespace(self, namespace_id): + def delete_namespace(self, namespace_id: str) -> str: if namespace_id not in self.namespaces: raise NamespaceNotFound(namespace_id) del self.namespaces[namespace_id] @@ -194,12 +198,12 @@ class ServiceDiscoveryBackend(BaseBackend): ) return operation_id - def get_namespace(self, namespace_id): + def get_namespace(self, namespace_id: str) -> Namespace: if namespace_id not in self.namespaces: raise NamespaceNotFound(namespace_id) return self.namespaces[namespace_id] - def list_operations(self): + def list_operations(self) -> Iterable[Operation]: """ Pagination or the Filters-argument is not yet implemented """ @@ -211,23 +215,31 @@ class ServiceDiscoveryBackend(BaseBackend): } return self.operations.values() - def get_operation(self, operation_id): + def get_operation(self, operation_id: str) -> Operation: if operation_id not in self.operations: raise OperationNotFound() return self.operations[operation_id] - def tag_resource(self, resource_arn, tags): + def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None: self.tagger.tag_resource(resource_arn, tags) - def untag_resource(self, resource_arn, tag_keys): + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: self.tagger.untag_resource_using_names(resource_arn, tag_keys) - def list_tags_for_resource(self, resource_arn): + def list_tags_for_resource( + self, resource_arn: str + ) -> Dict[str, List[Dict[str, str]]]: return self.tagger.list_tags_for_resource(resource_arn) def create_private_dns_namespace( - self, name, creator_request_id, description, vpc, tags, properties - ): + self, + name: str, + creator_request_id: str, + description: str, + vpc: str, + tags: List[Dict[str, str]], + properties: Dict[str, Any], + ) -> str: for namespace in self.namespaces.values(): if namespace.vpc == vpc: raise ConflictingDomainExists(vpc) @@ -253,8 +265,13 @@ class ServiceDiscoveryBackend(BaseBackend): return operation_id def create_public_dns_namespace( - self, name, creator_request_id, description, tags, properties - ): + self, + name: str, + creator_request_id: str, + description: str, + tags: List[Dict[str, str]], + properties: Dict[str, Any], + ) -> str: dns_properties = (properties or {}).get("DnsProperties", {}) dns_properties["HostedZoneId"] = "hzi" namespace = Namespace( @@ -277,16 +294,16 @@ class ServiceDiscoveryBackend(BaseBackend): def create_service( self, - name, - namespace_id, - creator_request_id, - description, - dns_config, - health_check_config, - health_check_custom_config, - tags, - service_type, - ): + name: str, + namespace_id: str, + creator_request_id: str, + description: str, + dns_config: Dict[str, Any], + health_check_config: Dict[str, Any], + health_check_custom_config: Dict[str, Any], + tags: List[Dict[str, str]], + service_type: str, + ) -> Service: service = Service( account_id=self.account_id, region=self.region_name, @@ -304,21 +321,21 @@ class ServiceDiscoveryBackend(BaseBackend): self.tagger.tag_resource(service.arn, tags) return service - def get_service(self, service_id): + def get_service(self, service_id: str) -> Service: if service_id not in self.services: raise ServiceNotFound(service_id) return self.services[service_id] - def delete_service(self, service_id): + def delete_service(self, service_id: str) -> None: self.services.pop(service_id, None) - def list_services(self): + def list_services(self) -> Iterable[Service]: """ Pagination or the Filters-argument is not yet implemented """ return self.services.values() - def update_service(self, service_id, details): + def update_service(self, service_id: str, details: Dict[str, Any]) -> str: service = self.get_service(service_id) service.update(details=details) operation_id = self._create_operation( diff --git a/moto/servicediscovery/responses.py b/moto/servicediscovery/responses.py index 55d43fe7a..972b9d532 100644 --- a/moto/servicediscovery/responses.py +++ b/moto/servicediscovery/responses.py @@ -1,24 +1,25 @@ """Handles incoming servicediscovery requests, invokes methods, returns responses.""" import json +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse -from .models import servicediscovery_backends +from .models import servicediscovery_backends, ServiceDiscoveryBackend class ServiceDiscoveryResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="servicediscovery") @property - def servicediscovery_backend(self): + def servicediscovery_backend(self) -> ServiceDiscoveryBackend: """Return backend instance specific for this region.""" return servicediscovery_backends[self.current_account][self.region] - def list_namespaces(self): + def list_namespaces(self) -> TYPE_RESPONSE: namespaces = self.servicediscovery_backend.list_namespaces() return 200, {}, json.dumps({"Namespaces": [ns.to_json() for ns in namespaces]}) - def create_http_namespace(self): + def create_http_namespace(self) -> str: params = json.loads(self.body) name = params.get("Name") creator_request_id = params.get("CreatorRequestId") @@ -32,7 +33,7 @@ class ServiceDiscoveryResponse(BaseResponse): ) return json.dumps(dict(OperationId=operation_id)) - def delete_namespace(self): + def delete_namespace(self) -> str: params = json.loads(self.body) namespace_id = params.get("Id") operation_id = self.servicediscovery_backend.delete_namespace( @@ -40,7 +41,7 @@ class ServiceDiscoveryResponse(BaseResponse): ) return json.dumps(dict(OperationId=operation_id)) - def list_operations(self): + def list_operations(self) -> TYPE_RESPONSE: operations = self.servicediscovery_backend.list_operations() return ( 200, @@ -48,7 +49,7 @@ class ServiceDiscoveryResponse(BaseResponse): json.dumps({"Operations": [o.to_json(short=True) for o in operations]}), ) - def get_operation(self): + def get_operation(self) -> str: params = json.loads(self.body) operation_id = params.get("OperationId") operation = self.servicediscovery_backend.get_operation( @@ -56,7 +57,7 @@ class ServiceDiscoveryResponse(BaseResponse): ) return json.dumps(dict(Operation=operation.to_json())) - def get_namespace(self): + def get_namespace(self) -> str: params = json.loads(self.body) namespace_id = params.get("Id") namespace = self.servicediscovery_backend.get_namespace( @@ -64,23 +65,23 @@ class ServiceDiscoveryResponse(BaseResponse): ) return json.dumps(dict(Namespace=namespace.to_json())) - def tag_resource(self): + def tag_resource(self) -> str: params = json.loads(self.body) resource_arn = params.get("ResourceARN") tags = params.get("Tags") self.servicediscovery_backend.tag_resource(resource_arn=resource_arn, tags=tags) - return json.dumps(dict()) + return "{}" - def untag_resource(self): + def untag_resource(self) -> str: params = json.loads(self.body) resource_arn = params.get("ResourceARN") tag_keys = params.get("TagKeys") self.servicediscovery_backend.untag_resource( resource_arn=resource_arn, tag_keys=tag_keys ) - return json.dumps(dict()) + return "{}" - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> TYPE_RESPONSE: params = json.loads(self.body) resource_arn = params.get("ResourceARN") tags = self.servicediscovery_backend.list_tags_for_resource( @@ -88,7 +89,7 @@ class ServiceDiscoveryResponse(BaseResponse): ) return 200, {}, json.dumps(tags) - def create_private_dns_namespace(self): + def create_private_dns_namespace(self) -> str: params = json.loads(self.body) name = params.get("Name") creator_request_id = params.get("CreatorRequestId") @@ -106,7 +107,7 @@ class ServiceDiscoveryResponse(BaseResponse): ) return json.dumps(dict(OperationId=operation_id)) - def create_public_dns_namespace(self): + def create_public_dns_namespace(self) -> str: params = json.loads(self.body) name = params.get("Name") creator_request_id = params.get("CreatorRequestId") @@ -122,7 +123,7 @@ class ServiceDiscoveryResponse(BaseResponse): ) return json.dumps(dict(OperationId=operation_id)) - def create_service(self): + def create_service(self) -> str: params = json.loads(self.body) name = params.get("Name") namespace_id = params.get("NamespaceId") @@ -146,23 +147,23 @@ class ServiceDiscoveryResponse(BaseResponse): ) return json.dumps(dict(Service=service.to_json())) - def get_service(self): + def get_service(self) -> str: params = json.loads(self.body) service_id = params.get("Id") service = self.servicediscovery_backend.get_service(service_id=service_id) return json.dumps(dict(Service=service.to_json())) - def delete_service(self): + def delete_service(self) -> str: params = json.loads(self.body) service_id = params.get("Id") self.servicediscovery_backend.delete_service(service_id=service_id) - return json.dumps(dict()) + return "{}" - def list_services(self): + def list_services(self) -> str: services = self.servicediscovery_backend.list_services() return json.dumps(dict(Services=[s.to_json() for s in services])) - def update_service(self): + def update_service(self) -> str: params = json.loads(self.body) service_id = params.get("Id") details = params.get("Service") diff --git a/moto/settings.py b/moto/settings.py index a84803df1..5ca0e6b13 100644 --- a/moto/settings.py +++ b/moto/settings.py @@ -38,7 +38,7 @@ SKIP_REQUIRES_DOCKER = bool(os.environ.get("TESTS_SKIP_REQUIRES_DOCKER", False)) LAMBDA_DATA_DIR = os.environ.get("MOTO_LAMBDA_DATA_DIR", "/tmp/data") -def get_sf_execution_history_type(): +def get_sf_execution_history_type() -> str: """ Determines which execution history events `get_execution_history` returns :returns: str representing the type of Step Function Execution Type events should be @@ -97,11 +97,11 @@ def moto_lambda_image() -> str: return os.environ.get("MOTO_DOCKER_LAMBDA_IMAGE", "mlupin/docker-lambda") -def moto_network_name() -> str: +def moto_network_name() -> Optional[str]: return os.environ.get("MOTO_DOCKER_NETWORK_NAME") -def moto_network_mode() -> str: +def moto_network_mode() -> Optional[str]: return os.environ.get("MOTO_DOCKER_NETWORK_MODE") diff --git a/moto/signer/models.py b/moto/signer/models.py index 7289362f6..59395db54 100644 --- a/moto/signer/models.py +++ b/moto/signer/models.py @@ -1,10 +1,18 @@ +from typing import Any, Dict, List, Optional + from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api._internal import mock_random class SigningProfile(BaseModel): def __init__( - self, account_id, region, name, platform_id, signature_validity_period, tags + self, + account_id: str, + region: str, + name: str, + platform_id: str, + signature_validity_period: Optional[Dict[str, Any]], + tags: Dict[str, str], ): self.name = name self.platform_id = platform_id @@ -19,11 +27,11 @@ class SigningProfile(BaseModel): self.profile_version = mock_random.get_random_hex(10) self.profile_version_arn = f"{self.arn}/{self.profile_version}" - def cancel(self): + def cancel(self) -> None: self.status = "Canceled" - def to_dict(self, full=True): - small = { + def to_dict(self, full: bool = True) -> Dict[str, Any]: + small: Dict[str, Any] = { "arn": self.arn, "profileVersion": self.profile_version, "profileVersionArn": self.profile_version_arn, @@ -149,22 +157,22 @@ class SignerBackend(BaseBackend): }, ] - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.signing_profiles: [str, SigningProfile] = dict() + self.signing_profiles: Dict[str, SigningProfile] = dict() - def cancel_signing_profile(self, profile_name) -> None: + def cancel_signing_profile(self, profile_name: str) -> None: self.signing_profiles[profile_name].cancel() - def get_signing_profile(self, profile_name) -> SigningProfile: + def get_signing_profile(self, profile_name: str) -> SigningProfile: return self.signing_profiles[profile_name] def put_signing_profile( self, - profile_name, - signature_validity_period, - platform_id, - tags, + profile_name: str, + signature_validity_period: Optional[Dict[str, Any]], + platform_id: str, + tags: Dict[str, str], ) -> SigningProfile: """ The following parameters are not yet implemented: SigningMaterial, Overrides, SigningParamaters @@ -180,7 +188,7 @@ class SignerBackend(BaseBackend): self.signing_profiles[profile_name] = profile return profile - def list_signing_platforms(self): + def list_signing_platforms(self) -> List[Dict[str, Any]]: """ Pagination is not yet implemented. The parameters category, partner, target are not yet implemented """ @@ -189,4 +197,4 @@ class SignerBackend(BaseBackend): # Using the lambda-regions # boto3.Session().get_available_regions("signer") still returns an empty list -signer_backends: [str, [str, SignerBackend]] = BackendDict(SignerBackend, "lambda") +signer_backends = BackendDict(SignerBackend, "lambda") diff --git a/moto/signer/responses.py b/moto/signer/responses.py index ad42aa587..86073dc48 100644 --- a/moto/signer/responses.py +++ b/moto/signer/responses.py @@ -6,7 +6,7 @@ from .models import signer_backends, SignerBackend class signerResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="signer") @property @@ -14,17 +14,17 @@ class signerResponse(BaseResponse): """Return backend instance specific for this region.""" return signer_backends[self.current_account][self.region] - def cancel_signing_profile(self): + def cancel_signing_profile(self) -> str: profile_name = self.path.split("/")[-1] self.signer_backend.cancel_signing_profile(profile_name=profile_name) return "{}" - def get_signing_profile(self): + def get_signing_profile(self) -> str: profile_name = self.path.split("/")[-1] profile = self.signer_backend.get_signing_profile(profile_name=profile_name) return json.dumps(profile.to_dict()) - def put_signing_profile(self): + def put_signing_profile(self) -> str: params = json.loads(self.body) profile_name = self.path.split("/")[-1] signature_validity_period = params.get("signatureValidityPeriod") @@ -38,6 +38,6 @@ class signerResponse(BaseResponse): ) return json.dumps(profile.to_dict(full=False)) - def list_signing_platforms(self): + def list_signing_platforms(self) -> str: platforms = self.signer_backend.list_signing_platforms() return json.dumps(dict(platforms=platforms)) diff --git a/moto/ssoadmin/exceptions.py b/moto/ssoadmin/exceptions.py index a133e1d5e..990eaa5de 100644 --- a/moto/ssoadmin/exceptions.py +++ b/moto/ssoadmin/exceptions.py @@ -3,5 +3,5 @@ from moto.core.exceptions import JsonRESTError class ResourceNotFound(JsonRESTError): - def __init__(self): + def __init__(self) -> None: super().__init__("ResourceNotFound", "Account not found") diff --git a/moto/ssoadmin/models.py b/moto/ssoadmin/models.py index 0aeba130c..d83c96147 100644 --- a/moto/ssoadmin/models.py +++ b/moto/ssoadmin/models.py @@ -1,21 +1,22 @@ -from .exceptions import ResourceNotFound +from typing import Any, Dict, List from moto.core import BaseBackend, BackendDict, BaseModel from moto.core.utils import unix_time from moto.moto_api._internal import mock_random as random from moto.utilities.paginator import paginate +from .exceptions import ResourceNotFound from .utils import PAGINATION_MODEL class AccountAssignment(BaseModel): def __init__( self, - instance_arn, - target_id, - target_type, - permission_set_arn, - principal_type, - principal_id, + instance_arn: str, + target_id: str, + target_type: str, + permission_set_arn: str, + principal_type: str, + principal_id: str, ): self.request_id = str(random.uuid4()) self.instance_arn = instance_arn @@ -26,8 +27,8 @@ class AccountAssignment(BaseModel): self.principal_id = principal_id self.created_date = unix_time() - def to_json(self, include_creation_date=False): - summary = { + def to_json(self, include_creation_date: bool = False) -> Dict[str, Any]: + summary: Dict[str, Any] = { "TargetId": self.target_id, "TargetType": self.target_type, "PermissionSetArn": self.permission_set_arn, @@ -42,12 +43,12 @@ class AccountAssignment(BaseModel): class PermissionSet(BaseModel): def __init__( self, - name, - description, - instance_arn, - session_duration, - relay_state, - tags, + name: str, + description: str, + instance_arn: str, + session_duration: str, + relay_state: str, + tags: List[Dict[str, str]], ): self.name = name self.description = description @@ -58,8 +59,8 @@ class PermissionSet(BaseModel): self.tags = tags self.created_date = unix_time() - def to_json(self, include_creation_date=False): - summary = { + def to_json(self, include_creation_date: bool = False) -> Dict[str, Any]: + summary: Dict[str, Any] = { "Name": self.name, "Description": self.description, "PermissionSetArn": self.permission_set_arn, @@ -71,7 +72,7 @@ class PermissionSet(BaseModel): return summary @staticmethod - def generate_id(instance_arn): + def generate_id(instance_arn: str) -> str: chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"] return ( instance_arn @@ -83,20 +84,20 @@ class PermissionSet(BaseModel): class SSOAdminBackend(BaseBackend): """Implementation of SSOAdmin APIs.""" - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.account_assignments = list() - self.permission_sets = list() + self.account_assignments: List[AccountAssignment] = list() + self.permission_sets: List[PermissionSet] = list() def create_account_assignment( self, - instance_arn, - target_id, - target_type, - permission_set_arn, - principal_type, - principal_id, - ): + instance_arn: str, + target_id: str, + target_type: str, + permission_set_arn: str, + principal_type: str, + principal_id: str, + ) -> Dict[str, Any]: assignment = AccountAssignment( instance_arn, target_id, @@ -110,13 +111,13 @@ class SSOAdminBackend(BaseBackend): def delete_account_assignment( self, - instance_arn, - target_id, - target_type, - permission_set_arn, - principal_type, - principal_id, - ): + instance_arn: str, + target_id: str, + target_type: str, + permission_set_arn: str, + principal_type: str, + principal_id: str, + ) -> Dict[str, Any]: account = self._find_account( instance_arn, target_id, @@ -130,13 +131,13 @@ class SSOAdminBackend(BaseBackend): def _find_account( self, - instance_arn, - target_id, - target_type, - permission_set_arn, - principal_type, - principal_id, - ): + instance_arn: str, + target_id: str, + target_type: str, + permission_set_arn: str, + principal_type: str, + principal_id: str, + ) -> AccountAssignment: for account in self.account_assignments: instance_arn_match = account.instance_arn == instance_arn target_id_match = account.target_id == target_id @@ -155,7 +156,9 @@ class SSOAdminBackend(BaseBackend): return account raise ResourceNotFound - def list_account_assignments(self, instance_arn, account_id, permission_set_arn): + def list_account_assignments( + self, instance_arn: str, account_id: str, permission_set_arn: str + ) -> List[Dict[str, Any]]: """ Pagination has not yet been implemented """ @@ -178,13 +181,13 @@ class SSOAdminBackend(BaseBackend): def create_permission_set( self, - name, - description, - instance_arn, - session_duration, - relay_state, - tags, - ): + name: str, + description: str, + instance_arn: str, + session_duration: str, + relay_state: str, + tags: List[Dict[str, str]], + ) -> Dict[str, Any]: permission_set = PermissionSet( name, description, @@ -198,12 +201,12 @@ class SSOAdminBackend(BaseBackend): def update_permission_set( self, - instance_arn, - permission_set_arn, - description, - session_duration, - relay_state, - ): + instance_arn: str, + permission_set_arn: str, + description: str, + session_duration: str, + relay_state: str, + ) -> Dict[str, Any]: permission_set = self._find_permission_set( instance_arn, permission_set_arn, @@ -216,10 +219,8 @@ class SSOAdminBackend(BaseBackend): return permission_set.to_json(True) def describe_permission_set( - self, - instance_arn, - permission_set_arn, - ): + self, instance_arn: str, permission_set_arn: str + ) -> Dict[str, Any]: permission_set = self._find_permission_set( instance_arn, permission_set_arn, @@ -227,10 +228,8 @@ class SSOAdminBackend(BaseBackend): return permission_set.to_json(True) def delete_permission_set( - self, - instance_arn, - permission_set_arn, - ): + self, instance_arn: str, permission_set_arn: str + ) -> Dict[str, Any]: permission_set = self._find_permission_set( instance_arn, permission_set_arn, @@ -239,10 +238,8 @@ class SSOAdminBackend(BaseBackend): return permission_set.to_json(include_creation_date=True) def _find_permission_set( - self, - instance_arn, - permission_set_arn, - ): + self, instance_arn: str, permission_set_arn: str + ) -> PermissionSet: for permission_set in self.permission_sets: instance_arn_match = permission_set.instance_arn == instance_arn permission_set_match = ( @@ -253,7 +250,7 @@ class SSOAdminBackend(BaseBackend): raise ResourceNotFound @paginate(pagination_model=PAGINATION_MODEL) - def list_permission_sets(self, instance_arn): + def list_permission_sets(self, instance_arn: str) -> List[PermissionSet]: # type: ignore[misc] permission_sets = [] for permission_set in self.permission_sets: if permission_set.instance_arn == instance_arn: diff --git a/moto/ssoadmin/responses.py b/moto/ssoadmin/responses.py index 83ad84388..7276a71ce 100644 --- a/moto/ssoadmin/responses.py +++ b/moto/ssoadmin/responses.py @@ -3,21 +3,21 @@ import json from moto.core.responses import BaseResponse from moto.moto_api._internal import mock_random -from .models import ssoadmin_backends +from .models import ssoadmin_backends, SSOAdminBackend class SSOAdminResponse(BaseResponse): """Handler for SSOAdmin requests and responses.""" - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="sso-admin") @property - def ssoadmin_backend(self): + def ssoadmin_backend(self) -> SSOAdminBackend: """Return backend instance specific for this region.""" return ssoadmin_backends[self.current_account][self.region] - def create_account_assignment(self): + def create_account_assignment(self) -> str: params = json.loads(self.body) instance_arn = params.get("InstanceArn") target_id = params.get("TargetId") @@ -37,7 +37,7 @@ class SSOAdminResponse(BaseResponse): summary["RequestId"] = str(mock_random.uuid4()) return json.dumps({"AccountAssignmentCreationStatus": summary}) - def delete_account_assignment(self): + def delete_account_assignment(self) -> str: params = json.loads(self.body) instance_arn = params.get("InstanceArn") target_id = params.get("TargetId") @@ -57,7 +57,7 @@ class SSOAdminResponse(BaseResponse): summary["RequestId"] = str(mock_random.uuid4()) return json.dumps({"AccountAssignmentDeletionStatus": summary}) - def list_account_assignments(self): + def list_account_assignments(self) -> str: params = json.loads(self.body) instance_arn = params.get("InstanceArn") account_id = params.get("AccountId") @@ -69,7 +69,7 @@ class SSOAdminResponse(BaseResponse): ) return json.dumps({"AccountAssignments": assignments}) - def create_permission_set(self): + def create_permission_set(self) -> str: name = self._get_param("Name") description = self._get_param("Description") instance_arn = self._get_param("InstanceArn") @@ -88,7 +88,7 @@ class SSOAdminResponse(BaseResponse): return json.dumps({"PermissionSet": permission_set}) - def delete_permission_set(self): + def delete_permission_set(self) -> str: params = json.loads(self.body) instance_arn = params.get("InstanceArn") permission_set_arn = params.get("PermissionSetArn") @@ -96,8 +96,9 @@ class SSOAdminResponse(BaseResponse): instance_arn=instance_arn, permission_set_arn=permission_set_arn, ) + return "{}" - def update_permission_set(self): + def update_permission_set(self) -> str: instance_arn = self._get_param("InstanceArn") permission_set_arn = self._get_param("PermissionSetArn") description = self._get_param("Description") @@ -111,8 +112,9 @@ class SSOAdminResponse(BaseResponse): session_duration=session_duration, relay_state=relay_state, ) + return "{}" - def describe_permission_set(self): + def describe_permission_set(self) -> str: instance_arn = self._get_param("InstanceArn") permission_set_arn = self._get_param("PermissionSetArn") @@ -122,7 +124,7 @@ class SSOAdminResponse(BaseResponse): ) return json.dumps({"PermissionSet": permission_set}) - def list_permission_sets(self): + def list_permission_sets(self) -> str: instance_arn = self._get_param("InstanceArn") max_results = self._get_int_param("MaxResults") next_token = self._get_param("NextToken") diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py index eae9dc1e8..b7e449a54 100644 --- a/moto/stepfunctions/exceptions.py +++ b/moto/stepfunctions/exceptions.py @@ -35,7 +35,7 @@ class InvalidToken(AWSError): TYPE = "InvalidToken" STATUS = 400 - def __init__(self, message="Invalid token"): + def __init__(self, message: str = "Invalid token"): super().__init__(f"Invalid Token: {message}") @@ -43,5 +43,5 @@ class ResourceNotFound(AWSError): TYPE = "ResourceNotFound" STATUS = 400 - def __init__(self, arn): + def __init__(self, arn: str): super().__init__(f"Resource not found: '{arn}'") diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index 64d4eea45..5628c95dd 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -2,6 +2,7 @@ import json import re from datetime import datetime from dateutil.tz import tzlocal +from typing import Any, Dict, List, Iterable, Optional, Pattern from moto.core import BaseBackend, BackendDict, CloudFormationModel from moto.core.utils import iso_8601_datetime_with_milliseconds @@ -21,19 +22,32 @@ from moto.utilities.paginator import paginate class StateMachine(CloudFormationModel): - def __init__(self, arn, name, definition, roleArn, tags=None): + def __init__( + self, + arn: str, + name: str, + definition: str, + roleArn: str, + tags: Optional[List[Dict[str, str]]] = None, + ): self.creation_date = iso_8601_datetime_with_milliseconds(datetime.now()) self.update_date = self.creation_date self.arn = arn self.name = name self.definition = definition self.roleArn = roleArn - self.executions = [] - self.tags = [] + self.executions: List[Execution] = [] + self.tags: List[Dict[str, str]] = [] if tags: self.add_tags(tags) - def start_execution(self, region_name, account_id, execution_name, execution_input): + def start_execution( + self, + region_name: str, + account_id: str, + execution_name: str, + execution_input: str, + ) -> "Execution": self._ensure_execution_name_doesnt_exist(execution_name) self._validate_execution_input(execution_input) execution = Execution( @@ -47,7 +61,7 @@ class StateMachine(CloudFormationModel): self.executions.append(execution) return execution - def stop_execution(self, execution_arn): + def stop_execution(self, execution_arn: str) -> "Execution": execution = next( (x for x in self.executions if x.execution_arn == execution_arn), None ) @@ -58,14 +72,14 @@ class StateMachine(CloudFormationModel): execution.stop() return execution - def _ensure_execution_name_doesnt_exist(self, name): + def _ensure_execution_name_doesnt_exist(self, name: str) -> None: for execution in self.executions: if execution.name == name: raise ExecutionAlreadyExists( "Execution Already Exists: '" + execution.execution_arn + "'" ) - def _validate_execution_input(self, execution_input): + def _validate_execution_input(self, execution_input: str) -> None: try: json.loads(execution_input) except Exception as ex: @@ -73,13 +87,13 @@ class StateMachine(CloudFormationModel): "Invalid State Machine Execution Input: '" + str(ex) + "'" ) - def update(self, **kwargs): + def update(self, **kwargs: Any) -> None: for key, value in kwargs.items(): if value is not None: setattr(self, key, value) self.update_date = iso_8601_datetime_with_milliseconds(datetime.now()) - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: merged_tags = [] for tag in self.tags: replacement_index = next( @@ -96,15 +110,15 @@ class StateMachine(CloudFormationModel): self.tags = merged_tags return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> List[Dict[str, str]]: self.tags = [tag_set for tag_set in self.tags if tag_set["key"] not in tag_keys] return self.tags @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.arn - def get_cfn_properties(self, prop_overrides): + def get_cfn_properties(self, prop_overrides: Dict[str, Any]) -> Dict[str, Any]: property_names = [ "DefinitionString", "RoleArn", @@ -124,7 +138,7 @@ class StateMachine(CloudFormationModel): return properties @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in [ "Name", "DefinitionString", @@ -133,7 +147,7 @@ class StateMachine(CloudFormationModel): "Tags", ] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> Any: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Name": @@ -150,17 +164,22 @@ class StateMachine(CloudFormationModel): raise UnformattedGetAttTemplateException() @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "StateMachine" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::StepFunctions::StateMachine" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "StateMachine": properties = cloudformation_json["Properties"] name = properties.get("StateMachineName", resource_name) definition = properties.get("DefinitionString", "") @@ -170,19 +189,25 @@ class StateMachine(CloudFormationModel): return sf_backend.create_state_machine(name, definition, role_arn, tags=tags) @classmethod - def delete_from_cloudformation_json(cls, resource_name, _, account_id, region_name): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: sf_backend = stepfunction_backends[account_id][region_name] sf_backend.delete_state_machine(resource_name) @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "StateMachine": properties = cloudformation_json.get("Properties", {}) name = properties.get("StateMachineName", original_resource.name) @@ -214,12 +239,12 @@ class StateMachine(CloudFormationModel): class Execution: def __init__( self, - region_name, - account_id, - state_machine_name, - execution_name, - state_machine_arn, - execution_input, + region_name: str, + account_id: str, + state_machine_name: str, + execution_name: str, + state_machine_arn: str, + execution_input: str, ): execution_arn = "arn:aws:states:{}:{}:execution:{}:{}" execution_arn = execution_arn.format( @@ -235,9 +260,9 @@ class Execution: if settings.get_sf_execution_history_type() == "SUCCESS" else "FAILED" ) - self.stop_date = None + self.stop_date: Optional[str] = None - def get_execution_history(self, roleArn): + def get_execution_history(self, roleArn: str) -> List[Dict[str, Any]]: sf_execution_history_type = settings.get_sf_execution_history_type() if sf_execution_history_type == "SUCCESS": return [ @@ -334,8 +359,9 @@ class Execution: }, }, ] + return [] - def stop(self): + def stop(self) -> None: self.status = "ABORTED" self.stop_date = iso_8601_datetime_with_milliseconds(datetime.now()) @@ -451,13 +477,19 @@ class StepFunctionBackend(BaseBackend): "arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):execution:.+" ) - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.state_machines = [] - self.executions = [] + self.state_machines: List[StateMachine] = [] + self.executions: List[Execution] = [] self._account_id = None - def create_state_machine(self, name, definition, roleArn, tags=None): + def create_state_machine( + self, + name: str, + definition: str, + roleArn: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> StateMachine: self._validate_name(name) self._validate_role_arn(roleArn) arn = f"arn:aws:states:{self.region_name}:{self.account_id}:stateMachine:{name}" @@ -469,11 +501,10 @@ class StepFunctionBackend(BaseBackend): return state_machine @paginate(pagination_model=PAGINATION_MODEL) - def list_state_machines(self): - state_machines = sorted(self.state_machines, key=lambda x: x.creation_date) - return state_machines + def list_state_machines(self) -> Iterable[StateMachine]: # type: ignore[misc] + return sorted(self.state_machines, key=lambda x: x.creation_date) - def describe_state_machine(self, arn): + def describe_state_machine(self, arn: str) -> StateMachine: self._validate_machine_arn(arn) sm = next((x for x in self.state_machines if x.arn == arn), None) if not sm: @@ -482,13 +513,15 @@ class StepFunctionBackend(BaseBackend): ) return sm - def delete_state_machine(self, arn): + def delete_state_machine(self, arn: str) -> None: self._validate_machine_arn(arn) sm = next((x for x in self.state_machines if x.arn == arn), None) if sm: self.state_machines.remove(sm) - def update_state_machine(self, arn, definition=None, role_arn=None): + def update_state_machine( + self, arn: str, definition: Optional[str] = None, role_arn: Optional[str] = None + ) -> StateMachine: sm = self.describe_state_machine(arn) updates = { "definition": definition, @@ -497,23 +530,24 @@ class StepFunctionBackend(BaseBackend): sm.update(**updates) return sm - def start_execution(self, state_machine_arn, name=None, execution_input=None): + def start_execution( + self, state_machine_arn: str, name: str, execution_input: str + ) -> Execution: state_machine = self.describe_state_machine(state_machine_arn) - execution = state_machine.start_execution( + return state_machine.start_execution( region_name=self.region_name, account_id=self.account_id, execution_name=name or str(mock_random.uuid4()), execution_input=execution_input, ) - return execution - def stop_execution(self, execution_arn): + def stop_execution(self, execution_arn: str) -> Execution: self._validate_execution_arn(execution_arn) state_machine = self._get_state_machine_for_execution(execution_arn) return state_machine.stop_execution(execution_arn) @paginate(pagination_model=PAGINATION_MODEL) - def list_executions(self, state_machine_arn, status_filter=None): + def list_executions(self, state_machine_arn: str, status_filter: Optional[str] = None) -> Iterable[Execution]: # type: ignore[misc] """ The status of every execution is set to 'RUNNING' by default. Set the following environment variable if you want to get a FAILED status back: @@ -530,7 +564,7 @@ class StepFunctionBackend(BaseBackend): executions = sorted(executions, key=lambda x: x.start_date, reverse=True) return executions - def describe_execution(self, execution_arn): + def describe_execution(self, execution_arn: str) -> Execution: """ The status of every execution is set to 'RUNNING' by default. Set the following environment variable if you want to get a FAILED status back: @@ -551,7 +585,7 @@ class StepFunctionBackend(BaseBackend): ) return exctn - def get_execution_history(self, execution_arn): + def get_execution_history(self, execution_arn: str) -> List[Dict[str, Any]]: """ A static list of successful events is returned by default. Set the following environment variable if you want to get a static list of events for a failed execution: @@ -572,61 +606,61 @@ class StepFunctionBackend(BaseBackend): ) return execution.get_execution_history(state_machine.roleArn) - def list_tags_for_resource(self, arn): + def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]: try: state_machine = self.describe_state_machine(arn) return state_machine.tags or [] except StateMachineDoesNotExist: return [] - def tag_resource(self, resource_arn, tags): + def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None: try: state_machine = self.describe_state_machine(resource_arn) state_machine.add_tags(tags) except StateMachineDoesNotExist: raise ResourceNotFound(resource_arn) - def untag_resource(self, resource_arn, tag_keys): + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: try: state_machine = self.describe_state_machine(resource_arn) state_machine.remove_tags(tag_keys) except StateMachineDoesNotExist: raise ResourceNotFound(resource_arn) - def _validate_name(self, name): + def _validate_name(self, name: str) -> None: if any(invalid_char in name for invalid_char in self.invalid_chars_for_name): raise InvalidName("Invalid Name: '" + name + "'") if any(name.find(char) >= 0 for char in self.invalid_unicodes_for_name): raise InvalidName("Invalid Name: '" + name + "'") - def _validate_role_arn(self, role_arn): + def _validate_role_arn(self, role_arn: str) -> None: self._validate_arn( arn=role_arn, regex=self.accepted_role_arn_format, invalid_msg="Invalid Role Arn: '" + role_arn + "'", ) - def _validate_machine_arn(self, machine_arn): + def _validate_machine_arn(self, machine_arn: str) -> None: self._validate_arn( arn=machine_arn, regex=self.accepted_mchn_arn_format, invalid_msg="Invalid State Machine Arn: '" + machine_arn + "'", ) - def _validate_execution_arn(self, execution_arn): + def _validate_execution_arn(self, execution_arn: str) -> None: self._validate_arn( arn=execution_arn, regex=self.accepted_exec_arn_format, invalid_msg="Execution Does Not Exist: '" + execution_arn + "'", ) - def _validate_arn(self, arn, regex, invalid_msg): + def _validate_arn(self, arn: str, regex: Pattern[str], invalid_msg: str) -> None: match = regex.match(arn) if not arn or not match: raise InvalidArn(invalid_msg) - def _get_state_machine_for_execution(self, execution_arn): + def _get_state_machine_for_execution(self, execution_arn: str) -> StateMachine: state_machine_name = execution_arn.split(":")[6] state_machine_arn = next( (x.arn for x in self.state_machines if x.name == state_machine_name), None diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py index d13b14ee5..cf68b83bf 100644 --- a/moto/stepfunctions/responses.py +++ b/moto/stepfunctions/responses.py @@ -1,20 +1,21 @@ import json +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse from moto.utilities.aws_headers import amzn_request_id -from .models import stepfunction_backends +from .models import stepfunction_backends, StepFunctionBackend class StepFunctionResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="stepfunctions") @property - def stepfunction_backend(self): + def stepfunction_backend(self) -> StepFunctionBackend: return stepfunction_backends[self.current_account][self.region] @amzn_request_id - def create_state_machine(self): + def create_state_machine(self) -> TYPE_RESPONSE: name = self._get_param("name") definition = self._get_param("definition") roleArn = self._get_param("roleArn") @@ -29,7 +30,7 @@ class StepFunctionResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def list_state_machines(self): + def list_state_machines(self) -> TYPE_RESPONSE: max_results = self._get_int_param("maxResults") next_token = self._get_param("nextToken") results, next_token = self.stepfunction_backend.list_state_machines( @@ -49,12 +50,12 @@ class StepFunctionResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def describe_state_machine(self): + def describe_state_machine(self) -> TYPE_RESPONSE: arn = self._get_param("stateMachineArn") return self._describe_state_machine(arn) @amzn_request_id - def _describe_state_machine(self, state_machine_arn): + def _describe_state_machine(self, state_machine_arn: str) -> TYPE_RESPONSE: state_machine = self.stepfunction_backend.describe_state_machine( state_machine_arn ) @@ -69,13 +70,13 @@ class StepFunctionResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def delete_state_machine(self): + def delete_state_machine(self) -> TYPE_RESPONSE: arn = self._get_param("stateMachineArn") self.stepfunction_backend.delete_state_machine(arn) return 200, {}, json.dumps("{}") @amzn_request_id - def update_state_machine(self): + def update_state_machine(self) -> TYPE_RESPONSE: arn = self._get_param("stateMachineArn") definition = self._get_param("definition") role_arn = self._get_param("roleArn") @@ -88,28 +89,28 @@ class StepFunctionResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> TYPE_RESPONSE: arn = self._get_param("resourceArn") tags = self.stepfunction_backend.list_tags_for_resource(arn) response = {"tags": tags} return 200, {}, json.dumps(response) @amzn_request_id - def tag_resource(self): + def tag_resource(self) -> TYPE_RESPONSE: arn = self._get_param("resourceArn") tags = self._get_param("tags", []) self.stepfunction_backend.tag_resource(arn, tags) return 200, {}, json.dumps({}) @amzn_request_id - def untag_resource(self): + def untag_resource(self) -> TYPE_RESPONSE: arn = self._get_param("resourceArn") tag_keys = self._get_param("tagKeys", []) self.stepfunction_backend.untag_resource(arn, tag_keys) return 200, {}, json.dumps({}) @amzn_request_id - def start_execution(self): + def start_execution(self) -> TYPE_RESPONSE: arn = self._get_param("stateMachineArn") name = self._get_param("name") execution_input = self._get_param("input", if_none="{}") @@ -123,7 +124,7 @@ class StepFunctionResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def list_executions(self): + def list_executions(self) -> TYPE_RESPONSE: max_results = self._get_int_param("maxResults") next_token = self._get_param("nextToken") arn = self._get_param("stateMachineArn") @@ -151,7 +152,7 @@ class StepFunctionResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def describe_execution(self): + def describe_execution(self) -> TYPE_RESPONSE: arn = self._get_param("executionArn") execution = self.stepfunction_backend.describe_execution(arn) response = { @@ -166,20 +167,20 @@ class StepFunctionResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def describe_state_machine_for_execution(self): + def describe_state_machine_for_execution(self) -> TYPE_RESPONSE: arn = self._get_param("executionArn") execution = self.stepfunction_backend.describe_execution(arn) return self._describe_state_machine(execution.state_machine_arn) @amzn_request_id - def stop_execution(self): + def stop_execution(self) -> TYPE_RESPONSE: arn = self._get_param("executionArn") execution = self.stepfunction_backend.stop_execution(arn) response = {"stopDate": execution.stop_date} return 200, {}, json.dumps(response) @amzn_request_id - def get_execution_history(self): + def get_execution_history(self) -> TYPE_RESPONSE: execution_arn = self._get_param("executionArn") execution_history = self.stepfunction_backend.get_execution_history( execution_arn diff --git a/moto/stepfunctions/utils.py b/moto/stepfunctions/utils.py index 20881771f..db9aa1f36 100644 --- a/moto/stepfunctions/utils.py +++ b/moto/stepfunctions/utils.py @@ -1,3 +1,6 @@ +from typing import Dict, List + + PAGINATION_MODEL = { "list_executions": { "input_token": "next_token", @@ -14,11 +17,9 @@ PAGINATION_MODEL = { } -def cfn_to_api_tags(cfn_tags_entry): - api_tags = [{k.lower(): v for k, v in d.items()} for d in cfn_tags_entry] - return api_tags +def cfn_to_api_tags(cfn_tags_entry: List[Dict[str, str]]) -> List[Dict[str, str]]: + return [{k.lower(): v for k, v in d.items()} for d in cfn_tags_entry] -def api_to_cfn_tags(api_tags): - cfn_tags_entry = [{k.capitalize(): v for k, v in d.items()} for d in api_tags] - return cfn_tags_entry +def api_to_cfn_tags(api_tags: List[Dict[str, str]]) -> List[Dict[str, str]]: + return [{k.capitalize(): v for k, v in d.items()} for d in api_tags] diff --git a/moto/sts/exceptions.py b/moto/sts/exceptions.py index 021945f9f..6f136af2d 100644 --- a/moto/sts/exceptions.py +++ b/moto/sts/exceptions.py @@ -1,3 +1,4 @@ +from typing import Any from moto.core.exceptions import RESTError @@ -6,5 +7,5 @@ class STSClientError(RESTError): class STSValidationError(STSClientError): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__("ValidationError", *args, **kwargs) diff --git a/moto/sts/models.py b/moto/sts/models.py index bf0641b55..c60525415 100644 --- a/moto/sts/models.py +++ b/moto/sts/models.py @@ -1,10 +1,12 @@ from base64 import b64decode +from typing import Any, Dict, List, Optional, Tuple import datetime import re import xmltodict + from moto.core import BaseBackend, BaseModel, BackendDict from moto.core.utils import iso_8601_datetime_with_milliseconds -from moto.iam import iam_backends +from moto.iam.models import iam_backends, AccessKey from moto.sts.utils import ( random_session_token, DEFAULT_STS_SESSION_DURATION, @@ -13,27 +15,27 @@ from moto.sts.utils import ( class Token(BaseModel): - def __init__(self, duration, name=None): + def __init__(self, duration: int, name: Optional[str] = None): now = datetime.datetime.utcnow() self.expiration = now + datetime.timedelta(seconds=duration) self.name = name self.policy = None @property - def expiration_ISO8601(self): + def expiration_ISO8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.expiration) class AssumedRole(BaseModel): def __init__( self, - account_id, - access_key, - role_session_name, - role_arn, - policy, - duration, - external_id, + account_id: str, + access_key: AccessKey, + role_session_name: str, + role_arn: str, + policy: str, + duration: int, + external_id: str, ): self.account_id = account_id self.session_name = role_session_name @@ -48,11 +50,11 @@ class AssumedRole(BaseModel): self.session_token = random_session_token() @property - def expiration_ISO8601(self): + def expiration_ISO8601(self) -> str: return iso_8601_datetime_with_milliseconds(self.expiration) @property - def user_id(self): + def user_id(self) -> str: iam_backend = iam_backends[self.account_id]["global"] try: role_id = iam_backend.get_role_by_arn(arn=self.role_arn).id @@ -61,31 +63,38 @@ class AssumedRole(BaseModel): return role_id + ":" + self.session_name @property - def arn(self): + def arn(self) -> str: return f"arn:aws:sts::{self.account_id}:assumed-role/{self.role_arn.split('/')[-1]}/{self.session_name}" class STSBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.assumed_roles = [] + self.assumed_roles: List[AssumedRole] = [] @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "sts" ) - def get_session_token(self, duration): - token = Token(duration=duration) - return token + def get_session_token(self, duration: int) -> Token: + return Token(duration=duration) - def get_federation_token(self, name, duration): - token = Token(duration=duration, name=name) - return token + def get_federation_token(self, name: Optional[str], duration: int) -> Token: + return Token(duration=duration, name=name) - def assume_role(self, role_session_name, role_arn, policy, duration, external_id): + def assume_role( + self, + role_session_name: str, + role_arn: str, + policy: str, + duration: int, + external_id: str, + ) -> AssumedRole: """ Assume an IAM Role. Note that the role does not need to exist. The ARN can point to another account, providing an opportunity to switch accounts. """ @@ -103,16 +112,18 @@ class STSBackend(BaseBackend): account_backend.assumed_roles.append(role) return role - def get_assumed_role_from_access_key(self, access_key_id): + def get_assumed_role_from_access_key( + self, access_key_id: str + ) -> Optional[AssumedRole]: for assumed_role in self.assumed_roles: if assumed_role.access_key_id == access_key_id: return assumed_role return None - def assume_role_with_web_identity(self, **kwargs): + def assume_role_with_web_identity(self, **kwargs: Any) -> AssumedRole: return self.assume_role(**kwargs) - def assume_role_with_saml(self, **kwargs): + def assume_role_with_saml(self, **kwargs: Any) -> AssumedRole: del kwargs["principal_arn"] saml_assertion_encoded = kwargs.pop("saml_assertion") saml_assertion_decoded = b64decode(saml_assertion_encoded) @@ -150,7 +161,7 @@ class STSBackend(BaseBackend): if "duration" not in kwargs: kwargs["duration"] = DEFAULT_STS_SESSION_DURATION - account_id, access_key = self._create_access_key(role=target_role) + account_id, access_key = self._create_access_key(role=target_role) # type: ignore kwargs["account_id"] = account_id kwargs["access_key"] = access_key @@ -160,7 +171,7 @@ class STSBackend(BaseBackend): self.assumed_roles.append(role) return role - def get_caller_identity(self, access_key_id): + def get_caller_identity(self, access_key_id: str) -> Tuple[str, str, str]: assumed_role = self.get_assumed_role_from_access_key(access_key_id) if assumed_role: return assumed_role.user_id, assumed_role.arn, assumed_role.account_id @@ -175,7 +186,7 @@ class STSBackend(BaseBackend): arn = f"arn:aws:sts::{self.account_id}:user/moto" return user_id, arn, self.account_id - def _create_access_key(self, role): + def _create_access_key(self, role: str) -> Tuple[str, AccessKey]: account_id_match = re.search(r"arn:aws:iam::([0-9]+).+", role) if account_id_match: account_id = account_id_match.group(1) diff --git a/moto/sts/responses.py b/moto/sts/responses.py index d6099566d..c54481652 100644 --- a/moto/sts/responses.py +++ b/moto/sts/responses.py @@ -1,25 +1,25 @@ from moto.core.responses import BaseResponse from .exceptions import STSValidationError -from .models import sts_backends +from .models import sts_backends, STSBackend MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048 class TokenResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="sts") @property - def backend(self): + def backend(self) -> STSBackend: return sts_backends[self.current_account]["global"] - def get_session_token(self): + def get_session_token(self) -> str: duration = int(self.querystring.get("DurationSeconds", [43200])[0]) token = self.backend.get_session_token(duration=duration) template = self.response_template(GET_SESSION_TOKEN_RESPONSE) return template.render(token=token) - def get_federation_token(self): + def get_federation_token(self) -> str: duration = int(self.querystring.get("DurationSeconds", [43200])[0]) policy = self.querystring.get("Policy", [None])[0] @@ -31,14 +31,14 @@ class TokenResponse(BaseResponse): f" equal to {MAX_FEDERATION_TOKEN_POLICY_LENGTH}" ) - name = self.querystring.get("Name")[0] + name = self.querystring.get("Name")[0] # type: ignore token = self.backend.get_federation_token(duration=duration, name=name) template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE) return template.render(token=token, account_id=self.current_account) - def assume_role(self): - role_session_name = self.querystring.get("RoleSessionName")[0] - role_arn = self.querystring.get("RoleArn")[0] + def assume_role(self) -> str: + role_session_name = self.querystring.get("RoleSessionName")[0] # type: ignore + role_arn = self.querystring.get("RoleArn")[0] # type: ignore policy = self.querystring.get("Policy", [None])[0] duration = int(self.querystring.get("DurationSeconds", [3600])[0]) @@ -54,9 +54,9 @@ class TokenResponse(BaseResponse): template = self.response_template(ASSUME_ROLE_RESPONSE) return template.render(role=role) - def assume_role_with_web_identity(self): - role_session_name = self.querystring.get("RoleSessionName")[0] - role_arn = self.querystring.get("RoleArn")[0] + def assume_role_with_web_identity(self) -> str: + role_session_name = self.querystring.get("RoleSessionName")[0] # type: ignore + role_arn = self.querystring.get("RoleArn")[0] # type: ignore policy = self.querystring.get("Policy", [None])[0] duration = int(self.querystring.get("DurationSeconds", [3600])[0]) @@ -72,10 +72,10 @@ class TokenResponse(BaseResponse): template = self.response_template(ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE) return template.render(role=role) - def assume_role_with_saml(self): - role_arn = self.querystring.get("RoleArn")[0] - principal_arn = self.querystring.get("PrincipalArn")[0] - saml_assertion = self.querystring.get("SAMLAssertion")[0] + def assume_role_with_saml(self) -> str: + role_arn = self.querystring.get("RoleArn")[0] # type: ignore + principal_arn = self.querystring.get("PrincipalArn")[0] # type: ignore + saml_assertion = self.querystring.get("SAMLAssertion")[0] # type: ignore role = self.backend.assume_role_with_saml( role_arn=role_arn, @@ -85,7 +85,7 @@ class TokenResponse(BaseResponse): template = self.response_template(ASSUME_ROLE_WITH_SAML_RESPONSE) return template.render(role=role) - def get_caller_identity(self): + def get_caller_identity(self) -> str: template = self.response_template(GET_CALLER_IDENTITY_RESPONSE) access_key_id = self.get_access_key() diff --git a/moto/sts/utils.py b/moto/sts/utils.py index afd095796..c2149146f 100644 --- a/moto/sts/utils.py +++ b/moto/sts/utils.py @@ -16,13 +16,13 @@ def random_session_token() -> str: ) -def random_assumed_role_id(): +def random_assumed_role_id() -> str: return ( ACCOUNT_SPECIFIC_ASSUMED_ROLE_ID_PREFIX + _random_uppercase_or_digit_sequence(9) ) -def _random_uppercase_or_digit_sequence(length): +def _random_uppercase_or_digit_sequence(length: int) -> str: return "".join( str(random.choice(string.ascii_uppercase + string.digits)) for _ in range(length) diff --git a/moto/support/models.py b/moto/support/models.py index 96ada0392..852e94a8c 100644 --- a/moto/support/models.py +++ b/moto/support/models.py @@ -4,6 +4,7 @@ from moto.moto_api._internal.managed_state_model import ManagedState from moto.moto_api._internal import mock_random as random from moto.utilities.utils import load_resource import datetime +from typing import Any, Dict, List, Optional checks_json = "resources/describe_trusted_advisor_checks.json" @@ -11,7 +12,7 @@ ADVISOR_CHECKS = load_resource(__name__, checks_json) class SupportCase(ManagedState): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): # Configure ManagedState super().__init__( "support::case", @@ -56,21 +57,21 @@ class SupportCase(ManagedState): } } - def get_datetime(self): + def get_datetime(self) -> str: return str(datetime.datetime.now().isoformat()) class SupportBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.check_status = {} - self.cases = {} + self.check_status: Dict[str, str] = {} + self.cases: Dict[str, SupportCase] = {} state_manager.register_default_transition( model_name="support::case", transition={"progression": "manual", "times": 1} ) - def describe_trusted_advisor_checks(self): + def describe_trusted_advisor_checks(self) -> List[Dict[str, Any]]: """ The Language-parameter is not yet implemented """ @@ -78,18 +79,17 @@ class SupportBackend(BaseBackend): checks = ADVISOR_CHECKS["checks"] return checks - def refresh_trusted_advisor_check(self, check_id): + def refresh_trusted_advisor_check(self, check_id: str) -> Dict[str, Any]: self.advance_check_status(check_id) - status = { + return { "status": { "checkId": check_id, "status": self.check_status[check_id], "millisUntilNextRefreshable": 123, } } - return status - def advance_check_status(self, check_id): + def advance_check_status(self, check_id: str) -> None: """ Fake an advancement through statuses on refreshing TA checks """ @@ -111,14 +111,13 @@ class SupportBackend(BaseBackend): elif self.check_status[check_id] == "abandoned": self.check_status[check_id] = "none" - def advance_case_status(self, case_id): + def advance_case_status(self, case_id: str) -> None: """ Fake an advancement through case statuses """ - self.cases[case_id].advance() - def advance_case_severity_codes(self, case_id): + def advance_case_severity_codes(self, case_id: str) -> None: """ Fake an advancement through case status severities """ @@ -137,28 +136,26 @@ class SupportBackend(BaseBackend): elif self.cases[case_id].severity_code == "critical": self.cases[case_id].severity_code = "low" - def resolve_case(self, case_id): + def resolve_case(self, case_id: str) -> Dict[str, Optional[str]]: self.advance_case_status(case_id) - resolved_case = { + return { "initialCaseStatus": self.cases[case_id].status, "finalCaseStatus": "resolved", } - return resolved_case - # persist case details to self.cases def create_case( self, - subject, - service_code, - severity_code, - category_code, - communication_body, - cc_email_addresses, - language, - attachment_set_id, - ): + subject: str, + service_code: str, + severity_code: str, + category_code: str, + communication_body: str, + cc_email_addresses: List[str], + language: str, + attachment_set_id: str, + ) -> Dict[str, str]: """ The IssueType-parameter is not yet implemented """ @@ -184,11 +181,11 @@ class SupportBackend(BaseBackend): def describe_cases( self, - case_id_list, - include_resolved_cases, - next_token, - include_communications, - ): + case_id_list: List[str], + include_resolved_cases: bool, + next_token: Optional[str], + include_communications: bool, + ) -> Dict[str, Any]: """ The following parameters have not yet been implemented: DisplayID, AfterTime, BeforeTime, MaxResults, Language @@ -223,10 +220,7 @@ class SupportBackend(BaseBackend): continue cases.append(formatted_case) - case_values = {"cases": cases} - case_values.update({"nextToken": next_token}) - - return case_values + return {"cases": cases, "nextToken": next_token} support_backends = BackendDict( diff --git a/moto/support/responses.py b/moto/support/responses.py index a400dc36c..fc5507cfd 100644 --- a/moto/support/responses.py +++ b/moto/support/responses.py @@ -1,33 +1,33 @@ from moto.core.responses import BaseResponse -from .models import support_backends +from .models import support_backends, SupportBackend import json class SupportResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="support") @property - def support_backend(self): + def support_backend(self) -> SupportBackend: return support_backends[self.current_account][self.region] - def describe_trusted_advisor_checks(self): + def describe_trusted_advisor_checks(self) -> str: checks = self.support_backend.describe_trusted_advisor_checks() return json.dumps({"checks": checks}) - def refresh_trusted_advisor_check(self): + def refresh_trusted_advisor_check(self) -> str: check_id = self._get_param("checkId") status = self.support_backend.refresh_trusted_advisor_check(check_id=check_id) return json.dumps(status) - def resolve_case(self): + def resolve_case(self) -> str: case_id = self._get_param("caseId") resolve_case_response = self.support_backend.resolve_case(case_id=case_id) return json.dumps(resolve_case_response) - def create_case(self): + def create_case(self) -> str: subject = self._get_param("subject") service_code = self._get_param("serviceCode") severity_code = self._get_param("severityCode") @@ -49,7 +49,7 @@ class SupportResponse(BaseResponse): return json.dumps(create_case_response) - def describe_cases(self): + def describe_cases(self) -> str: case_id_list = self._get_param("caseIdList") include_resolved_cases = self._get_param("includeResolvedCases", False) next_token = self._get_param("nextToken") diff --git a/setup.cfg b/setup.cfg index 7d7b3bef9..32597d466 100644 --- a/setup.cfg +++ b/setup.cfg @@ -239,7 +239,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf,moto/sns +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s* show_column_numbers=True show_error_codes = True disable_error_code=abstract