SNS: Forward MessageDeduplicationId to SQS queues (#6255)

This commit is contained in:
Niklas Janlert 2023-04-25 17:33:22 +02:00 committed by GitHub
parent 65611082be
commit 92da03b1dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 15 deletions

View File

@ -58,7 +58,14 @@ class Topic(CloudFormationModel):
self.fifo_topic = "false"
self.content_based_deduplication = "false"
def publish(self, message, subject=None, message_attributes=None, group_id=None):
def publish(
self,
message,
subject=None,
message_attributes=None,
group_id=None,
deduplication_id=None,
):
message_id = str(mock_random.uuid4())
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
for subscription in subscriptions:
@ -68,6 +75,7 @@ class Topic(CloudFormationModel):
subject=subject,
message_attributes=message_attributes,
group_id=group_id,
deduplication_id=deduplication_id,
)
self.sent_notifications.append(
(message_id, message, subject, message_attributes, group_id)
@ -189,7 +197,13 @@ class Subscription(BaseModel):
self.confirmed = False
def publish(
self, message, message_id, subject=None, message_attributes=None, group_id=None
self,
message,
message_id,
subject=None,
message_attributes=None,
group_id=None,
deduplication_id=None,
):
if not self._matches_filter_policy(message_attributes):
return
@ -211,6 +225,7 @@ class Subscription(BaseModel):
indent=2,
separators=(",", ": "),
),
deduplication_id=deduplication_id,
group_id=group_id,
)
else:
@ -232,6 +247,7 @@ class Subscription(BaseModel):
queue_name,
message,
message_attributes=raw_message_attributes,
deduplication_id=deduplication_id,
group_id=group_id,
)
elif self.protocol in ["http", "https"]:
@ -640,6 +656,7 @@ class SNSBackend(BaseBackend):
subject=None,
message_attributes=None,
group_id=None,
deduplication_id=None,
):
if subject is not None and len(subject) > 100:
# Note that the AWS docs around length are wrong: https://github.com/getmoto/moto/issues/1503
@ -663,23 +680,29 @@ class SNSBackend(BaseBackend):
topic = self.get_topic(arn)
fifo_topic = topic.fifo_topic == "true"
if group_id is None:
if fifo_topic:
if not group_id:
# MessageGroupId is a mandatory parameter for all
# messages in a fifo queue
if fifo_topic:
raise MissingParameter("MessageGroupId")
else:
if not fifo_topic:
msg = (
f"Value {group_id} for parameter MessageGroupId is invalid. "
"Reason: The request include parameter that is not valid for this queue type."
deduplication_id_required = topic.content_based_deduplication == "false"
if not deduplication_id and deduplication_id_required:
raise InvalidParameterValue(
"The topic should either have ContentBasedDeduplication enabled or MessageDeduplicationId provided explicitly"
)
raise InvalidParameterValue(msg)
elif group_id or deduplication_id:
parameter = "MessageGroupId" if group_id else "MessageDeduplicationId"
raise InvalidParameterValue(
f"Invalid parameter: {parameter} "
f"Reason: The request includes {parameter} parameter that is not valid for this topic type"
)
message_id = topic.publish(
message,
subject=subject,
message_attributes=message_attributes,
group_id=group_id,
deduplication_id=deduplication_id,
)
except SNSNotFoundError:
endpoint = self.get_endpoint(arn)
@ -1023,7 +1046,7 @@ class SNSBackend(BaseBackend):
{
"Id": entry["Id"],
"Code": "InvalidParameter",
"Message": f"Invalid parameter: {e.message}",
"Message": e.message,
"SenderFault": True,
}
)

View File

@ -334,6 +334,7 @@ class SNSResponse(BaseResponse):
phone_number = self._get_param("PhoneNumber")
subject = self._get_param("Subject")
message_group_id = self._get_param("MessageGroupId")
message_deduplication_id = self._get_param("MessageDeduplicationId")
message_attributes = self._parse_message_attributes()
@ -362,6 +363,7 @@ class SNSResponse(BaseResponse):
subject=subject,
message_attributes=message_attributes,
group_id=message_group_id,
deduplication_id=message_deduplication_id,
)
except ValueError as err:
error_response = self._error("InvalidParameter", str(err))

View File

@ -100,7 +100,7 @@ def test_publish_batch_standard_with_message_group_id():
{
"Id": "id_2",
"Code": "InvalidParameter",
"Message": "Invalid parameter: Value mgid for parameter MessageGroupId is invalid. Reason: The request include parameter that is not valid for this queue type.",
"Message": "Invalid parameter: MessageGroupId Reason: The request includes MessageGroupId parameter that is not valid for this topic type",
"SenderFault": True,
}
)

