SNS: Support filters , suffix, equals-ignore-case (#7390)

This commit is contained in:
Bert Blommers 2024-02-25 13:13:30 +00:00 committed by GitHub
parent 644407b8c6
commit 14b3db77b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 187 additions and 25 deletions

View File

@ -827,7 +827,11 @@ class SNSBackend(BaseBackend):
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
)
elif isinstance(_value, list):
_rules.append(_value)
if key == "$or":
for val in _value:
_rules.extend(aggregate_rules(val, depth=depth + 1))
else:
_rules.append(_value)
combinations = combinations * len(_value) * depth
else:
raise SNSInvalidParameter(
@ -864,7 +868,7 @@ class SNSBackend(BaseBackend):
if isinstance(rule, dict):
keyword = list(rule.keys())[0]
attributes = list(rule.values())[0]
if keyword == "anything-but":
if keyword in ["anything-but", "equals-ignore-case"]:
continue
elif keyword == "exists":
if not isinstance(attributes, bool):
@ -938,7 +942,7 @@ class SNSBackend(BaseBackend):
)
continue
elif keyword == "prefix":
elif keyword in ["prefix", "suffix"]:
continue
else:
raise SNSInvalidParameter(

View File

@ -47,7 +47,9 @@ class FilterPolicyMatcher:
if message_attributes is None:
message_attributes = {}
return self._attributes_based_match(message_attributes)
return FilterPolicyMatcher._attributes_based_match(
message_attributes, source=self.filter_policy
)
else:
try:
message_dict = json.loads(message)
@ -55,10 +57,13 @@ class FilterPolicyMatcher:
return False
return self._body_based_match(message_dict)
def _attributes_based_match(self, message_attributes: Dict[str, Any]) -> bool:
@staticmethod
def _attributes_based_match( # type: ignore[misc]
message_attributes: Dict[str, Any], source: Dict[str, Any]
) -> bool:
return all(
FilterPolicyMatcher._field_match(field, rules, message_attributes)
for field, rules in self.filter_policy.items()
for field, rules in source.items()
)
def _body_based_match(self, message_dict: Dict[str, Any]) -> bool:
@ -218,6 +223,9 @@ class FilterPolicyMatcher:
def _prefix_match(prefix: str, value: str) -> bool:
return value.startswith(prefix)
def _suffix_match(prefix: str, value: str) -> bool:
return value.endswith(prefix)
def _anything_but_match(
filter_value: Union[Dict[str, Any], List[str], str],
actual_values: List[str],
@ -316,6 +324,19 @@ class FilterPolicyMatcher:
else:
if _exists_match(value, field, dict_body):
return True
elif keyword == "equals-ignore-case" and isinstance(value, str):
if attributes_based_check:
if field not in dict_body:
return False
if _str_exact_match(dict_body[field]["Value"].lower(), value):
return True
else:
if field not in dict_body:
return False
if _str_exact_match(dict_body[field].lower(), value):
return True
elif keyword == "prefix" and isinstance(value, str):
if attributes_based_check:
if field in dict_body:
@ -327,6 +348,17 @@ class FilterPolicyMatcher:
if field in dict_body:
if _prefix_match(value, dict_body[field]):
return True
elif keyword == "suffix" and isinstance(value, str):
if attributes_based_check:
if field in dict_body:
attr = dict_body[field]
if attr["Type"] == "String":
if _suffix_match(value, attr["Value"]):
return True
else:
if field in dict_body:
if _suffix_match(value, dict_body[field]):
return True
elif keyword == "anything-but":
if attributes_based_check:
@ -390,4 +422,12 @@ class FilterPolicyMatcher:
if _numeric_match(numeric_ranges, dict_body[field]):
return True
if field == "$or" and isinstance(rules, list):
return any(
[
FilterPolicyMatcher._attributes_based_match(dict_body, rule)
for rule in rules
]
)
return False

View File

@ -1764,6 +1764,45 @@ def test_filtering_all_AND_matching_no_match_message_body():
assert message_bodies == []
@mock_aws
def test_filtering_or():
filter_policy = {
"source": ["aws.cloudwatch"],
"$or": [
{"metricName": ["CPUUtilization"]},
{"namespace": ["AWS/EC2"]},
],
}
topic, queue = _setup_filter_policy_test(filter_policy)
topic.publish(
Message="match_first",
MessageAttributes={
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"},
"metricName": {"DataType": "String", "StringValue": "CPUUtilization"},
},
)
topic.publish(
Message="match_second",
MessageAttributes={
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"},
"namespace": {"DataType": "String", "StringValue": "AWS/EC2"},
},
)
topic.publish(
Message="no_match",
MessageAttributes={
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"}
},
)
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert sorted(message_bodies) == ["match_first", "match_second"]
@mock_aws
def test_filtering_prefix():
topic, queue = _setup_filter_policy_test(
@ -1808,6 +1847,42 @@ def test_filtering_prefix_message_body():
)
@mock_aws
def test_filtering_suffix():
topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"suffix": "ball"}]}
)
for interest, idx in [("basketball", "1"), ("rugby", "2"), ("baseball", "3")]:
topic.publish(
Message=f"match{idx}",
MessageAttributes={
"customer_interests": {"DataType": "String", "StringValue": interest},
},
)
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert set(message_bodies) == {"match1", "match3"}
@mock_aws
def test_filtering_suffix_message_body():
topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"suffix": "ball"}]},
filter_policy_scope="MessageBody",
)
for interest in ["basketball", "rugby", "baseball"]:
topic.publish(Message=json.dumps({"customer_interests": interest}))
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert to_comparable_dicts(message_bodies) == to_comparable_dicts(
[{"customer_interests": "basketball"}, {"customer_interests": "baseball"}]
)
@mock_aws
def test_filtering_anything_but():
topic, queue = _setup_filter_policy_test(
@ -2058,6 +2133,51 @@ def test_filtering_anything_but_numeric_string_message_body():
)
@mock_aws
def test_filtering_equals_ignore_case():
topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"equals-ignore-case": "tennis"}]}
)
for interest, idx in [
("tenis", "1"),
("TeNnis", "2"),
("tennis", "3"),
("baseball", "4"),
]:
topic.publish(
Message=f"match{idx}",
MessageAttributes={
"customer_interests": {"DataType": "String", "StringValue": interest},
},
)
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert sorted(message_bodies) == ["match2", "match3"]
@mock_aws
def test_filtering_equals_ignore_case_message_body():
topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"equals-ignore-case": "tennis"}]},
filter_policy_scope="MessageBody",
)
for interest in ["tenis", "TeNnis", "tennis", "baseball"]:
topic.publish(Message=json.dumps({"customer_interests": interest}))
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)["Message"] for m in messages]
assert to_comparable_dicts(message_bodies) == to_comparable_dicts(
[
{"customer_interests": "TeNnis"},
{"customer_interests": "tennis"},
]
)
@mock_aws
def test_filtering_numeric_match():
topic, queue = _setup_filter_policy_test(

View File

@ -671,7 +671,7 @@ def test_subscribe_invalid_filter_policy():
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
)
try:
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
@ -680,14 +680,13 @@ def test_subscribe_invalid_filter_policy():
"FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}})
},
)
except ClientError as err:
assert err.response["Error"]["Code"] == "InvalidParameter"
assert err.response["Error"]["Message"] == (
"Invalid parameter: Filter policy scope MessageAttributes does "
"not support nested filter policy"
)
assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
assert (
err_info.value.response["Error"]["Message"]
== "Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
)
try:
with pytest.raises(ClientError) as err_info:
filter_policy = {
"key_a": ["value_one"],
"key_b": ["value_two"],
@ -702,13 +701,13 @@ def test_subscribe_invalid_filter_policy():
Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps(filter_policy)},
)
except ClientError as err:
assert err.response["Error"]["Code"] == "InvalidParameter"
assert err.response["Error"]["Message"] == (
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
)
assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
assert (
err_info.value.response["Error"]["Message"]
== "Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
)
try:
with pytest.raises(ClientError) as err_info:
nested_filter_policy = {
"key_a": {
"key_b": {
@ -731,11 +730,10 @@ def test_subscribe_invalid_filter_policy():
"FilterPolicy": json.dumps(nested_filter_policy),
},
)
except ClientError as err:
assert err.response["Error"]["Code"] == "InvalidParameter"
assert err.response["Error"]["Message"] == (
"Invalid parameter: FilterPolicy: Filter policy is too complex"
)
assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
assert err_info.value.response["Error"]["Message"] == (
"Invalid parameter: FilterPolicy: Filter policy is too complex"
)
@mock_aws