Implement basic SNS message filtering (#1521)
* Add support for FilterPolicy to sns subscription set_filter_attributes * Add basic support for sns message filtering This adds support for exact string value matching along with AND/OR logic as described here: https://docs.aws.amazon.com/sns/latest/dg/message-filtering.html It does not provide support for: - Anything-but string matching - Prefix string matching - Numeric Value Matching The above filter policies (if configured) will not match messages.
This commit is contained in:
parent
6dce7dcb18
commit
d3d9557d49
@ -42,11 +42,12 @@ class Topic(BaseModel):
|
|||||||
self.subscriptions_confimed = 0
|
self.subscriptions_confimed = 0
|
||||||
self.subscriptions_deleted = 0
|
self.subscriptions_deleted = 0
|
||||||
|
|
||||||
def publish(self, message, subject=None):
|
def publish(self, message, subject=None, message_attributes=None):
|
||||||
message_id = six.text_type(uuid.uuid4())
|
message_id = six.text_type(uuid.uuid4())
|
||||||
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
|
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
|
||||||
for subscription in subscriptions:
|
for subscription in subscriptions:
|
||||||
subscription.publish(message, message_id, subject=subject)
|
subscription.publish(message, message_id, subject=subject,
|
||||||
|
message_attributes=message_attributes)
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
def get_cfn_attribute(self, attribute_name):
|
def get_cfn_attribute(self, attribute_name):
|
||||||
@ -81,9 +82,14 @@ class Subscription(BaseModel):
|
|||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
self.arn = make_arn_for_subscription(self.topic.arn)
|
self.arn = make_arn_for_subscription(self.topic.arn)
|
||||||
self.attributes = {}
|
self.attributes = {}
|
||||||
|
self._filter_policy = None # filter policy as a dict, not json.
|
||||||
self.confirmed = False
|
self.confirmed = False
|
||||||
|
|
||||||
def publish(self, message, message_id, subject=None):
|
def publish(self, message, message_id, subject=None,
|
||||||
|
message_attributes=None):
|
||||||
|
if not self._matches_filter_policy(message_attributes):
|
||||||
|
return
|
||||||
|
|
||||||
if self.protocol == 'sqs':
|
if self.protocol == 'sqs':
|
||||||
queue_name = self.endpoint.split(":")[-1]
|
queue_name = self.endpoint.split(":")[-1]
|
||||||
region = self.endpoint.split(":")[3]
|
region = self.endpoint.split(":")[3]
|
||||||
@ -98,6 +104,28 @@ class Subscription(BaseModel):
|
|||||||
region = self.arn.split(':')[3]
|
region = self.arn.split(':')[3]
|
||||||
lambda_backends[region].send_message(function_name, message, subject=subject)
|
lambda_backends[region].send_message(function_name, message, subject=subject)
|
||||||
|
|
||||||
|
def _matches_filter_policy(self, message_attributes):
|
||||||
|
# TODO: support Anything-but matching, prefix matching and
|
||||||
|
# numeric value matching.
|
||||||
|
if not self._filter_policy:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if message_attributes is None:
|
||||||
|
message_attributes = {}
|
||||||
|
|
||||||
|
def _field_match(field, rules, message_attributes):
|
||||||
|
if field not in message_attributes:
|
||||||
|
return False
|
||||||
|
for rule in rules:
|
||||||
|
if isinstance(rule, six.string_types):
|
||||||
|
# only string value matching is supported
|
||||||
|
if message_attributes[field] == rule:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
return all(_field_match(field, rules, message_attributes)
|
||||||
|
for field, rules in six.iteritems(self._filter_policy))
|
||||||
|
|
||||||
def get_post_data(self, message, message_id, subject):
|
def get_post_data(self, message, message_id, subject):
|
||||||
return {
|
return {
|
||||||
"Type": "Notification",
|
"Type": "Notification",
|
||||||
@ -274,13 +302,14 @@ class SNSBackend(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
return self._get_values_nexttoken(self.subscriptions, next_token)
|
return self._get_values_nexttoken(self.subscriptions, next_token)
|
||||||
|
|
||||||
def publish(self, arn, message, subject=None):
|
def publish(self, arn, message, subject=None, message_attributes=None):
|
||||||
if subject is not None and len(subject) >= 100:
|
if subject is not None and len(subject) >= 100:
|
||||||
raise ValueError('Subject must be less than 100 characters')
|
raise ValueError('Subject must be less than 100 characters')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
topic = self.get_topic(arn)
|
topic = self.get_topic(arn)
|
||||||
message_id = topic.publish(message, subject=subject)
|
message_id = topic.publish(message, subject=subject,
|
||||||
|
message_attributes=message_attributes)
|
||||||
except SNSNotFoundError:
|
except SNSNotFoundError:
|
||||||
endpoint = self.get_endpoint(arn)
|
endpoint = self.get_endpoint(arn)
|
||||||
message_id = endpoint.publish(message)
|
message_id = endpoint.publish(message)
|
||||||
@ -352,7 +381,7 @@ class SNSBackend(BaseBackend):
|
|||||||
return subscription.attributes
|
return subscription.attributes
|
||||||
|
|
||||||
def set_subscription_attributes(self, arn, name, value):
|
def set_subscription_attributes(self, arn, name, value):
|
||||||
if name not in ['RawMessageDelivery', 'DeliveryPolicy']:
|
if name not in ['RawMessageDelivery', 'DeliveryPolicy', 'FilterPolicy']:
|
||||||
raise SNSInvalidParameter('AttributeName')
|
raise SNSInvalidParameter('AttributeName')
|
||||||
|
|
||||||
# TODO: should do validation
|
# TODO: should do validation
|
||||||
@ -363,6 +392,9 @@ class SNSBackend(BaseBackend):
|
|||||||
|
|
||||||
subscription.attributes[name] = value
|
subscription.attributes[name] = value
|
||||||
|
|
||||||
|
if name == 'FilterPolicy':
|
||||||
|
subscription._filter_policy = json.loads(value)
|
||||||
|
|
||||||
|
|
||||||
sns_backends = {}
|
sns_backends = {}
|
||||||
for region in boto.sns.regions():
|
for region in boto.sns.regions():
|
||||||
|
@ -241,6 +241,10 @@ class SNSResponse(BaseResponse):
|
|||||||
phone_number = self._get_param('PhoneNumber')
|
phone_number = self._get_param('PhoneNumber')
|
||||||
subject = self._get_param('Subject')
|
subject = self._get_param('Subject')
|
||||||
|
|
||||||
|
message_attributes = self._get_map_prefix('MessageAttributes.entry',
|
||||||
|
key_end='Name',
|
||||||
|
value_end='Value')
|
||||||
|
|
||||||
if phone_number is not None:
|
if phone_number is not None:
|
||||||
# Check phone is correct syntax (e164)
|
# Check phone is correct syntax (e164)
|
||||||
if not is_e164(phone_number):
|
if not is_e164(phone_number):
|
||||||
@ -265,7 +269,9 @@ class SNSResponse(BaseResponse):
|
|||||||
message = self._get_param('Message')
|
message = self._get_param('Message')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message_id = self.backend.publish(arn, message, subject=subject)
|
message_id = self.backend.publish(
|
||||||
|
arn, message, subject=subject,
|
||||||
|
message_attributes=message_attributes)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
error_response = self._error('InvalidParameter', str(err))
|
error_response = self._error('InvalidParameter', str(err))
|
||||||
return error_response, dict(status=400)
|
return error_response, dict(status=400)
|
||||||
|
@ -30,7 +30,7 @@ class SQSResponse(BaseResponse):
|
|||||||
@property
|
@property
|
||||||
def attribute(self):
|
def attribute(self):
|
||||||
if not hasattr(self, '_attribute'):
|
if not hasattr(self, '_attribute'):
|
||||||
self._attribute = self._get_map_prefix('Attribute', key_end='Name', value_end='Value')
|
self._attribute = self._get_map_prefix('Attribute', key_end='.Name', value_end='.Value')
|
||||||
return self._attribute
|
return self._attribute
|
||||||
|
|
||||||
def _get_queue_name(self):
|
def _get_queue_name(self):
|
||||||
|
@ -207,3 +207,136 @@ def test_publish_subject():
|
|||||||
err.response['Error']['Code'].should.equal('InvalidParameter')
|
err.response['Error']['Code'].should.equal('InvalidParameter')
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('Should have raised an InvalidParameter exception')
|
raise RuntimeError('Should have raised an InvalidParameter exception')
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_filter_policy_test(filter_policy):
|
||||||
|
sns = boto3.resource('sns', region_name='us-east-1')
|
||||||
|
topic = sns.create_topic(Name='some-topic')
|
||||||
|
|
||||||
|
sqs = boto3.resource('sqs', region_name='us-east-1')
|
||||||
|
queue = sqs.create_queue(QueueName='test-queue')
|
||||||
|
|
||||||
|
subscription = topic.subscribe(
|
||||||
|
Protocol='sqs', Endpoint=queue.attributes['QueueArn'])
|
||||||
|
|
||||||
|
subscription.set_attributes(
|
||||||
|
AttributeName='FilterPolicy', AttributeValue=json.dumps(filter_policy))
|
||||||
|
|
||||||
|
return topic, subscription, queue
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sqs
|
||||||
|
@mock_sns
|
||||||
|
def test_filtering_exact_string():
|
||||||
|
topic, subscription, queue = _setup_filter_policy_test(
|
||||||
|
{'store': ['example_corp']})
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message='match',
|
||||||
|
MessageAttributes={'store': {'DataType': 'String',
|
||||||
|
'StringValue': 'example_corp'}})
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)['Message'] for m in messages]
|
||||||
|
message_bodies.should.equal(['match'])
|
||||||
|
|
||||||
|
@mock_sqs
|
||||||
|
@mock_sns
|
||||||
|
def test_filtering_exact_string_multiple_message_attributes():
|
||||||
|
topic, subscription, queue = _setup_filter_policy_test(
|
||||||
|
{'store': ['example_corp']})
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message='match',
|
||||||
|
MessageAttributes={'store': {'DataType': 'String',
|
||||||
|
'StringValue': 'example_corp'},
|
||||||
|
'event': {'DataType': 'String',
|
||||||
|
'StringValue': 'order_cancelled'}})
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)['Message'] for m in messages]
|
||||||
|
message_bodies.should.equal(['match'])
|
||||||
|
|
||||||
|
@mock_sqs
|
||||||
|
@mock_sns
|
||||||
|
def test_filtering_exact_string_OR_matching():
|
||||||
|
topic, subscription, queue = _setup_filter_policy_test(
|
||||||
|
{'store': ['example_corp', 'different_corp']})
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message='match example_corp',
|
||||||
|
MessageAttributes={'store': {'DataType': 'String',
|
||||||
|
'StringValue': 'example_corp'}})
|
||||||
|
topic.publish(
|
||||||
|
Message='match different_corp',
|
||||||
|
MessageAttributes={'store': {'DataType': 'String',
|
||||||
|
'StringValue': 'different_corp'}})
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)['Message'] for m in messages]
|
||||||
|
message_bodies.should.equal(
|
||||||
|
['match example_corp', 'match different_corp'])
|
||||||
|
|
||||||
|
@mock_sqs
|
||||||
|
@mock_sns
|
||||||
|
def test_filtering_exact_string_AND_matching_positive():
|
||||||
|
topic, subscription, queue = _setup_filter_policy_test(
|
||||||
|
{'store': ['example_corp'],
|
||||||
|
'event': ['order_cancelled']})
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message='match example_corp order_cancelled',
|
||||||
|
MessageAttributes={'store': {'DataType': 'String',
|
||||||
|
'StringValue': 'example_corp'},
|
||||||
|
'event': {'DataType': 'String',
|
||||||
|
'StringValue': 'order_cancelled'}})
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)['Message'] for m in messages]
|
||||||
|
message_bodies.should.equal(
|
||||||
|
['match example_corp order_cancelled'])
|
||||||
|
|
||||||
|
@mock_sqs
|
||||||
|
@mock_sns
|
||||||
|
def test_filtering_exact_string_AND_matching_no_match():
|
||||||
|
topic, subscription, queue = _setup_filter_policy_test(
|
||||||
|
{'store': ['example_corp'],
|
||||||
|
'event': ['order_cancelled']})
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message='match example_corp order_accepted',
|
||||||
|
MessageAttributes={'store': {'DataType': 'String',
|
||||||
|
'StringValue': 'example_corp'},
|
||||||
|
'event': {'DataType': 'String',
|
||||||
|
'StringValue': 'order_accepted'}})
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)['Message'] for m in messages]
|
||||||
|
message_bodies.should.equal([])
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sqs
|
||||||
|
@mock_sns
|
||||||
|
def test_filtering_exact_string_no_match():
|
||||||
|
topic, subscription, queue = _setup_filter_policy_test(
|
||||||
|
{'store': ['example_corp']})
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message='no match',
|
||||||
|
MessageAttributes={'store': {'DataType': 'String',
|
||||||
|
'StringValue': 'different_corp'}})
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)['Message'] for m in messages]
|
||||||
|
message_bodies.should.equal([])
|
||||||
|
|
||||||
|
@mock_sqs
|
||||||
|
@mock_sns
|
||||||
|
def test_filtering_exact_string_no_attributes_no_match():
|
||||||
|
topic, subscription, queue = _setup_filter_policy_test(
|
||||||
|
{'store': ['example_corp']})
|
||||||
|
|
||||||
|
topic.publish(Message='no match')
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)['Message'] for m in messages]
|
||||||
|
message_bodies.should.equal([])
|
||||||
|
@ -223,11 +223,26 @@ def test_set_subscription_attributes():
|
|||||||
AttributeName='DeliveryPolicy',
|
AttributeName='DeliveryPolicy',
|
||||||
AttributeValue=delivery_policy
|
AttributeValue=delivery_policy
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filter_policy = json.dumps({
|
||||||
|
"store": ["example_corp"],
|
||||||
|
"event": ["order_cancelled"],
|
||||||
|
"encrypted": [False],
|
||||||
|
"customer_interests": ["basketball", "baseball"]
|
||||||
|
})
|
||||||
|
conn.set_subscription_attributes(
|
||||||
|
SubscriptionArn=subscription_arn,
|
||||||
|
AttributeName='FilterPolicy',
|
||||||
|
AttributeValue=filter_policy
|
||||||
|
)
|
||||||
|
|
||||||
attrs = conn.get_subscription_attributes(
|
attrs = conn.get_subscription_attributes(
|
||||||
SubscriptionArn=subscription_arn
|
SubscriptionArn=subscription_arn
|
||||||
)
|
)
|
||||||
|
|
||||||
attrs['Attributes']['RawMessageDelivery'].should.equal('true')
|
attrs['Attributes']['RawMessageDelivery'].should.equal('true')
|
||||||
attrs['Attributes']['DeliveryPolicy'].should.equal(delivery_policy)
|
attrs['Attributes']['DeliveryPolicy'].should.equal(delivery_policy)
|
||||||
|
attrs['Attributes']['FilterPolicy'].should.equal(filter_policy)
|
||||||
|
|
||||||
# not existing subscription
|
# not existing subscription
|
||||||
with assert_raises(ClientError):
|
with assert_raises(ClientError):
|
||||||
|
Loading…
Reference in New Issue
Block a user