SNS: numeric filtering fixups + parameter validation (#6242)

This commit is contained in:
Jakub P 2023-04-24 12:05:21 +02:00 committed by GitHub
parent ceb52cbaae
commit 2d7c38f64f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 274 additions and 50 deletions

View File

@ -9,7 +9,7 @@ import json
SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?> SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Error> <Error>
<Code>{{error_type}}</Code> <Code>{{error_type}}</Code>
<Message>{{message}}</Message> <Message><![CDATA[{{message}}]]></Message>
{% block extra %}{% endblock %} {% block extra %}{% endblock %}
<{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}> <{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}>
</Error> </Error>
@ -19,7 +19,7 @@ WRAPPED_SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse{% if xmlns is defined %} xmlns="{{xmlns}}"{% endif %}> <ErrorResponse{% if xmlns is defined %} xmlns="{{xmlns}}"{% endif %}>
<Error> <Error>
<Code>{{error_type}}</Code> <Code>{{error_type}}</Code>
<Message>{{message}}</Message> <Message><![CDATA[{{message}}]]></Message>
{% block extra %}{% endblock %} {% block extra %}{% endblock %}
<{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}> <{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}>
</Error> </Error>

View File

@ -301,10 +301,11 @@ class Subscription(BaseModel):
else: else:
return False return False
for attribute_values in attribute_values: for attribute_value in attribute_values:
attribute_value = float(attribute_value)
# Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6 # Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
if int(attribute_values * 1000000) == int(rule * 1000000): if int(attribute_value * 1000000) == int(rule * 1000000):
return True return True
if isinstance(rule, dict): if isinstance(rule, dict):
keyword = list(rule.keys())[0] keyword = list(rule.keys())[0]
@ -347,7 +348,7 @@ class Subscription(BaseModel):
[value] if isinstance(value, str) else value [value] if isinstance(value, str) else value
) )
if attr["Type"] == "Number": if attr["Type"] == "Number":
actual_values = [str(attr["Value"])] actual_values = [float(attr["Value"])]
elif attr["Type"] == "String": elif attr["Type"] == "String":
actual_values = [attr["Value"]] actual_values = [attr["Value"]]
else: else:
@ -361,10 +362,10 @@ class Subscription(BaseModel):
message_attributes.get(field, {}).get("Type", "") message_attributes.get(field, {}).get("Type", "")
== "Number" == "Number"
): ):
msg_value = message_attributes[field]["Value"] msg_value = float(message_attributes[field]["Value"])
matches = [] matches = []
for operator, test_value in numeric_ranges: for operator, test_value in numeric_ranges:
test_value = int(test_value) test_value = test_value
if operator == ">": if operator == ">":
matches.append((msg_value > test_value)) matches.append((msg_value > test_value))
if operator == ">=": if operator == ">=":
@ -376,7 +377,6 @@ class Subscription(BaseModel):
if operator == "<=": if operator == "<=":
matches.append((msg_value <= test_value)) matches.append((msg_value <= test_value))
return all(matches) return all(matches)
attr = message_attributes[field]
return False return False
return all( return all(
@ -832,6 +832,70 @@ class SNSBackend(BaseBackend):
) )
continue continue
elif keyword == "numeric": elif keyword == "numeric":
# TODO: All of the exceptions listed below contain column pointing where the error is (in AWS response)
# Example: 'Value of < must be numeric\n at [Source: (String)"{"price":[{"numeric":["<","100"]}]}"; line: 1, column: 28]'
# While it probably can be implemented, it doesn't feel as important as the general parameter checking
attributes_copy = attributes[:]
if not attributes_copy:
raise SNSInvalidParameter(
"Invalid parameter: Attributes Reason: FilterPolicy: Invalid member in numeric match: ]\n at ..."
)
operator = attributes_copy.pop(0)
if not isinstance(operator, str):
raise SNSInvalidParameter(
f"Invalid parameter: Attributes Reason: FilterPolicy: Invalid member in numeric match: {(str(operator))}\n at ..."
)
if operator not in ("<", "<=", "=", ">", ">="):
raise SNSInvalidParameter(
f"Invalid parameter: Attributes Reason: FilterPolicy: Unrecognized numeric range operator: {(str(operator))}\n at ..."
)
try:
value = attributes_copy.pop(0)
except IndexError:
value = None
if value is None or not isinstance(value, (int, float)):
raise SNSInvalidParameter(
f"Invalid parameter: Attributes Reason: FilterPolicy: Value of {(str(operator))} must be numeric\n at ..."
)
if not attributes_copy:
continue
if operator not in (">", ">="):
raise SNSInvalidParameter(
"Invalid parameter: Attributes Reason: FilterPolicy: Too many elements in numeric expression\n at ..."
)
second_operator = attributes_copy.pop(0)
if second_operator not in ("<", "<="):
raise SNSInvalidParameter(
f"Invalid parameter: Attributes Reason: FilterPolicy: Bad numeric range operator: {(str(second_operator))}\n at ..."
)
try:
second_value = attributes_copy.pop(0)
except IndexError:
second_value = None
if second_value is None or not isinstance(
second_value, (int, float)
):
raise SNSInvalidParameter(
f"Invalid parameter: Attributes Reason: FilterPolicy: Value of {(str(second_operator))} must be numeric\n at ..."
)
if second_value <= value:
raise SNSInvalidParameter(
"Invalid parameter: Attributes Reason: FilterPolicy: Bottom must be less than top\n at ..."
)
continue continue
elif keyword == "prefix": elif keyword == "prefix":
continue continue