View File

@ -103,6 +103,87 @@ def test_publish_to_sqs_fifo():
topic.publish(Message="message", MessageGroupId="message_group_id")
@mock_sns
@mock_sqs
def test_publish_to_sqs_fifo_with_deduplication_id():
sns = boto3.resource("sns", region_name="us-east-1")
topic = sns.create_topic(
Name="topic.fifo",
Attributes={"FifoTopic": "true"},
)
sqs = boto3.resource("sqs", region_name="us-east-1")
queue = sqs.create_queue(
QueueName="queue.fifo",
Attributes={"FifoQueue": "true"},
)
topic.subscribe(
Protocol="sqs",
Endpoint=queue.attributes["QueueArn"],
Attributes={"RawMessageDelivery": "true"},
)
message = '{"msg": "hello"}'
with freeze_time("2015-01-01 12:00:00"):
topic.publish(
Message=message,
MessageGroupId="message_group_id",
MessageDeduplicationId="message_deduplication_id",
)
with freeze_time("2015-01-01 12:00:01"):
messages = queue.receive_messages(
MaxNumberOfMessages=1,
AttributeNames=["MessageDeduplicationId", "MessageGroupId"],
)
messages[0].attributes["MessageGroupId"].should.equal("message_group_id")
messages[0].attributes["MessageDeduplicationId"].should.equal(
"message_deduplication_id"
)
@mock_sns
@mock_sqs
def test_publish_to_sqs_fifo_raw_with_deduplication_id():
sns = boto3.resource("sns", region_name="us-east-1")
topic = sns.create_topic(
Name="topic.fifo",
Attributes={"FifoTopic": "true"},
)
sqs = boto3.resource("sqs", region_name="us-east-1")
queue = sqs.create_queue(
QueueName="queue.fifo",
Attributes={"FifoQueue": "true"},
)
subscription = topic.subscribe(
Protocol="sqs", Endpoint=queue.attributes["QueueArn"]
)
subscription.set_attributes(
AttributeName="RawMessageDelivery", AttributeValue="true"
)
message = "my message"
with freeze_time("2015-01-01 12:00:00"):
topic.publish(
Message=message,
MessageGroupId="message_group_id",
MessageDeduplicationId="message_deduplication_id",
)
with freeze_time("2015-01-01 12:00:01"):
messages = queue.receive_messages(
MaxNumberOfMessages=1,
AttributeNames=["MessageDeduplicationId", "MessageGroupId"],
)
messages[0].attributes["MessageGroupId"].should.equal("message_group_id")
messages[0].attributes["MessageDeduplicationId"].should.equal(
"message_deduplication_id"
)
@mock_sqs
@mock_sns
def test_publish_to_sqs_bad():
@ -528,7 +609,7 @@ def test_publish_group_id_to_non_fifo():
with pytest.raises(
ClientError,
match="The request include parameter that is not valid for this queue type",
match="The request includes MessageGroupId parameter that is not valid for this topic type",
):
topic.publish(Message="message", MessageGroupId="message_group_id")
@ -536,6 +617,45 @@ def test_publish_group_id_to_non_fifo():
topic.publish(Message="message")
@mock_sns
def test_publish_fifo_needs_deduplication_id():
sns = boto3.resource("sns", region_name="us-east-1")
topic = sns.create_topic(
Name="topic.fifo",
Attributes={"FifoTopic": "true"},
)
with pytest.raises(
ClientError,
match="The topic should either have ContentBasedDeduplication enabled or MessageDeduplicationId provided explicitly",
):
topic.publish(Message="message", MessageGroupId="message_group_id")
# message deduplication id included - OK
topic.publish(
Message="message",
MessageGroupId="message_group_id",
MessageDeduplicationId="message_deduplication_id",
)
@mock_sns
def test_publish_deduplication_id_to_non_fifo():
sns = boto3.resource("sns", region_name="us-east-1")
topic = sns.create_topic(Name="topic")
with pytest.raises(
ClientError,
match="The request includes MessageDeduplicationId parameter that is not valid for this topic type",
):
topic.publish(
Message="message", MessageDeduplicationId="message_deduplication_id"
)
# message group not included - OK
topic.publish(Message="message")
def _setup_filter_policy_test(filter_policy):
sns = boto3.resource("sns", region_name="us-east-1")
topic = sns.create_topic(Name="some-topic")