SNS: publish_batch() (#4741)

This commit is contained in:
Bert Blommers 2022-01-06 22:09:16 -01:00 committed by GitHub
parent 0eb8ec47ad
commit 526559e22c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 271 additions and 7 deletions

View File

@ -482,14 +482,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
tracked_prefixes or set()
) # prefixes which have already been processed
def is_tracked(name_param):
for prefix_loop in tracked_prefixes:
if name_param.startswith(prefix_loop):
return True
return False
for name, value in self.querystring.items():
if is_tracked(name) or not name.startswith(param_prefix):
if not name.startswith(param_prefix):
continue
if len(name) > len(param_prefix) and not name[

View File

@ -8,6 +8,11 @@ class SNSNotFoundError(RESTError):
super().__init__("NotFound", message, **kwargs)
class TopicNotFound(SNSNotFoundError):
def __init__(self):
super().__init__(message="Topic does not exist")
class ResourceNotFoundError(RESTError):
code = 404
@ -60,3 +65,23 @@ class InternalError(RESTError):
def __init__(self, message):
super(InternalError, self).__init__("InternalFailure", message)
class TooManyEntriesInBatchRequest(RESTError):
code = 400
def __init__(self):
super().__init__(
"TooManyEntriesInBatchRequest",
"The batch request contains more entries than permissible.",
)
class BatchEntryIdsNotDistinct(RESTError):
code = 400
def __init__(self):
super().__init__(
"BatchEntryIdsNotDistinct",
"Two or more batch entries in the request have the same Id.",
)

View File

@ -17,6 +17,7 @@ from moto.sqs.exceptions import MissingParameter
from .exceptions import (
SNSNotFoundError,
TopicNotFound,
DuplicateSnsEndpointError,
SnsEndpointDisabled,
SNSInvalidParameter,
@ -24,6 +25,8 @@ from .exceptions import (
InternalError,
ResourceNotFoundError,
TagLimitExceededError,
TooManyEntriesInBatchRequest,
BatchEntryIdsNotDistinct,
)
from .utils import make_arn_for_topic, make_arn_for_subscription, is_e164
@ -843,6 +846,56 @@ class SNSBackend(BaseBackend):
for key in tag_keys:
self.topics[resource_arn]._tags.pop(key, None)
def publish_batch(self, topic_arn, publish_batch_request_entries):
"""
The MessageStructure and MessageDeduplicationId-parameters have not yet been implemented.
"""
try:
topic = self.get_topic(topic_arn)
except SNSNotFoundError:
raise TopicNotFound
if len(publish_batch_request_entries) > 10:
raise TooManyEntriesInBatchRequest
ids = [m["Id"] for m in publish_batch_request_entries]
if len(set(ids)) != len(ids):
raise BatchEntryIdsNotDistinct
fifo_topic = topic.fifo_topic == "true"
if fifo_topic:
if not all(
["MessageGroupId" in entry for entry in publish_batch_request_entries]
):
raise SNSInvalidParameter(
"Invalid parameter: The MessageGroupId parameter is required for FIFO topics"
)
successful = []
failed = []
for entry in publish_batch_request_entries:
try:
message_id = self.publish(
message=entry["Message"],
arn=topic_arn,
subject=entry.get("Subject"),
message_attributes=entry.get("MessageAttributes", []),
group_id=entry.get("MessageGroupId"),
)
successful.append({"MessageId": message_id, "Id": entry["Id"]})
except Exception as e:
if isinstance(e, InvalidParameterValue):
failed.append(
{
"Id": entry["Id"],
"Code": "InvalidParameter",
"Message": f"Invalid parameter: {e.message}",
"SenderFault": True,
}
)
return successful, failed
sns_backends = BackendDict(SNSBackend, "sns")

View File

@ -377,6 +377,18 @@ class SNSResponse(BaseResponse):
template = self.response_template(PUBLISH_TEMPLATE)
return template.render(message_id=message_id)
def publish_batch(self):
topic_arn = self._get_param("TopicArn")
publish_batch_request_entries = self._get_multi_param(
"PublishBatchRequestEntries.member"
)
successful, failed = self.backend.publish_batch(
topic_arn=topic_arn,
publish_batch_request_entries=publish_batch_request_entries,
)
template = self.response_template(PUBLISH_BATCH_TEMPLATE)
return template.render(successful=successful, failed=failed)
def create_platform_application(self):
name = self._get_param("Name")
platform = self._get_param("Platform")
@ -1191,3 +1203,29 @@ UNTAG_RESOURCE_TEMPLATE = """<UntagResourceResponse xmlns="http://sns.amazonaws.
<RequestId>14eb7b1a-4cbd-5a56-80db-2d06412df769</RequestId>
</ResponseMetadata>
</UntagResourceResponse>"""
PUBLISH_BATCH_TEMPLATE = """<PublishBatchResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
<ResponseMetadata>
<RequestId>1549581b-12b7-11e3-895e-1334aEXAMPLE</RequestId>
</ResponseMetadata>
<PublishBatchResult>
<Successful>
{% for successful in successful %}
<member>
<Id>{{ successful["Id"] }}</Id>
<MessageId>{{ successful["MessageId"] }}</MessageId>
</member>
{% endfor %}
</Successful>
<Failed>
{% for failed in failed %}
<member>
<Id>{{ failed["Id"] }}</Id>
<Code>{{ failed["Code"] }}</Code>
<Message>{{ failed["Message"] }}</Message>
<SenderFault>{{'true' if failed["SenderFault"] else 'false'}}</SenderFault>
</member>
{% endfor %}
</Failed>
</PublishBatchResult>
</PublishBatchResponse>"""

View File

@ -0,0 +1,154 @@
import boto3
import json
import sure # noqa # pylint: disable=unused-import
from botocore.exceptions import ClientError
import pytest
from moto import mock_sns, mock_sqs
from moto.core import ACCOUNT_ID
@mock_sns
def test_publish_batch_unknown_topic():
client = boto3.client("sns", region_name="us-east-1")
with pytest.raises(ClientError) as exc:
client.publish_batch(
TopicArn=f"arn:aws:sns:us-east-1:{ACCOUNT_ID}:unknown",
PublishBatchRequestEntries=[{"Id": "id_1", "Message": "1"}],
)
err = exc.value.response["Error"]
err["Code"].should.equal("NotFound")
err["Message"].should.equal("Topic does not exist")
@mock_sns
def test_publish_batch_too_many_items():
client = boto3.client("sns", region_name="eu-north-1")
topic = client.create_topic(Name="some-topic")
with pytest.raises(ClientError) as exc:
client.publish_batch(
TopicArn=topic["TopicArn"],
PublishBatchRequestEntries=[
{"Id": f"id_{idx}", "Message": f"{idx}"} for idx in range(11)
],
)
err = exc.value.response["Error"]
err["Code"].should.equal("TooManyEntriesInBatchRequest")
err["Message"].should.equal(
"The batch request contains more entries than permissible."
)
@mock_sns
def test_publish_batch_non_unique_ids():
client = boto3.client("sns", region_name="us-west-2")
topic = client.create_topic(Name="some-topic")
with pytest.raises(ClientError) as exc:
client.publish_batch(
TopicArn=topic["TopicArn"],
PublishBatchRequestEntries=[
{"Id": f"id", "Message": f"{idx}"} for idx in range(5)
],
)
err = exc.value.response["Error"]
err["Code"].should.equal("BatchEntryIdsNotDistinct")
err["Message"].should.equal(
"Two or more batch entries in the request have the same Id."
)
@mock_sns
def test_publish_batch_fifo_without_message_group_id():
client = boto3.client("sns", region_name="us-east-1")
topic = client.create_topic(
Name="fifo_without_msg.fifo",
Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true"},
)
with pytest.raises(ClientError) as exc:
client.publish_batch(
TopicArn=topic["TopicArn"],
PublishBatchRequestEntries=[{"Id": f"id_2", "Message": f"2"}],
)
err = exc.value.response["Error"]
err["Code"].should.equal("InvalidParameter")
err["Message"].should.equal(
"Invalid parameter: The MessageGroupId parameter is required for FIFO topics"
)
@mock_sns
def test_publish_batch_standard_with_message_group_id():
client = boto3.client("sns", region_name="us-east-1")
topic_arn = client.create_topic(Name="standard_topic")["TopicArn"]
entries = [
{"Id": f"id_1", "Message": f"1"},
{"Id": f"id_2", "Message": f"2", "MessageGroupId": "mgid"},
{"Id": f"id_3", "Message": f"3"},
]
resp = client.publish_batch(TopicArn=topic_arn, PublishBatchRequestEntries=entries)
resp.should.have.key("Successful").length_of(2)
for message_status in resp["Successful"]:
message_status.should.have.key("MessageId")
[m["Id"] for m in resp["Successful"]].should.equal(["id_1", "id_3"])
resp.should.have.key("Failed").length_of(1)
resp["Failed"][0].should.equal(
{
"Id": "id_2",
"Code": "InvalidParameter",
"Message": "Invalid parameter: Value mgid for parameter MessageGroupId is invalid. Reason: The request include parameter that is not valid for this queue type.",
"SenderFault": True,
}
)
@mock_sns
@mock_sqs
def test_publish_batch_to_sqs():
client = boto3.client("sns", region_name="us-east-1")
topic_arn = client.create_topic(Name="standard_topic")["TopicArn"]
entries = [
{"Id": f"id_1", "Message": f"1"},
{"Id": f"id_2", "Message": f"2", "Subject": "subj2"},
{
"Id": f"id_3",
"Message": f"3",
"MessageAttributes": {"a": {"DataType": "String", "StringValue": "v"}},
},
]
sqs_conn = boto3.resource("sqs", region_name="us-east-1")
queue = sqs_conn.create_queue(QueueName="test-queue")
queue_url = "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID)
client.subscribe(
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue_url,
)
resp = client.publish_batch(TopicArn=topic_arn, PublishBatchRequestEntries=entries)
resp.should.have.key("Successful").length_of(3)
messages = queue.receive_messages(MaxNumberOfMessages=3)
messages.should.have.length_of(3)
messages = [json.loads(m.body) for m in messages]
for m in messages:
for key in list(m.keys()):
if key not in ["Message", "Subject", "MessageAttributes"]:
del m[key]
messages.should.contain({"Message": "1"})
messages.should.contain({"Message": "2", "Subject": "subj2"})
messages.should.contain(
{
"Message": "3",
"MessageAttributes": [
{"Name": "a", "Value": {"DataType": "String", "StringValue": "v"}}
],
}
)