Include SNS message attributes with message body when delivering to SQS.

This commit is contained in:
wblack 2018-04-17 16:27:48 +00:00
parent 783504c897
commit f401c60825
4 changed files with 132 additions and 9 deletions

View File

@ -492,6 +492,54 @@ class BaseResponse(_TemplateEnvironmentMixin):
return results
def _get_object_map(self, prefix, name='Name', value='Value'):
"""
Given a query dict like
{
Prefix.1.Name: [u'event'],
Prefix.1.Value.StringValue: [u'order_cancelled'],
Prefix.1.Value.DataType: [u'String'],
Prefix.2.Name: [u'store'],
Prefix.2.Value.StringValue: [u'example_corp'],
Prefix.2.Value.DataType [u'String'],
}
returns
{
'event': {
'DataType': 'String',
'StringValue': 'example_corp'
},
'store': {
'DataType': 'String',
'StringValue': 'order_cancelled'
}
}
"""
object_map = {}
index = 1
while True:
# Loop through looking for keys representing object name
name_key = '{0}.{1}.{2}'.format(prefix, index, name)
obj_name = self.querystring.get(name_key)
if not obj_name:
# Found all keys
break
obj = {}
value_key_prefix = '{0}.{1}.{2}.'.format(
prefix, index, value)
for k, v in self.querystring.items():
if k.startswith(value_key_prefix):
_, value_key = k.split(value_key_prefix, 1)
obj[value_key] = v[0]
object_map[obj_name[0]] = obj
index += 1
return object_map
@property
def request_json(self):
return 'JSON' in self.querystring.get('ContentType', [])

View File

@ -93,7 +93,7 @@ class Subscription(BaseModel):
if self.protocol == 'sqs':
queue_name = self.endpoint.split(":")[-1]
region = self.endpoint.split(":")[3]
enveloped_message = json.dumps(self.get_post_data(message, message_id, subject), sort_keys=True, indent=2, separators=(',', ': '))
enveloped_message = json.dumps(self.get_post_data(message, message_id, subject, message_attributes=message_attributes), sort_keys=True, indent=2, separators=(',', ': '))
sqs_backends[region].send_message(queue_name, enveloped_message)
elif self.protocol in ['http', 'https']:
post_data = self.get_post_data(message, message_id, subject)
@ -131,15 +131,16 @@ class Subscription(BaseModel):
for rule in rules:
if isinstance(rule, six.string_types):
# only string value matching is supported
if message_attributes[field] == rule:
if message_attributes[field]['Value'] == rule:
return True
return False
return all(_field_match(field, rules, message_attributes)
for field, rules in six.iteritems(self._filter_policy))
def get_post_data(self, message, message_id, subject):
return {
def get_post_data(
self, message, message_id, subject, message_attributes=None):
post_data = {
"Type": "Notification",
"MessageId": message_id,
"TopicArn": self.topic.arn,
@ -151,6 +152,9 @@ class Subscription(BaseModel):
"SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem",
"UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55"
}
if message_attributes:
post_data["MessageAttributes"] = message_attributes
return post_data
class PlatformApplication(BaseModel):

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from .models import sns_backends
from .exceptions import SNSNotFoundError
from .exceptions import SNSNotFoundError, InvalidParameterValue
from .utils import is_e164
@ -30,6 +30,48 @@ class SNSResponse(BaseResponse):
in attributes
)
def _parse_message_attributes(self, prefix='', value_namespace='Value.'):
message_attributes = self._get_object_map(
'MessageAttributes.entry',
name='Name',
value='Value'
)
# SNS converts some key names before forwarding messages
# DataType -> Type, StringValue -> Value, BinaryValue -> Value
transformed_message_attributes = {}
for name, value in message_attributes.items():
# validation
data_type = value['DataType']
if not data_type:
raise InvalidParameterValue(
"The message attribute '{0}' must contain non-empty "
"message attribute value.".format(name))
data_type_parts = data_type.split('.')
if (len(data_type_parts) > 2 or
data_type_parts[0] not in ['String', 'Binary', 'Number']):
raise InvalidParameterValue(
"The message attribute '{0}' has an invalid message "
"attribute type, the set of supported type prefixes is "
"Binary, Number, and String.".format(name))
if 'StringValue' in value:
value = value['StringValue']
elif 'BinaryValue' in 'Value':
value = value['BinaryValue']
if not value:
raise InvalidParameterValue(
"The message attribute '{0}' must contain non-empty "
"message attribute value for message attribute "
"type '{1}'.".format(name, data_type[0]))
# transformation
transformed_message_attributes[name] = {
'Type': data_type, 'Value': value
}
return transformed_message_attributes
def create_topic(self):
name = self._get_param('Name')
topic = self.backend.create_topic(name)
@ -241,9 +283,10 @@ class SNSResponse(BaseResponse):
phone_number = self._get_param('PhoneNumber')
subject = self._get_param('Subject')
message_attributes = self._get_map_prefix('MessageAttributes.entry',
key_end='Name',
value_end='Value')
try:
message_attributes = self._parse_message_attributes()
except InvalidParameterValue as e:
return self._error(e.description), dict(status=e.code)
if phone_number is not None:
# Check phone is correct syntax (e164)

View File

@ -239,6 +239,11 @@ def test_filtering_exact_string():
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'store': {'Type': 'String', 'Value': 'example_corp'}}])
@mock_sqs
@mock_sns
@ -256,6 +261,11 @@ def test_filtering_exact_string_multiple_message_attributes():
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([{
'store': {'Type': 'String', 'Value': 'example_corp'},
'event': {'Type': 'String', 'Value': 'order_cancelled'}}])
@mock_sqs
@mock_sns
@ -275,6 +285,11 @@ def test_filtering_exact_string_OR_matching():
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(
['match example_corp', 'match different_corp'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([
{'store': {'Type': 'String', 'Value': 'example_corp'}},
{'store': {'Type': 'String', 'Value': 'different_corp'}}])
@mock_sqs
@mock_sns
@ -294,6 +309,11 @@ def test_filtering_exact_string_AND_matching_positive():
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(
['match example_corp order_cancelled'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([{
'store': {'Type': 'String', 'Value': 'example_corp'},
'event': {'Type': 'String', 'Value': 'order_cancelled'}}])
@mock_sqs
@mock_sns
@ -312,7 +332,9 @@ def test_filtering_exact_string_AND_matching_no_match():
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
@ -328,6 +350,9 @@ def test_filtering_exact_string_no_match():
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
@ -340,3 +365,6 @@ def test_filtering_exact_string_no_attributes_no_match():
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])