From 4e5180a9ba7f6ba622184dc5b4d51e936c49d9d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sindri=20Gu=C3=B0mundsson?= Date: Thu, 6 Jan 2022 16:04:55 +0000 Subject: [PATCH] Fix publishing to fifo sns->sqs subscription (#4738) --- moto/sns/models.py | 34 +++++++++++++-- moto/sns/responses.py | 2 + tests/test_sns/test_publishing_boto3.py | 56 ++++++++++++++++++++++++- 3 files changed, 87 insertions(+), 5 deletions(-) diff --git a/moto/sns/models.py b/moto/sns/models.py index b2a8a891f..da20b3555 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -13,6 +13,7 @@ from moto.core.utils import ( BackendDict, ) from moto.sqs import sqs_backends +from moto.sqs.exceptions import MissingParameter from .exceptions import ( SNSNotFoundError, @@ -55,7 +56,7 @@ class Topic(CloudFormationModel): self.fifo_topic = "false" self.content_based_deduplication = "false" - def publish(self, message, subject=None, message_attributes=None): + def publish(self, message, subject=None, message_attributes=None, group_id=None): message_id = str(uuid.uuid4()) subscriptions, _ = self.sns_backend.list_subscriptions(self.arn) for subscription in subscriptions: @@ -64,6 +65,7 @@ class Topic(CloudFormationModel): message_id, subject=subject, message_attributes=message_attributes, + group_id=group_id, ) return message_id @@ -177,7 +179,9 @@ class Subscription(BaseModel): self._filter_policy = None # filter policy as a dict, not json. self.confirmed = False - def publish(self, message, message_id, subject=None, message_attributes=None): + def publish( + self, message, message_id, subject=None, message_attributes=None, group_id=None + ): if not self._matches_filter_policy(message_attributes): return @@ -198,6 +202,7 @@ class Subscription(BaseModel): indent=2, separators=(",", ": "), ), + group_id=group_id, ) else: raw_message_attributes = {} @@ -215,7 +220,10 @@ class Subscription(BaseModel): } sqs_backends[region].send_message( - queue_name, message, message_attributes=raw_message_attributes + queue_name, + message, + message_attributes=raw_message_attributes, + group_id=group_id, ) elif self.protocol in ["http", "https"]: post_data = self.get_post_data(message, message_id, subject) @@ -568,6 +576,7 @@ class SNSBackend(BaseBackend): phone_number=None, subject=None, message_attributes=None, + group_id=None, ): if subject is not None and len(subject) > 100: # Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503 @@ -589,8 +598,25 @@ class SNSBackend(BaseBackend): try: topic = self.get_topic(arn) + + fifo_topic = topic.fifo_topic == "true" + if group_id is None: + # MessageGroupId is a mandatory parameter for all + # messages in a fifo queue + if fifo_topic: + raise MissingParameter("MessageGroupId") + else: + if not fifo_topic: + msg = ( + "Value {} for parameter MessageGroupId is invalid. " + "Reason: The request include parameter that is not valid for this queue type." + ).format(group_id) + raise InvalidParameterValue(msg) message_id = topic.publish( - message, subject=subject, message_attributes=message_attributes + message, + subject=subject, + message_attributes=message_attributes, + group_id=group_id, ) except SNSNotFoundError: endpoint = self.get_endpoint(arn) diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 2c088288c..416d9704f 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -328,6 +328,7 @@ class SNSResponse(BaseResponse): topic_arn = self._get_param("TopicArn") phone_number = self._get_param("PhoneNumber") subject = self._get_param("Subject") + message_group_id = self._get_param("MessageGroupId") message_attributes = self._parse_message_attributes() @@ -355,6 +356,7 @@ class SNSResponse(BaseResponse): phone_number=phone_number, subject=subject, message_attributes=message_attributes, + group_id=message_group_id, ) except ValueError as err: error_response = self._error("InvalidParameter", str(err)) diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 8dc6c5eb9..ce2d98dee 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -83,6 +83,27 @@ def test_publish_to_sqs_raw(): messages[0].body.should.equal(message) +@mock_sns +@mock_sqs +def test_publish_to_sqs_fifo(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic( + Name="topic.fifo", + Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true",}, + ) + + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="queue.fifo", + Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "true",}, + ) + topic.subscribe( + Protocol="sqs", Endpoint=queue.attributes["QueueArn"], + ) + + topic.publish(Message="message", MessageGroupId="message_group_id") + + @mock_sqs @mock_sns def test_publish_to_sqs_bad(): @@ -154,7 +175,7 @@ def test_publish_to_sqs_msg_attr_byte_value(): sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="test-queue") conn.subscribe( - TopicArn=topic_arn, Protocol="sqs", Endpoint=queue.attributes["QueueArn"], + TopicArn=topic_arn, Protocol="sqs", Endpoint=queue.attributes["QueueArn"] ) queue_raw = sqs.create_queue(QueueName="test-queue-raw") conn.subscribe( @@ -431,6 +452,39 @@ def test_publish_message_too_long(): topic.publish(Message="".join(["." for i in range(0, 262144)])) +@mock_sns +def test_publish_fifo_needs_group_id(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic( + Name="topic.fifo", + Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true",}, + ) + + with pytest.raises( + ClientError, match="The request must contain the parameter MessageGroupId" + ): + topic.publish(Message="message") + + # message group included - OK + topic.publish(Message="message", MessageGroupId="message_group_id") + + +@mock_sns +@mock_sqs +def test_publish_group_id_to_non_fifo(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="topic") + + with pytest.raises( + ClientError, + match="The request include parameter that is not valid for this queue type", + ): + topic.publish(Message="message", MessageGroupId="message_group_id") + + # message group not included - OK + topic.publish(Message="message") + + def _setup_filter_policy_test(filter_policy): sns = boto3.resource("sns", region_name="us-east-1") topic = sns.create_topic(Name="some-topic")