SNS: Support filters , suffix, equals-ignore-case (#7390)
This commit is contained in:
parent
644407b8c6
commit
14b3db77b9
@ -827,7 +827,11 @@ class SNSBackend(BaseBackend):
|
|||||||
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
|
"Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
|
||||||
)
|
)
|
||||||
elif isinstance(_value, list):
|
elif isinstance(_value, list):
|
||||||
_rules.append(_value)
|
if key == "$or":
|
||||||
|
for val in _value:
|
||||||
|
_rules.extend(aggregate_rules(val, depth=depth + 1))
|
||||||
|
else:
|
||||||
|
_rules.append(_value)
|
||||||
combinations = combinations * len(_value) * depth
|
combinations = combinations * len(_value) * depth
|
||||||
else:
|
else:
|
||||||
raise SNSInvalidParameter(
|
raise SNSInvalidParameter(
|
||||||
@ -864,7 +868,7 @@ class SNSBackend(BaseBackend):
|
|||||||
if isinstance(rule, dict):
|
if isinstance(rule, dict):
|
||||||
keyword = list(rule.keys())[0]
|
keyword = list(rule.keys())[0]
|
||||||
attributes = list(rule.values())[0]
|
attributes = list(rule.values())[0]
|
||||||
if keyword == "anything-but":
|
if keyword in ["anything-but", "equals-ignore-case"]:
|
||||||
continue
|
continue
|
||||||
elif keyword == "exists":
|
elif keyword == "exists":
|
||||||
if not isinstance(attributes, bool):
|
if not isinstance(attributes, bool):
|
||||||
@ -938,7 +942,7 @@ class SNSBackend(BaseBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
elif keyword == "prefix":
|
elif keyword in ["prefix", "suffix"]:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise SNSInvalidParameter(
|
raise SNSInvalidParameter(
|
||||||
|
@ -47,7 +47,9 @@ class FilterPolicyMatcher:
|
|||||||
if message_attributes is None:
|
if message_attributes is None:
|
||||||
message_attributes = {}
|
message_attributes = {}
|
||||||
|
|
||||||
return self._attributes_based_match(message_attributes)
|
return FilterPolicyMatcher._attributes_based_match(
|
||||||
|
message_attributes, source=self.filter_policy
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
message_dict = json.loads(message)
|
message_dict = json.loads(message)
|
||||||
@ -55,10 +57,13 @@ class FilterPolicyMatcher:
|
|||||||
return False
|
return False
|
||||||
return self._body_based_match(message_dict)
|
return self._body_based_match(message_dict)
|
||||||
|
|
||||||
def _attributes_based_match(self, message_attributes: Dict[str, Any]) -> bool:
|
@staticmethod
|
||||||
|
def _attributes_based_match( # type: ignore[misc]
|
||||||
|
message_attributes: Dict[str, Any], source: Dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
return all(
|
return all(
|
||||||
FilterPolicyMatcher._field_match(field, rules, message_attributes)
|
FilterPolicyMatcher._field_match(field, rules, message_attributes)
|
||||||
for field, rules in self.filter_policy.items()
|
for field, rules in source.items()
|
||||||
)
|
)
|
||||||
|
|
||||||
def _body_based_match(self, message_dict: Dict[str, Any]) -> bool:
|
def _body_based_match(self, message_dict: Dict[str, Any]) -> bool:
|
||||||
@ -218,6 +223,9 @@ class FilterPolicyMatcher:
|
|||||||
def _prefix_match(prefix: str, value: str) -> bool:
|
def _prefix_match(prefix: str, value: str) -> bool:
|
||||||
return value.startswith(prefix)
|
return value.startswith(prefix)
|
||||||
|
|
||||||
|
def _suffix_match(prefix: str, value: str) -> bool:
|
||||||
|
return value.endswith(prefix)
|
||||||
|
|
||||||
def _anything_but_match(
|
def _anything_but_match(
|
||||||
filter_value: Union[Dict[str, Any], List[str], str],
|
filter_value: Union[Dict[str, Any], List[str], str],
|
||||||
actual_values: List[str],
|
actual_values: List[str],
|
||||||
@ -316,6 +324,19 @@ class FilterPolicyMatcher:
|
|||||||
else:
|
else:
|
||||||
if _exists_match(value, field, dict_body):
|
if _exists_match(value, field, dict_body):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
elif keyword == "equals-ignore-case" and isinstance(value, str):
|
||||||
|
if attributes_based_check:
|
||||||
|
if field not in dict_body:
|
||||||
|
return False
|
||||||
|
if _str_exact_match(dict_body[field]["Value"].lower(), value):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if field not in dict_body:
|
||||||
|
return False
|
||||||
|
if _str_exact_match(dict_body[field].lower(), value):
|
||||||
|
return True
|
||||||
|
|
||||||
elif keyword == "prefix" and isinstance(value, str):
|
elif keyword == "prefix" and isinstance(value, str):
|
||||||
if attributes_based_check:
|
if attributes_based_check:
|
||||||
if field in dict_body:
|
if field in dict_body:
|
||||||
@ -327,6 +348,17 @@ class FilterPolicyMatcher:
|
|||||||
if field in dict_body:
|
if field in dict_body:
|
||||||
if _prefix_match(value, dict_body[field]):
|
if _prefix_match(value, dict_body[field]):
|
||||||
return True
|
return True
|
||||||
|
elif keyword == "suffix" and isinstance(value, str):
|
||||||
|
if attributes_based_check:
|
||||||
|
if field in dict_body:
|
||||||
|
attr = dict_body[field]
|
||||||
|
if attr["Type"] == "String":
|
||||||
|
if _suffix_match(value, attr["Value"]):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if field in dict_body:
|
||||||
|
if _suffix_match(value, dict_body[field]):
|
||||||
|
return True
|
||||||
|
|
||||||
elif keyword == "anything-but":
|
elif keyword == "anything-but":
|
||||||
if attributes_based_check:
|
if attributes_based_check:
|
||||||
@ -390,4 +422,12 @@ class FilterPolicyMatcher:
|
|||||||
if _numeric_match(numeric_ranges, dict_body[field]):
|
if _numeric_match(numeric_ranges, dict_body[field]):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if field == "$or" and isinstance(rules, list):
|
||||||
|
return any(
|
||||||
|
[
|
||||||
|
FilterPolicyMatcher._attributes_based_match(dict_body, rule)
|
||||||
|
for rule in rules
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
@ -1764,6 +1764,45 @@ def test_filtering_all_AND_matching_no_match_message_body():
|
|||||||
assert message_bodies == []
|
assert message_bodies == []
|
||||||
|
|
||||||
|
|
||||||
|
@mock_aws
|
||||||
|
def test_filtering_or():
|
||||||
|
filter_policy = {
|
||||||
|
"source": ["aws.cloudwatch"],
|
||||||
|
"$or": [
|
||||||
|
{"metricName": ["CPUUtilization"]},
|
||||||
|
{"namespace": ["AWS/EC2"]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
topic, queue = _setup_filter_policy_test(filter_policy)
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message="match_first",
|
||||||
|
MessageAttributes={
|
||||||
|
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"},
|
||||||
|
"metricName": {"DataType": "String", "StringValue": "CPUUtilization"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message="match_second",
|
||||||
|
MessageAttributes={
|
||||||
|
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"},
|
||||||
|
"namespace": {"DataType": "String", "StringValue": "AWS/EC2"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
topic.publish(
|
||||||
|
Message="no_match",
|
||||||
|
MessageAttributes={
|
||||||
|
"source": {"DataType": "String", "StringValue": "aws.cloudwatch"}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||||
|
assert sorted(message_bodies) == ["match_first", "match_second"]
|
||||||
|
|
||||||
|
|
||||||
@mock_aws
|
@mock_aws
|
||||||
def test_filtering_prefix():
|
def test_filtering_prefix():
|
||||||
topic, queue = _setup_filter_policy_test(
|
topic, queue = _setup_filter_policy_test(
|
||||||
@ -1808,6 +1847,42 @@ def test_filtering_prefix_message_body():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_aws
|
||||||
|
def test_filtering_suffix():
|
||||||
|
topic, queue = _setup_filter_policy_test(
|
||||||
|
{"customer_interests": [{"suffix": "ball"}]}
|
||||||
|
)
|
||||||
|
|
||||||
|
for interest, idx in [("basketball", "1"), ("rugby", "2"), ("baseball", "3")]:
|
||||||
|
topic.publish(
|
||||||
|
Message=f"match{idx}",
|
||||||
|
MessageAttributes={
|
||||||
|
"customer_interests": {"DataType": "String", "StringValue": interest},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||||
|
assert set(message_bodies) == {"match1", "match3"}
|
||||||
|
|
||||||
|
|
||||||
|
@mock_aws
|
||||||
|
def test_filtering_suffix_message_body():
|
||||||
|
topic, queue = _setup_filter_policy_test(
|
||||||
|
{"customer_interests": [{"suffix": "ball"}]},
|
||||||
|
filter_policy_scope="MessageBody",
|
||||||
|
)
|
||||||
|
|
||||||
|
for interest in ["basketball", "rugby", "baseball"]:
|
||||||
|
topic.publish(Message=json.dumps({"customer_interests": interest}))
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||||
|
assert to_comparable_dicts(message_bodies) == to_comparable_dicts(
|
||||||
|
[{"customer_interests": "basketball"}, {"customer_interests": "baseball"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock_aws
|
@mock_aws
|
||||||
def test_filtering_anything_but():
|
def test_filtering_anything_but():
|
||||||
topic, queue = _setup_filter_policy_test(
|
topic, queue = _setup_filter_policy_test(
|
||||||
@ -2058,6 +2133,51 @@ def test_filtering_anything_but_numeric_string_message_body():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_aws
|
||||||
|
def test_filtering_equals_ignore_case():
|
||||||
|
topic, queue = _setup_filter_policy_test(
|
||||||
|
{"customer_interests": [{"equals-ignore-case": "tennis"}]}
|
||||||
|
)
|
||||||
|
|
||||||
|
for interest, idx in [
|
||||||
|
("tenis", "1"),
|
||||||
|
("TeNnis", "2"),
|
||||||
|
("tennis", "3"),
|
||||||
|
("baseball", "4"),
|
||||||
|
]:
|
||||||
|
topic.publish(
|
||||||
|
Message=f"match{idx}",
|
||||||
|
MessageAttributes={
|
||||||
|
"customer_interests": {"DataType": "String", "StringValue": interest},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||||
|
assert sorted(message_bodies) == ["match2", "match3"]
|
||||||
|
|
||||||
|
|
||||||
|
@mock_aws
|
||||||
|
def test_filtering_equals_ignore_case_message_body():
|
||||||
|
topic, queue = _setup_filter_policy_test(
|
||||||
|
{"customer_interests": [{"equals-ignore-case": "tennis"}]},
|
||||||
|
filter_policy_scope="MessageBody",
|
||||||
|
)
|
||||||
|
|
||||||
|
for interest in ["tenis", "TeNnis", "tennis", "baseball"]:
|
||||||
|
topic.publish(Message=json.dumps({"customer_interests": interest}))
|
||||||
|
|
||||||
|
messages = queue.receive_messages(MaxNumberOfMessages=5)
|
||||||
|
|
||||||
|
message_bodies = [json.loads(m.body)["Message"] for m in messages]
|
||||||
|
assert to_comparable_dicts(message_bodies) == to_comparable_dicts(
|
||||||
|
[
|
||||||
|
{"customer_interests": "TeNnis"},
|
||||||
|
{"customer_interests": "tennis"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock_aws
|
@mock_aws
|
||||||
def test_filtering_numeric_match():
|
def test_filtering_numeric_match():
|
||||||
topic, queue = _setup_filter_policy_test(
|
topic, queue = _setup_filter_policy_test(
|
||||||
|
@ -671,7 +671,7 @@ def test_subscribe_invalid_filter_policy():
|
|||||||
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
|
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
with pytest.raises(ClientError) as err_info:
|
||||||
conn.subscribe(
|
conn.subscribe(
|
||||||
TopicArn=topic_arn,
|
TopicArn=topic_arn,
|
||||||
Protocol="http",
|
Protocol="http",
|
||||||
@ -680,14 +680,13 @@ def test_subscribe_invalid_filter_policy():
|
|||||||
"FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}})
|
"FilterPolicy": json.dumps({"store": {"key": [{"exists": None}]}})
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except ClientError as err:
|
assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
|
||||||
assert err.response["Error"]["Code"] == "InvalidParameter"
|
assert (
|
||||||
assert err.response["Error"]["Message"] == (
|
err_info.value.response["Error"]["Message"]
|
||||||
"Invalid parameter: Filter policy scope MessageAttributes does "
|
== "Invalid parameter: Filter policy scope MessageAttributes does not support nested filter policy"
|
||||||
"not support nested filter policy"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
with pytest.raises(ClientError) as err_info:
|
||||||
filter_policy = {
|
filter_policy = {
|
||||||
"key_a": ["value_one"],
|
"key_a": ["value_one"],
|
||||||
"key_b": ["value_two"],
|
"key_b": ["value_two"],
|
||||||
@ -702,13 +701,13 @@ def test_subscribe_invalid_filter_policy():
|
|||||||
Endpoint="http://example.com/",
|
Endpoint="http://example.com/",
|
||||||
Attributes={"FilterPolicy": json.dumps(filter_policy)},
|
Attributes={"FilterPolicy": json.dumps(filter_policy)},
|
||||||
)
|
)
|
||||||
except ClientError as err:
|
assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
|
||||||
assert err.response["Error"]["Code"] == "InvalidParameter"
|
assert (
|
||||||
assert err.response["Error"]["Message"] == (
|
err_info.value.response["Error"]["Message"]
|
||||||
"Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
|
== "Invalid parameter: FilterPolicy: Filter policy can not have more than 5 keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
with pytest.raises(ClientError) as err_info:
|
||||||
nested_filter_policy = {
|
nested_filter_policy = {
|
||||||
"key_a": {
|
"key_a": {
|
||||||
"key_b": {
|
"key_b": {
|
||||||
@ -731,11 +730,10 @@ def test_subscribe_invalid_filter_policy():
|
|||||||
"FilterPolicy": json.dumps(nested_filter_policy),
|
"FilterPolicy": json.dumps(nested_filter_policy),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except ClientError as err:
|
assert err_info.value.response["Error"]["Code"] == "InvalidParameter"
|
||||||
assert err.response["Error"]["Code"] == "InvalidParameter"
|
assert err_info.value.response["Error"]["Message"] == (
|
||||||
assert err.response["Error"]["Message"] == (
|
"Invalid parameter: FilterPolicy: Filter policy is too complex"
|
||||||
"Invalid parameter: FilterPolicy: Filter policy is too complex"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@mock_aws
|
@mock_aws
|
||||||
|
Loading…
Reference in New Issue
Block a user