Techdebt: MyPy S (#6261)

This commit is contained in:
Bert Blommers 2023-04-26 22:20:28 +00:00 committed by GitHub
parent 37f1456747
commit f38babb026
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 480 additions and 406 deletions

View File

@ -1,4 +1,5 @@
"""Exceptions raised by the sdb service.""" """Exceptions raised by the sdb service."""
from typing import Any
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
@ -18,7 +19,7 @@ SDB_ERROR = """<?xml version="1.0"?>
class InvalidParameterError(RESTError): class InvalidParameterError(RESTError):
code = 400 code = 400
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
kwargs.setdefault("template", "sdb_error") kwargs.setdefault("template", "sdb_error")
self.templates["sdb_error"] = SDB_ERROR self.templates["sdb_error"] = SDB_ERROR
kwargs["error_type"] = "InvalidParameterValue" kwargs["error_type"] = "InvalidParameterValue"
@ -28,7 +29,7 @@ class InvalidParameterError(RESTError):
class InvalidDomainName(InvalidParameterError): class InvalidDomainName(InvalidParameterError):
code = 400 code = 400
def __init__(self, domain_name): def __init__(self, domain_name: str):
super().__init__( super().__init__(
message=f"Value ({domain_name}) for parameter DomainName is invalid. " message=f"Value ({domain_name}) for parameter DomainName is invalid. "
) )
@ -37,7 +38,7 @@ class InvalidDomainName(InvalidParameterError):
class UnknownDomainName(RESTError): class UnknownDomainName(RESTError):
code = 400 code = 400
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
kwargs.setdefault("template", "sdb_error") kwargs.setdefault("template", "sdb_error")
self.templates["sdb_error"] = SDB_ERROR self.templates["sdb_error"] = SDB_ERROR
kwargs["error_type"] = "NoSuchDomain" kwargs["error_type"] = "NoSuchDomain"

View File

@ -1,23 +1,24 @@
"""SimpleDBBackend class with methods for supported APIs.""" """SimpleDBBackend class with methods for supported APIs."""
import re import re
from collections import defaultdict from collections import defaultdict
from moto.core import BaseBackend, BackendDict, BaseModel
from threading import Lock from threading import Lock
from typing import Any, Dict, List, Iterable, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import InvalidDomainName, UnknownDomainName from .exceptions import InvalidDomainName, UnknownDomainName
class FakeItem(BaseModel): class FakeItem(BaseModel):
def __init__(self): def __init__(self) -> None:
self.attributes = [] self.attributes: List[Dict[str, Any]] = []
self.lock = Lock() self.lock = Lock()
def get_attributes(self, names): def get_attributes(self, names: Optional[List[str]]) -> List[Dict[str, Any]]:
if not names: if not names:
return self.attributes return self.attributes
return [attr for attr in self.attributes if attr["name"] in names] return [attr for attr in self.attributes if attr["name"] in names]
def put_attributes(self, attributes): def put_attributes(self, attributes: List[Dict[str, Any]]) -> None:
# Replacing attributes involves quite a few loops # Replacing attributes involves quite a few loops
# Lock this, so we know noone else touches this list while we're operating on it # Lock this, so we know noone else touches this list while we're operating on it
with self.lock: with self.lock:
@ -26,56 +27,58 @@ class FakeItem(BaseModel):
self._remove_attributes(attr["name"]) self._remove_attributes(attr["name"])
self.attributes.append(attr) self.attributes.append(attr)
def _remove_attributes(self, name): def _remove_attributes(self, name: str) -> None:
self.attributes = [attr for attr in self.attributes if attr["name"] != name] self.attributes = [attr for attr in self.attributes if attr["name"] != name]
class FakeDomain(BaseModel): class FakeDomain(BaseModel):
def __init__(self, name): def __init__(self, name: str):
self.name = name self.name = name
self.items = defaultdict(FakeItem) self.items: Dict[str, FakeItem] = defaultdict(FakeItem)
def get(self, item_name, attribute_names): def get(self, item_name: str, attribute_names: List[str]) -> List[Dict[str, Any]]:
item = self.items[item_name] item = self.items[item_name]
return item.get_attributes(attribute_names) return item.get_attributes(attribute_names)
def put(self, item_name, attributes): def put(self, item_name: str, attributes: List[Dict[str, Any]]) -> None:
item = self.items[item_name] item = self.items[item_name]
item.put_attributes(attributes) item.put_attributes(attributes)
class SimpleDBBackend(BaseBackend): class SimpleDBBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.domains = dict() self.domains: Dict[str, FakeDomain] = dict()
def create_domain(self, domain_name): def create_domain(self, domain_name: str) -> None:
self._validate_domain_name(domain_name) self._validate_domain_name(domain_name)
self.domains[domain_name] = FakeDomain(name=domain_name) self.domains[domain_name] = FakeDomain(name=domain_name)
def list_domains(self): def list_domains(self) -> Iterable[str]:
""" """
The `max_number_of_domains` and `next_token` parameter have not been implemented yet - we simply return all domains. The `max_number_of_domains` and `next_token` parameter have not been implemented yet - we simply return all domains.
""" """
return self.domains.keys() return self.domains.keys()
def delete_domain(self, domain_name): def delete_domain(self, domain_name: str) -> None:
self._validate_domain_name(domain_name) self._validate_domain_name(domain_name)
# Ignore unknown domains - AWS does the same # Ignore unknown domains - AWS does the same
self.domains.pop(domain_name, None) self.domains.pop(domain_name, None)
def _validate_domain_name(self, domain_name): def _validate_domain_name(self, domain_name: str) -> None:
# Domain Name needs to have at least 3 chars # Domain Name needs to have at least 3 chars
# Can only contain characters: a-z, A-Z, 0-9, '_', '-', and '.' # Can only contain characters: a-z, A-Z, 0-9, '_', '-', and '.'
if not re.match("^[a-zA-Z0-9-_.]{3,}$", domain_name): if not re.match("^[a-zA-Z0-9-_.]{3,}$", domain_name):
raise InvalidDomainName(domain_name) raise InvalidDomainName(domain_name)
def _get_domain(self, domain_name): def _get_domain(self, domain_name: str) -> FakeDomain:
if domain_name not in self.domains: if domain_name not in self.domains:
raise UnknownDomainName() raise UnknownDomainName()
return self.domains[domain_name] return self.domains[domain_name]
def get_attributes(self, domain_name, item_name, attribute_names): def get_attributes(
self, domain_name: str, item_name: str, attribute_names: List[str]
) -> List[Dict[str, Any]]:
""" """
Behaviour for the consistent_read-attribute is not yet implemented Behaviour for the consistent_read-attribute is not yet implemented
""" """
@ -83,7 +86,9 @@ class SimpleDBBackend(BaseBackend):
domain = self._get_domain(domain_name) domain = self._get_domain(domain_name)
return domain.get(item_name, attribute_names) return domain.get(item_name, attribute_names)
def put_attributes(self, domain_name, item_name, attributes): def put_attributes(
self, domain_name: str, item_name: str, attributes: List[Dict[str, Any]]
) -> None:
""" """
Behaviour for the expected-attribute is not yet implemented. Behaviour for the expected-attribute is not yet implemented.
""" """

View File

@ -1,33 +1,33 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import sdb_backends from .models import sdb_backends, SimpleDBBackend
class SimpleDBResponse(BaseResponse): class SimpleDBResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="sdb") super().__init__(service_name="sdb")
@property @property
def sdb_backend(self): def sdb_backend(self) -> SimpleDBBackend:
return sdb_backends[self.current_account][self.region] return sdb_backends[self.current_account][self.region]
def create_domain(self): def create_domain(self) -> str:
domain_name = self._get_param("DomainName") domain_name = self._get_param("DomainName")
self.sdb_backend.create_domain(domain_name=domain_name) self.sdb_backend.create_domain(domain_name=domain_name)
template = self.response_template(CREATE_DOMAIN_TEMPLATE) template = self.response_template(CREATE_DOMAIN_TEMPLATE)
return template.render() return template.render()
def delete_domain(self): def delete_domain(self) -> str:
domain_name = self._get_param("DomainName") domain_name = self._get_param("DomainName")
self.sdb_backend.delete_domain(domain_name=domain_name) self.sdb_backend.delete_domain(domain_name=domain_name)
template = self.response_template(DELETE_DOMAIN_TEMPLATE) template = self.response_template(DELETE_DOMAIN_TEMPLATE)
return template.render() return template.render()
def list_domains(self): def list_domains(self) -> str:
domain_names = self.sdb_backend.list_domains() domain_names = self.sdb_backend.list_domains()
template = self.response_template(LIST_DOMAINS_TEMPLATE) template = self.response_template(LIST_DOMAINS_TEMPLATE)
return template.render(domain_names=domain_names, next_token=None) return template.render(domain_names=domain_names, next_token=None)
def get_attributes(self): def get_attributes(self) -> str:
domain_name = self._get_param("DomainName") domain_name = self._get_param("DomainName")
item_name = self._get_param("ItemName") item_name = self._get_param("ItemName")
attribute_names = self._get_multi_param("AttributeName.") attribute_names = self._get_multi_param("AttributeName.")
@ -39,7 +39,7 @@ class SimpleDBResponse(BaseResponse):
template = self.response_template(GET_ATTRIBUTES_TEMPLATE) template = self.response_template(GET_ATTRIBUTES_TEMPLATE)
return template.render(attributes=attributes) return template.render(attributes=attributes)
def put_attributes(self): def put_attributes(self) -> str:
domain_name = self._get_param("DomainName") domain_name = self._get_param("DomainName")
item_name = self._get_param("ItemName") item_name = self._get_param("ItemName")
attributes = self._get_list_prefix("Attribute") attributes = self._get_list_prefix("Attribute")

View File

@ -3,6 +3,7 @@ import os
import signal import signal
import sys import sys
import warnings import warnings
from typing import Any, List, Optional
from werkzeug.serving import run_simple from werkzeug.serving import run_simple
@ -15,11 +16,11 @@ from moto.moto_server.threaded_moto_server import ( # noqa # pylint: disable=un
) )
def signal_handler(signum, frame): # pylint: disable=unused-argument def signal_handler(signum: Any, frame: Any) -> None: # pylint: disable=unused-argument
sys.exit(0) sys.exit(0)
def main(argv=None): def main(argv: Optional[List[str]] = None) -> None:
argv = argv or sys.argv[1:] argv = argv or sys.argv[1:]
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -79,9 +80,9 @@ def main(argv=None):
# Wrap the main application # Wrap the main application
main_app = DomainDispatcherApplication(create_backend_app, service=args.service) main_app = DomainDispatcherApplication(create_backend_app, service=args.service)
main_app.debug = True main_app.debug = True # type: ignore
ssl_context = None ssl_context: Any = None
if args.ssl_key and args.ssl_cert: if args.ssl_key and args.ssl_cert:
ssl_context = (args.ssl_cert, args.ssl_key) ssl_context = (args.ssl_cert, args.ssl_key)
elif args.ssl: elif args.ssl:

View File

@ -3,20 +3,20 @@ from moto.core.exceptions import JsonRESTError
class OperationNotFound(JsonRESTError): class OperationNotFound(JsonRESTError):
def __init__(self): def __init__(self) -> None:
super().__init__("OperationNotFound", "") super().__init__("OperationNotFound", "")
class NamespaceNotFound(JsonRESTError): class NamespaceNotFound(JsonRESTError):
def __init__(self, ns_id): def __init__(self, ns_id: str):
super().__init__("NamespaceNotFound", f"{ns_id}") super().__init__("NamespaceNotFound", f"{ns_id}")
class ServiceNotFound(JsonRESTError): class ServiceNotFound(JsonRESTError):
def __init__(self, ns_id): def __init__(self, ns_id: str):
super().__init__("ServiceNotFound", f"{ns_id}") super().__init__("ServiceNotFound", f"{ns_id}")
class ConflictingDomainExists(JsonRESTError): class ConflictingDomainExists(JsonRESTError):
def __init__(self, vpc_id): def __init__(self, vpc_id: str):
super().__init__("ConflictingDomainExists", f"{vpc_id}") super().__init__("ConflictingDomainExists", f"{vpc_id}")

View File

@ -1,4 +1,5 @@
import string import string
from typing import Any, Dict, Iterable, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
@ -13,7 +14,7 @@ from .exceptions import (
) )
def random_id(size): def random_id(size: int) -> str:
return "".join( return "".join(
[random.choice(string.ascii_lowercase + string.digits) for _ in range(size)] [random.choice(string.ascii_lowercase + string.digits) for _ in range(size)]
) )
@ -22,17 +23,16 @@ def random_id(size):
class Namespace(BaseModel): class Namespace(BaseModel):
def __init__( def __init__(
self, self,
account_id, account_id: str,
region, region: str,
name, name: str,
ns_type, ns_type: str,
creator_request_id, creator_request_id: str,
description, description: str,
dns_properties, dns_properties: Dict[str, Any],
http_properties, http_properties: Dict[str, Any],
vpc=None, vpc: Optional[str] = None,
): ):
super().__init__()
self.id = f"ns-{random_id(20)}" self.id = f"ns-{random_id(20)}"
self.arn = f"arn:aws:servicediscovery:{region}:{account_id}:namespace/{self.id}" self.arn = f"arn:aws:servicediscovery:{region}:{account_id}:namespace/{self.id}"
self.name = name self.name = name
@ -45,7 +45,7 @@ class Namespace(BaseModel):
self.created = unix_time() self.created = unix_time()
self.updated = unix_time() self.updated = unix_time()
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"Id": self.id, "Id": self.id,
@ -65,31 +65,30 @@ class Namespace(BaseModel):
class Service(BaseModel): class Service(BaseModel):
def __init__( def __init__(
self, self,
account_id, account_id: str,
region, region: str,
name, name: str,
namespace_id, namespace_id: str,
description, description: str,
creator_request_id, creator_request_id: str,
dns_config, dns_config: Dict[str, Any],
health_check_config, health_check_config: Dict[str, Any],
health_check_custom_config, health_check_custom_config: Dict[str, int],
service_type, service_type: str,
): ):
super().__init__()
self.id = f"srv-{random_id(8)}" self.id = f"srv-{random_id(8)}"
self.arn = f"arn:aws:servicediscovery:{region}:{account_id}:service/{self.id}" self.arn = f"arn:aws:servicediscovery:{region}:{account_id}:service/{self.id}"
self.name = name self.name = name
self.namespace_id = namespace_id self.namespace_id = namespace_id
self.description = description self.description = description
self.creator_request_id = creator_request_id self.creator_request_id = creator_request_id
self.dns_config = dns_config self.dns_config: Optional[Dict[str, Any]] = dns_config
self.health_check_config = health_check_config self.health_check_config = health_check_config
self.health_check_custom_config = health_check_custom_config self.health_check_custom_config = health_check_custom_config
self.service_type = service_type self.service_type = service_type
self.created = unix_time() self.created = unix_time()
def update(self, details): def update(self, details: Dict[str, Any]) -> None:
if "Description" in details: if "Description" in details:
self.description = details["Description"] self.description = details["Description"]
if "DnsConfig" in details: if "DnsConfig" in details:
@ -104,7 +103,7 @@ class Service(BaseModel):
if "HealthCheckConfig" in details: if "HealthCheckConfig" in details:
self.health_check_config = details["HealthCheckConfig"] self.health_check_config = details["HealthCheckConfig"]
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"Arn": self.arn, "Arn": self.arn,
"Id": self.id, "Id": self.id,
@ -121,7 +120,7 @@ class Service(BaseModel):
class Operation(BaseModel): class Operation(BaseModel):
def __init__(self, operation_type, targets): def __init__(self, operation_type: str, targets: Dict[str, str]):
super().__init__() super().__init__()
self.id = f"{random_id(32)}-{random_id(8)}" self.id = f"{random_id(32)}-{random_id(8)}"
self.status = "SUCCESS" self.status = "SUCCESS"
@ -130,7 +129,7 @@ class Operation(BaseModel):
self.updated = unix_time() self.updated = unix_time()
self.targets = targets self.targets = targets
def to_json(self, short=False): def to_json(self, short: bool = False) -> Dict[str, Any]:
if short: if short:
return {"Id": self.id, "Status": self.status} return {"Id": self.id, "Status": self.status}
else: else:
@ -147,20 +146,26 @@ class Operation(BaseModel):
class ServiceDiscoveryBackend(BaseBackend): class ServiceDiscoveryBackend(BaseBackend):
"""Implementation of ServiceDiscovery APIs.""" """Implementation of ServiceDiscovery APIs."""
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.operations = dict() self.operations: Dict[str, Operation] = dict()
self.namespaces = dict() self.namespaces: Dict[str, Namespace] = dict()
self.services = dict() self.services: Dict[str, Service] = dict()
self.tagger = TaggingService() self.tagger = TaggingService()
def list_namespaces(self): def list_namespaces(self) -> Iterable[Namespace]:
""" """
Pagination or the Filters-parameter is not yet implemented Pagination or the Filters-parameter is not yet implemented
""" """
return self.namespaces.values() return self.namespaces.values()
def create_http_namespace(self, name, creator_request_id, description, tags): def create_http_namespace(
self,
name: str,
creator_request_id: str,
description: str,
tags: List[Dict[str, str]],
) -> str:
namespace = Namespace( namespace = Namespace(
account_id=self.account_id, account_id=self.account_id,
region=self.region_name, region=self.region_name,
@ -179,13 +184,12 @@ class ServiceDiscoveryBackend(BaseBackend):
) )
return operation_id return operation_id
def _create_operation(self, op_type, targets): def _create_operation(self, op_type: str, targets: Dict[str, str]) -> str:
operation = Operation(operation_type=op_type, targets=targets) operation = Operation(operation_type=op_type, targets=targets)
self.operations[operation.id] = operation self.operations[operation.id] = operation
operation_id = operation.id return operation.id
return operation_id
def delete_namespace(self, namespace_id): def delete_namespace(self, namespace_id: str) -> str:
if namespace_id not in self.namespaces: if namespace_id not in self.namespaces:
raise NamespaceNotFound(namespace_id) raise NamespaceNotFound(namespace_id)
del self.namespaces[namespace_id] del self.namespaces[namespace_id]
@ -194,12 +198,12 @@ class ServiceDiscoveryBackend(BaseBackend):
) )
return operation_id return operation_id
def get_namespace(self, namespace_id): def get_namespace(self, namespace_id: str) -> Namespace:
if namespace_id not in self.namespaces: if namespace_id not in self.namespaces:
raise NamespaceNotFound(namespace_id) raise NamespaceNotFound(namespace_id)
return self.namespaces[namespace_id] return self.namespaces[namespace_id]
def list_operations(self): def list_operations(self) -> Iterable[Operation]:
""" """
Pagination or the Filters-argument is not yet implemented Pagination or the Filters-argument is not yet implemented
""" """
@ -211,23 +215,31 @@ class ServiceDiscoveryBackend(BaseBackend):
} }
return self.operations.values() return self.operations.values()
def get_operation(self, operation_id): def get_operation(self, operation_id: str) -> Operation:
if operation_id not in self.operations: if operation_id not in self.operations:
raise OperationNotFound() raise OperationNotFound()
return self.operations[operation_id] return self.operations[operation_id]
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
self.tagger.tag_resource(resource_arn, tags) self.tagger.tag_resource(resource_arn, tags)
def untag_resource(self, resource_arn, tag_keys): def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys) self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def list_tags_for_resource(self, resource_arn): def list_tags_for_resource(
self, resource_arn: str
) -> Dict[str, List[Dict[str, str]]]:
return self.tagger.list_tags_for_resource(resource_arn) return self.tagger.list_tags_for_resource(resource_arn)
def create_private_dns_namespace( def create_private_dns_namespace(
self, name, creator_request_id, description, vpc, tags, properties self,
): name: str,
creator_request_id: str,
description: str,
vpc: str,
tags: List[Dict[str, str]],
properties: Dict[str, Any],
) -> str:
for namespace in self.namespaces.values(): for namespace in self.namespaces.values():
if namespace.vpc == vpc: if namespace.vpc == vpc:
raise ConflictingDomainExists(vpc) raise ConflictingDomainExists(vpc)
@ -253,8 +265,13 @@ class ServiceDiscoveryBackend(BaseBackend):
return operation_id return operation_id
def create_public_dns_namespace( def create_public_dns_namespace(
self, name, creator_request_id, description, tags, properties self,
): name: str,
creator_request_id: str,
description: str,
tags: List[Dict[str, str]],
properties: Dict[str, Any],
) -> str:
dns_properties = (properties or {}).get("DnsProperties", {}) dns_properties = (properties or {}).get("DnsProperties", {})
dns_properties["HostedZoneId"] = "hzi" dns_properties["HostedZoneId"] = "hzi"
namespace = Namespace( namespace = Namespace(
@ -277,16 +294,16 @@ class ServiceDiscoveryBackend(BaseBackend):
def create_service( def create_service(
self, self,
name, name: str,
namespace_id, namespace_id: str,
creator_request_id, creator_request_id: str,
description, description: str,
dns_config, dns_config: Dict[str, Any],
health_check_config, health_check_config: Dict[str, Any],
health_check_custom_config, health_check_custom_config: Dict[str, Any],
tags, tags: List[Dict[str, str]],
service_type, service_type: str,
): ) -> Service:
service = Service( service = Service(
account_id=self.account_id, account_id=self.account_id,
region=self.region_name, region=self.region_name,
@ -304,21 +321,21 @@ class ServiceDiscoveryBackend(BaseBackend):
self.tagger.tag_resource(service.arn, tags) self.tagger.tag_resource(service.arn, tags)
return service return service
def get_service(self, service_id): def get_service(self, service_id: str) -> Service:
if service_id not in self.services: if service_id not in self.services:
raise ServiceNotFound(service_id) raise ServiceNotFound(service_id)
return self.services[service_id] return self.services[service_id]
def delete_service(self, service_id): def delete_service(self, service_id: str) -> None:
self.services.pop(service_id, None) self.services.pop(service_id, None)
def list_services(self): def list_services(self) -> Iterable[Service]:
""" """
Pagination or the Filters-argument is not yet implemented Pagination or the Filters-argument is not yet implemented
""" """
return self.services.values() return self.services.values()
def update_service(self, service_id, details): def update_service(self, service_id: str, details: Dict[str, Any]) -> str:
service = self.get_service(service_id) service = self.get_service(service_id)
service.update(details=details) service.update(details=details)
operation_id = self._create_operation( operation_id = self._create_operation(

View File

@ -1,24 +1,25 @@
"""Handles incoming servicediscovery requests, invokes methods, returns responses.""" """Handles incoming servicediscovery requests, invokes methods, returns responses."""
import json import json
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import servicediscovery_backends from .models import servicediscovery_backends, ServiceDiscoveryBackend
class ServiceDiscoveryResponse(BaseResponse): class ServiceDiscoveryResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="servicediscovery") super().__init__(service_name="servicediscovery")
@property @property
def servicediscovery_backend(self): def servicediscovery_backend(self) -> ServiceDiscoveryBackend:
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return servicediscovery_backends[self.current_account][self.region] return servicediscovery_backends[self.current_account][self.region]
def list_namespaces(self): def list_namespaces(self) -> TYPE_RESPONSE:
namespaces = self.servicediscovery_backend.list_namespaces() namespaces = self.servicediscovery_backend.list_namespaces()
return 200, {}, json.dumps({"Namespaces": [ns.to_json() for ns in namespaces]}) return 200, {}, json.dumps({"Namespaces": [ns.to_json() for ns in namespaces]})
def create_http_namespace(self): def create_http_namespace(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
name = params.get("Name") name = params.get("Name")
creator_request_id = params.get("CreatorRequestId") creator_request_id = params.get("CreatorRequestId")
@ -32,7 +33,7 @@ class ServiceDiscoveryResponse(BaseResponse):
) )
return json.dumps(dict(OperationId=operation_id)) return json.dumps(dict(OperationId=operation_id))
def delete_namespace(self): def delete_namespace(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
namespace_id = params.get("Id") namespace_id = params.get("Id")
operation_id = self.servicediscovery_backend.delete_namespace( operation_id = self.servicediscovery_backend.delete_namespace(
@ -40,7 +41,7 @@ class ServiceDiscoveryResponse(BaseResponse):
) )
return json.dumps(dict(OperationId=operation_id)) return json.dumps(dict(OperationId=operation_id))
def list_operations(self): def list_operations(self) -> TYPE_RESPONSE:
operations = self.servicediscovery_backend.list_operations() operations = self.servicediscovery_backend.list_operations()
return ( return (
200, 200,
@ -48,7 +49,7 @@ class ServiceDiscoveryResponse(BaseResponse):
json.dumps({"Operations": [o.to_json(short=True) for o in operations]}), json.dumps({"Operations": [o.to_json(short=True) for o in operations]}),
) )
def get_operation(self): def get_operation(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
operation_id = params.get("OperationId") operation_id = params.get("OperationId")
operation = self.servicediscovery_backend.get_operation( operation = self.servicediscovery_backend.get_operation(
@ -56,7 +57,7 @@ class ServiceDiscoveryResponse(BaseResponse):
) )
return json.dumps(dict(Operation=operation.to_json())) return json.dumps(dict(Operation=operation.to_json()))
def get_namespace(self): def get_namespace(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
namespace_id = params.get("Id") namespace_id = params.get("Id")
namespace = self.servicediscovery_backend.get_namespace( namespace = self.servicediscovery_backend.get_namespace(
@ -64,23 +65,23 @@ class ServiceDiscoveryResponse(BaseResponse):
) )
return json.dumps(dict(Namespace=namespace.to_json())) return json.dumps(dict(Namespace=namespace.to_json()))
def tag_resource(self): def tag_resource(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
resource_arn = params.get("ResourceARN") resource_arn = params.get("ResourceARN")
tags = params.get("Tags") tags = params.get("Tags")
self.servicediscovery_backend.tag_resource(resource_arn=resource_arn, tags=tags) self.servicediscovery_backend.tag_resource(resource_arn=resource_arn, tags=tags)
return json.dumps(dict()) return "{}"
def untag_resource(self): def untag_resource(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
resource_arn = params.get("ResourceARN") resource_arn = params.get("ResourceARN")
tag_keys = params.get("TagKeys") tag_keys = params.get("TagKeys")
self.servicediscovery_backend.untag_resource( self.servicediscovery_backend.untag_resource(
resource_arn=resource_arn, tag_keys=tag_keys resource_arn=resource_arn, tag_keys=tag_keys
) )
return json.dumps(dict()) return "{}"
def list_tags_for_resource(self): def list_tags_for_resource(self) -> TYPE_RESPONSE:
params = json.loads(self.body) params = json.loads(self.body)
resource_arn = params.get("ResourceARN") resource_arn = params.get("ResourceARN")
tags = self.servicediscovery_backend.list_tags_for_resource( tags = self.servicediscovery_backend.list_tags_for_resource(
@ -88,7 +89,7 @@ class ServiceDiscoveryResponse(BaseResponse):
) )
return 200, {}, json.dumps(tags) return 200, {}, json.dumps(tags)
def create_private_dns_namespace(self): def create_private_dns_namespace(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
name = params.get("Name") name = params.get("Name")
creator_request_id = params.get("CreatorRequestId") creator_request_id = params.get("CreatorRequestId")
@ -106,7 +107,7 @@ class ServiceDiscoveryResponse(BaseResponse):
) )
return json.dumps(dict(OperationId=operation_id)) return json.dumps(dict(OperationId=operation_id))
def create_public_dns_namespace(self): def create_public_dns_namespace(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
name = params.get("Name") name = params.get("Name")
creator_request_id = params.get("CreatorRequestId") creator_request_id = params.get("CreatorRequestId")
@ -122,7 +123,7 @@ class ServiceDiscoveryResponse(BaseResponse):
) )
return json.dumps(dict(OperationId=operation_id)) return json.dumps(dict(OperationId=operation_id))
def create_service(self): def create_service(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
name = params.get("Name") name = params.get("Name")
namespace_id = params.get("NamespaceId") namespace_id = params.get("NamespaceId")
@ -146,23 +147,23 @@ class ServiceDiscoveryResponse(BaseResponse):
) )
return json.dumps(dict(Service=service.to_json())) return json.dumps(dict(Service=service.to_json()))
def get_service(self): def get_service(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
service_id = params.get("Id") service_id = params.get("Id")
service = self.servicediscovery_backend.get_service(service_id=service_id) service = self.servicediscovery_backend.get_service(service_id=service_id)
return json.dumps(dict(Service=service.to_json())) return json.dumps(dict(Service=service.to_json()))
def delete_service(self): def delete_service(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
service_id = params.get("Id") service_id = params.get("Id")
self.servicediscovery_backend.delete_service(service_id=service_id) self.servicediscovery_backend.delete_service(service_id=service_id)
return json.dumps(dict()) return "{}"
def list_services(self): def list_services(self) -> str:
services = self.servicediscovery_backend.list_services() services = self.servicediscovery_backend.list_services()
return json.dumps(dict(Services=[s.to_json() for s in services])) return json.dumps(dict(Services=[s.to_json() for s in services]))
def update_service(self): def update_service(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
service_id = params.get("Id") service_id = params.get("Id")
details = params.get("Service") details = params.get("Service")

View File

@ -38,7 +38,7 @@ SKIP_REQUIRES_DOCKER = bool(os.environ.get("TESTS_SKIP_REQUIRES_DOCKER", False))
LAMBDA_DATA_DIR = os.environ.get("MOTO_LAMBDA_DATA_DIR", "/tmp/data") LAMBDA_DATA_DIR = os.environ.get("MOTO_LAMBDA_DATA_DIR", "/tmp/data")
def get_sf_execution_history_type(): def get_sf_execution_history_type() -> str:
""" """
Determines which execution history events `get_execution_history` returns Determines which execution history events `get_execution_history` returns
:returns: str representing the type of Step Function Execution Type events should be :returns: str representing the type of Step Function Execution Type events should be
@ -97,11 +97,11 @@ def moto_lambda_image() -> str:
return os.environ.get("MOTO_DOCKER_LAMBDA_IMAGE", "mlupin/docker-lambda") return os.environ.get("MOTO_DOCKER_LAMBDA_IMAGE", "mlupin/docker-lambda")
def moto_network_name() -> str: def moto_network_name() -> Optional[str]:
return os.environ.get("MOTO_DOCKER_NETWORK_NAME") return os.environ.get("MOTO_DOCKER_NETWORK_NAME")
def moto_network_mode() -> str: def moto_network_mode() -> Optional[str]:
return os.environ.get("MOTO_DOCKER_NETWORK_MODE") return os.environ.get("MOTO_DOCKER_NETWORK_MODE")

View File

@ -1,10 +1,18 @@
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
class SigningProfile(BaseModel): class SigningProfile(BaseModel):
def __init__( def __init__(
self, account_id, region, name, platform_id, signature_validity_period, tags self,
account_id: str,
region: str,
name: str,
platform_id: str,
signature_validity_period: Optional[Dict[str, Any]],
tags: Dict[str, str],
): ):
self.name = name self.name = name
self.platform_id = platform_id self.platform_id = platform_id
@ -19,11 +27,11 @@ class SigningProfile(BaseModel):
self.profile_version = mock_random.get_random_hex(10) self.profile_version = mock_random.get_random_hex(10)
self.profile_version_arn = f"{self.arn}/{self.profile_version}" self.profile_version_arn = f"{self.arn}/{self.profile_version}"
def cancel(self): def cancel(self) -> None:
self.status = "Canceled" self.status = "Canceled"
def to_dict(self, full=True): def to_dict(self, full: bool = True) -> Dict[str, Any]:
small = { small: Dict[str, Any] = {
"arn": self.arn, "arn": self.arn,
"profileVersion": self.profile_version, "profileVersion": self.profile_version,
"profileVersionArn": self.profile_version_arn, "profileVersionArn": self.profile_version_arn,
@ -149,22 +157,22 @@ class SignerBackend(BaseBackend):
}, },
] ]
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.signing_profiles: [str, SigningProfile] = dict() self.signing_profiles: Dict[str, SigningProfile] = dict()
def cancel_signing_profile(self, profile_name) -> None: def cancel_signing_profile(self, profile_name: str) -> None:
self.signing_profiles[profile_name].cancel() self.signing_profiles[profile_name].cancel()
def get_signing_profile(self, profile_name) -> SigningProfile: def get_signing_profile(self, profile_name: str) -> SigningProfile:
return self.signing_profiles[profile_name] return self.signing_profiles[profile_name]
def put_signing_profile( def put_signing_profile(
self, self,
profile_name, profile_name: str,
signature_validity_period, signature_validity_period: Optional[Dict[str, Any]],
platform_id, platform_id: str,
tags, tags: Dict[str, str],
) -> SigningProfile: ) -> SigningProfile:
""" """
The following parameters are not yet implemented: SigningMaterial, Overrides, SigningParamaters The following parameters are not yet implemented: SigningMaterial, Overrides, SigningParamaters
@ -180,7 +188,7 @@ class SignerBackend(BaseBackend):
self.signing_profiles[profile_name] = profile self.signing_profiles[profile_name] = profile
return profile return profile
def list_signing_platforms(self): def list_signing_platforms(self) -> List[Dict[str, Any]]:
""" """
Pagination is not yet implemented. The parameters category, partner, target are not yet implemented Pagination is not yet implemented. The parameters category, partner, target are not yet implemented
""" """
@ -189,4 +197,4 @@ class SignerBackend(BaseBackend):
# Using the lambda-regions # Using the lambda-regions
# boto3.Session().get_available_regions("signer") still returns an empty list # boto3.Session().get_available_regions("signer") still returns an empty list
signer_backends: [str, [str, SignerBackend]] = BackendDict(SignerBackend, "lambda") signer_backends = BackendDict(SignerBackend, "lambda")

View File

@ -6,7 +6,7 @@ from .models import signer_backends, SignerBackend
class signerResponse(BaseResponse): class signerResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="signer") super().__init__(service_name="signer")
@property @property
@ -14,17 +14,17 @@ class signerResponse(BaseResponse):
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return signer_backends[self.current_account][self.region] return signer_backends[self.current_account][self.region]
def cancel_signing_profile(self): def cancel_signing_profile(self) -> str:
profile_name = self.path.split("/")[-1] profile_name = self.path.split("/")[-1]
self.signer_backend.cancel_signing_profile(profile_name=profile_name) self.signer_backend.cancel_signing_profile(profile_name=profile_name)
return "{}" return "{}"
def get_signing_profile(self): def get_signing_profile(self) -> str:
profile_name = self.path.split("/")[-1] profile_name = self.path.split("/")[-1]
profile = self.signer_backend.get_signing_profile(profile_name=profile_name) profile = self.signer_backend.get_signing_profile(profile_name=profile_name)
return json.dumps(profile.to_dict()) return json.dumps(profile.to_dict())
def put_signing_profile(self): def put_signing_profile(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
profile_name = self.path.split("/")[-1] profile_name = self.path.split("/")[-1]
signature_validity_period = params.get("signatureValidityPeriod") signature_validity_period = params.get("signatureValidityPeriod")
@ -38,6 +38,6 @@ class signerResponse(BaseResponse):
) )
return json.dumps(profile.to_dict(full=False)) return json.dumps(profile.to_dict(full=False))
def list_signing_platforms(self): def list_signing_platforms(self) -> str:
platforms = self.signer_backend.list_signing_platforms() platforms = self.signer_backend.list_signing_platforms()
return json.dumps(dict(platforms=platforms)) return json.dumps(dict(platforms=platforms))

View File

@ -3,5 +3,5 @@ from moto.core.exceptions import JsonRESTError
class ResourceNotFound(JsonRESTError): class ResourceNotFound(JsonRESTError):
def __init__(self): def __init__(self) -> None:
super().__init__("ResourceNotFound", "Account not found") super().__init__("ResourceNotFound", "Account not found")

View File

@ -1,21 +1,22 @@
from .exceptions import ResourceNotFound from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
from moto.moto_api._internal import mock_random as random from moto.moto_api._internal import mock_random as random
from moto.utilities.paginator import paginate from moto.utilities.paginator import paginate
from .exceptions import ResourceNotFound
from .utils import PAGINATION_MODEL from .utils import PAGINATION_MODEL
class AccountAssignment(BaseModel): class AccountAssignment(BaseModel):
def __init__( def __init__(
self, self,
instance_arn, instance_arn: str,
target_id, target_id: str,
target_type, target_type: str,
permission_set_arn, permission_set_arn: str,
principal_type, principal_type: str,
principal_id, principal_id: str,
): ):
self.request_id = str(random.uuid4()) self.request_id = str(random.uuid4())
self.instance_arn = instance_arn self.instance_arn = instance_arn
@ -26,8 +27,8 @@ class AccountAssignment(BaseModel):
self.principal_id = principal_id self.principal_id = principal_id
self.created_date = unix_time() self.created_date = unix_time()
def to_json(self, include_creation_date=False): def to_json(self, include_creation_date: bool = False) -> Dict[str, Any]:
summary = { summary: Dict[str, Any] = {
"TargetId": self.target_id, "TargetId": self.target_id,
"TargetType": self.target_type, "TargetType": self.target_type,
"PermissionSetArn": self.permission_set_arn, "PermissionSetArn": self.permission_set_arn,
@ -42,12 +43,12 @@ class AccountAssignment(BaseModel):
class PermissionSet(BaseModel): class PermissionSet(BaseModel):
def __init__( def __init__(
self, self,
name, name: str,
description, description: str,
instance_arn, instance_arn: str,
session_duration, session_duration: str,
relay_state, relay_state: str,
tags, tags: List[Dict[str, str]],
): ):
self.name = name self.name = name
self.description = description self.description = description
@ -58,8 +59,8 @@ class PermissionSet(BaseModel):
self.tags = tags self.tags = tags
self.created_date = unix_time() self.created_date = unix_time()
def to_json(self, include_creation_date=False): def to_json(self, include_creation_date: bool = False) -> Dict[str, Any]:
summary = { summary: Dict[str, Any] = {
"Name": self.name, "Name": self.name,
"Description": self.description, "Description": self.description,
"PermissionSetArn": self.permission_set_arn, "PermissionSetArn": self.permission_set_arn,
@ -71,7 +72,7 @@ class PermissionSet(BaseModel):
return summary return summary
@staticmethod @staticmethod
def generate_id(instance_arn): def generate_id(instance_arn: str) -> str:
chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"] chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"]
return ( return (
instance_arn instance_arn
@ -83,20 +84,20 @@ class PermissionSet(BaseModel):
class SSOAdminBackend(BaseBackend): class SSOAdminBackend(BaseBackend):
"""Implementation of SSOAdmin APIs.""" """Implementation of SSOAdmin APIs."""
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.account_assignments = list() self.account_assignments: List[AccountAssignment] = list()
self.permission_sets = list() self.permission_sets: List[PermissionSet] = list()
def create_account_assignment( def create_account_assignment(
self, self,
instance_arn, instance_arn: str,
target_id, target_id: str,
target_type, target_type: str,
permission_set_arn, permission_set_arn: str,
principal_type, principal_type: str,
principal_id, principal_id: str,
): ) -> Dict[str, Any]:
assignment = AccountAssignment( assignment = AccountAssignment(
instance_arn, instance_arn,
target_id, target_id,
@ -110,13 +111,13 @@ class SSOAdminBackend(BaseBackend):
def delete_account_assignment( def delete_account_assignment(
self, self,
instance_arn, instance_arn: str,
target_id, target_id: str,
target_type, target_type: str,
permission_set_arn, permission_set_arn: str,
principal_type, principal_type: str,
principal_id, principal_id: str,
): ) -> Dict[str, Any]:
account = self._find_account( account = self._find_account(
instance_arn, instance_arn,
target_id, target_id,
@ -130,13 +131,13 @@ class SSOAdminBackend(BaseBackend):
def _find_account( def _find_account(
self, self,
instance_arn, instance_arn: str,
target_id, target_id: str,
target_type, target_type: str,
permission_set_arn, permission_set_arn: str,
principal_type, principal_type: str,
principal_id, principal_id: str,
): ) -> AccountAssignment:
for account in self.account_assignments: for account in self.account_assignments:
instance_arn_match = account.instance_arn == instance_arn instance_arn_match = account.instance_arn == instance_arn
target_id_match = account.target_id == target_id target_id_match = account.target_id == target_id
@ -155,7 +156,9 @@ class SSOAdminBackend(BaseBackend):
return account return account
raise ResourceNotFound raise ResourceNotFound
def list_account_assignments(self, instance_arn, account_id, permission_set_arn): def list_account_assignments(
self, instance_arn: str, account_id: str, permission_set_arn: str
) -> List[Dict[str, Any]]:
""" """
Pagination has not yet been implemented Pagination has not yet been implemented
""" """
@ -178,13 +181,13 @@ class SSOAdminBackend(BaseBackend):
def create_permission_set( def create_permission_set(
self, self,
name, name: str,
description, description: str,
instance_arn, instance_arn: str,
session_duration, session_duration: str,
relay_state, relay_state: str,
tags, tags: List[Dict[str, str]],
): ) -> Dict[str, Any]:
permission_set = PermissionSet( permission_set = PermissionSet(
name, name,
description, description,
@ -198,12 +201,12 @@ class SSOAdminBackend(BaseBackend):
def update_permission_set( def update_permission_set(
self, self,
instance_arn, instance_arn: str,
permission_set_arn, permission_set_arn: str,
description, description: str,
session_duration, session_duration: str,
relay_state, relay_state: str,
): ) -> Dict[str, Any]:
permission_set = self._find_permission_set( permission_set = self._find_permission_set(
instance_arn, instance_arn,
permission_set_arn, permission_set_arn,
@ -216,10 +219,8 @@ class SSOAdminBackend(BaseBackend):
return permission_set.to_json(True) return permission_set.to_json(True)
def describe_permission_set( def describe_permission_set(
self, self, instance_arn: str, permission_set_arn: str
instance_arn, ) -> Dict[str, Any]:
permission_set_arn,
):
permission_set = self._find_permission_set( permission_set = self._find_permission_set(
instance_arn, instance_arn,
permission_set_arn, permission_set_arn,
@ -227,10 +228,8 @@ class SSOAdminBackend(BaseBackend):
return permission_set.to_json(True) return permission_set.to_json(True)
def delete_permission_set( def delete_permission_set(
self, self, instance_arn: str, permission_set_arn: str
instance_arn, ) -> Dict[str, Any]:
permission_set_arn,
):
permission_set = self._find_permission_set( permission_set = self._find_permission_set(
instance_arn, instance_arn,
permission_set_arn, permission_set_arn,
@ -239,10 +238,8 @@ class SSOAdminBackend(BaseBackend):
return permission_set.to_json(include_creation_date=True) return permission_set.to_json(include_creation_date=True)
def _find_permission_set( def _find_permission_set(
self, self, instance_arn: str, permission_set_arn: str
instance_arn, ) -> PermissionSet:
permission_set_arn,
):
for permission_set in self.permission_sets: for permission_set in self.permission_sets:
instance_arn_match = permission_set.instance_arn == instance_arn instance_arn_match = permission_set.instance_arn == instance_arn
permission_set_match = ( permission_set_match = (
@ -253,7 +250,7 @@ class SSOAdminBackend(BaseBackend):
raise ResourceNotFound raise ResourceNotFound
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_permission_sets(self, instance_arn): def list_permission_sets(self, instance_arn: str) -> List[PermissionSet]: # type: ignore[misc]
permission_sets = [] permission_sets = []
for permission_set in self.permission_sets: for permission_set in self.permission_sets:
if permission_set.instance_arn == instance_arn: if permission_set.instance_arn == instance_arn:

View File

@ -3,21 +3,21 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
from .models import ssoadmin_backends from .models import ssoadmin_backends, SSOAdminBackend
class SSOAdminResponse(BaseResponse): class SSOAdminResponse(BaseResponse):
"""Handler for SSOAdmin requests and responses.""" """Handler for SSOAdmin requests and responses."""
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="sso-admin") super().__init__(service_name="sso-admin")
@property @property
def ssoadmin_backend(self): def ssoadmin_backend(self) -> SSOAdminBackend:
"""Return backend instance specific for this region.""" """Return backend instance specific for this region."""
return ssoadmin_backends[self.current_account][self.region] return ssoadmin_backends[self.current_account][self.region]
def create_account_assignment(self): def create_account_assignment(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
instance_arn = params.get("InstanceArn") instance_arn = params.get("InstanceArn")
target_id = params.get("TargetId") target_id = params.get("TargetId")
@ -37,7 +37,7 @@ class SSOAdminResponse(BaseResponse):
summary["RequestId"] = str(mock_random.uuid4()) summary["RequestId"] = str(mock_random.uuid4())
return json.dumps({"AccountAssignmentCreationStatus": summary}) return json.dumps({"AccountAssignmentCreationStatus": summary})
def delete_account_assignment(self): def delete_account_assignment(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
instance_arn = params.get("InstanceArn") instance_arn = params.get("InstanceArn")
target_id = params.get("TargetId") target_id = params.get("TargetId")
@ -57,7 +57,7 @@ class SSOAdminResponse(BaseResponse):
summary["RequestId"] = str(mock_random.uuid4()) summary["RequestId"] = str(mock_random.uuid4())
return json.dumps({"AccountAssignmentDeletionStatus": summary}) return json.dumps({"AccountAssignmentDeletionStatus": summary})
def list_account_assignments(self): def list_account_assignments(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
instance_arn = params.get("InstanceArn") instance_arn = params.get("InstanceArn")
account_id = params.get("AccountId") account_id = params.get("AccountId")
@ -69,7 +69,7 @@ class SSOAdminResponse(BaseResponse):
) )
return json.dumps({"AccountAssignments": assignments}) return json.dumps({"AccountAssignments": assignments})
def create_permission_set(self): def create_permission_set(self) -> str:
name = self._get_param("Name") name = self._get_param("Name")
description = self._get_param("Description") description = self._get_param("Description")
instance_arn = self._get_param("InstanceArn") instance_arn = self._get_param("InstanceArn")
@ -88,7 +88,7 @@ class SSOAdminResponse(BaseResponse):
return json.dumps({"PermissionSet": permission_set}) return json.dumps({"PermissionSet": permission_set})
def delete_permission_set(self): def delete_permission_set(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
instance_arn = params.get("InstanceArn") instance_arn = params.get("InstanceArn")
permission_set_arn = params.get("PermissionSetArn") permission_set_arn = params.get("PermissionSetArn")
@ -96,8 +96,9 @@ class SSOAdminResponse(BaseResponse):
instance_arn=instance_arn, instance_arn=instance_arn,
permission_set_arn=permission_set_arn, permission_set_arn=permission_set_arn,
) )
return "{}"
def update_permission_set(self): def update_permission_set(self) -> str:
instance_arn = self._get_param("InstanceArn") instance_arn = self._get_param("InstanceArn")
permission_set_arn = self._get_param("PermissionSetArn") permission_set_arn = self._get_param("PermissionSetArn")
description = self._get_param("Description") description = self._get_param("Description")
@ -111,8 +112,9 @@ class SSOAdminResponse(BaseResponse):
session_duration=session_duration, session_duration=session_duration,
relay_state=relay_state, relay_state=relay_state,
) )
return "{}"
def describe_permission_set(self): def describe_permission_set(self) -> str:
instance_arn = self._get_param("InstanceArn") instance_arn = self._get_param("InstanceArn")
permission_set_arn = self._get_param("PermissionSetArn") permission_set_arn = self._get_param("PermissionSetArn")
@ -122,7 +124,7 @@ class SSOAdminResponse(BaseResponse):
) )
return json.dumps({"PermissionSet": permission_set}) return json.dumps({"PermissionSet": permission_set})
def list_permission_sets(self): def list_permission_sets(self) -> str:
instance_arn = self._get_param("InstanceArn") instance_arn = self._get_param("InstanceArn")
max_results = self._get_int_param("MaxResults") max_results = self._get_int_param("MaxResults")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")

View File

@ -35,7 +35,7 @@ class InvalidToken(AWSError):
TYPE = "InvalidToken" TYPE = "InvalidToken"
STATUS = 400 STATUS = 400
def __init__(self, message="Invalid token"): def __init__(self, message: str = "Invalid token"):
super().__init__(f"Invalid Token: {message}") super().__init__(f"Invalid Token: {message}")
@ -43,5 +43,5 @@ class ResourceNotFound(AWSError):
TYPE = "ResourceNotFound" TYPE = "ResourceNotFound"
STATUS = 400 STATUS = 400
def __init__(self, arn): def __init__(self, arn: str):
super().__init__(f"Resource not found: '{arn}'") super().__init__(f"Resource not found: '{arn}'")

View File

@ -2,6 +2,7 @@ import json
import re import re
from datetime import datetime from datetime import datetime
from dateutil.tz import tzlocal from dateutil.tz import tzlocal
from typing import Any, Dict, List, Iterable, Optional, Pattern
from moto.core import BaseBackend, BackendDict, CloudFormationModel from moto.core import BaseBackend, BackendDict, CloudFormationModel
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds
@ -21,19 +22,32 @@ from moto.utilities.paginator import paginate
class StateMachine(CloudFormationModel): class StateMachine(CloudFormationModel):
def __init__(self, arn, name, definition, roleArn, tags=None): def __init__(
self,
arn: str,
name: str,
definition: str,
roleArn: str,
tags: Optional[List[Dict[str, str]]] = None,
):
self.creation_date = iso_8601_datetime_with_milliseconds(datetime.now()) self.creation_date = iso_8601_datetime_with_milliseconds(datetime.now())
self.update_date = self.creation_date self.update_date = self.creation_date
self.arn = arn self.arn = arn
self.name = name self.name = name
self.definition = definition self.definition = definition
self.roleArn = roleArn self.roleArn = roleArn
self.executions = [] self.executions: List[Execution] = []
self.tags = [] self.tags: List[Dict[str, str]] = []
if tags: if tags:
self.add_tags(tags) self.add_tags(tags)
def start_execution(self, region_name, account_id, execution_name, execution_input): def start_execution(
self,
region_name: str,
account_id: str,
execution_name: str,
execution_input: str,
) -> "Execution":
self._ensure_execution_name_doesnt_exist(execution_name) self._ensure_execution_name_doesnt_exist(execution_name)
self._validate_execution_input(execution_input) self._validate_execution_input(execution_input)
execution = Execution( execution = Execution(
@ -47,7 +61,7 @@ class StateMachine(CloudFormationModel):
self.executions.append(execution) self.executions.append(execution)
return execution return execution
def stop_execution(self, execution_arn): def stop_execution(self, execution_arn: str) -> "Execution":
execution = next( execution = next(
(x for x in self.executions if x.execution_arn == execution_arn), None (x for x in self.executions if x.execution_arn == execution_arn), None
) )
@ -58,14 +72,14 @@ class StateMachine(CloudFormationModel):
execution.stop() execution.stop()
return execution return execution
def _ensure_execution_name_doesnt_exist(self, name): def _ensure_execution_name_doesnt_exist(self, name: str) -> None:
for execution in self.executions: for execution in self.executions:
if execution.name == name: if execution.name == name:
raise ExecutionAlreadyExists( raise ExecutionAlreadyExists(
"Execution Already Exists: '" + execution.execution_arn + "'" "Execution Already Exists: '" + execution.execution_arn + "'"
) )
def _validate_execution_input(self, execution_input): def _validate_execution_input(self, execution_input: str) -> None:
try: try:
json.loads(execution_input) json.loads(execution_input)
except Exception as ex: except Exception as ex:
@ -73,13 +87,13 @@ class StateMachine(CloudFormationModel):
"Invalid State Machine Execution Input: '" + str(ex) + "'" "Invalid State Machine Execution Input: '" + str(ex) + "'"
) )
def update(self, **kwargs): def update(self, **kwargs: Any) -> None:
for key, value in kwargs.items(): for key, value in kwargs.items():
if value is not None: if value is not None:
setattr(self, key, value) setattr(self, key, value)
self.update_date = iso_8601_datetime_with_milliseconds(datetime.now()) self.update_date = iso_8601_datetime_with_milliseconds(datetime.now())
def add_tags(self, tags): def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]:
merged_tags = [] merged_tags = []
for tag in self.tags: for tag in self.tags:
replacement_index = next( replacement_index = next(
@ -96,15 +110,15 @@ class StateMachine(CloudFormationModel):
self.tags = merged_tags self.tags = merged_tags
return self.tags return self.tags
def remove_tags(self, tag_keys): def remove_tags(self, tag_keys: List[str]) -> List[Dict[str, str]]:
self.tags = [tag_set for tag_set in self.tags if tag_set["key"] not in tag_keys] self.tags = [tag_set for tag_set in self.tags if tag_set["key"] not in tag_keys]
return self.tags return self.tags
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.arn return self.arn
def get_cfn_properties(self, prop_overrides): def get_cfn_properties(self, prop_overrides: Dict[str, Any]) -> Dict[str, Any]:
property_names = [ property_names = [
"DefinitionString", "DefinitionString",
"RoleArn", "RoleArn",
@ -124,7 +138,7 @@ class StateMachine(CloudFormationModel):
return properties return properties
@classmethod @classmethod
def has_cfn_attr(cls, attr): def has_cfn_attr(cls, attr: str) -> bool:
return attr in [ return attr in [
"Name", "Name",
"DefinitionString", "DefinitionString",
@ -133,7 +147,7 @@ class StateMachine(CloudFormationModel):
"Tags", "Tags",
] ]
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name: str) -> Any:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Name": if attribute_name == "Name":
@ -150,17 +164,22 @@ class StateMachine(CloudFormationModel):
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return "StateMachine" return "StateMachine"
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
return "AWS::StepFunctions::StateMachine" return "AWS::StepFunctions::StateMachine"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "StateMachine":
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
name = properties.get("StateMachineName", resource_name) name = properties.get("StateMachineName", resource_name)
definition = properties.get("DefinitionString", "") definition = properties.get("DefinitionString", "")
@ -170,19 +189,25 @@ class StateMachine(CloudFormationModel):
return sf_backend.create_state_machine(name, definition, role_arn, tags=tags) return sf_backend.create_state_machine(name, definition, role_arn, tags=tags)
@classmethod @classmethod
def delete_from_cloudformation_json(cls, resource_name, _, account_id, region_name): def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
sf_backend = stepfunction_backends[account_id][region_name] sf_backend = stepfunction_backends[account_id][region_name]
sf_backend.delete_state_machine(resource_name) sf_backend.delete_state_machine(resource_name)
@classmethod @classmethod
def update_from_cloudformation_json( def update_from_cloudformation_json( # type: ignore[misc]
cls, cls,
original_resource, original_resource: Any,
new_resource_name, new_resource_name: str,
cloudformation_json, cloudformation_json: Any,
account_id, account_id: str,
region_name, region_name: str,
): ) -> "StateMachine":
properties = cloudformation_json.get("Properties", {}) properties = cloudformation_json.get("Properties", {})
name = properties.get("StateMachineName", original_resource.name) name = properties.get("StateMachineName", original_resource.name)
@ -214,12 +239,12 @@ class StateMachine(CloudFormationModel):
class Execution: class Execution:
def __init__( def __init__(
self, self,
region_name, region_name: str,
account_id, account_id: str,
state_machine_name, state_machine_name: str,
execution_name, execution_name: str,
state_machine_arn, state_machine_arn: str,
execution_input, execution_input: str,
): ):
execution_arn = "arn:aws:states:{}:{}:execution:{}:{}" execution_arn = "arn:aws:states:{}:{}:execution:{}:{}"
execution_arn = execution_arn.format( execution_arn = execution_arn.format(
@ -235,9 +260,9 @@ class Execution:
if settings.get_sf_execution_history_type() == "SUCCESS" if settings.get_sf_execution_history_type() == "SUCCESS"
else "FAILED" else "FAILED"
) )
self.stop_date = None self.stop_date: Optional[str] = None
def get_execution_history(self, roleArn): def get_execution_history(self, roleArn: str) -> List[Dict[str, Any]]:
sf_execution_history_type = settings.get_sf_execution_history_type() sf_execution_history_type = settings.get_sf_execution_history_type()
if sf_execution_history_type == "SUCCESS": if sf_execution_history_type == "SUCCESS":
return [ return [
@ -334,8 +359,9 @@ class Execution:
}, },
}, },
] ]
return []
def stop(self): def stop(self) -> None:
self.status = "ABORTED" self.status = "ABORTED"
self.stop_date = iso_8601_datetime_with_milliseconds(datetime.now()) self.stop_date = iso_8601_datetime_with_milliseconds(datetime.now())
@ -451,13 +477,19 @@ class StepFunctionBackend(BaseBackend):
"arn:aws:states:[-0-9a-zA-Z]+:(?P<account_id>[0-9]{12}):execution:.+" "arn:aws:states:[-0-9a-zA-Z]+:(?P<account_id>[0-9]{12}):execution:.+"
) )
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.state_machines = [] self.state_machines: List[StateMachine] = []
self.executions = [] self.executions: List[Execution] = []
self._account_id = None self._account_id = None
def create_state_machine(self, name, definition, roleArn, tags=None): def create_state_machine(
self,
name: str,
definition: str,
roleArn: str,
tags: Optional[List[Dict[str, str]]] = None,
) -> StateMachine:
self._validate_name(name) self._validate_name(name)
self._validate_role_arn(roleArn) self._validate_role_arn(roleArn)
arn = f"arn:aws:states:{self.region_name}:{self.account_id}:stateMachine:{name}" arn = f"arn:aws:states:{self.region_name}:{self.account_id}:stateMachine:{name}"
@ -469,11 +501,10 @@ class StepFunctionBackend(BaseBackend):
return state_machine return state_machine
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_state_machines(self): def list_state_machines(self) -> Iterable[StateMachine]: # type: ignore[misc]
state_machines = sorted(self.state_machines, key=lambda x: x.creation_date) return sorted(self.state_machines, key=lambda x: x.creation_date)
return state_machines
def describe_state_machine(self, arn): def describe_state_machine(self, arn: str) -> StateMachine:
self._validate_machine_arn(arn) self._validate_machine_arn(arn)
sm = next((x for x in self.state_machines if x.arn == arn), None) sm = next((x for x in self.state_machines if x.arn == arn), None)
if not sm: if not sm:
@ -482,13 +513,15 @@ class StepFunctionBackend(BaseBackend):
) )
return sm return sm
def delete_state_machine(self, arn): def delete_state_machine(self, arn: str) -> None:
self._validate_machine_arn(arn) self._validate_machine_arn(arn)
sm = next((x for x in self.state_machines if x.arn == arn), None) sm = next((x for x in self.state_machines if x.arn == arn), None)
if sm: if sm:
self.state_machines.remove(sm) self.state_machines.remove(sm)
def update_state_machine(self, arn, definition=None, role_arn=None): def update_state_machine(
self, arn: str, definition: Optional[str] = None, role_arn: Optional[str] = None
) -> StateMachine:
sm = self.describe_state_machine(arn) sm = self.describe_state_machine(arn)
updates = { updates = {
"definition": definition, "definition": definition,
@ -497,23 +530,24 @@ class StepFunctionBackend(BaseBackend):
sm.update(**updates) sm.update(**updates)
return sm return sm
def start_execution(self, state_machine_arn, name=None, execution_input=None): def start_execution(
self, state_machine_arn: str, name: str, execution_input: str
) -> Execution:
state_machine = self.describe_state_machine(state_machine_arn) state_machine = self.describe_state_machine(state_machine_arn)
execution = state_machine.start_execution( return state_machine.start_execution(
region_name=self.region_name, region_name=self.region_name,
account_id=self.account_id, account_id=self.account_id,
execution_name=name or str(mock_random.uuid4()), execution_name=name or str(mock_random.uuid4()),
execution_input=execution_input, execution_input=execution_input,
) )
return execution
def stop_execution(self, execution_arn): def stop_execution(self, execution_arn: str) -> Execution:
self._validate_execution_arn(execution_arn) self._validate_execution_arn(execution_arn)
state_machine = self._get_state_machine_for_execution(execution_arn) state_machine = self._get_state_machine_for_execution(execution_arn)
return state_machine.stop_execution(execution_arn) return state_machine.stop_execution(execution_arn)
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_executions(self, state_machine_arn, status_filter=None): def list_executions(self, state_machine_arn: str, status_filter: Optional[str] = None) -> Iterable[Execution]: # type: ignore[misc]
""" """
The status of every execution is set to 'RUNNING' by default. The status of every execution is set to 'RUNNING' by default.
Set the following environment variable if you want to get a FAILED status back: Set the following environment variable if you want to get a FAILED status back:
@ -530,7 +564,7 @@ class StepFunctionBackend(BaseBackend):
executions = sorted(executions, key=lambda x: x.start_date, reverse=True) executions = sorted(executions, key=lambda x: x.start_date, reverse=True)
return executions return executions
def describe_execution(self, execution_arn): def describe_execution(self, execution_arn: str) -> Execution:
""" """
The status of every execution is set to 'RUNNING' by default. The status of every execution is set to 'RUNNING' by default.
Set the following environment variable if you want to get a FAILED status back: Set the following environment variable if you want to get a FAILED status back:
@ -551,7 +585,7 @@ class StepFunctionBackend(BaseBackend):
) )
return exctn return exctn
def get_execution_history(self, execution_arn): def get_execution_history(self, execution_arn: str) -> List[Dict[str, Any]]:
""" """
A static list of successful events is returned by default. A static list of successful events is returned by default.
Set the following environment variable if you want to get a static list of events for a failed execution: Set the following environment variable if you want to get a static list of events for a failed execution:
@ -572,61 +606,61 @@ class StepFunctionBackend(BaseBackend):
) )
return execution.get_execution_history(state_machine.roleArn) return execution.get_execution_history(state_machine.roleArn)
def list_tags_for_resource(self, arn): def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]:
try: try:
state_machine = self.describe_state_machine(arn) state_machine = self.describe_state_machine(arn)
return state_machine.tags or [] return state_machine.tags or []
except StateMachineDoesNotExist: except StateMachineDoesNotExist:
return [] return []
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn: str, tags: List[Dict[str, str]]) -> None:
try: try:
state_machine = self.describe_state_machine(resource_arn) state_machine = self.describe_state_machine(resource_arn)
state_machine.add_tags(tags) state_machine.add_tags(tags)
except StateMachineDoesNotExist: except StateMachineDoesNotExist:
raise ResourceNotFound(resource_arn) raise ResourceNotFound(resource_arn)
def untag_resource(self, resource_arn, tag_keys): def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
try: try:
state_machine = self.describe_state_machine(resource_arn) state_machine = self.describe_state_machine(resource_arn)
state_machine.remove_tags(tag_keys) state_machine.remove_tags(tag_keys)
except StateMachineDoesNotExist: except StateMachineDoesNotExist:
raise ResourceNotFound(resource_arn) raise ResourceNotFound(resource_arn)
def _validate_name(self, name): def _validate_name(self, name: str) -> None:
if any(invalid_char in name for invalid_char in self.invalid_chars_for_name): if any(invalid_char in name for invalid_char in self.invalid_chars_for_name):
raise InvalidName("Invalid Name: '" + name + "'") raise InvalidName("Invalid Name: '" + name + "'")
if any(name.find(char) >= 0 for char in self.invalid_unicodes_for_name): if any(name.find(char) >= 0 for char in self.invalid_unicodes_for_name):
raise InvalidName("Invalid Name: '" + name + "'") raise InvalidName("Invalid Name: '" + name + "'")
def _validate_role_arn(self, role_arn): def _validate_role_arn(self, role_arn: str) -> None:
self._validate_arn( self._validate_arn(
arn=role_arn, arn=role_arn,
regex=self.accepted_role_arn_format, regex=self.accepted_role_arn_format,
invalid_msg="Invalid Role Arn: '" + role_arn + "'", invalid_msg="Invalid Role Arn: '" + role_arn + "'",
) )
def _validate_machine_arn(self, machine_arn): def _validate_machine_arn(self, machine_arn: str) -> None:
self._validate_arn( self._validate_arn(
arn=machine_arn, arn=machine_arn,
regex=self.accepted_mchn_arn_format, regex=self.accepted_mchn_arn_format,
invalid_msg="Invalid State Machine Arn: '" + machine_arn + "'", invalid_msg="Invalid State Machine Arn: '" + machine_arn + "'",
) )
def _validate_execution_arn(self, execution_arn): def _validate_execution_arn(self, execution_arn: str) -> None:
self._validate_arn( self._validate_arn(
arn=execution_arn, arn=execution_arn,
regex=self.accepted_exec_arn_format, regex=self.accepted_exec_arn_format,
invalid_msg="Execution Does Not Exist: '" + execution_arn + "'", invalid_msg="Execution Does Not Exist: '" + execution_arn + "'",
) )
def _validate_arn(self, arn, regex, invalid_msg): def _validate_arn(self, arn: str, regex: Pattern[str], invalid_msg: str) -> None:
match = regex.match(arn) match = regex.match(arn)
if not arn or not match: if not arn or not match:
raise InvalidArn(invalid_msg) raise InvalidArn(invalid_msg)
def _get_state_machine_for_execution(self, execution_arn): def _get_state_machine_for_execution(self, execution_arn: str) -> StateMachine:
state_machine_name = execution_arn.split(":")[6] state_machine_name = execution_arn.split(":")[6]
state_machine_arn = next( state_machine_arn = next(
(x.arn for x in self.state_machines if x.name == state_machine_name), None (x.arn for x in self.state_machines if x.name == state_machine_name), None

View File

@ -1,20 +1,21 @@
import json import json
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.utilities.aws_headers import amzn_request_id from moto.utilities.aws_headers import amzn_request_id
from .models import stepfunction_backends from .models import stepfunction_backends, StepFunctionBackend
class StepFunctionResponse(BaseResponse): class StepFunctionResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="stepfunctions") super().__init__(service_name="stepfunctions")
@property @property
def stepfunction_backend(self): def stepfunction_backend(self) -> StepFunctionBackend:
return stepfunction_backends[self.current_account][self.region] return stepfunction_backends[self.current_account][self.region]
@amzn_request_id @amzn_request_id
def create_state_machine(self): def create_state_machine(self) -> TYPE_RESPONSE:
name = self._get_param("name") name = self._get_param("name")
definition = self._get_param("definition") definition = self._get_param("definition")
roleArn = self._get_param("roleArn") roleArn = self._get_param("roleArn")
@ -29,7 +30,7 @@ class StepFunctionResponse(BaseResponse):
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def list_state_machines(self): def list_state_machines(self) -> TYPE_RESPONSE:
max_results = self._get_int_param("maxResults") max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken") next_token = self._get_param("nextToken")
results, next_token = self.stepfunction_backend.list_state_machines( results, next_token = self.stepfunction_backend.list_state_machines(
@ -49,12 +50,12 @@ class StepFunctionResponse(BaseResponse):
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def describe_state_machine(self): def describe_state_machine(self) -> TYPE_RESPONSE:
arn = self._get_param("stateMachineArn") arn = self._get_param("stateMachineArn")
return self._describe_state_machine(arn) return self._describe_state_machine(arn)
@amzn_request_id @amzn_request_id
def _describe_state_machine(self, state_machine_arn): def _describe_state_machine(self, state_machine_arn: str) -> TYPE_RESPONSE:
state_machine = self.stepfunction_backend.describe_state_machine( state_machine = self.stepfunction_backend.describe_state_machine(
state_machine_arn state_machine_arn
) )
@ -69,13 +70,13 @@ class StepFunctionResponse(BaseResponse):
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def delete_state_machine(self): def delete_state_machine(self) -> TYPE_RESPONSE:
arn = self._get_param("stateMachineArn") arn = self._get_param("stateMachineArn")
self.stepfunction_backend.delete_state_machine(arn) self.stepfunction_backend.delete_state_machine(arn)
return 200, {}, json.dumps("{}") return 200, {}, json.dumps("{}")
@amzn_request_id @amzn_request_id
def update_state_machine(self): def update_state_machine(self) -> TYPE_RESPONSE:
arn = self._get_param("stateMachineArn") arn = self._get_param("stateMachineArn")
definition = self._get_param("definition") definition = self._get_param("definition")
role_arn = self._get_param("roleArn") role_arn = self._get_param("roleArn")
@ -88,28 +89,28 @@ class StepFunctionResponse(BaseResponse):
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def list_tags_for_resource(self): def list_tags_for_resource(self) -> TYPE_RESPONSE:
arn = self._get_param("resourceArn") arn = self._get_param("resourceArn")
tags = self.stepfunction_backend.list_tags_for_resource(arn) tags = self.stepfunction_backend.list_tags_for_resource(arn)
response = {"tags": tags} response = {"tags": tags}
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def tag_resource(self): def tag_resource(self) -> TYPE_RESPONSE:
arn = self._get_param("resourceArn") arn = self._get_param("resourceArn")
tags = self._get_param("tags", []) tags = self._get_param("tags", [])
self.stepfunction_backend.tag_resource(arn, tags) self.stepfunction_backend.tag_resource(arn, tags)
return 200, {}, json.dumps({}) return 200, {}, json.dumps({})
@amzn_request_id @amzn_request_id
def untag_resource(self): def untag_resource(self) -> TYPE_RESPONSE:
arn = self._get_param("resourceArn") arn = self._get_param("resourceArn")
tag_keys = self._get_param("tagKeys", []) tag_keys = self._get_param("tagKeys", [])
self.stepfunction_backend.untag_resource(arn, tag_keys) self.stepfunction_backend.untag_resource(arn, tag_keys)
return 200, {}, json.dumps({}) return 200, {}, json.dumps({})
@amzn_request_id @amzn_request_id
def start_execution(self): def start_execution(self) -> TYPE_RESPONSE:
arn = self._get_param("stateMachineArn") arn = self._get_param("stateMachineArn")
name = self._get_param("name") name = self._get_param("name")
execution_input = self._get_param("input", if_none="{}") execution_input = self._get_param("input", if_none="{}")
@ -123,7 +124,7 @@ class StepFunctionResponse(BaseResponse):
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def list_executions(self): def list_executions(self) -> TYPE_RESPONSE:
max_results = self._get_int_param("maxResults") max_results = self._get_int_param("maxResults")
next_token = self._get_param("nextToken") next_token = self._get_param("nextToken")
arn = self._get_param("stateMachineArn") arn = self._get_param("stateMachineArn")
@ -151,7 +152,7 @@ class StepFunctionResponse(BaseResponse):
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def describe_execution(self): def describe_execution(self) -> TYPE_RESPONSE:
arn = self._get_param("executionArn") arn = self._get_param("executionArn")
execution = self.stepfunction_backend.describe_execution(arn) execution = self.stepfunction_backend.describe_execution(arn)
response = { response = {
@ -166,20 +167,20 @@ class StepFunctionResponse(BaseResponse):
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def describe_state_machine_for_execution(self): def describe_state_machine_for_execution(self) -> TYPE_RESPONSE:
arn = self._get_param("executionArn") arn = self._get_param("executionArn")
execution = self.stepfunction_backend.describe_execution(arn) execution = self.stepfunction_backend.describe_execution(arn)
return self._describe_state_machine(execution.state_machine_arn) return self._describe_state_machine(execution.state_machine_arn)
@amzn_request_id @amzn_request_id
def stop_execution(self): def stop_execution(self) -> TYPE_RESPONSE:
arn = self._get_param("executionArn") arn = self._get_param("executionArn")
execution = self.stepfunction_backend.stop_execution(arn) execution = self.stepfunction_backend.stop_execution(arn)
response = {"stopDate": execution.stop_date} response = {"stopDate": execution.stop_date}
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def get_execution_history(self): def get_execution_history(self) -> TYPE_RESPONSE:
execution_arn = self._get_param("executionArn") execution_arn = self._get_param("executionArn")
execution_history = self.stepfunction_backend.get_execution_history( execution_history = self.stepfunction_backend.get_execution_history(
execution_arn execution_arn

View File

@ -1,3 +1,6 @@
from typing import Dict, List
PAGINATION_MODEL = { PAGINATION_MODEL = {
"list_executions": { "list_executions": {
"input_token": "next_token", "input_token": "next_token",
@ -14,11 +17,9 @@ PAGINATION_MODEL = {
} }
def cfn_to_api_tags(cfn_tags_entry): def cfn_to_api_tags(cfn_tags_entry: List[Dict[str, str]]) -> List[Dict[str, str]]:
api_tags = [{k.lower(): v for k, v in d.items()} for d in cfn_tags_entry] return [{k.lower(): v for k, v in d.items()} for d in cfn_tags_entry]
return api_tags
def api_to_cfn_tags(api_tags): def api_to_cfn_tags(api_tags: List[Dict[str, str]]) -> List[Dict[str, str]]:
cfn_tags_entry = [{k.capitalize(): v for k, v in d.items()} for d in api_tags] return [{k.capitalize(): v for k, v in d.items()} for d in api_tags]
return cfn_tags_entry

View File

@ -1,3 +1,4 @@
from typing import Any
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
@ -6,5 +7,5 @@ class STSClientError(RESTError):
class STSValidationError(STSClientError): class STSValidationError(STSClientError):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
super().__init__("ValidationError", *args, **kwargs) super().__init__("ValidationError", *args, **kwargs)

View File

@ -1,10 +1,12 @@
from base64 import b64decode from base64 import b64decode
from typing import Any, Dict, List, Optional, Tuple
import datetime import datetime
import re import re
import xmltodict import xmltodict
from moto.core import BaseBackend, BaseModel, BackendDict from moto.core import BaseBackend, BaseModel, BackendDict
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.iam import iam_backends from moto.iam.models import iam_backends, AccessKey
from moto.sts.utils import ( from moto.sts.utils import (
random_session_token, random_session_token,
DEFAULT_STS_SESSION_DURATION, DEFAULT_STS_SESSION_DURATION,
@ -13,27 +15,27 @@ from moto.sts.utils import (
class Token(BaseModel): class Token(BaseModel):
def __init__(self, duration, name=None): def __init__(self, duration: int, name: Optional[str] = None):
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
self.expiration = now + datetime.timedelta(seconds=duration) self.expiration = now + datetime.timedelta(seconds=duration)
self.name = name self.name = name
self.policy = None self.policy = None
@property @property
def expiration_ISO8601(self): def expiration_ISO8601(self) -> str:
return iso_8601_datetime_with_milliseconds(self.expiration) return iso_8601_datetime_with_milliseconds(self.expiration)
class AssumedRole(BaseModel): class AssumedRole(BaseModel):
def __init__( def __init__(
self, self,
account_id, account_id: str,
access_key, access_key: AccessKey,
role_session_name, role_session_name: str,
role_arn, role_arn: str,
policy, policy: str,
duration, duration: int,
external_id, external_id: str,
): ):
self.account_id = account_id self.account_id = account_id
self.session_name = role_session_name self.session_name = role_session_name
@ -48,11 +50,11 @@ class AssumedRole(BaseModel):
self.session_token = random_session_token() self.session_token = random_session_token()
@property @property
def expiration_ISO8601(self): def expiration_ISO8601(self) -> str:
return iso_8601_datetime_with_milliseconds(self.expiration) return iso_8601_datetime_with_milliseconds(self.expiration)
@property @property
def user_id(self): def user_id(self) -> str:
iam_backend = iam_backends[self.account_id]["global"] iam_backend = iam_backends[self.account_id]["global"]
try: try:
role_id = iam_backend.get_role_by_arn(arn=self.role_arn).id role_id = iam_backend.get_role_by_arn(arn=self.role_arn).id
@ -61,31 +63,38 @@ class AssumedRole(BaseModel):
return role_id + ":" + self.session_name return role_id + ":" + self.session_name
@property @property
def arn(self): def arn(self) -> str:
return f"arn:aws:sts::{self.account_id}:assumed-role/{self.role_arn.split('/')[-1]}/{self.session_name}" return f"arn:aws:sts::{self.account_id}:assumed-role/{self.role_arn.split('/')[-1]}/{self.session_name}"
class STSBackend(BaseBackend): class STSBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.assumed_roles = [] self.assumed_roles: List[AssumedRole] = []
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "sts" service_region, zones, "sts"
) )
def get_session_token(self, duration): def get_session_token(self, duration: int) -> Token:
token = Token(duration=duration) return Token(duration=duration)
return token
def get_federation_token(self, name, duration): def get_federation_token(self, name: Optional[str], duration: int) -> Token:
token = Token(duration=duration, name=name) return Token(duration=duration, name=name)
return token
def assume_role(self, role_session_name, role_arn, policy, duration, external_id): def assume_role(
self,
role_session_name: str,
role_arn: str,
policy: str,
duration: int,
external_id: str,
) -> AssumedRole:
""" """
Assume an IAM Role. Note that the role does not need to exist. The ARN can point to another account, providing an opportunity to switch accounts. Assume an IAM Role. Note that the role does not need to exist. The ARN can point to another account, providing an opportunity to switch accounts.
""" """
@ -103,16 +112,18 @@ class STSBackend(BaseBackend):
account_backend.assumed_roles.append(role) account_backend.assumed_roles.append(role)
return role return role
def get_assumed_role_from_access_key(self, access_key_id): def get_assumed_role_from_access_key(
self, access_key_id: str
) -> Optional[AssumedRole]:
for assumed_role in self.assumed_roles: for assumed_role in self.assumed_roles:
if assumed_role.access_key_id == access_key_id: if assumed_role.access_key_id == access_key_id:
return assumed_role return assumed_role
return None return None
def assume_role_with_web_identity(self, **kwargs): def assume_role_with_web_identity(self, **kwargs: Any) -> AssumedRole:
return self.assume_role(**kwargs) return self.assume_role(**kwargs)
def assume_role_with_saml(self, **kwargs): def assume_role_with_saml(self, **kwargs: Any) -> AssumedRole:
del kwargs["principal_arn"] del kwargs["principal_arn"]
saml_assertion_encoded = kwargs.pop("saml_assertion") saml_assertion_encoded = kwargs.pop("saml_assertion")
saml_assertion_decoded = b64decode(saml_assertion_encoded) saml_assertion_decoded = b64decode(saml_assertion_encoded)
@ -150,7 +161,7 @@ class STSBackend(BaseBackend):
if "duration" not in kwargs: if "duration" not in kwargs:
kwargs["duration"] = DEFAULT_STS_SESSION_DURATION kwargs["duration"] = DEFAULT_STS_SESSION_DURATION
account_id, access_key = self._create_access_key(role=target_role) account_id, access_key = self._create_access_key(role=target_role) # type: ignore
kwargs["account_id"] = account_id kwargs["account_id"] = account_id
kwargs["access_key"] = access_key kwargs["access_key"] = access_key
@ -160,7 +171,7 @@ class STSBackend(BaseBackend):
self.assumed_roles.append(role) self.assumed_roles.append(role)
return role return role
def get_caller_identity(self, access_key_id): def get_caller_identity(self, access_key_id: str) -> Tuple[str, str, str]:
assumed_role = self.get_assumed_role_from_access_key(access_key_id) assumed_role = self.get_assumed_role_from_access_key(access_key_id)
if assumed_role: if assumed_role:
return assumed_role.user_id, assumed_role.arn, assumed_role.account_id return assumed_role.user_id, assumed_role.arn, assumed_role.account_id
@ -175,7 +186,7 @@ class STSBackend(BaseBackend):
arn = f"arn:aws:sts::{self.account_id}:user/moto" arn = f"arn:aws:sts::{self.account_id}:user/moto"
return user_id, arn, self.account_id return user_id, arn, self.account_id
def _create_access_key(self, role): def _create_access_key(self, role: str) -> Tuple[str, AccessKey]:
account_id_match = re.search(r"arn:aws:iam::([0-9]+).+", role) account_id_match = re.search(r"arn:aws:iam::([0-9]+).+", role)
if account_id_match: if account_id_match:
account_id = account_id_match.group(1) account_id = account_id_match.group(1)

View File

@ -1,25 +1,25 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .exceptions import STSValidationError from .exceptions import STSValidationError
from .models import sts_backends from .models import sts_backends, STSBackend
MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048 MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048
class TokenResponse(BaseResponse): class TokenResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="sts") super().__init__(service_name="sts")
@property @property
def backend(self): def backend(self) -> STSBackend:
return sts_backends[self.current_account]["global"] return sts_backends[self.current_account]["global"]
def get_session_token(self): def get_session_token(self) -> str:
duration = int(self.querystring.get("DurationSeconds", [43200])[0]) duration = int(self.querystring.get("DurationSeconds", [43200])[0])
token = self.backend.get_session_token(duration=duration) token = self.backend.get_session_token(duration=duration)
template = self.response_template(GET_SESSION_TOKEN_RESPONSE) template = self.response_template(GET_SESSION_TOKEN_RESPONSE)
return template.render(token=token) return template.render(token=token)
def get_federation_token(self): def get_federation_token(self) -> str:
duration = int(self.querystring.get("DurationSeconds", [43200])[0]) duration = int(self.querystring.get("DurationSeconds", [43200])[0])
policy = self.querystring.get("Policy", [None])[0] policy = self.querystring.get("Policy", [None])[0]
@ -31,14 +31,14 @@ class TokenResponse(BaseResponse):
f" equal to {MAX_FEDERATION_TOKEN_POLICY_LENGTH}" f" equal to {MAX_FEDERATION_TOKEN_POLICY_LENGTH}"
) )
name = self.querystring.get("Name")[0] name = self.querystring.get("Name")[0] # type: ignore
token = self.backend.get_federation_token(duration=duration, name=name) token = self.backend.get_federation_token(duration=duration, name=name)
template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE) template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE)
return template.render(token=token, account_id=self.current_account) return template.render(token=token, account_id=self.current_account)
def assume_role(self): def assume_role(self) -> str:
role_session_name = self.querystring.get("RoleSessionName")[0] role_session_name = self.querystring.get("RoleSessionName")[0] # type: ignore
role_arn = self.querystring.get("RoleArn")[0] role_arn = self.querystring.get("RoleArn")[0] # type: ignore
policy = self.querystring.get("Policy", [None])[0] policy = self.querystring.get("Policy", [None])[0]
duration = int(self.querystring.get("DurationSeconds", [3600])[0]) duration = int(self.querystring.get("DurationSeconds", [3600])[0])
@ -54,9 +54,9 @@ class TokenResponse(BaseResponse):
template = self.response_template(ASSUME_ROLE_RESPONSE) template = self.response_template(ASSUME_ROLE_RESPONSE)
return template.render(role=role) return template.render(role=role)
def assume_role_with_web_identity(self): def assume_role_with_web_identity(self) -> str:
role_session_name = self.querystring.get("RoleSessionName")[0] role_session_name = self.querystring.get("RoleSessionName")[0] # type: ignore
role_arn = self.querystring.get("RoleArn")[0] role_arn = self.querystring.get("RoleArn")[0] # type: ignore
policy = self.querystring.get("Policy", [None])[0] policy = self.querystring.get("Policy", [None])[0]
duration = int(self.querystring.get("DurationSeconds", [3600])[0]) duration = int(self.querystring.get("DurationSeconds", [3600])[0])
@ -72,10 +72,10 @@ class TokenResponse(BaseResponse):
template = self.response_template(ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE) template = self.response_template(ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE)
return template.render(role=role) return template.render(role=role)
def assume_role_with_saml(self): def assume_role_with_saml(self) -> str:
role_arn = self.querystring.get("RoleArn")[0] role_arn = self.querystring.get("RoleArn")[0] # type: ignore
principal_arn = self.querystring.get("PrincipalArn")[0] principal_arn = self.querystring.get("PrincipalArn")[0] # type: ignore
saml_assertion = self.querystring.get("SAMLAssertion")[0] saml_assertion = self.querystring.get("SAMLAssertion")[0] # type: ignore
role = self.backend.assume_role_with_saml( role = self.backend.assume_role_with_saml(
role_arn=role_arn, role_arn=role_arn,
@ -85,7 +85,7 @@ class TokenResponse(BaseResponse):
template = self.response_template(ASSUME_ROLE_WITH_SAML_RESPONSE) template = self.response_template(ASSUME_ROLE_WITH_SAML_RESPONSE)
return template.render(role=role) return template.render(role=role)
def get_caller_identity(self): def get_caller_identity(self) -> str:
template = self.response_template(GET_CALLER_IDENTITY_RESPONSE) template = self.response_template(GET_CALLER_IDENTITY_RESPONSE)
access_key_id = self.get_access_key() access_key_id = self.get_access_key()

View File

@ -16,13 +16,13 @@ def random_session_token() -> str:
) )
def random_assumed_role_id(): def random_assumed_role_id() -> str:
return ( return (
ACCOUNT_SPECIFIC_ASSUMED_ROLE_ID_PREFIX + _random_uppercase_or_digit_sequence(9) ACCOUNT_SPECIFIC_ASSUMED_ROLE_ID_PREFIX + _random_uppercase_or_digit_sequence(9)
) )
def _random_uppercase_or_digit_sequence(length): def _random_uppercase_or_digit_sequence(length: int) -> str:
return "".join( return "".join(
str(random.choice(string.ascii_uppercase + string.digits)) str(random.choice(string.ascii_uppercase + string.digits))
for _ in range(length) for _ in range(length)

View File

@ -4,6 +4,7 @@ from moto.moto_api._internal.managed_state_model import ManagedState
from moto.moto_api._internal import mock_random as random from moto.moto_api._internal import mock_random as random
from moto.utilities.utils import load_resource from moto.utilities.utils import load_resource
import datetime import datetime
from typing import Any, Dict, List, Optional
checks_json = "resources/describe_trusted_advisor_checks.json" checks_json = "resources/describe_trusted_advisor_checks.json"
@ -11,7 +12,7 @@ ADVISOR_CHECKS = load_resource(__name__, checks_json)
class SupportCase(ManagedState): class SupportCase(ManagedState):
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
# Configure ManagedState # Configure ManagedState
super().__init__( super().__init__(
"support::case", "support::case",
@ -56,21 +57,21 @@ class SupportCase(ManagedState):
} }
} }
def get_datetime(self): def get_datetime(self) -> str:
return str(datetime.datetime.now().isoformat()) return str(datetime.datetime.now().isoformat())
class SupportBackend(BaseBackend): class SupportBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.check_status = {} self.check_status: Dict[str, str] = {}
self.cases = {} self.cases: Dict[str, SupportCase] = {}
state_manager.register_default_transition( state_manager.register_default_transition(
model_name="support::case", transition={"progression": "manual", "times": 1} model_name="support::case", transition={"progression": "manual", "times": 1}
) )
def describe_trusted_advisor_checks(self): def describe_trusted_advisor_checks(self) -> List[Dict[str, Any]]:
""" """
The Language-parameter is not yet implemented The Language-parameter is not yet implemented
""" """
@ -78,18 +79,17 @@ class SupportBackend(BaseBackend):
checks = ADVISOR_CHECKS["checks"] checks = ADVISOR_CHECKS["checks"]
return checks return checks
def refresh_trusted_advisor_check(self, check_id): def refresh_trusted_advisor_check(self, check_id: str) -> Dict[str, Any]:
self.advance_check_status(check_id) self.advance_check_status(check_id)
status = { return {
"status": { "status": {
"checkId": check_id, "checkId": check_id,
"status": self.check_status[check_id], "status": self.check_status[check_id],
"millisUntilNextRefreshable": 123, "millisUntilNextRefreshable": 123,
} }
} }
return status
def advance_check_status(self, check_id): def advance_check_status(self, check_id: str) -> None:
""" """
Fake an advancement through statuses on refreshing TA checks Fake an advancement through statuses on refreshing TA checks
""" """
@ -111,14 +111,13 @@ class SupportBackend(BaseBackend):
elif self.check_status[check_id] == "abandoned": elif self.check_status[check_id] == "abandoned":
self.check_status[check_id] = "none" self.check_status[check_id] = "none"
def advance_case_status(self, case_id): def advance_case_status(self, case_id: str) -> None:
""" """
Fake an advancement through case statuses Fake an advancement through case statuses
""" """
self.cases[case_id].advance() self.cases[case_id].advance()
def advance_case_severity_codes(self, case_id): def advance_case_severity_codes(self, case_id: str) -> None:
""" """
Fake an advancement through case status severities Fake an advancement through case status severities
""" """
@ -137,28 +136,26 @@ class SupportBackend(BaseBackend):
elif self.cases[case_id].severity_code == "critical": elif self.cases[case_id].severity_code == "critical":
self.cases[case_id].severity_code = "low" self.cases[case_id].severity_code = "low"
def resolve_case(self, case_id): def resolve_case(self, case_id: str) -> Dict[str, Optional[str]]:
self.advance_case_status(case_id) self.advance_case_status(case_id)
resolved_case = { return {
"initialCaseStatus": self.cases[case_id].status, "initialCaseStatus": self.cases[case_id].status,
"finalCaseStatus": "resolved", "finalCaseStatus": "resolved",
} }
return resolved_case
# persist case details to self.cases # persist case details to self.cases
def create_case( def create_case(
self, self,
subject, subject: str,
service_code, service_code: str,
severity_code, severity_code: str,
category_code, category_code: str,
communication_body, communication_body: str,
cc_email_addresses, cc_email_addresses: List[str],
language, language: str,
attachment_set_id, attachment_set_id: str,
): ) -> Dict[str, str]:
""" """
The IssueType-parameter is not yet implemented The IssueType-parameter is not yet implemented
""" """
@ -184,11 +181,11 @@ class SupportBackend(BaseBackend):
def describe_cases( def describe_cases(
self, self,
case_id_list, case_id_list: List[str],
include_resolved_cases, include_resolved_cases: bool,
next_token, next_token: Optional[str],
include_communications, include_communications: bool,
): ) -> Dict[str, Any]:
""" """
The following parameters have not yet been implemented: The following parameters have not yet been implemented:
DisplayID, AfterTime, BeforeTime, MaxResults, Language DisplayID, AfterTime, BeforeTime, MaxResults, Language
@ -223,10 +220,7 @@ class SupportBackend(BaseBackend):
continue continue
cases.append(formatted_case) cases.append(formatted_case)
case_values = {"cases": cases} return {"cases": cases, "nextToken": next_token}
case_values.update({"nextToken": next_token})
return case_values
support_backends = BackendDict( support_backends = BackendDict(

View File

@ -1,33 +1,33 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import support_backends from .models import support_backends, SupportBackend
import json import json
class SupportResponse(BaseResponse): class SupportResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="support") super().__init__(service_name="support")
@property @property
def support_backend(self): def support_backend(self) -> SupportBackend:
return support_backends[self.current_account][self.region] return support_backends[self.current_account][self.region]
def describe_trusted_advisor_checks(self): def describe_trusted_advisor_checks(self) -> str:
checks = self.support_backend.describe_trusted_advisor_checks() checks = self.support_backend.describe_trusted_advisor_checks()
return json.dumps({"checks": checks}) return json.dumps({"checks": checks})
def refresh_trusted_advisor_check(self): def refresh_trusted_advisor_check(self) -> str:
check_id = self._get_param("checkId") check_id = self._get_param("checkId")
status = self.support_backend.refresh_trusted_advisor_check(check_id=check_id) status = self.support_backend.refresh_trusted_advisor_check(check_id=check_id)
return json.dumps(status) return json.dumps(status)
def resolve_case(self): def resolve_case(self) -> str:
case_id = self._get_param("caseId") case_id = self._get_param("caseId")
resolve_case_response = self.support_backend.resolve_case(case_id=case_id) resolve_case_response = self.support_backend.resolve_case(case_id=case_id)
return json.dumps(resolve_case_response) return json.dumps(resolve_case_response)
def create_case(self): def create_case(self) -> str:
subject = self._get_param("subject") subject = self._get_param("subject")
service_code = self._get_param("serviceCode") service_code = self._get_param("serviceCode")
severity_code = self._get_param("severityCode") severity_code = self._get_param("severityCode")
@ -49,7 +49,7 @@ class SupportResponse(BaseResponse):
return json.dumps(create_case_response) return json.dumps(create_case_response)
def describe_cases(self): def describe_cases(self) -> str:
case_id_list = self._get_param("caseIdList") case_id_list = self._get_param("caseIdList")
include_resolved_cases = self._get_param("includeResolvedCases", False) include_resolved_cases = self._get_param("includeResolvedCases", False)
next_token = self._get_param("nextToken") next_token = self._get_param("nextToken")

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf,moto/sns files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s*
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract