diff --git a/moto/sns/models.py b/moto/sns/models.py index e4a038977..c7f678a75 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -732,8 +732,17 @@ class SNSBackend(BaseBackend): raise SNSNotFoundError( "Subscription does not exist", template="wrapped_single_error" ) + # AWS does not return the FilterPolicy scope if the FilterPolicy is not set + # if the FilterPolicy is set and not the FilterPolicyScope, it returns the default value + attributes = {**subscription.attributes} + if "FilterPolicyScope" in attributes and not attributes.get("FilterPolicy"): + attributes.pop("FilterPolicyScope", None) + attributes.pop("FilterPolicy", None) - return subscription.attributes + elif "FilterPolicy" in attributes and "FilterPolicyScope" not in attributes: + attributes["FilterPolicyScope"] = "MessageAttributes" + + return attributes def set_subscription_attributes(self, arn: str, name: str, value: Any) -> None: if name not in [ @@ -753,15 +762,19 @@ class SNSBackend(BaseBackend): subscription = _subscription[0] if name == "FilterPolicy": - filter_policy = json.loads(value) - # we validate the filter policy differently depending on the scope - # we need to always set the scope first - filter_policy_scope = subscription.attributes.get("FilterPolicyScope") - self._validate_filter_policy(filter_policy, scope=filter_policy_scope) - subscription._filter_policy = filter_policy - subscription._filter_policy_matcher = FilterPolicyMatcher( - filter_policy, filter_policy_scope - ) + if value: + filter_policy = json.loads(value) + # we validate the filter policy differently depending on the scope + # we need to always set the scope first + filter_policy_scope = subscription.attributes.get("FilterPolicyScope") + self._validate_filter_policy(filter_policy, scope=filter_policy_scope) + subscription._filter_policy = filter_policy + subscription._filter_policy_matcher = FilterPolicyMatcher( + filter_policy, filter_policy_scope + ) + else: + subscription._filter_policy = None + subscription._filter_policy_matcher = None subscription.attributes[name] = value diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index 700ccd32a..a63df5f50 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -393,6 +393,7 @@ def test_set_subscription_attributes(): assert attrs["Attributes"]["RawMessageDelivery"] == "true" assert attrs["Attributes"]["DeliveryPolicy"] == delivery_policy assert attrs["Attributes"]["FilterPolicy"] == filter_policy + assert attrs["Attributes"]["FilterPolicyScope"] == "MessageAttributes" filter_policy_scope = "MessageBody" conn.set_subscription_attributes( @@ -405,6 +406,17 @@ def test_set_subscription_attributes(): assert attrs["Attributes"]["FilterPolicyScope"] == filter_policy_scope + # test unsetting a filter policy + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName="FilterPolicy", + AttributeValue="", + ) + + attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn) + assert "FilterPolicy" not in attrs["Attributes"] + assert "FilterPolicyScope" not in attrs["Attributes"] + # not existing subscription with pytest.raises(ClientError): conn.set_subscription_attributes( @@ -413,7 +425,7 @@ def test_set_subscription_attributes(): AttributeValue="true", ) with pytest.raises(ClientError): - attrs = conn.get_subscription_attributes(SubscriptionArn="invalid") + conn.get_subscription_attributes(SubscriptionArn="invalid") # invalid attr name with pytest.raises(ClientError):