From c53dc4c21c9a0713d2e08a3961ce32331c57bc6a Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Mon, 21 Feb 2022 21:01:38 -0100 Subject: [PATCH] SNS Improvements (#4881) --- docs/docs/services/sns.rst | 2 + moto/sns/exceptions.py | 26 +++++++----- moto/sns/models.py | 23 +++++++++-- moto/sns/responses.py | 32 +++++++++++---- tests/terraform-tests.success.txt | 1 + tests/test_sns/test_publish_batch.py | 45 ++++++++++++++++++--- tests/test_sns/test_publishing_boto3.py | 54 +++++++++++++++++++++++-- tests/test_sns/test_topics_boto3.py | 2 +- 8 files changed, 154 insertions(+), 31 deletions(-) diff --git a/docs/docs/services/sns.rst b/docs/docs/services/sns.rst index 500930f02..1e96fab5e 100644 --- a/docs/docs/services/sns.rst +++ b/docs/docs/services/sns.rst @@ -12,6 +12,8 @@ sns === +.. autoclass:: moto.sns.models.SNSBackend + |start-h3| Example usage |end-h3| .. sourcecode:: python diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py index 41a30e051..954d28d5a 100644 --- a/moto/sns/exceptions.py +++ b/moto/sns/exceptions.py @@ -1,7 +1,13 @@ from moto.core.exceptions import RESTError -class SNSNotFoundError(RESTError): +class SNSException(RESTError): + def __init__(self, *args, **kwargs): + kwargs["template"] = "wrapped_single_error" + super().__init__(*args, **kwargs) + + +class SNSNotFoundError(SNSException): code = 404 def __init__(self, message, **kwargs): @@ -13,42 +19,42 @@ class TopicNotFound(SNSNotFoundError): super().__init__(message="Topic does not exist") -class ResourceNotFoundError(RESTError): +class ResourceNotFoundError(SNSException): code = 404 def __init__(self): super().__init__("ResourceNotFound", "Resource does not exist") -class DuplicateSnsEndpointError(RESTError): +class DuplicateSnsEndpointError(SNSException): code = 400 def __init__(self, message): super().__init__("DuplicateEndpoint", message) -class SnsEndpointDisabled(RESTError): +class SnsEndpointDisabled(SNSException): code = 400 def __init__(self, message): super().__init__("EndpointDisabled", message) -class SNSInvalidParameter(RESTError): +class SNSInvalidParameter(SNSException): code = 400 def __init__(self, message): super().__init__("InvalidParameter", message) -class InvalidParameterValue(RESTError): +class InvalidParameterValue(SNSException): code = 400 def __init__(self, message): super().__init__("InvalidParameterValue", message) -class TagLimitExceededError(RESTError): +class TagLimitExceededError(SNSException): code = 400 def __init__(self): @@ -58,14 +64,14 @@ class TagLimitExceededError(RESTError): ) -class InternalError(RESTError): +class InternalError(SNSException): code = 500 def __init__(self, message): super().__init__("InternalFailure", message) -class TooManyEntriesInBatchRequest(RESTError): +class TooManyEntriesInBatchRequest(SNSException): code = 400 def __init__(self): @@ -75,7 +81,7 @@ class TooManyEntriesInBatchRequest(RESTError): ) -class BatchEntryIdsNotDistinct(RESTError): +class BatchEntryIdsNotDistinct(SNSException): code = 400 def __init__(self): diff --git a/moto/sns/models.py b/moto/sns/models.py index 55d783d4d..9f4f700c4 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -51,6 +51,7 @@ class Topic(CloudFormationModel): self.subscriptions_pending = 0 self.subscriptions_confimed = 0 self.subscriptions_deleted = 0 + self.sent_notifications = [] self._policy_json = self._create_default_topic_policy( sns_backend.region_name, self.account_id, name @@ -70,6 +71,9 @@ class Topic(CloudFormationModel): message_attributes=message_attributes, group_id=group_id, ) + self.sent_notifications.append( + (message_id, message, subject, message_attributes, group_id) + ) return message_id @classmethod @@ -89,7 +93,7 @@ class Topic(CloudFormationModel): @property def policy(self): - return json.dumps(self._policy_json) + return json.dumps(self._policy_json, separators=(",", ":")) @policy.setter def policy(self, policy): @@ -215,7 +219,7 @@ class Subscription(BaseModel): if value["Type"].startswith("Binary"): attr_type = "binary_value" elif value["Type"].startswith("Number"): - type_value = "{0:g}".format(value["Value"]) + type_value = str(value["Value"]) raw_message_attributes[key] = { "data_type": value["Type"], @@ -404,6 +408,19 @@ class PlatformEndpoint(BaseModel): class SNSBackend(BaseBackend): + """ + Responsible for mocking calls to SNS. Integration with SQS/HTTP/etc is supported. + + Messages published to a topic are persisted in the backend. If you need to verify that a message was published successfully, you can use the internal API to check the message was published successfully: + + .. sourcecode:: python + + from moto.sns import sns_backend + all_send_notifications = sns_backend.topics[topic_arn].sent_notifications + + Note that, as this is an internal API, the exact format may differ per versions. + """ + def __init__(self, region_name): super().__init__() self.topics = OrderedDict() @@ -880,7 +897,7 @@ class SNSBackend(BaseBackend): message=entry["Message"], arn=topic_arn, subject=entry.get("Subject"), - message_attributes=entry.get("MessageAttributes", []), + message_attributes=entry.get("MessageAttributes", {}), group_id=entry.get("MessageGroupId"), ) successful.append({"MessageId": message_id, "Id": entry["Id"]}) diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 8291586cc..bd9fbc8e0 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -31,10 +31,13 @@ class SNSResponse(BaseResponse): tags = self._get_list_prefix("Tags.member") return {tag["key"]: tag["value"] for tag in tags} - def _parse_message_attributes(self, prefix="", value_namespace="Value."): + def _parse_message_attributes(self): message_attributes = self._get_object_map( "MessageAttributes.entry", name="Name", value="Value" ) + return self._transform_message_attributes(message_attributes) + + def _transform_message_attributes(self, message_attributes): # SNS converts some key names before forwarding messages # DataType -> Type, StringValue -> Value, BinaryValue -> Value transformed_message_attributes = {} @@ -63,15 +66,18 @@ class SNSResponse(BaseResponse): if "StringValue" in value: if data_type == "Number": try: - transform_value = float(value["StringValue"]) + transform_value = int(value["StringValue"]) except ValueError: - raise InvalidParameterValue( - "An error occurred (ParameterValueInvalid) " - "when calling the Publish operation: " - "Could not cast message attribute '{0}' value to number.".format( - name + try: + transform_value = float(value["StringValue"]) + except ValueError: + raise InvalidParameterValue( + "An error occurred (ParameterValueInvalid) " + "when calling the Publish operation: " + "Could not cast message attribute '{0}' value to number.".format( + name + ) ) - ) else: transform_value = value["StringValue"] elif "BinaryValue" in value: @@ -382,6 +388,16 @@ class SNSResponse(BaseResponse): publish_batch_request_entries = self._get_multi_param( "PublishBatchRequestEntries.member" ) + for entry in publish_batch_request_entries: + if "MessageAttributes" in entry: + # Convert into the same format as the regular publish-method + # FROM: [{'Name': 'a', 'Value': {'DataType': 'String', 'StringValue': 'v'}}] + # TO : {'name': {'DataType': 'Number', 'StringValue': '123'}} + msg_attrs = {y["Name"]: y["Value"] for y in entry["MessageAttributes"]} + # Use the same validation/processing as the regular publish-method + entry["MessageAttributes"] = self._transform_message_attributes( + msg_attrs + ) successful, failed = self.backend.publish_batch( topic_arn=topic_arn, publish_batch_request_entries=publish_batch_request_entries, diff --git a/tests/terraform-tests.success.txt b/tests/terraform-tests.success.txt index 6c0b11bb8..8dbe512bf 100644 --- a/tests/terraform-tests.success.txt +++ b/tests/terraform-tests.success.txt @@ -87,6 +87,7 @@ TestAccAWSProvider TestAccAWSRedshiftServiceAccount TestAccAWSRolePolicyAttachment TestAccAWSSNSSMSPreferences +TestAccAWSSNSTopicPolicy TestAccAWSSageMakerPrebuiltECRImage TestAccAWSServiceDiscovery TestAccAWSSQSQueuePolicy diff --git a/tests/test_sns/test_publish_batch.py b/tests/test_sns/test_publish_batch.py index 53dc37f5f..930580ac4 100644 --- a/tests/test_sns/test_publish_batch.py +++ b/tests/test_sns/test_publish_batch.py @@ -145,10 +145,43 @@ def test_publish_batch_to_sqs(): messages.should.contain({"Message": "1"}) messages.should.contain({"Message": "2", "Subject": "subj2"}) messages.should.contain( - { - "Message": "3", - "MessageAttributes": [ - {"Name": "a", "Value": {"DataType": "String", "StringValue": "v"}} - ], - } + {"Message": "3", "MessageAttributes": {"a": {"Type": "String", "Value": "v"}}} ) + + +@mock_sqs +@mock_sns +def test_publish_batch_to_sqs_raw(): + client = boto3.client("sns", region_name="us-east-1") + topic_arn = client.create_topic(Name="standard_topic")["TopicArn"] + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") + + queue_url = "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID) + client.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint=queue_url, + Attributes={"RawMessageDelivery": "true"}, + ) + + entries = [ + {"Id": "1", "Message": "foo",}, + { + "Id": "2", + "Message": "bar", + "MessageAttributes": {"a": {"DataType": "String", "StringValue": "v"}}, + }, + ] + resp = client.publish_batch(TopicArn=topic_arn, PublishBatchRequestEntries=entries,) + + resp.should.have.key("Successful").length_of(2) + + received = queue.receive_messages( + MaxNumberOfMessages=10, MessageAttributeNames=["All"], + ) + + messages = [(message.body, message.message_attributes) for message in received] + + messages.should.contain(("foo", None)) + messages.should.contain(("bar", {"a": {"StringValue": "v", "DataType": "String"}})) diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index ce2d98dee..fde6f8a71 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -238,6 +238,49 @@ def test_publish_to_sqs_msg_attr_number_type(): message.body.should.equal("test message") +@mock_sqs +@mock_sns +def test_publish_to_sqs_msg_attr_different_formats(): + """ + Verify different Number-formats are processed correctly + """ + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="test-topic") + sqs = boto3.resource("sqs", region_name="us-east-1") + sqs_client = boto3.client("sqs", region_name="us-east-1") + queue_raw = sqs.create_queue(QueueName="test-queue-raw") + + topic.subscribe( + Protocol="sqs", + Endpoint=queue_raw.attributes["QueueArn"], + Attributes={"RawMessageDelivery": "true"}, + ) + + topic.publish( + Message="test message", + MessageAttributes={ + "integer": {"DataType": "Number", "StringValue": "123"}, + "float": {"DataType": "Number", "StringValue": "12.34"}, + "big-integer": {"DataType": "Number", "StringValue": "123456789"}, + "big-float": {"DataType": "Number", "StringValue": "123456.789"}, + }, + ) + + messages_resp = sqs_client.receive_message( + QueueUrl=queue_raw.url, MessageAttributeNames=["All"] + ) + message = messages_resp["Messages"][0] + message_attributes = message["MessageAttributes"] + message_attributes.should.equal( + { + "integer": {"DataType": "Number", "StringValue": "123"}, + "float": {"DataType": "Number", "StringValue": "12.34"}, + "big-integer": {"DataType": "Number", "StringValue": "123456789"}, + "big-float": {"DataType": "Number", "StringValue": "123456.789"}, + } + ) + + @mock_sns def test_publish_sms(): client = boto3.client("sns", region_name="us-east-1") @@ -374,9 +417,14 @@ def test_publish_to_http(): TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/foobar" ) - response = conn.publish( - TopicArn=topic_arn, Message="my message", Subject="my subject" - ) + conn.publish(TopicArn=topic_arn, Message="my message", Subject="my subject") + + if not settings.TEST_SERVER_MODE: + sns_backend.topics[topic_arn].sent_notifications.should.have.length_of(1) + notification = sns_backend.topics[topic_arn].sent_notifications[0] + _, msg, subject, _, _ = notification + msg.should.equal("my message") + subject.should.equal("my subject") @mock_sqs diff --git a/tests/test_sns/test_topics_boto3.py b/tests/test_sns/test_topics_boto3.py index 4b38d615d..13307fc44 100644 --- a/tests/test_sns/test_topics_boto3.py +++ b/tests/test_sns/test_topics_boto3.py @@ -214,7 +214,7 @@ def test_topic_attributes(): ) attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"] - attributes["Policy"].should.equal('{"foo": "bar"}') + attributes["Policy"].should.equal('{"foo":"bar"}') attributes["DisplayName"].should.equal("My display name") attributes["DeliveryPolicy"].should.equal( '{"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}}'