SNS: Support filters , suffix, equals-ignore-case (#7390)
This commit is contained in:
parent
644407b8c6
commit
14b3db77b9
@ -827,7 +827,11 @@ class SNSBackend(BaseBackend):
|
||||
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
|
||||
)
|
||||
elif isinstance(_value, list):
|
||||
_rules.append(_value)
|
||||
if key == "$or":
|
||||
for val in _value:
|
||||
_rules.extend(aggregate_rules(val, depth=depth + 1))
|
||||
else:
|
||||
_rules.append(_value)
|
||||
combinations = combinations * len(_value) * depth
|
||||
else:
|
||||
raise SNSInvalidParameter(
|
||||
@ -864,7 +868,7 @@ class SNSBackend(BaseBackend):
|
||||
if isinstance(rule, dict):
|
||||
keyword = list(rule.keys())[0]
|
||||
attributes = list(rule.values())[0]
|
||||
if keyword == "anything-but":
|
||||
if keyword in ["anything-but", "equals-ignore-case"]:
|
||||
continue
|
||||
elif keyword == "exists":
|
||||
if not isinstance(attributes, bool):
|
||||
@ -938,7 +942,7 @@ class SNSBackend(BaseBackend):
|
||||
)
|
||||
|
||||
continue
|
||||
elif keyword == "prefix":
|
||||
elif keyword in ["prefix", "suffix"]:
|
||||
continue
|
||||
else:
|
||||
raise SNSInvalidParameter(
|
||||
|
@ -47,7 +47,9 @@ class FilterPolicyMatcher:
|
||||
if message_attributes is None:
|
||||
message_attributes = {}
|
||||
|
||||
return self._attributes_based_match(message_attributes)
|
||||
return FilterPolicyMatcher._attributes_based_match(
|
||||
message_attributes, source=self.filter_policy
|
||||
)
|
||||
else:
|
||||
try:
|
||||
message_dict = json.loads(message)
|
||||
@ -55,10 +57,13 @@ class FilterPolicyMatcher:
|
||||
return False
|
||||
return self._body_based_match(message_dict)
|
||||
|
||||
def _attributes_based_match(self, message_attributes: Dict[str, Any]) -> bool:
|
||||
@staticmethod
|
||||
def _attributes_based_match( # type: ignore[misc]
|
||||
message_attributes: Dict[str, Any], source: Dict[str, Any]
|
||||
) -> bool:
|
||||
return all(
|
||||
FilterPolicyMatcher._field_match(field, rules, message_attributes)
|
||||
for field, rules in self.filter_policy.items()
|
||||
for field, rules in source.items()
|
||||
)
|
||||
|
||||
def _body_based_match(self, message_dict: Dict[str, Any]) -> bool:
|
||||
@ -218,6 +223,9 @@ class FilterPolicyMatcher:
|
||||
def _prefix_match(prefix: str, value: str) -> bool:
|
||||
return value.startswith(prefix)
|
||||
|
||||
def _suffix_match(prefix: str, value: str) -> bool:
|
||||
return value.endswith(prefix)
|
||||
|
||||
def _anything_but_match(
|
||||
filter_value: Union[Dict[str, Any], List[str], str],
|
||||
actual_values: List[str],
|
||||
@ -316,6 +324,19 @@ class FilterPolicyMatcher:
|
||||
else:
|
||||
if _exists_match(value, field, dict_body):
|
||||
return True
|
||||
|
||||
elif keyword == "equals-ignore-case" and isinstance(value, str):
|
||||
if attributes_based_check:
|
||||
if field not in dict_body:
|
||||
return False
|
||||
if _str_exact_match(dict_body[field]["Value"].lower(), value):
|
||||
return True
|
||||
else:
|
||||
if field not in dict_body:
|
||||
return False
|
||||
if _str_exact_match(dict_body[field].lower(), value):
|
||||
return True
|
||||
|
||||
elif keyword == "prefix" and isinstance(value, str):
|
||||
if attributes_based_check:
|
||||
if field in dict_body:
|
||||
@ -327,6 +348,17 @@ class FilterPolicyMatcher:
|
||||
if field in dict_body:
|
||||
if _prefix_match(value, dict_body[field]):
|
||||
return True
|
||||
elif keyword == "suffix" and isinstance(value, str):
|
||||
if attributes_based_check:
|
||||
if field in dict_body:
|
||||
attr = dict_body[field]
|
||||
if attr["Type"] == "String":
|
||||
if _suffix_match(value, attr["Value"]):
|
||||
return True
|
||||
else:
|
||||
if field in dict_body:
|
||||
if _suffix_match(value, dict_body[field]):
|
||||
return True
|
||||
|
||||
elif keyword == "anything-but":
|
||||
if attributes_based_check:
|
||||
@ -390,4 +422,12 @@ class FilterPolicyMatcher:
|
||||
if _numeric_match(numeric_ranges, dict_body[field]):
|
||||
return True
|
||||
|
||||
if field == "$or" and isinstance(rules, list):
|
||||
return any(
|
||||
[
|
||||
FilterPolicyMatcher._attributes_based_match(dict_body, rule)
|
||||
for rule in rules
|
||||
]
|
||||
)
|
||||
|
||||
return False
|
||||
|
@ -1764,6 +1764,45 @@ def test_filtering_all_AND_matching_no_match_message_body():
|
||||
assert message_bodies == []
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_filtering_or():
|
||||
filter_policy = {
|
||||
"source": ["aws.cloudwatch"],
|
||||
"$or": [
|
||||
{"metricName": ["CPUUtilization"]},
|
||||
{"namespace": ["AWS/EC2"]},
|
||||
],
|
||||
}
|
||||
topic, queue = _setup_filter_policy_test(filter_policy)
|
||||
|
||||
topic.publish(
|
||||
Message="match_first",
|
||||
MessageAttributes={
|
||||
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"},
|
||||
"metricName": {"DataType": "String", "StringValue": "CPUUtilization"},
|
||||
},
|
||||
)
|
||||
|
||||
topic.publish(
|
||||
Message="match_second",
|
||||
MessageAttributes={
|
||||
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"},
|
||||
"namespace": {"DataType": "String", "StringValue": "AWS/EC2"},
|
||||
},
|
||||
)
|
||||
|
||||
topic.publish(
|
||||
Message="no_match",
|
||||
MessageAttributes={
|
||||
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"}
|
||||
},
|
||||
)
|
||||
|
||||
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||
assert sorted(message_bodies) == ["match_first", "match_second"]
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_filtering_prefix():
|
||||
topic, queue = _setup_filter_policy_test(
|
||||
@ -1808,6 +1847,42 @@ def test_filtering_prefix_message_body():
|
||||
)
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_filtering_suffix():
|
||||
topic, queue = _setup_filter_policy_test(
|
||||
{"customer_interests": [{"suffix": "ball"}]}
|
||||
)
|
||||
|
||||
for interest, idx in [("basketball", "1"), ("rugby", "2"), ("baseball", "3")]:
|
||||
topic.publish(
|
||||
Message=f"match{idx}",
|
||||
MessageAttributes={
|
||||
"customer_interests": {"DataType": "String", "StringValue": interest},
|
||||
},
|
||||
)
|
||||
|
||||
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||
assert set(message_bodies) == {"match1", "match3"}
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_filtering_suffix_message_body():
|
||||
topic, queue = _setup_filter_policy_test(
|
||||
{"customer_interests": [{"suffix": "ball"}]},
|
||||
filter_policy_scope="MessageBody",
|
||||
)
|
||||
|
||||
for interest in ["basketball", "rugby", "baseball"]:
|
||||
topic.publish(Message=json.dumps({"customer_interests": interest}))
|
||||
|
||||
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||
assert to_comparable_dicts(message_bodies) == to_comparable_dicts(
|
||||
[{"customer_interests": "basketball"}, {"customer_interests": "baseball"}]
|
||||
)
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_filtering_anything_but():
|
||||
topic, queue = _setup_filter_policy_test(
|
||||
@ -2058,6 +2133,51 @@ def test_filtering_anything_but_numeric_string_message_body():
|
||||
)
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_filtering_equals_ignore_case():
|
||||
topic, queue = _setup_filter_policy_test(
|
||||
{"customer_interests": [{"equals-ignore-case": "tennis"}]}
|
||||
)
|
||||
|
||||
for interest, idx in [
|
||||
("tenis", "1"),
|
||||
("TeNnis", "2"),
|
||||
("tennis", "3"),
|
||||
("baseball", "4"),
|
||||
]:
|
||||
topic.publish(
|
||||
Message=f"match{idx}",
|
||||
MessageAttributes={
|
||||
"customer_interests": {"DataType": "String", "StringValue": interest},
|
||||
},
|
||||
)
|
||||
|
||||
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||
assert sorted(message_bodies) == ["match2", "match3"]
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_filtering_equals_ignore_case_message_body():
|
||||
topic, queue = _setup_filter_policy_test(
|
||||
{"customer_interests": [{"equals-ignore-case": "tennis"}]},
|
||||
filter_policy_scope="MessageBody",
|
||||
)
|
||||
|
||||
for interest in ["tenis", "TeNnis", "tennis", "baseball"]:
|
||||
topic.publish(Message=json.dumps({"customer_interests": interest}))
|
||||
|
||||
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||
|
||||
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||
assert to_comparable_dicts(message_bodies) == to_comparable_dicts(
|
||||
[
|
||||
{"customer_interests": "TeNnis"},
|
||||
{"customer_interests": "tennis"},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_filtering_numeric_match():
|
||||
topic, queue = _setup_filter_policy_test(
|
||||
|
@ -671,7 +671,7 @@ def test_subscribe_invalid_filter_policy():
|
||||
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
|
||||
)
|
||||
|
||||
try:
|
||||
with pytest.raises(ClientError) as err_info:
|
||||
conn.subscribe(
|
||||
TopicArn=topic_arn,
|
||||
Protocol="http",
|
||||
@ -680,14 +680,13 @@ def test_subscribe_invalid_filter_policy():
|
||||
"FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}})
|
||||
},
|
||||
)
|
||||
except ClientError as err:
|
||||
assert err.response["Error"]["Code"] == "InvalidParameter"
|
||||
assert err.response["Error"]["Message"] == (
|
||||
"Invalid parameter: Filter policy scope MessageAttributes does "
|
||||
"not support nested filter policy"
|
||||
)
|
||||
assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
|
||||
assert (
|
||||
err_info.value.response["Error"]["Message"]
|
||||
== "Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
|
||||
)
|
||||
|
||||
try:
|
||||
with pytest.raises(ClientError) as err_info:
|
||||
filter_policy = {
|
||||
"key_a": ["value_one"],
|
||||
"key_b": ["value_two"],
|
||||
@ -702,13 +701,13 @@ def test_subscribe_invalid_filter_policy():
|
||||
Endpoint="http://example.com/",
|
||||
Attributes={"FilterPolicy": json.dumps(filter_policy)},
|
||||
)
|
||||
except ClientError as err:
|
||||
assert err.response["Error"]["Code"] == "InvalidParameter"
|
||||
assert err.response["Error"]["Message"] == (
|
||||
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
|
||||
)
|
||||
assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
|
||||
assert (
|
||||
err_info.value.response["Error"]["Message"]
|
||||
== "Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
|
||||
)
|
||||
|
||||
try:
|
||||
with pytest.raises(ClientError) as err_info:
|
||||
nested_filter_policy = {
|
||||
"key_a": {
|
||||
"key_b": {
|
||||
@ -731,11 +730,10 @@ def test_subscribe_invalid_filter_policy():
|
||||
"FilterPolicy": json.dumps(nested_filter_policy),
|
||||
},
|
||||
)
|
||||
except ClientError as err:
|
||||
assert err.response["Error"]["Code"] == "InvalidParameter"
|
||||
assert err.response["Error"]["Message"] == (
|
||||
"Invalid parameter: FilterPolicy: Filter policy is too complex"
|
||||
)
|
||||
assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
|
||||
assert err_info.value.response["Error"]["Message"] == (
|
||||
"Invalid parameter: FilterPolicy: Filter policy is too complex"
|
||||
)
|
||||
|
||||
|
||||
@mock_aws
|
||||
|
Loading…
Reference in New Issue
Block a user