SNS Improvements (#4881)
This commit is contained in:
		
							parent
							
								
									412153aeb7
								
							
						
					
					
						commit
						c53dc4c21c
					
				@ -12,6 +12,8 @@
 | 
			
		||||
sns
 | 
			
		||||
===
 | 
			
		||||
 | 
			
		||||
.. autoclass:: moto.sns.models.SNSBackend
 | 
			
		||||
 | 
			
		||||
|start-h3| Example usage |end-h3|
 | 
			
		||||
 | 
			
		||||
.. sourcecode:: python
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,13 @@
 | 
			
		||||
from moto.core.exceptions import RESTError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SNSNotFoundError(RESTError):
 | 
			
		||||
class SNSException(RESTError):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        kwargs["template"] = "wrapped_single_error"
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SNSNotFoundError(SNSException):
 | 
			
		||||
    code = 404
 | 
			
		||||
 | 
			
		||||
    def __init__(self, message, **kwargs):
 | 
			
		||||
@ -13,42 +19,42 @@ class TopicNotFound(SNSNotFoundError):
 | 
			
		||||
        super().__init__(message="Topic does not exist")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ResourceNotFoundError(RESTError):
 | 
			
		||||
class ResourceNotFoundError(SNSException):
 | 
			
		||||
    code = 404
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__("ResourceNotFound", "Resource does not exist")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DuplicateSnsEndpointError(RESTError):
 | 
			
		||||
class DuplicateSnsEndpointError(SNSException):
 | 
			
		||||
    code = 400
 | 
			
		||||
 | 
			
		||||
    def __init__(self, message):
 | 
			
		||||
        super().__init__("DuplicateEndpoint", message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SnsEndpointDisabled(RESTError):
 | 
			
		||||
class SnsEndpointDisabled(SNSException):
 | 
			
		||||
    code = 400
 | 
			
		||||
 | 
			
		||||
    def __init__(self, message):
 | 
			
		||||
        super().__init__("EndpointDisabled", message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SNSInvalidParameter(RESTError):
 | 
			
		||||
class SNSInvalidParameter(SNSException):
 | 
			
		||||
    code = 400
 | 
			
		||||
 | 
			
		||||
    def __init__(self, message):
 | 
			
		||||
        super().__init__("InvalidParameter", message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InvalidParameterValue(RESTError):
 | 
			
		||||
class InvalidParameterValue(SNSException):
 | 
			
		||||
    code = 400
 | 
			
		||||
 | 
			
		||||
    def __init__(self, message):
 | 
			
		||||
        super().__init__("InvalidParameterValue", message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TagLimitExceededError(RESTError):
 | 
			
		||||
class TagLimitExceededError(SNSException):
 | 
			
		||||
    code = 400
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
@ -58,14 +64,14 @@ class TagLimitExceededError(RESTError):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InternalError(RESTError):
 | 
			
		||||
class InternalError(SNSException):
 | 
			
		||||
    code = 500
 | 
			
		||||
 | 
			
		||||
    def __init__(self, message):
 | 
			
		||||
        super().__init__("InternalFailure", message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TooManyEntriesInBatchRequest(RESTError):
 | 
			
		||||
class TooManyEntriesInBatchRequest(SNSException):
 | 
			
		||||
    code = 400
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
@ -75,7 +81,7 @@ class TooManyEntriesInBatchRequest(RESTError):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BatchEntryIdsNotDistinct(RESTError):
 | 
			
		||||
class BatchEntryIdsNotDistinct(SNSException):
 | 
			
		||||
    code = 400
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
 | 
			
		||||
@ -51,6 +51,7 @@ class Topic(CloudFormationModel):
 | 
			
		||||
        self.subscriptions_pending = 0
 | 
			
		||||
        self.subscriptions_confimed = 0
 | 
			
		||||
        self.subscriptions_deleted = 0
 | 
			
		||||
        self.sent_notifications = []
 | 
			
		||||
 | 
			
		||||
        self._policy_json = self._create_default_topic_policy(
 | 
			
		||||
            sns_backend.region_name, self.account_id, name
 | 
			
		||||
@ -70,6 +71,9 @@ class Topic(CloudFormationModel):
 | 
			
		||||
                message_attributes=message_attributes,
 | 
			
		||||
                group_id=group_id,
 | 
			
		||||
            )
 | 
			
		||||
        self.sent_notifications.append(
 | 
			
		||||
            (message_id, message, subject, message_attributes, group_id)
 | 
			
		||||
        )
 | 
			
		||||
        return message_id
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
@ -89,7 +93,7 @@ class Topic(CloudFormationModel):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def policy(self):
 | 
			
		||||
        return json.dumps(self._policy_json)
 | 
			
		||||
        return json.dumps(self._policy_json, separators=(",", ":"))
 | 
			
		||||
 | 
			
		||||
    @policy.setter
 | 
			
		||||
    def policy(self, policy):
 | 
			
		||||
@ -215,7 +219,7 @@ class Subscription(BaseModel):
 | 
			
		||||
                    if value["Type"].startswith("Binary"):
 | 
			
		||||
                        attr_type = "binary_value"
 | 
			
		||||
                    elif value["Type"].startswith("Number"):
 | 
			
		||||
                        type_value = "{0:g}".format(value["Value"])
 | 
			
		||||
                        type_value = str(value["Value"])
 | 
			
		||||
 | 
			
		||||
                    raw_message_attributes[key] = {
 | 
			
		||||
                        "data_type": value["Type"],
 | 
			
		||||
@ -404,6 +408,19 @@ class PlatformEndpoint(BaseModel):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SNSBackend(BaseBackend):
 | 
			
		||||
    """
 | 
			
		||||
    Responsible for mocking calls to SNS. Integration with SQS/HTTP/etc is supported.
 | 
			
		||||
 | 
			
		||||
    Messages published to a topic are persisted in the backend. If you need to verify that a message was published successfully, you can use the internal API to check the message was published successfully:
 | 
			
		||||
 | 
			
		||||
    .. sourcecode:: python
 | 
			
		||||
 | 
			
		||||
        from moto.sns import sns_backend
 | 
			
		||||
        all_send_notifications = sns_backend.topics[topic_arn].sent_notifications
 | 
			
		||||
 | 
			
		||||
    Note that, as this is an internal API, the exact format may differ per versions.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, region_name):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.topics = OrderedDict()
 | 
			
		||||
@ -880,7 +897,7 @@ class SNSBackend(BaseBackend):
 | 
			
		||||
                    message=entry["Message"],
 | 
			
		||||
                    arn=topic_arn,
 | 
			
		||||
                    subject=entry.get("Subject"),
 | 
			
		||||
                    message_attributes=entry.get("MessageAttributes", []),
 | 
			
		||||
                    message_attributes=entry.get("MessageAttributes", {}),
 | 
			
		||||
                    group_id=entry.get("MessageGroupId"),
 | 
			
		||||
                )
 | 
			
		||||
                successful.append({"MessageId": message_id, "Id": entry["Id"]})
 | 
			
		||||
 | 
			
		||||
@ -31,10 +31,13 @@ class SNSResponse(BaseResponse):
 | 
			
		||||
        tags = self._get_list_prefix("Tags.member")
 | 
			
		||||
        return {tag["key"]: tag["value"] for tag in tags}
 | 
			
		||||
 | 
			
		||||
    def _parse_message_attributes(self, prefix="", value_namespace="Value."):
 | 
			
		||||
    def _parse_message_attributes(self):
 | 
			
		||||
        message_attributes = self._get_object_map(
 | 
			
		||||
            "MessageAttributes.entry", name="Name", value="Value"
 | 
			
		||||
        )
 | 
			
		||||
        return self._transform_message_attributes(message_attributes)
 | 
			
		||||
 | 
			
		||||
    def _transform_message_attributes(self, message_attributes):
 | 
			
		||||
        # SNS converts some key names before forwarding messages
 | 
			
		||||
        # DataType -> Type, StringValue -> Value, BinaryValue -> Value
 | 
			
		||||
        transformed_message_attributes = {}
 | 
			
		||||
@ -63,15 +66,18 @@ class SNSResponse(BaseResponse):
 | 
			
		||||
            if "StringValue" in value:
 | 
			
		||||
                if data_type == "Number":
 | 
			
		||||
                    try:
 | 
			
		||||
                        transform_value = float(value["StringValue"])
 | 
			
		||||
                        transform_value = int(value["StringValue"])
 | 
			
		||||
                    except ValueError:
 | 
			
		||||
                        raise InvalidParameterValue(
 | 
			
		||||
                            "An error occurred (ParameterValueInvalid) "
 | 
			
		||||
                            "when calling the Publish operation: "
 | 
			
		||||
                            "Could not cast message attribute '{0}' value to number.".format(
 | 
			
		||||
                                name
 | 
			
		||||
                        try:
 | 
			
		||||
                            transform_value = float(value["StringValue"])
 | 
			
		||||
                        except ValueError:
 | 
			
		||||
                            raise InvalidParameterValue(
 | 
			
		||||
                                "An error occurred (ParameterValueInvalid) "
 | 
			
		||||
                                "when calling the Publish operation: "
 | 
			
		||||
                                "Could not cast message attribute '{0}' value to number.".format(
 | 
			
		||||
                                    name
 | 
			
		||||
                                )
 | 
			
		||||
                            )
 | 
			
		||||
                        )
 | 
			
		||||
                else:
 | 
			
		||||
                    transform_value = value["StringValue"]
 | 
			
		||||
            elif "BinaryValue" in value:
 | 
			
		||||
@ -382,6 +388,16 @@ class SNSResponse(BaseResponse):
 | 
			
		||||
        publish_batch_request_entries = self._get_multi_param(
 | 
			
		||||
            "PublishBatchRequestEntries.member"
 | 
			
		||||
        )
 | 
			
		||||
        for entry in publish_batch_request_entries:
 | 
			
		||||
            if "MessageAttributes" in entry:
 | 
			
		||||
                # Convert into the same format as the regular publish-method
 | 
			
		||||
                # FROM: [{'Name': 'a', 'Value': {'DataType': 'String', 'StringValue': 'v'}}]
 | 
			
		||||
                # TO  : {'name': {'DataType': 'Number', 'StringValue': '123'}}
 | 
			
		||||
                msg_attrs = {y["Name"]: y["Value"] for y in entry["MessageAttributes"]}
 | 
			
		||||
                # Use the same validation/processing as the regular publish-method
 | 
			
		||||
                entry["MessageAttributes"] = self._transform_message_attributes(
 | 
			
		||||
                    msg_attrs
 | 
			
		||||
                )
 | 
			
		||||
        successful, failed = self.backend.publish_batch(
 | 
			
		||||
            topic_arn=topic_arn,
 | 
			
		||||
            publish_batch_request_entries=publish_batch_request_entries,
 | 
			
		||||
 | 
			
		||||
@ -87,6 +87,7 @@ TestAccAWSProvider
 | 
			
		||||
TestAccAWSRedshiftServiceAccount
 | 
			
		||||
TestAccAWSRolePolicyAttachment
 | 
			
		||||
TestAccAWSSNSSMSPreferences
 | 
			
		||||
TestAccAWSSNSTopicPolicy
 | 
			
		||||
TestAccAWSSageMakerPrebuiltECRImage
 | 
			
		||||
TestAccAWSServiceDiscovery
 | 
			
		||||
TestAccAWSSQSQueuePolicy
 | 
			
		||||
 | 
			
		||||
@ -145,10 +145,43 @@ def test_publish_batch_to_sqs():
 | 
			
		||||
    messages.should.contain({"Message": "1"})
 | 
			
		||||
    messages.should.contain({"Message": "2", "Subject": "subj2"})
 | 
			
		||||
    messages.should.contain(
 | 
			
		||||
        {
 | 
			
		||||
            "Message": "3",
 | 
			
		||||
            "MessageAttributes": [
 | 
			
		||||
                {"Name": "a", "Value": {"DataType": "String", "StringValue": "v"}}
 | 
			
		||||
            ],
 | 
			
		||||
        }
 | 
			
		||||
        {"Message": "3", "MessageAttributes": {"a": {"Type": "String", "Value": "v"}}}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_sqs
 | 
			
		||||
@mock_sns
 | 
			
		||||
def test_publish_batch_to_sqs_raw():
 | 
			
		||||
    client = boto3.client("sns", region_name="us-east-1")
 | 
			
		||||
    topic_arn = client.create_topic(Name="standard_topic")["TopicArn"]
 | 
			
		||||
    sqs = boto3.resource("sqs", region_name="us-east-1")
 | 
			
		||||
    queue = sqs.create_queue(QueueName="test-queue")
 | 
			
		||||
 | 
			
		||||
    queue_url = "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID)
 | 
			
		||||
    client.subscribe(
 | 
			
		||||
        TopicArn=topic_arn,
 | 
			
		||||
        Protocol="sqs",
 | 
			
		||||
        Endpoint=queue_url,
 | 
			
		||||
        Attributes={"RawMessageDelivery": "true"},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    entries = [
 | 
			
		||||
        {"Id": "1", "Message": "foo",},
 | 
			
		||||
        {
 | 
			
		||||
            "Id": "2",
 | 
			
		||||
            "Message": "bar",
 | 
			
		||||
            "MessageAttributes": {"a": {"DataType": "String", "StringValue": "v"}},
 | 
			
		||||
        },
 | 
			
		||||
    ]
 | 
			
		||||
    resp = client.publish_batch(TopicArn=topic_arn, PublishBatchRequestEntries=entries,)
 | 
			
		||||
 | 
			
		||||
    resp.should.have.key("Successful").length_of(2)
 | 
			
		||||
 | 
			
		||||
    received = queue.receive_messages(
 | 
			
		||||
        MaxNumberOfMessages=10, MessageAttributeNames=["All"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    messages = [(message.body, message.message_attributes) for message in received]
 | 
			
		||||
 | 
			
		||||
    messages.should.contain(("foo", None))
 | 
			
		||||
    messages.should.contain(("bar", {"a": {"StringValue": "v", "DataType": "String"}}))
 | 
			
		||||
 | 
			
		||||
@ -238,6 +238,49 @@ def test_publish_to_sqs_msg_attr_number_type():
 | 
			
		||||
    message.body.should.equal("test message")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_sqs
 | 
			
		||||
@mock_sns
 | 
			
		||||
def test_publish_to_sqs_msg_attr_different_formats():
 | 
			
		||||
    """
 | 
			
		||||
    Verify different Number-formats are processed correctly
 | 
			
		||||
    """
 | 
			
		||||
    sns = boto3.resource("sns", region_name="us-east-1")
 | 
			
		||||
    topic = sns.create_topic(Name="test-topic")
 | 
			
		||||
    sqs = boto3.resource("sqs", region_name="us-east-1")
 | 
			
		||||
    sqs_client = boto3.client("sqs", region_name="us-east-1")
 | 
			
		||||
    queue_raw = sqs.create_queue(QueueName="test-queue-raw")
 | 
			
		||||
 | 
			
		||||
    topic.subscribe(
 | 
			
		||||
        Protocol="sqs",
 | 
			
		||||
        Endpoint=queue_raw.attributes["QueueArn"],
 | 
			
		||||
        Attributes={"RawMessageDelivery": "true"},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    topic.publish(
 | 
			
		||||
        Message="test message",
 | 
			
		||||
        MessageAttributes={
 | 
			
		||||
            "integer": {"DataType": "Number", "StringValue": "123"},
 | 
			
		||||
            "float": {"DataType": "Number", "StringValue": "12.34"},
 | 
			
		||||
            "big-integer": {"DataType": "Number", "StringValue": "123456789"},
 | 
			
		||||
            "big-float": {"DataType": "Number", "StringValue": "123456.789"},
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    messages_resp = sqs_client.receive_message(
 | 
			
		||||
        QueueUrl=queue_raw.url, MessageAttributeNames=["All"]
 | 
			
		||||
    )
 | 
			
		||||
    message = messages_resp["Messages"][0]
 | 
			
		||||
    message_attributes = message["MessageAttributes"]
 | 
			
		||||
    message_attributes.should.equal(
 | 
			
		||||
        {
 | 
			
		||||
            "integer": {"DataType": "Number", "StringValue": "123"},
 | 
			
		||||
            "float": {"DataType": "Number", "StringValue": "12.34"},
 | 
			
		||||
            "big-integer": {"DataType": "Number", "StringValue": "123456789"},
 | 
			
		||||
            "big-float": {"DataType": "Number", "StringValue": "123456.789"},
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_sns
 | 
			
		||||
def test_publish_sms():
 | 
			
		||||
    client = boto3.client("sns", region_name="us-east-1")
 | 
			
		||||
@ -374,9 +417,14 @@ def test_publish_to_http():
 | 
			
		||||
        TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/foobar"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    response = conn.publish(
 | 
			
		||||
        TopicArn=topic_arn, Message="my message", Subject="my subject"
 | 
			
		||||
    )
 | 
			
		||||
    conn.publish(TopicArn=topic_arn, Message="my message", Subject="my subject")
 | 
			
		||||
 | 
			
		||||
    if not settings.TEST_SERVER_MODE:
 | 
			
		||||
        sns_backend.topics[topic_arn].sent_notifications.should.have.length_of(1)
 | 
			
		||||
        notification = sns_backend.topics[topic_arn].sent_notifications[0]
 | 
			
		||||
        _, msg, subject, _, _ = notification
 | 
			
		||||
        msg.should.equal("my message")
 | 
			
		||||
        subject.should.equal("my subject")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock_sqs
 | 
			
		||||
 | 
			
		||||
@ -214,7 +214,7 @@ def test_topic_attributes():
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"]
 | 
			
		||||
    attributes["Policy"].should.equal('{"foo": "bar"}')
 | 
			
		||||
    attributes["Policy"].should.equal('{"foo":"bar"}')
 | 
			
		||||
    attributes["DisplayName"].should.equal("My display name")
 | 
			
		||||
    attributes["DeliveryPolicy"].should.equal(
 | 
			
		||||
        '{"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}}'
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user