diff --git a/moto/core/responses.py b/moto/core/responses.py index 0438fe78d..136646e3b 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -482,14 +482,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): tracked_prefixes or set() ) # prefixes which have already been processed - def is_tracked(name_param): - for prefix_loop in tracked_prefixes: - if name_param.startswith(prefix_loop): - return True - return False - for name, value in self.querystring.items(): - if is_tracked(name) or not name.startswith(param_prefix): + if not name.startswith(param_prefix): continue if len(name) > len(param_prefix) and not name[ diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py index 5ba6eef6c..211e32924 100644 --- a/moto/sns/exceptions.py +++ b/moto/sns/exceptions.py @@ -8,6 +8,11 @@ class SNSNotFoundError(RESTError): super().__init__("NotFound", message, **kwargs) +class TopicNotFound(SNSNotFoundError): + def __init__(self): + super().__init__(message="Topic does not exist") + + class ResourceNotFoundError(RESTError): code = 404 @@ -60,3 +65,23 @@ class InternalError(RESTError): def __init__(self, message): super(InternalError, self).__init__("InternalFailure", message) + + +class TooManyEntriesInBatchRequest(RESTError): + code = 400 + + def __init__(self): + super().__init__( + "TooManyEntriesInBatchRequest", + "The batch request contains more entries than permissible.", + ) + + +class BatchEntryIdsNotDistinct(RESTError): + code = 400 + + def __init__(self): + super().__init__( + "BatchEntryIdsNotDistinct", + "Two or more batch entries in the request have the same Id.", + ) diff --git a/moto/sns/models.py b/moto/sns/models.py index da20b3555..aba60d814 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -17,6 +17,7 @@ from moto.sqs.exceptions import MissingParameter from .exceptions import ( SNSNotFoundError, + TopicNotFound, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter, @@ -24,6 +25,8 @@ from .exceptions import ( InternalError, ResourceNotFoundError, TagLimitExceededError, + TooManyEntriesInBatchRequest, + BatchEntryIdsNotDistinct, ) from .utils import make_arn_for_topic, make_arn_for_subscription, is_e164 @@ -843,6 +846,56 @@ class SNSBackend(BaseBackend): for key in tag_keys: self.topics[resource_arn]._tags.pop(key, None) + def publish_batch(self, topic_arn, publish_batch_request_entries): + """ + The MessageStructure and MessageDeduplicationId-parameters have not yet been implemented. + """ + try: + topic = self.get_topic(topic_arn) + except SNSNotFoundError: + raise TopicNotFound + + if len(publish_batch_request_entries) > 10: + raise TooManyEntriesInBatchRequest + + ids = [m["Id"] for m in publish_batch_request_entries] + if len(set(ids)) != len(ids): + raise BatchEntryIdsNotDistinct + + fifo_topic = topic.fifo_topic == "true" + if fifo_topic: + if not all( + ["MessageGroupId" in entry for entry in publish_batch_request_entries] + ): + raise SNSInvalidParameter( + "Invalid parameter: The MessageGroupId parameter is required for FIFO topics" + ) + + successful = [] + failed = [] + + for entry in publish_batch_request_entries: + try: + message_id = self.publish( + message=entry["Message"], + arn=topic_arn, + subject=entry.get("Subject"), + message_attributes=entry.get("MessageAttributes", []), + group_id=entry.get("MessageGroupId"), + ) + successful.append({"MessageId": message_id, "Id": entry["Id"]}) + except Exception as e: + if isinstance(e, InvalidParameterValue): + failed.append( + { + "Id": entry["Id"], + "Code": "InvalidParameter", + "Message": f"Invalid parameter: {e.message}", + "SenderFault": True, + } + ) + return successful, failed + sns_backends = BackendDict(SNSBackend, "sns") diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 416d9704f..8291586cc 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -377,6 +377,18 @@ class SNSResponse(BaseResponse): template = self.response_template(PUBLISH_TEMPLATE) return template.render(message_id=message_id) + def publish_batch(self): + topic_arn = self._get_param("TopicArn") + publish_batch_request_entries = self._get_multi_param( + "PublishBatchRequestEntries.member" + ) + successful, failed = self.backend.publish_batch( + topic_arn=topic_arn, + publish_batch_request_entries=publish_batch_request_entries, + ) + template = self.response_template(PUBLISH_BATCH_TEMPLATE) + return template.render(successful=successful, failed=failed) + def create_platform_application(self): name = self._get_param("Name") platform = self._get_param("Platform") @@ -1191,3 +1203,29 @@ UNTAG_RESOURCE_TEMPLATE = """ + + 1549581b-12b7-11e3-895e-1334aEXAMPLE + + + +{% for successful in successful %} + + {{ successful["Id"] }} + {{ successful["MessageId"] }} + +{% endfor %} + + +{% for failed in failed %} + + {{ failed["Id"] }} + {{ failed["Code"] }} + {{ failed["Message"] }} + {{'true' if failed["SenderFault"] else 'false'}} + +{% endfor %} + + +""" diff --git a/tests/test_sns/test_publish_batch.py b/tests/test_sns/test_publish_batch.py new file mode 100644 index 000000000..53dc37f5f --- /dev/null +++ b/tests/test_sns/test_publish_batch.py @@ -0,0 +1,154 @@ +import boto3 +import json +import sure # noqa # pylint: disable=unused-import + +from botocore.exceptions import ClientError +import pytest +from moto import mock_sns, mock_sqs +from moto.core import ACCOUNT_ID + + +@mock_sns +def test_publish_batch_unknown_topic(): + client = boto3.client("sns", region_name="us-east-1") + with pytest.raises(ClientError) as exc: + client.publish_batch( + TopicArn=f"arn:aws:sns:us-east-1:{ACCOUNT_ID}:unknown", + PublishBatchRequestEntries=[{"Id": "id_1", "Message": "1"}], + ) + err = exc.value.response["Error"] + err["Code"].should.equal("NotFound") + err["Message"].should.equal("Topic does not exist") + + +@mock_sns +def test_publish_batch_too_many_items(): + client = boto3.client("sns", region_name="eu-north-1") + topic = client.create_topic(Name="some-topic") + + with pytest.raises(ClientError) as exc: + client.publish_batch( + TopicArn=topic["TopicArn"], + PublishBatchRequestEntries=[ + {"Id": f"id_{idx}", "Message": f"{idx}"} for idx in range(11) + ], + ) + err = exc.value.response["Error"] + err["Code"].should.equal("TooManyEntriesInBatchRequest") + err["Message"].should.equal( + "The batch request contains more entries than permissible." + ) + + +@mock_sns +def test_publish_batch_non_unique_ids(): + client = boto3.client("sns", region_name="us-west-2") + topic = client.create_topic(Name="some-topic") + + with pytest.raises(ClientError) as exc: + client.publish_batch( + TopicArn=topic["TopicArn"], + PublishBatchRequestEntries=[ + {"Id": f"id", "Message": f"{idx}"} for idx in range(5) + ], + ) + err = exc.value.response["Error"] + err["Code"].should.equal("BatchEntryIdsNotDistinct") + err["Message"].should.equal( + "Two or more batch entries in the request have the same Id." + ) + + +@mock_sns +def test_publish_batch_fifo_without_message_group_id(): + client = boto3.client("sns", region_name="us-east-1") + topic = client.create_topic( + Name="fifo_without_msg.fifo", + Attributes={"FifoTopic": "true", "ContentBasedDeduplication": "true"}, + ) + + with pytest.raises(ClientError) as exc: + client.publish_batch( + TopicArn=topic["TopicArn"], + PublishBatchRequestEntries=[{"Id": f"id_2", "Message": f"2"}], + ) + err = exc.value.response["Error"] + err["Code"].should.equal("InvalidParameter") + err["Message"].should.equal( + "Invalid parameter: The MessageGroupId parameter is required for FIFO topics" + ) + + +@mock_sns +def test_publish_batch_standard_with_message_group_id(): + client = boto3.client("sns", region_name="us-east-1") + topic_arn = client.create_topic(Name="standard_topic")["TopicArn"] + entries = [ + {"Id": f"id_1", "Message": f"1"}, + {"Id": f"id_2", "Message": f"2", "MessageGroupId": "mgid"}, + {"Id": f"id_3", "Message": f"3"}, + ] + resp = client.publish_batch(TopicArn=topic_arn, PublishBatchRequestEntries=entries) + + resp.should.have.key("Successful").length_of(2) + for message_status in resp["Successful"]: + message_status.should.have.key("MessageId") + [m["Id"] for m in resp["Successful"]].should.equal(["id_1", "id_3"]) + + resp.should.have.key("Failed").length_of(1) + resp["Failed"][0].should.equal( + { + "Id": "id_2", + "Code": "InvalidParameter", + "Message": "Invalid parameter: Value mgid for parameter MessageGroupId is invalid. Reason: The request include parameter that is not valid for this queue type.", + "SenderFault": True, + } + ) + + +@mock_sns +@mock_sqs +def test_publish_batch_to_sqs(): + client = boto3.client("sns", region_name="us-east-1") + topic_arn = client.create_topic(Name="standard_topic")["TopicArn"] + entries = [ + {"Id": f"id_1", "Message": f"1"}, + {"Id": f"id_2", "Message": f"2", "Subject": "subj2"}, + { + "Id": f"id_3", + "Message": f"3", + "MessageAttributes": {"a": {"DataType": "String", "StringValue": "v"}}, + }, + ] + + sqs_conn = boto3.resource("sqs", region_name="us-east-1") + queue = sqs_conn.create_queue(QueueName="test-queue") + + queue_url = "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID) + client.subscribe( + TopicArn=topic_arn, Protocol="sqs", Endpoint=queue_url, + ) + + resp = client.publish_batch(TopicArn=topic_arn, PublishBatchRequestEntries=entries) + + resp.should.have.key("Successful").length_of(3) + + messages = queue.receive_messages(MaxNumberOfMessages=3) + messages.should.have.length_of(3) + + messages = [json.loads(m.body) for m in messages] + for m in messages: + for key in list(m.keys()): + if key not in ["Message", "Subject", "MessageAttributes"]: + del m[key] + + messages.should.contain({"Message": "1"}) + messages.should.contain({"Message": "2", "Subject": "subj2"}) + messages.should.contain( + { + "Message": "3", + "MessageAttributes": [ + {"Name": "a", "Value": {"DataType": "String", "StringValue": "v"}} + ], + } + )