Fix publishing to fifo sns->sqs subscription (#4738)

This commit is contained in:
Sindri Guðmundsson 2022-01-06 16:04:55 +00:00 committed by GitHub
parent ee6b2bfff8
commit 4e5180a9ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 5 deletions

View File

@ -13,6 +13,7 @@ from moto.core.utils import (
BackendDict, BackendDict,
) )
from moto.sqs import sqs_backends from moto.sqs import sqs_backends
from moto.sqs.exceptions import MissingParameter
from .exceptions import ( from .exceptions import (
SNSNotFoundError, SNSNotFoundError,
@ -55,7 +56,7 @@ class Topic(CloudFormationModel):
self.fifo_topic = "false" self.fifo_topic = "false"
self.content_based_deduplication = "false" self.content_based_deduplication = "false"
def publish(self, message, subject=None, message_attributes=None): def publish(self, message, subject=None, message_attributes=None, group_id=None):
message_id = str(uuid.uuid4()) message_id = str(uuid.uuid4())
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn) subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
for subscription in subscriptions: for subscription in subscriptions:
@ -64,6 +65,7 @@ class Topic(CloudFormationModel):
message_id, message_id,
subject=subject, subject=subject,
message_attributes=message_attributes, message_attributes=message_attributes,
group_id=group_id,
) )
return message_id return message_id
@ -177,7 +179,9 @@ class Subscription(BaseModel):
self._filter_policy = None # filter policy as a dict, not json. self._filter_policy = None # filter policy as a dict, not json.
self.confirmed = False self.confirmed = False
def publish(self, message, message_id, subject=None, message_attributes=None): def publish(
self, message, message_id, subject=None, message_attributes=None, group_id=None
):
if not self._matches_filter_policy(message_attributes): if not self._matches_filter_policy(message_attributes):
return return
@ -198,6 +202,7 @@ class Subscription(BaseModel):
indent=2, indent=2,
separators=(",", ": "), separators=(",", ": "),
), ),
group_id=group_id,
) )
else: else:
raw_message_attributes = {} raw_message_attributes = {}
@ -215,7 +220,10 @@ class Subscription(BaseModel):
} }
sqs_backends[region].send_message( sqs_backends[region].send_message(
queue_name, message, message_attributes=raw_message_attributes queue_name,
message,
message_attributes=raw_message_attributes,
group_id=group_id,
) )
elif self.protocol in ["http", "https"]: elif self.protocol in ["http", "https"]:
post_data = self.get_post_data(message, message_id, subject) post_data = self.get_post_data(message, message_id, subject)
@ -568,6 +576,7 @@ class SNSBackend(BaseBackend):
phone_number=None, phone_number=None,
subject=None, subject=None,
message_attributes=None, message_attributes=None,
group_id=None,
): ):
if subject is not None and len(subject) > 100: if subject is not None and len(subject) > 100:
# Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503 # Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503
@ -589,8 +598,25 @@ class SNSBackend(BaseBackend):
try: try:
topic = self.get_topic(arn) topic = self.get_topic(arn)
fifo_topic = topic.fifo_topic == "true"
if group_id is None:
# MessageGroupId is a mandatory parameter for all
# messages in a fifo queue
if fifo_topic:
raise MissingParameter("MessageGroupId")
else:
if not fifo_topic:
msg = (
"Value {} for parameter MessageGroupId is invalid. "
"Reason: The request include parameter that is not valid for this queue type."
).format(group_id)
raise InvalidParameterValue(msg)
message_id = topic.publish( message_id = topic.publish(
message, subject=subject, message_attributes=message_attributes message,
subject=subject,
message_attributes=message_attributes,
group_id=group_id,
) )
except SNSNotFoundError: except SNSNotFoundError:
endpoint = self.get_endpoint(arn) endpoint = self.get_endpoint(arn)

View File

@ -328,6 +328,7 @@ class SNSResponse(BaseResponse):
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
phone_number = self._get_param("PhoneNumber") phone_number = self._get_param("PhoneNumber")
subject = self._get_param("Subject") subject = self._get_param("Subject")
message_group_id = self._get_param("MessageGroupId")
message_attributes = self._parse_message_attributes() message_attributes = self._parse_message_attributes()
@ -355,6 +356,7 @@ class SNSResponse(BaseResponse):
phone_number=phone_number, phone_number=phone_number,
subject=subject, subject=subject,
message_attributes=message_attributes, message_attributes=message_attributes,
group_id=message_group_id,
) )
except ValueError as err: except ValueError as err:
error_response = self._error("InvalidParameter", str(err)) error_response = self._error("InvalidParameter", str(err))

View File

@ -83,6 +83,27 @@ def test_publish_to_sqs_raw():
messages[0].body.should.equal(message) messages[0].body.should.equal(message)
@mock_sns
@mock_sqs
def test_publish_to_sqs_fifo():
sns = boto3.resource("sns", region_name="us-east-1")
topic = sns.create_topic(
Name="topic.fifo",
Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true",},
)
sqs = boto3.resource("sqs", region_name="us-east-1")
queue = sqs.create_queue(
QueueName="queue.fifo",
Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "true",},
)
topic.subscribe(
Protocol="sqs", Endpoint=queue.attributes["QueueArn"],
)
topic.publish(Message="message", MessageGroupId="message_group_id")
@mock_sqs @mock_sqs
@mock_sns @mock_sns
def test_publish_to_sqs_bad(): def test_publish_to_sqs_bad():
@ -154,7 +175,7 @@ def test_publish_to_sqs_msg_attr_byte_value():
sqs = boto3.resource("sqs", region_name="us-east-1") sqs = boto3.resource("sqs", region_name="us-east-1")
queue = sqs.create_queue(QueueName="test-queue") queue = sqs.create_queue(QueueName="test-queue")
conn.subscribe( conn.subscribe(
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue.attributes["QueueArn"], TopicArn=topic_arn, Protocol="sqs", Endpoint=queue.attributes["QueueArn"]
) )
queue_raw = sqs.create_queue(QueueName="test-queue-raw") queue_raw = sqs.create_queue(QueueName="test-queue-raw")
conn.subscribe( conn.subscribe(
@ -431,6 +452,39 @@ def test_publish_message_too_long():
topic.publish(Message="".join(["." for i in range(0, 262144)])) topic.publish(Message="".join(["." for i in range(0, 262144)]))
@mock_sns
def test_publish_fifo_needs_group_id():
sns = boto3.resource("sns", region_name="us-east-1")
topic = sns.create_topic(
Name="topic.fifo",
Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true",},
)
with pytest.raises(
ClientError, match="The request must contain the parameter MessageGroupId"
):
topic.publish(Message="message")
# message group included - OK
topic.publish(Message="message", MessageGroupId="message_group_id")
@mock_sns
@mock_sqs
def test_publish_group_id_to_non_fifo():
sns = boto3.resource("sns", region_name="us-east-1")
topic = sns.create_topic(Name="topic")
with pytest.raises(
ClientError,
match="The request include parameter that is not valid for this queue type",
):
topic.publish(Message="message", MessageGroupId="message_group_id")
# message group not included - OK
topic.publish(Message="message")
def _setup_filter_policy_test(filter_policy): def _setup_filter_policy_test(filter_policy):
sns = boto3.resource("sns", region_name="us-east-1") sns = boto3.resource("sns", region_name="us-east-1")
topic = sns.create_topic(Name="some-topic") topic = sns.create_topic(Name="some-topic")