diff --git a/moto/sns/models.py b/moto/sns/models.py index 70587d980..9afc28f46 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -42,11 +42,12 @@ class Topic(BaseModel): self.subscriptions_confimed = 0 self.subscriptions_deleted = 0 - def publish(self, message, subject=None): + def publish(self, message, subject=None, message_attributes=None): message_id = six.text_type(uuid.uuid4()) subscriptions, _ = self.sns_backend.list_subscriptions(self.arn) for subscription in subscriptions: - subscription.publish(message, message_id, subject=subject) + subscription.publish(message, message_id, subject=subject, + message_attributes=message_attributes) return message_id def get_cfn_attribute(self, attribute_name): @@ -81,9 +82,14 @@ class Subscription(BaseModel): self.protocol = protocol self.arn = make_arn_for_subscription(self.topic.arn) self.attributes = {} + self._filter_policy = None # filter policy as a dict, not json. self.confirmed = False - def publish(self, message, message_id, subject=None): + def publish(self, message, message_id, subject=None, + message_attributes=None): + if not self._matches_filter_policy(message_attributes): + return + if self.protocol == 'sqs': queue_name = self.endpoint.split(":")[-1] region = self.endpoint.split(":")[3] @@ -98,6 +104,28 @@ class Subscription(BaseModel): region = self.arn.split(':')[3] lambda_backends[region].send_message(function_name, message, subject=subject) + def _matches_filter_policy(self, message_attributes): + # TODO: support Anything-but matching, prefix matching and + # numeric value matching. + if not self._filter_policy: + return True + + if message_attributes is None: + message_attributes = {} + + def _field_match(field, rules, message_attributes): + if field not in message_attributes: + return False + for rule in rules: + if isinstance(rule, six.string_types): + # only string value matching is supported + if message_attributes[field] == rule: + return True + return False + + return all(_field_match(field, rules, message_attributes) + for field, rules in six.iteritems(self._filter_policy)) + def get_post_data(self, message, message_id, subject): return { "Type": "Notification", @@ -274,13 +302,14 @@ class SNSBackend(BaseBackend): else: return self._get_values_nexttoken(self.subscriptions, next_token) - def publish(self, arn, message, subject=None): + def publish(self, arn, message, subject=None, message_attributes=None): if subject is not None and len(subject) >= 100: raise ValueError('Subject must be less than 100 characters') try: topic = self.get_topic(arn) - message_id = topic.publish(message, subject=subject) + message_id = topic.publish(message, subject=subject, + message_attributes=message_attributes) except SNSNotFoundError: endpoint = self.get_endpoint(arn) message_id = endpoint.publish(message) @@ -352,7 +381,7 @@ class SNSBackend(BaseBackend): return subscription.attributes def set_subscription_attributes(self, arn, name, value): - if name not in ['RawMessageDelivery', 'DeliveryPolicy']: + if name not in ['RawMessageDelivery', 'DeliveryPolicy', 'FilterPolicy']: raise SNSInvalidParameter('AttributeName') # TODO: should do validation @@ -363,6 +392,9 @@ class SNSBackend(BaseBackend): subscription.attributes[name] = value + if name == 'FilterPolicy': + subscription._filter_policy = json.loads(value) + sns_backends = {} for region in boto.sns.regions(): diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 3b4aade80..7f23214cf 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -241,6 +241,10 @@ class SNSResponse(BaseResponse): phone_number = self._get_param('PhoneNumber') subject = self._get_param('Subject') + message_attributes = self._get_map_prefix('MessageAttributes.entry', + key_end='Name', + value_end='Value') + if phone_number is not None: # Check phone is correct syntax (e164) if not is_e164(phone_number): @@ -265,7 +269,9 @@ class SNSResponse(BaseResponse): message = self._get_param('Message') try: - message_id = self.backend.publish(arn, message, subject=subject) + message_id = self.backend.publish( + arn, message, subject=subject, + message_attributes=message_attributes) except ValueError as err: error_response = self._error('InvalidParameter', str(err)) return error_response, dict(status=400) diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index 71aab9a58..c475f0ce0 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -30,7 +30,7 @@ class SQSResponse(BaseResponse): @property def attribute(self): if not hasattr(self, '_attribute'): - self._attribute = self._get_map_prefix('Attribute', key_end='Name', value_end='Value') + self._attribute = self._get_map_prefix('Attribute', key_end='.Name', value_end='.Value') return self._attribute def _get_queue_name(self): diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 1540ceb84..3ccc3ef44 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -207,3 +207,136 @@ def test_publish_subject(): err.response['Error']['Code'].should.equal('InvalidParameter') else: raise RuntimeError('Should have raised an InvalidParameter exception') + + +def _setup_filter_policy_test(filter_policy): + sns = boto3.resource('sns', region_name='us-east-1') + topic = sns.create_topic(Name='some-topic') + + sqs = boto3.resource('sqs', region_name='us-east-1') + queue = sqs.create_queue(QueueName='test-queue') + + subscription = topic.subscribe( + Protocol='sqs', Endpoint=queue.attributes['QueueArn']) + + subscription.set_attributes( + AttributeName='FilterPolicy', AttributeValue=json.dumps(filter_policy)) + + return topic, subscription, queue + + +@mock_sqs +@mock_sns +def test_filtering_exact_string(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': ['example_corp']}) + + topic.publish( + Message='match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + +@mock_sqs +@mock_sns +def test_filtering_exact_string_multiple_message_attributes(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': ['example_corp']}) + + topic.publish( + Message='match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}, + 'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + +@mock_sqs +@mock_sns +def test_filtering_exact_string_OR_matching(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': ['example_corp', 'different_corp']}) + + topic.publish( + Message='match example_corp', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}}) + topic.publish( + Message='match different_corp', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'different_corp'}}) + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal( + ['match example_corp', 'match different_corp']) + +@mock_sqs +@mock_sns +def test_filtering_exact_string_AND_matching_positive(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': ['example_corp'], + 'event': ['order_cancelled']}) + + topic.publish( + Message='match example_corp order_cancelled', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}, + 'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal( + ['match example_corp order_cancelled']) + +@mock_sqs +@mock_sns +def test_filtering_exact_string_AND_matching_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': ['example_corp'], + 'event': ['order_cancelled']}) + + topic.publish( + Message='match example_corp order_accepted', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}, + 'event': {'DataType': 'String', + 'StringValue': 'order_accepted'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': ['example_corp']}) + + topic.publish( + Message='no match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'different_corp'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + +@mock_sqs +@mock_sns +def test_filtering_exact_string_no_attributes_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': ['example_corp']}) + + topic.publish(Message='no match') + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index 59cef221f..98075e617 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -223,11 +223,26 @@ def test_set_subscription_attributes(): AttributeName='DeliveryPolicy', AttributeValue=delivery_policy ) + + filter_policy = json.dumps({ + "store": ["example_corp"], + "event": ["order_cancelled"], + "encrypted": [False], + "customer_interests": ["basketball", "baseball"] + }) + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName='FilterPolicy', + AttributeValue=filter_policy + ) + attrs = conn.get_subscription_attributes( SubscriptionArn=subscription_arn ) + attrs['Attributes']['RawMessageDelivery'].should.equal('true') attrs['Attributes']['DeliveryPolicy'].should.equal(delivery_policy) + attrs['Attributes']['FilterPolicy'].should.equal(filter_policy) # not existing subscription with assert_raises(ClientError):