diff --git a/moto/sns/models.py b/moto/sns/models.py index 97e4d837f..7cc3f461b 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -58,7 +58,14 @@ class Topic(CloudFormationModel): self.fifo_topic = "false" self.content_based_deduplication = "false" - def publish(self, message, subject=None, message_attributes=None, group_id=None): + def publish( + self, + message, + subject=None, + message_attributes=None, + group_id=None, + deduplication_id=None, + ): message_id = str(mock_random.uuid4()) subscriptions, _ = self.sns_backend.list_subscriptions(self.arn) for subscription in subscriptions: @@ -68,6 +75,7 @@ class Topic(CloudFormationModel): subject=subject, message_attributes=message_attributes, group_id=group_id, + deduplication_id=deduplication_id, ) self.sent_notifications.append( (message_id, message, subject, message_attributes, group_id) @@ -189,7 +197,13 @@ class Subscription(BaseModel): self.confirmed = False def publish( - self, message, message_id, subject=None, message_attributes=None, group_id=None + self, + message, + message_id, + subject=None, + message_attributes=None, + group_id=None, + deduplication_id=None, ): if not self._matches_filter_policy(message_attributes): return @@ -211,6 +225,7 @@ class Subscription(BaseModel): indent=2, separators=(",", ": "), ), + deduplication_id=deduplication_id, group_id=group_id, ) else: @@ -232,6 +247,7 @@ class Subscription(BaseModel): queue_name, message, message_attributes=raw_message_attributes, + deduplication_id=deduplication_id, group_id=group_id, ) elif self.protocol in ["http", "https"]: @@ -640,6 +656,7 @@ class SNSBackend(BaseBackend): subject=None, message_attributes=None, group_id=None, + deduplication_id=None, ): if subject is not None and len(subject) > 100: # Note that the AWS docs around length are wrong: https://github.com/getmoto/moto/issues/1503 @@ -663,23 +680,29 @@ class SNSBackend(BaseBackend): topic = self.get_topic(arn) fifo_topic = topic.fifo_topic == "true" - if group_id is None: - # MessageGroupId is a mandatory parameter for all - # messages in a fifo queue - if fifo_topic: + if fifo_topic: + if not group_id: + # MessageGroupId is a mandatory parameter for all + # messages in a fifo queue raise MissingParameter("MessageGroupId") - else: - if not fifo_topic: - msg = ( - f"Value {group_id} for parameter MessageGroupId is invalid. " - "Reason: The request include parameter that is not valid for this queue type." + deduplication_id_required = topic.content_based_deduplication == "false" + if not deduplication_id and deduplication_id_required: + raise InvalidParameterValue( + "The topic should either have ContentBasedDeduplication enabled or MessageDeduplicationId provided explicitly" ) - raise InvalidParameterValue(msg) + elif group_id or deduplication_id: + parameter = "MessageGroupId" if group_id else "MessageDeduplicationId" + raise InvalidParameterValue( + f"Invalid parameter: {parameter} " + f"Reason: The request includes {parameter} parameter that is not valid for this topic type" + ) + message_id = topic.publish( message, subject=subject, message_attributes=message_attributes, group_id=group_id, + deduplication_id=deduplication_id, ) except SNSNotFoundError: endpoint = self.get_endpoint(arn) @@ -1023,7 +1046,7 @@ class SNSBackend(BaseBackend): { "Id": entry["Id"], "Code": "InvalidParameter", - "Message": f"Invalid parameter: {e.message}", + "Message": e.message, "SenderFault": True, } ) diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 9e6d26c9a..21372218e 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -334,6 +334,7 @@ class SNSResponse(BaseResponse): phone_number = self._get_param("PhoneNumber") subject = self._get_param("Subject") message_group_id = self._get_param("MessageGroupId") + message_deduplication_id = self._get_param("MessageDeduplicationId") message_attributes = self._parse_message_attributes() @@ -362,6 +363,7 @@ class SNSResponse(BaseResponse): subject=subject, message_attributes=message_attributes, group_id=message_group_id, + deduplication_id=message_deduplication_id, ) except ValueError as err: error_response = self._error("InvalidParameter", str(err)) diff --git a/tests/test_sns/test_publish_batch.py b/tests/test_sns/test_publish_batch.py index a0d65162a..517d1f2f9 100644 --- a/tests/test_sns/test_publish_batch.py +++ b/tests/test_sns/test_publish_batch.py @@ -100,7 +100,7 @@ def test_publish_batch_standard_with_message_group_id(): { "Id": "id_2", "Code": "InvalidParameter", - "Message": "Invalid parameter: Value mgid for parameter MessageGroupId is invalid. Reason: The request include parameter that is not valid for this queue type.", + "Message": "Invalid parameter: MessageGroupId Reason: The request includes MessageGroupId parameter that is not valid for this topic type", "SenderFault": True, } ) diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 76c441167..40c7e6a6e 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -103,6 +103,87 @@ def test_publish_to_sqs_fifo(): topic.publish(Message="message", MessageGroupId="message_group_id") +@mock_sns +@mock_sqs +def test_publish_to_sqs_fifo_with_deduplication_id(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic( + Name="topic.fifo", + Attributes={"FifoTopic": "true"}, + ) + + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="queue.fifo", + Attributes={"FifoQueue": "true"}, + ) + + topic.subscribe( + Protocol="sqs", + Endpoint=queue.attributes["QueueArn"], + Attributes={"RawMessageDelivery": "true"}, + ) + + message = '{"msg": "hello"}' + with freeze_time("2015-01-01 12:00:00"): + topic.publish( + Message=message, + MessageGroupId="message_group_id", + MessageDeduplicationId="message_deduplication_id", + ) + + with freeze_time("2015-01-01 12:00:01"): + messages = queue.receive_messages( + MaxNumberOfMessages=1, + AttributeNames=["MessageDeduplicationId", "MessageGroupId"], + ) + messages[0].attributes["MessageGroupId"].should.equal("message_group_id") + messages[0].attributes["MessageDeduplicationId"].should.equal( + "message_deduplication_id" + ) + + +@mock_sns +@mock_sqs +def test_publish_to_sqs_fifo_raw_with_deduplication_id(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic( + Name="topic.fifo", + Attributes={"FifoTopic": "true"}, + ) + + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="queue.fifo", + Attributes={"FifoQueue": "true"}, + ) + + subscription = topic.subscribe( + Protocol="sqs", Endpoint=queue.attributes["QueueArn"] + ) + subscription.set_attributes( + AttributeName="RawMessageDelivery", AttributeValue="true" + ) + + message = "my message" + with freeze_time("2015-01-01 12:00:00"): + topic.publish( + Message=message, + MessageGroupId="message_group_id", + MessageDeduplicationId="message_deduplication_id", + ) + + with freeze_time("2015-01-01 12:00:01"): + messages = queue.receive_messages( + MaxNumberOfMessages=1, + AttributeNames=["MessageDeduplicationId", "MessageGroupId"], + ) + messages[0].attributes["MessageGroupId"].should.equal("message_group_id") + messages[0].attributes["MessageDeduplicationId"].should.equal( + "message_deduplication_id" + ) + + @mock_sqs @mock_sns def test_publish_to_sqs_bad(): @@ -528,7 +609,7 @@ def test_publish_group_id_to_non_fifo(): with pytest.raises( ClientError, - match="The request include parameter that is not valid for this queue type", + match="The request includes MessageGroupId parameter that is not valid for this topic type", ): topic.publish(Message="message", MessageGroupId="message_group_id") @@ -536,6 +617,45 @@ def test_publish_group_id_to_non_fifo(): topic.publish(Message="message") +@mock_sns +def test_publish_fifo_needs_deduplication_id(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic( + Name="topic.fifo", + Attributes={"FifoTopic": "true"}, + ) + + with pytest.raises( + ClientError, + match="The topic should either have ContentBasedDeduplication enabled or MessageDeduplicationId provided explicitly", + ): + topic.publish(Message="message", MessageGroupId="message_group_id") + + # message deduplication id included - OK + topic.publish( + Message="message", + MessageGroupId="message_group_id", + MessageDeduplicationId="message_deduplication_id", + ) + + +@mock_sns +def test_publish_deduplication_id_to_non_fifo(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="topic") + + with pytest.raises( + ClientError, + match="The request includes MessageDeduplicationId parameter that is not valid for this topic type", + ): + topic.publish( + Message="message", MessageDeduplicationId="message_deduplication_id" + ) + + # message group not included - OK + topic.publish(Message="message") + + def _setup_filter_policy_test(filter_policy): sns = boto3.resource("sns", region_name="us-east-1") topic = sns.create_topic(Name="some-topic")