Fix publishing to fifo sns->sqs subscription (#4738)
This commit is contained in:
parent
ee6b2bfff8
commit
4e5180a9ba
@ -13,6 +13,7 @@ from moto.core.utils import (
|
|||||||
BackendDict,
|
BackendDict,
|
||||||
)
|
)
|
||||||
from moto.sqs import sqs_backends
|
from moto.sqs import sqs_backends
|
||||||
|
from moto.sqs.exceptions import MissingParameter
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
SNSNotFoundError,
|
SNSNotFoundError,
|
||||||
@ -55,7 +56,7 @@ class Topic(CloudFormationModel):
|
|||||||
self.fifo_topic = "false"
|
self.fifo_topic = "false"
|
||||||
self.content_based_deduplication = "false"
|
self.content_based_deduplication = "false"
|
||||||
|
|
||||||
def publish(self, message, subject=None, message_attributes=None):
|
def publish(self, message, subject=None, message_attributes=None, group_id=None):
|
||||||
message_id = str(uuid.uuid4())
|
message_id = str(uuid.uuid4())
|
||||||
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
|
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
|
||||||
for subscription in subscriptions:
|
for subscription in subscriptions:
|
||||||
@ -64,6 +65,7 @@ class Topic(CloudFormationModel):
|
|||||||
message_id,
|
message_id,
|
||||||
subject=subject,
|
subject=subject,
|
||||||
message_attributes=message_attributes,
|
message_attributes=message_attributes,
|
||||||
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
@ -177,7 +179,9 @@ class Subscription(BaseModel):
|
|||||||
self._filter_policy = None # filter policy as a dict, not json.
|
self._filter_policy = None # filter policy as a dict, not json.
|
||||||
self.confirmed = False
|
self.confirmed = False
|
||||||
|
|
||||||
def publish(self, message, message_id, subject=None, message_attributes=None):
|
def publish(
|
||||||
|
self, message, message_id, subject=None, message_attributes=None, group_id=None
|
||||||
|
):
|
||||||
if not self._matches_filter_policy(message_attributes):
|
if not self._matches_filter_policy(message_attributes):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -198,6 +202,7 @@ class Subscription(BaseModel):
|
|||||||
indent=2,
|
indent=2,
|
||||||
separators=(",", ": "),
|
separators=(",", ": "),
|
||||||
),
|
),
|
||||||
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raw_message_attributes = {}
|
raw_message_attributes = {}
|
||||||
@ -215,7 +220,10 @@ class Subscription(BaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
sqs_backends[region].send_message(
|
sqs_backends[region].send_message(
|
||||||
queue_name, message, message_attributes=raw_message_attributes
|
queue_name,
|
||||||
|
message,
|
||||||
|
message_attributes=raw_message_attributes,
|
||||||
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
elif self.protocol in ["http", "https"]:
|
elif self.protocol in ["http", "https"]:
|
||||||
post_data = self.get_post_data(message, message_id, subject)
|
post_data = self.get_post_data(message, message_id, subject)
|
||||||
@ -568,6 +576,7 @@ class SNSBackend(BaseBackend):
|
|||||||
phone_number=None,
|
phone_number=None,
|
||||||
subject=None,
|
subject=None,
|
||||||
message_attributes=None,
|
message_attributes=None,
|
||||||
|
group_id=None,
|
||||||
):
|
):
|
||||||
if subject is not None and len(subject) > 100:
|
if subject is not None and len(subject) > 100:
|
||||||
# Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503
|
# Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503
|
||||||
@ -589,8 +598,25 @@ class SNSBackend(BaseBackend):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
topic = self.get_topic(arn)
|
topic = self.get_topic(arn)
|
||||||
|
|
||||||
|
fifo_topic = topic.fifo_topic == "true"
|
||||||
|
if group_id is None:
|
||||||
|
# MessageGroupId is a mandatory parameter for all
|
||||||
|
# messages in a fifo queue
|
||||||
|
if fifo_topic:
|
||||||
|
raise MissingParameter("MessageGroupId")
|
||||||
|
else:
|
||||||
|
if not fifo_topic:
|
||||||
|
msg = (
|
||||||
|
"Value {} for parameter MessageGroupId is invalid. "
|
||||||
|
"Reason: The request include parameter that is not valid for this queue type."
|
||||||
|
).format(group_id)
|
||||||
|
raise InvalidParameterValue(msg)
|
||||||
message_id = topic.publish(
|
message_id = topic.publish(
|
||||||
message, subject=subject, message_attributes=message_attributes
|
message,
|
||||||
|
subject=subject,
|
||||||
|
message_attributes=message_attributes,
|
||||||
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
except SNSNotFoundError:
|
except SNSNotFoundError:
|
||||||
endpoint = self.get_endpoint(arn)
|
endpoint = self.get_endpoint(arn)
|
||||||
|
@ -328,6 +328,7 @@ class SNSResponse(BaseResponse):
|
|||||||
topic_arn = self._get_param("TopicArn")
|
topic_arn = self._get_param("TopicArn")
|
||||||
phone_number = self._get_param("PhoneNumber")
|
phone_number = self._get_param("PhoneNumber")
|
||||||
subject = self._get_param("Subject")
|
subject = self._get_param("Subject")
|
||||||
|
message_group_id = self._get_param("MessageGroupId")
|
||||||
|
|
||||||
message_attributes = self._parse_message_attributes()
|
message_attributes = self._parse_message_attributes()
|
||||||
|
|
||||||
@ -355,6 +356,7 @@ class SNSResponse(BaseResponse):
|
|||||||
phone_number=phone_number,
|
phone_number=phone_number,
|
||||||
subject=subject,
|
subject=subject,
|
||||||
message_attributes=message_attributes,
|
message_attributes=message_attributes,
|
||||||
|
group_id=message_group_id,
|
||||||
)
|
)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
error_response = self._error("InvalidParameter", str(err))
|
error_response = self._error("InvalidParameter", str(err))
|
||||||
|
@ -83,6 +83,27 @@ def test_publish_to_sqs_raw():
|
|||||||
messages[0].body.should.equal(message)
|
messages[0].body.should.equal(message)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sns
|
||||||
|
@mock_sqs
|
||||||
|
def test_publish_to_sqs_fifo():
|
||||||
|
sns = boto3.resource("sns", region_name="us-east-1")
|
||||||
|
topic = sns.create_topic(
|
||||||
|
Name="topic.fifo",
|
||||||
|
Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true",},
|
||||||
|
)
|
||||||
|
|
||||||
|
sqs = boto3.resource("sqs", region_name="us-east-1")
|
||||||
|
queue = sqs.create_queue(
|
||||||
|
QueueName="queue.fifo",
|
||||||
|
Attributes={"FifoQueue": "true", "ContentBasedDeduplication": "true",},
|
||||||
|
)
|
||||||
|
topic.subscribe(
|
||||||
|
Protocol="sqs", Endpoint=queue.attributes["QueueArn"],
|
||||||
|
)
|
||||||
|
|
||||||
|
topic.publish(Message="message", MessageGroupId="message_group_id")
|
||||||
|
|
||||||
|
|
||||||
@mock_sqs
|
@mock_sqs
|
||||||
@mock_sns
|
@mock_sns
|
||||||
def test_publish_to_sqs_bad():
|
def test_publish_to_sqs_bad():
|
||||||
@ -154,7 +175,7 @@ def test_publish_to_sqs_msg_attr_byte_value():
|
|||||||
sqs = boto3.resource("sqs", region_name="us-east-1")
|
sqs = boto3.resource("sqs", region_name="us-east-1")
|
||||||
queue = sqs.create_queue(QueueName="test-queue")
|
queue = sqs.create_queue(QueueName="test-queue")
|
||||||
conn.subscribe(
|
conn.subscribe(
|
||||||
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue.attributes["QueueArn"],
|
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue.attributes["QueueArn"]
|
||||||
)
|
)
|
||||||
queue_raw = sqs.create_queue(QueueName="test-queue-raw")
|
queue_raw = sqs.create_queue(QueueName="test-queue-raw")
|
||||||
conn.subscribe(
|
conn.subscribe(
|
||||||
@ -431,6 +452,39 @@ def test_publish_message_too_long():
|
|||||||
topic.publish(Message="".join(["." for i in range(0, 262144)]))
|
topic.publish(Message="".join(["." for i in range(0, 262144)]))
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sns
|
||||||
|
def test_publish_fifo_needs_group_id():
|
||||||
|
sns = boto3.resource("sns", region_name="us-east-1")
|
||||||
|
topic = sns.create_topic(
|
||||||
|
Name="topic.fifo",
|
||||||
|
Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true",},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ClientError, match="The request must contain the parameter MessageGroupId"
|
||||||
|
):
|
||||||
|
topic.publish(Message="message")
|
||||||
|
|
||||||
|
# message group included - OK
|
||||||
|
topic.publish(Message="message", MessageGroupId="message_group_id")
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sns
|
||||||
|
@mock_sqs
|
||||||
|
def test_publish_group_id_to_non_fifo():
|
||||||
|
sns = boto3.resource("sns", region_name="us-east-1")
|
||||||
|
topic = sns.create_topic(Name="topic")
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ClientError,
|
||||||
|
match="The request include parameter that is not valid for this queue type",
|
||||||
|
):
|
||||||
|
topic.publish(Message="message", MessageGroupId="message_group_id")
|
||||||
|
|
||||||
|
# message group not included - OK
|
||||||
|
topic.publish(Message="message")
|
||||||
|
|
||||||
|
|
||||||
def _setup_filter_policy_test(filter_policy):
|
def _setup_filter_policy_test(filter_policy):
|
||||||
sns = boto3.resource("sns", region_name="us-east-1")
|
sns = boto3.resource("sns", region_name="us-east-1")
|
||||||
topic = sns.create_topic(Name="some-topic")
|
topic = sns.create_topic(Name="some-topic")
|
||||||
|
Loading…
Reference in New Issue
Block a user