diff --git a/moto/sns/models.py b/moto/sns/models.py index fd43eb181..545d72227 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -28,7 +28,12 @@ from .exceptions import ( TooManyEntriesInBatchRequest, BatchEntryIdsNotDistinct, ) -from .utils import make_arn_for_topic, make_arn_for_subscription, is_e164 +from .utils import ( + make_arn_for_topic, + make_arn_for_subscription, + is_e164, + FilterPolicyMatcher, +) DEFAULT_PAGE_SIZE = 100 @@ -192,6 +197,7 @@ class Subscription(BaseModel): self.arn = make_arn_for_subscription(self.topic.arn) self.attributes: Dict[str, Any] = {} self._filter_policy = None # filter policy as a dict, not json. + self._filter_policy_matcher = None self.confirmed = False def publish( @@ -203,8 +209,9 @@ class Subscription(BaseModel): group_id: Optional[str] = None, deduplication_id: Optional[str] = None, ) -> None: - if not self._matches_filter_policy(message_attributes): - return + if self._filter_policy_matcher is not None: + if not self._filter_policy_matcher.matches(message_attributes, message): + return if self.protocol == "sqs": queue_name = self.endpoint.split(":")[-1] @@ -277,131 +284,6 @@ class Subscription(BaseModel): function_name, message, subject=subject, qualifier=qualifier ) - def _matches_filter_policy( - self, message_attributes: Optional[Dict[str, Any]] - ) -> bool: - if not self._filter_policy: - return True - - if message_attributes is None: - message_attributes = {} - - def _field_match( - field: str, rules: List[Any], message_attributes: Dict[str, Any] - ) -> bool: - for rule in rules: - # TODO: boolean value matching is not supported, SNS behavior unknown - if isinstance(rule, str): - if field not in message_attributes: - return False - if message_attributes[field]["Value"] == rule: - return True - try: - json_data = json.loads(message_attributes[field]["Value"]) - if rule in json_data: - return True - except (ValueError, TypeError): - pass - if isinstance(rule, (int, float)): - if field not in message_attributes: - return False - if message_attributes[field]["Type"] == "Number": - attribute_values = [message_attributes[field]["Value"]] - elif message_attributes[field]["Type"] == "String.Array": - try: - attribute_values = json.loads( - message_attributes[field]["Value"] - ) - if not isinstance(attribute_values, list): - attribute_values = [attribute_values] - except (ValueError, TypeError): - return False - else: - return False - - for attribute_value in attribute_values: - attribute_value = float(attribute_value) - # Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6 - # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints - if int(attribute_value * 1000000) == int(rule * 1000000): - return True - if isinstance(rule, dict): - keyword = list(rule.keys())[0] - value = list(rule.values())[0] - if keyword == "exists": - if value and field in message_attributes: - return True - elif not value and field not in message_attributes: - return True - elif keyword == "prefix" and isinstance(value, str): - if field in message_attributes: - attr = message_attributes[field] - if attr["Type"] == "String" and attr["Value"].startswith( - value - ): - return True - elif keyword == "anything-but": - if field not in message_attributes: - continue - attr = message_attributes[field] - if isinstance(value, dict): - # We can combine anything-but with the prefix-filter - anything_but_key = list(value.keys())[0] - anything_but_val = list(value.values())[0] - if anything_but_key != "prefix": - return False - if attr["Type"] == "String": - actual_values = [attr["Value"]] - else: - actual_values = [v for v in attr["Value"]] - if all( - [ - not v.startswith(anything_but_val) - for v in actual_values - ] - ): - return True - else: - undesired_values = ( - [value] if isinstance(value, str) else value - ) - if attr["Type"] == "Number": - actual_values = [float(attr["Value"])] - elif attr["Type"] == "String": - actual_values = [attr["Value"]] - else: - actual_values = [v for v in attr["Value"]] - if all([v not in undesired_values for v in actual_values]): - return True - elif keyword == "numeric" and isinstance(value, list): - # [(< x), (=, y), (>=, z)] - numeric_ranges = zip(value[0::2], value[1::2]) - if ( - message_attributes.get(field, {}).get("Type", "") - == "Number" - ): - msg_value = float(message_attributes[field]["Value"]) - matches = [] - for operator, test_value in numeric_ranges: - test_value = test_value - if operator == ">": - matches.append((msg_value > test_value)) - if operator == ">=": - matches.append((msg_value >= test_value)) - if operator == "=": - matches.append((msg_value == test_value)) - if operator == "<": - matches.append((msg_value < test_value)) - if operator == "<=": - matches.append((msg_value <= test_value)) - return all(matches) - return False - - return all( - _field_match(field, rules, message_attributes) - for field, rules in self._filter_policy.items() - ) - def get_post_data( self, message: str, @@ -848,6 +730,7 @@ class SNSBackend(BaseBackend): "RawMessageDelivery", "DeliveryPolicy", "FilterPolicy", + "FilterPolicyScope", "RedrivePolicy", "SubscriptionRoleArn", ]: @@ -859,26 +742,90 @@ class SNSBackend(BaseBackend): raise SNSNotFoundError(f"Subscription with arn {arn} not found") subscription = _subscription[0] - subscription.attributes[name] = value - if name == "FilterPolicy": filter_policy = json.loads(value) - self._validate_filter_policy(filter_policy) + # we validate the filter policy differently depending on the scope + # we need to always set the scope first + filter_policy_scope = subscription.attributes.get("FilterPolicyScope") + self._validate_filter_policy(filter_policy, scope=filter_policy_scope) subscription._filter_policy = filter_policy + subscription._filter_policy_matcher = FilterPolicyMatcher( + filter_policy, filter_policy_scope + ) - def _validate_filter_policy(self, value: Any) -> None: - # TODO: extend validation checks + subscription.attributes[name] = value + + def _validate_filter_policy(self, value: Any, scope: str) -> None: combinations = 1 - for rules in value.values(): - combinations *= len(rules) - # Even the official documentation states the total combination of values must not exceed 100, in reality it is 150 - # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints + + def aggregate_rules( + filter_policy: Dict[str, Any], depth: int = 1 + ) -> List[List[Any]]: + """ + This method evaluate the filter policy recursively, and returns only a list of lists of rules. + It also calculates the combinations of rules, calculated depending on the nesting of the rules. + Example: + nested_filter_policy = { + "key_a": { + "key_b": { + "key_c": ["value_one", "value_two", "value_three", "value_four"] + } + }, + "key_d": { + "key_e": ["value_one", "value_two", "value_three"] + } + } + This function then iterates on the values of the top level keys of the filter policy: ("key_a", "key_d") + If the iterated value is not a list, it means it is a nested property. If the scope is `MessageBody`, it is + allowed, we call this method on the value, adding a level to the depth to keep track on how deep the key is. + If the value is a list, it means it contains rules: we will append this list of rules in _rules, and + calculate the combinations it adds. + For the example filter policy containing nested properties, we calculate it this way + The first array has four values in a three-level nested key, and the second has three values in a two-level + nested key. 3 x 4 x 2 x 3 = 72 + The return value would be: + [["value_one", "value_two", "value_three", "value_four"], ["value_one", "value_two", "value_three"]] + It allows us to later iterate of the list of rules in an easy way, to verify its conditions. + + :param filter_policy: a dict, starting at the FilterPolicy + :param depth: the depth/level of the rules we are evaluating + :return: a list of lists of rules + """ + nonlocal combinations + _rules = [] + for key, _value in filter_policy.items(): + if isinstance(_value, dict): + if scope == "MessageBody": + # From AWS docs: "unlike attribute-based policies, payload-based policies support property nesting." + _rules.extend(aggregate_rules(_value, depth=depth + 1)) + else: + raise SNSInvalidParameter( + "Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy" + ) + elif isinstance(_value, list): + _rules.append(_value) + combinations = combinations * len(_value) * depth + else: + raise SNSInvalidParameter( + f'Invalid parameter: FilterPolicy: "{key}" must be an object or an array' + ) + return _rules + + # A filter policy can have a maximum of five attribute names. For a nested policy, only parent keys are counted. + if len(value.values()) > 5: + raise SNSInvalidParameter( + "Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys" + ) + + aggregated_rules = aggregate_rules(value) + # For the complexity of the filter policy, the total combination of values must not exceed 150. + # https://docs.aws.amazon.com/sns/latest/dg/subscription-filter-policy-constraints.html if combinations > 150: raise SNSInvalidParameter( "Invalid parameter: FilterPolicy: Filter policy is too complex" ) - for rules in value.values(): + for rules in aggregated_rules: for rule in rules: if rule is None: continue diff --git a/moto/sns/responses.py b/moto/sns/responses.py index adad2dc96..5bbd84d92 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -226,6 +226,13 @@ class SNSResponse(BaseResponse): subscription = self.backend.subscribe(topic_arn, endpoint, protocol) if attributes is not None: + # We need to set the FilterPolicyScope first, as the validation of the FilterPolicy will depend on it + if "FilterPolicyScope" in attributes: + filter_policy_scope = attributes.pop("FilterPolicyScope") + self.backend.set_subscription_attributes( + subscription.arn, "FilterPolicyScope", filter_policy_scope + ) + for attr_name, attr_value in attributes.items(): self.backend.set_subscription_attributes( subscription.arn, attr_name, attr_value diff --git a/moto/sns/utils.py b/moto/sns/utils.py index 8849aa153..26a229e7e 100644 --- a/moto/sns/utils.py +++ b/moto/sns/utils.py @@ -1,5 +1,7 @@ import re from moto.moto_api._internal import mock_random +from typing import Any, Dict, List, Iterable, Optional, Tuple, Union, Callable +import json E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$") @@ -15,3 +17,376 @@ def make_arn_for_subscription(topic_arn: str) -> str: def is_e164(number: str) -> bool: return E164_REGEX.match(number) is not None + + +class FilterPolicyMatcher: + class CheckException(Exception): + pass + + def __init__(self, filter_policy: Dict[str, Any], filter_policy_scope: str): + self.filter_policy = filter_policy + self.filter_policy_scope = ( + filter_policy_scope + if filter_policy_scope is not None + else "MessageAttributes" + ) + + if self.filter_policy_scope not in ("MessageAttributes", "MessageBody"): + raise FilterPolicyMatcher.CheckException( + f"Unsupported filter_policy_scope: {filter_policy_scope}" + ) + + def matches( + self, message_attributes: Optional[Dict[str, Any]], message: str + ) -> bool: + if not self.filter_policy: + return True + + if self.filter_policy_scope == "MessageAttributes": + if message_attributes is None: + message_attributes = {} + + return self._attributes_based_match(message_attributes) + else: + try: + message_dict = json.loads(message) + except ValueError: + return False + return self._body_based_match(message_dict) + + def _attributes_based_match(self, message_attributes: Dict[str, Any]) -> bool: + return all( + FilterPolicyMatcher._field_match(field, rules, message_attributes) + for field, rules in self.filter_policy.items() + ) + + def _body_based_match(self, message_dict: Dict[str, Any]) -> bool: + try: + checks = self._compute_body_checks(self.filter_policy, message_dict) + except FilterPolicyMatcher.CheckException: + return False + + return self._perform_body_checks(checks) + + def _perform_body_checks(self, check: Any) -> bool: + # If the checks are a list, only a single elem has to pass + # otherwise all the entries have to pass + + if isinstance(check, tuple): + if len(check) == 2: + # (any|all, checks) + aggregate_func, checks = check + return aggregate_func( + self._perform_body_checks(single_check) for single_check in checks + ) + elif len(check) == 3: + field, rules, dict_body = check + return FilterPolicyMatcher._field_match(field, rules, dict_body, False) + + raise FilterPolicyMatcher.CheckException(f"Check is not a tuple: {str(check)}") + + def _compute_body_checks( + self, + filter_policy: Dict[str, Union[Dict[str, Any], List[Any]]], + message_body: Union[Dict[str, Any], List[Any]], + ) -> Tuple[Callable[[Iterable[Any]], bool], Any]: + """ + Generate (possibly nested) list of checks to be performed based on the filter policy + Returned list is of format (any|all, checks), where first elem defines what aggregation should be used in checking + and the second argument is a list containing sublists of the same format or concrete checks (field, rule, body): Tuple[str, List[Any], Dict[str, Any]] + + All the checks returned by this function will only require one-level-deep entry into dict in _field_match function + This is done this way to simplify the actual check logic and keep it as close as possible between MessageAttributes and MessageBody + + Given message_body: + {"Records": [ + { + "eventName": "ObjectCreated:Put", + }, + { + "eventName": "ObjectCreated:Delete", + }, + ]} + + and filter policy: + {"Records": { + "eventName": [{"prefix": "ObjectCreated:"}], + }} + + the following check list would be computed: + (, ( + (, ( + (, ( + ('eventName', [{'prefix': 'ObjectCreated:'}], {'eventName': 'ObjectCreated:Put'}), + ('eventName', [{'prefix': 'ObjectCreated:'}], {'eventName': 'ObjectCreated:Delete'})) + ), + ) + ),)) + """ + rules = [] + for filter_key, filter_value in filter_policy.items(): + if isinstance(filter_value, dict): + if isinstance(message_body, dict): + message_value = message_body.get(filter_key) + if message_value is not None: + rules.append( + self._compute_body_checks(filter_value, message_value) + ) + else: + raise FilterPolicyMatcher.CheckException + elif isinstance(message_body, list): + subchecks = [] + for entry in message_body: + subchecks.append( + self._compute_body_checks(filter_policy, entry) + ) + rules.append((any, tuple(subchecks))) + else: + raise FilterPolicyMatcher.CheckException + + elif isinstance(filter_value, list): + # These are the real rules, same as in MessageAttributes case + + concrete_checks = [] + if isinstance(message_body, dict): + if message_body is not None: + concrete_checks.append((filter_key, filter_value, message_body)) + else: + raise FilterPolicyMatcher.CheckException + elif isinstance(message_body, list): + # Apply policy to each element of the list, pass if at list one element matches + for list_elem in message_body: + concrete_checks.append((filter_key, filter_value, list_elem)) + else: + raise FilterPolicyMatcher.CheckException + rules.append((any, tuple(concrete_checks))) + else: + raise FilterPolicyMatcher.CheckException + + return (all, tuple(rules)) + + @staticmethod + def _field_match( # type: ignore # decorated function contains type Any + field: str, + rules: List[Any], + dict_body: Dict[str, Any], + attributes_based_check: bool = True, + ) -> bool: + # dict_body = MessageAttributes if attributes_based_check is True + # otherwise it's the cut-out part of the MessageBody (so only single-level nesting must be supported) + + # Iterate over every rule from the list of rules + # At least one rule has to match the field for the function to return a match + + def _str_exact_match(value: str, rule: Union[str, List[str]]) -> bool: + if value == rule: + return True + + if isinstance(value, list): + if rule in value: + return True + + try: + json_data = json.loads(value) + if rule in json_data: + return True + except (ValueError, TypeError): + pass + + return False + + def _number_match(values: List[float], rule: float) -> bool: + for value in values: + # Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6 + # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints + if int(value * 1000000) == int(rule * 1000000): + return True + + return False + + def _exists_match( + should_exist: bool, field: str, dict_body: Dict[str, Any] + ) -> bool: + if should_exist and field in dict_body: + return True + elif not should_exist and field not in dict_body: + return True + + return False + + def _prefix_match(prefix: str, value: str) -> bool: + return value.startswith(prefix) + + def _anything_but_match( + filter_value: Union[Dict[str, Any], List[str], str], + actual_values: List[str], + ) -> bool: + if isinstance(filter_value, dict): + # We can combine anything-but with the prefix-filter + anything_but_key = list(filter_value.keys())[0] + anything_but_val = list(filter_value.values())[0] + if anything_but_key != "prefix": + return False + if all([not v.startswith(anything_but_val) for v in actual_values]): + return True + else: + undesired_values = ( + [filter_value] if isinstance(filter_value, str) else filter_value + ) + if all([v not in undesired_values for v in actual_values]): + return True + + return False + + def _numeric_match( + numeric_ranges: Iterable[Tuple[str, float]], numeric_value: float + ) -> bool: + # numeric_ranges' format: + # [(< x), (=, y), (>=, z)] + msg_value = numeric_value + matches = [] + for operator, test_value in numeric_ranges: + if operator == ">": + matches.append((msg_value > test_value)) + if operator == ">=": + matches.append((msg_value >= test_value)) + if operator == "=": + matches.append((msg_value == test_value)) + if operator == "<": + matches.append((msg_value < test_value)) + if operator == "<=": + matches.append((msg_value <= test_value)) + return all(matches) + + for rule in rules: + # TODO: boolean value matching is not supported, SNS behavior unknown + if isinstance(rule, str): + if attributes_based_check: + if field not in dict_body: + return False + if _str_exact_match(dict_body[field]["Value"], rule): + return True + else: + if field not in dict_body: + return False + if _str_exact_match(dict_body[field], rule): + return True + + if isinstance(rule, (int, float)): + if attributes_based_check: + if field not in dict_body: + return False + if dict_body[field]["Type"] == "Number": + attribute_values = [dict_body[field]["Value"]] + elif dict_body[field]["Type"] == "String.Array": + try: + attribute_values = json.loads(dict_body[field]["Value"]) + if not isinstance(attribute_values, list): + attribute_values = [attribute_values] + except (ValueError, TypeError): + return False + else: + return False + + values = [float(value) for value in attribute_values] + if _number_match(values, rule): + return True + else: + if field not in dict_body: + return False + + if isinstance(dict_body[field], (int, float)): + values = [dict_body[field]] + elif isinstance(dict_body[field], list): + values = [float(value) for value in dict_body[field]] + else: + return False + + if _number_match(values, rule): + return True + + if isinstance(rule, dict): + keyword = list(rule.keys())[0] + value = list(rule.values())[0] + if keyword == "exists": + if attributes_based_check: + if _exists_match(value, field, dict_body): + return True + else: + if _exists_match(value, field, dict_body): + return True + elif keyword == "prefix" and isinstance(value, str): + if attributes_based_check: + if field in dict_body: + attr = dict_body[field] + if attr["Type"] == "String": + if _prefix_match(value, attr["Value"]): + return True + else: + if field in dict_body: + if _prefix_match(value, dict_body[field]): + return True + + elif keyword == "anything-but": + if attributes_based_check: + if field not in dict_body: + return False + attr = dict_body[field] + if isinstance(value, dict): + # We can combine anything-but with the prefix-filter + if attr["Type"] == "String": + actual_values = [attr["Value"]] + else: + actual_values = [v for v in attr["Value"]] + else: + if attr["Type"] == "Number": + actual_values = [float(attr["Value"])] + elif attr["Type"] == "String": + actual_values = [attr["Value"]] + else: + actual_values = [v for v in attr["Value"]] + + if _anything_but_match(value, actual_values): + return True + else: + if field not in dict_body: + return False + attr = dict_body[field] + if isinstance(value, dict): + if isinstance(attr, str): + actual_values = [attr] + elif isinstance(attr, list): + actual_values = attr + else: + return False + else: + if isinstance(attr, (int, float, str)): + actual_values = [attr] + elif isinstance(attr, list): + actual_values = attr + else: + return False + + if _anything_but_match(value, actual_values): + return True + + elif keyword == "numeric" and isinstance(value, list): + if attributes_based_check: + if dict_body.get(field, {}).get("Type", "") == "Number": + + checks = value + numeric_ranges = zip(checks[0::2], checks[1::2]) + if _numeric_match( + numeric_ranges, float(dict_body[field]["Value"]) + ): + return True + else: + if field not in dict_body: + return False + + checks = value + numeric_ranges = zip(checks[0::2], checks[1::2]) + if _numeric_match(numeric_ranges, dict_body[field]): + return True + + return False diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 40c7e6a6e..e3d7de43c 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -23,6 +23,14 @@ MESSAGE_FROM_SQS_TEMPLATE = ( ) +def to_comparable_dicts(list_entry: list): + if list_entry: + if isinstance(list_entry[0], dict): + return set(map(json.dumps, list_entry)) + + return set(list_entry) + + @mock_sqs @mock_sns def test_publish_to_sqs(): @@ -656,7 +664,9 @@ def test_publish_deduplication_id_to_non_fifo(): topic.publish(Message="message") -def _setup_filter_policy_test(filter_policy): +def _setup_filter_policy_test( + filter_policy: dict, filter_policy_scope: str = "MessageAttributes" +): sns = boto3.resource("sns", region_name="us-east-1") topic = sns.create_topic(Name="some-topic") @@ -667,6 +677,10 @@ def _setup_filter_policy_test(filter_policy): Protocol="sqs", Endpoint=queue.attributes["QueueArn"] ) + subscription.set_attributes( + AttributeName="FilterPolicyScope", AttributeValue=filter_policy_scope + ) + subscription.set_attributes( AttributeName="FilterPolicy", AttributeValue=json.dumps(filter_policy) ) @@ -695,6 +709,25 @@ def test_filtering_exact_string(): ) +@mock_sqs +@mock_sns +def test_filtering_exact_string_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": ["example_corp"]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"store": "example_corp"}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"result": {"Type": "String", "Value": "match"}}]) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"store": "example_corp"}]) + + @mock_sqs @mock_sns def test_filtering_exact_string_multiple_message_attributes(): @@ -722,6 +755,25 @@ def test_filtering_exact_string_multiple_message_attributes(): ) +@mock_sqs +@mock_sns +def test_filtering_exact_string_multiple_message_attributes_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": ["example_corp"]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"store": "example_corp", "event": "order_cancelled"}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"result": {"Type": "String", "Value": "match"}}]) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"store": "example_corp", "event": "order_cancelled"}]) + + @mock_sqs @mock_sns def test_filtering_exact_string_OR_matching(): @@ -753,6 +805,43 @@ def test_filtering_exact_string_OR_matching(): ) +@mock_sqs +@mock_sns +def test_filtering_exact_string_OR_matching_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": ["example_corp", "different_corp"]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"store": "example_corp"}), + MessageAttributes={ + "result": {"DataType": "String", "StringValue": "match example_corp"} + }, + ) + + topic.publish( + Message=json.dumps({"store": "different_corp"}), + MessageAttributes={ + "result": {"DataType": "String", "StringValue": "match different_corp"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + to_comparable_dicts(message_attributes).should.equal( + to_comparable_dicts( + [ + {"result": {"Type": "String", "Value": "match example_corp"}}, + {"result": {"Type": "String", "Value": "match different_corp"}}, + ] + ) + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts([{"store": "example_corp"}, {"store": "different_corp"}]) + ) + + @mock_sqs @mock_sns def test_filtering_exact_string_AND_matching_positive(): @@ -782,6 +871,40 @@ def test_filtering_exact_string_AND_matching_positive(): ) +@mock_sqs +@mock_sns +def test_filtering_exact_string_AND_matching_positive_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": ["example_corp"], "event": ["order_cancelled"]}, + filter_policy_scope="MessageBody", + ) + + topic.publish( + Message=json.dumps({"store": "example_corp", "event": "order_cancelled"}), + MessageAttributes={ + "result": { + "DataType": "String", + "StringValue": "match example_corp order_cancelled", + } + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": { + "Type": "String", + "Value": "match example_corp order_cancelled", + }, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"store": "example_corp", "event": "order_cancelled"}]) + + @mock_sqs @mock_sns def test_filtering_exact_string_AND_matching_no_match(): @@ -804,6 +927,31 @@ def test_filtering_exact_string_AND_matching_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_exact_string_AND_matching_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": ["example_corp"], "event": ["order_cancelled"]}, + filter_policy_scope="MessageBody", + ) + + topic.publish( + Message=json.dumps({"store": "example_corp", "event": "order_accepted"}), + MessageAttributes={ + "result": { + "DataType": "String", + "StringValue": "match example_corp order_accepted", + } + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_exact_string_no_match(): @@ -823,6 +971,25 @@ def test_filtering_exact_string_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_exact_string_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": ["example_corp"]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"store": "different_corp"}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_exact_string_no_attributes_no_match(): @@ -837,6 +1004,25 @@ def test_filtering_exact_string_no_attributes_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_exact_string_empty_body_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": ["example_corp"]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_exact_number_int(): @@ -854,6 +1040,31 @@ def test_filtering_exact_number_int(): message_attributes.should.equal([{"price": {"Type": "Number", "Value": "100"}}]) +@mock_sqs +@mock_sns +def test_filtering_exact_number_int_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [100]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": 100}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"price": 100}]) + + @mock_sqs @mock_sns def test_filtering_exact_number_float(): @@ -871,6 +1082,31 @@ def test_filtering_exact_number_float(): message_attributes.should.equal([{"price": {"Type": "Number", "Value": "100.1"}}]) +@mock_sqs +@mock_sns +def test_filtering_exact_number_float_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [100.1]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": 100.1}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"price": 100.1}]) + + @mock_sqs @mock_sns def test_filtering_exact_number_float_accuracy(): @@ -892,6 +1128,31 @@ def test_filtering_exact_number_float_accuracy(): ) +@mock_sqs +@mock_sns +def test_filtering_exact_number_float_accuracy_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [100.123456789]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": 100.1234567}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"price": 100.1234567}]) + + @mock_sqs @mock_sns def test_filtering_exact_number_no_match(): @@ -909,6 +1170,25 @@ def test_filtering_exact_number_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_exact_number_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [100]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": 101}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_exact_number_with_string_no_match(): @@ -926,6 +1206,25 @@ def test_filtering_exact_number_with_string_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_exact_number_with_string_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [100]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": "100"}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_string_array_match(): @@ -959,6 +1258,32 @@ def test_filtering_string_array_match(): ) +@mock_sqs +@mock_sns +def test_filtering_string_array_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"customer_interests": ["basketball", "baseball"]}, + filter_policy_scope="MessageBody", + ) + + topic.publish( + Message=json.dumps({"customer_interests": ["basketball", "rugby"]}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"customer_interests": ["basketball", "rugby"]}]) + + @mock_sqs @mock_sns def test_filtering_string_array_no_match(): @@ -981,6 +1306,25 @@ def test_filtering_string_array_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_string_array_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"customer_interests": ["baseball"]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"customer_interests": ["basketball", "rugby"]}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no_match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_string_array_with_number_match(): @@ -1002,6 +1346,31 @@ def test_filtering_string_array_with_number_match(): ) +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [100, 500]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": [100, 50]}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"price": [100, 50]}]) + + @mock_sqs @mock_sns def test_filtering_string_array_with_number_float_accuracy_match(): @@ -1026,6 +1395,31 @@ def test_filtering_string_array_with_number_float_accuracy_match(): ) +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_float_accuracy_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [100.123456789, 500]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": [100.1234567, 50]}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"price": [100.1234567, 50]}]) + + @mock_sqs @mock_sns # this is the correct behavior from SNS @@ -1046,6 +1440,32 @@ def test_filtering_string_array_with_number_no_array_match(): ) +@mock_sqs +@mock_sns +# this is the correct behavior from SNS +def test_filtering_string_array_with_number_no_array_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [100, 500]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": 100}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"price": 100}]) + + @mock_sqs @mock_sns def test_filtering_string_array_with_number_no_match(): @@ -1065,6 +1485,25 @@ def test_filtering_string_array_with_number_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [500]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": [100, 50]}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no_match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns # this is the correct behavior from SNS @@ -1085,6 +1524,25 @@ def test_filtering_string_array_with_string_no_array_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_string_array_with_string_no_array_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"price": [100]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"price": "one hundred"}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no_match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_attribute_key_exists_match(): @@ -1106,6 +1564,31 @@ def test_filtering_attribute_key_exists_match(): ) +@mock_sqs +@mock_sns +def test_filtering_body_key_exists_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": [{"exists": True}]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"store": "example_corp"}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"store": "example_corp"}]) + + @mock_sqs @mock_sns def test_filtering_attribute_key_exists_no_match(): @@ -1125,6 +1608,25 @@ def test_filtering_attribute_key_exists_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_body_key_exists_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": [{"exists": True}]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"event": "order_cancelled"}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_attribute_key_not_exists_match(): @@ -1146,6 +1648,31 @@ def test_filtering_attribute_key_not_exists_match(): ) +@mock_sqs +@mock_sns +def test_filtering_body_key_not_exists_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": [{"exists": False}]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"event": "order_cancelled"}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([{"event": "order_cancelled"}]) + + @mock_sqs @mock_sns def test_filtering_attribute_key_not_exists_no_match(): @@ -1165,6 +1692,25 @@ def test_filtering_attribute_key_not_exists_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_body_key_not_exists_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + {"store": [{"exists": False}]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"store": "example_corp"}), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_all_AND_matching_match(): @@ -1209,6 +1755,53 @@ def test_filtering_all_AND_matching_match(): ) +@mock_sqs +@mock_sns +def test_filtering_all_AND_matching_match_message_body(): + topic, queue = _setup_filter_policy_test( + { + "store": [{"exists": True}], + "event": ["order_cancelled"], + "customer_interests": ["basketball", "baseball"], + "price": [100], + }, + filter_policy_scope="MessageBody", + ) + + topic.publish( + Message=json.dumps( + { + "store": "example_corp", + "event": "order_cancelled", + "customer_interests": ["basketball", "rugby"], + "price": 100, + } + ), + MessageAttributes={"result": {"DataType": "String", "StringValue": "match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "result": {"Type": "String", "Value": "match"}, + } + ] + ) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal( + [ + { + "store": "example_corp", + "event": "order_cancelled", + "customer_interests": ["basketball", "rugby"], + "price": 100, + } + ] + ) + + @mock_sqs @mock_sns def test_filtering_all_AND_matching_no_match(): @@ -1242,6 +1835,39 @@ def test_filtering_all_AND_matching_no_match(): message_attributes.should.equal([]) +@mock_sqs +@mock_sns +def test_filtering_all_AND_matching_no_match_message_body(): + topic, queue = _setup_filter_policy_test( + { + "store": [{"exists": True}], + "event": ["order_cancelled"], + "customer_interests": ["basketball", "baseball"], + "price": [100], + "encrypted": [False], + }, + filter_policy_scope="MessageBody", + ) + + topic.publish( + Message=json.dumps( + { + "store": "example_corp", + "event": "order_cancelled", + "customer_interests": ["basketball", "rugby"], + "price": 100, + } + ), + MessageAttributes={"result": {"DataType": "String", "StringValue": "no match"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + + @mock_sqs @mock_sns def test_filtering_prefix(): @@ -1262,6 +1888,34 @@ def test_filtering_prefix(): set(message_bodies).should.equal({"match1", "match3"}) +@mock_sqs +@mock_sns +def test_filtering_prefix_message_body(): + topic, queue = _setup_filter_policy_test( + { + "customer_interests": [{"prefix": "bas"}], + }, + filter_policy_scope="MessageBody", + ) + + for interest in ["basketball", "rugby", "baseball"]: + topic.publish( + Message=json.dumps( + { + "customer_interests": interest, + } + ) + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts( + [{"customer_interests": "basketball"}, {"customer_interests": "baseball"}] + ) + ) + + @mock_sqs @mock_sns def test_filtering_anything_but(): @@ -1282,6 +1936,34 @@ def test_filtering_anything_but(): set(message_bodies).should.equal({"match2", "match3"}) +@mock_sqs +@mock_sns +def test_filtering_anything_but_message_body(): + topic, queue = _setup_filter_policy_test( + { + "customer_interests": [{"anything-but": "basketball"}], + }, + filter_policy_scope="MessageBody", + ) + + for interest in ["basketball", "rugby", "baseball"]: + topic.publish( + Message=json.dumps( + { + "customer_interests": interest, + } + ) + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts( + [{"customer_interests": "rugby"}, {"customer_interests": "baseball"}] + ) + ) + + @mock_sqs @mock_sns def test_filtering_anything_but_multiple_values(): @@ -1302,6 +1984,32 @@ def test_filtering_anything_but_multiple_values(): set(message_bodies).should.equal({"match3"}) +@mock_sqs +@mock_sns +def test_filtering_anything_but_multiple_values_message_body(): + topic, queue = _setup_filter_policy_test( + { + "customer_interests": [{"anything-but": ["basketball", "rugby"]}], + }, + filter_policy_scope="MessageBody", + ) + + for interest in ["basketball", "rugby", "baseball"]: + topic.publish( + Message=json.dumps( + { + "customer_interests": interest, + } + ) + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts([{"customer_interests": "baseball"}]) + ) + + @mock_sqs @mock_sns def test_filtering_anything_but_prefix(): @@ -1325,23 +2033,53 @@ def test_filtering_anything_but_prefix(): @mock_sqs @mock_sns -def test_filtering_anything_but_unknown(): +def test_filtering_anything_but_prefix_message_body(): topic, queue = _setup_filter_policy_test( - {"customer_interests": [{"anything-but": {"unknown": "bas"}}]} + { + "customer_interests": [{"anything-but": {"prefix": "bas"}}], + }, + filter_policy_scope="MessageBody", ) - for interest, idx in [("basketball", "1"), ("rugby", "2"), ("baseball", "3")]: + for interest in ["basketball", "rugby", "baseball"]: topic.publish( - Message=f"match{idx}", - MessageAttributes={ - "customer_interests": {"DataType": "String", "StringValue": interest}, - }, + Message=json.dumps( + { + "customer_interests": interest, + } + ) ) - # This should match rugby only messages = queue.receive_messages(MaxNumberOfMessages=5) message_bodies = [json.loads(m.body)["Message"] for m in messages] - message_bodies.should.equal([]) + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts([{"customer_interests": "rugby"}]) + ) + + +@mock_sqs +@mock_sns +def test_filtering_anything_but_unknown(): + try: + _setup_filter_policy_test( + {"customer_interests": [{"anything-but": {"unknown": "bas"}}]} + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + + +@mock_sqs +@mock_sns +def test_filtering_anything_but_unknown_message_body_raises(): + try: + _setup_filter_policy_test( + { + "customer_interests": [{"anything-but": {"unknown": "bas"}}], + }, + filter_policy_scope="MessageBody", + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") @mock_sqs @@ -1364,6 +2102,86 @@ def test_filtering_anything_but_numeric(): set(message_bodies).should.equal({"match1", "match3"}) +@mock_sqs +@mock_sns +def test_filtering_anything_but_numeric_message_body(): + topic, queue = _setup_filter_policy_test( + { + "customer_interests": [{"anything-but": [100]}], + }, + filter_policy_scope="MessageBody", + ) + + for nr in [50, 100, 150]: + topic.publish( + Message=json.dumps( + { + "customer_interests": nr, + } + ) + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + + message_bodies = [json.loads(m.body)["Message"] for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts([{"customer_interests": 50}, {"customer_interests": 150}]) + ) + + +@mock_sqs +@mock_sns +def test_filtering_anything_but_numeric_string(): + topic, queue = _setup_filter_policy_test( + {"customer_interests": [{"anything-but": ["100"]}]} + ) + + for nr, idx in [("50", "1"), ("100", "2"), ("150", "3")]: + topic.publish( + Message=f"match{idx}", + MessageAttributes={ + "customer_interests": {"DataType": "Number", "StringValue": nr}, + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + set(message_bodies).should.equal({"match1", "match2", "match3"}) + + +@mock_sqs +@mock_sns +def test_filtering_anything_but_numeric_string_message_body(): + topic, queue = _setup_filter_policy_test( + { + "customer_interests": [{"anything-but": ["100"]}], + }, + filter_policy_scope="MessageBody", + ) + + for nr in [50, 100, 150]: + topic.publish( + Message=json.dumps( + { + "customer_interests": nr, + } + ) + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + + message_bodies = [json.loads(m.body)["Message"] for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts( + [ + {"customer_interests": 50}, + {"customer_interests": 100}, + {"customer_interests": 150}, + ] + ) + ) + + @mock_sqs @mock_sns def test_filtering_numeric_match(): @@ -1384,6 +2202,32 @@ def test_filtering_numeric_match(): set(message_bodies).should.equal({"match2"}) +@mock_sqs +@mock_sns +def test_filtering_numeric_match_message_body(): + topic, queue = _setup_filter_policy_test( + { + "customer_interests": [{"numeric": ["=", 100]}], + }, + filter_policy_scope="MessageBody", + ) + + for nr in [50, 100, 150]: + topic.publish( + Message=json.dumps( + { + "customer_interests": nr, + } + ) + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts([{"customer_interests": 100}]) + ) + + @mock_sqs @mock_sns def test_filtering_numeric_range(): @@ -1402,3 +2246,297 @@ def test_filtering_numeric_range(): messages = queue.receive_messages(MaxNumberOfMessages=5) message_bodies = [json.loads(m.body)["Message"] for m in messages] set(message_bodies).should.equal({"match1", "match2"}) + + +@mock_sqs +@mock_sns +def test_filtering_numeric_range_message_body(): + topic, queue = _setup_filter_policy_test( + { + "customer_interests": [{"numeric": [">", 49, "<=", 100]}], + }, + filter_policy_scope="MessageBody", + ) + + for nr in [50, 100, 150]: + topic.publish( + Message=json.dumps( + { + "customer_interests": nr, + } + ) + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts([{"customer_interests": 50}, {"customer_interests": 100}]) + ) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_message_body_invalid_json_no_match(): + topic, queue = _setup_filter_policy_test( + {"store": ["example_corp"]}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message='{"store": "another_corp"', + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_message_body_empty_filter_policy_match(): + topic, queue = _setup_filter_policy_test({}, filter_policy_scope="MessageBody") + + topic.publish( + Message='{"store": "another_corp"}', + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + to_comparable_dicts(message_attributes).should.equal( + to_comparable_dicts([{"match": {"Type": "String", "Value": "body"}}]) + ) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts([{"store": "another_corp"}]) + ) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_message_body_nested(): + topic, queue = _setup_filter_policy_test( + {"store": {"name": ["example_corp"]}}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"store": {"name": "example_corp"}}), + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"match": {"Type": "String", "Value": "body"}}]) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts([{"store": {"name": "example_corp"}}]) + ) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_message_body_nested_no_match(): + topic, queue = _setup_filter_policy_test( + {"store": {"name": ["example_corp"]}}, filter_policy_scope="MessageBody" + ) + + topic.publish( + Message=json.dumps({"store": {"name": "another_corp"}}), + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_message_body_nested_prefix(): + topic, queue = _setup_filter_policy_test( + {"store": {"name": [{"prefix": "example_corp"}]}}, + filter_policy_scope="MessageBody", + ) + + topic.publish( + Message=json.dumps({"store": {"name": "example_corp"}}), + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"match": {"Type": "String", "Value": "body"}}]) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + to_comparable_dicts(message_bodies).should.equal( + to_comparable_dicts([{"store": {"name": "example_corp"}}]) + ) + + +@mock_sqs +@mock_sns +def test_filtering_message_body_nested_prefix_no_match(): + topic, queue = _setup_filter_policy_test( + {"store": {"name": [{"prefix": "example_corp"}]}}, + filter_policy_scope="MessageBody", + ) + + topic.publish( + Message=json.dumps({"store": {"name": "another_corp-1"}}), + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_message_body_nested_multiple_prefix(): + topic, queue = _setup_filter_policy_test( + { + "Records": { + "s3": {"object": {"key": [{"prefix": "test-"}]}}, + "eventName": [{"prefix": "ObjectCreated:"}], + } + }, + filter_policy_scope="MessageBody", + ) + + payload = { + "Records": [ + { + "eventName": "ObjectCreated:Put", + "s3": { + "object": { + "key": "test-entry.xml", + } + }, + } + ] + } + + topic.publish( + Message=json.dumps(payload), + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"match": {"Type": "String", "Value": "body"}}]) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([payload]) + + +@mock_sqs +@mock_sns +def test_filtering_message_body_nested_multiple_prefix_no_match(): + topic, queue = _setup_filter_policy_test( + { + "Records": { + "s3": {"object": {"key": [{"prefix": "test-"}]}}, + "eventName": [{"prefix": "ObjectCreated:"}], + } + }, + filter_policy_scope="MessageBody", + ) + + payload = { + "Records": [ + { + "eventName": "ObjectCreated:Put", + "s3": { + "object": { + "key": "no-match-entry.xml", + } + }, + } + ] + } + + topic.publish( + Message=json.dumps(payload), + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_message_body_nested_multiple_records_partial_match(): + topic, queue = _setup_filter_policy_test( + { + "Records": { + "eventName": [{"prefix": "ObjectCreated:"}], + } + }, + filter_policy_scope="MessageBody", + ) + + payload = { + "Records": [ + { + "eventName": "ObjectCreated:Put", + }, + { + "eventName": "ObjectDeleted:Delete", + }, + ] + } + + topic.publish( + Message=json.dumps(payload), + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"match": {"Type": "String", "Value": "body"}}]) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([payload]) + + +@mock_sqs +@mock_sns +def test_filtering_message_body_nested_multiple_records_match(): + topic, queue = _setup_filter_policy_test( + { + "Records": { + "eventName": [{"prefix": "ObjectCreated:"}], + } + }, + filter_policy_scope="MessageBody", + ) + + payload = { + "Records": [ + { + "eventName": "ObjectCreated:Put", + }, + { + "eventName": "ObjectCreated:Put", + }, + ] + } + + topic.publish( + Message=json.dumps(payload), + MessageAttributes={"match": {"DataType": "String", "StringValue": "body"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"match": {"Type": "String", "Value": "body"}}]) + message_bodies = [json.loads(json.loads(m.body)["Message"]) for m in messages] + message_bodies.should.equal([payload]) diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index 9066c3cae..226ff2073 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -250,7 +250,6 @@ def test_creating_subscription_with_attributes(): filter_policy = json.dumps( { "store": ["example_corp"], - "event": ["order_cancelled"], "encrypted": [False], "customer_interests": ["basketball", "baseball"], "price": [100, 100.12], @@ -371,7 +370,6 @@ def test_set_subscription_attributes(): filter_policy = json.dumps( { "store": ["example_corp"], - "event": ["order_cancelled"], "encrypted": [False], "customer_interests": ["basketball", "baseball"], "price": [100, 100.12], @@ -390,6 +388,17 @@ def test_set_subscription_attributes(): attrs["Attributes"]["DeliveryPolicy"].should.equal(delivery_policy) attrs["Attributes"]["FilterPolicy"].should.equal(filter_policy) + filter_policy_scope = "MessageBody" + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName="FilterPolicyScope", + AttributeValue=filter_policy_scope, + ) + + attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn) + + attrs["Attributes"]["FilterPolicyScope"].should.equal(filter_policy_scope) + # not existing subscription with pytest.raises(ClientError): conn.set_subscription_attributes( @@ -641,6 +650,70 @@ def test_subscribe_invalid_filter_policy(): "Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..." ) + try: + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={ + "FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}}) + }, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy" + ) + + try: + filter_policy = { + "key_a": ["value_one"], + "key_b": ["value_two"], + "key_c": ["value_three"], + "key_d": ["value_four"], + "key_e": ["value_five"], + "key_f": ["value_six"], + } + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"FilterPolicy": json.dumps(filter_policy)}, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys" + ) + + try: + nested_filter_policy = { + "key_a": { + "key_b": { + "key_c": ["value_one", "value_two", "value_three", "value_four"] + }, + }, + "key_d": {"key_e": ["value_one", "value_two", "value_three"]}, + "key_f": ["value_one", "value_two", "value_three"], + } + # The first array has four values in a three-level nested key, and the second has three values in a two-level + # nested key. The total combination is calculated as follows: + # 3 x 4 x 2 x 3 x 1 x 3 = 216 + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={ + "FilterPolicyScope": "MessageBody", + "FilterPolicy": json.dumps(nested_filter_policy), + }, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: Filter policy is too complex" + ) + @mock_sns def test_check_not_opted_out(): diff --git a/tests/test_sns/test_utils.py b/tests/test_sns/test_utils.py new file mode 100644 index 000000000..4ef9afe1c --- /dev/null +++ b/tests/test_sns/test_utils.py @@ -0,0 +1,17 @@ +from moto.sns.utils import FilterPolicyMatcher +import pytest + + +def test_filter_policy_matcher_scope_sanity_check(): + with pytest.raises(FilterPolicyMatcher.CheckException): + FilterPolicyMatcher({}, "IncorrectFilterPolicyScope") + + +def test_filter_policy_matcher_empty_message_attributes(): + matcher = FilterPolicyMatcher({}, None) + assert matcher.matches(None, "") + + +def test_filter_policy_matcher_empty_message_attributes_filtering_fail(): + matcher = FilterPolicyMatcher({"store": ["test"]}, None) + assert not matcher.matches(None, "")