SNS: Support filters , suffix, equals-ignore-case (#7390)

This commit is contained in:
Bert Blommers 2024-02-25 13:13:30 +00:00 committed by GitHub
parent 644407b8c6
commit 14b3db77b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 187 additions and 25 deletions

View File

@ -827,7 +827,11 @@ class SNSBackend(BaseBackend):
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy" "Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
) )
elif isinstance(_value, list): elif isinstance(_value, list):
_rules.append(_value) if key == "$or":
for val in _value:
_rules.extend(aggregate_rules(val, depth=depth + 1))
else:
_rules.append(_value)
combinations = combinations * len(_value) * depth combinations = combinations * len(_value) * depth
else: else:
raise SNSInvalidParameter( raise SNSInvalidParameter(
@ -864,7 +868,7 @@ class SNSBackend(BaseBackend):
if isinstance(rule, dict): if isinstance(rule, dict):
keyword = list(rule.keys())[0] keyword = list(rule.keys())[0]
attributes = list(rule.values())[0] attributes = list(rule.values())[0]
if keyword == "anything-but": if keyword in ["anything-but", "equals-ignore-case"]:
continue continue
elif keyword == "exists": elif keyword == "exists":
if not isinstance(attributes, bool): if not isinstance(attributes, bool):
@ -938,7 +942,7 @@ class SNSBackend(BaseBackend):
) )
continue continue
elif keyword == "prefix": elif keyword in ["prefix", "suffix"]:
continue continue
else: else:
raise SNSInvalidParameter( raise SNSInvalidParameter(

View File

@ -47,7 +47,9 @@ class FilterPolicyMatcher:
if message_attributes is None: if message_attributes is None:
message_attributes = {} message_attributes = {}
return self._attributes_based_match(message_attributes) return FilterPolicyMatcher._attributes_based_match(
message_attributes, source=self.filter_policy
)
else: else:
try: try:
message_dict = json.loads(message) message_dict = json.loads(message)
@ -55,10 +57,13 @@ class FilterPolicyMatcher:
return False return False
return self._body_based_match(message_dict) return self._body_based_match(message_dict)
def _attributes_based_match(self, message_attributes: Dict[str, Any]) -> bool: @staticmethod
def _attributes_based_match( # type: ignore[misc]
message_attributes: Dict[str, Any], source: Dict[str, Any]
) -> bool:
return all( return all(
FilterPolicyMatcher._field_match(field, rules, message_attributes) FilterPolicyMatcher._field_match(field, rules, message_attributes)
for field, rules in self.filter_policy.items() for field, rules in source.items()
) )
def _body_based_match(self, message_dict: Dict[str, Any]) -> bool: def _body_based_match(self, message_dict: Dict[str, Any]) -> bool:
@ -218,6 +223,9 @@ class FilterPolicyMatcher:
def _prefix_match(prefix: str, value: str) -> bool: def _prefix_match(prefix: str, value: str) -> bool:
return value.startswith(prefix) return value.startswith(prefix)
def _suffix_match(prefix: str, value: str) -> bool:
return value.endswith(prefix)
def _anything_but_match( def _anything_but_match(
filter_value: Union[Dict[str, Any], List[str], str], filter_value: Union[Dict[str, Any], List[str], str],
actual_values: List[str], actual_values: List[str],
@ -316,6 +324,19 @@ class FilterPolicyMatcher:
else: else:
if _exists_match(value, field, dict_body): if _exists_match(value, field, dict_body):
return True return True
elif keyword == "equals-ignore-case" and isinstance(value, str):
if attributes_based_check:
if field not in dict_body:
return False
if _str_exact_match(dict_body[field]["Value"].lower(), value):
return True
else:
if field not in dict_body:
return False
if _str_exact_match(dict_body[field].lower(), value):
return True
elif keyword == "prefix" and isinstance(value, str): elif keyword == "prefix" and isinstance(value, str):
if attributes_based_check: if attributes_based_check:
if field in dict_body: if field in dict_body:
@ -327,6 +348,17 @@ class FilterPolicyMatcher:
if field in dict_body: if field in dict_body:
if _prefix_match(value, dict_body[field]): if _prefix_match(value, dict_body[field]):
return True return True
elif keyword == "suffix" and isinstance(value, str):
if attributes_based_check:
if field in dict_body:
attr = dict_body[field]
if attr["Type"] == "String":
if _suffix_match(value, attr["Value"]):
return True
else:
if field in dict_body:
if _suffix_match(value, dict_body[field]):
return True
elif keyword == "anything-but": elif keyword == "anything-but":
if attributes_based_check: if attributes_based_check:
@ -390,4 +422,12 @@ class FilterPolicyMatcher:
if _numeric_match(numeric_ranges, dict_body[field]): if _numeric_match(numeric_ranges, dict_body[field]):
return True return True
if field == "$or" and isinstance(rules, list):
return any(
[
FilterPolicyMatcher._attributes_based_match(dict_body, rule)
for rule in rules
]
)
return False return False

View File

@ -1764,6 +1764,45 @@ def test_filtering_all_AND_matching_no_match_message_body():
assert message_bodies == [] assert message_bodies == []
@mock_aws
def test_filtering_or():
filter_policy = {
"source": ["aws.cloudwatch"],
"$or": [
{"metricName": ["CPUUtilization"]},
{"namespace": ["AWS/EC2"]},
],
}
topic, queue = _setup_filter_policy_test(filter_policy)
topic.publish(
Message="match_first",
MessageAttributes={
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"},
"metricName": {"DataType": "String", "StringValue": "CPUUtilization"},
},
)
topic.publish(
Message="match_second",
MessageAttributes={
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"},
"namespace": {"DataType": "String", "StringValue": "AWS/EC2"},
},
)
topic.publish(
Message="no_match",
MessageAttributes={
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"}
},
)
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert sorted(message_bodies) == ["match_first", "match_second"]
@mock_aws @mock_aws
def test_filtering_prefix(): def test_filtering_prefix():
topic, queue = _setup_filter_policy_test( topic, queue = _setup_filter_policy_test(
@ -1808,6 +1847,42 @@ def test_filtering_prefix_message_body():
) )
@mock_aws
def test_filtering_suffix():
topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"suffix": "ball"}]}
)
for interest, idx in [("basketball", "1"), ("rugby", "2"), ("baseball", "3")]:
topic.publish(
Message=f"match{idx}",
MessageAttributes={
"customer_interests": {"DataType": "String", "StringValue": interest},
},
)
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert set(message_bodies) == {"match1", "match3"}
@mock_aws
def test_filtering_suffix_message_body():
topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"suffix": "ball"}]},
filter_policy_scope="MessageBody",
)
for interest in ["basketball", "rugby", "baseball"]:
topic.publish(Message=json.dumps({"customer_interests": interest}))
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert to_comparable_dicts(message_bodies) == to_comparable_dicts(
[{"customer_interests": "basketball"}, {"customer_interests": "baseball"}]
)
@mock_aws @mock_aws
def test_filtering_anything_but(): def test_filtering_anything_but():
topic, queue = _setup_filter_policy_test( topic, queue = _setup_filter_policy_test(
@ -2058,6 +2133,51 @@ def test_filtering_anything_but_numeric_string_message_body():
) )
@mock_aws
def test_filtering_equals_ignore_case():
topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"equals-ignore-case": "tennis"}]}
)
for interest, idx in [
("tenis", "1"),
("TeNnis", "2"),
("tennis", "3"),
("baseball", "4"),
]:
topic.publish(
Message=f"match{idx}",
MessageAttributes={
"customer_interests": {"DataType": "String", "StringValue": interest},
},
)
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert sorted(message_bodies) == ["match2", "match3"]
@mock_aws
def test_filtering_equals_ignore_case_message_body():
topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"equals-ignore-case": "tennis"}]},
filter_policy_scope="MessageBody",
)
for interest in ["tenis", "TeNnis", "tennis", "baseball"]:
topic.publish(Message=json.dumps({"customer_interests": interest}))
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert to_comparable_dicts(message_bodies) == to_comparable_dicts(
[
{"customer_interests": "TeNnis"},
{"customer_interests": "tennis"},
]
)
@mock_aws @mock_aws
def test_filtering_numeric_match(): def test_filtering_numeric_match():
topic, queue = _setup_filter_policy_test( topic, queue = _setup_filter_policy_test(

View File

@ -671,7 +671,7 @@ def test_subscribe_invalid_filter_policy():
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..." "Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
) )
try: with pytest.raises(ClientError) as err_info:
conn.subscribe( conn.subscribe(
TopicArn=topic_arn, TopicArn=topic_arn,
Protocol="http", Protocol="http",
@ -680,14 +680,13 @@ def test_subscribe_invalid_filter_policy():
"FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}}) "FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}})
}, },
) )
except ClientError as err: assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
assert err.response["Error"]["Code"] == "InvalidParameter" assert (
assert err.response["Error"]["Message"] == ( err_info.value.response["Error"]["Message"]
"Invalid parameter: Filter policy scope MessageAttributes does " == "Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
"not support nested filter policy" )
)
try: with pytest.raises(ClientError) as err_info:
filter_policy = { filter_policy = {
"key_a": ["value_one"], "key_a": ["value_one"],
"key_b": ["value_two"], "key_b": ["value_two"],
@ -702,13 +701,13 @@ def test_subscribe_invalid_filter_policy():
Endpoint="http://example.com/", Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps(filter_policy)}, Attributes={"FilterPolicy": json.dumps(filter_policy)},
) )
except ClientError as err: assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
assert err.response["Error"]["Code"] == "InvalidParameter" assert (
assert err.response["Error"]["Message"] == ( err_info.value.response["Error"]["Message"]
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys" == "Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
) )
try: with pytest.raises(ClientError) as err_info:
nested_filter_policy = { nested_filter_policy = {
"key_a": { "key_a": {
"key_b": { "key_b": {
@ -731,11 +730,10 @@ def test_subscribe_invalid_filter_policy():
"FilterPolicy": json.dumps(nested_filter_policy), "FilterPolicy": json.dumps(nested_filter_policy),
}, },
) )
except ClientError as err: assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
assert err.response["Error"]["Code"] == "InvalidParameter" assert err_info.value.response["Error"]["Message"] == (
assert err.response["Error"]["Message"] == ( "Invalid parameter: FilterPolicy: Filter policy is too complex"
"Invalid parameter: FilterPolicy: Filter policy is too complex" )
)
@mock_aws @mock_aws