From f401c60825872fe64f18a2a59d3a4ff89563028f Mon Sep 17 00:00:00 2001 From: wblack Date: Tue, 17 Apr 2018 16:27:48 +0000 Subject: [PATCH 1/2] Include SNS message attributes with message body when delivering to SQS. --- moto/core/responses.py | 48 +++++++++++++++++++++++ moto/sns/models.py | 12 ++++-- moto/sns/responses.py | 51 +++++++++++++++++++++++-- tests/test_sns/test_publishing_boto3.py | 30 ++++++++++++++- 4 files changed, 132 insertions(+), 9 deletions(-) diff --git a/moto/core/responses.py b/moto/core/responses.py index ca5b9f7d2..ed4792083 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -492,6 +492,54 @@ class BaseResponse(_TemplateEnvironmentMixin): return results + def _get_object_map(self, prefix, name='Name', value='Value'): + """ + Given a query dict like + { + Prefix.1.Name: [u'event'], + Prefix.1.Value.StringValue: [u'order_cancelled'], + Prefix.1.Value.DataType: [u'String'], + Prefix.2.Name: [u'store'], + Prefix.2.Value.StringValue: [u'example_corp'], + Prefix.2.Value.DataType [u'String'], + } + + returns + { + 'event': { + 'DataType': 'String', + 'StringValue': 'example_corp' + }, + 'store': { + 'DataType': 'String', + 'StringValue': 'order_cancelled' + } + } + """ + object_map = {} + index = 1 + while True: + # Loop through looking for keys representing object name + name_key = '{0}.{1}.{2}'.format(prefix, index, name) + obj_name = self.querystring.get(name_key) + if not obj_name: + # Found all keys + break + + obj = {} + value_key_prefix = '{0}.{1}.{2}.'.format( + prefix, index, value) + for k, v in self.querystring.items(): + if k.startswith(value_key_prefix): + _, value_key = k.split(value_key_prefix, 1) + obj[value_key] = v[0] + + object_map[obj_name[0]] = obj + + index += 1 + + return object_map + @property def request_json(self): return 'JSON' in self.querystring.get('ContentType', []) diff --git a/moto/sns/models.py b/moto/sns/models.py index acfbac550..562e9c106 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -93,7 +93,7 @@ class Subscription(BaseModel): if self.protocol == 'sqs': queue_name = self.endpoint.split(":")[-1] region = self.endpoint.split(":")[3] - enveloped_message = json.dumps(self.get_post_data(message, message_id, subject), sort_keys=True, indent=2, separators=(',', ': ')) + enveloped_message = json.dumps(self.get_post_data(message, message_id, subject, message_attributes=message_attributes), sort_keys=True, indent=2, separators=(',', ': ')) sqs_backends[region].send_message(queue_name, enveloped_message) elif self.protocol in ['http', 'https']: post_data = self.get_post_data(message, message_id, subject) @@ -131,15 +131,16 @@ class Subscription(BaseModel): for rule in rules: if isinstance(rule, six.string_types): # only string value matching is supported - if message_attributes[field] == rule: + if message_attributes[field]['Value'] == rule: return True return False return all(_field_match(field, rules, message_attributes) for field, rules in six.iteritems(self._filter_policy)) - def get_post_data(self, message, message_id, subject): - return { + def get_post_data( + self, message, message_id, subject, message_attributes=None): + post_data = { "Type": "Notification", "MessageId": message_id, "TopicArn": self.topic.arn, @@ -151,6 +152,9 @@ class Subscription(BaseModel): "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem", "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55" } + if message_attributes: + post_data["MessageAttributes"] = message_attributes + return post_data class PlatformApplication(BaseModel): diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 7f23214cf..9c6f64f91 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -6,7 +6,7 @@ from collections import defaultdict from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores from .models import sns_backends -from .exceptions import SNSNotFoundError +from .exceptions import SNSNotFoundError, InvalidParameterValue from .utils import is_e164 @@ -30,6 +30,48 @@ class SNSResponse(BaseResponse): in attributes ) + def _parse_message_attributes(self, prefix='', value_namespace='Value.'): + message_attributes = self._get_object_map( + 'MessageAttributes.entry', + name='Name', + value='Value' + ) + # SNS converts some key names before forwarding messages + # DataType -> Type, StringValue -> Value, BinaryValue -> Value + transformed_message_attributes = {} + for name, value in message_attributes.items(): + # validation + data_type = value['DataType'] + if not data_type: + raise InvalidParameterValue( + "The message attribute '{0}' must contain non-empty " + "message attribute value.".format(name)) + + data_type_parts = data_type.split('.') + if (len(data_type_parts) > 2 or + data_type_parts[0] not in ['String', 'Binary', 'Number']): + raise InvalidParameterValue( + "The message attribute '{0}' has an invalid message " + "attribute type, the set of supported type prefixes is " + "Binary, Number, and String.".format(name)) + + if 'StringValue' in value: + value = value['StringValue'] + elif 'BinaryValue' in 'Value': + value = value['BinaryValue'] + if not value: + raise InvalidParameterValue( + "The message attribute '{0}' must contain non-empty " + "message attribute value for message attribute " + "type '{1}'.".format(name, data_type[0])) + + # transformation + transformed_message_attributes[name] = { + 'Type': data_type, 'Value': value + } + + return transformed_message_attributes + def create_topic(self): name = self._get_param('Name') topic = self.backend.create_topic(name) @@ -241,9 +283,10 @@ class SNSResponse(BaseResponse): phone_number = self._get_param('PhoneNumber') subject = self._get_param('Subject') - message_attributes = self._get_map_prefix('MessageAttributes.entry', - key_end='Name', - value_end='Value') + try: + message_attributes = self._parse_message_attributes() + except InvalidParameterValue as e: + return self._error(e.description), dict(status=e.code) if phone_number is not None: # Check phone is correct syntax (e164) diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 52347cc15..9a2403034 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -239,6 +239,11 @@ def test_filtering_exact_string(): messages = queue.receive_messages(MaxNumberOfMessages=5) message_bodies = [json.loads(m.body)['Message'] for m in messages] message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'store': {'Type': 'String', 'Value': 'example_corp'}}]) + @mock_sqs @mock_sns @@ -256,6 +261,11 @@ def test_filtering_exact_string_multiple_message_attributes(): messages = queue.receive_messages(MaxNumberOfMessages=5) message_bodies = [json.loads(m.body)['Message'] for m in messages] message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([{ + 'store': {'Type': 'String', 'Value': 'example_corp'}, + 'event': {'Type': 'String', 'Value': 'order_cancelled'}}]) @mock_sqs @mock_sns @@ -275,6 +285,11 @@ def test_filtering_exact_string_OR_matching(): message_bodies = [json.loads(m.body)['Message'] for m in messages] message_bodies.should.equal( ['match example_corp', 'match different_corp']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([ + {'store': {'Type': 'String', 'Value': 'example_corp'}}, + {'store': {'Type': 'String', 'Value': 'different_corp'}}]) @mock_sqs @mock_sns @@ -294,6 +309,11 @@ def test_filtering_exact_string_AND_matching_positive(): message_bodies = [json.loads(m.body)['Message'] for m in messages] message_bodies.should.equal( ['match example_corp order_cancelled']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([{ + 'store': {'Type': 'String', 'Value': 'example_corp'}, + 'event': {'Type': 'String', 'Value': 'order_cancelled'}}]) @mock_sqs @mock_sns @@ -312,7 +332,9 @@ def test_filtering_exact_string_AND_matching_no_match(): messages = queue.receive_messages(MaxNumberOfMessages=5) message_bodies = [json.loads(m.body)['Message'] for m in messages] message_bodies.should.equal([]) - + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) @mock_sqs @mock_sns @@ -328,6 +350,9 @@ def test_filtering_exact_string_no_match(): messages = queue.receive_messages(MaxNumberOfMessages=5) message_bodies = [json.loads(m.body)['Message'] for m in messages] message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) @mock_sqs @mock_sns @@ -340,3 +365,6 @@ def test_filtering_exact_string_no_attributes_no_match(): messages = queue.receive_messages(MaxNumberOfMessages=5) message_bodies = [json.loads(m.body)['Message'] for m in messages] message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) From 0b36f06df101f0f7a15d2a70851e0992afd0c275 Mon Sep 17 00:00:00 2001 From: wblack Date: Wed, 18 Apr 2018 09:54:15 +0000 Subject: [PATCH 2/2] Fixes for linter warnings --- tests/test_sns/test_publishing_boto3.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 9a2403034..6ea29c986 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -2,8 +2,6 @@ from __future__ import unicode_literals import json -from six.moves.urllib.parse import parse_qs - import boto3 import re from freezegun import freeze_time @@ -12,7 +10,6 @@ import sure # noqa import responses from botocore.exceptions import ClientError from moto import mock_sns, mock_sqs -from freezegun import freeze_time MESSAGE_FROM_SQS_TEMPLATE = '{\n "Message": "%s",\n "MessageId": "%s",\n "Signature": "EXAMPLElDMXvB8r9R83tGoNn0ecwd5UjllzsvSvbItzfaMpN2nk5HVSw7XnOn/49IkxDKz8YrlH2qJXj2iZB0Zo2O71c4qQk1fMUDi3LGpij7RCW7AW9vYYsSqIKRnFS94ilu7NFhUzLiieYr4BKHpdTmdD6c0esKEYBpabxDSc=",\n "SignatureVersion": "1",\n "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem",\n "Subject": "my subject",\n "Timestamp": "2015-01-01T12:00:00.000Z",\n "TopicArn": "arn:aws:sns:%s:123456789012:some-topic",\n "Type": "Notification",\n "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55"\n}' @@ -176,7 +173,6 @@ def test_publish_to_http(): response = conn.publish( TopicArn=topic_arn, Message="my message", Subject="my subject") - message_id = response['MessageId'] @mock_sqs