Cleanup SNS exceptions. Closes #751.

This commit is contained in:
Steve Pulec 2017-03-16 22:28:30 -04:00
parent e7a3f3408e
commit c207963a86
4 changed files with 87 additions and 3 deletions

View File

@ -8,3 +8,19 @@ class SNSNotFoundError(RESTError):
def __init__(self, message):
super(SNSNotFoundError, self).__init__(
"NotFound", message)
class DuplicateSnsEndpointError(RESTError):
code = 400
def __init__(self, message):
super(DuplicateSnsEndpointError, self).__init__(
"DuplicateEndpoint", message)
class SnsEndpointDisabled(RESTError):
code = 400
def __init__(self, message):
super(SnsEndpointDisabled, self).__init__(
"EndpointDisabled", message)

View File

@ -12,7 +12,9 @@ from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.sqs import sqs_backends
from .exceptions import SNSNotFoundError
from .exceptions import (
SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled
)
from .utils import make_arn_for_topic, make_arn_for_subscription
DEFAULT_ACCOUNT_ID = 123456789012
@ -136,6 +138,10 @@ class PlatformEndpoint(BaseModel):
if 'Enabled' not in self.attributes:
self.attributes['Enabled'] = True
@property
def enabled(self):
return json.loads(self.attributes.get('Enabled', 'true').lower())
@property
def arn(self):
return "arn:aws:sns:{region}:123456789012:endpoint/{platform}/{name}/{id}".format(
@ -146,6 +152,9 @@ class PlatformEndpoint(BaseModel):
)
def publish(self, message):
if not self.enabled:
raise SnsEndpointDisabled("Endpoint %s disabled" % self.id)
# This is where we would actually send a message
message_id = six.text_type(uuid.uuid4())
self.messages[message_id] = message
@ -251,6 +260,8 @@ class SNSBackend(BaseBackend):
self.applications.pop(platform_arn)
def create_platform_endpoint(self, region, application, custom_user_data, token, attributes):
if any(token == endpoint.token for endpoint in self.platform_endpoints.values()):
raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token)
platform_endpoint = PlatformEndpoint(
region, application, custom_user_data, token, attributes)
self.platform_endpoints[platform_endpoint.arn] = platform_endpoint

View File

@ -297,7 +297,7 @@ def test_publish_to_platform_endpoint():
token="some_unique_id",
custom_user_data="some user data",
attributes={
"Enabled": False,
"Enabled": True,
},
)

View File

@ -142,6 +142,35 @@ def test_create_platform_endpoint():
"arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/")
@mock_sns
def test_create_duplicate_platform_endpoint():
conn = boto3.client('sns', region_name='us-east-1')
platform_application = conn.create_platform_application(
Name="my-application",
Platform="APNS",
Attributes={},
)
application_arn = platform_application['PlatformApplicationArn']
endpoint = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="some_unique_id",
CustomUserData="some user data",
Attributes={
"Enabled": 'false',
},
)
endpoint = conn.create_platform_endpoint.when.called_with(
PlatformApplicationArn=application_arn,
Token="some_unique_id",
CustomUserData="some user data",
Attributes={
"Enabled": 'false',
},
).should.throw(ClientError)
@mock_sns
def test_get_list_endpoints_by_platform_application():
conn = boto3.client('sns', region_name='us-east-1')
@ -256,7 +285,7 @@ def test_publish_to_platform_endpoint():
Token="some_unique_id",
CustomUserData="some user data",
Attributes={
"Enabled": 'false',
"Enabled": 'true',
},
)
@ -264,3 +293,31 @@ def test_publish_to_platform_endpoint():
conn.publish(Message="some message",
MessageStructure="json", TargetArn=endpoint_arn)
@mock_sns
def test_publish_to_disabled_platform_endpoint():
conn = boto3.client('sns', region_name='us-east-1')
platform_application = conn.create_platform_application(
Name="my-application",
Platform="APNS",
Attributes={},
)
application_arn = platform_application['PlatformApplicationArn']
endpoint = conn.create_platform_endpoint(
PlatformApplicationArn=application_arn,
Token="some_unique_id",
CustomUserData="some user data",
Attributes={
"Enabled": 'false',
},
)
endpoint_arn = endpoint['EndpointArn']
conn.publish.when.called_with(
Message="some message",
MessageStructure="json",
TargetArn=endpoint_arn,
).should.throw(ClientError)