Make SNS multi-region.

This commit is contained in:
Steve Pulec 2014-11-16 18:35:11 -05:00
parent 2b775aa075
commit 53acdf6c76
4 changed files with 49 additions and 19 deletions

View File

@ -1,3 +1,12 @@
from __future__ import unicode_literals
from .models import sns_backend
mock_sns = sns_backend.decorator
from .models import sns_backends
from ..core.models import MockAWS
sns_backend = sns_backends['us-east-1']
def mock_sns(func=None):
if func:
return MockAWS(sns_backends)(func)
else:
return MockAWS(sns_backends)

View File

@ -1,7 +1,10 @@
from __future__ import unicode_literals
import datetime
import requests
import uuid
import boto.sns
import requests
import six
from moto.core import BaseBackend
@ -13,8 +16,9 @@ DEFAULT_ACCOUNT_ID = 123456789012
class Topic(object):
def __init__(self, name):
def __init__(self, name, sns_backend):
self.name = name
self.sns_backend = sns_backend
self.account_id = DEFAULT_ACCOUNT_ID
self.display_name = ""
self.policy = DEFAULT_TOPIC_POLICY
@ -28,7 +32,7 @@ class Topic(object):
def publish(self, message):
message_id = six.text_type(uuid.uuid4())
subscriptions = sns_backend.list_subscriptions(self.arn)
subscriptions = self.sns_backend.list_subscriptions(self.arn)
for subscription in subscriptions:
subscription.publish(message, message_id)
return message_id
@ -76,7 +80,7 @@ class SNSBackend(BaseBackend):
self.subscriptions = {}
def create_topic(self, name):
topic = Topic(name)
topic = Topic(name, self)
self.topics[topic.arn] = topic
return topic
@ -114,8 +118,9 @@ class SNSBackend(BaseBackend):
message_id = topic.publish(message)
return message_id
sns_backend = SNSBackend()
sns_backends = {}
for region in boto.sns.regions():
sns_backends[region.name] = SNSBackend()
DEFAULT_TOPIC_POLICY = {

View File

@ -3,14 +3,18 @@ import json
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from .models import sns_backend
from .models import sns_backends
class SNSResponse(BaseResponse):
@property
def backend(self):
return sns_backends[self.region]
def create_topic(self):
name = self._get_param('Name')
topic = sns_backend.create_topic(name)
topic = self.backend.create_topic(name)
return json.dumps({
'CreateTopicResponse': {
@ -24,7 +28,7 @@ class SNSResponse(BaseResponse):
})
def list_topics(self):
topics = sns_backend.list_topics()
topics = self.backend.list_topics()
return json.dumps({
'ListTopicsResponse': {
@ -40,7 +44,7 @@ class SNSResponse(BaseResponse):
def delete_topic(self):
topic_arn = self._get_param('TopicArn')
sns_backend.delete_topic(topic_arn)
self.backend.delete_topic(topic_arn)
return json.dumps({
'DeleteTopicResponse': {
@ -52,7 +56,7 @@ class SNSResponse(BaseResponse):
def get_topic_attributes(self):
topic_arn = self._get_param('TopicArn')
topic = sns_backend.get_topic(topic_arn)
topic = self.backend.get_topic(topic_arn)
return json.dumps({
"GetTopicAttributesResponse": {
@ -80,7 +84,7 @@ class SNSResponse(BaseResponse):
attribute_name = self._get_param('AttributeName')
attribute_name = camelcase_to_underscores(attribute_name)
attribute_value = self._get_param('AttributeValue')
sns_backend.set_topic_attribute(topic_arn, attribute_name, attribute_value)
self.backend.set_topic_attribute(topic_arn, attribute_name, attribute_value)
return json.dumps({
"SetTopicAttributesResponse": {
@ -94,7 +98,7 @@ class SNSResponse(BaseResponse):
topic_arn = self._get_param('TopicArn')
endpoint = self._get_param('Endpoint')
protocol = self._get_param('Protocol')
subscription = sns_backend.subscribe(topic_arn, endpoint, protocol)
subscription = self.backend.subscribe(topic_arn, endpoint, protocol)
return json.dumps({
"SubscribeResponse": {
@ -109,7 +113,7 @@ class SNSResponse(BaseResponse):
def unsubscribe(self):
subscription_arn = self._get_param('SubscriptionArn')
sns_backend.unsubscribe(subscription_arn)
self.backend.unsubscribe(subscription_arn)
return json.dumps({
"UnsubscribeResponse": {
@ -120,7 +124,7 @@ class SNSResponse(BaseResponse):
})
def list_subscriptions(self):
subscriptions = sns_backend.list_subscriptions()
subscriptions = self.backend.list_subscriptions()
return json.dumps({
"ListSubscriptionsResponse": {
@ -142,7 +146,7 @@ class SNSResponse(BaseResponse):
def list_subscriptions_by_topic(self):
topic_arn = self._get_param('TopicArn')
subscriptions = sns_backend.list_subscriptions(topic_arn)
subscriptions = self.backend.list_subscriptions(topic_arn)
return json.dumps({
"ListSubscriptionsByTopicResponse": {
@ -165,7 +169,7 @@ class SNSResponse(BaseResponse):
def publish(self):
topic_arn = self._get_param('TopicArn')
message = self._get_param('Message')
message_id = sns_backend.publish(topic_arn, message)
message_id = self.backend.publish(topic_arn, message)
return json.dumps({
"PublishResponse": {

View File

@ -27,6 +27,18 @@ def test_create_and_delete_topic():
topics.should.have.length_of(0)
@mock_sns
def test_create_topic_in_multiple_regions():
west1_conn = boto.sns.connect_to_region("us-west-1")
west1_conn.create_topic("some-topic")
west2_conn = boto.sns.connect_to_region("us-west-2")
west2_conn.create_topic("some-topic")
list(west1_conn.get_all_topics()["ListTopicsResponse"]["ListTopicsResult"]["Topics"]).should.have.length_of(1)
list(west2_conn.get_all_topics()["ListTopicsResponse"]["ListTopicsResult"]["Topics"]).should.have.length_of(1)
@mock_sns
def test_topic_attributes():
conn = boto.connect_sns()