From 6087a203fd491d11060e9f8795762299cbfc45e4 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Mon, 20 Mar 2023 13:58:49 -0100 Subject: [PATCH] Techdebt: MyPy I (#6092) --- moto/identitystore/models.py | 14 +- moto/identitystore/responses.py | 20 +- moto/instance_metadata/responses.py | 13 +- moto/iot/exceptions.py | 19 +- moto/iot/models.py | 625 ++++++++++++++++------------ moto/iot/responses.py | 178 ++++---- moto/iotdata/exceptions.py | 6 +- moto/iotdata/models.py | 61 +-- moto/iotdata/responses.py | 14 +- setup.cfg | 2 +- 10 files changed, 535 insertions(+), 417 deletions(-) diff --git a/moto/identitystore/models.py b/moto/identitystore/models.py index 287610f21..397150095 100644 --- a/moto/identitystore/models.py +++ b/moto/identitystore/models.py @@ -1,19 +1,19 @@ -"""IdentityStoreBackend class with methods for supported APIs.""" -from moto.moto_api._internal import mock_random +from typing import Dict, Tuple +from moto.moto_api._internal import mock_random from moto.core import BaseBackend, BackendDict class IdentityStoreBackend(BaseBackend): """Implementation of IdentityStore APIs.""" - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.groups = {} + self.groups: Dict[str, Dict[str, str]] = {} - # add methods from here - - def create_group(self, identity_store_id, display_name, description): + def create_group( + self, identity_store_id: str, display_name: str, description: str + ) -> Tuple[str, str]: group_id = str(mock_random.uuid4()) group_dict = { "GroupId": group_id, diff --git a/moto/identitystore/responses.py b/moto/identitystore/responses.py index 0308631a2..40d764559 100644 --- a/moto/identitystore/responses.py +++ b/moto/identitystore/responses.py @@ -2,33 +2,27 @@ import json from moto.core.responses import BaseResponse -from .models import identitystore_backends +from .models import identitystore_backends, IdentityStoreBackend class IdentityStoreResponse(BaseResponse): """Handler for IdentityStore requests and responses.""" - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="identitystore") @property - def identitystore_backend(self): + def identitystore_backend(self) -> IdentityStoreBackend: """Return backend instance specific for this region.""" return identitystore_backends[self.current_account][self.region] - # add methods from here - - def create_group(self): - params = self._get_params() - identity_store_id = params.get("IdentityStoreId") - display_name = params.get("DisplayName") - description = params.get("Description") + def create_group(self) -> str: + identity_store_id = self._get_param("IdentityStoreId") + display_name = self._get_param("DisplayName") + description = self._get_param("Description") group_id, identity_store_id = self.identitystore_backend.create_group( identity_store_id=identity_store_id, display_name=display_name, description=description, ) return json.dumps(dict(GroupId=group_id, IdentityStoreId=identity_store_id)) - - def _get_params(self): - return json.loads(self.body) diff --git a/moto/instance_metadata/responses.py b/moto/instance_metadata/responses.py index 6b11aeedd..91c74ce66 100644 --- a/moto/instance_metadata/responses.py +++ b/moto/instance_metadata/responses.py @@ -1,20 +1,25 @@ import datetime import json +from typing import Any from urllib.parse import urlparse +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse class InstanceMetadataResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name=None) - def backends(self): + def backends(self) -> None: pass def metadata_response( - self, request, full_url, headers - ): # pylint: disable=unused-argument + self, + request: Any, # pylint: disable=unused-argument + full_url: str, + headers: Any, + ) -> TYPE_RESPONSE: """ Mock response for localhost metadata diff --git a/moto/iot/exceptions.py b/moto/iot/exceptions.py index 00e613bfe..b12c481e4 100644 --- a/moto/iot/exceptions.py +++ b/moto/iot/exceptions.py @@ -1,4 +1,5 @@ import json +from typing import Optional from moto.core.exceptions import JsonRESTError @@ -8,7 +9,7 @@ class IoTClientError(JsonRESTError): class ResourceNotFoundException(IoTClientError): - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): self.code = 404 super().__init__( "ResourceNotFoundException", msg or "The specified resource does not exist" @@ -16,13 +17,13 @@ class ResourceNotFoundException(IoTClientError): class InvalidRequestException(IoTClientError): - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): self.code = 400 super().__init__("InvalidRequestException", msg or "The request is not valid.") class InvalidStateTransitionException(IoTClientError): - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): self.code = 409 super().__init__( "InvalidStateTransitionException", @@ -31,7 +32,7 @@ class InvalidStateTransitionException(IoTClientError): class VersionConflictException(IoTClientError): - def __init__(self, name): + def __init__(self, name: str): self.code = 409 super().__init__( "VersionConflictException", @@ -40,19 +41,19 @@ class VersionConflictException(IoTClientError): class CertificateStateException(IoTClientError): - def __init__(self, msg, cert_id): + def __init__(self, msg: str, cert_id: str): self.code = 406 super().__init__("CertificateStateException", f"{msg} Id: {cert_id}") class DeleteConflictException(IoTClientError): - def __init__(self, msg): + def __init__(self, msg: str): self.code = 409 super().__init__("DeleteConflictException", msg) class ResourceAlreadyExistsException(IoTClientError): - def __init__(self, msg, resource_id, resource_arn): + def __init__(self, msg: str, resource_id: str, resource_arn: str): self.code = 409 super().__init__( "ResourceAlreadyExistsException", msg or "The resource already exists." @@ -67,7 +68,7 @@ class ResourceAlreadyExistsException(IoTClientError): class VersionsLimitExceededException(IoTClientError): - def __init__(self, name): + def __init__(self, name: str): self.code = 409 super().__init__( "VersionsLimitExceededException", @@ -76,7 +77,7 @@ class VersionsLimitExceededException(IoTClientError): class ThingStillAttached(IoTClientError): - def __init__(self, name): + def __init__(self, name: str): self.code = 409 super().__init__( "InvalidRequestException", diff --git a/moto/iot/models.py b/moto/iot/models.py index 68656778b..69715bf14 100644 --- a/moto/iot/models.py +++ b/moto/iot/models.py @@ -7,8 +7,8 @@ from cryptography.hazmat._oid import NameOID from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization, hashes - from datetime import datetime, timedelta +from typing import Any, Dict, List, Tuple, Optional, Pattern, Iterable from .utils import PAGINATION_MODEL @@ -29,7 +29,14 @@ from .exceptions import ( class FakeThing(BaseModel): - def __init__(self, thing_name, thing_type, attributes, account_id, region_name): + def __init__( + self, + thing_name: str, + thing_type: Optional["FakeThingType"], + attributes: Dict[str, Any], + account_id: str, + region_name: str, + ): self.region_name = region_name self.thing_name = thing_name self.thing_type = thing_type @@ -39,20 +46,20 @@ class FakeThing(BaseModel): # TODO: we need to handle "version"? # for iot-data - self.thing_shadow = None + self.thing_shadow: Any = None - def matches(self, query_string): + def matches(self, query_string: str) -> bool: if query_string == "*": return True if query_string.startswith("thingName:"): qs = query_string[10:].replace("*", ".*").replace("?", ".") - return re.search(f"^{qs}$", self.thing_name) + return re.search(f"^{qs}$", self.thing_name) is not None if query_string.startswith("attributes."): k, v = query_string[11:].split(":") return self.attributes.get(k) == v return query_string in self.thing_name - def to_dict(self, include_default_client_id=False): + def to_dict(self, include_default_client_id: bool = False) -> Dict[str, Any]: obj = { "thingName": self.thing_name, "thingArn": self.arn, @@ -67,7 +74,12 @@ class FakeThing(BaseModel): class FakeThingType(BaseModel): - def __init__(self, thing_type_name, thing_type_properties, region_name): + def __init__( + self, + thing_type_name: str, + thing_type_properties: Optional[Dict[str, Any]], + region_name: str, + ): self.region_name = region_name self.thing_type_name = thing_type_name self.thing_type_properties = thing_type_properties @@ -76,7 +88,7 @@ class FakeThingType(BaseModel): self.metadata = {"deprecated": False, "creationDate": int(t * 1000) / 1000.0} self.arn = f"arn:aws:iot:{self.region_name}:1:thingtype/{thing_type_name}" - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "thingTypeName": self.thing_type_name, "thingTypeId": self.thing_type_id, @@ -89,11 +101,11 @@ class FakeThingType(BaseModel): class FakeThingGroup(BaseModel): def __init__( self, - thing_group_name, - parent_group_name, - thing_group_properties, - region_name, - thing_groups, + thing_group_name: str, + parent_group_name: str, + thing_group_properties: Dict[str, str], + region_name: str, + thing_groups: Dict[str, "FakeThingGroup"], ): self.region_name = region_name self.thing_group_name = thing_group_name @@ -102,7 +114,7 @@ class FakeThingGroup(BaseModel): self.parent_group_name = parent_group_name self.thing_group_properties = thing_group_properties or {} t = time.time() - self.metadata = {"creationDate": int(t * 1000) / 1000.0} + self.metadata: Dict[str, Any] = {"creationDate": int(t * 1000) / 1000.0} if parent_group_name: self.metadata["parentGroupName"] = parent_group_name # initilize rootToParentThingGroups @@ -124,14 +136,14 @@ class FakeThingGroup(BaseModel): [ { "groupName": parent_group_name, - "groupArn": parent_thing_group_structure.arn, + "groupArn": parent_thing_group_structure.arn, # type: ignore } ] ) self.arn = f"arn:aws:iot:{self.region_name}:1:thinggroup/{thing_group_name}" - self.things = OrderedDict() + self.things: Dict[str, FakeThing] = OrderedDict() - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "thingGroupName": self.thing_group_name, "thingGroupId": self.thing_group_id, @@ -144,7 +156,12 @@ class FakeThingGroup(BaseModel): class FakeCertificate(BaseModel): def __init__( - self, certificate_pem, status, account_id, region_name, ca_certificate_id=None + self, + certificate_pem: str, + status: str, + account_id: str, + region_name: str, + ca_certificate_id: Optional[str] = None, ): m = hashlib.sha256() m.update(certificate_pem.encode("utf-8")) @@ -154,14 +171,14 @@ class FakeCertificate(BaseModel): self.status = status self.owner = account_id - self.transfer_data = {} + self.transfer_data: Dict[str, str] = {} self.creation_date = time.time() self.last_modified_date = self.creation_date self.validity_not_before = time.time() - 86400 self.validity_not_after = time.time() + 86400 self.ca_certificate_id = ca_certificate_id - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "certificateArn": self.arn, "certificateId": self.certificate_id, @@ -170,7 +187,7 @@ class FakeCertificate(BaseModel): "creationDate": self.creation_date, } - def to_description_dict(self): + def to_description_dict(self) -> Dict[str, Any]: """ You might need keys below in some situation - caCertificateId @@ -194,7 +211,12 @@ class FakeCertificate(BaseModel): class FakeCaCertificate(FakeCertificate): def __init__( - self, ca_certificate, status, account_id, region_name, registration_config + self, + ca_certificate: str, + status: str, + account_id: str, + region_name: str, + registration_config: Dict[str, str], ): super().__init__( certificate_pem=ca_certificate, @@ -207,7 +229,14 @@ class FakeCaCertificate(FakeCertificate): class FakePolicy(BaseModel): - def __init__(self, name, document, account_id, region_name, default_version_id="1"): + def __init__( + self, + name: str, + document: Dict[str, Any], + account_id: str, + region_name: str, + default_version_id: str = "1", + ): self.name = name self.document = document self.arn = f"arn:aws:iot:{region_name}:{account_id}:policy/{name}" @@ -217,7 +246,7 @@ class FakePolicy(BaseModel): ] self._max_version_id = self.versions[0]._version_id - def to_get_dict(self): + def to_get_dict(self) -> Dict[str, Any]: return { "policyName": self.name, "policyArn": self.arn, @@ -225,7 +254,7 @@ class FakePolicy(BaseModel): "defaultVersionId": self.default_version_id, } - def to_dict_at_creation(self): + def to_dict_at_creation(self) -> Dict[str, Any]: return { "policyName": self.name, "policyArn": self.arn, @@ -233,13 +262,19 @@ class FakePolicy(BaseModel): "policyVersionId": self.default_version_id, } - def to_dict(self): + def to_dict(self) -> Dict[str, str]: return {"policyName": self.name, "policyArn": self.arn} -class FakePolicyVersion(object): +class FakePolicyVersion: def __init__( - self, policy_name, document, is_default, account_id, region_name, version_id=1 + self, + policy_name: str, + document: Dict[str, Any], + is_default: bool, + account_id: str, + region_name: str, + version_id: int = 1, ): self.name = policy_name self.arn = f"arn:aws:iot:{region_name}:{account_id}:policy/{policy_name}" @@ -251,10 +286,10 @@ class FakePolicyVersion(object): self.last_modified_datetime = time.mktime(datetime(2015, 1, 2).timetuple()) @property - def version_id(self): + def version_id(self) -> str: return str(self._version_id) - def to_get_dict(self): + def to_get_dict(self) -> Dict[str, Any]: return { "policyName": self.name, "policyArn": self.arn, @@ -266,7 +301,7 @@ class FakePolicyVersion(object): "generationId": self.version_id, } - def to_dict_at_creation(self): + def to_dict_at_creation(self) -> Dict[str, Any]: return { "policyArn": self.arn, "policyDocument": self.document, @@ -274,7 +309,7 @@ class FakePolicyVersion(object): "isDefaultVersion": self.is_default, } - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "versionId": self.version_id, "isDefaultVersion": self.is_default, @@ -288,16 +323,16 @@ class FakeJob(BaseModel): def __init__( self, - job_id, - targets, - document_source, - document, - description, - presigned_url_config, - target_selection, - job_executions_rollout_config, - document_parameters, - region_name, + job_id: str, + targets: List[str], + document_source: str, + document: str, + description: str, + presigned_url_config: Dict[str, Any], + target_selection: str, + job_executions_rollout_config: Dict[str, Any], + document_parameters: Dict[str, str], + region_name: str, ): if not self._job_id_matcher(self.JOB_ID_REGEX, job_id): raise InvalidRequestException() @@ -314,8 +349,8 @@ class FakeJob(BaseModel): self.target_selection = target_selection self.job_executions_rollout_config = job_executions_rollout_config self.status = "QUEUED" # IN_PROGRESS | CANCELED | COMPLETED - self.comment = None - self.reason_code = None + self.comment: Optional[str] = None + self.reason_code: Optional[str] = None self.created_at = time.mktime(datetime(2015, 1, 1).timetuple()) self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple()) self.completed_at = None @@ -331,7 +366,7 @@ class FakeJob(BaseModel): } self.document_parameters = document_parameters - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: obj = { "jobArn": self.job_arn, "jobId": self.job_id, @@ -355,8 +390,8 @@ class FakeJob(BaseModel): return obj - def _job_id_matcher(self, regex, argument): - regex_match = regex.match(argument) + def _job_id_matcher(self, regex: Pattern[str], argument: str) -> bool: + regex_match = regex.match(argument) is not None length_match = len(argument) <= 64 return regex_match and length_match @@ -364,11 +399,11 @@ class FakeJob(BaseModel): class FakeJobExecution(BaseModel): def __init__( self, - job_id, - thing_arn, - status="QUEUED", - force_canceled=False, - status_details_map=None, + job_id: str, + thing_arn: str, + status: str = "QUEUED", + force_canceled: bool = False, + status_details_map: Optional[Dict[str, Any]] = None, ): self.job_id = job_id self.status = status # IN_PROGRESS | CANCELED | COMPLETED @@ -382,8 +417,8 @@ class FakeJobExecution(BaseModel): self.version_number = 123 self.approximate_seconds_before_time_out = 123 - def to_get_dict(self): - obj = { + def to_get_dict(self) -> Dict[str, Any]: + return { "jobId": self.job_id, "status": self.status, "forceCanceled": self.force_canceled, @@ -397,10 +432,8 @@ class FakeJobExecution(BaseModel): "approximateSecondsBeforeTimedOut": self.approximate_seconds_before_time_out, } - return obj - - def to_dict(self): - obj = { + def to_dict(self) -> Dict[str, Any]: + return { "jobId": self.job_id, "thingArn": self.thing_arn, "jobExecutionSummary": { @@ -412,11 +445,9 @@ class FakeJobExecution(BaseModel): }, } - return obj - class FakeEndpoint(BaseModel): - def __init__(self, endpoint_type, region_name): + def __init__(self, endpoint_type: str, region_name: str): if endpoint_type not in [ "iot:Data", "iot:Data-ATS", @@ -441,34 +472,30 @@ class FakeEndpoint(BaseModel): self.endpoint = f"{identifier}.jobs.iot.{self.region_name}.amazonaws.com" self.endpoint_type = endpoint_type - def to_get_dict(self): - obj = { + def to_get_dict(self) -> Dict[str, str]: + return { "endpointAddress": self.endpoint, } - return obj - - def to_dict(self): - obj = { + def to_dict(self) -> Dict[str, str]: + return { "endpointAddress": self.endpoint, } - return obj - class FakeRule(BaseModel): def __init__( self, - rule_name, - description, - created_at, - rule_disabled, - topic_pattern, - actions, - error_action, - sql, - aws_iot_sql_version, - region_name, + rule_name: str, + description: str, + created_at: int, + rule_disabled: bool, + topic_pattern: Optional[str], + actions: List[Dict[str, Any]], + error_action: Dict[str, Any], + sql: str, + aws_iot_sql_version: str, + region_name: str, ): self.region_name = region_name self.rule_name = rule_name @@ -482,7 +509,7 @@ class FakeRule(BaseModel): self.aws_iot_sql_version = aws_iot_sql_version or "2016-03-23" self.arn = f"arn:aws:iot:{self.region_name}:1:rule/{rule_name}" - def to_get_dict(self): + def to_get_dict(self) -> Dict[str, Any]: return { "rule": { "actions": self.actions, @@ -497,7 +524,7 @@ class FakeRule(BaseModel): "ruleArn": self.arn, } - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "ruleName": self.rule_name, "createdAt": self.created_at, @@ -510,14 +537,14 @@ class FakeRule(BaseModel): class FakeDomainConfiguration(BaseModel): def __init__( self, - region_name, - domain_configuration_name, - domain_name, - server_certificate_arns, - domain_configuration_status, - service_type, - authorizer_config, - domain_type, + region_name: str, + domain_configuration_name: str, + domain_name: str, + server_certificate_arns: List[str], + domain_configuration_status: str, + service_type: str, + authorizer_config: Optional[Dict[str, Any]], + domain_type: str, ): if service_type and service_type not in ["DATA", "CREDENTIAL_PROVIDER", "JOBS"]: raise InvalidRequestException( @@ -539,7 +566,7 @@ class FakeDomainConfiguration(BaseModel): self.domain_type = domain_type self.last_status_change_date = time.time() - def to_description_dict(self): + def to_description_dict(self) -> Dict[str, Any]: return { "domainConfigurationName": self.domain_configuration_name, "domainConfigurationArn": self.domain_configuration_arn, @@ -552,7 +579,7 @@ class FakeDomainConfiguration(BaseModel): "lastStatusChangeDate": self.last_status_change_date, } - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "domainConfigurationName": self.domain_configuration_name, "domainConfigurationArn": self.domain_configuration_arn, @@ -560,24 +587,30 @@ class FakeDomainConfiguration(BaseModel): class IoTBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.things = OrderedDict() - self.jobs = OrderedDict() - self.job_executions = OrderedDict() - self.thing_types = OrderedDict() - self.thing_groups = OrderedDict() - self.ca_certificates = OrderedDict() - self.certificates = OrderedDict() - self.policies = OrderedDict() - self.principal_policies = OrderedDict() - self.principal_things = OrderedDict() - self.rules = OrderedDict() - self.endpoint = None - self.domain_configurations = OrderedDict() + self.things: Dict[str, FakeThing] = OrderedDict() + self.jobs: Dict[str, FakeJob] = OrderedDict() + self.job_executions: Dict[Tuple[str, str], FakeJobExecution] = OrderedDict() + self.thing_types: Dict[str, FakeThingType] = OrderedDict() + self.thing_groups: Dict[str, FakeThingGroup] = OrderedDict() + self.ca_certificates: Dict[str, FakeCaCertificate] = OrderedDict() + self.certificates: Dict[str, FakeCertificate] = OrderedDict() + self.policies: Dict[str, FakePolicy] = OrderedDict() + self.principal_policies: Dict[ + Tuple[str, str], Tuple[str, FakePolicy] + ] = OrderedDict() + self.principal_things: Dict[ + Tuple[str, str], Tuple[str, FakeThing] + ] = OrderedDict() + self.rules: Dict[str, FakeRule] = OrderedDict() + self.endpoint: Optional[FakeEndpoint] = None + self.domain_configurations: Dict[str, FakeDomainConfiguration] = OrderedDict() @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "iot" @@ -590,7 +623,9 @@ class IoTBackend(BaseBackend): policy_supported=False, ) - def create_certificate_from_csr(self, csr, set_as_active): + def create_certificate_from_csr( + self, csr: str, set_as_active: bool + ) -> FakeCertificate: cert = x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend()) pem = self._generate_certificate_pem( domain_name="example.com", subject=cert.subject @@ -599,11 +634,13 @@ class IoTBackend(BaseBackend): pem, ca_certificate_pem=None, set_as_active=set_as_active, status="INACTIVE" ) - def _generate_certificate_pem(self, domain_name, subject, key=None): - sans = set() - - sans.add(domain_name) - sans = [x509.DNSName(item) for item in sans] + def _generate_certificate_pem( + self, + domain_name: str, + subject: x509.Name, + key: Optional[rsa.RSAPrivateKey] = None, + ) -> str: + sans = [x509.DNSName(domain_name)] key = key or rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=default_backend() @@ -632,7 +669,12 @@ class IoTBackend(BaseBackend): return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") - def create_thing(self, thing_name, thing_type_name, attribute_payload): + def create_thing( + self, + thing_name: str, + thing_type_name: str, + attribute_payload: Optional[Dict[str, Any]], + ) -> Tuple[str, str]: thing_types = self.list_thing_types() thing_type = None if thing_type_name: @@ -649,7 +691,7 @@ class IoTBackend(BaseBackend): msg=f"Can not create new thing with depreated thing type:{thing_type_name}" ) if attribute_payload is None: - attributes = {} + attributes: Dict[str, Any] = {} elif "attributes" not in attribute_payload: attributes = {} else: @@ -660,7 +702,9 @@ class IoTBackend(BaseBackend): self.things[thing.arn] = thing return thing.thing_name, thing.arn - def create_thing_type(self, thing_type_name, thing_type_properties): + def create_thing_type( + self, thing_type_name: str, thing_type_properties: Dict[str, Any] + ) -> Tuple[str, str]: if thing_type_properties is None: thing_type_properties = {} thing_type = FakeThingType( @@ -669,7 +713,9 @@ class IoTBackend(BaseBackend): self.thing_types[thing_type.arn] = thing_type return thing_type.thing_type_name, thing_type.arn - def list_thing_types(self, thing_type_name=None): + def list_thing_types( + self, thing_type_name: Optional[str] = None + ) -> Iterable[FakeThingType]: if thing_type_name: # It's weird but thing_type_name is filtered by forward match, not complete match return [ @@ -680,8 +726,13 @@ class IoTBackend(BaseBackend): return self.thing_types.values() def list_things( - self, attribute_name, attribute_value, thing_type_name, max_results, token - ): + self, + attribute_name: str, + attribute_value: str, + thing_type_name: str, + max_results: int, + token: Optional[str], + ) -> Tuple[Iterable[FakeThing], Optional[str]]: all_things = [_.to_dict() for _ in self.things.values()] if attribute_name is not None and thing_type_name is not None: filtered_things = list( @@ -718,23 +769,23 @@ class IoTBackend(BaseBackend): str(max_results) if len(filtered_things) > max_results else None ) else: - token = int(token) - things = filtered_things[token : token + max_results] + int_token = int(token) + things = filtered_things[int_token : int_token + max_results] next_token = ( - str(token + max_results) - if len(filtered_things) > token + max_results + str(int_token + max_results) + if len(filtered_things) > int_token + max_results else None ) return things, next_token - def describe_thing(self, thing_name): + def describe_thing(self, thing_name: str) -> FakeThing: things = [_ for _ in self.things.values() if _.thing_name == thing_name] if len(things) == 0: raise ResourceNotFoundException() return things[0] - def describe_thing_type(self, thing_type_name): + def describe_thing_type(self, thing_type_name: str) -> FakeThingType: thing_types = [ _ for _ in self.thing_types.values() if _.thing_type_name == thing_type_name ] @@ -742,11 +793,11 @@ class IoTBackend(BaseBackend): raise ResourceNotFoundException() return thing_types[0] - def describe_endpoint(self, endpoint_type): + def describe_endpoint(self, endpoint_type: str) -> FakeEndpoint: self.endpoint = FakeEndpoint(endpoint_type, self.region_name) return self.endpoint - def delete_thing(self, thing_name): + def delete_thing(self, thing_name: str) -> None: """ The ExpectedVersion-parameter is not yet implemented """ @@ -760,12 +811,14 @@ class IoTBackend(BaseBackend): del self.things[thing.arn] - def delete_thing_type(self, thing_type_name): + def delete_thing_type(self, thing_type_name: str) -> None: # can raise ResourceNotFoundError thing_type = self.describe_thing_type(thing_type_name) del self.thing_types[thing_type.arn] - def deprecate_thing_type(self, thing_type_name, undo_deprecate): + def deprecate_thing_type( + self, thing_type_name: str, undo_deprecate: bool + ) -> FakeThingType: thing_types = [ _ for _ in self.thing_types.values() if _.thing_type_name == thing_type_name ] @@ -776,17 +829,16 @@ class IoTBackend(BaseBackend): def update_thing( self, - thing_name, - thing_type_name, - attribute_payload, - remove_thing_type, - ): + thing_name: str, + thing_type_name: str, + attribute_payload: Optional[Dict[str, Any]], + remove_thing_type: bool, + ) -> None: """ The ExpectedVersion-parameter is not yet implemented """ # if attributes payload = {}, nothing thing = self.describe_thing(thing_name) - thing_type = None if remove_thing_type and thing_type_name: raise InvalidRequestException() @@ -820,7 +872,9 @@ class IoTBackend(BaseBackend): else: thing.attributes.update(attributes) - def create_keys_and_certificate(self, set_as_active): + def create_keys_and_certificate( + self, set_as_active: bool + ) -> Tuple[FakeCertificate, Dict[str, str]]: # implement here # caCertificate can be blank private_key = rsa.generate_private_key( @@ -852,17 +906,19 @@ class IoTBackend(BaseBackend): self.certificates[certificate.certificate_id] = certificate return certificate, key_pair - def delete_ca_certificate(self, certificate_id): + def delete_ca_certificate(self, certificate_id: str) -> None: cert = self.describe_ca_certificate(certificate_id) self._validation_delete(cert) del self.ca_certificates[certificate_id] - def delete_certificate(self, certificate_id, force_delete): + def delete_certificate(self, certificate_id: str, force_delete: bool) -> None: cert = self.describe_certificate(certificate_id) self._validation_delete(cert, force_delete) del self.certificates[certificate_id] - def _validation_delete(self, cert, force_delete: bool = False): + def _validation_delete( + self, cert: FakeCertificate, force_delete: bool = False + ) -> None: if cert.status == "ACTIVE": raise CertificateStateException( "Certificate must be deactivated (not ACTIVE) before deletion.", @@ -890,12 +946,12 @@ class IoTBackend(BaseBackend): % certs[0] ) - def describe_ca_certificate(self, certificate_id): + def describe_ca_certificate(self, certificate_id: str) -> FakeCaCertificate: if certificate_id not in self.ca_certificates: raise ResourceNotFoundException() return self.ca_certificates[certificate_id] - def describe_certificate(self, certificate_id): + def describe_certificate(self, certificate_id: str) -> FakeCertificate: certs = [ _ for _ in self.certificates.values() if _.certificate_id == certificate_id ] @@ -903,16 +959,16 @@ class IoTBackend(BaseBackend): raise ResourceNotFoundException() return certs[0] - def get_registration_code(self): + def get_registration_code(self) -> str: return str(random.uuid4()) - def list_certificates(self): + def list_certificates(self) -> Iterable[FakeCertificate]: """ Pagination is not yet implemented """ return self.certificates.values() - def list_certificates_by_ca(self, ca_certificate_id): + def list_certificates_by_ca(self, ca_certificate_id: str) -> List[FakeCertificate]: """ Pagination is not yet implemented """ @@ -922,7 +978,9 @@ class IoTBackend(BaseBackend): if cert.ca_certificate_id == ca_certificate_id ] - def __raise_if_certificate_already_exists(self, certificate_id, certificate_arn): + def __raise_if_certificate_already_exists( + self, certificate_id: str, certificate_arn: str + ) -> None: if certificate_id in self.certificates: raise ResourceAlreadyExistsException( "The certificate is already provisioned or registered", @@ -932,10 +990,10 @@ class IoTBackend(BaseBackend): def register_ca_certificate( self, - ca_certificate, - set_as_active, - registration_config, - ): + ca_certificate: str, + set_as_active: bool, + registration_config: Dict[str, str], + ) -> FakeCaCertificate: """ The VerificationCertificate-parameter is not yet implemented """ @@ -950,15 +1008,19 @@ class IoTBackend(BaseBackend): self.ca_certificates[certificate.certificate_id] = certificate return certificate - def _find_ca_certificate(self, ca_certificate_pem): + def _find_ca_certificate(self, ca_certificate_pem: Optional[str]) -> Optional[str]: for ca_cert in self.ca_certificates.values(): if ca_cert.certificate_pem == ca_certificate_pem: return ca_cert.certificate_id return None def register_certificate( - self, certificate_pem, ca_certificate_pem, set_as_active, status - ): + self, + certificate_pem: str, + ca_certificate_pem: Optional[str], + set_as_active: bool, + status: str, + ) -> FakeCertificate: ca_certificate_id = self._find_ca_certificate(ca_certificate_pem) certificate = FakeCertificate( certificate_pem, @@ -974,7 +1036,9 @@ class IoTBackend(BaseBackend): self.certificates[certificate.certificate_id] = certificate return certificate - def register_certificate_without_ca(self, certificate_pem, status): + def register_certificate_without_ca( + self, certificate_pem: str, status: str + ) -> FakeCertificate: certificate = FakeCertificate( certificate_pem, status, self.account_id, self.region_name ) @@ -985,7 +1049,12 @@ class IoTBackend(BaseBackend): self.certificates[certificate.certificate_id] = certificate return certificate - def update_ca_certificate(self, certificate_id, new_status, config): + def update_ca_certificate( + self, + certificate_id: str, + new_status: Optional[str], + config: Optional[Dict[str, str]], + ) -> None: """ The newAutoRegistrationStatus and removeAutoRegistration-parameters are not yet implemented """ @@ -995,12 +1064,14 @@ class IoTBackend(BaseBackend): if config is not None: cert.registration_config = config - def update_certificate(self, certificate_id, new_status): + def update_certificate(self, certificate_id: str, new_status: str) -> None: cert = self.describe_certificate(certificate_id) # TODO: validate new_status cert.status = new_status - def create_policy(self, policy_name, policy_document): + def create_policy( + self, policy_name: str, policy_document: Dict[str, Any] + ) -> FakePolicy: if policy_name in self.policies: current_policy = self.policies[policy_name] raise ResourceAlreadyExistsException( @@ -1014,7 +1085,7 @@ class IoTBackend(BaseBackend): self.policies[policy.name] = policy return policy - def attach_policy(self, policy_name, target): + def attach_policy(self, policy_name: str, target: str) -> None: principal = self._get_principal(target) policy = self.get_policy(policy_name) k = (target, policy_name) @@ -1022,8 +1093,8 @@ class IoTBackend(BaseBackend): return self.principal_policies[k] = (principal, policy) - def detach_policy(self, policy_name, target): - # this may raises ResourceNotFoundException + def detach_policy(self, policy_name: str, target: str) -> None: + # this may raise ResourceNotFoundException self._get_principal(target) self.get_policy(policy_name) @@ -1032,21 +1103,19 @@ class IoTBackend(BaseBackend): raise ResourceNotFoundException() del self.principal_policies[k] - def list_attached_policies(self, target): - policies = [v[1] for k, v in self.principal_policies.items() if k[0] == target] - return policies + def list_attached_policies(self, target: str) -> List[FakePolicy]: + return [v[1] for k, v in self.principal_policies.items() if k[0] == target] - def list_policies(self): - policies = self.policies.values() - return policies + def list_policies(self) -> Iterable[FakePolicy]: + return self.policies.values() - def get_policy(self, policy_name): + def get_policy(self, policy_name: str) -> FakePolicy: policies = [_ for _ in self.policies.values() if _.name == policy_name] if len(policies) == 0: raise ResourceNotFoundException() return policies[0] - def delete_policy(self, policy_name): + def delete_policy(self, policy_name: str) -> None: policies = [ k[1] for k, v in self.principal_policies.items() if k[1] == policy_name ] @@ -1064,7 +1133,9 @@ class IoTBackend(BaseBackend): ) del self.policies[policy.name] - def create_policy_version(self, policy_name, policy_document, set_as_default): + def create_policy_version( + self, policy_name: str, policy_document: Dict[str, Any], set_as_default: bool + ) -> FakePolicyVersion: policy = self.get_policy(policy_name) if not policy: raise ResourceNotFoundException() @@ -1085,7 +1156,7 @@ class IoTBackend(BaseBackend): self.set_default_policy_version(policy_name, version.version_id) return version - def set_default_policy_version(self, policy_name, version_id): + def set_default_policy_version(self, policy_name: str, version_id: str) -> None: policy = self.get_policy(policy_name) if not policy: raise ResourceNotFoundException() @@ -1093,11 +1164,13 @@ class IoTBackend(BaseBackend): if version.version_id == version_id: version.is_default = True policy.default_version_id = version.version_id - policy.document = version.document + policy.document = version.document # type: ignore else: version.is_default = False - def get_policy_version(self, policy_name, version_id): + def get_policy_version( + self, policy_name: str, version_id: str + ) -> FakePolicyVersion: policy = self.get_policy(policy_name) if not policy: raise ResourceNotFoundException() @@ -1106,13 +1179,13 @@ class IoTBackend(BaseBackend): return version raise ResourceNotFoundException() - def list_policy_versions(self, policy_name): + def list_policy_versions(self, policy_name: str) -> Iterable[FakePolicyVersion]: policy = self.get_policy(policy_name) if not policy: raise ResourceNotFoundException() return policy.versions - def delete_policy_version(self, policy_name, version_id): + def delete_policy_version(self, policy_name: str, version_id: str) -> None: policy = self.get_policy(policy_name) if not policy: raise ResourceNotFoundException() @@ -1126,7 +1199,7 @@ class IoTBackend(BaseBackend): return raise ResourceNotFoundException() - def _get_principal(self, principal_arn): + def _get_principal(self, principal_arn: str) -> Any: """ raise ResourceNotFoundException """ @@ -1157,7 +1230,7 @@ class IoTBackend(BaseBackend): raise ResourceNotFoundException() - def attach_principal_policy(self, policy_name, principal_arn): + def attach_principal_policy(self, policy_name: str, principal_arn: str) -> None: principal = self._get_principal(principal_arn) policy = self.get_policy(policy_name) k = (principal_arn, policy_name) @@ -1165,7 +1238,7 @@ class IoTBackend(BaseBackend): return self.principal_policies[k] = (principal, policy) - def detach_principal_policy(self, policy_name, principal_arn): + def detach_principal_policy(self, policy_name: str, principal_arn: str) -> None: # this may raises ResourceNotFoundException self._get_principal(principal_arn) self.get_policy(policy_name) @@ -1175,13 +1248,13 @@ class IoTBackend(BaseBackend): raise ResourceNotFoundException() del self.principal_policies[k] - def list_principal_policies(self, principal_arn): + def list_principal_policies(self, principal_arn: str) -> List[FakePolicy]: policies = [ v[1] for k, v in self.principal_policies.items() if k[0] == principal_arn ] return policies - def list_policy_principals(self, policy_name): + def list_policy_principals(self, policy_name: str) -> List[str]: # this action is deprecated # https://docs.aws.amazon.com/iot/latest/apireference/API_ListTargetsForPolicy.html # should use ListTargetsForPolicy instead @@ -1190,13 +1263,13 @@ class IoTBackend(BaseBackend): ] return principals - def list_targets_for_policy(self, policy_name): + def list_targets_for_policy(self, policy_name: str) -> List[str]: # This behaviour is different to list_policy_principals which will just return an empty list if policy_name not in self.policies: raise ResourceNotFoundException("Policy not found") return self.list_policy_principals(policy_name=policy_name) - def attach_thing_principal(self, thing_name, principal_arn): + def attach_thing_principal(self, thing_name: str, principal_arn: str) -> None: principal = self._get_principal(principal_arn) thing = self.describe_thing(thing_name) k = (principal_arn, thing_name) @@ -1204,7 +1277,7 @@ class IoTBackend(BaseBackend): return self.principal_things[k] = (principal, thing) - def detach_thing_principal(self, thing_name, principal_arn): + def detach_thing_principal(self, thing_name: str, principal_arn: str) -> None: # this may raises ResourceNotFoundException self._get_principal(principal_arn) self.describe_thing(thing_name) @@ -1214,13 +1287,13 @@ class IoTBackend(BaseBackend): raise ResourceNotFoundException() del self.principal_things[k] - def list_principal_things(self, principal_arn): + def list_principal_things(self, principal_arn: str) -> List[str]: thing_names = [ k[1] for k, v in self.principal_things.items() if k[0] == principal_arn ] return thing_names - def list_thing_principals(self, thing_name): + def list_thing_principals(self, thing_name: str) -> List[str]: things = [_ for _ in self.things.values() if _.thing_name == thing_name] if len(things) == 0: @@ -1234,7 +1307,7 @@ class IoTBackend(BaseBackend): ] return principals - def describe_thing_group(self, thing_group_name): + def describe_thing_group(self, thing_group_name: str) -> FakeThingGroup: thing_groups = [ _ for _ in self.thing_groups.values() @@ -1245,8 +1318,11 @@ class IoTBackend(BaseBackend): return thing_groups[0] def create_thing_group( - self, thing_group_name, parent_group_name, thing_group_properties - ): + self, + thing_group_name: str, + parent_group_name: str, + thing_group_properties: Dict[str, Any], + ) -> Tuple[str, str, str]: thing_group = FakeThingGroup( thing_group_name, parent_group_name, @@ -1271,7 +1347,7 @@ class IoTBackend(BaseBackend): self.thing_groups[thing_group.arn] = thing_group return thing_group.thing_group_name, thing_group.arn, thing_group.thing_group_id - def delete_thing_group(self, thing_group_name): + def delete_thing_group(self, thing_group_name: str) -> None: """ The ExpectedVersion-parameter is not yet implemented """ @@ -1293,7 +1369,12 @@ class IoTBackend(BaseBackend): # AWS returns success even if the thing group does not exist. pass - def list_thing_groups(self, parent_group, name_prefix_filter, recursive): + def list_thing_groups( + self, + parent_group: Optional[str], + name_prefix_filter: Optional[str], + recursive: Optional[bool], + ) -> List[FakeThingGroup]: if recursive is None: recursive = True if name_prefix_filter is None: @@ -1320,8 +1401,11 @@ class IoTBackend(BaseBackend): ] def update_thing_group( - self, thing_group_name, thing_group_properties, expected_version - ): + self, + thing_group_name: str, + thing_group_properties: Dict[str, Any], + expected_version: int, + ) -> int: thing_group = self.describe_thing_group(thing_group_name) if expected_version and expected_version != thing_group.version: raise VersionConflictException(thing_group_name) @@ -1331,17 +1415,15 @@ class IoTBackend(BaseBackend): attributes = attribute_payload["attributes"] if attributes: # might not exist yet, for example when the thing group was created without attributes - current_attribute_payload = ( - thing_group.thing_group_properties.setdefault( - "attributePayload", {"attributes": {}} - ) + current_attribute_payload = thing_group.thing_group_properties.setdefault( + "attributePayload", {"attributes": {}} # type: ignore ) if not do_merge: - current_attribute_payload["attributes"] = attributes + current_attribute_payload["attributes"] = attributes # type: ignore else: - current_attribute_payload["attributes"].update(attributes) + current_attribute_payload["attributes"].update(attributes) # type: ignore elif attribute_payload is not None and "attributes" not in attribute_payload: - thing_group.attributes = {} + thing_group.attributes = {} # type: ignore if "thingGroupDescription" in thing_group_properties: thing_group.thing_group_properties[ "thingGroupDescription" @@ -1349,7 +1431,9 @@ class IoTBackend(BaseBackend): thing_group.version = thing_group.version + 1 return thing_group.version - def _identify_thing_group(self, thing_group_name, thing_group_arn): + def _identify_thing_group( + self, thing_group_name: Optional[str], thing_group_arn: Optional[str] + ) -> FakeThingGroup: # identify thing group if thing_group_name is None and thing_group_arn is None: raise InvalidRequestException( @@ -1367,7 +1451,9 @@ class IoTBackend(BaseBackend): thing_group = self.thing_groups[thing_group_arn] return thing_group - def _identify_thing(self, thing_name, thing_arn): + def _identify_thing( + self, thing_name: Optional[str], thing_arn: Optional[str] + ) -> FakeThing: # identify thing if thing_name is None and thing_arn is None: raise InvalidRequestException( @@ -1386,8 +1472,12 @@ class IoTBackend(BaseBackend): return thing def add_thing_to_thing_group( - self, thing_group_name, thing_group_arn, thing_name, thing_arn - ): + self, + thing_group_name: str, + thing_group_arn: Optional[str], + thing_name: str, + thing_arn: Optional[str], + ) -> None: thing_group = self._identify_thing_group(thing_group_name, thing_group_arn) thing = self._identify_thing(thing_name, thing_arn) if thing.arn in thing_group.things: @@ -1396,8 +1486,12 @@ class IoTBackend(BaseBackend): thing_group.things[thing.arn] = thing def remove_thing_from_thing_group( - self, thing_group_name, thing_group_arn, thing_name, thing_arn - ): + self, + thing_group_name: str, + thing_group_arn: Optional[str], + thing_name: str, + thing_arn: Optional[str], + ) -> None: thing_group = self._identify_thing_group(thing_group_name, thing_group_arn) thing = self._identify_thing(thing_name, thing_arn) if thing.arn not in thing_group.things: @@ -1405,14 +1499,14 @@ class IoTBackend(BaseBackend): return del thing_group.things[thing.arn] - def list_things_in_thing_group(self, thing_group_name): + def list_things_in_thing_group(self, thing_group_name: str) -> Iterable[FakeThing]: """ Pagination and the recursive-parameter is not yet implemented """ thing_group = self.describe_thing_group(thing_group_name) return thing_group.things.values() - def list_thing_groups_for_thing(self, thing_name): + def list_thing_groups_for_thing(self, thing_name: str) -> List[Dict[str, str]]: """ Pagination is not yet implemented """ @@ -1430,8 +1524,11 @@ class IoTBackend(BaseBackend): return ret def update_thing_groups_for_thing( - self, thing_name, thing_groups_to_add, thing_groups_to_remove - ): + self, + thing_name: str, + thing_groups_to_add: List[str], + thing_groups_to_remove: List[str], + ) -> None: thing = self.describe_thing(thing_name) for thing_group_name in thing_groups_to_add: thing_group = self.describe_thing_group(thing_group_name) @@ -1446,16 +1543,16 @@ class IoTBackend(BaseBackend): def create_job( self, - job_id, - targets, - document_source, - document, - description, - presigned_url_config, - target_selection, - job_executions_rollout_config, - document_parameters, - ): + job_id: str, + targets: List[str], + document_source: str, + document: str, + description: str, + presigned_url_config: Dict[str, Any], + target_selection: str, + job_executions_rollout_config: Dict[str, Any], + document_parameters: Dict[str, str], + ) -> Tuple[str, str, str]: job = FakeJob( job_id, targets, @@ -1476,13 +1573,13 @@ class IoTBackend(BaseBackend): self.job_executions[(job_id, thing_name)] = job_execution return job.job_arn, job_id, description - def describe_job(self, job_id): + def describe_job(self, job_id: str) -> FakeJob: jobs = [_ for _ in self.jobs.values() if _.job_id == job_id] if len(jobs) == 0: raise ResourceNotFoundException() return jobs[0] - def delete_job(self, job_id, force): + def delete_job(self, job_id: str, force: bool) -> None: job = self.jobs[job_id] if job.status == "IN_PROGRESS" and force: @@ -1492,7 +1589,9 @@ class IoTBackend(BaseBackend): else: raise InvalidStateTransitionException() - def cancel_job(self, job_id, reason_code, comment, force): + def cancel_job( + self, job_id: str, reason_code: str, comment: str, force: bool + ) -> FakeJob: job = self.jobs[job_id] job.reason_code = reason_code if reason_code is not None else job.reason_code @@ -1509,10 +1608,12 @@ class IoTBackend(BaseBackend): return job - def get_job_document(self, job_id): + def get_job_document(self, job_id: str) -> FakeJob: return self.jobs[job_id] - def list_jobs(self, max_results, token): + def list_jobs( + self, max_results: int, token: Optional[str] + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: """ The following parameter are not yet implemented: Status, TargetSelection, ThingGroupName, ThingGroupId """ @@ -1523,17 +1624,19 @@ class IoTBackend(BaseBackend): jobs = filtered_jobs[0:max_results] next_token = str(max_results) if len(filtered_jobs) > max_results else None else: - token = int(token) - jobs = filtered_jobs[token : token + max_results] + int_token = int(token) + jobs = filtered_jobs[int_token : int_token + max_results] next_token = ( - str(token + max_results) - if len(filtered_jobs) > token + max_results + str(int_token + max_results) + if len(filtered_jobs) > int_token + max_results else None ) return jobs, next_token - def describe_job_execution(self, job_id, thing_name, execution_number): + def describe_job_execution( + self, job_id: str, thing_name: str, execution_number: int + ) -> FakeJobExecution: try: job_execution = self.job_executions[(job_id, thing_name)] except KeyError: @@ -1547,7 +1650,7 @@ class IoTBackend(BaseBackend): return job_execution - def cancel_job_execution(self, job_id, thing_name, force): + def cancel_job_execution(self, job_id: str, thing_name: str, force: bool) -> None: """ The parameters ExpectedVersion and StatusDetails are not yet implemented """ @@ -1570,7 +1673,9 @@ class IoTBackend(BaseBackend): else: raise InvalidStateTransitionException() - def delete_job_execution(self, job_id, thing_name, execution_number, force): + def delete_job_execution( + self, job_id: str, thing_name: str, execution_number: int, force: bool + ) -> None: job_execution = self.job_executions[(job_id, thing_name)] if job_execution.execution_number != execution_number: @@ -1583,7 +1688,9 @@ class IoTBackend(BaseBackend): else: raise InvalidStateTransitionException() - def list_job_executions_for_job(self, job_id, status, max_results, next_token): + def list_job_executions_for_job( + self, job_id: str, status: str, max_results: int, token: Optional[str] + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: job_executions = [ self.job_executions[je].to_dict() for je in self.job_executions @@ -1598,23 +1705,22 @@ class IoTBackend(BaseBackend): ) ) - token = next_token if token is None: job_executions = job_executions[0:max_results] next_token = str(max_results) if len(job_executions) > max_results else None else: - token = int(token) - job_executions = job_executions[token : token + max_results] + int_token = int(token) + job_executions = job_executions[int_token : int_token + max_results] next_token = ( - str(token + max_results) - if len(job_executions) > token + max_results + str(int_token + max_results) + if len(job_executions) > int_token + max_results else None ) return job_executions, next_token @paginate(PAGINATION_MODEL) - def list_job_executions_for_thing(self, thing_name, status): + def list_job_executions_for_thing(self, thing_name: str, status: Optional[str]) -> List[Dict[str, Any]]: # type: ignore[misc] job_executions = [ self.job_executions[je].to_dict() for je in self.job_executions @@ -1631,15 +1737,15 @@ class IoTBackend(BaseBackend): return job_executions - def list_topic_rules(self): + def list_topic_rules(self) -> List[Dict[str, Any]]: return [r.to_dict() for r in self.rules.values()] - def get_topic_rule(self, rule_name): + def get_topic_rule(self, rule_name: str) -> Dict[str, Any]: if rule_name not in self.rules: raise ResourceNotFoundException() return self.rules[rule_name].to_get_dict() - def create_topic_rule(self, rule_name, sql, **kwargs): + def create_topic_rule(self, rule_name: str, sql: str, **kwargs: Any) -> None: if rule_name in self.rules: raise ResourceAlreadyExistsException( "Rule with given name already exists", "", self.rules[rule_name].arn @@ -1655,33 +1761,33 @@ class IoTBackend(BaseBackend): **kwargs, ) - def replace_topic_rule(self, rule_name, **kwargs): + def replace_topic_rule(self, rule_name: str, **kwargs: Any) -> None: self.delete_topic_rule(rule_name) self.create_topic_rule(rule_name, **kwargs) - def delete_topic_rule(self, rule_name): + def delete_topic_rule(self, rule_name: str) -> None: if rule_name not in self.rules: raise ResourceNotFoundException() del self.rules[rule_name] - def enable_topic_rule(self, rule_name): + def enable_topic_rule(self, rule_name: str) -> None: if rule_name not in self.rules: raise ResourceNotFoundException() self.rules[rule_name].rule_disabled = False - def disable_topic_rule(self, rule_name): + def disable_topic_rule(self, rule_name: str) -> None: if rule_name not in self.rules: raise ResourceNotFoundException() self.rules[rule_name].rule_disabled = True def create_domain_configuration( self, - domain_configuration_name, - domain_name, - server_certificate_arns, - authorizer_config, - service_type, - ): + domain_configuration_name: str, + domain_name: str, + server_certificate_arns: List[str], + authorizer_config: Dict[str, Any], + service_type: str, + ) -> FakeDomainConfiguration: """ The ValidationCertificateArn-parameter is not yet implemented """ @@ -1707,26 +1813,28 @@ class IoTBackend(BaseBackend): ) return self.domain_configurations[domain_configuration_name] - def delete_domain_configuration(self, domain_configuration_name): + def delete_domain_configuration(self, domain_configuration_name: str) -> None: if domain_configuration_name not in self.domain_configurations: raise ResourceNotFoundException("The specified resource does not exist.") del self.domain_configurations[domain_configuration_name] - def describe_domain_configuration(self, domain_configuration_name): + def describe_domain_configuration( + self, domain_configuration_name: str + ) -> FakeDomainConfiguration: if domain_configuration_name not in self.domain_configurations: raise ResourceNotFoundException("The specified resource does not exist.") return self.domain_configurations[domain_configuration_name] - def list_domain_configurations(self): + def list_domain_configurations(self) -> List[Dict[str, Any]]: return [_.to_dict() for _ in self.domain_configurations.values()] def update_domain_configuration( self, - domain_configuration_name, - authorizer_config, - domain_configuration_status, - remove_authorizer_config, - ): + domain_configuration_name: str, + authorizer_config: Dict[str, Any], + domain_configuration_status: str, + remove_authorizer_config: Optional[bool], + ) -> FakeDomainConfiguration: if domain_configuration_name not in self.domain_configurations: raise ResourceNotFoundException("The specified resource does not exist.") domain_configuration = self.domain_configurations[domain_configuration_name] @@ -1740,15 +1848,14 @@ class IoTBackend(BaseBackend): domain_configuration.authorizer_config = None return domain_configuration - def search_index(self, query_string): + def search_index(self, query_string: str) -> List[Dict[str, Any]]: """ Pagination is not yet implemented. Only basic search queries are supported for now. """ things = [ thing for thing in self.things.values() if thing.matches(query_string) ] - groups = [] - return [t.to_dict() for t in things], groups + return [t.to_dict() for t in things] iot_backends = BackendDict(IoTBackend, "iot") diff --git a/moto/iot/responses.py b/moto/iot/responses.py index 5138677cc..8f6f52521 100644 --- a/moto/iot/responses.py +++ b/moto/iot/responses.py @@ -1,19 +1,21 @@ import json +from typing import Any from urllib.parse import unquote +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse -from .models import iot_backends +from .models import iot_backends, IoTBackend class IoTResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="iot") @property - def iot_backend(self): + def iot_backend(self) -> IoTBackend: return iot_backends[self.current_account][self.region] - def create_certificate_from_csr(self): + def create_certificate_from_csr(self) -> str: certificate_signing_request = self._get_param("certificateSigningRequest") set_as_active = self._get_param("setAsActive") cert = self.iot_backend.create_certificate_from_csr( @@ -27,7 +29,7 @@ class IoTResponse(BaseResponse): } ) - def create_thing(self): + def create_thing(self) -> str: thing_name = self._get_param("thingName") thing_type_name = self._get_param("thingTypeName") attribute_payload = self._get_param("attributePayload") @@ -38,7 +40,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict(thingName=thing_name, thingArn=thing_arn)) - def create_thing_type(self): + def create_thing_type(self) -> str: thing_type_name = self._get_param("thingTypeName") thing_type_properties = self._get_param("thingTypeProperties") thing_type_name, thing_type_arn = self.iot_backend.create_thing_type( @@ -48,7 +50,7 @@ class IoTResponse(BaseResponse): dict(thingTypeName=thing_type_name, thingTypeArn=thing_type_arn) ) - def list_thing_types(self): + def list_thing_types(self) -> str: previous_next_token = self._get_param("nextToken") max_results = self._get_int_param( "maxResults", 50 @@ -71,7 +73,7 @@ class IoTResponse(BaseResponse): return json.dumps(dict(thingTypes=result, nextToken=next_token)) - def list_things(self): + def list_things(self) -> str: previous_next_token = self._get_param("nextToken") max_results = self._get_int_param( "maxResults", 50 @@ -89,34 +91,34 @@ class IoTResponse(BaseResponse): return json.dumps(dict(things=things, nextToken=next_token)) - def describe_thing(self): + def describe_thing(self) -> str: thing_name = self._get_param("thingName") thing = self.iot_backend.describe_thing(thing_name=thing_name) return json.dumps(thing.to_dict(include_default_client_id=True)) - def describe_thing_type(self): + def describe_thing_type(self) -> str: thing_type_name = self._get_param("thingTypeName") thing_type = self.iot_backend.describe_thing_type( thing_type_name=thing_type_name ) return json.dumps(thing_type.to_dict()) - def describe_endpoint(self): + def describe_endpoint(self) -> str: endpoint_type = self._get_param("endpointType", "iot:Data-ATS") endpoint = self.iot_backend.describe_endpoint(endpoint_type=endpoint_type) return json.dumps(endpoint.to_dict()) - def delete_thing(self): + def delete_thing(self) -> str: thing_name = self._get_param("thingName") self.iot_backend.delete_thing(thing_name=thing_name) return json.dumps(dict()) - def delete_thing_type(self): + def delete_thing_type(self) -> str: thing_type_name = self._get_param("thingTypeName") self.iot_backend.delete_thing_type(thing_type_name=thing_type_name) return json.dumps(dict()) - def deprecate_thing_type(self): + def deprecate_thing_type(self) -> str: thing_type_name = self._get_param("thingTypeName") undo_deprecate = self._get_param("undoDeprecate") thing_type = self.iot_backend.deprecate_thing_type( @@ -124,7 +126,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(thing_type.to_dict()) - def update_thing(self): + def update_thing(self) -> str: thing_name = self._get_param("thingName") thing_type_name = self._get_param("thingTypeName") attribute_payload = self._get_param("attributePayload") @@ -137,7 +139,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def create_job(self): + def create_job(self) -> str: job_arn, job_id, description = self.iot_backend.create_job( job_id=self._get_param("jobId"), targets=self._get_param("targets"), @@ -152,7 +154,7 @@ class IoTResponse(BaseResponse): return json.dumps(dict(jobArn=job_arn, jobId=job_id, description=description)) - def describe_job(self): + def describe_job(self) -> str: job = self.iot_backend.describe_job(job_id=self._get_param("jobId")) return json.dumps( dict( @@ -178,7 +180,7 @@ class IoTResponse(BaseResponse): ) ) - def delete_job(self): + def delete_job(self) -> str: job_id = self._get_param("jobId") force = self._get_bool_param("force") @@ -186,7 +188,7 @@ class IoTResponse(BaseResponse): return json.dumps(dict()) - def cancel_job(self): + def cancel_job(self) -> str: job_id = self._get_param("jobId") reason_code = self._get_param("reasonCode") comment = self._get_param("comment") @@ -198,7 +200,7 @@ class IoTResponse(BaseResponse): return json.dumps(job.to_dict()) - def get_job_document(self): + def get_job_document(self) -> str: job = self.iot_backend.get_job_document(job_id=self._get_param("jobId")) if job.document is not None: @@ -208,7 +210,7 @@ class IoTResponse(BaseResponse): # TODO: needs to be implemented to get document_source's content from S3 return json.dumps({"document": ""}) - def list_jobs(self): + def list_jobs(self) -> str: # not the default, but makes testing easier max_results = self._get_int_param("maxResults", 50) previous_next_token = self._get_param("nextToken") @@ -218,7 +220,7 @@ class IoTResponse(BaseResponse): return json.dumps(dict(jobs=jobs, nextToken=next_token)) - def describe_job_execution(self): + def describe_job_execution(self) -> str: job_id = self._get_param("jobId") thing_name = self._get_param("thingName") execution_number = self._get_int_param("executionNumber") @@ -228,7 +230,7 @@ class IoTResponse(BaseResponse): return json.dumps(dict(execution=job_execution.to_get_dict())) - def cancel_job_execution(self): + def cancel_job_execution(self) -> str: job_id = self._get_param("jobId") thing_name = self._get_param("thingName") force = self._get_bool_param("force") @@ -239,7 +241,7 @@ class IoTResponse(BaseResponse): return json.dumps(dict()) - def delete_job_execution(self): + def delete_job_execution(self) -> str: job_id = self._get_param("jobId") thing_name = self._get_param("thingName") execution_number = self._get_int_param("executionNumber") @@ -254,7 +256,7 @@ class IoTResponse(BaseResponse): return json.dumps(dict()) - def list_job_executions_for_job(self): + def list_job_executions_for_job(self) -> str: job_id = self._get_param("jobId") status = self._get_param("status") max_results = self._get_int_param( @@ -262,12 +264,12 @@ class IoTResponse(BaseResponse): ) # not the default, but makes testing easier next_token = self._get_param("nextToken") job_executions, next_token = self.iot_backend.list_job_executions_for_job( - job_id=job_id, status=status, max_results=max_results, next_token=next_token + job_id=job_id, status=status, max_results=max_results, token=next_token ) return json.dumps(dict(executionSummaries=job_executions, nextToken=next_token)) - def list_job_executions_for_thing(self): + def list_job_executions_for_thing(self) -> str: thing_name = self._get_param("thingName") status = self._get_param("status") max_results = self._get_int_param( @@ -283,7 +285,7 @@ class IoTResponse(BaseResponse): return json.dumps(dict(executionSummaries=job_executions, nextToken=next_token)) - def create_keys_and_certificate(self): + def create_keys_and_certificate(self) -> str: set_as_active = self._get_bool_param("setAsActive") cert, key_pair = self.iot_backend.create_keys_and_certificate( set_as_active=set_as_active @@ -297,18 +299,18 @@ class IoTResponse(BaseResponse): ) ) - def delete_ca_certificate(self): + def delete_ca_certificate(self) -> str: certificate_id = self.path.split("/")[-1] self.iot_backend.delete_ca_certificate(certificate_id=certificate_id) return json.dumps(dict()) - def delete_certificate(self): + def delete_certificate(self) -> str: certificate_id = self._get_param("certificateId") force_delete = self._get_bool_param("forceDelete", False) self.iot_backend.delete_certificate(certificate_id, force_delete) return json.dumps(dict()) - def describe_ca_certificate(self): + def describe_ca_certificate(self) -> str: certificate_id = self.path.split("/")[-1] certificate = self.iot_backend.describe_ca_certificate( certificate_id=certificate_id @@ -320,7 +322,7 @@ class IoTResponse(BaseResponse): } ) - def describe_certificate(self): + def describe_certificate(self) -> str: certificate_id = self._get_param("certificateId") certificate = self.iot_backend.describe_certificate( certificate_id=certificate_id @@ -329,23 +331,23 @@ class IoTResponse(BaseResponse): dict(certificateDescription=certificate.to_description_dict()) ) - def get_registration_code(self): + def get_registration_code(self) -> str: code = self.iot_backend.get_registration_code() return json.dumps(dict(registrationCode=code)) - def list_certificates(self): + def list_certificates(self) -> str: # page_size = self._get_int_param("pageSize") # marker = self._get_param("marker") # ascending_order = self._get_param("ascendingOrder") certificates = self.iot_backend.list_certificates() return json.dumps(dict(certificates=[_.to_dict() for _ in certificates])) - def list_certificates_by_ca(self): + def list_certificates_by_ca(self) -> str: ca_certificate_id = self._get_param("caCertificateId") certificates = self.iot_backend.list_certificates_by_ca(ca_certificate_id) return json.dumps(dict(certificates=[_.to_dict() for _ in certificates])) - def register_ca_certificate(self): + def register_ca_certificate(self) -> str: ca_certificate = self._get_param("caCertificate") set_as_active = self._get_bool_param("setAsActive") registration_config = self._get_param("registrationConfig") @@ -359,7 +361,7 @@ class IoTResponse(BaseResponse): dict(certificateId=cert.certificate_id, certificateArn=cert.arn) ) - def register_certificate(self): + def register_certificate(self) -> str: certificate_pem = self._get_param("certificatePem") ca_certificate_pem = self._get_param("caCertificatePem") set_as_active = self._get_bool_param("setAsActive") @@ -375,7 +377,7 @@ class IoTResponse(BaseResponse): dict(certificateId=cert.certificate_id, certificateArn=cert.arn) ) - def register_certificate_without_ca(self): + def register_certificate_without_ca(self) -> str: certificate_pem = self._get_param("certificatePem") status = self._get_param("status") @@ -386,7 +388,7 @@ class IoTResponse(BaseResponse): dict(certificateId=cert.certificate_id, certificateArn=cert.arn) ) - def update_ca_certificate(self): + def update_ca_certificate(self) -> str: certificate_id = self.path.split("/")[-1] new_status = self._get_param("newStatus") config = self._get_param("registrationConfig") @@ -395,7 +397,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def update_certificate(self): + def update_certificate(self) -> str: certificate_id = self._get_param("certificateId") new_status = self._get_param("newStatus") self.iot_backend.update_certificate( @@ -403,7 +405,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def create_policy(self): + def create_policy(self) -> str: policy_name = self._get_param("policyName") policy_document = self._get_param("policyDocument") policy = self.iot_backend.create_policy( @@ -411,7 +413,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(policy.to_dict_at_creation()) - def list_policies(self): + def list_policies(self) -> str: # marker = self._get_param("marker") # page_size = self._get_int_param("pageSize") # ascending_order = self._get_param("ascendingOrder") @@ -420,17 +422,17 @@ class IoTResponse(BaseResponse): # TODO: implement pagination in the future return json.dumps(dict(policies=[_.to_dict() for _ in policies])) - def get_policy(self): + def get_policy(self) -> str: policy_name = self._get_param("policyName") policy = self.iot_backend.get_policy(policy_name=policy_name) return json.dumps(policy.to_get_dict()) - def delete_policy(self): + def delete_policy(self) -> str: policy_name = self._get_param("policyName") self.iot_backend.delete_policy(policy_name=policy_name) return json.dumps(dict()) - def create_policy_version(self): + def create_policy_version(self) -> str: policy_name = self._get_param("policyName") policy_document = self._get_param("policyDocument") set_as_default = self._get_bool_param("setAsDefault") @@ -440,20 +442,20 @@ class IoTResponse(BaseResponse): return json.dumps(dict(policy_version.to_dict_at_creation())) - def set_default_policy_version(self): + def set_default_policy_version(self) -> str: policy_name = self._get_param("policyName") version_id = self._get_param("policyVersionId") self.iot_backend.set_default_policy_version(policy_name, version_id) return json.dumps(dict()) - def get_policy_version(self): + def get_policy_version(self) -> str: policy_name = self._get_param("policyName") version_id = self._get_param("policyVersionId") policy_version = self.iot_backend.get_policy_version(policy_name, version_id) return json.dumps(dict(policy_version.to_get_dict())) - def list_policy_versions(self): + def list_policy_versions(self) -> str: policy_name = self._get_param("policyName") policiy_versions = self.iot_backend.list_policy_versions( policy_name=policy_name @@ -461,20 +463,22 @@ class IoTResponse(BaseResponse): return json.dumps(dict(policyVersions=[_.to_dict() for _ in policiy_versions])) - def delete_policy_version(self): + def delete_policy_version(self) -> str: policy_name = self._get_param("policyName") version_id = self._get_param("policyVersionId") self.iot_backend.delete_policy_version(policy_name, version_id) return json.dumps(dict()) - def attach_policy(self): + def attach_policy(self) -> str: policy_name = self._get_param("policyName") target = self._get_param("target") self.iot_backend.attach_policy(policy_name=policy_name, target=target) return json.dumps(dict()) - def dispatch_attached_policies(self, request, full_url, headers): + def dispatch_attached_policies( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: # This endpoint requires specialized handling because it has # a uri parameter containing forward slashes that is not # correctly url encoded when we're running in server mode. @@ -485,7 +489,7 @@ class IoTResponse(BaseResponse): self.querystring["target"] = [unquote(target)] if "%" in target else [target] return self.call_action() - def list_attached_policies(self): + def list_attached_policies(self) -> str: principal = self._get_param("target") # marker = self._get_param("marker") # page_size = self._get_int_param("pageSize") @@ -496,7 +500,7 @@ class IoTResponse(BaseResponse): dict(policies=[_.to_dict() for _ in policies], nextMarker=next_marker) ) - def attach_principal_policy(self): + def attach_principal_policy(self) -> str: policy_name = self._get_param("policyName") principal = self.headers.get("x-amzn-iot-principal") self.iot_backend.attach_principal_policy( @@ -504,13 +508,13 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def detach_policy(self): + def detach_policy(self) -> str: policy_name = self._get_param("policyName") target = self._get_param("target") self.iot_backend.detach_policy(policy_name=policy_name, target=target) return json.dumps(dict()) - def detach_principal_policy(self): + def detach_principal_policy(self) -> str: policy_name = self._get_param("policyName") principal = self.headers.get("x-amzn-iot-principal") self.iot_backend.detach_principal_policy( @@ -518,7 +522,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def list_principal_policies(self): + def list_principal_policies(self) -> str: principal = self.headers.get("x-amzn-iot-principal") # marker = self._get_param("marker") # page_size = self._get_int_param("pageSize") @@ -530,7 +534,7 @@ class IoTResponse(BaseResponse): dict(policies=[_.to_dict() for _ in policies], nextMarker=next_marker) ) - def list_policy_principals(self): + def list_policy_principals(self) -> str: policy_name = self.headers.get("x-amzn-iot-policy") # marker = self._get_param("marker") # page_size = self._get_int_param("pageSize") @@ -540,13 +544,13 @@ class IoTResponse(BaseResponse): next_marker = None return json.dumps(dict(principals=principals, nextMarker=next_marker)) - def list_targets_for_policy(self): + def list_targets_for_policy(self) -> str: """https://docs.aws.amazon.com/iot/latest/apireference/API_ListTargetsForPolicy.html""" policy_name = self._get_param("policyName") principals = self.iot_backend.list_targets_for_policy(policy_name=policy_name) return json.dumps(dict(targets=principals, nextMarker=None)) - def attach_thing_principal(self): + def attach_thing_principal(self) -> str: thing_name = self._get_param("thingName") principal = self.headers.get("x-amzn-principal") self.iot_backend.attach_thing_principal( @@ -554,7 +558,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def detach_thing_principal(self): + def detach_thing_principal(self) -> str: thing_name = self._get_param("thingName") principal = self.headers.get("x-amzn-principal") self.iot_backend.detach_thing_principal( @@ -562,7 +566,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def list_principal_things(self): + def list_principal_things(self) -> str: next_token = self._get_param("nextToken") # max_results = self._get_int_param("maxResults") principal = self.headers.get("x-amzn-principal") @@ -571,19 +575,19 @@ class IoTResponse(BaseResponse): next_token = None return json.dumps(dict(things=things, nextToken=next_token)) - def list_thing_principals(self): + def list_thing_principals(self) -> str: thing_name = self._get_param("thingName") principals = self.iot_backend.list_thing_principals(thing_name=thing_name) return json.dumps(dict(principals=principals)) - def describe_thing_group(self): + def describe_thing_group(self) -> str: thing_group_name = self._get_param("thingGroupName") thing_group = self.iot_backend.describe_thing_group( thing_group_name=thing_group_name ) return json.dumps(thing_group.to_dict()) - def create_thing_group(self): + def create_thing_group(self) -> str: thing_group_name = self._get_param("thingGroupName") parent_group_name = self._get_param("parentGroupName") thing_group_properties = self._get_param("thingGroupProperties") @@ -604,12 +608,12 @@ class IoTResponse(BaseResponse): ) ) - def delete_thing_group(self): + def delete_thing_group(self) -> str: thing_group_name = self._get_param("thingGroupName") self.iot_backend.delete_thing_group(thing_group_name=thing_group_name) return json.dumps(dict()) - def list_thing_groups(self): + def list_thing_groups(self) -> str: # next_token = self._get_param("nextToken") # max_results = self._get_int_param("maxResults") parent_group = self._get_param("parentGroup") @@ -627,7 +631,7 @@ class IoTResponse(BaseResponse): # TODO: implement pagination in the future return json.dumps(dict(thingGroups=rets, nextToken=next_token)) - def update_thing_group(self): + def update_thing_group(self) -> str: thing_group_name = self._get_param("thingGroupName") thing_group_properties = self._get_param("thingGroupProperties") expected_version = self._get_param("expectedVersion") @@ -638,7 +642,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict(version=version)) - def add_thing_to_thing_group(self): + def add_thing_to_thing_group(self) -> str: thing_group_name = self._get_param("thingGroupName") thing_group_arn = self._get_param("thingGroupArn") thing_name = self._get_param("thingName") @@ -651,7 +655,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def remove_thing_from_thing_group(self): + def remove_thing_from_thing_group(self) -> str: thing_group_name = self._get_param("thingGroupName") thing_group_arn = self._get_param("thingGroupArn") thing_name = self._get_param("thingName") @@ -664,7 +668,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def list_things_in_thing_group(self): + def list_things_in_thing_group(self) -> str: thing_group_name = self._get_param("thingGroupName") things = self.iot_backend.list_things_in_thing_group( thing_group_name=thing_group_name @@ -673,7 +677,7 @@ class IoTResponse(BaseResponse): thing_names = [_.thing_name for _ in things] return json.dumps(dict(things=thing_names, nextToken=next_token)) - def list_thing_groups_for_thing(self): + def list_thing_groups_for_thing(self) -> str: thing_name = self._get_param("thingName") # next_token = self._get_param("nextToken") # max_results = self._get_int_param("maxResults") @@ -683,7 +687,7 @@ class IoTResponse(BaseResponse): next_token = None return json.dumps(dict(thingGroups=thing_groups, nextToken=next_token)) - def update_thing_groups_for_thing(self): + def update_thing_groups_for_thing(self) -> str: thing_name = self._get_param("thingName") thing_groups_to_add = self._get_param("thingGroupsToAdd") or [] thing_groups_to_remove = self._get_param("thingGroupsToRemove") or [] @@ -694,15 +698,15 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def list_topic_rules(self): + def list_topic_rules(self) -> str: return json.dumps(dict(rules=self.iot_backend.list_topic_rules())) - def get_topic_rule(self): + def get_topic_rule(self) -> str: return json.dumps( self.iot_backend.get_topic_rule(rule_name=self._get_param("ruleName")) ) - def create_topic_rule(self): + def create_topic_rule(self) -> str: self.iot_backend.create_topic_rule( rule_name=self._get_param("ruleName"), description=self._get_param("description"), @@ -714,7 +718,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def replace_topic_rule(self): + def replace_topic_rule(self) -> str: self.iot_backend.replace_topic_rule( rule_name=self._get_param("ruleName"), description=self._get_param("description"), @@ -726,19 +730,19 @@ class IoTResponse(BaseResponse): ) return json.dumps(dict()) - def delete_topic_rule(self): + def delete_topic_rule(self) -> str: self.iot_backend.delete_topic_rule(rule_name=self._get_param("ruleName")) return json.dumps(dict()) - def enable_topic_rule(self): + def enable_topic_rule(self) -> str: self.iot_backend.enable_topic_rule(rule_name=self._get_param("ruleName")) return json.dumps(dict()) - def disable_topic_rule(self): + def disable_topic_rule(self) -> str: self.iot_backend.disable_topic_rule(rule_name=self._get_param("ruleName")) return json.dumps(dict()) - def create_domain_configuration(self): + def create_domain_configuration(self) -> str: domain_configuration = self.iot_backend.create_domain_configuration( domain_configuration_name=self._get_param("domainConfigurationName"), domain_name=self._get_param("domainName"), @@ -748,24 +752,24 @@ class IoTResponse(BaseResponse): ) return json.dumps(domain_configuration.to_dict()) - def delete_domain_configuration(self): + def delete_domain_configuration(self) -> str: self.iot_backend.delete_domain_configuration( domain_configuration_name=self._get_param("domainConfigurationName") ) return json.dumps(dict()) - def describe_domain_configuration(self): + def describe_domain_configuration(self) -> str: domain_configuration = self.iot_backend.describe_domain_configuration( domain_configuration_name=self._get_param("domainConfigurationName") ) return json.dumps(domain_configuration.to_description_dict()) - def list_domain_configurations(self): + def list_domain_configurations(self) -> str: return json.dumps( dict(domainConfigurations=self.iot_backend.list_domain_configurations()) ) - def update_domain_configuration(self): + def update_domain_configuration(self) -> str: domain_configuration = self.iot_backend.update_domain_configuration( domain_configuration_name=self._get_param("domainConfigurationName"), authorizer_config=self._get_param("authorizerConfig"), @@ -774,7 +778,7 @@ class IoTResponse(BaseResponse): ) return json.dumps(domain_configuration.to_dict()) - def search_index(self): + def search_index(self) -> str: query = self._get_param("queryString") - things, groups = self.iot_backend.search_index(query) - return json.dumps({"things": things, "thingGroups": groups}) + things = self.iot_backend.search_index(query) + return json.dumps({"things": things, "thingGroups": []}) diff --git a/moto/iotdata/exceptions.py b/moto/iotdata/exceptions.py index 04c37edfc..cb515902f 100644 --- a/moto/iotdata/exceptions.py +++ b/moto/iotdata/exceptions.py @@ -6,7 +6,7 @@ class IoTDataPlaneClientError(JsonRESTError): class ResourceNotFoundException(IoTDataPlaneClientError): - def __init__(self): + def __init__(self) -> None: self.code = 404 super().__init__( "ResourceNotFoundException", "The specified resource does not exist" @@ -14,12 +14,12 @@ class ResourceNotFoundException(IoTDataPlaneClientError): class InvalidRequestException(IoTDataPlaneClientError): - def __init__(self, message): + def __init__(self, message: str): self.code = 400 super().__init__("InvalidRequestException", message) class ConflictException(IoTDataPlaneClientError): - def __init__(self, message): + def __init__(self, message: str): self.code = 409 super().__init__("ConflictException", message) diff --git a/moto/iotdata/models.py b/moto/iotdata/models.py index 507284b09..b8624528b 100644 --- a/moto/iotdata/models.py +++ b/moto/iotdata/models.py @@ -1,10 +1,11 @@ import json import time import jsondiff +from typing import Any, Dict, List, Tuple, Optional from moto.core import BaseBackend, BackendDict, BaseModel from moto.core.utils import merge_dicts -from moto.iot import iot_backends +from moto.iot.models import iot_backends, IoTBackend from .exceptions import ( ConflictException, ResourceNotFoundException, @@ -17,7 +18,14 @@ class FakeShadow(BaseModel): http://docs.aws.amazon.com/iot/latest/developerguide/thing-shadow-document-syntax.html """ - def __init__(self, desired, reported, requested_payload, version, deleted=False): + def __init__( + self, + desired: Optional[str], + reported: Optional[str], + requested_payload: Optional[Dict[str, Any]], + version: int, + deleted: bool = False, + ): self.desired = desired self.reported = reported self.requested_payload = requested_payload @@ -33,7 +41,7 @@ class FakeShadow(BaseModel): ) @classmethod - def create_from_previous_version(cls, previous_shadow, payload): + def create_from_previous_version(cls, previous_shadow: Optional["FakeShadow"], payload: Optional[Dict[str, Any]]) -> "FakeShadow": # type: ignore[misc] """ set None to payload when you want to delete shadow """ @@ -55,11 +63,10 @@ class FakeShadow(BaseModel): merge_dicts(state_document, payload, remove_nulls=True) desired = state_document.get("state", {}).get("desired") reported = state_document.get("state", {}).get("reported") - shadow = FakeShadow(desired, reported, payload, version) - return shadow + return FakeShadow(desired, reported, payload, version) @classmethod - def parse_payload(cls, desired, reported): + def parse_payload(cls, desired: Optional[str], reported: Optional[str]) -> Any: # type: ignore[misc] if desired is None: delta = reported elif reported is None: @@ -68,15 +75,15 @@ class FakeShadow(BaseModel): delta = jsondiff.diff(desired, reported) return delta - def _create_metadata_from_state(self, state, ts): + def _create_metadata_from_state(self, state: Any, ts: Any) -> Any: """ - state must be disired or reported stype dict object - replces primitive type with {"timestamp": ts} in dict + state must be desired or reported stype dict object + replaces primitive type with {"timestamp": ts} in dict """ if state is None: return None - def _f(elem, ts): + def _f(elem: Any, ts: Any) -> Any: if isinstance(elem, dict): return {_: _f(elem[_], ts) for _ in elem.keys()} if isinstance(elem, list): @@ -85,9 +92,9 @@ class FakeShadow(BaseModel): return _f(state, ts) - def to_response_dict(self): - desired = self.requested_payload["state"].get("desired", None) - reported = self.requested_payload["state"].get("reported", None) + def to_response_dict(self) -> Dict[str, Any]: + desired = self.requested_payload["state"].get("desired", None) # type: ignore + reported = self.requested_payload["state"].get("reported", None) # type: ignore payload = {} if desired is not None: @@ -111,7 +118,7 @@ class FakeShadow(BaseModel): "version": self.version, } - def to_dict(self, include_delta=True): + def to_dict(self, include_delta: bool = True) -> Dict[str, Any]: """returning nothing except for just top-level keys for now.""" if self.deleted: return {"timestamp": self.timestamp, "version": self.version} @@ -139,15 +146,15 @@ class FakeShadow(BaseModel): class IoTDataPlaneBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.published_payloads = list() + self.published_payloads: List[Tuple[str, str]] = list() @property - def iot_backend(self): + def iot_backend(self) -> IoTBackend: return iot_backends[self.account_id][self.region_name] - def update_thing_shadow(self, thing_name, payload): + def update_thing_shadow(self, thing_name: str, payload: str) -> FakeShadow: """ spec of payload: - need node `state` @@ -158,32 +165,32 @@ class IoTDataPlaneBackend(BaseBackend): # validate try: - payload = json.loads(payload) + _payload = json.loads(payload) except ValueError: raise InvalidRequestException("invalid json") - if "state" not in payload: + if "state" not in _payload: raise InvalidRequestException("need node `state`") - if not isinstance(payload["state"], dict): + if not isinstance(_payload["state"], dict): raise InvalidRequestException("state node must be an Object") - if any(_ for _ in payload["state"].keys() if _ not in ["desired", "reported"]): + if any(_ for _ in _payload["state"].keys() if _ not in ["desired", "reported"]): raise InvalidRequestException("State contains an invalid node") - if "version" in payload and thing.thing_shadow.version != payload["version"]: + if "version" in _payload and thing.thing_shadow.version != _payload["version"]: raise ConflictException("Version conflict") new_shadow = FakeShadow.create_from_previous_version( - thing.thing_shadow, payload + thing.thing_shadow, _payload ) thing.thing_shadow = new_shadow return thing.thing_shadow - def get_thing_shadow(self, thing_name): + def get_thing_shadow(self, thing_name: str) -> FakeShadow: thing = self.iot_backend.describe_thing(thing_name) if thing.thing_shadow is None or thing.thing_shadow.deleted: raise ResourceNotFoundException() return thing.thing_shadow - def delete_thing_shadow(self, thing_name): + def delete_thing_shadow(self, thing_name: str) -> FakeShadow: thing = self.iot_backend.describe_thing(thing_name) if thing.thing_shadow is None: raise ResourceNotFoundException() @@ -194,7 +201,7 @@ class IoTDataPlaneBackend(BaseBackend): thing.thing_shadow = new_shadow return thing.thing_shadow - def publish(self, topic, payload): + def publish(self, topic: str, payload: str) -> None: self.published_payloads.append((topic, payload)) diff --git a/moto/iotdata/responses.py b/moto/iotdata/responses.py index db4dbfe99..fe085c6cf 100644 --- a/moto/iotdata/responses.py +++ b/moto/iotdata/responses.py @@ -1,11 +1,11 @@ from moto.core.responses import BaseResponse -from .models import iotdata_backends +from .models import iotdata_backends, IoTDataPlaneBackend import json from urllib.parse import unquote class IoTDataPlaneResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="iot-data") def _get_action(self) -> str: @@ -15,10 +15,10 @@ class IoTDataPlaneResponse(BaseResponse): return super()._get_action() @property - def iotdata_backend(self): + def iotdata_backend(self) -> IoTDataPlaneBackend: return iotdata_backends[self.current_account][self.region] - def update_thing_shadow(self): + def update_thing_shadow(self) -> str: thing_name = self._get_param("thingName") payload = self.body payload = self.iotdata_backend.update_thing_shadow( @@ -26,17 +26,17 @@ class IoTDataPlaneResponse(BaseResponse): ) return json.dumps(payload.to_response_dict()) - def get_thing_shadow(self): + def get_thing_shadow(self) -> str: thing_name = self._get_param("thingName") payload = self.iotdata_backend.get_thing_shadow(thing_name=thing_name) return json.dumps(payload.to_dict()) - def delete_thing_shadow(self): + def delete_thing_shadow(self) -> str: thing_name = self._get_param("thingName") payload = self.iotdata_backend.delete_thing_shadow(thing_name=thing_name) return json.dumps(payload.to_dict()) - def publish(self): + def publish(self) -> str: topic = self.path.split("/topics/")[-1] # a uri parameter containing forward slashes is not correctly url encoded when we're running in server mode. # https://github.com/pallets/flask/issues/900 diff --git a/setup.cfg b/setup.cfg index da18a49b3..6f2cb44c0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -230,7 +230,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/iam,moto/moto_api,moto/neptune +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/moto_api,moto/neptune show_column_numbers=True show_error_codes = True disable_error_code=abstract