SNS Improvements (#4881)
This commit is contained in:
parent
412153aeb7
commit
c53dc4c21c
@ -12,6 +12,8 @@
|
|||||||
sns
|
sns
|
||||||
===
|
===
|
||||||
|
|
||||||
|
.. autoclass:: moto.sns.models.SNSBackend
|
||||||
|
|
||||||
|start-h3| Example usage |end-h3|
|
|start-h3| Example usage |end-h3|
|
||||||
|
|
||||||
.. sourcecode:: python
|
.. sourcecode:: python
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
from moto.core.exceptions import RESTError
|
from moto.core.exceptions import RESTError
|
||||||
|
|
||||||
|
|
||||||
class SNSNotFoundError(RESTError):
|
class SNSException(RESTError):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
kwargs["template"] = "wrapped_single_error"
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class SNSNotFoundError(SNSException):
|
||||||
code = 404
|
code = 404
|
||||||
|
|
||||||
def __init__(self, message, **kwargs):
|
def __init__(self, message, **kwargs):
|
||||||
@ -13,42 +19,42 @@ class TopicNotFound(SNSNotFoundError):
|
|||||||
super().__init__(message="Topic does not exist")
|
super().__init__(message="Topic does not exist")
|
||||||
|
|
||||||
|
|
||||||
class ResourceNotFoundError(RESTError):
|
class ResourceNotFoundError(SNSException):
|
||||||
code = 404
|
code = 404
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("ResourceNotFound", "Resource does not exist")
|
super().__init__("ResourceNotFound", "Resource does not exist")
|
||||||
|
|
||||||
|
|
||||||
class DuplicateSnsEndpointError(RESTError):
|
class DuplicateSnsEndpointError(SNSException):
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
super().__init__("DuplicateEndpoint", message)
|
super().__init__("DuplicateEndpoint", message)
|
||||||
|
|
||||||
|
|
||||||
class SnsEndpointDisabled(RESTError):
|
class SnsEndpointDisabled(SNSException):
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
super().__init__("EndpointDisabled", message)
|
super().__init__("EndpointDisabled", message)
|
||||||
|
|
||||||
|
|
||||||
class SNSInvalidParameter(RESTError):
|
class SNSInvalidParameter(SNSException):
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
super().__init__("InvalidParameter", message)
|
super().__init__("InvalidParameter", message)
|
||||||
|
|
||||||
|
|
||||||
class InvalidParameterValue(RESTError):
|
class InvalidParameterValue(SNSException):
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
super().__init__("InvalidParameterValue", message)
|
super().__init__("InvalidParameterValue", message)
|
||||||
|
|
||||||
|
|
||||||
class TagLimitExceededError(RESTError):
|
class TagLimitExceededError(SNSException):
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -58,14 +64,14 @@ class TagLimitExceededError(RESTError):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InternalError(RESTError):
|
class InternalError(SNSException):
|
||||||
code = 500
|
code = 500
|
||||||
|
|
||||||
def __init__(self, message):
|
def __init__(self, message):
|
||||||
super().__init__("InternalFailure", message)
|
super().__init__("InternalFailure", message)
|
||||||
|
|
||||||
|
|
||||||
class TooManyEntriesInBatchRequest(RESTError):
|
class TooManyEntriesInBatchRequest(SNSException):
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -75,7 +81,7 @@ class TooManyEntriesInBatchRequest(RESTError):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BatchEntryIdsNotDistinct(RESTError):
|
class BatchEntryIdsNotDistinct(SNSException):
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -51,6 +51,7 @@ class Topic(CloudFormationModel):
|
|||||||
self.subscriptions_pending = 0
|
self.subscriptions_pending = 0
|
||||||
self.subscriptions_confimed = 0
|
self.subscriptions_confimed = 0
|
||||||
self.subscriptions_deleted = 0
|
self.subscriptions_deleted = 0
|
||||||
|
self.sent_notifications = []
|
||||||
|
|
||||||
self._policy_json = self._create_default_topic_policy(
|
self._policy_json = self._create_default_topic_policy(
|
||||||
sns_backend.region_name, self.account_id, name
|
sns_backend.region_name, self.account_id, name
|
||||||
@ -70,6 +71,9 @@ class Topic(CloudFormationModel):
|
|||||||
message_attributes=message_attributes,
|
message_attributes=message_attributes,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
)
|
)
|
||||||
|
self.sent_notifications.append(
|
||||||
|
(message_id, message, subject, message_attributes, group_id)
|
||||||
|
)
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -89,7 +93,7 @@ class Topic(CloudFormationModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def policy(self):
|
def policy(self):
|
||||||
return json.dumps(self._policy_json)
|
return json.dumps(self._policy_json, separators=(",", ":"))
|
||||||
|
|
||||||
@policy.setter
|
@policy.setter
|
||||||
def policy(self, policy):
|
def policy(self, policy):
|
||||||
@ -215,7 +219,7 @@ class Subscription(BaseModel):
|
|||||||
if value["Type"].startswith("Binary"):
|
if value["Type"].startswith("Binary"):
|
||||||
attr_type = "binary_value"
|
attr_type = "binary_value"
|
||||||
elif value["Type"].startswith("Number"):
|
elif value["Type"].startswith("Number"):
|
||||||
type_value = "{0:g}".format(value["Value"])
|
type_value = str(value["Value"])
|
||||||
|
|
||||||
raw_message_attributes[key] = {
|
raw_message_attributes[key] = {
|
||||||
"data_type": value["Type"],
|
"data_type": value["Type"],
|
||||||
@ -404,6 +408,19 @@ class PlatformEndpoint(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SNSBackend(BaseBackend):
|
class SNSBackend(BaseBackend):
|
||||||
|
"""
|
||||||
|
Responsible for mocking calls to SNS. Integration with SQS/HTTP/etc is supported.
|
||||||
|
|
||||||
|
Messages published to a topic are persisted in the backend. If you need to verify that a message was published successfully, you can use the internal API to check the message was published successfully:
|
||||||
|
|
||||||
|
.. sourcecode:: python
|
||||||
|
|
||||||
|
from moto.sns import sns_backend
|
||||||
|
all_send_notifications = sns_backend.topics[topic_arn].sent_notifications
|
||||||
|
|
||||||
|
Note that, as this is an internal API, the exact format may differ per versions.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, region_name):
|
def __init__(self, region_name):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.topics = OrderedDict()
|
self.topics = OrderedDict()
|
||||||
@ -880,7 +897,7 @@ class SNSBackend(BaseBackend):
|
|||||||
message=entry["Message"],
|
message=entry["Message"],
|
||||||
arn=topic_arn,
|
arn=topic_arn,
|
||||||
subject=entry.get("Subject"),
|
subject=entry.get("Subject"),
|
||||||
message_attributes=entry.get("MessageAttributes", []),
|
message_attributes=entry.get("MessageAttributes", {}),
|
||||||
group_id=entry.get("MessageGroupId"),
|
group_id=entry.get("MessageGroupId"),
|
||||||
)
|
)
|
||||||
successful.append({"MessageId": message_id, "Id": entry["Id"]})
|
successful.append({"MessageId": message_id, "Id": entry["Id"]})
|
||||||
|
@ -31,10 +31,13 @@ class SNSResponse(BaseResponse):
|
|||||||
tags = self._get_list_prefix("Tags.member")
|
tags = self._get_list_prefix("Tags.member")
|
||||||
return {tag["key"]: tag["value"] for tag in tags}
|
return {tag["key"]: tag["value"] for tag in tags}
|
||||||
|
|
||||||
def _parse_message_attributes(self, prefix="", value_namespace="Value."):
|
def _parse_message_attributes(self):
|
||||||
message_attributes = self._get_object_map(
|
message_attributes = self._get_object_map(
|
||||||
"MessageAttributes.entry", name="Name", value="Value"
|
"MessageAttributes.entry", name="Name", value="Value"
|
||||||
)
|
)
|
||||||
|
return self._transform_message_attributes(message_attributes)
|
||||||
|
|
||||||
|
def _transform_message_attributes(self, message_attributes):
|
||||||
# SNS converts some key names before forwarding messages
|
# SNS converts some key names before forwarding messages
|
||||||
# DataType -> Type, StringValue -> Value, BinaryValue -> Value
|
# DataType -> Type, StringValue -> Value, BinaryValue -> Value
|
||||||
transformed_message_attributes = {}
|
transformed_message_attributes = {}
|
||||||
@ -63,15 +66,18 @@ class SNSResponse(BaseResponse):
|
|||||||
if "StringValue" in value:
|
if "StringValue" in value:
|
||||||
if data_type == "Number":
|
if data_type == "Number":
|
||||||
try:
|
try:
|
||||||
transform_value = float(value["StringValue"])
|
transform_value = int(value["StringValue"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise InvalidParameterValue(
|
try:
|
||||||
"An error occurred (ParameterValueInvalid) "
|
transform_value = float(value["StringValue"])
|
||||||
"when calling the Publish operation: "
|
except ValueError:
|
||||||
"Could not cast message attribute '{0}' value to number.".format(
|
raise InvalidParameterValue(
|
||||||
name
|
"An error occurred (ParameterValueInvalid) "
|
||||||
|
"when calling the Publish operation: "
|
||||||
|
"Could not cast message attribute '{0}' value to number.".format(
|
||||||
|
name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
transform_value = value["StringValue"]
|
transform_value = value["StringValue"]
|
||||||
elif "BinaryValue" in value:
|
elif "BinaryValue" in value:
|
||||||
@ -382,6 +388,16 @@ class SNSResponse(BaseResponse):
|
|||||||
publish_batch_request_entries = self._get_multi_param(
|
publish_batch_request_entries = self._get_multi_param(
|
||||||
"PublishBatchRequestEntries.member"
|
"PublishBatchRequestEntries.member"
|
||||||
)
|
)
|
||||||
|
for entry in publish_batch_request_entries:
|
||||||
|
if "MessageAttributes" in entry:
|
||||||
|
# Convert into the same format as the regular publish-method
|
||||||
|
# FROM: [{'Name': 'a', 'Value': {'DataType': 'String', 'StringValue': 'v'}}]
|
||||||
|
# TO : {'name': {'DataType': 'Number', 'StringValue': '123'}}
|
||||||
|
msg_attrs = {y["Name"]: y["Value"] for y in entry["MessageAttributes"]}
|
||||||
|
# Use the same validation/processing as the regular publish-method
|
||||||
|
entry["MessageAttributes"] = self._transform_message_attributes(
|
||||||
|
msg_attrs
|
||||||
|
)
|
||||||
successful, failed = self.backend.publish_batch(
|
successful, failed = self.backend.publish_batch(
|
||||||
topic_arn=topic_arn,
|
topic_arn=topic_arn,
|
||||||
publish_batch_request_entries=publish_batch_request_entries,
|
publish_batch_request_entries=publish_batch_request_entries,
|
||||||
|
@ -87,6 +87,7 @@ TestAccAWSProvider
|
|||||||
TestAccAWSRedshiftServiceAccount
|
TestAccAWSRedshiftServiceAccount
|
||||||
TestAccAWSRolePolicyAttachment
|
TestAccAWSRolePolicyAttachment
|
||||||
TestAccAWSSNSSMSPreferences
|
TestAccAWSSNSSMSPreferences
|
||||||
|
TestAccAWSSNSTopicPolicy
|
||||||
TestAccAWSSageMakerPrebuiltECRImage
|
TestAccAWSSageMakerPrebuiltECRImage
|
||||||
TestAccAWSServiceDiscovery
|
TestAccAWSServiceDiscovery
|
||||||
TestAccAWSSQSQueuePolicy
|
TestAccAWSSQSQueuePolicy
|
||||||
|
@ -145,10 +145,43 @@ def test_publish_batch_to_sqs():
|
|||||||
messages.should.contain({"Message": "1"})
|
messages.should.contain({"Message": "1"})
|
||||||
messages.should.contain({"Message": "2", "Subject": "subj2"})
|
messages.should.contain({"Message": "2", "Subject": "subj2"})
|
||||||
messages.should.contain(
|
messages.should.contain(
|
||||||
{
|
{"Message": "3", "MessageAttributes": {"a": {"Type": "String", "Value": "v"}}}
|
||||||
"Message": "3",
|
|
||||||
"MessageAttributes": [
|
|
||||||
{"Name": "a", "Value": {"DataType": "String", "StringValue": "v"}}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sqs
|
||||||
|
@mock_sns
|
||||||
|
def test_publish_batch_to_sqs_raw():
|
||||||
|
client = boto3.client("sns", region_name="us-east-1")
|
||||||
|
topic_arn = client.create_topic(Name="standard_topic")["TopicArn"]
|
||||||
|
sqs = boto3.resource("sqs", region_name="us-east-1")
|
||||||
|
queue = sqs.create_queue(QueueName="test-queue")
|
||||||
|
|
||||||
|
queue_url = "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID)
|
||||||
|
client.subscribe(
|
||||||
|
TopicArn=topic_arn,
|
||||||
|
Protocol="sqs",
|
||||||
|
Endpoint=queue_url,
|
||||||
|
Attributes={"RawMessageDelivery": "true"},
|
||||||
|
)
|
||||||
|
|
||||||
|
entries = [
|
||||||
|
{"Id": "1", "Message": "foo",},
|
||||||
|
{
|
||||||
|
"Id": "2",
|
||||||
|
"Message": "bar",
|
||||||
|
"MessageAttributes": {"a": {"DataType": "String", "StringValue": "v"}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
resp = client.publish_batch(TopicArn=topic_arn, PublishBatchRequestEntries=entries,)
|
||||||
|
|
||||||
|
resp.should.have.key("Successful").length_of(2)
|
||||||
|
|
||||||
|
received = queue.receive_messages(
|
||||||
|
MaxNumberOfMessages=10, MessageAttributeNames=["All"],
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [(message.body, message.message_attributes) for message in received]
|
||||||
|
|
||||||
|
messages.should.contain(("foo", None))
|
||||||
|
messages.should.contain(("bar", {"a": {"StringValue": "v", "DataType": "String"}}))
|
||||||
|
@ -238,6 +238,49 @@ def test_publish_to_sqs_msg_attr_number_type():
|
|||||||
message.body.should.equal("test message")
|
message.body.should.equal("test message")
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sqs
|
||||||
|
@mock_sns
|
||||||
|
def test_publish_to_sqs_msg_attr_different_formats():
|
||||||
|
"""
|
||||||
|
Verify different Number-formats are processed correctly
|
||||||
|
"""
|
||||||
|
sns = boto3.resource("sns", region_name="us-east-1")
|
||||||
|
topic = sns.create_topic(Name="test-topic")
|
||||||
|
sqs = boto3.resource("sqs", region_name="us-east-1")
|
||||||
|
sqs_client = boto3.client("sqs", region_name="us-east-1")
|
||||||
|
queue_raw = sqs.create_queue(QueueName="test-queue-raw")
|
||||||
|
|
||||||
|
topic.subscribe(
|
||||||
|
Protocol="sqs",
|
||||||
|
Endpoint=queue_raw.attributes["QueueArn"],
|
||||||
|
Attributes={"RawMessageDelivery": "true"},
|
||||||
|
)
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message="test message",
|
||||||
|
MessageAttributes={
|
||||||
|
"integer": {"DataType": "Number", "StringValue": "123"},
|
||||||
|
"float": {"DataType": "Number", "StringValue": "12.34"},
|
||||||
|
"big-integer": {"DataType": "Number", "StringValue": "123456789"},
|
||||||
|
"big-float": {"DataType": "Number", "StringValue": "123456.789"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
messages_resp = sqs_client.receive_message(
|
||||||
|
QueueUrl=queue_raw.url, MessageAttributeNames=["All"]
|
||||||
|
)
|
||||||
|
message = messages_resp["Messages"][0]
|
||||||
|
message_attributes = message["MessageAttributes"]
|
||||||
|
message_attributes.should.equal(
|
||||||
|
{
|
||||||
|
"integer": {"DataType": "Number", "StringValue": "123"},
|
||||||
|
"float": {"DataType": "Number", "StringValue": "12.34"},
|
||||||
|
"big-integer": {"DataType": "Number", "StringValue": "123456789"},
|
||||||
|
"big-float": {"DataType": "Number", "StringValue": "123456.789"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock_sns
|
@mock_sns
|
||||||
def test_publish_sms():
|
def test_publish_sms():
|
||||||
client = boto3.client("sns", region_name="us-east-1")
|
client = boto3.client("sns", region_name="us-east-1")
|
||||||
@ -374,9 +417,14 @@ def test_publish_to_http():
|
|||||||
TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/foobar"
|
TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/foobar"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = conn.publish(
|
conn.publish(TopicArn=topic_arn, Message="my message", Subject="my subject")
|
||||||
TopicArn=topic_arn, Message="my message", Subject="my subject"
|
|
||||||
)
|
if not settings.TEST_SERVER_MODE:
|
||||||
|
sns_backend.topics[topic_arn].sent_notifications.should.have.length_of(1)
|
||||||
|
notification = sns_backend.topics[topic_arn].sent_notifications[0]
|
||||||
|
_, msg, subject, _, _ = notification
|
||||||
|
msg.should.equal("my message")
|
||||||
|
subject.should.equal("my subject")
|
||||||
|
|
||||||
|
|
||||||
@mock_sqs
|
@mock_sqs
|
||||||
|
@ -214,7 +214,7 @@ def test_topic_attributes():
|
|||||||
)
|
)
|
||||||
|
|
||||||
attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"]
|
attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"]
|
||||||
attributes["Policy"].should.equal('{"foo": "bar"}')
|
attributes["Policy"].should.equal('{"foo":"bar"}')
|
||||||
attributes["DisplayName"].should.equal("My display name")
|
attributes["DisplayName"].should.equal("My display name")
|
||||||
attributes["DeliveryPolicy"].should.equal(
|
attributes["DeliveryPolicy"].should.equal(
|
||||||
'{"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}}'
|
'{"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}}'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user