diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 035d56584..8c1bb885e 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -181,6 +181,7 @@ class SNSResponse(BaseResponse): topic_arn = self._get_param('TopicArn') endpoint = self._get_param('Endpoint') protocol = self._get_param('Protocol') + attributes = self._get_attributes() if protocol == 'sms' and not is_e164(endpoint): return self._error( @@ -190,6 +191,10 @@ class SNSResponse(BaseResponse): subscription = self.backend.subscribe(topic_arn, endpoint, protocol) + if attributes is not None: + for attr_name, attr_value in attributes.items(): + self.backend.set_subscription_attributes(subscription.arn, attr_name, attr_value) + if self.request_json: return json.dumps({ "SubscribeResponse": { diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index 98075e617..2a56c8213 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -182,6 +182,72 @@ def test_subscription_paging(): topic1_subscriptions.shouldnt.have("NextToken") +@mock_sns +def test_creating_subscription_with_attributes(): + conn = boto3.client('sns', region_name='us-east-1') + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]['TopicArn'] + + delivery_policy = json.dumps({ + 'healthyRetryPolicy': { + "numRetries": 10, + "minDelayTarget": 1, + "maxDelayTarget":2 + } + }) + + filter_policy = json.dumps({ + "store": ["example_corp"], + "event": ["order_cancelled"], + "encrypted": [False], + "customer_interests": ["basketball", "baseball"] + }) + + conn.subscribe(TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={ + 'RawMessageDelivery': 'true', + 'DeliveryPolicy': delivery_policy, + 'FilterPolicy': filter_policy + }) + + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + # Test the subscription attributes have been set + subscription_arn = subscription["SubscriptionArn"] + attrs = conn.get_subscription_attributes( + SubscriptionArn=subscription_arn + ) + + attrs['Attributes']['RawMessageDelivery'].should.equal('true') + attrs['Attributes']['DeliveryPolicy'].should.equal(delivery_policy) + attrs['Attributes']['FilterPolicy'].should.equal(filter_policy) + + # Now unsubscribe the subscription + conn.unsubscribe(SubscriptionArn=subscription["SubscriptionArn"]) + + # And there should be zero subscriptions left + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(0) + + # invalid attr name + with assert_raises(ClientError): + conn.subscribe(TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={ + 'InvalidName': 'true' + }) + + @mock_sns def test_set_subscription_attributes(): conn = boto3.client('sns', region_name='us-east-1')