View File

@ -66,20 +66,19 @@ class SNSResponse(BaseResponse):
transform_value = None transform_value = None
if "StringValue" in value: if "StringValue" in value:
transform_value = value["StringValue"]
if data_type == "Number": if data_type == "Number":
try: try:
transform_value = int(value["StringValue"]) int(transform_value)
except ValueError: except ValueError:
try: try:
transform_value = float(value["StringValue"]) float(transform_value)
except ValueError: except ValueError:
raise InvalidParameterValue( raise InvalidParameterValue(
"An error occurred (ParameterValueInvalid) " "An error occurred (ParameterValueInvalid) "
"when calling the Publish operation: " "when calling the Publish operation: "
f"Could not cast message attribute '{name}' value to number." f"Could not cast message attribute '{name}' value to number."
) )
else:
transform_value = value["StringValue"]
elif "BinaryValue" in value: elif "BinaryValue" in value:
transform_value = value["BinaryValue"] transform_value = value["BinaryValue"]
if transform_value == "": if transform_value == "":

View File

@ -230,7 +230,7 @@ def test_publish_to_sqs_msg_attr_number_type():
message = json.loads(queue.receive_messages()[0].body) message = json.loads(queue.receive_messages()[0].body)
message["Message"].should.equal("test message") message["Message"].should.equal("test message")
message["MessageAttributes"].should.equal( message["MessageAttributes"].should.equal(
{"retries": {"Type": "Number", "Value": 0}} {"retries": {"Type": "Number", "Value": "0"}}
) )
message = queue_raw.receive_messages()[0] message = queue_raw.receive_messages()[0]
@ -731,7 +731,7 @@ def test_filtering_exact_number_int():
message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies = [json.loads(m.body)["Message"] for m in messages]
message_bodies.should.equal(["match"]) message_bodies.should.equal(["match"])
message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages]
message_attributes.should.equal([{"price": {"Type": "Number", "Value": 100}}]) message_attributes.should.equal([{"price": {"Type": "Number", "Value": "100"}}])
@mock_sqs @mock_sqs
@ -748,7 +748,7 @@ def test_filtering_exact_number_float():
message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies = [json.loads(m.body)["Message"] for m in messages]
message_bodies.should.equal(["match"]) message_bodies.should.equal(["match"])
message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages]
message_attributes.should.equal([{"price": {"Type": "Number", "Value": 100.1}}]) message_attributes.should.equal([{"price": {"Type": "Number", "Value": "100.1"}}])
@mock_sqs @mock_sqs
@ -759,7 +759,7 @@ def test_filtering_exact_number_float_accuracy():
topic.publish( topic.publish(
Message="match", Message="match",
MessageAttributes={ MessageAttributes={
"price": {"DataType": "Number", "StringValue": "100.1234561"} "price": {"DataType": "Number", "StringValue": "100.1234567"}
}, },
) )
@ -768,7 +768,7 @@ def test_filtering_exact_number_float_accuracy():
message_bodies.should.equal(["match"]) message_bodies.should.equal(["match"])
message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages]
message_attributes.should.equal( message_attributes.should.equal(
[{"price": {"Type": "Number", "Value": 100.1234561}}] [{"price": {"Type": "Number", "Value": "100.1234567"}}]
) )
@ -892,7 +892,7 @@ def test_filtering_string_array_with_number_float_accuracy_match():
MessageAttributes={ MessageAttributes={
"price": { "price": {
"DataType": "String.Array", "DataType": "String.Array",
"StringValue": json.dumps([100.1234561, 50]), "StringValue": json.dumps([100.1234567, 50]),
} }
}, },
) )
@ -902,7 +902,7 @@ def test_filtering_string_array_with_number_float_accuracy_match():
message_bodies.should.equal(["match"]) message_bodies.should.equal(["match"])
message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages]
message_attributes.should.equal( message_attributes.should.equal(
[{"price": {"Type": "String.Array", "Value": json.dumps([100.1234561, 50])}}] [{"price": {"Type": "String.Array", "Value": json.dumps([100.1234567, 50])}}]
) )
@ -1083,7 +1083,7 @@ def test_filtering_all_AND_matching_match():
"Type": "String.Array", "Type": "String.Array",
"Value": json.dumps(["basketball", "rugby"]), "Value": json.dumps(["basketball", "rugby"]),
}, },
"price": {"Type": "Number", "Value": 100}, "price": {"Type": "Number", "Value": "100"},
} }
] ]
) )
@ -1228,7 +1228,7 @@ def test_filtering_anything_but_unknown():
@mock_sns @mock_sns
def test_filtering_anything_but_numeric(): def test_filtering_anything_but_numeric():
topic, queue = _setup_filter_policy_test( topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"anything-but": ["100"]}]} {"customer_interests": [{"anything-but": [100]}]}
) )
for nr, idx in [("50", "1"), ("100", "2"), ("150", "3")]: for nr, idx in [("50", "1"), ("100", "2"), ("150", "3")]:
@ -1248,7 +1248,7 @@ def test_filtering_anything_but_numeric():
@mock_sns @mock_sns
def test_filtering_numeric_match(): def test_filtering_numeric_match():
topic, queue = _setup_filter_policy_test( topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"numeric": ["=", "100"]}]} {"customer_interests": [{"numeric": ["=", 100]}]}
) )
for nr, idx in [("50", "1"), ("100", "2"), ("150", "3")]: for nr, idx in [("50", "1"), ("100", "2"), ("150", "3")]:
@ -1268,7 +1268,7 @@ def test_filtering_numeric_match():
@mock_sns @mock_sns
def test_filtering_numeric_range(): def test_filtering_numeric_range():
topic, queue = _setup_filter_policy_test( topic, queue = _setup_filter_policy_test(
{"customer_interests": [{"numeric": [">", "49", "<=", "100"]}]} {"customer_interests": [{"numeric": [">", 49, "<=", 100]}]}
) )
for nr, idx in [("50", "1"), ("100", "2"), ("150", "3")]: for nr, idx in [("50", "1"), ("100", "2"), ("150", "3")]:

