From 14b3db77b927ff18c59be41727c432b389696cfa Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 25 Feb 2024 13:13:30 +0000 Subject: [PATCH] SNS: Support filters , suffix, equals-ignore-case (#7390) --- moto/sns/models.py | 10 +- moto/sns/utils.py | 46 +++++++- tests/test_sns/test_publishing_boto3.py | 120 +++++++++++++++++++++ tests/test_sns/test_subscriptions_boto3.py | 36 +++---- 4 files changed, 187 insertions(+), 25 deletions(-) diff --git a/moto/sns/models.py b/moto/sns/models.py index 78c8690fb..481599150 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -827,7 +827,11 @@ class SNSBackend(BaseBackend): "Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy" ) elif isinstance(_value, list): - _rules.append(_value) + if key == "$or": + for val in _value: + _rules.extend(aggregate_rules(val, depth=depth + 1)) + else: + _rules.append(_value) combinations = combinations * len(_value) * depth else: raise SNSInvalidParameter( @@ -864,7 +868,7 @@ class SNSBackend(BaseBackend): if isinstance(rule, dict): keyword = list(rule.keys())[0] attributes = list(rule.values())[0] - if keyword == "anything-but": + if keyword in ["anything-but", "equals-ignore-case"]: continue elif keyword == "exists": if not isinstance(attributes, bool): @@ -938,7 +942,7 @@ class SNSBackend(BaseBackend): ) continue - elif keyword == "prefix": + elif keyword in ["prefix", "suffix"]: continue else: raise SNSInvalidParameter( diff --git a/moto/sns/utils.py b/moto/sns/utils.py index 901498416..e71fa43d1 100644 --- a/moto/sns/utils.py +++ b/moto/sns/utils.py @@ -47,7 +47,9 @@ class FilterPolicyMatcher: if message_attributes is None: message_attributes = {} - return self._attributes_based_match(message_attributes) + return FilterPolicyMatcher._attributes_based_match( + message_attributes, source=self.filter_policy + ) else: try: message_dict = json.loads(message) @@ -55,10 +57,13 @@ class FilterPolicyMatcher: return False return self._body_based_match(message_dict) - def _attributes_based_match(self, message_attributes: Dict[str, Any]) -> bool: + @staticmethod + def _attributes_based_match( # type: ignore[misc] + message_attributes: Dict[str, Any], source: Dict[str, Any] + ) -> bool: return all( FilterPolicyMatcher._field_match(field, rules, message_attributes) - for field, rules in self.filter_policy.items() + for field, rules in source.items() ) def _body_based_match(self, message_dict: Dict[str, Any]) -> bool: @@ -218,6 +223,9 @@ class FilterPolicyMatcher: def _prefix_match(prefix: str, value: str) -> bool: return value.startswith(prefix) + def _suffix_match(prefix: str, value: str) -> bool: + return value.endswith(prefix) + def _anything_but_match( filter_value: Union[Dict[str, Any], List[str], str], actual_values: List[str], @@ -316,6 +324,19 @@ class FilterPolicyMatcher: else: if _exists_match(value, field, dict_body): return True + + elif keyword == "equals-ignore-case" and isinstance(value, str): + if attributes_based_check: + if field not in dict_body: + return False + if _str_exact_match(dict_body[field]["Value"].lower(), value): + return True + else: + if field not in dict_body: + return False + if _str_exact_match(dict_body[field].lower(), value): + return True + elif keyword == "prefix" and isinstance(value, str): if attributes_based_check: if field in dict_body: @@ -327,6 +348,17 @@ class FilterPolicyMatcher: if field in dict_body: if _prefix_match(value, dict_body[field]): return True + elif keyword == "suffix" and isinstance(value, str): + if attributes_based_check: + if field in dict_body: + attr = dict_body[field] + if attr["Type"] == "String": + if _suffix_match(value, attr["Value"]): + return True + else: + if field in dict_body: + if _suffix_match(value, dict_body[field]): + return True elif keyword == "anything-but": if attributes_based_check: @@ -390,4 +422,12 @@ class FilterPolicyMatcher: if _numeric_match(numeric_ranges, dict_body[field]): return True + if field == "$or" and isinstance(rules, list): + return any( + [ + FilterPolicyMatcher._attributes_based_match(dict_body, rule) + for rule in rules + ] + ) + return False diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 18961b3e2..d6c82ab64 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -1764,6 +1764,45 @@ def test_filtering_all_AND_matching_no_match_message_body(): assert message_bodies == [] +@mock_aws +def test_filtering_or(): + filter_policy = { + "source": ["aws.cloudwatch"], + "$or": [ + {"metricName": ["CPUUtilization"]}, + {"namespace": ["AWS/EC2"]}, + ], + } + topic, queue = _setup_filter_policy_test(filter_policy) + + topic.publish( + Message="match_first", + MessageAttributes={ + "source": {"DataType": "String", "StringValue": "aws.cloudwatch"}, + "metricName": {"DataType": "String", "StringValue": "CPUUtilization"}, + }, + ) + + topic.publish( + Message="match_second", + MessageAttributes={ + "source": {"DataType": "String", "StringValue": "aws.cloudwatch"}, + "namespace": {"DataType": "String", "StringValue": "AWS/EC2"}, + }, + ) + + topic.publish( + Message="no_match", + MessageAttributes={ + "source": {"DataType": "String", "StringValue": "aws.cloudwatch"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + assert sorted(message_bodies) == ["match_first", "match_second"] + + @mock_aws def test_filtering_prefix(): topic, queue = _setup_filter_policy_test( @@ -1808,6 +1847,42 @@ def test_filtering_prefix_message_body(): ) +@mock_aws +def test_filtering_suffix(): + topic, queue = _setup_filter_policy_test( + {"customer_interests": [{"suffix": "ball"}]} + ) + + for interest, idx in [("basketball", "1"), ("rugby", "2"), ("baseball", "3")]: + topic.publish( + Message=f"match{idx}", + MessageAttributes={ + "customer_interests": {"DataType": "String", "StringValue": interest}, + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + assert set(message_bodies) == {"match1", "match3"} + + +@mock_aws +def test_filtering_suffix_message_body(): + topic, queue = _setup_filter_policy_test( + {"customer_interests": [{"suffix": "ball"}]}, + filter_policy_scope="MessageBody", + ) + + for interest in ["basketball", "rugby", "baseball"]: + topic.publish(Message=json.dumps({"customer_interests": interest})) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + assert to_comparable_dicts(message_bodies) == to_comparable_dicts( + [{"customer_interests": "basketball"}, {"customer_interests": "baseball"}] + ) + + @mock_aws def test_filtering_anything_but(): topic, queue = _setup_filter_policy_test( @@ -2058,6 +2133,51 @@ def test_filtering_anything_but_numeric_string_message_body(): ) +@mock_aws +def test_filtering_equals_ignore_case(): + topic, queue = _setup_filter_policy_test( + {"customer_interests": [{"equals-ignore-case": "tennis"}]} + ) + + for interest, idx in [ + ("tenis", "1"), + ("TeNnis", "2"), + ("tennis", "3"), + ("baseball", "4"), + ]: + topic.publish( + Message=f"match{idx}", + MessageAttributes={ + "customer_interests": {"DataType": "String", "StringValue": interest}, + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + assert sorted(message_bodies) == ["match2", "match3"] + + +@mock_aws +def test_filtering_equals_ignore_case_message_body(): + topic, queue = _setup_filter_policy_test( + {"customer_interests": [{"equals-ignore-case": "tennis"}]}, + filter_policy_scope="MessageBody", + ) + + for interest in ["tenis", "TeNnis", "tennis", "baseball"]: + topic.publish(Message=json.dumps({"customer_interests": interest})) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + + message_bodies = [json.loads(m.body)["Message"] for m in messages] + assert to_comparable_dicts(message_bodies) == to_comparable_dicts( + [ + {"customer_interests": "TeNnis"}, + {"customer_interests": "tennis"}, + ] + ) + + @mock_aws def test_filtering_numeric_match(): topic, queue = _setup_filter_policy_test( diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index 1f7b18279..1ed7d6b1b 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -671,7 +671,7 @@ def test_subscribe_invalid_filter_policy(): "Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..." ) - try: + with pytest.raises(ClientError) as err_info: conn.subscribe( TopicArn=topic_arn, Protocol="http", @@ -680,14 +680,13 @@ def test_subscribe_invalid_filter_policy(): "FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}}) }, ) - except ClientError as err: - assert err.response["Error"]["Code"] == "InvalidParameter" - assert err.response["Error"]["Message"] == ( - "Invalid parameter: Filter policy scope MessageAttributes does " - "not support nested filter policy" - ) + assert err_info.value.response["Error"]["Code"] == "InvalidParameter" + assert ( + err_info.value.response["Error"]["Message"] + == "Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy" + ) - try: + with pytest.raises(ClientError) as err_info: filter_policy = { "key_a": ["value_one"], "key_b": ["value_two"], @@ -702,13 +701,13 @@ def test_subscribe_invalid_filter_policy(): Endpoint="http://example.com/", Attributes={"FilterPolicy": json.dumps(filter_policy)}, ) - except ClientError as err: - assert err.response["Error"]["Code"] == "InvalidParameter" - assert err.response["Error"]["Message"] == ( - "Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys" - ) + assert err_info.value.response["Error"]["Code"] == "InvalidParameter" + assert ( + err_info.value.response["Error"]["Message"] + == "Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys" + ) - try: + with pytest.raises(ClientError) as err_info: nested_filter_policy = { "key_a": { "key_b": { @@ -731,11 +730,10 @@ def test_subscribe_invalid_filter_policy(): "FilterPolicy": json.dumps(nested_filter_policy), }, ) - except ClientError as err: - assert err.response["Error"]["Code"] == "InvalidParameter" - assert err.response["Error"]["Message"] == ( - "Invalid parameter: FilterPolicy: Filter policy is too complex" - ) + assert err_info.value.response["Error"]["Code"] == "InvalidParameter" + assert err_info.value.response["Error"]["Message"] == ( + "Invalid parameter: FilterPolicy: Filter policy is too complex" + ) @mock_aws