From 303b1b92cbfaf6d8ec034d50ec4edfc15fb588d3 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Fri, 11 Aug 2023 07:20:44 +0000 Subject: [PATCH] IOTData: Support named shadows (#6633) --- moto/iot/models.py | 7 ++-- moto/iotdata/models.py | 45 +++++++++++++++--------- moto/iotdata/responses.py | 21 +++++++++--- tests/test_iotdata/test_iotdata.py | 55 ++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 23 deletions(-) diff --git a/moto/iot/models.py b/moto/iot/models.py index 9fe9014a7..f03bbe104 100644 --- a/moto/iot/models.py +++ b/moto/iot/models.py @@ -8,7 +8,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization, hashes from datetime import datetime, timedelta -from typing import Any, Dict, List, Tuple, Optional, Pattern, Iterable +from typing import Any, Dict, List, Tuple, Optional, Pattern, Iterable, TYPE_CHECKING from .utils import PAGINATION_MODEL @@ -27,6 +27,9 @@ from .exceptions import ( ThingStillAttached, ) +if TYPE_CHECKING: + from moto.iotdata.models import FakeShadow + class FakeThing(BaseModel): def __init__( @@ -46,7 +49,7 @@ class FakeThing(BaseModel): # TODO: we need to handle "version"? # for iot-data - self.thing_shadow: Any = None + self.thing_shadows: Dict[Optional[str], FakeShadow] = {} def matches(self, query_string: str) -> bool: if query_string == "*": diff --git a/moto/iotdata/models.py b/moto/iotdata/models.py index b8624528b..e8a283fe9 100644 --- a/moto/iotdata/models.py +++ b/moto/iotdata/models.py @@ -154,7 +154,9 @@ class IoTDataPlaneBackend(BaseBackend): def iot_backend(self) -> IoTBackend: return iot_backends[self.account_id][self.region_name] - def update_thing_shadow(self, thing_name: str, payload: str) -> FakeShadow: + def update_thing_shadow( + self, thing_name: str, payload: str, shadow_name: Optional[str] + ) -> FakeShadow: """ spec of payload: - need node `state` @@ -175,34 +177,43 @@ class IoTDataPlaneBackend(BaseBackend): if any(_ for _ in _payload["state"].keys() if _ not in ["desired", "reported"]): raise InvalidRequestException("State contains an invalid node") - if "version" in _payload and thing.thing_shadow.version != _payload["version"]: + thing_shadow = thing.thing_shadows.get(shadow_name) + if "version" in _payload and thing_shadow.version != _payload["version"]: # type: ignore raise ConflictException("Version conflict") - new_shadow = FakeShadow.create_from_previous_version( - thing.thing_shadow, _payload - ) - thing.thing_shadow = new_shadow - return thing.thing_shadow + new_shadow = FakeShadow.create_from_previous_version(thing_shadow, _payload) + thing.thing_shadows[shadow_name] = new_shadow + return new_shadow - def get_thing_shadow(self, thing_name: str) -> FakeShadow: + def get_thing_shadow( + self, thing_name: str, shadow_name: Optional[str] + ) -> FakeShadow: thing = self.iot_backend.describe_thing(thing_name) + thing_shadow = thing.thing_shadows.get(shadow_name) - if thing.thing_shadow is None or thing.thing_shadow.deleted: + if thing_shadow is None or thing_shadow.deleted: raise ResourceNotFoundException() - return thing.thing_shadow + return thing_shadow - def delete_thing_shadow(self, thing_name: str) -> FakeShadow: + def delete_thing_shadow( + self, thing_name: str, shadow_name: Optional[str] + ) -> FakeShadow: thing = self.iot_backend.describe_thing(thing_name) - if thing.thing_shadow is None: + thing_shadow = thing.thing_shadows.get(shadow_name) + if thing_shadow is None: raise ResourceNotFoundException() payload = None - new_shadow = FakeShadow.create_from_previous_version( - thing.thing_shadow, payload - ) - thing.thing_shadow = new_shadow - return thing.thing_shadow + new_shadow = FakeShadow.create_from_previous_version(thing_shadow, payload) + thing.thing_shadows[shadow_name] = new_shadow + return new_shadow def publish(self, topic: str, payload: str) -> None: self.published_payloads.append((topic, payload)) + def list_named_shadows_for_thing(self, thing_name: str) -> List[FakeShadow]: + thing = self.iot_backend.describe_thing(thing_name) + return [ + shadow for name, shadow in thing.thing_shadows.items() if name is not None + ] + iotdata_backends = BackendDict(IoTDataPlaneBackend, "iot") diff --git a/moto/iotdata/responses.py b/moto/iotdata/responses.py index ca35653d9..e39ef5a8c 100644 --- a/moto/iotdata/responses.py +++ b/moto/iotdata/responses.py @@ -26,20 +26,28 @@ class IoTDataPlaneResponse(BaseResponse): def update_thing_shadow(self) -> str: thing_name = self._get_param("thingName") - payload = self.body + shadow_name = self.querystring.get("name", [None])[0] payload = self.iotdata_backend.update_thing_shadow( - thing_name=thing_name, payload=payload + thing_name=thing_name, + payload=self.body, + shadow_name=shadow_name, ) return json.dumps(payload.to_response_dict()) def get_thing_shadow(self) -> str: thing_name = self._get_param("thingName") - payload = self.iotdata_backend.get_thing_shadow(thing_name=thing_name) + shadow_name = self.querystring.get("name", [None])[0] + payload = self.iotdata_backend.get_thing_shadow( + thing_name=thing_name, shadow_name=shadow_name + ) return json.dumps(payload.to_dict()) def delete_thing_shadow(self) -> str: thing_name = self._get_param("thingName") - payload = self.iotdata_backend.delete_thing_shadow(thing_name=thing_name) + shadow_name = self.querystring.get("name", [None])[0] + payload = self.iotdata_backend.delete_thing_shadow( + thing_name=thing_name, shadow_name=shadow_name + ) return json.dumps(payload.to_dict()) def publish(self) -> str: @@ -49,3 +57,8 @@ class IoTDataPlaneResponse(BaseResponse): topic = unquote(topic) if "%" in topic else topic self.iotdata_backend.publish(topic=topic, payload=self.body) return json.dumps(dict()) + + def list_named_shadows_for_thing(self) -> str: + thing_name = self._get_param("thingName") + shadows = self.iotdata_backend.list_named_shadows_for_thing(thing_name) + return json.dumps({"results": [shadow.to_dict() for shadow in shadows]}) diff --git a/tests/test_iotdata/test_iotdata.py b/tests/test_iotdata/test_iotdata.py index c51bffe39..983b177c2 100644 --- a/tests/test_iotdata/test_iotdata.py +++ b/tests/test_iotdata/test_iotdata.py @@ -93,6 +93,61 @@ def test_update(): assert ex.value.response["Error"]["Message"] == "Version conflict" +@mock_iot +@mock_iotdata +def test_create_named_shadows(): + iot_client = boto3.client("iot", region_name="ap-northeast-1") + client = boto3.client("iot-data", region_name="ap-northeast-1") + thing_name = "my-thing" + iot_client.create_thing(thingName=thing_name) + + # default shadow + default_payload = json.dumps({"state": {"desired": {"name": "default"}}}) + res = client.update_thing_shadow(thingName=thing_name, payload=default_payload) + payload = json.loads(res["payload"].read()) + assert payload["state"] == {"desired": {"name": "default"}} + + # Create named shadows + for name in ["shadow1", "shadow2"]: + named_payload = json.dumps({"state": {"reported": {"name": name}}}).encode( + "utf-8" + ) + client.update_thing_shadow( + thingName=thing_name, payload=named_payload, shadowName=name + ) + + res = client.get_thing_shadow(thingName=thing_name, shadowName=name) + payload = json.loads(res["payload"].read()) + assert payload["state"]["reported"] == {"name": name} + + # List named shadows + shadows = client.list_named_shadows_for_thing(thingName=thing_name)["results"] + assert len(shadows) == 2 + + for shadow in shadows: + shadow.pop("metadata") + shadow.pop("timestamp") + shadow.pop("version") + + # Verify both named shadows are present + for name in ["shadow1", "shadow2"]: + assert { + "state": {"reported": {"name": name}, "delta": {"name": name}} + } in shadows + + # Verify we can delete a named shadow + client.delete_thing_shadow(thingName=thing_name, shadowName="shadow2") + + with pytest.raises(ClientError): + client.get_thing_shadow(thingName="shadow1") + + # The default and other named shadow are still there + assert "payload" in client.get_thing_shadow(thingName=thing_name) + assert "payload" in client.get_thing_shadow( + thingName=thing_name, shadowName="shadow1" + ) + + @mock_iotdata def test_publish(): region_name = "ap-northeast-1"