View File

@ -416,7 +416,7 @@ def test_subscribe_invalid_filter_policy():
response = conn.list_topics() response = conn.list_topics()
topic_arn = response["Topics"][0]["TopicArn"] topic_arn = response["Topics"][0]["TopicArn"]
try: with pytest.raises(ClientError) as err_info:
conn.subscribe( conn.subscribe(
TopicArn=topic_arn, TopicArn=topic_arn,
Protocol="http", Protocol="http",
@ -425,61 +425,222 @@ def test_subscribe_invalid_filter_policy():
"FilterPolicy": json.dumps({"store": [str(i) for i in range(151)]}) "FilterPolicy": json.dumps({"store": [str(i) for i in range(151)]})
}, },
) )
except ClientError as err:
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter") err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal( err.response["Error"]["Message"].should.equal(
"Invalid parameter: FilterPolicy: Filter policy is too complex" "Invalid parameter: FilterPolicy: Filter policy is too complex"
) )
try: with pytest.raises(ClientError) as err_info:
conn.subscribe( conn.subscribe(
TopicArn=topic_arn, TopicArn=topic_arn,
Protocol="http", Protocol="http",
Endpoint="http://example.com/", Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps({"store": [["example_corp"]]})}, Attributes={"FilterPolicy": json.dumps({"store": [["example_corp"]]})},
) )
except ClientError as err:
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter") err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal( err.response["Error"]["Message"].should.equal(
"Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null" "Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null"
) )
try: with pytest.raises(ClientError) as err_info:
conn.subscribe( conn.subscribe(
TopicArn=topic_arn, TopicArn=topic_arn,
Protocol="http", Protocol="http",
Endpoint="http://example.com/", Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps({"store": [{"exists": None}]})}, Attributes={"FilterPolicy": json.dumps({"store": [{"exists": None}]})},
) )
except ClientError as err:
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter") err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal( err.response["Error"]["Message"].should.equal(
"Invalid parameter: FilterPolicy: exists match pattern must be either true or false." "Invalid parameter: FilterPolicy: exists match pattern must be either true or false."
) )
try: with pytest.raises(ClientError) as err_info:
conn.subscribe( conn.subscribe(
TopicArn=topic_arn, TopicArn=topic_arn,
Protocol="http", Protocol="http",
Endpoint="http://example.com/", Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps({"store": [{"error": True}]})}, Attributes={"FilterPolicy": json.dumps({"store": [{"error": True}]})},
) )
except ClientError as err:
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter") err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal( err.response["Error"]["Message"].should.equal(
"Invalid parameter: FilterPolicy: Unrecognized match type error" "Invalid parameter: FilterPolicy: Unrecognized match type error"
) )
try: with pytest.raises(ClientError) as err_info:
conn.subscribe( conn.subscribe(
TopicArn=topic_arn, TopicArn=topic_arn,
Protocol="http", Protocol="http",
Endpoint="http://example.com/", Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps({"store": [1000000001]})}, Attributes={"FilterPolicy": json.dumps({"store": [1000000001]})},
) )
except ClientError as err:
err = err_info.value
err.response["Error"]["Code"].should.equal("InternalFailure") err.response["Error"]["Code"].should.equal("InternalFailure")
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={
"FilterPolicy": json.dumps({"price": [{"numeric": ["<", "100"]}]})
},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
)
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={
"FilterPolicy": json.dumps(
{"price": [{"numeric": [">", 50, "<=", "100"]}]}
)
},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Value of <= must be numeric\n at ..."
)
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps({"price": [{"numeric": []}]})},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Invalid member in numeric match: ]\n at ..."
)
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={
"FilterPolicy": json.dumps({"price": [{"numeric": [50, "<=", "100"]}]})
},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Invalid member in numeric match: 50\n at ..."
)
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps({"price": [{"numeric": ["<"]}]})},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
)
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={"FilterPolicy": json.dumps({"price": [{"numeric": ["0"]}]})},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Unrecognized numeric range operator: 0\n at ..."
)
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={
"FilterPolicy": json.dumps({"price": [{"numeric": ["<", 20, ">", 1]}]})
},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Too many elements in numeric expression\n at ..."
)
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={
"FilterPolicy": json.dumps({"price": [{"numeric": [">", 20, ">", 1]}]})
},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Bad numeric range operator: >\n at ..."
)
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={
"FilterPolicy": json.dumps({"price": [{"numeric": [">", 20, "<", 1]}]})
},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Bottom must be less than top\n at ..."
)
with pytest.raises(ClientError) as err_info:
conn.subscribe(
TopicArn=topic_arn,
Protocol="http",
Endpoint="http://example.com/",
Attributes={
"FilterPolicy": json.dumps({"price": [{"numeric": [">", 20, "<"]}]})
},
)
err = err_info.value
err.response["Error"]["Code"].should.equal("InvalidParameter")
err.response["Error"]["Message"].should.equal(
"Invalid parameter: Attributes Reason: FilterPolicy: Value of < must be numeric\n at ..."
)
@mock_sns @mock_sns
def test_check_not_opted_out(): def test_check_not_opted_out():