sns: support FilterPolicyScope attribute, including filtering (#6262)

This commit is contained in:
Jakub P 2023-05-15 19:24:44 +02:00 committed by GitHub
parent 523225d6e9
commit 054ebcb326
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1707 additions and 150 deletions

View File

@ -28,7 +28,12 @@ from .exceptions import (
TooManyEntriesInBatchRequest,
BatchEntryIdsNotDistinct,
)
from .utils import make_arn_for_topic, make_arn_for_subscription, is_e164
from .utils import (
make_arn_for_topic,
make_arn_for_subscription,
is_e164,
FilterPolicyMatcher,
)
DEFAULT_PAGE_SIZE = 100
@ -192,6 +197,7 @@ class Subscription(BaseModel):
self.arn = make_arn_for_subscription(self.topic.arn)
self.attributes: Dict[str, Any] = {}
self._filter_policy = None # filter policy as a dict, not json.
self._filter_policy_matcher = None
self.confirmed = False
def publish(
@ -203,8 +209,9 @@ class Subscription(BaseModel):
group_id: Optional[str] = None,
deduplication_id: Optional[str] = None,
) -> None:
if not self._matches_filter_policy(message_attributes):
return
if self._filter_policy_matcher is not None:
if not self._filter_policy_matcher.matches(message_attributes, message):
return
if self.protocol == "sqs":
queue_name = self.endpoint.split(":")[-1]
@ -277,131 +284,6 @@ class Subscription(BaseModel):
function_name, message, subject=subject, qualifier=qualifier
)
def _matches_filter_policy(
self, message_attributes: Optional[Dict[str, Any]]
) -> bool:
if not self._filter_policy:
return True
if message_attributes is None:
message_attributes = {}
def _field_match(
field: str, rules: List[Any], message_attributes: Dict[str, Any]
) -> bool:
for rule in rules:
# TODO: boolean value matching is not supported, SNS behavior unknown
if isinstance(rule, str):
if field not in message_attributes:
return False
if message_attributes[field]["Value"] == rule:
return True
try:
json_data = json.loads(message_attributes[field]["Value"])
if rule in json_data:
return True
except (ValueError, TypeError):
pass
if isinstance(rule, (int, float)):
if field not in message_attributes:
return False
if message_attributes[field]["Type"] == "Number":
attribute_values = [message_attributes[field]["Value"]]
elif message_attributes[field]["Type"] == "String.Array":
try:
attribute_values = json.loads(
message_attributes[field]["Value"]
)
if not isinstance(attribute_values, list):
attribute_values = [attribute_values]
except (ValueError, TypeError):
return False
else:
return False
for attribute_value in attribute_values:
attribute_value = float(attribute_value)
# Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
if int(attribute_value * 1000000) == int(rule * 1000000):
return True
if isinstance(rule, dict):
keyword = list(rule.keys())[0]
value = list(rule.values())[0]
if keyword == "exists":
if value and field in message_attributes:
return True
elif not value and field not in message_attributes:
return True
elif keyword == "prefix" and isinstance(value, str):
if field in message_attributes:
attr = message_attributes[field]
if attr["Type"] == "String" and attr["Value"].startswith(
value
):
return True
elif keyword == "anything-but":
if field not in message_attributes:
continue
attr = message_attributes[field]
if isinstance(value, dict):
# We can combine anything-but with the prefix-filter
anything_but_key = list(value.keys())[0]
anything_but_val = list(value.values())[0]
if anything_but_key != "prefix":
return False
if attr["Type"] == "String":
actual_values = [attr["Value"]]
else:
actual_values = [v for v in attr["Value"]]
if all(
[
not v.startswith(anything_but_val)
for v in actual_values
]
):
return True
else:
undesired_values = (
[value] if isinstance(value, str) else value
)
if attr["Type"] == "Number":
actual_values = [float(attr["Value"])]
elif attr["Type"] == "String":
actual_values = [attr["Value"]]
else:
actual_values = [v for v in attr["Value"]]
if all([v not in undesired_values for v in actual_values]):
return True
elif keyword == "numeric" and isinstance(value, list):
# [(< x), (=, y), (>=, z)]
numeric_ranges = zip(value[0::2], value[1::2])
if (
message_attributes.get(field, {}).get("Type", "")
== "Number"
):
msg_value = float(message_attributes[field]["Value"])
matches = []
for operator, test_value in numeric_ranges:
test_value = test_value
if operator == ">":
matches.append((msg_value > test_value))
if operator == ">=":
matches.append((msg_value >= test_value))
if operator == "=":
matches.append((msg_value == test_value))
if operator == "<":
matches.append((msg_value < test_value))
if operator == "<=":
matches.append((msg_value <= test_value))
return all(matches)
return False
return all(
_field_match(field, rules, message_attributes)
for field, rules in self._filter_policy.items()
)
def get_post_data(
self,
message: str,
@ -848,6 +730,7 @@ class SNSBackend(BaseBackend):
"RawMessageDelivery",
"DeliveryPolicy",
"FilterPolicy",
"FilterPolicyScope",
"RedrivePolicy",
"SubscriptionRoleArn",
]:
@ -859,26 +742,90 @@ class SNSBackend(BaseBackend):
raise SNSNotFoundError(f"Subscription with arn {arn} not found")
subscription = _subscription[0]
subscription.attributes[name] = value
if name == "FilterPolicy":
filter_policy = json.loads(value)
self._validate_filter_policy(filter_policy)
# we validate the filter policy differently depending on the scope
# we need to always set the scope first
filter_policy_scope = subscription.attributes.get("FilterPolicyScope")
self._validate_filter_policy(filter_policy, scope=filter_policy_scope)
subscription._filter_policy = filter_policy
subscription._filter_policy_matcher = FilterPolicyMatcher(
filter_policy, filter_policy_scope
)
def _validate_filter_policy(self, value: Any) -> None:
# TODO: extend validation checks
subscription.attributes[name] = value
def _validate_filter_policy(self, value: Any, scope: str) -> None:
combinations = 1
for rules in value.values():
combinations *= len(rules)
# Even the official documentation states the total combination of values must not exceed 100, in reality it is 150
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
def aggregate_rules(
filter_policy: Dict[str, Any], depth: int = 1
) -> List[List[Any]]:
"""
This method evaluate the filter policy recursively, and returns only a list of lists of rules.
It also calculates the combinations of rules, calculated depending on the nesting of the rules.
Example:
nested_filter_policy = {
"key_a": {
"key_b": {
"key_c": ["value_one", "value_two", "value_three", "value_four"]
}
},
"key_d": {
"key_e": ["value_one", "value_two", "value_three"]
}
}
This function then iterates on the values of the top level keys of the filter policy: ("key_a", "key_d")
If the iterated value is not a list, it means it is a nested property. If the scope is `MessageBody`, it is
allowed, we call this method on the value, adding a level to the depth to keep track on how deep the key is.
If the value is a list, it means it contains rules: we will append this list of rules in _rules, and
calculate the combinations it adds.
For the example filter policy containing nested properties, we calculate it this way
The first array has four values in a three-level nested key, and the second has three values in a two-level
nested key. 3 x 4 x 2 x 3 = 72
The return value would be:
[["value_one", "value_two", "value_three", "value_four"], ["value_one", "value_two", "value_three"]]
It allows us to later iterate of the list of rules in an easy way, to verify its conditions.
:param filter_policy: a dict, starting at the FilterPolicy
:param depth: the depth/level of the rules we are evaluating
:return: a list of lists of rules
"""
nonlocal combinations
_rules = []
for key, _value in filter_policy.items():
if isinstance(_value, dict):
if scope == "MessageBody":
# From AWS docs: "unlike attribute-based policies, payload-based policies support property nesting."
_rules.extend(aggregate_rules(_value, depth=depth + 1))
else:
raise SNSInvalidParameter(
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
)
elif isinstance(_value, list):
_rules.append(_value)
combinations = combinations * len(_value) * depth
else:
raise SNSInvalidParameter(
f'Invalid parameter: FilterPolicy: "{key}" must be an object or an array'
)
return _rules
# A filter policy can have a maximum of five attribute names. For a nested policy, only parent keys are counted.
if len(value.values()) > 5:
raise SNSInvalidParameter(
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
)
aggregated_rules = aggregate_rules(value)
# For the complexity of the filter policy, the total combination of values must not exceed 150.
# https://docs.aws.amazon.com/sns/latest/dg/subscription-filter-policy-constraints.html
if combinations > 150:
raise SNSInvalidParameter(
"Invalid parameter: FilterPolicy: Filter policy is too complex"
)
for rules in value.values():
for rules in aggregated_rules:
for rule in rules:
if rule is None:
continue

