moto/moto/iot/models.py

1696 lines
61 KiB
Python

import hashlib
import random
import re
import string
import time
import uuid
from collections import OrderedDict
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization, hashes
from datetime import datetime, timedelta
from .utils import PAGINATION_MODEL
from moto.core import get_account_id, BaseBackend, BaseModel
from moto.core.utils import BackendDict
from moto.utilities.utils import random_string
from moto.utilities.paginator import paginate
from .exceptions import (
CertificateStateException,
DeleteConflictException,
ResourceNotFoundException,
InvalidRequestException,
InvalidStateTransitionException,
VersionConflictException,
ResourceAlreadyExistsException,
VersionsLimitExceededException,
ThingStillAttached,
)
class FakeThing(BaseModel):
def __init__(self, thing_name, thing_type, attributes, region_name):
self.region_name = region_name
self.thing_name = thing_name
self.thing_type = thing_type
self.attributes = attributes
self.arn = f"arn:aws:iot:{region_name}:{get_account_id()}:thing/{thing_name}"
self.version = 1
# TODO: we need to handle "version"?
# for iot-data
self.thing_shadow = None
def matches(self, query_string):
if query_string == "*":
return True
if query_string.startswith("thingName:"):
qs = query_string[10:].replace("*", ".*").replace("?", ".")
return re.search(f"^{qs}$", self.thing_name)
if query_string.startswith("attributes."):
k, v = query_string[11:].split(":")
return self.attributes.get(k) == v
return query_string in self.thing_name
def to_dict(self, include_default_client_id=False):
obj = {
"thingName": self.thing_name,
"thingArn": self.arn,
"attributes": self.attributes,
"version": self.version,
}
if self.thing_type:
obj["thingTypeName"] = self.thing_type.thing_type_name
if include_default_client_id:
obj["defaultClientId"] = self.thing_name
return obj
class FakeThingType(BaseModel):
def __init__(self, thing_type_name, thing_type_properties, region_name):
self.region_name = region_name
self.thing_type_name = thing_type_name
self.thing_type_properties = thing_type_properties
self.thing_type_id = str(uuid.uuid4()) # I don't know the rule of id
t = time.time()
self.metadata = {"deprecated": False, "creationDate": int(t * 1000) / 1000.0}
self.arn = "arn:aws:iot:%s:1:thingtype/%s" % (self.region_name, thing_type_name)
def to_dict(self):
return {
"thingTypeName": self.thing_type_name,
"thingTypeId": self.thing_type_id,
"thingTypeProperties": self.thing_type_properties,
"thingTypeMetadata": self.metadata,
"thingTypeArn": self.arn,
}
class FakeThingGroup(BaseModel):
def __init__(
self,
thing_group_name,
parent_group_name,
thing_group_properties,
region_name,
thing_groups,
):
self.region_name = region_name
self.thing_group_name = thing_group_name
self.thing_group_id = str(uuid.uuid4()) # I don't know the rule of id
self.version = 1 # TODO: tmp
self.parent_group_name = parent_group_name
self.thing_group_properties = thing_group_properties or {}
t = time.time()
self.metadata = {"creationDate": int(t * 1000) / 1000.0}
if parent_group_name:
self.metadata["parentGroupName"] = parent_group_name
# initilize rootToParentThingGroups
if "rootToParentThingGroups" not in self.metadata:
self.metadata["rootToParentThingGroups"] = []
# search for parent arn
for thing_group in thing_groups.values():
if thing_group.thing_group_name == parent_group_name:
parent_thing_group_structure = thing_group
break
# if parent arn found (should always be found)
if parent_thing_group_structure:
# copy parent's rootToParentThingGroups
if "rootToParentThingGroups" in parent_thing_group_structure.metadata:
self.metadata["rootToParentThingGroups"].extend(
parent_thing_group_structure.metadata["rootToParentThingGroups"]
)
self.metadata["rootToParentThingGroups"].extend(
[
{
"groupName": parent_group_name,
"groupArn": parent_thing_group_structure.arn,
}
]
)
self.arn = "arn:aws:iot:%s:1:thinggroup/%s" % (
self.region_name,
thing_group_name,
)
self.things = OrderedDict()
def to_dict(self):
return {
"thingGroupName": self.thing_group_name,
"thingGroupId": self.thing_group_id,
"version": self.version,
"thingGroupProperties": self.thing_group_properties,
"thingGroupMetadata": self.metadata,
"thingGroupArn": self.arn,
}
class FakeCertificate(BaseModel):
def __init__(self, certificate_pem, status, region_name, ca_certificate_id=None):
m = hashlib.sha256()
m.update(certificate_pem.encode("utf-8"))
self.certificate_id = m.hexdigest()
self.arn = (
f"arn:aws:iot:{region_name}:{get_account_id()}:cert/{self.certificate_id}"
)
self.certificate_pem = certificate_pem
self.status = status
self.owner = get_account_id()
self.transfer_data = {}
self.creation_date = time.time()
self.last_modified_date = self.creation_date
self.validity_not_before = time.time() - 86400
self.validity_not_after = time.time() + 86400
self.ca_certificate_id = ca_certificate_id
def to_dict(self):
return {
"certificateArn": self.arn,
"certificateId": self.certificate_id,
"caCertificateId": self.ca_certificate_id,
"status": self.status,
"creationDate": self.creation_date,
}
def to_description_dict(self):
"""
You might need keys below in some situation
- caCertificateId
- previousOwnedBy
"""
return {
"certificateArn": self.arn,
"certificateId": self.certificate_id,
"status": self.status,
"certificatePem": self.certificate_pem,
"ownedBy": self.owner,
"creationDate": self.creation_date,
"lastModifiedDate": self.last_modified_date,
"validity": {
"notBefore": self.validity_not_before,
"notAfter": self.validity_not_after,
},
"transferData": self.transfer_data,
}
class FakeCaCertificate(FakeCertificate):
def __init__(self, ca_certificate, status, region_name, registration_config):
super().__init__(
certificate_pem=ca_certificate,
status=status,
region_name=region_name,
ca_certificate_id=None,
)
self.registration_config = registration_config
class FakePolicy(BaseModel):
def __init__(self, name, document, region_name, default_version_id="1"):
self.name = name
self.document = document
self.arn = f"arn:aws:iot:{region_name}:{get_account_id()}:policy/{name}"
self.default_version_id = default_version_id
self.versions = [FakePolicyVersion(self.name, document, True, region_name)]
def to_get_dict(self):
return {
"policyName": self.name,
"policyArn": self.arn,
"policyDocument": self.document,
"defaultVersionId": self.default_version_id,
}
def to_dict_at_creation(self):
return {
"policyName": self.name,
"policyArn": self.arn,
"policyDocument": self.document,
"policyVersionId": self.default_version_id,
}
def to_dict(self):
return {"policyName": self.name, "policyArn": self.arn}
class FakePolicyVersion(object):
def __init__(self, policy_name, document, is_default, region_name):
self.name = policy_name
self.arn = f"arn:aws:iot:{region_name}:{get_account_id()}:policy/{policy_name}"
self.document = document or {}
self.is_default = is_default
self.version_id = "1"
self.create_datetime = time.mktime(datetime(2015, 1, 1).timetuple())
self.last_modified_datetime = time.mktime(datetime(2015, 1, 2).timetuple())
def to_get_dict(self):
return {
"policyName": self.name,
"policyArn": self.arn,
"policyDocument": self.document,
"policyVersionId": self.version_id,
"isDefaultVersion": self.is_default,
"creationDate": self.create_datetime,
"lastModifiedDate": self.last_modified_datetime,
"generationId": self.version_id,
}
def to_dict_at_creation(self):
return {
"policyArn": self.arn,
"policyDocument": self.document,
"policyVersionId": self.version_id,
"isDefaultVersion": self.is_default,
}
def to_dict(self):
return {
"versionId": self.version_id,
"isDefaultVersion": self.is_default,
"createDate": self.create_datetime,
}
class FakeJob(BaseModel):
JOB_ID_REGEX_PATTERN = "[a-zA-Z0-9_-]"
JOB_ID_REGEX = re.compile(JOB_ID_REGEX_PATTERN)
def __init__(
self,
job_id,
targets,
document_source,
document,
description,
presigned_url_config,
target_selection,
job_executions_rollout_config,
document_parameters,
region_name,
):
if not self._job_id_matcher(self.JOB_ID_REGEX, job_id):
raise InvalidRequestException()
self.region_name = region_name
self.job_id = job_id
self.job_arn = "arn:aws:iot:%s:1:job/%s" % (self.region_name, job_id)
self.targets = targets
self.document_source = document_source
self.document = document
self.force = False
self.description = description
self.presigned_url_config = presigned_url_config
self.target_selection = target_selection
self.job_executions_rollout_config = job_executions_rollout_config
self.status = "QUEUED" # IN_PROGRESS | CANCELED | COMPLETED
self.comment = None
self.reason_code = None
self.created_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.completed_at = None
self.job_process_details = {
"processingTargets": targets,
"numberOfQueuedThings": 1,
"numberOfCanceledThings": 0,
"numberOfSucceededThings": 0,
"numberOfFailedThings": 0,
"numberOfRejectedThings": 0,
"numberOfInProgressThings": 0,
"numberOfRemovedThings": 0,
}
self.document_parameters = document_parameters
def to_dict(self):
obj = {
"jobArn": self.job_arn,
"jobId": self.job_id,
"targets": self.targets,
"description": self.description,
"presignedUrlConfig": self.presigned_url_config,
"targetSelection": self.target_selection,
"jobExecutionsRolloutConfig": self.job_executions_rollout_config,
"status": self.status,
"comment": self.comment,
"forceCanceled": self.force,
"reasonCode": self.reason_code,
"createdAt": self.created_at,
"lastUpdatedAt": self.last_updated_at,
"completedAt": self.completed_at,
"jobProcessDetails": self.job_process_details,
"documentParameters": self.document_parameters,
"document": self.document,
"documentSource": self.document_source,
}
return obj
def _job_id_matcher(self, regex, argument):
regex_match = regex.match(argument)
length_match = len(argument) <= 64
return regex_match and length_match
class FakeJobExecution(BaseModel):
def __init__(
self,
job_id,
thing_arn,
status="QUEUED",
force_canceled=False,
status_details_map=None,
):
self.job_id = job_id
self.status = status # IN_PROGRESS | CANCELED | COMPLETED
self.force_canceled = force_canceled
self.status_details_map = status_details_map or {}
self.thing_arn = thing_arn
self.queued_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.started_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.execution_number = 123
self.version_number = 123
self.approximate_seconds_before_time_out = 123
def to_get_dict(self):
obj = {
"jobId": self.job_id,
"status": self.status,
"forceCanceled": self.force_canceled,
"statusDetails": {"detailsMap": self.status_details_map},
"thingArn": self.thing_arn,
"queuedAt": self.queued_at,
"startedAt": self.started_at,
"lastUpdatedAt": self.last_updated_at,
"executionNumber": self.execution_number,
"versionNumber": self.version_number,
"approximateSecondsBeforeTimedOut": self.approximate_seconds_before_time_out,
}
return obj
def to_dict(self):
obj = {
"jobId": self.job_id,
"thingArn": self.thing_arn,
"jobExecutionSummary": {
"status": self.status,
"queuedAt": self.queued_at,
"startedAt": self.started_at,
"lastUpdatedAt": self.last_updated_at,
"executionNumber": self.execution_number,
},
}
return obj
class FakeEndpoint(BaseModel):
def __init__(self, endpoint_type, region_name):
if endpoint_type not in [
"iot:Data",
"iot:Data-ATS",
"iot:CredentialProvider",
"iot:Jobs",
]:
raise InvalidRequestException(
" An error occurred (InvalidRequestException) when calling the DescribeEndpoint "
"operation: Endpoint type %s not recognized." % endpoint_type
)
self.region_name = region_name
identifier = random_string(14).lower()
if endpoint_type == "iot:Data":
self.endpoint = "{i}.iot.{r}.amazonaws.com".format(
i=identifier, r=self.region_name
)
elif "iot:Data-ATS" in endpoint_type:
self.endpoint = "{i}-ats.iot.{r}.amazonaws.com".format(
i=identifier, r=self.region_name
)
elif "iot:CredentialProvider" in endpoint_type:
self.endpoint = "{i}.credentials.iot.{r}.amazonaws.com".format(
i=identifier, r=self.region_name
)
elif "iot:Jobs" in endpoint_type:
self.endpoint = "{i}.jobs.iot.{r}.amazonaws.com".format(
i=identifier, r=self.region_name
)
self.endpoint_type = endpoint_type
def to_get_dict(self):
obj = {
"endpointAddress": self.endpoint,
}
return obj
def to_dict(self):
obj = {
"endpointAddress": self.endpoint,
}
return obj
class FakeRule(BaseModel):
def __init__(
self,
rule_name,
description,
created_at,
rule_disabled,
topic_pattern,
actions,
error_action,
sql,
aws_iot_sql_version,
region_name,
):
self.region_name = region_name
self.rule_name = rule_name
self.description = description or ""
self.created_at = created_at
self.rule_disabled = bool(rule_disabled)
self.topic_pattern = topic_pattern
self.actions = actions or []
self.error_action = error_action or {}
self.sql = sql
self.aws_iot_sql_version = aws_iot_sql_version or "2016-03-23"
self.arn = "arn:aws:iot:%s:1:rule/%s" % (self.region_name, rule_name)
def to_get_dict(self):
return {
"rule": {
"actions": self.actions,
"awsIotSqlVersion": self.aws_iot_sql_version,
"createdAt": self.created_at,
"description": self.description,
"errorAction": self.error_action,
"ruleDisabled": self.rule_disabled,
"ruleName": self.rule_name,
"sql": self.sql,
},
"ruleArn": self.arn,
}
def to_dict(self):
return {
"ruleName": self.rule_name,
"createdAt": self.created_at,
"ruleArn": self.arn,
"ruleDisabled": self.rule_disabled,
"topicPattern": self.topic_pattern,
}
class FakeDomainConfiguration(BaseModel):
def __init__(
self,
region_name,
domain_configuration_name,
domain_name,
server_certificate_arns,
domain_configuration_status,
service_type,
authorizer_config,
domain_type,
):
if service_type and service_type not in ["DATA", "CREDENTIAL_PROVIDER", "JOBS"]:
raise InvalidRequestException(
"An error occurred (InvalidRequestException) when calling the DescribeDomainConfiguration "
"operation: Service type %s not recognized." % service_type
)
self.domain_configuration_name = domain_configuration_name
self.domain_configuration_arn = "arn:aws:iot:%s:1:domainconfiguration/%s/%s" % (
region_name,
domain_configuration_name,
random_string(5),
)
self.domain_name = domain_name
self.server_certificates = []
if server_certificate_arns:
for sc in server_certificate_arns:
self.server_certificates.append(
{"serverCertificateArn": sc, "serverCertificateStatus": "VALID"}
)
self.domain_configuration_status = domain_configuration_status
self.service_type = service_type
self.authorizer_config = authorizer_config
self.domain_type = domain_type
self.last_status_change_date = time.time()
def to_description_dict(self):
return {
"domainConfigurationName": self.domain_configuration_name,
"domainConfigurationArn": self.domain_configuration_arn,
"domainName": self.domain_name,
"serverCertificates": self.server_certificates,
"authorizerConfig": self.authorizer_config,
"domainConfigurationStatus": self.domain_configuration_status,
"serviceType": self.service_type,
"domainType": self.domain_type,
"lastStatusChangeDate": self.last_status_change_date,
}
def to_dict(self):
return {
"domainConfigurationName": self.domain_configuration_name,
"domainConfigurationArn": self.domain_configuration_arn,
}
class IoTBackend(BaseBackend):
def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
self.things = OrderedDict()
self.jobs = OrderedDict()
self.job_executions = OrderedDict()
self.thing_types = OrderedDict()
self.thing_groups = OrderedDict()
self.ca_certificates = OrderedDict()
self.certificates = OrderedDict()
self.policies = OrderedDict()
self.principal_policies = OrderedDict()
self.principal_things = OrderedDict()
self.rules = OrderedDict()
self.endpoint = None
self.domain_configurations = OrderedDict()
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "iot"
) + BaseBackend.default_vpc_endpoint_service_factory(
service_region,
zones,
"data.iot",
private_dns_names=False,
special_service_name="iot.data",
policy_supported=False,
)
def create_certificate_from_csr(self, csr, set_as_active):
cert = x509.load_pem_x509_csr(csr.encode("utf-8"), default_backend())
pem = self._generate_certificate_pem(
domain_name="example.com", subject=cert.subject
)
return self.register_certificate(
pem, ca_certificate_pem=None, set_as_active=set_as_active, status="INACTIVE"
)
def _generate_certificate_pem(self, domain_name, subject):
sans = set()
sans.add(domain_name)
sans = [x509.DNSName(item) for item in sans]
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
issuer = x509.Name(
[ # C = US, O = Moto, OU = Server CA 1B, CN = Moto
x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, "Moto"),
x509.NameAttribute(
x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
),
x509.NameAttribute(x509.NameOID.COMMON_NAME, "Moto"),
]
)
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.utcnow())
.not_valid_after(datetime.utcnow() + timedelta(days=365))
.add_extension(x509.SubjectAlternativeName(sans), critical=False)
.sign(key, hashes.SHA512(), default_backend())
)
return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")
def create_thing(self, thing_name, thing_type_name, attribute_payload):
thing_types = self.list_thing_types()
thing_type = None
if thing_type_name:
filtered_thing_types = [
_ for _ in thing_types if _.thing_type_name == thing_type_name
]
if len(filtered_thing_types) == 0:
raise ResourceNotFoundException()
thing_type = filtered_thing_types[0]
if thing_type.metadata["deprecated"]:
# note - typo (depreated) exists also in the original exception.
raise InvalidRequestException(
msg=f"Can not create new thing with depreated thing type:{thing_type_name}"
)
if attribute_payload is None:
attributes = {}
elif "attributes" not in attribute_payload:
attributes = {}
else:
attributes = attribute_payload["attributes"]
thing = FakeThing(thing_name, thing_type, attributes, self.region_name)
self.things[thing.arn] = thing
return thing.thing_name, thing.arn
def create_thing_type(self, thing_type_name, thing_type_properties):
if thing_type_properties is None:
thing_type_properties = {}
thing_type = FakeThingType(
thing_type_name, thing_type_properties, self.region_name
)
self.thing_types[thing_type.arn] = thing_type
return thing_type.thing_type_name, thing_type.arn
def list_thing_types(self, thing_type_name=None):
if thing_type_name:
# It's weird but thing_type_name is filtered by forward match, not complete match
return [
_
for _ in self.thing_types.values()
if _.thing_type_name.startswith(thing_type_name)
]
return self.thing_types.values()
def list_things(
self, attribute_name, attribute_value, thing_type_name, max_results, token
):
all_things = [_.to_dict() for _ in self.things.values()]
if attribute_name is not None and thing_type_name is not None:
filtered_things = list(
filter(
lambda elem: attribute_name in elem["attributes"]
and elem["attributes"][attribute_name] == attribute_value
and "thingTypeName" in elem
and elem["thingTypeName"] == thing_type_name,
all_things,
)
)
elif attribute_name is not None and thing_type_name is None:
filtered_things = list(
filter(
lambda elem: attribute_name in elem["attributes"]
and elem["attributes"][attribute_name] == attribute_value,
all_things,
)
)
elif attribute_name is None and thing_type_name is not None:
filtered_things = list(
filter(
lambda elem: "thingTypeName" in elem
and elem["thingTypeName"] == thing_type_name,
all_things,
)
)
else:
filtered_things = all_things
if token is None:
things = filtered_things[0:max_results]
next_token = (
str(max_results) if len(filtered_things) > max_results else None
)
else:
token = int(token)
things = filtered_things[token : token + max_results]
next_token = (
str(token + max_results)
if len(filtered_things) > token + max_results
else None
)
return things, next_token
def describe_thing(self, thing_name):
things = [_ for _ in self.things.values() if _.thing_name == thing_name]
if len(things) == 0:
raise ResourceNotFoundException()
return things[0]
def describe_thing_type(self, thing_type_name):
thing_types = [
_ for _ in self.thing_types.values() if _.thing_type_name == thing_type_name
]
if len(thing_types) == 0:
raise ResourceNotFoundException()
return thing_types[0]
def describe_endpoint(self, endpoint_type):
self.endpoint = FakeEndpoint(endpoint_type, self.region_name)
return self.endpoint
def delete_thing(self, thing_name):
"""
The ExpectedVersion-parameter is not yet implemented
"""
# can raise ResourceNotFoundError
thing = self.describe_thing(thing_name)
for k in list(self.principal_things.keys()):
if k[1] == thing_name:
raise ThingStillAttached(thing_name)
del self.things[thing.arn]
def delete_thing_type(self, thing_type_name):
# can raise ResourceNotFoundError
thing_type = self.describe_thing_type(thing_type_name)
del self.thing_types[thing_type.arn]
def deprecate_thing_type(self, thing_type_name, undo_deprecate):
thing_types = [
_ for _ in self.thing_types.values() if _.thing_type_name == thing_type_name
]
if len(thing_types) == 0:
raise ResourceNotFoundException()
thing_types[0].metadata["deprecated"] = not undo_deprecate
return thing_types[0]
def update_thing(
self,
thing_name,
thing_type_name,
attribute_payload,
remove_thing_type,
):
"""
The ExpectedVersion-parameter is not yet implemented
"""
# if attributes payload = {}, nothing
thing = self.describe_thing(thing_name)
thing_type = None
if remove_thing_type and thing_type_name:
raise InvalidRequestException()
# thing_type
if thing_type_name:
thing_types = self.list_thing_types()
filtered_thing_types = [
_ for _ in thing_types if _.thing_type_name == thing_type_name
]
if len(filtered_thing_types) == 0:
raise ResourceNotFoundException()
thing_type = filtered_thing_types[0]
if thing_type.metadata["deprecated"]:
raise InvalidRequestException(
msg=f"Can not update a thing to use deprecated thing type: {thing_type_name}"
)
thing.thing_type = thing_type
if remove_thing_type:
thing.thing_type = None
# attribute
if attribute_payload is not None and "attributes" in attribute_payload:
do_merge = attribute_payload.get("merge", False)
attributes = attribute_payload["attributes"]
if not do_merge:
thing.attributes = attributes
else:
thing.attributes.update(attributes)
def _random_string(self):
n = 20
random_str = "".join(
[random.choice(string.ascii_letters + string.digits) for i in range(n)]
)
return random_str
def create_keys_and_certificate(self, set_as_active):
# implement here
# caCertificate can be blank
key_pair = {
"PublicKey": self._random_string(),
"PrivateKey": self._random_string(),
}
certificate_pem = self._random_string()
status = "ACTIVE" if set_as_active else "INACTIVE"
certificate = FakeCertificate(certificate_pem, status, self.region_name)
self.certificates[certificate.certificate_id] = certificate
return certificate, key_pair
def delete_ca_certificate(self, certificate_id):
cert = self.describe_ca_certificate(certificate_id)
self._validation_delete(cert)
del self.ca_certificates[certificate_id]
def delete_certificate(self, certificate_id):
cert = self.describe_certificate(certificate_id)
self._validation_delete(cert)
del self.certificates[certificate_id]
def _validation_delete(self, cert):
if cert.status == "ACTIVE":
raise CertificateStateException(
"Certificate must be deactivated (not ACTIVE) before deletion.",
cert.certificate_id,
)
certs = [
k[0]
for k, v in self.principal_things.items()
if self._get_principal(k[0]).certificate_id == cert.certificate_id
]
if len(certs) > 0:
raise DeleteConflictException(
"Things must be detached before deletion (arn: %s)" % certs[0]
)
certs = [
k[0]
for k, v in self.principal_policies.items()
if self._get_principal(k[0]).certificate_id == cert.certificate_id
]
if len(certs) > 0:
raise DeleteConflictException(
"Certificate policies must be detached before deletion (arn: %s)"
% certs[0]
)
def describe_ca_certificate(self, certificate_id):
if certificate_id not in self.ca_certificates:
raise ResourceNotFoundException()
return self.ca_certificates[certificate_id]
def describe_certificate(self, certificate_id):
certs = [
_ for _ in self.certificates.values() if _.certificate_id == certificate_id
]
if len(certs) == 0:
raise ResourceNotFoundException()
return certs[0]
def get_registration_code(self):
return str(uuid.uuid4())
def list_certificates(self):
"""
Pagination is not yet implemented
"""
return self.certificates.values()
def list_certificates_by_ca(self, ca_certificate_id):
"""
Pagination is not yet implemented
"""
return [
cert
for cert in self.certificates.values()
if cert.ca_certificate_id == ca_certificate_id
]
def __raise_if_certificate_already_exists(self, certificate_id, certificate_arn):
if certificate_id in self.certificates:
raise ResourceAlreadyExistsException(
"The certificate is already provisioned or registered",
certificate_id,
certificate_arn,
)
def register_ca_certificate(
self,
ca_certificate,
set_as_active,
registration_config,
):
"""
The VerificationCertificate-parameter is not yet implemented
"""
certificate = FakeCaCertificate(
ca_certificate=ca_certificate,
status="ACTIVE" if set_as_active else "INACTIVE",
region_name=self.region_name,
registration_config=registration_config,
)
self.ca_certificates[certificate.certificate_id] = certificate
return certificate
def _find_ca_certificate(self, ca_certificate_pem):
for ca_cert in self.ca_certificates.values():
if ca_cert.certificate_pem == ca_certificate_pem:
return ca_cert.certificate_id
return None
def register_certificate(
self, certificate_pem, ca_certificate_pem, set_as_active, status
):
ca_certificate_id = self._find_ca_certificate(ca_certificate_pem)
certificate = FakeCertificate(
certificate_pem,
"ACTIVE" if set_as_active else status,
self.region_name,
ca_certificate_id,
)
self.__raise_if_certificate_already_exists(
certificate.certificate_id, certificate_arn=certificate.arn
)
self.certificates[certificate.certificate_id] = certificate
return certificate
def register_certificate_without_ca(self, certificate_pem, status):
certificate = FakeCertificate(certificate_pem, status, self.region_name)
self.__raise_if_certificate_already_exists(
certificate.certificate_id, certificate_arn=certificate.arn
)
self.certificates[certificate.certificate_id] = certificate
return certificate
def update_ca_certificate(self, certificate_id, new_status, config):
"""
The newAutoRegistrationStatus and removeAutoRegistration-parameters are not yet implemented
"""
cert = self.describe_ca_certificate(certificate_id)
if new_status is not None:
cert.status = new_status
if config is not None:
cert.registration_config = config
def update_certificate(self, certificate_id, new_status):
cert = self.describe_certificate(certificate_id)
# TODO: validate new_status
cert.status = new_status
def create_policy(self, policy_name, policy_document):
policy = FakePolicy(policy_name, policy_document, self.region_name)
self.policies[policy.name] = policy
return policy
def attach_policy(self, policy_name, target):
principal = self._get_principal(target)
policy = self.get_policy(policy_name)
k = (target, policy_name)
if k in self.principal_policies:
return
self.principal_policies[k] = (principal, policy)
def detach_policy(self, policy_name, target):
# this may raises ResourceNotFoundException
self._get_principal(target)
self.get_policy(policy_name)
k = (target, policy_name)
if k not in self.principal_policies:
raise ResourceNotFoundException()
del self.principal_policies[k]
def list_attached_policies(self, target):
policies = [v[1] for k, v in self.principal_policies.items() if k[0] == target]
return policies
def list_policies(self):
policies = self.policies.values()
return policies
def get_policy(self, policy_name):
policies = [_ for _ in self.policies.values() if _.name == policy_name]
if len(policies) == 0:
raise ResourceNotFoundException()
return policies[0]
def delete_policy(self, policy_name):
policies = [
k[1] for k, v in self.principal_policies.items() if k[1] == policy_name
]
if len(policies) > 0:
raise DeleteConflictException(
"The policy cannot be deleted as the policy is attached to one or more principals (name=%s)"
% policy_name
)
policy = self.get_policy(policy_name)
if len(policy.versions) > 1:
raise DeleteConflictException(
"Cannot delete the policy because it has one or more policy versions attached to it (name=%s)"
% policy_name
)
del self.policies[policy.name]
def create_policy_version(self, policy_name, policy_document, set_as_default):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
if len(policy.versions) >= 5:
raise VersionsLimitExceededException(policy_name)
version = FakePolicyVersion(
policy_name, policy_document, set_as_default, self.region_name
)
policy.versions.append(version)
version.version_id = "{0}".format(len(policy.versions))
if set_as_default:
self.set_default_policy_version(policy_name, version.version_id)
return version
def set_default_policy_version(self, policy_name, version_id):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
for version in policy.versions:
if version.version_id == version_id:
version.is_default = True
policy.default_version_id = version.version_id
policy.document = version.document
else:
version.is_default = False
def get_policy_version(self, policy_name, version_id):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
for version in policy.versions:
if version.version_id == version_id:
return version
raise ResourceNotFoundException()
def list_policy_versions(self, policy_name):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
return policy.versions
def delete_policy_version(self, policy_name, version_id):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
if version_id == policy.default_version_id:
raise InvalidRequestException(
"Cannot delete the default version of a policy"
)
for i, v in enumerate(policy.versions):
if v.version_id == version_id:
del policy.versions[i]
return
raise ResourceNotFoundException()
def _get_principal(self, principal_arn):
"""
raise ResourceNotFoundException
"""
if ":cert/" in principal_arn:
certs = [_ for _ in self.certificates.values() if _.arn == principal_arn]
if len(certs) == 0:
raise ResourceNotFoundException()
principal = certs[0]
return principal
if ":thinggroup/" in principal_arn:
try:
return self.thing_groups[principal_arn]
except KeyError:
raise ResourceNotFoundException(
f"No thing group with ARN {principal_arn} exists"
)
from moto.cognitoidentity import cognitoidentity_backends
cognito = cognitoidentity_backends[self.region_name]
identities = []
for identity_pool in cognito.identity_pools:
pool_identities = cognito.pools_identities.get(identity_pool, None)
identities.extend(
[pi["IdentityId"] for pi in pool_identities.get("Identities", [])]
)
if principal_arn in identities:
return {"IdentityId": principal_arn}
raise ResourceNotFoundException()
def attach_principal_policy(self, policy_name, principal_arn):
principal = self._get_principal(principal_arn)
policy = self.get_policy(policy_name)
k = (principal_arn, policy_name)
if k in self.principal_policies:
return
self.principal_policies[k] = (principal, policy)
def detach_principal_policy(self, policy_name, principal_arn):
# this may raises ResourceNotFoundException
self._get_principal(principal_arn)
self.get_policy(policy_name)
k = (principal_arn, policy_name)
if k not in self.principal_policies:
raise ResourceNotFoundException()
del self.principal_policies[k]
def list_principal_policies(self, principal_arn):
policies = [
v[1] for k, v in self.principal_policies.items() if k[0] == principal_arn
]
return policies
def list_policy_principals(self, policy_name):
principals = [
k[0] for k, v in self.principal_policies.items() if k[1] == policy_name
]
return principals
def attach_thing_principal(self, thing_name, principal_arn):
principal = self._get_principal(principal_arn)
thing = self.describe_thing(thing_name)
k = (principal_arn, thing_name)
if k in self.principal_things:
return
self.principal_things[k] = (principal, thing)
def detach_thing_principal(self, thing_name, principal_arn):
# this may raises ResourceNotFoundException
self._get_principal(principal_arn)
self.describe_thing(thing_name)
k = (principal_arn, thing_name)
if k not in self.principal_things:
raise ResourceNotFoundException()
del self.principal_things[k]
def list_principal_things(self, principal_arn):
thing_names = [
k[1] for k, v in self.principal_things.items() if k[0] == principal_arn
]
return thing_names
def list_thing_principals(self, thing_name):
things = [_ for _ in self.things.values() if _.thing_name == thing_name]
if len(things) == 0:
raise ResourceNotFoundException(
"Failed to list principals for thing %s because the thing does not exist in your account"
% thing_name
)
principals = [
k[0] for k, v in self.principal_things.items() if k[1] == thing_name
]
return principals
def describe_thing_group(self, thing_group_name):
thing_groups = [
_
for _ in self.thing_groups.values()
if _.thing_group_name == thing_group_name
]
if len(thing_groups) == 0:
raise ResourceNotFoundException()
return thing_groups[0]
def create_thing_group(
self, thing_group_name, parent_group_name, thing_group_properties
):
thing_group = FakeThingGroup(
thing_group_name,
parent_group_name,
thing_group_properties,
self.region_name,
self.thing_groups,
)
self.thing_groups[thing_group.arn] = thing_group
return thing_group.thing_group_name, thing_group.arn, thing_group.thing_group_id
def delete_thing_group(self, thing_group_name):
"""
The ExpectedVersion-parameter is not yet implemented
"""
child_groups = [
thing_group
for _, thing_group in self.thing_groups.items()
if thing_group.parent_group_name == thing_group_name
]
if len(child_groups) > 0:
raise InvalidRequestException(
" Cannot delete thing group : "
+ thing_group_name
+ " when there are still child groups attached to it"
)
try:
thing_group = self.describe_thing_group(thing_group_name)
del self.thing_groups[thing_group.arn]
except ResourceNotFoundException:
# AWS returns success even if the thing group does not exist.
pass
def list_thing_groups(self, parent_group, name_prefix_filter, recursive):
if recursive is None:
recursive = True
if name_prefix_filter is None:
name_prefix_filter = ""
if parent_group and parent_group not in [
_.thing_group_name for _ in self.thing_groups.values()
]:
raise ResourceNotFoundException()
thing_groups = [
_ for _ in self.thing_groups.values() if _.parent_group_name == parent_group
]
if recursive:
for g in thing_groups:
thing_groups.extend(
self.list_thing_groups(
parent_group=g.thing_group_name,
name_prefix_filter=None,
recursive=False,
)
)
# thing_groups = groups_to_process.values()
return [
_ for _ in thing_groups if _.thing_group_name.startswith(name_prefix_filter)
]
def update_thing_group(
self, thing_group_name, thing_group_properties, expected_version
):
thing_group = self.describe_thing_group(thing_group_name)
if expected_version and expected_version != thing_group.version:
raise VersionConflictException(thing_group_name)
attribute_payload = thing_group_properties.get("attributePayload", None)
if attribute_payload is not None and "attributes" in attribute_payload:
do_merge = attribute_payload.get("merge", False)
attributes = attribute_payload["attributes"]
if not do_merge:
thing_group.thing_group_properties["attributePayload"][
"attributes"
] = attributes
else:
thing_group.thing_group_properties["attributePayload"][
"attributes"
].update(attributes)
elif attribute_payload is not None and "attributes" not in attribute_payload:
thing_group.attributes = {}
thing_group.version = thing_group.version + 1
return thing_group.version
def _identify_thing_group(self, thing_group_name, thing_group_arn):
# identify thing group
if thing_group_name is None and thing_group_arn is None:
raise InvalidRequestException(
" Both thingGroupArn and thingGroupName are empty. Need to specify at least one of them"
)
if thing_group_name is not None:
thing_group = self.describe_thing_group(thing_group_name)
if thing_group_arn and thing_group.arn != thing_group_arn:
raise InvalidRequestException(
"ThingGroupName thingGroupArn does not match specified thingGroupName in request"
)
elif thing_group_arn is not None:
if thing_group_arn not in self.thing_groups:
raise InvalidRequestException()
thing_group = self.thing_groups[thing_group_arn]
return thing_group
def _identify_thing(self, thing_name, thing_arn):
# identify thing
if thing_name is None and thing_arn is None:
raise InvalidRequestException(
"Both thingArn and thingName are empty. Need to specify at least one of them"
)
if thing_name is not None:
thing = self.describe_thing(thing_name)
if thing_arn and thing.arn != thing_arn:
raise InvalidRequestException(
"ThingName thingArn does not match specified thingName in request"
)
elif thing_arn is not None:
if thing_arn not in self.things:
raise InvalidRequestException()
thing = self.things[thing_arn]
return thing
def add_thing_to_thing_group(
self, thing_group_name, thing_group_arn, thing_name, thing_arn
):
thing_group = self._identify_thing_group(thing_group_name, thing_group_arn)
thing = self._identify_thing(thing_name, thing_arn)
if thing.arn in thing_group.things:
# aws ignores duplicate registration
return
thing_group.things[thing.arn] = thing
def remove_thing_from_thing_group(
self, thing_group_name, thing_group_arn, thing_name, thing_arn
):
thing_group = self._identify_thing_group(thing_group_name, thing_group_arn)
thing = self._identify_thing(thing_name, thing_arn)
if thing.arn not in thing_group.things:
# aws ignores non-registered thing
return
del thing_group.things[thing.arn]
def list_things_in_thing_group(self, thing_group_name):
"""
Pagination and the recursive-parameter is not yet implemented
"""
thing_group = self.describe_thing_group(thing_group_name)
return thing_group.things.values()
def list_thing_groups_for_thing(self, thing_name):
"""
Pagination is not yet implemented
"""
thing = self.describe_thing(thing_name)
all_thing_groups = self.list_thing_groups(None, None, None)
ret = []
for thing_group in all_thing_groups:
if thing.arn in thing_group.things:
ret.append(
{
"groupName": thing_group.thing_group_name,
"groupArn": thing_group.arn,
}
)
return ret
def update_thing_groups_for_thing(
self, thing_name, thing_groups_to_add, thing_groups_to_remove
):
thing = self.describe_thing(thing_name)
for thing_group_name in thing_groups_to_add:
thing_group = self.describe_thing_group(thing_group_name)
self.add_thing_to_thing_group(
thing_group.thing_group_name, None, thing.thing_name, None
)
for thing_group_name in thing_groups_to_remove:
thing_group = self.describe_thing_group(thing_group_name)
self.remove_thing_from_thing_group(
thing_group.thing_group_name, None, thing.thing_name, None
)
def create_job(
self,
job_id,
targets,
document_source,
document,
description,
presigned_url_config,
target_selection,
job_executions_rollout_config,
document_parameters,
):
job = FakeJob(
job_id,
targets,
document_source,
document,
description,
presigned_url_config,
target_selection,
job_executions_rollout_config,
document_parameters,
self.region_name,
)
self.jobs[job_id] = job
for thing_arn in targets:
thing_name = thing_arn.split(":")[-1].split("/")[-1]
job_execution = FakeJobExecution(job_id, thing_arn)
self.job_executions[(job_id, thing_name)] = job_execution
return job.job_arn, job_id, description
def describe_job(self, job_id):
jobs = [_ for _ in self.jobs.values() if _.job_id == job_id]
if len(jobs) == 0:
raise ResourceNotFoundException()
return jobs[0]
def delete_job(self, job_id, force):
job = self.jobs[job_id]
if job.status == "IN_PROGRESS" and force:
del self.jobs[job_id]
elif job.status != "IN_PROGRESS":
del self.jobs[job_id]
else:
raise InvalidStateTransitionException()
def cancel_job(self, job_id, reason_code, comment, force):
job = self.jobs[job_id]
job.reason_code = reason_code if reason_code is not None else job.reason_code
job.comment = comment if comment is not None else job.comment
job.force = force if force is not None and force != job.force else job.force
job.status = "CANCELED"
if job.status == "IN_PROGRESS" and force:
self.jobs[job_id] = job
elif job.status != "IN_PROGRESS":
self.jobs[job_id] = job
else:
raise InvalidStateTransitionException()
return job
def get_job_document(self, job_id):
return self.jobs[job_id]
def list_jobs(self, max_results, token):
"""
The following parameter are not yet implemented: Status, TargetSelection, ThingGroupName, ThingGroupId
"""
all_jobs = [_.to_dict() for _ in self.jobs.values()]
filtered_jobs = all_jobs
if token is None:
jobs = filtered_jobs[0:max_results]
next_token = str(max_results) if len(filtered_jobs) > max_results else None
else:
token = int(token)
jobs = filtered_jobs[token : token + max_results]
next_token = (
str(token + max_results)
if len(filtered_jobs) > token + max_results
else None
)
return jobs, next_token
def describe_job_execution(self, job_id, thing_name, execution_number):
try:
job_execution = self.job_executions[(job_id, thing_name)]
except KeyError:
raise ResourceNotFoundException()
if job_execution is None or (
execution_number is not None
and job_execution.execution_number != execution_number
):
raise ResourceNotFoundException()
return job_execution
def cancel_job_execution(self, job_id, thing_name, force):
"""
The parameters ExpectedVersion and StatusDetails are not yet implemented
"""
job_execution = self.job_executions[(job_id, thing_name)]
if job_execution is None:
raise ResourceNotFoundException()
job_execution.force_canceled = (
force if force is not None else job_execution.force_canceled
)
# TODO: implement expected_version and status_details (at most 10 can be specified)
if job_execution.status == "IN_PROGRESS" and force:
job_execution.status = "CANCELED"
self.job_executions[(job_id, thing_name)] = job_execution
elif job_execution.status != "IN_PROGRESS":
job_execution.status = "CANCELED"
self.job_executions[(job_id, thing_name)] = job_execution
else:
raise InvalidStateTransitionException()
def delete_job_execution(self, job_id, thing_name, execution_number, force):
job_execution = self.job_executions[(job_id, thing_name)]
if job_execution.execution_number != execution_number:
raise ResourceNotFoundException()
if job_execution.status == "IN_PROGRESS" and force:
del self.job_executions[(job_id, thing_name)]
elif job_execution.status != "IN_PROGRESS":
del self.job_executions[(job_id, thing_name)]
else:
raise InvalidStateTransitionException()
def list_job_executions_for_job(self, job_id, status, max_results, next_token):
job_executions = [
self.job_executions[je].to_dict()
for je in self.job_executions
if je[0] == job_id
]
if status is not None:
job_executions = list(
filter(
lambda elem: elem["jobExecutionSummary"].get("status") == status,
job_executions,
)
)
token = next_token
if token is None:
job_executions = job_executions[0:max_results]
next_token = str(max_results) if len(job_executions) > max_results else None
else:
token = int(token)
job_executions = job_executions[token : token + max_results]
next_token = (
str(token + max_results)
if len(job_executions) > token + max_results
else None
)
return job_executions, next_token
@paginate(PAGINATION_MODEL)
def list_job_executions_for_thing(self, thing_name, status):
job_executions = [
self.job_executions[je].to_dict()
for je in self.job_executions
if je[1] == thing_name
]
if status is not None:
job_executions = list(
filter(
lambda elem: elem["jobExecutionSummary"].get("status") == status,
job_executions,
)
)
return job_executions
def list_topic_rules(self):
return [r.to_dict() for r in self.rules.values()]
def get_topic_rule(self, rule_name):
if rule_name not in self.rules:
raise ResourceNotFoundException()
return self.rules[rule_name].to_get_dict()
def create_topic_rule(self, rule_name, sql, **kwargs):
if rule_name in self.rules:
raise ResourceAlreadyExistsException(
"Rule with given name already exists", "", self.rules[rule_name].arn
)
result = re.search(r"FROM\s+([^\s]*)", sql)
topic = result.group(1).strip("'") if result else None
self.rules[rule_name] = FakeRule(
rule_name=rule_name,
created_at=int(time.time()),
topic_pattern=topic,
sql=sql,
region_name=self.region_name,
**kwargs,
)
def replace_topic_rule(self, rule_name, **kwargs):
self.delete_topic_rule(rule_name)
self.create_topic_rule(rule_name, **kwargs)
def delete_topic_rule(self, rule_name):
if rule_name not in self.rules:
raise ResourceNotFoundException()
del self.rules[rule_name]
def enable_topic_rule(self, rule_name):
if rule_name not in self.rules:
raise ResourceNotFoundException()
self.rules[rule_name].rule_disabled = False
def disable_topic_rule(self, rule_name):
if rule_name not in self.rules:
raise ResourceNotFoundException()
self.rules[rule_name].rule_disabled = True
def create_domain_configuration(
self,
domain_configuration_name,
domain_name,
server_certificate_arns,
authorizer_config,
service_type,
):
"""
The ValidationCertificateArn-parameter is not yet implemented
"""
if domain_configuration_name in self.domain_configurations:
raise ResourceAlreadyExistsException(
"Domain configuration with given name already exists.",
self.domain_configurations[
domain_configuration_name
].domain_configuration_name,
self.domain_configurations[
domain_configuration_name
].domain_configuration_arn,
)
self.domain_configurations[domain_configuration_name] = FakeDomainConfiguration(
self.region_name,
domain_configuration_name,
domain_name,
server_certificate_arns,
"ENABLED",
service_type,
authorizer_config,
"CUSTOMER_MANAGED",
)
return self.domain_configurations[domain_configuration_name]
def delete_domain_configuration(self, domain_configuration_name):
if domain_configuration_name not in self.domain_configurations:
raise ResourceNotFoundException("The specified resource does not exist.")
del self.domain_configurations[domain_configuration_name]
def describe_domain_configuration(self, domain_configuration_name):
if domain_configuration_name not in self.domain_configurations:
raise ResourceNotFoundException("The specified resource does not exist.")
return self.domain_configurations[domain_configuration_name]
def list_domain_configurations(self):
return [_.to_dict() for _ in self.domain_configurations.values()]
def update_domain_configuration(
self,
domain_configuration_name,
authorizer_config,
domain_configuration_status,
remove_authorizer_config,
):
if domain_configuration_name not in self.domain_configurations:
raise ResourceNotFoundException("The specified resource does not exist.")
domain_configuration = self.domain_configurations[domain_configuration_name]
if authorizer_config is not None:
domain_configuration.authorizer_config = authorizer_config
if domain_configuration_status is not None:
domain_configuration.domain_configuration_status = (
domain_configuration_status
)
if remove_authorizer_config is not None and remove_authorizer_config is True:
domain_configuration.authorizer_config = None
return domain_configuration
def search_index(self, query_string):
"""
Pagination is not yet implemented. Only basic search queries are supported for now.
"""
things = [
thing for thing in self.things.values() if thing.matches(query_string)
]
groups = []
return [t.to_dict() for t in things], groups
iot_backends = BackendDict(IoTBackend, "iot")