sns: support FilterPolicyScope attribute, including filtering (#6262)
This commit is contained in:
parent
523225d6e9
commit
054ebcb326
@ -28,7 +28,12 @@ from .exceptions import (
|
||||
TooManyEntriesInBatchRequest,
|
||||
BatchEntryIdsNotDistinct,
|
||||
)
|
||||
from .utils import make_arn_for_topic, make_arn_for_subscription, is_e164
|
||||
from .utils import (
|
||||
make_arn_for_topic,
|
||||
make_arn_for_subscription,
|
||||
is_e164,
|
||||
FilterPolicyMatcher,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_PAGE_SIZE = 100
|
||||
@ -192,6 +197,7 @@ class Subscription(BaseModel):
|
||||
self.arn = make_arn_for_subscription(self.topic.arn)
|
||||
self.attributes: Dict[str, Any] = {}
|
||||
self._filter_policy = None # filter policy as a dict, not json.
|
||||
self._filter_policy_matcher = None
|
||||
self.confirmed = False
|
||||
|
||||
def publish(
|
||||
@ -203,8 +209,9 @@ class Subscription(BaseModel):
|
||||
group_id: Optional[str] = None,
|
||||
deduplication_id: Optional[str] = None,
|
||||
) -> None:
|
||||
if not self._matches_filter_policy(message_attributes):
|
||||
return
|
||||
if self._filter_policy_matcher is not None:
|
||||
if not self._filter_policy_matcher.matches(message_attributes, message):
|
||||
return
|
||||
|
||||
if self.protocol == "sqs":
|
||||
queue_name = self.endpoint.split(":")[-1]
|
||||
@ -277,131 +284,6 @@ class Subscription(BaseModel):
|
||||
function_name, message, subject=subject, qualifier=qualifier
|
||||
)
|
||||
|
||||
def _matches_filter_policy(
|
||||
self, message_attributes: Optional[Dict[str, Any]]
|
||||
) -> bool:
|
||||
if not self._filter_policy:
|
||||
return True
|
||||
|
||||
if message_attributes is None:
|
||||
message_attributes = {}
|
||||
|
||||
def _field_match(
|
||||
field: str, rules: List[Any], message_attributes: Dict[str, Any]
|
||||
) -> bool:
|
||||
for rule in rules:
|
||||
# TODO: boolean value matching is not supported, SNS behavior unknown
|
||||
if isinstance(rule, str):
|
||||
if field not in message_attributes:
|
||||
return False
|
||||
if message_attributes[field]["Value"] == rule:
|
||||
return True
|
||||
try:
|
||||
json_data = json.loads(message_attributes[field]["Value"])
|
||||
if rule in json_data:
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if isinstance(rule, (int, float)):
|
||||
if field not in message_attributes:
|
||||
return False
|
||||
if message_attributes[field]["Type"] == "Number":
|
||||
attribute_values = [message_attributes[field]["Value"]]
|
||||
elif message_attributes[field]["Type"] == "String.Array":
|
||||
try:
|
||||
attribute_values = json.loads(
|
||||
message_attributes[field]["Value"]
|
||||
)
|
||||
if not isinstance(attribute_values, list):
|
||||
attribute_values = [attribute_values]
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
for attribute_value in attribute_values:
|
||||
attribute_value = float(attribute_value)
|
||||
# Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6
|
||||
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
|
||||
if int(attribute_value * 1000000) == int(rule * 1000000):
|
||||
return True
|
||||
if isinstance(rule, dict):
|
||||
keyword = list(rule.keys())[0]
|
||||
value = list(rule.values())[0]
|
||||
if keyword == "exists":
|
||||
if value and field in message_attributes:
|
||||
return True
|
||||
elif not value and field not in message_attributes:
|
||||
return True
|
||||
elif keyword == "prefix" and isinstance(value, str):
|
||||
if field in message_attributes:
|
||||
attr = message_attributes[field]
|
||||
if attr["Type"] == "String" and attr["Value"].startswith(
|
||||
value
|
||||
):
|
||||
return True
|
||||
elif keyword == "anything-but":
|
||||
if field not in message_attributes:
|
||||
continue
|
||||
attr = message_attributes[field]
|
||||
if isinstance(value, dict):
|
||||
# We can combine anything-but with the prefix-filter
|
||||
anything_but_key = list(value.keys())[0]
|
||||
anything_but_val = list(value.values())[0]
|
||||
if anything_but_key != "prefix":
|
||||
return False
|
||||
if attr["Type"] == "String":
|
||||
actual_values = [attr["Value"]]
|
||||
else:
|
||||
actual_values = [v for v in attr["Value"]]
|
||||
if all(
|
||||
[
|
||||
not v.startswith(anything_but_val)
|
||||
for v in actual_values
|
||||
]
|
||||
):
|
||||
return True
|
||||
else:
|
||||
undesired_values = (
|
||||
[value] if isinstance(value, str) else value
|
||||
)
|
||||
if attr["Type"] == "Number":
|
||||
actual_values = [float(attr["Value"])]
|
||||
elif attr["Type"] == "String":
|
||||
actual_values = [attr["Value"]]
|
||||
else:
|
||||
actual_values = [v for v in attr["Value"]]
|
||||
if all([v not in undesired_values for v in actual_values]):
|
||||
return True
|
||||
elif keyword == "numeric" and isinstance(value, list):
|
||||
# [(< x), (=, y), (>=, z)]
|
||||
numeric_ranges = zip(value[0::2], value[1::2])
|
||||
if (
|
||||
message_attributes.get(field, {}).get("Type", "")
|
||||
== "Number"
|
||||
):
|
||||
msg_value = float(message_attributes[field]["Value"])
|
||||
matches = []
|
||||
for operator, test_value in numeric_ranges:
|
||||
test_value = test_value
|
||||
if operator == ">":
|
||||
matches.append((msg_value > test_value))
|
||||
if operator == ">=":
|
||||
matches.append((msg_value >= test_value))
|
||||
if operator == "=":
|
||||
matches.append((msg_value == test_value))
|
||||
if operator == "<":
|
||||
matches.append((msg_value < test_value))
|
||||
if operator == "<=":
|
||||
matches.append((msg_value <= test_value))
|
||||
return all(matches)
|
||||
return False
|
||||
|
||||
return all(
|
||||
_field_match(field, rules, message_attributes)
|
||||
for field, rules in self._filter_policy.items()
|
||||
)
|
||||
|
||||
def get_post_data(
|
||||
self,
|
||||
message: str,
|
||||
@ -848,6 +730,7 @@ class SNSBackend(BaseBackend):
|
||||
"RawMessageDelivery",
|
||||
"DeliveryPolicy",
|
||||
"FilterPolicy",
|
||||
"FilterPolicyScope",
|
||||
"RedrivePolicy",
|
||||
"SubscriptionRoleArn",
|
||||
]:
|
||||
@ -859,26 +742,90 @@ class SNSBackend(BaseBackend):
|
||||
raise SNSNotFoundError(f"Subscription with arn {arn} not found")
|
||||
subscription = _subscription[0]
|
||||
|
||||
subscription.attributes[name] = value
|
||||
|
||||
if name == "FilterPolicy":
|
||||
filter_policy = json.loads(value)
|
||||
self._validate_filter_policy(filter_policy)
|
||||
# we validate the filter policy differently depending on the scope
|
||||
# we need to always set the scope first
|
||||
filter_policy_scope = subscription.attributes.get("FilterPolicyScope")
|
||||
self._validate_filter_policy(filter_policy, scope=filter_policy_scope)
|
||||
subscription._filter_policy = filter_policy
|
||||
subscription._filter_policy_matcher = FilterPolicyMatcher(
|
||||
filter_policy, filter_policy_scope
|
||||
)
|
||||
|
||||
def _validate_filter_policy(self, value: Any) -> None:
|
||||
# TODO: extend validation checks
|
||||
subscription.attributes[name] = value
|
||||
|
||||
def _validate_filter_policy(self, value: Any, scope: str) -> None:
|
||||
combinations = 1
|
||||
for rules in value.values():
|
||||
combinations *= len(rules)
|
||||
# Even the official documentation states the total combination of values must not exceed 100, in reality it is 150
|
||||
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
|
||||
|
||||
def aggregate_rules(
|
||||
filter_policy: Dict[str, Any], depth: int = 1
|
||||
) -> List[List[Any]]:
|
||||
"""
|
||||
This method evaluate the filter policy recursively, and returns only a list of lists of rules.
|
||||
It also calculates the combinations of rules, calculated depending on the nesting of the rules.
|
||||
Example:
|
||||
nested_filter_policy = {
|
||||
"key_a": {
|
||||
"key_b": {
|
||||
"key_c": ["value_one", "value_two", "value_three", "value_four"]
|
||||
}
|
||||
},
|
||||
"key_d": {
|
||||
"key_e": ["value_one", "value_two", "value_three"]
|
||||
}
|
||||
}
|
||||
This function then iterates on the values of the top level keys of the filter policy: ("key_a", "key_d")
|
||||
If the iterated value is not a list, it means it is a nested property. If the scope is `MessageBody`, it is
|
||||
allowed, we call this method on the value, adding a level to the depth to keep track on how deep the key is.
|
||||
If the value is a list, it means it contains rules: we will append this list of rules in _rules, and
|
||||
calculate the combinations it adds.
|
||||
For the example filter policy containing nested properties, we calculate it this way
|
||||
The first array has four values in a three-level nested key, and the second has three values in a two-level
|
||||
nested key. 3 x 4 x 2 x 3 = 72
|
||||
The return value would be:
|
||||
[["value_one", "value_two", "value_three", "value_four"], ["value_one", "value_two", "value_three"]]
|
||||
It allows us to later iterate of the list of rules in an easy way, to verify its conditions.
|
||||
|
||||
:param filter_policy: a dict, starting at the FilterPolicy
|
||||
:param depth: the depth/level of the rules we are evaluating
|
||||
:return: a list of lists of rules
|
||||
"""
|
||||
nonlocal combinations
|
||||
_rules = []
|
||||
for key, _value in filter_policy.items():
|
||||
if isinstance(_value, dict):
|
||||
if scope == "MessageBody":
|
||||
# From AWS docs: "unlike attribute-based policies, payload-based policies support property nesting."
|
||||
_rules.extend(aggregate_rules(_value, depth=depth + 1))
|
||||
else:
|
||||
raise SNSInvalidParameter(
|
||||
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
|
||||
)
|
||||
elif isinstance(_value, list):
|
||||
_rules.append(_value)
|
||||
combinations = combinations * len(_value) * depth
|
||||
else:
|
||||
raise SNSInvalidParameter(
|
||||
f'Invalid parameter: FilterPolicy: "{key}" must be an object or an array'
|
||||
)
|
||||
return _rules
|
||||
|
||||
# A filter policy can have a maximum of five attribute names. For a nested policy, only parent keys are counted.
|
||||
if len(value.values()) > 5:
|
||||
raise SNSInvalidParameter(
|
||||
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
|
||||
)
|
||||
|
||||
aggregated_rules = aggregate_rules(value)
|
||||
# For the complexity of the filter policy, the total combination of values must not exceed 150.
|
||||
# https://docs.aws.amazon.com/sns/latest/dg/subscription-filter-policy-constraints.html
|
||||
if combinations > 150:
|
||||
raise SNSInvalidParameter(
|
||||
"Invalid parameter: FilterPolicy: Filter policy is too complex"
|
||||
)
|
||||
|
||||
for rules in value.values():
|
||||
for rules in aggregated_rules:
|
||||
for rule in rules:
|
||||
if rule is None:
|
||||
continue
|
||||
|
@ -226,6 +226,13 @@ class SNSResponse(BaseResponse):
|
||||
subscription = self.backend.subscribe(topic_arn, endpoint, protocol)
|
||||
|
||||
if attributes is not None:
|
||||
# We need to set the FilterPolicyScope first, as the validation of the FilterPolicy will depend on it
|
||||
if "FilterPolicyScope" in attributes:
|
||||
filter_policy_scope = attributes.pop("FilterPolicyScope")
|
||||
self.backend.set_subscription_attributes(
|
||||
subscription.arn, "FilterPolicyScope", filter_policy_scope
|
||||
)
|
||||
|
||||
for attr_name, attr_value in attributes.items():
|
||||
self.backend.set_subscription_attributes(
|
||||
subscription.arn, attr_name, attr_value
|
||||
|
@ -1,5 +1,7 @@
|
||||
import re
|
||||
from moto.moto_api._internal import mock_random
|
||||
from typing import Any, Dict, List, Iterable, Optional, Tuple, Union, Callable
|
||||
import json
|
||||
|
||||
E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$")
|
||||
|
||||
@ -15,3 +17,376 @@ def make_arn_for_subscription(topic_arn: str) -> str:
|
||||
|
||||
def is_e164(number: str) -> bool:
|
||||
return E164_REGEX.match(number) is not None
|
||||
|
||||
|
||||
class FilterPolicyMatcher:
|
||||
class CheckException(Exception):
|
||||
pass
|
||||
|
||||
def __init__(self, filter_policy: Dict[str, Any], filter_policy_scope: str):
|
||||
self.filter_policy = filter_policy
|
||||
self.filter_policy_scope = (
|
||||
filter_policy_scope
|
||||
if filter_policy_scope is not None
|
||||
else "MessageAttributes"
|
||||
)
|
||||
|
||||
if self.filter_policy_scope not in ("MessageAttributes", "MessageBody"):
|
||||
raise FilterPolicyMatcher.CheckException(
|
||||
f"Unsupported filter_policy_scope: {filter_policy_scope}"
|
||||
)
|
||||
|
||||
def matches(
|
||||
self, message_attributes: Optional[Dict[str, Any]], message: str
|
||||
) -> bool:
|
||||
if not self.filter_policy:
|
||||
return True
|
||||
|
||||
if self.filter_policy_scope == "MessageAttributes":
|
||||
if message_attributes is None:
|
||||
message_attributes = {}
|
||||
|
||||
return self._attributes_based_match(message_attributes)
|
||||
else:
|
||||
try:
|
||||
message_dict = json.loads(message)
|
||||
except ValueError:
|
||||
return False
|
||||
return self._body_based_match(message_dict)
|
||||
|
||||
def _attributes_based_match(self, message_attributes: Dict[str, Any]) -> bool:
|
||||
return all(
|
||||
FilterPolicyMatcher._field_match(field, rules, message_attributes)
|
||||
for field, rules in self.filter_policy.items()
|
||||
)
|
||||
|
||||
def _body_based_match(self, message_dict: Dict[str, Any]) -> bool:
|
||||
try:
|
||||
checks = self._compute_body_checks(self.filter_policy, message_dict)
|
||||
except FilterPolicyMatcher.CheckException:
|
||||
return False
|
||||
|
||||
return self._perform_body_checks(checks)
|
||||
|
||||
def _perform_body_checks(self, check: Any) -> bool:
|
||||
# If the checks are a list, only a single elem has to pass
|
||||
# otherwise all the entries have to pass
|
||||
|
||||
if isinstance(check, tuple):
|
||||
if len(check) == 2:
|
||||
# (any|all, checks)
|
||||
aggregate_func, checks = check
|
||||
return aggregate_func(
|
||||
self._perform_body_checks(single_check) for single_check in checks
|
||||
)
|
||||
elif len(check) == 3:
|
||||
field, rules, dict_body = check
|
||||
return FilterPolicyMatcher._field_match(field, rules, dict_body, False)
|
||||
|
||||
raise FilterPolicyMatcher.CheckException(f"Check is not a tuple: {str(check)}")
|
||||
|
||||
def _compute_body_checks(
|
||||
self,
|
||||
filter_policy: Dict[str, Union[Dict[str, Any], List[Any]]],
|
||||
message_body: Union[Dict[str, Any], List[Any]],
|
||||
) -> Tuple[Callable[[Iterable[Any]], bool], Any]:
|
||||
"""
|
||||
Generate (possibly nested) list of checks to be performed based on the filter policy
|
||||
Returned list is of format (any|all, checks), where first elem defines what aggregation should be used in checking
|
||||
and the second argument is a list containing sublists of the same format or concrete checks (field, rule, body): Tuple[str, List[Any], Dict[str, Any]]
|
||||
|
||||
All the checks returned by this function will only require one-level-deep entry into dict in _field_match function
|
||||
This is done this way to simplify the actual check logic and keep it as close as possible between MessageAttributes and MessageBody
|
||||
|
||||
Given message_body:
|
||||
{"Records": [
|
||||
{
|
||||
"eventName": "ObjectCreated:Put",
|
||||
},
|
||||
{
|
||||
"eventName": "ObjectCreated:Delete",
|
||||
},
|
||||
]}
|
||||
|
||||
and filter policy:
|
||||
{"Records": {
|
||||
"eventName": [{"prefix": "ObjectCreated:"}],
|
||||
}}
|
||||
|
||||
the following check list would be computed:
|
||||
(<built-in function all>, (
|
||||
(<built-in function all>, (
|
||||
(<built-in function any>, (
|
||||
('eventName', [{'prefix': 'ObjectCreated:'}], {'eventName': 'ObjectCreated:Put'}),
|
||||
('eventName', [{'prefix': 'ObjectCreated:'}], {'eventName': 'ObjectCreated:Delete'}))
|
||||
),
|
||||
)
|
||||
),))
|
||||
"""
|
||||
rules = []
|
||||
for filter_key, filter_value in filter_policy.items():
|
||||
if isinstance(filter_value, dict):
|
||||
if isinstance(message_body, dict):
|
||||
message_value = message_body.get(filter_key)
|
||||
if message_value is not None:
|
||||
rules.append(
|
||||
self._compute_body_checks(filter_value, message_value)
|
||||
)
|
||||
else:
|
||||
raise FilterPolicyMatcher.CheckException
|
||||
elif isinstance(message_body, list):
|
||||
subchecks = []
|
||||
for entry in message_body:
|
||||
subchecks.append(
|
||||
self._compute_body_checks(filter_policy, entry)
|
||||
)
|
||||
rules.append((any, tuple(subchecks)))
|
||||
else:
|
||||
raise FilterPolicyMatcher.CheckException
|
||||
|
||||
elif isinstance(filter_value, list):
|
||||
# These are the real rules, same as in MessageAttributes case
|
||||
|
||||
concrete_checks = []
|
||||
if isinstance(message_body, dict):
|
||||
if message_body is not None:
|
||||
concrete_checks.append((filter_key, filter_value, message_body))
|
||||
else:
|
||||
raise FilterPolicyMatcher.CheckException
|
||||
elif isinstance(message_body, list):
|
||||
# Apply policy to each element of the list, pass if at list one element matches
|
||||
for list_elem in message_body:
|
||||
concrete_checks.append((filter_key, filter_value, list_elem))
|
||||
else:
|
||||
raise FilterPolicyMatcher.CheckException
|
||||
rules.append((any, tuple(concrete_checks)))
|
||||
else:
|
||||
raise FilterPolicyMatcher.CheckException
|
||||
|
||||
return (all, tuple(rules))
|
||||
|
||||
@staticmethod
|
||||
def _field_match( # type: ignore # decorated function contains type Any
|
||||
field: str,
|
||||
rules: List[Any],
|
||||
dict_body: Dict[str, Any],
|
||||
attributes_based_check: bool = True,
|
||||
) -> bool:
|
||||
# dict_body = MessageAttributes if attributes_based_check is True
|
||||
# otherwise it's the cut-out part of the MessageBody (so only single-level nesting must be supported)
|
||||
|
||||
# Iterate over every rule from the list of rules
|
||||
# At least one rule has to match the field for the function to return a match
|
||||
|
||||
def _str_exact_match(value: str, rule: Union[str, List[str]]) -> bool:
|
||||
if value == rule:
|
||||
return True
|
||||
|
||||
if isinstance(value, list):
|
||||
if rule in value:
|
||||
return True
|
||||
|
||||
try:
|
||||
json_data = json.loads(value)
|
||||
if rule in json_data:
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def _number_match(values: List[float], rule: float) -> bool:
|
||||
for value in values:
|
||||
# Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6
|
||||
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
|
||||
if int(value * 1000000) == int(rule * 1000000):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _exists_match(
|
||||
should_exist: bool, field: str, dict_body: Dict[str, Any]
|
||||
) -> bool:
|
||||
if should_exist and field in dict_body:
|
||||
return True
|
||||
elif not should_exist and field not in dict_body:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _prefix_match(prefix: str, value: str) -> bool:
|
||||
return value.startswith(prefix)
|
||||
|
||||
def _anything_but_match(
|
||||
filter_value: Union[Dict[str, Any], List[str], str],
|
||||
actual_values: List[str],
|
||||
) -> bool:
|
||||
if isinstance(filter_value, dict):
|
||||
# We can combine anything-but with the prefix-filter
|
||||
anything_but_key = list(filter_value.keys())[0]
|
||||
anything_but_val = list(filter_value.values())[0]
|
||||
if anything_but_key != "prefix":
|
||||
return False
|
||||
if all([not v.startswith(anything_but_val) for v in actual_values]):
|
||||
return True
|
||||
else:
|
||||
undesired_values = (
|
||||
[filter_value] if isinstance(filter_value, str) else filter_value
|
||||
)
|
||||
if all([v not in undesired_values for v in actual_values]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _numeric_match(
|
||||
numeric_ranges: Iterable[Tuple[str, float]], numeric_value: float
|
||||
) -> bool:
|
||||
# numeric_ranges' format:
|
||||
# [(< x), (=, y), (>=, z)]
|
||||
msg_value = numeric_value
|
||||
matches = []
|
||||
for operator, test_value in numeric_ranges:
|
||||
if operator == ">":
|
||||
matches.append((msg_value > test_value))
|
||||
if operator == ">=":
|
||||
matches.append((msg_value >= test_value))
|
||||
if operator == "=":
|
||||
matches.append((msg_value == test_value))
|
||||
if operator == "<":
|
||||
matches.append((msg_value < test_value))
|
||||
if operator == "<=":
|
||||
matches.append((msg_value <= test_value))
|
||||
return all(matches)
|
||||
|
||||
for rule in rules:
|
||||
# TODO: boolean value matching is not supported, SNS behavior unknown
|
||||
if isinstance(rule, str):
|
||||
if attributes_based_check:
|
||||
if field not in dict_body:
|
||||
return False
|
||||
if _str_exact_match(dict_body[field]["Value"], rule):
|
||||
return True
|
||||
else:
|
||||
if field not in dict_body:
|
||||
return False
|
||||
if _str_exact_match(dict_body[field], rule):
|
||||
return True
|
||||
|
||||
if isinstance(rule, (int, float)):
|
||||
if attributes_based_check:
|
||||
if field not in dict_body:
|
||||
return False
|
||||
if dict_body[field]["Type"] == "Number":
|
||||
attribute_values = [dict_body[field]["Value"]]
|
||||
elif dict_body[field]["Type"] == "String.Array":
|
||||
try:
|
||||
attribute_values = json.loads(dict_body[field]["Value"])
|
||||
if not isinstance(attribute_values, list):
|
||||
attribute_values = [attribute_values]
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
values = [float(value) for value in attribute_values]
|
||||
if _number_match(values, rule):
|
||||
return True
|
||||
else:
|
||||
if field not in dict_body:
|
||||
return False
|
||||
|
||||
if isinstance(dict_body[field], (int, float)):
|
||||
values = [dict_body[field]]
|
||||
elif isinstance(dict_body[field], list):
|
||||
values = [float(value) for value in dict_body[field]]
|
||||
else:
|
||||
return False
|
||||
|
||||
if _number_match(values, rule):
|
||||
return True
|
||||
|
||||
if isinstance(rule, dict):
|
||||
keyword = list(rule.keys())[0]
|
||||
value = list(rule.values())[0]
|
||||
if keyword == "exists":
|
||||
if attributes_based_check:
|
||||
if _exists_match(value, field, dict_body):
|
||||
return True
|
||||
else:
|
||||
if _exists_match(value, field, dict_body):
|
||||
return True
|
||||
elif keyword == "prefix" and isinstance(value, str):
|
||||
if attributes_based_check:
|
||||
if field in dict_body:
|
||||
attr = dict_body[field]
|
||||
if attr["Type"] == "String":
|
||||
if _prefix_match(value, attr["Value"]):
|
||||
return True
|
||||
else:
|
||||
if field in dict_body:
|
||||
if _prefix_match(value, dict_body[field]):
|
||||
return True
|
||||
|
||||
elif keyword == "anything-but":
|
||||
if attributes_based_check:
|
||||
if field not in dict_body:
|
||||
return False
|
||||
attr = dict_body[field]
|
||||
if isinstance(value, dict):
|
||||
# We can combine anything-but with the prefix-filter
|
||||
if attr["Type"] == "String":
|
||||
actual_values = [attr["Value"]]
|
||||
else:
|
||||
actual_values = [v for v in attr["Value"]]
|
||||
else:
|
||||
if attr["Type"] == "Number":
|
||||
actual_values = [float(attr["Value"])]
|
||||
elif attr["Type"] == "String":
|
||||
actual_values = [attr["Value"]]
|
||||
else:
|
||||
actual_values = [v for v in attr["Value"]]
|
||||
|
||||
if _anything_but_match(value, actual_values):
|
||||
return True
|
||||
else:
|
||||
if field not in dict_body:
|
||||
return False
|
||||
attr = dict_body[field]
|
||||
if isinstance(value, dict):
|
||||
if isinstance(attr, str):
|
||||
actual_values = [attr]
|
||||
elif isinstance(attr, list):
|
||||
actual_values = attr
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
if isinstance(attr, (int, float, str)):
|
||||
actual_values = [attr]
|
||||
elif isinstance(attr, list):
|
||||
actual_values = attr
|
||||
else:
|
||||
return False
|
||||
|
||||
if _anything_but_match(value, actual_values):
|
||||
return True
|
||||
|
||||
elif keyword == "numeric" and isinstance(value, list):
|
||||
if attributes_based_check:
|
||||
if dict_body.get(field, {}).get("Type", "") == "Number":
|
||||
|
||||
checks = value
|
||||
numeric_ranges = zip(checks[0::2], checks[1::2])
|
||||
if _numeric_match(
|
||||
numeric_ranges, float(dict_body[field]["Value"])
|
||||
):
|
||||
return True
|
||||
else:
|
||||
if field not in dict_body:
|
||||
return False
|
||||
|
||||
checks = value
|
||||
numeric_ranges = zip(checks[0::2], checks[1::2])
|
||||
if _numeric_match(numeric_ranges, dict_body[field]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -250,7 +250,6 @@ def test_creating_subscription_with_attributes():
|
||||
filter_policy = json.dumps(
|
||||
{
|
||||
"store": ["example_corp"],
|
||||
"event": ["order_cancelled"],
|
||||
"encrypted": [False],
|
||||
"customer_interests": ["basketball", "baseball"],
|
||||
"price": [100, 100.12],
|
||||
@ -371,7 +370,6 @@ def test_set_subscription_attributes():
|
||||
filter_policy = json.dumps(
|
||||
{
|
||||
"store": ["example_corp"],
|
||||
"event": ["order_cancelled"],
|
||||
"encrypted": [False],
|
||||
"customer_interests": ["basketball", "baseball"],
|
||||
"price": [100, 100.12],
|
||||
@ -390,6 +388,17 @@ def test_set_subscription_attributes():
|
||||
attrs["Attributes"]["DeliveryPolicy"].should.equal(delivery_policy)
|
||||
attrs["Attributes"]["FilterPolicy"].should.equal(filter_policy)
|
||||
|
||||
filter_policy_scope = "MessageBody"
|
||||
conn.set_subscription_attributes(
|
||||
SubscriptionArn=subscription_arn,
|
||||
AttributeName="FilterPolicyScope",
|
||||
AttributeValue=filter_policy_scope,
|
||||
)
|
||||
|
||||
attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn)
|
||||
|
||||
attrs["Attributes"]["FilterPolicyScope"].should.equal(filter_policy_scope)
|
||||
|
||||
# not existing subscription
|
||||
with pytest.raises(ClientError):
|
||||
conn.set_subscription_attributes(
|
||||
@ -641,6 +650,70 @@ def test_subscribe_invalid_filter_policy():
|
||||
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
|
||||
)
|
||||
|
||||
try:
|
||||
conn.subscribe(
|
||||
TopicArn=topic_arn,
|
||||
Protocol="http",
|
||||
Endpoint="http://example.com/",
|
||||
Attributes={
|
||||
"FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}})
|
||||
},
|
||||
)
|
||||
except ClientError as err:
|
||||
err.response["Error"]["Code"].should.equal("InvalidParameter")
|
||||
err.response["Error"]["Message"].should.equal(
|
||||
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
|
||||
)
|
||||
|
||||
try:
|
||||
filter_policy = {
|
||||
"key_a": ["value_one"],
|
||||
"key_b": ["value_two"],
|
||||
"key_c": ["value_three"],
|
||||
"key_d": ["value_four"],
|
||||
"key_e": ["value_five"],
|
||||
"key_f": ["value_six"],
|
||||
}
|
||||
conn.subscribe(
|
||||
TopicArn=topic_arn,
|
||||
Protocol="http",
|
||||
Endpoint="http://example.com/",
|
||||
Attributes={"FilterPolicy": json.dumps(filter_policy)},
|
||||
)
|
||||
except ClientError as err:
|
||||
err.response["Error"]["Code"].should.equal("InvalidParameter")
|
||||
err.response["Error"]["Message"].should.equal(
|
||||
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
|
||||
)
|
||||
|
||||
try:
|
||||
nested_filter_policy = {
|
||||
"key_a": {
|
||||
"key_b": {
|
||||
"key_c": ["value_one", "value_two", "value_three", "value_four"]
|
||||
},
|
||||
},
|
||||
"key_d": {"key_e": ["value_one", "value_two", "value_three"]},
|
||||
"key_f": ["value_one", "value_two", "value_three"],
|
||||
}
|
||||
# The first array has four values in a three-level nested key, and the second has three values in a two-level
|
||||
# nested key. The total combination is calculated as follows:
|
||||
# 3 x 4 x 2 x 3 x 1 x 3 = 216
|
||||
conn.subscribe(
|
||||
TopicArn=topic_arn,
|
||||
Protocol="http",
|
||||
Endpoint="http://example.com/",
|
||||
Attributes={
|
||||
"FilterPolicyScope": "MessageBody",
|
||||
"FilterPolicy": json.dumps(nested_filter_policy),
|
||||
},
|
||||
)
|
||||
except ClientError as err:
|
||||
err.response["Error"]["Code"].should.equal("InvalidParameter")
|
||||
err.response["Error"]["Message"].should.equal(
|
||||
"Invalid parameter: FilterPolicy: Filter policy is too complex"
|
||||
)
|
||||
|
||||
|
||||
@mock_sns
|
||||
def test_check_not_opted_out():
|
||||
|
17
tests/test_sns/test_utils.py
Normal file
17
tests/test_sns/test_utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
from moto.sns.utils import FilterPolicyMatcher
|
||||
import pytest
|
||||
|
||||
|
||||
def test_filter_policy_matcher_scope_sanity_check():
|
||||
with pytest.raises(FilterPolicyMatcher.CheckException):
|
||||
FilterPolicyMatcher({}, "IncorrectFilterPolicyScope")
|
||||
|
||||
|
||||
def test_filter_policy_matcher_empty_message_attributes():
|
||||
matcher = FilterPolicyMatcher({}, None)
|
||||
assert matcher.matches(None, "")
|
||||
|
||||
|
||||
def test_filter_policy_matcher_empty_message_attributes_filtering_fail():
|
||||
matcher = FilterPolicyMatcher({"store": ["test"]}, None)
|
||||
assert not matcher.matches(None, "")
|
Loading…
Reference in New Issue
Block a user