diff --git a/moto/sns/__init__.py b/moto/sns/__init__.py index f64c5e133..1aa1a0e3e 100644 --- a/moto/sns/__init__.py +++ b/moto/sns/__init__.py @@ -1,3 +1,12 @@ from __future__ import unicode_literals -from .models import sns_backend -mock_sns = sns_backend.decorator +from .models import sns_backends +from ..core.models import MockAWS + +sns_backend = sns_backends['us-east-1'] + + +def mock_sns(func=None): + if func: + return MockAWS(sns_backends)(func) + else: + return MockAWS(sns_backends) diff --git a/moto/sns/models.py b/moto/sns/models.py index 886007aee..891ccee43 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -1,7 +1,10 @@ from __future__ import unicode_literals + import datetime -import requests import uuid + +import boto.sns +import requests import six from moto.core import BaseBackend @@ -13,8 +16,9 @@ DEFAULT_ACCOUNT_ID = 123456789012 class Topic(object): - def __init__(self, name): + def __init__(self, name, sns_backend): self.name = name + self.sns_backend = sns_backend self.account_id = DEFAULT_ACCOUNT_ID self.display_name = "" self.policy = DEFAULT_TOPIC_POLICY @@ -28,7 +32,7 @@ class Topic(object): def publish(self, message): message_id = six.text_type(uuid.uuid4()) - subscriptions = sns_backend.list_subscriptions(self.arn) + subscriptions = self.sns_backend.list_subscriptions(self.arn) for subscription in subscriptions: subscription.publish(message, message_id) return message_id @@ -76,7 +80,7 @@ class SNSBackend(BaseBackend): self.subscriptions = {} def create_topic(self, name): - topic = Topic(name) + topic = Topic(name, self) self.topics[topic.arn] = topic return topic @@ -114,8 +118,9 @@ class SNSBackend(BaseBackend): message_id = topic.publish(message) return message_id - -sns_backend = SNSBackend() +sns_backends = {} +for region in boto.sns.regions(): + sns_backends[region.name] = SNSBackend() DEFAULT_TOPIC_POLICY = { diff --git a/moto/sns/responses.py b/moto/sns/responses.py index dc3735f6b..19ebdd04c 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -3,14 +3,18 @@ import json from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores -from .models import sns_backend +from .models import sns_backends class SNSResponse(BaseResponse): + @property + def backend(self): + return sns_backends[self.region] + def create_topic(self): name = self._get_param('Name') - topic = sns_backend.create_topic(name) + topic = self.backend.create_topic(name) return json.dumps({ 'CreateTopicResponse': { @@ -24,7 +28,7 @@ class SNSResponse(BaseResponse): }) def list_topics(self): - topics = sns_backend.list_topics() + topics = self.backend.list_topics() return json.dumps({ 'ListTopicsResponse': { @@ -40,7 +44,7 @@ class SNSResponse(BaseResponse): def delete_topic(self): topic_arn = self._get_param('TopicArn') - sns_backend.delete_topic(topic_arn) + self.backend.delete_topic(topic_arn) return json.dumps({ 'DeleteTopicResponse': { @@ -52,7 +56,7 @@ class SNSResponse(BaseResponse): def get_topic_attributes(self): topic_arn = self._get_param('TopicArn') - topic = sns_backend.get_topic(topic_arn) + topic = self.backend.get_topic(topic_arn) return json.dumps({ "GetTopicAttributesResponse": { @@ -80,7 +84,7 @@ class SNSResponse(BaseResponse): attribute_name = self._get_param('AttributeName') attribute_name = camelcase_to_underscores(attribute_name) attribute_value = self._get_param('AttributeValue') - sns_backend.set_topic_attribute(topic_arn, attribute_name, attribute_value) + self.backend.set_topic_attribute(topic_arn, attribute_name, attribute_value) return json.dumps({ "SetTopicAttributesResponse": { @@ -94,7 +98,7 @@ class SNSResponse(BaseResponse): topic_arn = self._get_param('TopicArn') endpoint = self._get_param('Endpoint') protocol = self._get_param('Protocol') - subscription = sns_backend.subscribe(topic_arn, endpoint, protocol) + subscription = self.backend.subscribe(topic_arn, endpoint, protocol) return json.dumps({ "SubscribeResponse": { @@ -109,7 +113,7 @@ class SNSResponse(BaseResponse): def unsubscribe(self): subscription_arn = self._get_param('SubscriptionArn') - sns_backend.unsubscribe(subscription_arn) + self.backend.unsubscribe(subscription_arn) return json.dumps({ "UnsubscribeResponse": { @@ -120,7 +124,7 @@ class SNSResponse(BaseResponse): }) def list_subscriptions(self): - subscriptions = sns_backend.list_subscriptions() + subscriptions = self.backend.list_subscriptions() return json.dumps({ "ListSubscriptionsResponse": { @@ -142,7 +146,7 @@ class SNSResponse(BaseResponse): def list_subscriptions_by_topic(self): topic_arn = self._get_param('TopicArn') - subscriptions = sns_backend.list_subscriptions(topic_arn) + subscriptions = self.backend.list_subscriptions(topic_arn) return json.dumps({ "ListSubscriptionsByTopicResponse": { @@ -165,7 +169,7 @@ class SNSResponse(BaseResponse): def publish(self): topic_arn = self._get_param('TopicArn') message = self._get_param('Message') - message_id = sns_backend.publish(topic_arn, message) + message_id = self.backend.publish(topic_arn, message) return json.dumps({ "PublishResponse": { diff --git a/tests/test_sns/test_topics.py b/tests/test_sns/test_topics.py index b4129425b..427c8c003 100644 --- a/tests/test_sns/test_topics.py +++ b/tests/test_sns/test_topics.py @@ -27,6 +27,18 @@ def test_create_and_delete_topic(): topics.should.have.length_of(0) +@mock_sns +def test_create_topic_in_multiple_regions(): + west1_conn = boto.sns.connect_to_region("us-west-1") + west1_conn.create_topic("some-topic") + + west2_conn = boto.sns.connect_to_region("us-west-2") + west2_conn.create_topic("some-topic") + + list(west1_conn.get_all_topics()["ListTopicsResponse"]["ListTopicsResult"]["Topics"]).should.have.length_of(1) + list(west2_conn.get_all_topics()["ListTopicsResponse"]["ListTopicsResult"]["Topics"]).should.have.length_of(1) + + @mock_sns def test_topic_attributes(): conn = boto.connect_sns()