Fix publishing to fifo sns->sqs subscription (#4738)
This commit is contained in:
parent
ee6b2bfff8
commit
4e5180a9ba
@ -13,6 +13,7 @@ from moto.core.utils import (
|
||||
BackendDict,
|
||||
)
|
||||
from moto.sqs import sqs_backends
|
||||
from moto.sqs.exceptions import MissingParameter
|
||||
|
||||
from .exceptions import (
|
||||
SNSNotFoundError,
|
||||
@ -55,7 +56,7 @@ class Topic(CloudFormationModel):
|
||||
self.fifo_topic = "false"
|
||||
self.content_based_deduplication = "false"
|
||||
|
||||
def publish(self, message, subject=None, message_attributes=None):
|
||||
def publish(self, message, subject=None, message_attributes=None, group_id=None):
|
||||
message_id = str(uuid.uuid4())
|
||||
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
|
||||
for subscription in subscriptions:
|
||||
@ -64,6 +65,7 @@ class Topic(CloudFormationModel):
|
||||
message_id,
|
||||
subject=subject,
|
||||
message_attributes=message_attributes,
|
||||
group_id=group_id,
|
||||
)
|
||||
return message_id
|
||||
|
||||
@ -177,7 +179,9 @@ class Subscription(BaseModel):
|
||||
self._filter_policy = None # filter policy as a dict, not json.
|
||||
self.confirmed = False
|
||||
|
||||
def publish(self, message, message_id, subject=None, message_attributes=None):
|
||||
def publish(
|
||||
self, message, message_id, subject=None, message_attributes=None, group_id=None
|
||||
):
|
||||
if not self._matches_filter_policy(message_attributes):
|
||||
return
|
||||
|
||||
@ -198,6 +202,7 @@ class Subscription(BaseModel):
|
||||
indent=2,
|
||||
separators=(",", ": "),
|
||||
),
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
raw_message_attributes = {}
|
||||
@ -215,7 +220,10 @@ class Subscription(BaseModel):
|
||||
}
|
||||
|
||||
sqs_backends[region].send_message(
|
||||
queue_name, message, message_attributes=raw_message_attributes
|
||||
queue_name,
|
||||
message,
|
||||
message_attributes=raw_message_attributes,
|
||||
group_id=group_id,
|
||||
)
|
||||
elif self.protocol in ["http", "https"]:
|
||||
post_data = self.get_post_data(message, message_id, subject)
|
||||
@ -568,6 +576,7 @@ class SNSBackend(BaseBackend):
|
||||
phone_number=None,
|
||||
subject=None,
|
||||
message_attributes=None,
|
||||
group_id=None,
|
||||
):
|
||||
if subject is not None and len(subject) > 100:
|
||||
# Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503
|
||||
@ -589,8 +598,25 @@ class SNSBackend(BaseBackend):
|
||||
|
||||
try:
|
||||
topic = self.get_topic(arn)
|
||||
|
||||
fifo_topic = topic.fifo_topic == "true"
|
||||
if group_id is None:
|
||||
# MessageGroupId is a mandatory parameter for all
|
||||
# messages in a fifo queue
|
||||
if fifo_topic:
|
||||
raise MissingParameter("MessageGroupId")
|
||||
else:
|
||||
if not fifo_topic:
|
||||
msg = (
|
||||
"Value {} for parameter MessageGroupId is invalid. "
|
||||
"Reason: The request include parameter that is not valid for this queue type."
|
||||
).format(group_id)
|
||||
raise InvalidParameterValue(msg)
|
||||
message_id = topic.publish(
|
||||
message, subject=subject, message_attributes=message_attributes
|
||||
message,
|
||||
subject=subject,
|
||||
message_attributes=message_attributes,
|
||||
group_id=group_id,
|
||||
)
|
||||
except SNSNotFoundError:
|
||||
endpoint = self.get_endpoint(arn)
|
||||
|
@ -328,6 +328,7 @@ class SNSResponse(BaseResponse):
|
||||
topic_arn = self._get_param("TopicArn")
|
||||
phone_number = self._get_param("PhoneNumber")
|
||||
subject = self._get_param("Subject")
|
||||
message_group_id = self._get_param("MessageGroupId")
|
||||
|
||||
message_attributes = self._parse_message_attributes()
|
||||
|
||||
@ -355,6 +356,7 @@ class SNSResponse(BaseResponse):
|
||||
phone_number=phone_number,
|
||||
subject=subject,
|
||||
message_attributes=message_attributes,
|
||||
group_id=message_group_id,
|
||||
)
|
||||
except ValueError as err:
|
||||
error_response = self._error("InvalidParameter", str(err))
|
||||
|
@ -83,6 +83,27 @@ def test_publish_to_sqs_raw():
|
||||
messages[0].body.should.equal(message)
|
||||
|
||||
|
||||
@mock_sns
|
||||
@mock_sqs
|
||||
def test_publish_to_sqs_fifo():
|
||||
sns = boto3.resource("sns", region_name="us-east-1")
|
||||
topic = sns.create_topic(
|
||||
Name="topic.fifo",
|
||||
Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true",},
|
||||
)
|
||||
|
||||
sqs = boto3.resource("sqs", region_name="us-east-1")
|
||||
queue = sqs.create_queue(
|
||||
QueueName="queue.fifo",
|
||||
Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "true",},
|
||||
)
|
||||
topic.subscribe(
|
||||
Protocol="sqs", Endpoint=queue.attributes["QueueArn"],
|
||||
)
|
||||
|
||||
topic.publish(Message="message", MessageGroupId="message_group_id")
|
||||
|
||||
|
||||
@mock_sqs
|
||||
@mock_sns
|
||||
def test_publish_to_sqs_bad():
|
||||
@ -154,7 +175,7 @@ def test_publish_to_sqs_msg_attr_byte_value():
|
||||
sqs = boto3.resource("sqs", region_name="us-east-1")
|
||||
queue = sqs.create_queue(QueueName="test-queue")
|
||||
conn.subscribe(
|
||||
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue.attributes["QueueArn"],
|
||||
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue.attributes["QueueArn"]
|
||||
)
|
||||
queue_raw = sqs.create_queue(QueueName="test-queue-raw")
|
||||
conn.subscribe(
|
||||
@ -431,6 +452,39 @@ def test_publish_message_too_long():
|
||||
topic.publish(Message="".join(["." for i in range(0, 262144)]))
|
||||
|
||||
|
||||
@mock_sns
|
||||
def test_publish_fifo_needs_group_id():
|
||||
sns = boto3.resource("sns", region_name="us-east-1")
|
||||
topic = sns.create_topic(
|
||||
Name="topic.fifo",
|
||||
Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true",},
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientError, match="The request must contain the parameter MessageGroupId"
|
||||
):
|
||||
topic.publish(Message="message")
|
||||
|
||||
# message group included - OK
|
||||
topic.publish(Message="message", MessageGroupId="message_group_id")
|
||||
|
||||
|
||||
@mock_sns
|
||||
@mock_sqs
|
||||
def test_publish_group_id_to_non_fifo():
|
||||
sns = boto3.resource("sns", region_name="us-east-1")
|
||||
topic = sns.create_topic(Name="topic")
|
||||
|
||||
with pytest.raises(
|
||||
ClientError,
|
||||
match="The request include parameter that is not valid for this queue type",
|
||||
):
|
||||
topic.publish(Message="message", MessageGroupId="message_group_id")
|
||||
|
||||
# message group not included - OK
|
||||
topic.publish(Message="message")
|
||||
|
||||
|
||||
def _setup_filter_policy_test(filter_policy):
|
||||
sns = boto3.resource("sns", region_name="us-east-1")
|
||||
topic = sns.create_topic(Name="some-topic")
|
||||
|
Loading…
Reference in New Issue
Block a user