From 706c60175b1acdd175dfd136af8ebaffcdf29a59 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Sat, 14 Mar 2015 09:06:31 -0400 Subject: [PATCH 1/3] Add SNS applications and endpoints. --- moto/sns/exceptions.py | 10 ++ moto/sns/models.py | 104 ++++++++++++- moto/sns/responses.py | 161 +++++++++++++++++++- tests/test_sns/test_application.py | 237 ++++++++++++++++++++++++++++- tests/test_sns/test_topics.py | 7 + 5 files changed, 503 insertions(+), 16 deletions(-) create mode 100644 moto/sns/exceptions.py diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py new file mode 100644 index 000000000..76e0bccb1 --- /dev/null +++ b/moto/sns/exceptions.py @@ -0,0 +1,10 @@ +from __future__ import unicode_literals +from moto.core.exceptions import RESTError + + +class SNSNotFoundError(RESTError): + code = 404 + + def __init__(self, message): + super(SNSNotFoundError, self).__init__( + "NotFound", message) diff --git a/moto/sns/models.py b/moto/sns/models.py index 769b755aa..8e9404690 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -11,6 +11,7 @@ from moto.compat import OrderedDict from moto.core import BaseBackend from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.sqs import sqs_backends +from .exceptions import SNSNotFoundError from .utils import make_arn_for_topic, make_arn_for_subscription DEFAULT_ACCOUNT_ID = 123456789012 @@ -93,10 +94,52 @@ class Subscription(object): } +class PlatformApplication(object): + def __init__(self, region, name, platform, attributes): + self.region = region + self.name = name + self.platform = platform + self.attributes = attributes + + @property + def arn(self): + return "arn:aws:sns:{region}:123456789012:app/{platform}/{name}".format( + region=self.region, + platform=self.platform, + name=self.name, + ) + + +class PlatformEndpoint(object): + def __init__(self, region, application, custom_user_data, token, attributes): + self.region = region + self.application = application + self.custom_user_data = custom_user_data + self.token = token + self.attributes = attributes + self.id = uuid.uuid4() + + @property + def arn(self): + return "arn:aws:sns:{region}:123456789012:endpoint/{platform}/{name}/{id}".format( + region=self.region, + platform=self.application.platform, + name=self.application.name, + id=self.id, + ) + + def publish(self, message): + message_id = six.text_type(uuid.uuid4()) + # This is where we would actually send a message + return message_id + + class SNSBackend(BaseBackend): def __init__(self): self.topics = OrderedDict() self.subscriptions = OrderedDict() + self.applications = {} + self.platform_endpoints = {} def create_topic(self, name): topic = Topic(name, self) @@ -121,7 +164,10 @@ class SNSBackend(BaseBackend): self.topics.pop(arn) def get_topic(self, arn): - return self.topics[arn] + try: + return self.topics[arn] + except KeyError: + raise SNSNotFoundError("Topic with arn {} not found".format(arn)) def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): topic = self.get_topic(topic_arn) @@ -144,11 +190,61 @@ class SNSBackend(BaseBackend): else: return self._get_values_nexttoken(self.subscriptions, next_token) - def publish(self, topic_arn, message): - topic = self.get_topic(topic_arn) - message_id = topic.publish(message) + def publish(self, arn, message): + try: + topic = self.get_topic(arn) + message_id = topic.publish(message) + except SNSNotFoundError: + endpoint = self.get_endpoint(arn) + message_id = endpoint.publish(message) return message_id + def create_platform_application(self, region, name, platform, attributes): + application = PlatformApplication(region, name, platform, attributes) + self.applications[application.arn] = application + return application + + def get_application(self, arn): + try: + return self.applications[arn] + except KeyError: + raise SNSNotFoundError("Application with arn {} not found".format(arn)) + + def set_application_attributes(self, arn, attributes): + application = self.get_application(arn) + application.attributes.update(attributes) + return application + + def list_platform_applications(self): + return self.applications.values() + + def delete_platform_application(self, platform_arn): + self.applications.pop(platform_arn) + + def create_platform_endpoint(self, region, application, custom_user_data, token, attributes): + platform_endpoint = PlatformEndpoint(region, application, custom_user_data, token, attributes) + self.platform_endpoints[platform_endpoint.arn] = platform_endpoint + return platform_endpoint + + def list_endpoints_by_platform_application(self, application_arn): + return [ + endpoint for endpoint + in self.platform_endpoints.values() + if endpoint.application.arn == application_arn + ] + + def get_endpoint(self, arn): + try: + return self.platform_endpoints[arn] + except KeyError: + raise SNSNotFoundError("Endpoint with arn {} not found".format(arn)) + + def set_endpoint_attributes(self, arn, attributes): + endpoint = self.get_endpoint(arn) + endpoint.attributes.update(attributes) + return endpoint + + sns_backends = {} for region in boto.sns.regions(): sns_backends[region.name] = SNSBackend() diff --git a/moto/sns/responses.py b/moto/sns/responses.py index cf500376a..b18d40bb7 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -170,9 +170,11 @@ class SNSResponse(BaseResponse): }) def publish(self): + target_arn = self._get_param('TargetArn') topic_arn = self._get_param('TopicArn') + arn = target_arn if target_arn else topic_arn message = self._get_param('Message') - message_id = self.backend.publish(topic_arn, message) + message_id = self.backend.publish(arn, message) return json.dumps({ "PublishResponse": { @@ -185,19 +187,129 @@ class SNSResponse(BaseResponse): } }) + def create_platform_application(self): + name = self._get_param('Name') + platform = self._get_param('Platform') + attributes = self._get_list_prefix('Attributes.entry') + attributes = { + attribute['key']: attribute['value'] + for attribute + in attributes + } + platform_application = self.backend.create_platform_application(self.region, name, platform, attributes) + + return json.dumps({ + "CreatePlatformApplicationResponse": { + "CreatePlatformApplicationResult": { + "PlatformApplicationArn": platform_application.arn, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937b", + } + } + }) + + def get_platform_application_attributes(self): + arn = self._get_param('PlatformApplicationArn') + application = self.backend.get_application(arn) + + return json.dumps({ + "GetPlatformApplicationAttributesResponse": { + "GetPlatformApplicationAttributesResult": { + "Attributes": application.attributes, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937f", + } + } + }) + + def set_platform_application_attributes(self): + arn = self._get_param('PlatformApplicationArn') + attributes = self._get_list_prefix('Attributes.entry') + attributes = { + attribute['key']: attribute['value'] + for attribute + in attributes + } + self.backend.set_application_attributes(arn, attributes) + + return json.dumps({ + "SetPlatformApplicationAttributesResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-12df-8963-01868b7c937f", + } + } + }) + + def list_platform_applications(self): + applications = self.backend.list_platform_applications() + + return json.dumps({ + "ListPlatformApplicationsResponse": { + "ListPlatformApplicationsResult": { + "PlatformApplications": [{ + "PlatformApplicationArn": application.arn, + "attributes": application.attributes, + } for application in applications] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937c", + } + } + }) + + def delete_platform_application(self): + platform_arn = self._get_param('PlatformApplicationArn') + self.backend.delete_platform_application(platform_arn) + + return json.dumps({ + "DeletePlatformApplicationResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937e", + } + } + }) + + def create_platform_endpoint(self): + application_arn = self._get_param('PlatformApplicationArn') + application = self.backend.get_application(application_arn) + + custom_user_data = self._get_param('CustomUserData') + token = self._get_param('Token') + attributes = self._get_list_prefix('Attributes.entry') + attributes = { + attribute['key']: attribute['value'] + for attribute + in attributes + } + + platform_endpoint = self.backend.create_platform_endpoint( + self.region, application, custom_user_data, token, attributes) + + return json.dumps({ + "CreatePlatformEndpointResponse": { + "CreatePlatformEndpointResult": { + "EndpointArn": platform_endpoint.arn, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3779-11df-8963-01868b7c937b", + } + } + }) + def list_endpoints_by_platform_application(self): + application_arn = self._get_param('PlatformApplicationArn') + endpoints = self.backend.list_endpoints_by_platform_application(application_arn) + return json.dumps({ "ListEndpointsByPlatformApplicationResponse": { "ListEndpointsByPlatformApplicationResult": { "Endpoints": [ { - "Attributes": { - "Token": "TOKEN", - "Enabled": "true", - "CustomUserData": "" - }, - "EndpointArn": "FAKE_ARN_ENDPOINT" - } + "Attributes": endpoint.attributes, + "EndpointArn": endpoint.arn, + } for endpoint in endpoints ], "NextToken": None }, @@ -206,3 +318,36 @@ class SNSResponse(BaseResponse): } } }) + + def get_endpoint_attributes(self): + arn = self._get_param('EndpointArn') + endpoint = self.backend.get_endpoint(arn) + + return json.dumps({ + "GetEndpointAttributesResponse": { + "GetEndpointAttributesResult": { + "Attributes": endpoint.attributes, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937f", + } + } + }) + + def set_endpoint_attributes(self): + arn = self._get_param('EndpointArn') + attributes = self._get_list_prefix('Attributes.entry') + attributes = { + attribute['key']: attribute['value'] + for attribute + in attributes + } + self.backend.set_endpoint_attributes(arn, attributes) + + return json.dumps({ + "SetEndpointAttributesResponse": { + "ResponseMetadata": { + "RequestId": "384bc68d-3775-12df-8963-01868b7c937f", + } + } + }) diff --git a/tests/test_sns/test_application.py b/tests/test_sns/test_application.py index 24c5a1fbd..87d0316d6 100644 --- a/tests/test_sns/test_application.py +++ b/tests/test_sns/test_application.py @@ -1,16 +1,245 @@ from __future__ import unicode_literals -import boto +import boto +from boto.exception import BotoServerError from moto import mock_sns +import sure # noqa + + +@mock_sns +def test_create_platform_application(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + attributes={ + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + }, + ) + application_arn = platform_application['CreatePlatformApplicationResponse']['CreatePlatformApplicationResult']['PlatformApplicationArn'] + application_arn.should.equal('arn:aws:sns:us-east-1:123456789012:app/APNS/my-application') + + +@mock_sns +def test_get_platform_application_attributes(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + attributes={ + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + }, + ) + arn = platform_application['CreatePlatformApplicationResponse']['CreatePlatformApplicationResult']['PlatformApplicationArn'] + attributes = conn.get_platform_application_attributes(arn)['GetPlatformApplicationAttributesResponse']['GetPlatformApplicationAttributesResult']['Attributes'] + attributes.should.equal({ + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + }) + + +@mock_sns +def test_get_missing_platform_application_attributes(): + conn = boto.connect_sns() + conn.get_platform_application_attributes.when.called_with("a-fake-arn").should.throw(BotoServerError) + + +@mock_sns +def test_set_platform_application_attributes(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + attributes={ + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + }, + ) + arn = platform_application['CreatePlatformApplicationResponse']['CreatePlatformApplicationResult']['PlatformApplicationArn'] + conn.set_platform_application_attributes(arn, + {"PlatformPrincipal": "other"} + ) + attributes = conn.get_platform_application_attributes(arn)['GetPlatformApplicationAttributesResponse']['GetPlatformApplicationAttributesResult']['Attributes'] + attributes.should.equal({ + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "other", + }) + + +@mock_sns +def test_list_platform_applications(): + conn = boto.connect_sns() + conn.create_platform_application( + name="application1", + platform="APNS", + ) + conn.create_platform_application( + name="application2", + platform="APNS", + ) + + applications_repsonse = conn.list_platform_applications() + applications = applications_repsonse['ListPlatformApplicationsResponse']['ListPlatformApplicationsResult']['PlatformApplications'] + applications.should.have.length_of(2) + + +@mock_sns +def test_delete_platform_application(): + conn = boto.connect_sns() + conn.create_platform_application( + name="application1", + platform="APNS", + ) + conn.create_platform_application( + name="application2", + platform="APNS", + ) + + applications_repsonse = conn.list_platform_applications() + applications = applications_repsonse['ListPlatformApplicationsResponse']['ListPlatformApplicationsResult']['PlatformApplications'] + applications.should.have.length_of(2) + + application_arn = applications[0]['PlatformApplicationArn'] + conn.delete_platform_application(application_arn) + + applications_repsonse = conn.list_platform_applications() + applications = applications_repsonse['ListPlatformApplicationsResponse']['ListPlatformApplicationsResult']['PlatformApplications'] + applications.should.have.length_of(1) + + +@mock_sns +def test_create_platform_endpoint(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + ) + application_arn = platform_application['CreatePlatformApplicationResponse']['CreatePlatformApplicationResult']['PlatformApplicationArn'] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={ + "Enabled": False, + }, + ) + + endpoint_arn = endpoint['CreatePlatformEndpointResponse']['CreatePlatformEndpointResult']['EndpointArn'] + endpoint_arn.should.contain("arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/") @mock_sns def test_get_list_endpoints_by_platform_application(): conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + ) + application_arn = platform_application['CreatePlatformApplicationResponse']['CreatePlatformApplicationResult']['PlatformApplicationArn'] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={ + "CustomUserData": "some data", + }, + ) + endpoint_arn = endpoint['CreatePlatformEndpointResponse']['CreatePlatformEndpointResult']['EndpointArn'] + endpoint_list = conn.list_endpoints_by_platform_application( - platform_application_arn='fake_arn' + platform_application_arn=application_arn )['ListEndpointsByPlatformApplicationResponse']['ListEndpointsByPlatformApplicationResult']['Endpoints'] endpoint_list.should.have.length_of(1) - endpoint_list[0]['Attributes']['Enabled'].should.equal('true') - endpoint_list[0]['EndpointArn'].should.equal('FAKE_ARN_ENDPOINT') + endpoint_list[0]['Attributes']['CustomUserData'].should.equal('some data') + endpoint_list[0]['EndpointArn'].should.equal(endpoint_arn) + + +@mock_sns +def test_get_endpoint_attributes(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + ) + application_arn = platform_application['CreatePlatformApplicationResponse']['CreatePlatformApplicationResult']['PlatformApplicationArn'] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={ + "Enabled": False, + "CustomUserData": "some data", + }, + ) + endpoint_arn = endpoint['CreatePlatformEndpointResponse']['CreatePlatformEndpointResult']['EndpointArn'] + + attributes = conn.get_endpoint_attributes(endpoint_arn)['GetEndpointAttributesResponse']['GetEndpointAttributesResult']['Attributes'] + attributes.should.equal({ + "Enabled": 'False', + "CustomUserData": "some data", + }) + + +@mock_sns +def test_get_missing_endpoint_attributes(): + conn = boto.connect_sns() + conn.get_endpoint_attributes.when.called_with("a-fake-arn").should.throw(BotoServerError) + + +@mock_sns +def test_set_endpoint_attributes(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + ) + application_arn = platform_application['CreatePlatformApplicationResponse']['CreatePlatformApplicationResult']['PlatformApplicationArn'] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={ + "Enabled": False, + "CustomUserData": "some data", + }, + ) + endpoint_arn = endpoint['CreatePlatformEndpointResponse']['CreatePlatformEndpointResult']['EndpointArn'] + + conn.set_endpoint_attributes(endpoint_arn, + {"CustomUserData": "other data"} + ) + attributes = conn.get_endpoint_attributes(endpoint_arn)['GetEndpointAttributesResponse']['GetEndpointAttributesResult']['Attributes'] + attributes.should.equal({ + "Enabled": 'False', + "CustomUserData": "other data", + }) + + +@mock_sns +def test_publish_to_platform_endpoint(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + ) + application_arn = platform_application['CreatePlatformApplicationResponse']['CreatePlatformApplicationResult']['PlatformApplicationArn'] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={ + "Enabled": False, + }, + ) + + endpoint_arn = endpoint['CreatePlatformEndpointResponse']['CreatePlatformEndpointResult']['EndpointArn'] + + conn.publish(message="some message", message_structure="json", target_arn=endpoint_arn) diff --git a/tests/test_sns/test_topics.py b/tests/test_sns/test_topics.py index 817426244..e2488f3d2 100644 --- a/tests/test_sns/test_topics.py +++ b/tests/test_sns/test_topics.py @@ -4,6 +4,7 @@ import six import sure # noqa +from boto.exception import BotoServerError from moto import mock_sns from moto.sns.models import DEFAULT_TOPIC_POLICY, DEFAULT_EFFECTIVE_DELIVERY_POLICY, DEFAULT_PAGE_SIZE @@ -27,6 +28,12 @@ def test_create_and_delete_topic(): topics.should.have.length_of(0) +@mock_sns +def test_get_missing_topic(): + conn = boto.connect_sns() + conn.get_topic_attributes.when.called_with("a-fake-arn").should.throw(BotoServerError) + + @mock_sns def test_create_topic_in_multiple_regions(): west1_conn = boto.sns.connect_to_region("us-west-1") From e2d75cba2c06a32c963a6d2d387a9a59c684fa3f Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Sat, 14 Mar 2015 09:13:58 -0400 Subject: [PATCH 2/3] Remove dict comprehension for py2.6 --- moto/sns/responses.py | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/moto/sns/responses.py b/moto/sns/responses.py index b18d40bb7..51381f008 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -12,6 +12,14 @@ class SNSResponse(BaseResponse): def backend(self): return sns_backends[self.region] + def _get_attributes(self): + attributes = self._get_list_prefix('Attributes.entry') + return dict( + (attribute['key'], attribute['value']) + for attribute + in attributes + ) + def create_topic(self): name = self._get_param('Name') topic = self.backend.create_topic(name) @@ -190,12 +198,7 @@ class SNSResponse(BaseResponse): def create_platform_application(self): name = self._get_param('Name') platform = self._get_param('Platform') - attributes = self._get_list_prefix('Attributes.entry') - attributes = { - attribute['key']: attribute['value'] - for attribute - in attributes - } + attributes = self._get_attributes() platform_application = self.backend.create_platform_application(self.region, name, platform, attributes) return json.dumps({ @@ -226,12 +229,8 @@ class SNSResponse(BaseResponse): def set_platform_application_attributes(self): arn = self._get_param('PlatformApplicationArn') - attributes = self._get_list_prefix('Attributes.entry') - attributes = { - attribute['key']: attribute['value'] - for attribute - in attributes - } + attributes = self._get_attributes() + self.backend.set_application_attributes(arn, attributes) return json.dumps({ @@ -277,12 +276,7 @@ class SNSResponse(BaseResponse): custom_user_data = self._get_param('CustomUserData') token = self._get_param('Token') - attributes = self._get_list_prefix('Attributes.entry') - attributes = { - attribute['key']: attribute['value'] - for attribute - in attributes - } + attributes = self._get_attributes() platform_endpoint = self.backend.create_platform_endpoint( self.region, application, custom_user_data, token, attributes) @@ -336,12 +330,8 @@ class SNSResponse(BaseResponse): def set_endpoint_attributes(self): arn = self._get_param('EndpointArn') - attributes = self._get_list_prefix('Attributes.entry') - attributes = { - attribute['key']: attribute['value'] - for attribute - in attributes - } + attributes = self._get_attributes() + self.backend.set_endpoint_attributes(arn, attributes) return json.dumps({ From ca39591ef2b0f49a1e560c81746ec4da1d2cb3d8 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Sat, 14 Mar 2015 09:19:36 -0400 Subject: [PATCH 3/3] Fix error string formatting for py26. --- moto/sns/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/moto/sns/models.py b/moto/sns/models.py index 8e9404690..32514fc0e 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -167,7 +167,7 @@ class SNSBackend(BaseBackend): try: return self.topics[arn] except KeyError: - raise SNSNotFoundError("Topic with arn {} not found".format(arn)) + raise SNSNotFoundError("Topic with arn {0} not found".format(arn)) def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): topic = self.get_topic(topic_arn) @@ -208,7 +208,7 @@ class SNSBackend(BaseBackend): try: return self.applications[arn] except KeyError: - raise SNSNotFoundError("Application with arn {} not found".format(arn)) + raise SNSNotFoundError("Application with arn {0} not found".format(arn)) def set_application_attributes(self, arn, attributes): application = self.get_application(arn) @@ -237,7 +237,7 @@ class SNSBackend(BaseBackend): try: return self.platform_endpoints[arn] except KeyError: - raise SNSNotFoundError("Endpoint with arn {} not found".format(arn)) + raise SNSNotFoundError("Endpoint with arn {0} not found".format(arn)) def set_endpoint_attributes(self, arn, attributes): endpoint = self.get_endpoint(arn)