From c207963a86b7dcc68006c760cab1f4169519b94a Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Thu, 16 Mar 2017 22:28:30 -0400 Subject: [PATCH] Cleanup SNS exceptions. Closes #751. --- moto/sns/exceptions.py | 16 +++++++ moto/sns/models.py | 13 +++++- tests/test_sns/test_application.py | 2 +- tests/test_sns/test_application_boto3.py | 59 +++++++++++++++++++++++- 4 files changed, 87 insertions(+), 3 deletions(-) diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py index 76e0bccb1..092bb9d69 100644 --- a/moto/sns/exceptions.py +++ b/moto/sns/exceptions.py @@ -8,3 +8,19 @@ class SNSNotFoundError(RESTError): def __init__(self, message): super(SNSNotFoundError, self).__init__( "NotFound", message) + + +class DuplicateSnsEndpointError(RESTError): + code = 400 + + def __init__(self, message): + super(DuplicateSnsEndpointError, self).__init__( + "DuplicateEndpoint", message) + + +class SnsEndpointDisabled(RESTError): + code = 400 + + def __init__(self, message): + super(SnsEndpointDisabled, self).__init__( + "EndpointDisabled", message) diff --git a/moto/sns/models.py b/moto/sns/models.py index 64352d545..5289c8bcd 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -12,7 +12,9 @@ from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.sqs import sqs_backends -from .exceptions import SNSNotFoundError +from .exceptions import ( + SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled +) from .utils import make_arn_for_topic, make_arn_for_subscription DEFAULT_ACCOUNT_ID = 123456789012 @@ -136,6 +138,10 @@ class PlatformEndpoint(BaseModel): if 'Enabled' not in self.attributes: self.attributes['Enabled'] = True + @property + def enabled(self): + return json.loads(self.attributes.get('Enabled', 'true').lower()) + @property def arn(self): return "arn:aws:sns:{region}:123456789012:endpoint/{platform}/{name}/{id}".format( @@ -146,6 +152,9 @@ class PlatformEndpoint(BaseModel): ) def publish(self, message): + if not self.enabled: + raise SnsEndpointDisabled("Endpoint %s disabled" % self.id) + # This is where we would actually send a message message_id = six.text_type(uuid.uuid4()) self.messages[message_id] = message @@ -251,6 +260,8 @@ class SNSBackend(BaseBackend): self.applications.pop(platform_arn) def create_platform_endpoint(self, region, application, custom_user_data, token, attributes): + if any(token == endpoint.token for endpoint in self.platform_endpoints.values()): + raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token) platform_endpoint = PlatformEndpoint( region, application, custom_user_data, token, attributes) self.platform_endpoints[platform_endpoint.arn] = platform_endpoint diff --git a/tests/test_sns/test_application.py b/tests/test_sns/test_application.py index 613b11af5..319e4a6f8 100644 --- a/tests/test_sns/test_application.py +++ b/tests/test_sns/test_application.py @@ -297,7 +297,7 @@ def test_publish_to_platform_endpoint(): token="some_unique_id", custom_user_data="some user data", attributes={ - "Enabled": False, + "Enabled": True, }, ) diff --git a/tests/test_sns/test_application_boto3.py b/tests/test_sns/test_application_boto3.py index 968240b15..99c378fe4 100644 --- a/tests/test_sns/test_application_boto3.py +++ b/tests/test_sns/test_application_boto3.py @@ -142,6 +142,35 @@ def test_create_platform_endpoint(): "arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/") +@mock_sns +def test_create_duplicate_platform_endpoint(): + conn = boto3.client('sns', region_name='us-east-1') + platform_application = conn.create_platform_application( + Name="my-application", + Platform="APNS", + Attributes={}, + ) + application_arn = platform_application['PlatformApplicationArn'] + + endpoint = conn.create_platform_endpoint( + PlatformApplicationArn=application_arn, + Token="some_unique_id", + CustomUserData="some user data", + Attributes={ + "Enabled": 'false', + }, + ) + + endpoint = conn.create_platform_endpoint.when.called_with( + PlatformApplicationArn=application_arn, + Token="some_unique_id", + CustomUserData="some user data", + Attributes={ + "Enabled": 'false', + }, + ).should.throw(ClientError) + + @mock_sns def test_get_list_endpoints_by_platform_application(): conn = boto3.client('sns', region_name='us-east-1') @@ -256,7 +285,7 @@ def test_publish_to_platform_endpoint(): Token="some_unique_id", CustomUserData="some user data", Attributes={ - "Enabled": 'false', + "Enabled": 'true', }, ) @@ -264,3 +293,31 @@ def test_publish_to_platform_endpoint(): conn.publish(Message="some message", MessageStructure="json", TargetArn=endpoint_arn) + + +@mock_sns +def test_publish_to_disabled_platform_endpoint(): + conn = boto3.client('sns', region_name='us-east-1') + platform_application = conn.create_platform_application( + Name="my-application", + Platform="APNS", + Attributes={}, + ) + application_arn = platform_application['PlatformApplicationArn'] + + endpoint = conn.create_platform_endpoint( + PlatformApplicationArn=application_arn, + Token="some_unique_id", + CustomUserData="some user data", + Attributes={ + "Enabled": 'false', + }, + ) + + endpoint_arn = endpoint['EndpointArn'] + + conn.publish.when.called_with( + Message="some message", + MessageStructure="json", + TargetArn=endpoint_arn, + ).should.throw(ClientError)