SNS: publish_batch() (#4741)
This commit is contained in:
parent
0eb8ec47ad
commit
526559e22c
@ -482,14 +482,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
||||
tracked_prefixes or set()
|
||||
) # prefixes which have already been processed
|
||||
|
||||
def is_tracked(name_param):
|
||||
for prefix_loop in tracked_prefixes:
|
||||
if name_param.startswith(prefix_loop):
|
||||
return True
|
||||
return False
|
||||
|
||||
for name, value in self.querystring.items():
|
||||
if is_tracked(name) or not name.startswith(param_prefix):
|
||||
if not name.startswith(param_prefix):
|
||||
continue
|
||||
|
||||
if len(name) > len(param_prefix) and not name[
|
||||
|
@ -8,6 +8,11 @@ class SNSNotFoundError(RESTError):
|
||||
super().__init__("NotFound", message, **kwargs)
|
||||
|
||||
|
||||
class TopicNotFound(SNSNotFoundError):
|
||||
def __init__(self):
|
||||
super().__init__(message="Topic does not exist")
|
||||
|
||||
|
||||
class ResourceNotFoundError(RESTError):
|
||||
code = 404
|
||||
|
||||
@ -60,3 +65,23 @@ class InternalError(RESTError):
|
||||
|
||||
def __init__(self, message):
|
||||
super(InternalError, self).__init__("InternalFailure", message)
|
||||
|
||||
|
||||
class TooManyEntriesInBatchRequest(RESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"TooManyEntriesInBatchRequest",
|
||||
"The batch request contains more entries than permissible.",
|
||||
)
|
||||
|
||||
|
||||
class BatchEntryIdsNotDistinct(RESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"BatchEntryIdsNotDistinct",
|
||||
"Two or more batch entries in the request have the same Id.",
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ from moto.sqs.exceptions import MissingParameter
|
||||
|
||||
from .exceptions import (
|
||||
SNSNotFoundError,
|
||||
TopicNotFound,
|
||||
DuplicateSnsEndpointError,
|
||||
SnsEndpointDisabled,
|
||||
SNSInvalidParameter,
|
||||
@ -24,6 +25,8 @@ from .exceptions import (
|
||||
InternalError,
|
||||
ResourceNotFoundError,
|
||||
TagLimitExceededError,
|
||||
TooManyEntriesInBatchRequest,
|
||||
BatchEntryIdsNotDistinct,
|
||||
)
|
||||
from .utils import make_arn_for_topic, make_arn_for_subscription, is_e164
|
||||
|
||||
@ -843,6 +846,56 @@ class SNSBackend(BaseBackend):
|
||||
for key in tag_keys:
|
||||
self.topics[resource_arn]._tags.pop(key, None)
|
||||
|
||||
def publish_batch(self, topic_arn, publish_batch_request_entries):
|
||||
"""
|
||||
The MessageStructure and MessageDeduplicationId-parameters have not yet been implemented.
|
||||
"""
|
||||
try:
|
||||
topic = self.get_topic(topic_arn)
|
||||
except SNSNotFoundError:
|
||||
raise TopicNotFound
|
||||
|
||||
if len(publish_batch_request_entries) > 10:
|
||||
raise TooManyEntriesInBatchRequest
|
||||
|
||||
ids = [m["Id"] for m in publish_batch_request_entries]
|
||||
if len(set(ids)) != len(ids):
|
||||
raise BatchEntryIdsNotDistinct
|
||||
|
||||
fifo_topic = topic.fifo_topic == "true"
|
||||
if fifo_topic:
|
||||
if not all(
|
||||
["MessageGroupId" in entry for entry in publish_batch_request_entries]
|
||||
):
|
||||
raise SNSInvalidParameter(
|
||||
"Invalid parameter: The MessageGroupId parameter is required for FIFO topics"
|
||||
)
|
||||
|
||||
successful = []
|
||||
failed = []
|
||||
|
||||
for entry in publish_batch_request_entries:
|
||||
try:
|
||||
message_id = self.publish(
|
||||
message=entry["Message"],
|
||||
arn=topic_arn,
|
||||
subject=entry.get("Subject"),
|
||||
message_attributes=entry.get("MessageAttributes", []),
|
||||
group_id=entry.get("MessageGroupId"),
|
||||
)
|
||||
successful.append({"MessageId": message_id, "Id": entry["Id"]})
|
||||
except Exception as e:
|
||||
if isinstance(e, InvalidParameterValue):
|
||||
failed.append(
|
||||
{
|
||||
"Id": entry["Id"],
|
||||
"Code": "InvalidParameter",
|
||||
"Message": f"Invalid parameter: {e.message}",
|
||||
"SenderFault": True,
|
||||
}
|
||||
)
|
||||
return successful, failed
|
||||
|
||||
|
||||
sns_backends = BackendDict(SNSBackend, "sns")
|
||||
|
||||
|
@ -377,6 +377,18 @@ class SNSResponse(BaseResponse):
|
||||
template = self.response_template(PUBLISH_TEMPLATE)
|
||||
return template.render(message_id=message_id)
|
||||
|
||||
def publish_batch(self):
|
||||
topic_arn = self._get_param("TopicArn")
|
||||
publish_batch_request_entries = self._get_multi_param(
|
||||
"PublishBatchRequestEntries.member"
|
||||
)
|
||||
successful, failed = self.backend.publish_batch(
|
||||
topic_arn=topic_arn,
|
||||
publish_batch_request_entries=publish_batch_request_entries,
|
||||
)
|
||||
template = self.response_template(PUBLISH_BATCH_TEMPLATE)
|
||||
return template.render(successful=successful, failed=failed)
|
||||
|
||||
def create_platform_application(self):
|
||||
name = self._get_param("Name")
|
||||
platform = self._get_param("Platform")
|
||||
@ -1191,3 +1203,29 @@ UNTAG_RESOURCE_TEMPLATE = """<UntagResourceResponse xmlns="http://sns.amazonaws.
|
||||
<RequestId>14eb7b1a-4cbd-5a56-80db-2d06412df769</RequestId>
|
||||
</ResponseMetadata>
|
||||
</UntagResourceResponse>"""
|
||||
|
||||
PUBLISH_BATCH_TEMPLATE = """<PublishBatchResponse xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
|
||||
<ResponseMetadata>
|
||||
<RequestId>1549581b-12b7-11e3-895e-1334aEXAMPLE</RequestId>
|
||||
</ResponseMetadata>
|
||||
<PublishBatchResult>
|
||||
<Successful>
|
||||
{% for successful in successful %}
|
||||
<member>
|
||||
<Id>{{ successful["Id"] }}</Id>
|
||||
<MessageId>{{ successful["MessageId"] }}</MessageId>
|
||||
</member>
|
||||
{% endfor %}
|
||||
</Successful>
|
||||
<Failed>
|
||||
{% for failed in failed %}
|
||||
<member>
|
||||
<Id>{{ failed["Id"] }}</Id>
|
||||
<Code>{{ failed["Code"] }}</Code>
|
||||
<Message>{{ failed["Message"] }}</Message>
|
||||
<SenderFault>{{'true' if failed["SenderFault"] else 'false'}}</SenderFault>
|
||||
</member>
|
||||
{% endfor %}
|
||||
</Failed>
|
||||
</PublishBatchResult>
|
||||
</PublishBatchResponse>"""
|
||||
|
154
tests/test_sns/test_publish_batch.py
Normal file
154
tests/test_sns/test_publish_batch.py
Normal file
@ -0,0 +1,154 @@
|
||||
import boto3
|
||||
import json
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
import pytest
|
||||
from moto import mock_sns, mock_sqs
|
||||
from moto.core import ACCOUNT_ID
|
||||
|
||||
|
||||
@mock_sns
|
||||
def test_publish_batch_unknown_topic():
|
||||
client = boto3.client("sns", region_name="us-east-1")
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.publish_batch(
|
||||
TopicArn=f"arn:aws:sns:us-east-1:{ACCOUNT_ID}:unknown",
|
||||
PublishBatchRequestEntries=[{"Id": "id_1", "Message": "1"}],
|
||||
)
|
||||
err = exc.value.response["Error"]
|
||||
err["Code"].should.equal("NotFound")
|
||||
err["Message"].should.equal("Topic does not exist")
|
||||
|
||||
|
||||
@mock_sns
|
||||
def test_publish_batch_too_many_items():
|
||||
client = boto3.client("sns", region_name="eu-north-1")
|
||||
topic = client.create_topic(Name="some-topic")
|
||||
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.publish_batch(
|
||||
TopicArn=topic["TopicArn"],
|
||||
PublishBatchRequestEntries=[
|
||||
{"Id": f"id_{idx}", "Message": f"{idx}"} for idx in range(11)
|
||||
],
|
||||
)
|
||||
err = exc.value.response["Error"]
|
||||
err["Code"].should.equal("TooManyEntriesInBatchRequest")
|
||||
err["Message"].should.equal(
|
||||
"The batch request contains more entries than permissible."
|
||||
)
|
||||
|
||||
|
||||
@mock_sns
|
||||
def test_publish_batch_non_unique_ids():
|
||||
client = boto3.client("sns", region_name="us-west-2")
|
||||
topic = client.create_topic(Name="some-topic")
|
||||
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.publish_batch(
|
||||
TopicArn=topic["TopicArn"],
|
||||
PublishBatchRequestEntries=[
|
||||
{"Id": f"id", "Message": f"{idx}"} for idx in range(5)
|
||||
],
|
||||
)
|
||||
err = exc.value.response["Error"]
|
||||
err["Code"].should.equal("BatchEntryIdsNotDistinct")
|
||||
err["Message"].should.equal(
|
||||
"Two or more batch entries in the request have the same Id."
|
||||
)
|
||||
|
||||
|
||||
@mock_sns
|
||||
def test_publish_batch_fifo_without_message_group_id():
|
||||
client = boto3.client("sns", region_name="us-east-1")
|
||||
topic = client.create_topic(
|
||||
Name="fifo_without_msg.fifo",
|
||||
Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true"},
|
||||
)
|
||||
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.publish_batch(
|
||||
TopicArn=topic["TopicArn"],
|
||||
PublishBatchRequestEntries=[{"Id": f"id_2", "Message": f"2"}],
|
||||
)
|
||||
err = exc.value.response["Error"]
|
||||
err["Code"].should.equal("InvalidParameter")
|
||||
err["Message"].should.equal(
|
||||
"Invalid parameter: The MessageGroupId parameter is required for FIFO topics"
|
||||
)
|
||||
|
||||
|
||||
@mock_sns
|
||||
def test_publish_batch_standard_with_message_group_id():
|
||||
client = boto3.client("sns", region_name="us-east-1")
|
||||
topic_arn = client.create_topic(Name="standard_topic")["TopicArn"]
|
||||
entries = [
|
||||
{"Id": f"id_1", "Message": f"1"},
|
||||
{"Id": f"id_2", "Message": f"2", "MessageGroupId": "mgid"},
|
||||
{"Id": f"id_3", "Message": f"3"},
|
||||
]
|
||||
resp = client.publish_batch(TopicArn=topic_arn, PublishBatchRequestEntries=entries)
|
||||
|
||||
resp.should.have.key("Successful").length_of(2)
|
||||
for message_status in resp["Successful"]:
|
||||
message_status.should.have.key("MessageId")
|
||||
[m["Id"] for m in resp["Successful"]].should.equal(["id_1", "id_3"])
|
||||
|
||||
resp.should.have.key("Failed").length_of(1)
|
||||
resp["Failed"][0].should.equal(
|
||||
{
|
||||
"Id": "id_2",
|
||||
"Code": "InvalidParameter",
|
||||
"Message": "Invalid parameter: Value mgid for parameter MessageGroupId is invalid. Reason: The request include parameter that is not valid for this queue type.",
|
||||
"SenderFault": True,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mock_sns
|
||||
@mock_sqs
|
||||
def test_publish_batch_to_sqs():
|
||||
client = boto3.client("sns", region_name="us-east-1")
|
||||
topic_arn = client.create_topic(Name="standard_topic")["TopicArn"]
|
||||
entries = [
|
||||
{"Id": f"id_1", "Message": f"1"},
|
||||
{"Id": f"id_2", "Message": f"2", "Subject": "subj2"},
|
||||
{
|
||||
"Id": f"id_3",
|
||||
"Message": f"3",
|
||||
"MessageAttributes": {"a": {"DataType": "String", "StringValue": "v"}},
|
||||
},
|
||||
]
|
||||
|
||||
sqs_conn = boto3.resource("sqs", region_name="us-east-1")
|
||||
queue = sqs_conn.create_queue(QueueName="test-queue")
|
||||
|
||||
queue_url = "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID)
|
||||
client.subscribe(
|
||||
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue_url,
|
||||
)
|
||||
|
||||
resp = client.publish_batch(TopicArn=topic_arn, PublishBatchRequestEntries=entries)
|
||||
|
||||
resp.should.have.key("Successful").length_of(3)
|
||||
|
||||
messages = queue.receive_messages(MaxNumberOfMessages=3)
|
||||
messages.should.have.length_of(3)
|
||||
|
||||
messages = [json.loads(m.body) for m in messages]
|
||||
for m in messages:
|
||||
for key in list(m.keys()):
|
||||
if key not in ["Message", "Subject", "MessageAttributes"]:
|
||||
del m[key]
|
||||
|
||||
messages.should.contain({"Message": "1"})
|
||||
messages.should.contain({"Message": "2", "Subject": "subj2"})
|
||||
messages.should.contain(
|
||||
{
|
||||
"Message": "3",
|
||||
"MessageAttributes": [
|
||||
{"Name": "a", "Value": {"DataType": "String", "StringValue": "v"}}
|
||||
],
|
||||
}
|
||||
)
|
Loading…
Reference in New Issue
Block a user