View File

@ -226,6 +226,13 @@ class SNSResponse(BaseResponse):
subscription = self.backend.subscribe(topic_arn, endpoint, protocol)
if attributes is not None:
# We need to set the FilterPolicyScope first, as the validation of the FilterPolicy will depend on it
if "FilterPolicyScope" in attributes:
filter_policy_scope = attributes.pop("FilterPolicyScope")
self.backend.set_subscription_attributes(
subscription.arn, "FilterPolicyScope", filter_policy_scope
)
for attr_name, attr_value in attributes.items():
self.backend.set_subscription_attributes(
subscription.arn, attr_name, attr_value

View File

@ -1,5 +1,7 @@
import re
from moto.moto_api._internal import mock_random
from typing import Any, Dict, List, Iterable, Optional, Tuple, Union, Callable
import json
E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$")
@ -15,3 +17,376 @@ def make_arn_for_subscription(topic_arn: str) -> str:
def is_e164(number: str) -> bool:
return E164_REGEX.match(number) is not None
class FilterPolicyMatcher:
class CheckException(Exception):
pass
def __init__(self, filter_policy: Dict[str, Any], filter_policy_scope: str):
self.filter_policy = filter_policy
self.filter_policy_scope = (
filter_policy_scope
if filter_policy_scope is not None
else "MessageAttributes"
)
if self.filter_policy_scope not in ("MessageAttributes", "MessageBody"):
raise FilterPolicyMatcher.CheckException(
f"Unsupported filter_policy_scope: {filter_policy_scope}"
)
def matches(
self, message_attributes: Optional[Dict[str, Any]], message: str
) -> bool:
if not self.filter_policy:
return True
if self.filter_policy_scope == "MessageAttributes":
if message_attributes is None:
message_attributes = {}
return self._attributes_based_match(message_attributes)
else:
try:
message_dict = json.loads(message)
except ValueError:
return False
return self._body_based_match(message_dict)
def _attributes_based_match(self, message_attributes: Dict[str, Any]) -> bool:
return all(
FilterPolicyMatcher._field_match(field, rules, message_attributes)
for field, rules in self.filter_policy.items()
)
def _body_based_match(self, message_dict: Dict[str, Any]) -> bool:
try:
checks = self._compute_body_checks(self.filter_policy, message_dict)
except FilterPolicyMatcher.CheckException:
return False
return self._perform_body_checks(checks)
def _perform_body_checks(self, check: Any) -> bool:
# If the checks are a list, only a single elem has to pass
# otherwise all the entries have to pass
if isinstance(check, tuple):
if len(check) == 2:
# (any|all, checks)
aggregate_func, checks = check
return aggregate_func(
self._perform_body_checks(single_check) for single_check in checks
)
elif len(check) == 3:
field, rules, dict_body = check
return FilterPolicyMatcher._field_match(field, rules, dict_body, False)
raise FilterPolicyMatcher.CheckException(f"Check is not a tuple: {str(check)}")
def _compute_body_checks(
self,
filter_policy: Dict[str, Union[Dict[str, Any], List[Any]]],
message_body: Union[Dict[str, Any], List[Any]],
) -> Tuple[Callable[[Iterable[Any]], bool], Any]:
"""
Generate (possibly nested) list of checks to be performed based on the filter policy
Returned list is of format (any|all, checks), where first elem defines what aggregation should be used in checking
and the second argument is a list containing sublists of the same format or concrete checks (field, rule, body): Tuple[str, List[Any], Dict[str, Any]]
All the checks returned by this function will only require one-level-deep entry into dict in _field_match function
This is done this way to simplify the actual check logic and keep it as close as possible between MessageAttributes and MessageBody
Given message_body:
{"Records": [
{
"eventName": "ObjectCreated:Put",
},
{
"eventName": "ObjectCreated:Delete",
},
]}
and filter policy:
{"Records": {
"eventName": [{"prefix": "ObjectCreated:"}],
}}
the following check list would be computed:
(<built-in function all>, (
(<built-in function all>, (
(<built-in function any>, (
('eventName', [{'prefix': 'ObjectCreated:'}], {'eventName': 'ObjectCreated:Put'}),
('eventName', [{'prefix': 'ObjectCreated:'}], {'eventName': 'ObjectCreated:Delete'}))
),
)
),))
"""
rules = []
for filter_key, filter_value in filter_policy.items():
if isinstance(filter_value, dict):
if isinstance(message_body, dict):
message_value = message_body.get(filter_key)
if message_value is not None:
rules.append(
self._compute_body_checks(filter_value, message_value)
)
else:
raise FilterPolicyMatcher.CheckException
elif isinstance(message_body, list):
subchecks = []
for entry in message_body:
subchecks.append(
self._compute_body_checks(filter_policy, entry)
)
rules.append((any, tuple(subchecks)))
else:
raise FilterPolicyMatcher.CheckException
elif isinstance(filter_value, list):
# These are the real rules, same as in MessageAttributes case
concrete_checks = []
if isinstance(message_body, dict):
if message_body is not None:
concrete_checks.append((filter_key, filter_value, message_body))
else:
raise FilterPolicyMatcher.CheckException
elif isinstance(message_body, list):
# Apply policy to each element of the list, pass if at list one element matches
for list_elem in message_body:
concrete_checks.append((filter_key, filter_value, list_elem))
else:
raise FilterPolicyMatcher.CheckException
rules.append((any, tuple(concrete_checks)))
else:
raise FilterPolicyMatcher.CheckException
return (all, tuple(rules))
@staticmethod
def _field_match( # type: ignore # decorated function contains type Any
field: str,
rules: List[Any],
dict_body: Dict[str, Any],
attributes_based_check: bool = True,
) -> bool:
# dict_body = MessageAttributes if attributes_based_check is True
# otherwise it's the cut-out part of the MessageBody (so only single-level nesting must be supported)
# Iterate over every rule from the list of rules
# At least one rule has to match the field for the function to return a match
def _str_exact_match(value: str, rule: Union[str, List[str]]) -> bool:
if value == rule:
return True
if isinstance(value, list):
if rule in value:
return True
try:
json_data = json.loads(value)
if rule in json_data:
return True
except (ValueError, TypeError):
pass
return False
def _number_match(values: List[float], rule: float) -> bool:
for value in values:
# Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
if int(value * 1000000) == int(rule * 1000000):
return True
return False
def _exists_match(
should_exist: bool, field: str, dict_body: Dict[str, Any]
) -> bool:
if should_exist and field in dict_body:
return True
elif not should_exist and field not in dict_body:
return True
return False
def _prefix_match(prefix: str, value: str) -> bool:
return value.startswith(prefix)
def _anything_but_match(
filter_value: Union[Dict[str, Any], List[str], str],
actual_values: List[str],
) -> bool:
if isinstance(filter_value, dict):
# We can combine anything-but with the prefix-filter
anything_but_key = list(filter_value.keys())[0]
anything_but_val = list(filter_value.values())[0]
if anything_but_key != "prefix":
return False
if all([not v.startswith(anything_but_val) for v in actual_values]):
return True
else:
undesired_values = (
[filter_value] if isinstance(filter_value, str) else filter_value
)
if all([v not in undesired_values for v in actual_values]):
return True
return False
def _numeric_match(
numeric_ranges: Iterable[Tuple[str, float]], numeric_value: float
) -> bool:
# numeric_ranges' format:
# [(< x), (=, y), (>=, z)]
msg_value = numeric_value
matches = []
for operator, test_value in numeric_ranges:
if operator == ">":
matches.append((msg_value > test_value))
if operator == ">=":
matches.append((msg_value >= test_value))
if operator == "=":
matches.append((msg_value == test_value))
if operator == "<":
matches.append((msg_value < test_value))
if operator == "<=":
matches.append((msg_value <= test_value))
return all(matches)
for rule in rules:
# TODO: boolean value matching is not supported, SNS behavior unknown
if isinstance(rule, str):
if attributes_based_check:
if field not in dict_body:
return False
if _str_exact_match(dict_body[field]["Value"], rule):
return True
else:
if field not in dict_body:
return False
if _str_exact_match(dict_body[field], rule):
return True
if isinstance(rule, (int, float)):
if attributes_based_check:
if field not in dict_body:
return False
if dict_body[field]["Type"] == "Number":
attribute_values = [dict_body[field]["Value"]]
elif dict_body[field]["Type"] == "String.Array":
try:
attribute_values = json.loads(dict_body[field]["Value"])
if not isinstance(attribute_values, list):
attribute_values = [attribute_values]
except (ValueError, TypeError):
return False
else:
return False
values = [float(value) for value in attribute_values]
if _number_match(values, rule):
return True
else:
if field not in dict_body:
return False
if isinstance(dict_body[field], (int, float)):
values = [dict_body[field]]
elif isinstance(dict_body[field], list):
values = [float(value) for value in dict_body[field]]
else:
return False
if _number_match(values, rule):
return True
if isinstance(rule, dict):
keyword = list(rule.keys())[0]
value = list(rule.values())[0]
if keyword == "exists":
if attributes_based_check:
if _exists_match(value, field, dict_body):
return True
else:
if _exists_match(value, field, dict_body):
return True
elif keyword == "prefix" and isinstance(value, str):
if attributes_based_check:
if field in dict_body:
attr = dict_body[field]
if attr["Type"] == "String":
if _prefix_match(value, attr["Value"]):
return True
else:
if field in dict_body:
if _prefix_match(value, dict_body[field]):
return True
elif keyword == "anything-but":
if attributes_based_check:
if field not in dict_body:
return False
attr = dict_body[field]
if isinstance(value, dict):
# We can combine anything-but with the prefix-filter
if attr["Type"] == "String":
actual_values = [attr["Value"]]
else:
actual_values = [v for v in attr["Value"]]
else:
if attr["Type"] == "Number":
actual_values = [float(attr["Value"])]
elif attr["Type"] == "String":
actual_values = [attr["Value"]]
else:
actual_values = [v for v in attr["Value"]]
if _anything_but_match(value, actual_values):
return True
else:
if field not in dict_body:
return False
attr = dict_body[field]
if isinstance(value, dict):
if isinstance(attr, str):
actual_values = [attr]
elif isinstance(attr, list):
actual_values = attr
else:
return False
else:
if isinstance(attr, (int, float, str)):
actual_values = [attr]
elif isinstance(attr, list):
actual_values = attr
else:
return False
if _anything_but_match(value, actual_values):
return True
elif keyword == "numeric" and isinstance(value, list):
if attributes_based_check:
if dict_body.get(field, {}).get("Type", "") == "Number":
checks = value
numeric_ranges = zip(checks[0::2], checks[1::2])
if _numeric_match(
numeric_ranges, float(dict_body[field]["Value"])
):
return True
else:
if field not in dict_body:
return False
checks = value
numeric_ranges = zip(checks[0::2], checks[1::2])
if _numeric_match(numeric_ranges, dict_body[field]):
return True
return False

File diff suppressed because it is too large Load Diff

View File

@ -250,7 +250,6 @@ def test_creating_subscription_with_attributes():
filter_policy = json.dumps(
{
"store": ["example_corp"],
"event": ["order_cancelled"],
"encrypted": [False],
"customer_interests": ["basketball", "baseball"],
"price": [100, 100.12],
@ -371,7 +370,6 @@ def test_set_subscription_attributes():
filter_policy = json.dumps(
{
"store": ["example_corp"],
"event": ["order_cancelled"],
"encrypted": [False],
"customer_interests": ["basketball", "baseball"],
"price": [100, 100.12],
@ -390,6 +388,17 @@ def test_set_subscription_attributes():
attrs["Attributes"]["DeliveryPolicy"].should.equal(delivery_policy)
attrs["Attributes"]["FilterPolicy"].should.equal(filter_policy)
filter_policy_scope = "MessageBody"
conn.set_subscription_attributes(
SubscriptionArn=subscription_arn,
AttributeName="FilterPolicyScope",
AttributeValue=filter_policy_scope,
)
attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn)
attrs["Attributes"]["FilterPolicyScope"].should.equal(filter_policy_scope)
# not existing subscription
with pytest.raises(ClientError):
conn.set_subscription_attributes(
@ -641,6 +650,70 @@ def test_subscribe_invalid_filter_policy():
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
)
try:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={
"FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}})
},
)
except ClientError as err:
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
)
try:
filter_policy = {
"key_a": ["value_one"],
"key_b": ["value_two"],
"key_c": ["value_three"],
"key_d": ["value_four"],
"key_e": ["value_five"],
"key_f": ["value_six"],
}
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps(filter_policy)},
)
except ClientError as err:
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
)
try:
nested_filter_policy = {
"key_a": {
"key_b": {
"key_c": ["value_one", "value_two", "value_three", "value_four"]
},
},
"key_d": {"key_e": ["value_one", "value_two", "value_three"]},
"key_f": ["value_one", "value_two", "value_three"],
}
# The first array has four values in a three-level nested key, and the second has three values in a two-level
# nested key. The total combination is calculated as follows:
# 3 x 4 x 2 x 3 x 1 x 3 = 216
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={
"FilterPolicyScope": "MessageBody",
"FilterPolicy": json.dumps(nested_filter_policy),
},
)
except ClientError as err:
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: FilterPolicy: Filter policy is too complex"
)
@mock_sns
def test_check_not_opted_out():

View File

@ -0,0 +1,17 @@
from moto.sns.utils import FilterPolicyMatcher
import pytest
def test_filter_policy_matcher_scope_sanity_check():
with pytest.raises(FilterPolicyMatcher.CheckException):
FilterPolicyMatcher({}, "IncorrectFilterPolicyScope")
def test_filter_policy_matcher_empty_message_attributes():
matcher = FilterPolicyMatcher({}, None)
assert matcher.matches(None, "")
def test_filter_policy_matcher_empty_message_attributes_filtering_fail():
matcher = FilterPolicyMatcher({"store": ["test"]}, None)
assert not matcher.matches(None, "")