sns: support FilterPolicyScope attribute, including filtering (#6262)
This commit is contained in:
parent
523225d6e9
commit
054ebcb326
@ -28,7 +28,12 @@ from .exceptions import (
|
|||||||
TooManyEntriesInBatchRequest,
|
TooManyEntriesInBatchRequest,
|
||||||
BatchEntryIdsNotDistinct,
|
BatchEntryIdsNotDistinct,
|
||||||
)
|
)
|
||||||
from .utils import make_arn_for_topic, make_arn_for_subscription, is_e164
|
from .utils import (
|
||||||
|
make_arn_for_topic,
|
||||||
|
make_arn_for_subscription,
|
||||||
|
is_e164,
|
||||||
|
FilterPolicyMatcher,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PAGE_SIZE = 100
|
DEFAULT_PAGE_SIZE = 100
|
||||||
@ -192,6 +197,7 @@ class Subscription(BaseModel):
|
|||||||
self.arn = make_arn_for_subscription(self.topic.arn)
|
self.arn = make_arn_for_subscription(self.topic.arn)
|
||||||
self.attributes: Dict[str, Any] = {}
|
self.attributes: Dict[str, Any] = {}
|
||||||
self._filter_policy = None # filter policy as a dict, not json.
|
self._filter_policy = None # filter policy as a dict, not json.
|
||||||
|
self._filter_policy_matcher = None
|
||||||
self.confirmed = False
|
self.confirmed = False
|
||||||
|
|
||||||
def publish(
|
def publish(
|
||||||
@ -203,8 +209,9 @@ class Subscription(BaseModel):
|
|||||||
group_id: Optional[str] = None,
|
group_id: Optional[str] = None,
|
||||||
deduplication_id: Optional[str] = None,
|
deduplication_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self._matches_filter_policy(message_attributes):
|
if self._filter_policy_matcher is not None:
|
||||||
return
|
if not self._filter_policy_matcher.matches(message_attributes, message):
|
||||||
|
return
|
||||||
|
|
||||||
if self.protocol == "sqs":
|
if self.protocol == "sqs":
|
||||||
queue_name = self.endpoint.split(":")[-1]
|
queue_name = self.endpoint.split(":")[-1]
|
||||||
@ -277,131 +284,6 @@ class Subscription(BaseModel):
|
|||||||
function_name, message, subject=subject, qualifier=qualifier
|
function_name, message, subject=subject, qualifier=qualifier
|
||||||
)
|
)
|
||||||
|
|
||||||
def _matches_filter_policy(
|
|
||||||
self, message_attributes: Optional[Dict[str, Any]]
|
|
||||||
) -> bool:
|
|
||||||
if not self._filter_policy:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if message_attributes is None:
|
|
||||||
message_attributes = {}
|
|
||||||
|
|
||||||
def _field_match(
|
|
||||||
field: str, rules: List[Any], message_attributes: Dict[str, Any]
|
|
||||||
) -> bool:
|
|
||||||
for rule in rules:
|
|
||||||
# TODO: boolean value matching is not supported, SNS behavior unknown
|
|
||||||
if isinstance(rule, str):
|
|
||||||
if field not in message_attributes:
|
|
||||||
return False
|
|
||||||
if message_attributes[field]["Value"] == rule:
|
|
||||||
return True
|
|
||||||
try:
|
|
||||||
json_data = json.loads(message_attributes[field]["Value"])
|
|
||||||
if rule in json_data:
|
|
||||||
return True
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
if isinstance(rule, (int, float)):
|
|
||||||
if field not in message_attributes:
|
|
||||||
return False
|
|
||||||
if message_attributes[field]["Type"] == "Number":
|
|
||||||
attribute_values = [message_attributes[field]["Value"]]
|
|
||||||
elif message_attributes[field]["Type"] == "String.Array":
|
|
||||||
try:
|
|
||||||
attribute_values = json.loads(
|
|
||||||
message_attributes[field]["Value"]
|
|
||||||
)
|
|
||||||
if not isinstance(attribute_values, list):
|
|
||||||
attribute_values = [attribute_values]
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
for attribute_value in attribute_values:
|
|
||||||
attribute_value = float(attribute_value)
|
|
||||||
# Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6
|
|
||||||
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
|
|
||||||
if int(attribute_value * 1000000) == int(rule * 1000000):
|
|
||||||
return True
|
|
||||||
if isinstance(rule, dict):
|
|
||||||
keyword = list(rule.keys())[0]
|
|
||||||
value = list(rule.values())[0]
|
|
||||||
if keyword == "exists":
|
|
||||||
if value and field in message_attributes:
|
|
||||||
return True
|
|
||||||
elif not value and field not in message_attributes:
|
|
||||||
return True
|
|
||||||
elif keyword == "prefix" and isinstance(value, str):
|
|
||||||
if field in message_attributes:
|
|
||||||
attr = message_attributes[field]
|
|
||||||
if attr["Type"] == "String" and attr["Value"].startswith(
|
|
||||||
value
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
elif keyword == "anything-but":
|
|
||||||
if field not in message_attributes:
|
|
||||||
continue
|
|
||||||
attr = message_attributes[field]
|
|
||||||
if isinstance(value, dict):
|
|
||||||
# We can combine anything-but with the prefix-filter
|
|
||||||
anything_but_key = list(value.keys())[0]
|
|
||||||
anything_but_val = list(value.values())[0]
|
|
||||||
if anything_but_key != "prefix":
|
|
||||||
return False
|
|
||||||
if attr["Type"] == "String":
|
|
||||||
actual_values = [attr["Value"]]
|
|
||||||
else:
|
|
||||||
actual_values = [v for v in attr["Value"]]
|
|
||||||
if all(
|
|
||||||
[
|
|
||||||
not v.startswith(anything_but_val)
|
|
||||||
for v in actual_values
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
undesired_values = (
|
|
||||||
[value] if isinstance(value, str) else value
|
|
||||||
)
|
|
||||||
if attr["Type"] == "Number":
|
|
||||||
actual_values = [float(attr["Value"])]
|
|
||||||
elif attr["Type"] == "String":
|
|
||||||
actual_values = [attr["Value"]]
|
|
||||||
else:
|
|
||||||
actual_values = [v for v in attr["Value"]]
|
|
||||||
if all([v not in undesired_values for v in actual_values]):
|
|
||||||
return True
|
|
||||||
elif keyword == "numeric" and isinstance(value, list):
|
|
||||||
# [(< x), (=, y), (>=, z)]
|
|
||||||
numeric_ranges = zip(value[0::2], value[1::2])
|
|
||||||
if (
|
|
||||||
message_attributes.get(field, {}).get("Type", "")
|
|
||||||
== "Number"
|
|
||||||
):
|
|
||||||
msg_value = float(message_attributes[field]["Value"])
|
|
||||||
matches = []
|
|
||||||
for operator, test_value in numeric_ranges:
|
|
||||||
test_value = test_value
|
|
||||||
if operator == ">":
|
|
||||||
matches.append((msg_value > test_value))
|
|
||||||
if operator == ">=":
|
|
||||||
matches.append((msg_value >= test_value))
|
|
||||||
if operator == "=":
|
|
||||||
matches.append((msg_value == test_value))
|
|
||||||
if operator == "<":
|
|
||||||
matches.append((msg_value < test_value))
|
|
||||||
if operator == "<=":
|
|
||||||
matches.append((msg_value <= test_value))
|
|
||||||
return all(matches)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return all(
|
|
||||||
_field_match(field, rules, message_attributes)
|
|
||||||
for field, rules in self._filter_policy.items()
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_post_data(
|
def get_post_data(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
@ -848,6 +730,7 @@ class SNSBackend(BaseBackend):
|
|||||||
"RawMessageDelivery",
|
"RawMessageDelivery",
|
||||||
"DeliveryPolicy",
|
"DeliveryPolicy",
|
||||||
"FilterPolicy",
|
"FilterPolicy",
|
||||||
|
"FilterPolicyScope",
|
||||||
"RedrivePolicy",
|
"RedrivePolicy",
|
||||||
"SubscriptionRoleArn",
|
"SubscriptionRoleArn",
|
||||||
]:
|
]:
|
||||||
@ -859,26 +742,90 @@ class SNSBackend(BaseBackend):
|
|||||||
raise SNSNotFoundError(f"Subscription with arn {arn} not found")
|
raise SNSNotFoundError(f"Subscription with arn {arn} not found")
|
||||||
subscription = _subscription[0]
|
subscription = _subscription[0]
|
||||||
|
|
||||||
subscription.attributes[name] = value
|
|
||||||
|
|
||||||
if name == "FilterPolicy":
|
if name == "FilterPolicy":
|
||||||
filter_policy = json.loads(value)
|
filter_policy = json.loads(value)
|
||||||
self._validate_filter_policy(filter_policy)
|
# we validate the filter policy differently depending on the scope
|
||||||
|
# we need to always set the scope first
|
||||||
|
filter_policy_scope = subscription.attributes.get("FilterPolicyScope")
|
||||||
|
self._validate_filter_policy(filter_policy, scope=filter_policy_scope)
|
||||||
subscription._filter_policy = filter_policy
|
subscription._filter_policy = filter_policy
|
||||||
|
subscription._filter_policy_matcher = FilterPolicyMatcher(
|
||||||
|
filter_policy, filter_policy_scope
|
||||||
|
)
|
||||||
|
|
||||||
def _validate_filter_policy(self, value: Any) -> None:
|
subscription.attributes[name] = value
|
||||||
# TODO: extend validation checks
|
|
||||||
|
def _validate_filter_policy(self, value: Any, scope: str) -> None:
|
||||||
combinations = 1
|
combinations = 1
|
||||||
for rules in value.values():
|
|
||||||
combinations *= len(rules)
|
def aggregate_rules(
|
||||||
# Even the official documentation states the total combination of values must not exceed 100, in reality it is 150
|
filter_policy: Dict[str, Any], depth: int = 1
|
||||||
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
|
) -> List[List[Any]]:
|
||||||
|
"""
|
||||||
|
This method evaluate the filter policy recursively, and returns only a list of lists of rules.
|
||||||
|
It also calculates the combinations of rules, calculated depending on the nesting of the rules.
|
||||||
|
Example:
|
||||||
|
nested_filter_policy = {
|
||||||
|
"key_a": {
|
||||||
|
"key_b": {
|
||||||
|
"key_c": ["value_one", "value_two", "value_three", "value_four"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"key_d": {
|
||||||
|
"key_e": ["value_one", "value_two", "value_three"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
This function then iterates on the values of the top level keys of the filter policy: ("key_a", "key_d")
|
||||||
|
If the iterated value is not a list, it means it is a nested property. If the scope is `MessageBody`, it is
|
||||||
|
allowed, we call this method on the value, adding a level to the depth to keep track on how deep the key is.
|
||||||
|
If the value is a list, it means it contains rules: we will append this list of rules in _rules, and
|
||||||
|
calculate the combinations it adds.
|
||||||
|
For the example filter policy containing nested properties, we calculate it this way
|
||||||
|
The first array has four values in a three-level nested key, and the second has three values in a two-level
|
||||||
|
nested key. 3 x 4 x 2 x 3 = 72
|
||||||
|
The return value would be:
|
||||||
|
[["value_one", "value_two", "value_three", "value_four"], ["value_one", "value_two", "value_three"]]
|
||||||
|
It allows us to later iterate of the list of rules in an easy way, to verify its conditions.
|
||||||
|
|
||||||
|
:param filter_policy: a dict, starting at the FilterPolicy
|
||||||
|
:param depth: the depth/level of the rules we are evaluating
|
||||||
|
:return: a list of lists of rules
|
||||||
|
"""
|
||||||
|
nonlocal combinations
|
||||||
|
_rules = []
|
||||||
|
for key, _value in filter_policy.items():
|
||||||
|
if isinstance(_value, dict):
|
||||||
|
if scope == "MessageBody":
|
||||||
|
# From AWS docs: "unlike attribute-based policies, payload-based policies support property nesting."
|
||||||
|
_rules.extend(aggregate_rules(_value, depth=depth + 1))
|
||||||
|
else:
|
||||||
|
raise SNSInvalidParameter(
|
||||||
|
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
|
||||||
|
)
|
||||||
|
elif isinstance(_value, list):
|
||||||
|
_rules.append(_value)
|
||||||
|
combinations = combinations * len(_value) * depth
|
||||||
|
else:
|
||||||
|
raise SNSInvalidParameter(
|
||||||
|
f'Invalid parameter: FilterPolicy: "{key}" must be an object or an array'
|
||||||
|
)
|
||||||
|
return _rules
|
||||||
|
|
||||||
|
# A filter policy can have a maximum of five attribute names. For a nested policy, only parent keys are counted.
|
||||||
|
if len(value.values()) > 5:
|
||||||
|
raise SNSInvalidParameter(
|
||||||
|
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated_rules = aggregate_rules(value)
|
||||||
|
# For the complexity of the filter policy, the total combination of values must not exceed 150.
|
||||||
|
# https://docs.aws.amazon.com/sns/latest/dg/subscription-filter-policy-constraints.html
|
||||||
if combinations > 150:
|
if combinations > 150:
|
||||||
raise SNSInvalidParameter(
|
raise SNSInvalidParameter(
|
||||||
"Invalid parameter: FilterPolicy: Filter policy is too complex"
|
"Invalid parameter: FilterPolicy: Filter policy is too complex"
|
||||||
)
|
)
|
||||||
|
|
||||||
for rules in value.values():
|
for rules in aggregated_rules:
|
||||||
for rule in rules:
|
for rule in rules:
|
||||||
if rule is None:
|
if rule is None:
|
||||||
continue
|
continue
|
||||||
|
@ -226,6 +226,13 @@ class SNSResponse(BaseResponse):
|
|||||||
subscription = self.backend.subscribe(topic_arn, endpoint, protocol)
|
subscription = self.backend.subscribe(topic_arn, endpoint, protocol)
|
||||||
|
|
||||||
if attributes is not None:
|
if attributes is not None:
|
||||||
|
# We need to set the FilterPolicyScope first, as the validation of the FilterPolicy will depend on it
|
||||||
|
if "FilterPolicyScope" in attributes:
|
||||||
|
filter_policy_scope = attributes.pop("FilterPolicyScope")
|
||||||
|
self.backend.set_subscription_attributes(
|
||||||
|
subscription.arn, "FilterPolicyScope", filter_policy_scope
|
||||||
|
)
|
||||||
|
|
||||||
for attr_name, attr_value in attributes.items():
|
for attr_name, attr_value in attributes.items():
|
||||||
self.backend.set_subscription_attributes(
|
self.backend.set_subscription_attributes(
|
||||||
subscription.arn, attr_name, attr_value
|
subscription.arn, attr_name, attr_value
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from moto.moto_api._internal import mock_random
|
from moto.moto_api._internal import mock_random
|
||||||
|
from typing import Any, Dict, List, Iterable, Optional, Tuple, Union, Callable
|
||||||
|
import json
|
||||||
|
|
||||||
E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$")
|
E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$")
|
||||||
|
|
||||||
@ -15,3 +17,376 @@ def make_arn_for_subscription(topic_arn: str) -> str:
|
|||||||
|
|
||||||
def is_e164(number: str) -> bool:
|
def is_e164(number: str) -> bool:
|
||||||
return E164_REGEX.match(number) is not None
|
return E164_REGEX.match(number) is not None
|
||||||
|
|
||||||
|
|
||||||
|
class FilterPolicyMatcher:
|
||||||
|
class CheckException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, filter_policy: Dict[str, Any], filter_policy_scope: str):
|
||||||
|
self.filter_policy = filter_policy
|
||||||
|
self.filter_policy_scope = (
|
||||||
|
filter_policy_scope
|
||||||
|
if filter_policy_scope is not None
|
||||||
|
else "MessageAttributes"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.filter_policy_scope not in ("MessageAttributes", "MessageBody"):
|
||||||
|
raise FilterPolicyMatcher.CheckException(
|
||||||
|
f"Unsupported filter_policy_scope: {filter_policy_scope}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def matches(
|
||||||
|
self, message_attributes: Optional[Dict[str, Any]], message: str
|
||||||
|
) -> bool:
|
||||||
|
if not self.filter_policy:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.filter_policy_scope == "MessageAttributes":
|
||||||
|
if message_attributes is None:
|
||||||
|
message_attributes = {}
|
||||||
|
|
||||||
|
return self._attributes_based_match(message_attributes)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
message_dict = json.loads(message)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return self._body_based_match(message_dict)
|
||||||
|
|
||||||
|
def _attributes_based_match(self, message_attributes: Dict[str, Any]) -> bool:
|
||||||
|
return all(
|
||||||
|
FilterPolicyMatcher._field_match(field, rules, message_attributes)
|
||||||
|
for field, rules in self.filter_policy.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _body_based_match(self, message_dict: Dict[str, Any]) -> bool:
|
||||||
|
try:
|
||||||
|
checks = self._compute_body_checks(self.filter_policy, message_dict)
|
||||||
|
except FilterPolicyMatcher.CheckException:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self._perform_body_checks(checks)
|
||||||
|
|
||||||
|
def _perform_body_checks(self, check: Any) -> bool:
|
||||||
|
# If the checks are a list, only a single elem has to pass
|
||||||
|
# otherwise all the entries have to pass
|
||||||
|
|
||||||
|
if isinstance(check, tuple):
|
||||||
|
if len(check) == 2:
|
||||||
|
# (any|all, checks)
|
||||||
|
aggregate_func, checks = check
|
||||||
|
return aggregate_func(
|
||||||
|
self._perform_body_checks(single_check) for single_check in checks
|
||||||
|
)
|
||||||
|
elif len(check) == 3:
|
||||||
|
field, rules, dict_body = check
|
||||||
|
return FilterPolicyMatcher._field_match(field, rules, dict_body, False)
|
||||||
|
|
||||||
|
raise FilterPolicyMatcher.CheckException(f"Check is not a tuple: {str(check)}")
|
||||||
|
|
||||||
|
def _compute_body_checks(
|
||||||
|
self,
|
||||||
|
filter_policy: Dict[str, Union[Dict[str, Any], List[Any]]],
|
||||||
|
message_body: Union[Dict[str, Any], List[Any]],
|
||||||
|
) -> Tuple[Callable[[Iterable[Any]], bool], Any]:
|
||||||
|
"""
|
||||||
|
Generate (possibly nested) list of checks to be performed based on the filter policy
|
||||||
|
Returned list is of format (any|all, checks), where first elem defines what aggregation should be used in checking
|
||||||
|
and the second argument is a list containing sublists of the same format or concrete checks (field, rule, body): Tuple[str, List[Any], Dict[str, Any]]
|
||||||
|
|
||||||
|
All the checks returned by this function will only require one-level-deep entry into dict in _field_match function
|
||||||
|
This is done this way to simplify the actual check logic and keep it as close as possible between MessageAttributes and MessageBody
|
||||||
|
|
||||||
|
Given message_body:
|
||||||
|
{"Records": [
|
||||||
|
{
|
||||||
|
"eventName": "ObjectCreated:Put",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"eventName": "ObjectCreated:Delete",
|
||||||
|
},
|
||||||
|
]}
|
||||||
|
|
||||||
|
and filter policy:
|
||||||
|
{"Records": {
|
||||||
|
"eventName": [{"prefix": "ObjectCreated:"}],
|
||||||
|
}}
|
||||||
|
|
||||||
|
the following check list would be computed:
|
||||||
|
(<built-in function all>, (
|
||||||
|
(<built-in function all>, (
|
||||||
|
(<built-in function any>, (
|
||||||
|
('eventName', [{'prefix': 'ObjectCreated:'}], {'eventName': 'ObjectCreated:Put'}),
|
||||||
|
('eventName', [{'prefix': 'ObjectCreated:'}], {'eventName': 'ObjectCreated:Delete'}))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),))
|
||||||
|
"""
|
||||||
|
rules = []
|
||||||
|
for filter_key, filter_value in filter_policy.items():
|
||||||
|
if isinstance(filter_value, dict):
|
||||||
|
if isinstance(message_body, dict):
|
||||||
|
message_value = message_body.get(filter_key)
|
||||||
|
if message_value is not None:
|
||||||
|
rules.append(
|
||||||
|
self._compute_body_checks(filter_value, message_value)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise FilterPolicyMatcher.CheckException
|
||||||
|
elif isinstance(message_body, list):
|
||||||
|
subchecks = []
|
||||||
|
for entry in message_body:
|
||||||
|
subchecks.append(
|
||||||
|
self._compute_body_checks(filter_policy, entry)
|
||||||
|
)
|
||||||
|
rules.append((any, tuple(subchecks)))
|
||||||
|
else:
|
||||||
|
raise FilterPolicyMatcher.CheckException
|
||||||
|
|
||||||
|
elif isinstance(filter_value, list):
|
||||||
|
# These are the real rules, same as in MessageAttributes case
|
||||||
|
|
||||||
|
concrete_checks = []
|
||||||
|
if isinstance(message_body, dict):
|
||||||
|
if message_body is not None:
|
||||||
|
concrete_checks.append((filter_key, filter_value, message_body))
|
||||||
|
else:
|
||||||
|
raise FilterPolicyMatcher.CheckException
|
||||||
|
elif isinstance(message_body, list):
|
||||||
|
# Apply policy to each element of the list, pass if at list one element matches
|
||||||
|
for list_elem in message_body:
|
||||||
|
concrete_checks.append((filter_key, filter_value, list_elem))
|
||||||
|
else:
|
||||||
|
raise FilterPolicyMatcher.CheckException
|
||||||
|
rules.append((any, tuple(concrete_checks)))
|
||||||
|
else:
|
||||||
|
raise FilterPolicyMatcher.CheckException
|
||||||
|
|
||||||
|
return (all, tuple(rules))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _field_match( # type: ignore # decorated function contains type Any
|
||||||
|
field: str,
|
||||||
|
rules: List[Any],
|
||||||
|
dict_body: Dict[str, Any],
|
||||||
|
attributes_based_check: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
# dict_body = MessageAttributes if attributes_based_check is True
|
||||||
|
# otherwise it's the cut-out part of the MessageBody (so only single-level nesting must be supported)
|
||||||
|
|
||||||
|
# Iterate over every rule from the list of rules
|
||||||
|
# At least one rule has to match the field for the function to return a match
|
||||||
|
|
||||||
|
def _str_exact_match(value: str, rule: Union[str, List[str]]) -> bool:
|
||||||
|
if value == rule:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
if rule in value:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_data = json.loads(value)
|
||||||
|
if rule in json_data:
|
||||||
|
return True
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _number_match(values: List[float], rule: float) -> bool:
|
||||||
|
for value in values:
|
||||||
|
# Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6
|
||||||
|
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
|
||||||
|
if int(value * 1000000) == int(rule * 1000000):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _exists_match(
|
||||||
|
should_exist: bool, field: str, dict_body: Dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
if should_exist and field in dict_body:
|
||||||
|
return True
|
||||||
|
elif not should_exist and field not in dict_body:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _prefix_match(prefix: str, value: str) -> bool:
|
||||||
|
return value.startswith(prefix)
|
||||||
|
|
||||||
|
def _anything_but_match(
|
||||||
|
filter_value: Union[Dict[str, Any], List[str], str],
|
||||||
|
actual_values: List[str],
|
||||||
|
) -> bool:
|
||||||
|
if isinstance(filter_value, dict):
|
||||||
|
# We can combine anything-but with the prefix-filter
|
||||||
|
anything_but_key = list(filter_value.keys())[0]
|
||||||
|
anything_but_val = list(filter_value.values())[0]
|
||||||
|
if anything_but_key != "prefix":
|
||||||
|
return False
|
||||||
|
if all([not v.startswith(anything_but_val) for v in actual_values]):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
undesired_values = (
|
||||||
|
[filter_value] if isinstance(filter_value, str) else filter_value
|
||||||
|
)
|
||||||
|
if all([v not in undesired_values for v in actual_values]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _numeric_match(
|
||||||
|
numeric_ranges: Iterable[Tuple[str, float]], numeric_value: float
|
||||||
|
) -> bool:
|
||||||
|
# numeric_ranges' format:
|
||||||
|
# [(< x), (=, y), (>=, z)]
|
||||||
|
msg_value = numeric_value
|
||||||
|
matches = []
|
||||||
|
for operator, test_value in numeric_ranges:
|
||||||
|
if operator == ">":
|
||||||
|
matches.append((msg_value > test_value))
|
||||||
|
if operator == ">=":
|
||||||
|
matches.append((msg_value >= test_value))
|
||||||
|
if operator == "=":
|
||||||
|
matches.append((msg_value == test_value))
|
||||||
|
if operator == "<":
|
||||||
|
matches.append((msg_value < test_value))
|
||||||
|
if operator == "<=":
|
||||||
|
matches.append((msg_value <= test_value))
|
||||||
|
return all(matches)
|
||||||
|
|
||||||
|
for rule in rules:
|
||||||
|
# TODO: boolean value matching is not supported, SNS behavior unknown
|
||||||
|
if isinstance(rule, str):
|
||||||
|
if attributes_based_check:
|
||||||
|
if field not in dict_body:
|
||||||
|
return False
|
||||||
|
if _str_exact_match(dict_body[field]["Value"], rule):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if field not in dict_body:
|
||||||
|
return False
|
||||||
|
if _str_exact_match(dict_body[field], rule):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(rule, (int, float)):
|
||||||
|
if attributes_based_check:
|
||||||
|
if field not in dict_body:
|
||||||
|
return False
|
||||||
|
if dict_body[field]["Type"] == "Number":
|
||||||
|
attribute_values = [dict_body[field]["Value"]]
|
||||||
|
elif dict_body[field]["Type"] == "String.Array":
|
||||||
|
try:
|
||||||
|
attribute_values = json.loads(dict_body[field]["Value"])
|
||||||
|
if not isinstance(attribute_values, list):
|
||||||
|
attribute_values = [attribute_values]
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
values = [float(value) for value in attribute_values]
|
||||||
|
if _number_match(values, rule):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if field not in dict_body:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(dict_body[field], (int, float)):
|
||||||
|
values = [dict_body[field]]
|
||||||
|
elif isinstance(dict_body[field], list):
|
||||||
|
values = [float(value) for value in dict_body[field]]
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if _number_match(values, rule):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(rule, dict):
|
||||||
|
keyword = list(rule.keys())[0]
|
||||||
|
value = list(rule.values())[0]
|
||||||
|
if keyword == "exists":
|
||||||
|
if attributes_based_check:
|
||||||
|
if _exists_match(value, field, dict_body):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if _exists_match(value, field, dict_body):
|
||||||
|
return True
|
||||||
|
elif keyword == "prefix" and isinstance(value, str):
|
||||||
|
if attributes_based_check:
|
||||||
|
if field in dict_body:
|
||||||
|
attr = dict_body[field]
|
||||||
|
if attr["Type"] == "String":
|
||||||
|
if _prefix_match(value, attr["Value"]):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if field in dict_body:
|
||||||
|
if _prefix_match(value, dict_body[field]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif keyword == "anything-but":
|
||||||
|
if attributes_based_check:
|
||||||
|
if field not in dict_body:
|
||||||
|
return False
|
||||||
|
attr = dict_body[field]
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# We can combine anything-but with the prefix-filter
|
||||||
|
if attr["Type"] == "String":
|
||||||
|
actual_values = [attr["Value"]]
|
||||||
|
else:
|
||||||
|
actual_values = [v for v in attr["Value"]]
|
||||||
|
else:
|
||||||
|
if attr["Type"] == "Number":
|
||||||
|
actual_values = [float(attr["Value"])]
|
||||||
|
elif attr["Type"] == "String":
|
||||||
|
actual_values = [attr["Value"]]
|
||||||
|
else:
|
||||||
|
actual_values = [v for v in attr["Value"]]
|
||||||
|
|
||||||
|
if _anything_but_match(value, actual_values):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if field not in dict_body:
|
||||||
|
return False
|
||||||
|
attr = dict_body[field]
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if isinstance(attr, str):
|
||||||
|
actual_values = [attr]
|
||||||
|
elif isinstance(attr, list):
|
||||||
|
actual_values = attr
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if isinstance(attr, (int, float, str)):
|
||||||
|
actual_values = [attr]
|
||||||
|
elif isinstance(attr, list):
|
||||||
|
actual_values = attr
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if _anything_but_match(value, actual_values):
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif keyword == "numeric" and isinstance(value, list):
|
||||||
|
if attributes_based_check:
|
||||||
|
if dict_body.get(field, {}).get("Type", "") == "Number":
|
||||||
|
|
||||||
|
checks = value
|
||||||
|
numeric_ranges = zip(checks[0::2], checks[1::2])
|
||||||
|
if _numeric_match(
|
||||||
|
numeric_ranges, float(dict_body[field]["Value"])
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if field not in dict_body:
|
||||||
|
return False
|
||||||
|
|
||||||
|
checks = value
|
||||||
|
numeric_ranges = zip(checks[0::2], checks[1::2])
|
||||||
|
if _numeric_match(numeric_ranges, dict_body[field]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -250,7 +250,6 @@ def test_creating_subscription_with_attributes():
|
|||||||
filter_policy = json.dumps(
|
filter_policy = json.dumps(
|
||||||
{
|
{
|
||||||
"store": ["example_corp"],
|
"store": ["example_corp"],
|
||||||
"event": ["order_cancelled"],
|
|
||||||
"encrypted": [False],
|
"encrypted": [False],
|
||||||
"customer_interests": ["basketball", "baseball"],
|
"customer_interests": ["basketball", "baseball"],
|
||||||
"price": [100, 100.12],
|
"price": [100, 100.12],
|
||||||
@ -371,7 +370,6 @@ def test_set_subscription_attributes():
|
|||||||
filter_policy = json.dumps(
|
filter_policy = json.dumps(
|
||||||
{
|
{
|
||||||
"store": ["example_corp"],
|
"store": ["example_corp"],
|
||||||
"event": ["order_cancelled"],
|
|
||||||
"encrypted": [False],
|
"encrypted": [False],
|
||||||
"customer_interests": ["basketball", "baseball"],
|
"customer_interests": ["basketball", "baseball"],
|
||||||
"price": [100, 100.12],
|
"price": [100, 100.12],
|
||||||
@ -390,6 +388,17 @@ def test_set_subscription_attributes():
|
|||||||
attrs["Attributes"]["DeliveryPolicy"].should.equal(delivery_policy)
|
attrs["Attributes"]["DeliveryPolicy"].should.equal(delivery_policy)
|
||||||
attrs["Attributes"]["FilterPolicy"].should.equal(filter_policy)
|
attrs["Attributes"]["FilterPolicy"].should.equal(filter_policy)
|
||||||
|
|
||||||
|
filter_policy_scope = "MessageBody"
|
||||||
|
conn.set_subscription_attributes(
|
||||||
|
SubscriptionArn=subscription_arn,
|
||||||
|
AttributeName="FilterPolicyScope",
|
||||||
|
AttributeValue=filter_policy_scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn)
|
||||||
|
|
||||||
|
attrs["Attributes"]["FilterPolicyScope"].should.equal(filter_policy_scope)
|
||||||
|
|
||||||
# not existing subscription
|
# not existing subscription
|
||||||
with pytest.raises(ClientError):
|
with pytest.raises(ClientError):
|
||||||
conn.set_subscription_attributes(
|
conn.set_subscription_attributes(
|
||||||
@ -641,6 +650,70 @@ def test_subscribe_invalid_filter_policy():
|
|||||||
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
|
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn.subscribe(
|
||||||
|
TopicArn=topic_arn,
|
||||||
|
Protocol="http",
|
||||||
|
Endpoint="http://example.com/",
|
||||||
|
Attributes={
|
||||||
|
"FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except ClientError as err:
|
||||||
|
err.response["Error"]["Code"].should.equal("InvalidParameter")
|
||||||
|
err.response["Error"]["Message"].should.equal(
|
||||||
|
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
filter_policy = {
|
||||||
|
"key_a": ["value_one"],
|
||||||
|
"key_b": ["value_two"],
|
||||||
|
"key_c": ["value_three"],
|
||||||
|
"key_d": ["value_four"],
|
||||||
|
"key_e": ["value_five"],
|
||||||
|
"key_f": ["value_six"],
|
||||||
|
}
|
||||||
|
conn.subscribe(
|
||||||
|
TopicArn=topic_arn,
|
||||||
|
Protocol="http",
|
||||||
|
Endpoint="http://example.com/",
|
||||||
|
Attributes={"FilterPolicy": json.dumps(filter_policy)},
|
||||||
|
)
|
||||||
|
except ClientError as err:
|
||||||
|
err.response["Error"]["Code"].should.equal("InvalidParameter")
|
||||||
|
err.response["Error"]["Message"].should.equal(
|
||||||
|
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nested_filter_policy = {
|
||||||
|
"key_a": {
|
||||||
|
"key_b": {
|
||||||
|
"key_c": ["value_one", "value_two", "value_three", "value_four"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"key_d": {"key_e": ["value_one", "value_two", "value_three"]},
|
||||||
|
"key_f": ["value_one", "value_two", "value_three"],
|
||||||
|
}
|
||||||
|
# The first array has four values in a three-level nested key, and the second has three values in a two-level
|
||||||
|
# nested key. The total combination is calculated as follows:
|
||||||
|
# 3 x 4 x 2 x 3 x 1 x 3 = 216
|
||||||
|
conn.subscribe(
|
||||||
|
TopicArn=topic_arn,
|
||||||
|
Protocol="http",
|
||||||
|
Endpoint="http://example.com/",
|
||||||
|
Attributes={
|
||||||
|
"FilterPolicyScope": "MessageBody",
|
||||||
|
"FilterPolicy": json.dumps(nested_filter_policy),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except ClientError as err:
|
||||||
|
err.response["Error"]["Code"].should.equal("InvalidParameter")
|
||||||
|
err.response["Error"]["Message"].should.equal(
|
||||||
|
"Invalid parameter: FilterPolicy: Filter policy is too complex"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock_sns
|
@mock_sns
|
||||||
def test_check_not_opted_out():
|
def test_check_not_opted_out():
|
||||||
|
17
tests/test_sns/test_utils.py
Normal file
17
tests/test_sns/test_utils.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from moto.sns.utils import FilterPolicyMatcher
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_policy_matcher_scope_sanity_check():
|
||||||
|
with pytest.raises(FilterPolicyMatcher.CheckException):
|
||||||
|
FilterPolicyMatcher({}, "IncorrectFilterPolicyScope")
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_policy_matcher_empty_message_attributes():
|
||||||
|
matcher = FilterPolicyMatcher({}, None)
|
||||||
|
assert matcher.matches(None, "")
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_policy_matcher_empty_message_attributes_filtering_fail():
|
||||||
|
matcher = FilterPolicyMatcher({"store": ["test"]}, None)
|
||||||
|
assert not matcher.matches(None, "")
|
Loading…
x
Reference in New Issue
Block a user