diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py index 092bb9d69..95b91acca 100644 --- a/moto/sns/exceptions.py +++ b/moto/sns/exceptions.py @@ -24,3 +24,11 @@ class SnsEndpointDisabled(RESTError): def __init__(self, message): super(SnsEndpointDisabled, self).__init__( "EndpointDisabled", message) + + +class SNSInvalidParameter(RESTError): + code = 400 + + def __init__(self, message): + super(SNSInvalidParameter, self).__init__( + "InvalidParameter", message) diff --git a/moto/sns/models.py b/moto/sns/models.py index dc7420db4..009398407 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -13,7 +13,7 @@ from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.sqs import sqs_backends from .exceptions import ( - SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled + SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter ) from .utils import make_arn_for_topic, make_arn_for_subscription @@ -76,6 +76,7 @@ class Subscription(BaseModel): self.endpoint = endpoint self.protocol = protocol self.arn = make_arn_for_subscription(self.topic.arn) + self.attributes = {} def publish(self, message, message_id): if self.protocol == 'sqs': @@ -301,6 +302,26 @@ class SNSBackend(BaseBackend): raise SNSNotFoundError( "Endpoint with arn {0} not found".format(arn)) + def get_subscription_attributes(self, arn): + _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] + if not _subscription: + raise SNSNotFoundError("Subscription with arn {0} not found".format(arn)) + subscription = _subscription[0] + + return subscription.attributes + + def set_subscription_attributes(self, arn, name, value): + if name not in ['RawMessageDelivery', 'DeliveryPolicy']: + raise SNSInvalidParameter('AttributeName') + + # TODO: should do validation + _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] + if not _subscription: + raise SNSNotFoundError("Subscription with arn {0} not found".format(arn)) + subscription = _subscription[0] + + subscription.attributes[name] = value + sns_backends = {} for region in boto.sns.regions(): diff --git a/moto/sns/responses.py b/moto/sns/responses.py index edb82e40c..9c079b006 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -445,6 +445,20 @@ class SNSResponse(BaseResponse): template = self.response_template(DELETE_ENDPOINT_TEMPLATE) return template.render() + def get_subscription_attributes(self): + arn = self._get_param('SubscriptionArn') + attributes = self.backend.get_subscription_attributes(arn) + template = self.response_template(GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) + return template.render(attributes=attributes) + + def set_subscription_attributes(self): + arn = self._get_param('SubscriptionArn') + attr_name = self._get_param('AttributeName') + attr_value = self._get_param('AttributeValue') + self.backend.set_subscription_attributes(arn, attr_name, attr_value) + template = self.response_template(SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) + return template.render() + CREATE_TOPIC_TEMPLATE = """ @@ -719,3 +733,28 @@ LIST_SUBSCRIPTIONS_BY_TOPIC_TEMPLATE = """384ac68d-3775-11df-8963-01868b7c937a """ + + +# Not responding aws system attribetus like 'Owner' and 'SubscriptionArn' +GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE = """ + + + {% for name, value in attributes.items() %} + + {{ name }} + {{ value }} + + {% endfor %} + + + + 057f074c-33a7-11df-9540-99d0768312d3 + +""" + + +SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE = """ + + a8763b99-33a7-11df-a9b7-05d48da6f042 + +""" diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index cfb57b9ec..a53744d63 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -17,8 +17,6 @@ from freezegun import freeze_time MESSAGE_FROM_SQS_TEMPLATE = '{\n "Message": "%s",\n "MessageId": "%s",\n "Signature": "EXAMPLElDMXvB8r9R83tGoNn0ecwd5UjllzsvSvbItzfaMpN2nk5HVSw7XnOn/49IkxDKz8YrlH2qJXj2iZB0Zo2O71c4qQk1fMUDi3LGpij7RCW7AW9vYYsSqIKRnFS94ilu7NFhUzLiieYr4BKHpdTmdD6c0esKEYBpabxDSc=",\n "SignatureVersion": "1",\n "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem",\n "Subject": "my subject",\n "Timestamp": "2015-01-01T12:00:00.000Z",\n "TopicArn": "arn:aws:sns:%s:123456789012:some-topic",\n "Type": "Notification",\n "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55"\n}' -from nose.plugins.attrib import attr -@attr('slow') @mock_sqs @mock_sns def test_publish_to_sqs(): diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index ac325ed20..8cb5c1886 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -1,8 +1,12 @@ from __future__ import unicode_literals import boto3 +import json import sure # noqa +from botocore.exceptions import ClientError +from nose.tools import assert_raises + from moto import mock_sns from moto.sns.models import DEFAULT_PAGE_SIZE @@ -124,3 +128,72 @@ def test_subscription_paging(): topic1_subscriptions["Subscriptions"].should.have.length_of( int(DEFAULT_PAGE_SIZE / 3)) topic1_subscriptions.shouldnt.have("NextToken") + + +@mock_sns +def test_set_subscription_attributes(): + conn = boto3.client('sns', region_name='us-east-1') + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]['TopicArn'] + + conn.subscribe(TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/") + + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + subscription_arn = subscription["SubscriptionArn"] + attrs = conn.get_subscription_attributes( + SubscriptionArn=subscription_arn + ) + attrs.should.have.key('Attributes') + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName='RawMessageDelivery', + AttributeValue='true' + ) + delivery_policy = json.dumps({ + 'healthyRetryPolicy': { + "numRetries": 10, + "minDelayTarget": 1, + "maxDelayTarget":2 + } + }) + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName='DeliveryPolicy', + AttributeValue=delivery_policy + ) + attrs = conn.get_subscription_attributes( + SubscriptionArn=subscription_arn + ) + attrs['Attributes']['RawMessageDelivery'].should.equal('true') + attrs['Attributes']['DeliveryPolicy'].should.equal(delivery_policy) + + # not existing subscription + with assert_raises(ClientError): + conn.set_subscription_attributes( + SubscriptionArn='invalid', + AttributeName='RawMessageDelivery', + AttributeValue='true' + ) + with assert_raises(ClientError): + attrs = conn.get_subscription_attributes( + SubscriptionArn='invalid' + ) + + + # invalid attr name + with assert_raises(ClientError): + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName='InvalidName', + AttributeValue='true' + )