Cleanup SNS exceptions. Closes #751.
This commit is contained in:
		
							parent
							
								
									e7a3f3408e
								
							
						
					
					
						commit
						c207963a86
					
				| @ -8,3 +8,19 @@ class SNSNotFoundError(RESTError): | |||||||
|     def __init__(self, message): |     def __init__(self, message): | ||||||
|         super(SNSNotFoundError, self).__init__( |         super(SNSNotFoundError, self).__init__( | ||||||
|             "NotFound", message) |             "NotFound", message) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DuplicateSnsEndpointError(RESTError): | ||||||
|  |     code = 400 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, message): | ||||||
|  |         super(DuplicateSnsEndpointError, self).__init__( | ||||||
|  |             "DuplicateEndpoint", message) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class SnsEndpointDisabled(RESTError): | ||||||
|  |     code = 400 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, message): | ||||||
|  |         super(SnsEndpointDisabled, self).__init__( | ||||||
|  |             "EndpointDisabled", message) | ||||||
|  | |||||||
| @ -12,7 +12,9 @@ from moto.compat import OrderedDict | |||||||
| from moto.core import BaseBackend, BaseModel | from moto.core import BaseBackend, BaseModel | ||||||
| from moto.core.utils import iso_8601_datetime_with_milliseconds | from moto.core.utils import iso_8601_datetime_with_milliseconds | ||||||
| from moto.sqs import sqs_backends | from moto.sqs import sqs_backends | ||||||
| from .exceptions import SNSNotFoundError | from .exceptions import ( | ||||||
|  |     SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled | ||||||
|  | ) | ||||||
| from .utils import make_arn_for_topic, make_arn_for_subscription | from .utils import make_arn_for_topic, make_arn_for_subscription | ||||||
| 
 | 
 | ||||||
| DEFAULT_ACCOUNT_ID = 123456789012 | DEFAULT_ACCOUNT_ID = 123456789012 | ||||||
| @ -136,6 +138,10 @@ class PlatformEndpoint(BaseModel): | |||||||
|         if 'Enabled' not in self.attributes: |         if 'Enabled' not in self.attributes: | ||||||
|             self.attributes['Enabled'] = True |             self.attributes['Enabled'] = True | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def enabled(self): | ||||||
|  |         return json.loads(self.attributes.get('Enabled', 'true').lower()) | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def arn(self): |     def arn(self): | ||||||
|         return "arn:aws:sns:{region}:123456789012:endpoint/{platform}/{name}/{id}".format( |         return "arn:aws:sns:{region}:123456789012:endpoint/{platform}/{name}/{id}".format( | ||||||
| @ -146,6 +152,9 @@ class PlatformEndpoint(BaseModel): | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def publish(self, message): |     def publish(self, message): | ||||||
|  |         if not self.enabled: | ||||||
|  |             raise SnsEndpointDisabled("Endpoint %s disabled" % self.id) | ||||||
|  | 
 | ||||||
|         # This is where we would actually send a message |         # This is where we would actually send a message | ||||||
|         message_id = six.text_type(uuid.uuid4()) |         message_id = six.text_type(uuid.uuid4()) | ||||||
|         self.messages[message_id] = message |         self.messages[message_id] = message | ||||||
| @ -251,6 +260,8 @@ class SNSBackend(BaseBackend): | |||||||
|         self.applications.pop(platform_arn) |         self.applications.pop(platform_arn) | ||||||
| 
 | 
 | ||||||
|     def create_platform_endpoint(self, region, application, custom_user_data, token, attributes): |     def create_platform_endpoint(self, region, application, custom_user_data, token, attributes): | ||||||
|  |         if any(token == endpoint.token for endpoint in self.platform_endpoints.values()): | ||||||
|  |             raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token) | ||||||
|         platform_endpoint = PlatformEndpoint( |         platform_endpoint = PlatformEndpoint( | ||||||
|             region, application, custom_user_data, token, attributes) |             region, application, custom_user_data, token, attributes) | ||||||
|         self.platform_endpoints[platform_endpoint.arn] = platform_endpoint |         self.platform_endpoints[platform_endpoint.arn] = platform_endpoint | ||||||
|  | |||||||
| @ -297,7 +297,7 @@ def test_publish_to_platform_endpoint(): | |||||||
|         token="some_unique_id", |         token="some_unique_id", | ||||||
|         custom_user_data="some user data", |         custom_user_data="some user data", | ||||||
|         attributes={ |         attributes={ | ||||||
|             "Enabled": False, |             "Enabled": True, | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -142,6 +142,35 @@ def test_create_platform_endpoint(): | |||||||
|         "arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/") |         "arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @mock_sns | ||||||
|  | def test_create_duplicate_platform_endpoint(): | ||||||
|  |     conn = boto3.client('sns', region_name='us-east-1') | ||||||
|  |     platform_application = conn.create_platform_application( | ||||||
|  |         Name="my-application", | ||||||
|  |         Platform="APNS", | ||||||
|  |         Attributes={}, | ||||||
|  |     ) | ||||||
|  |     application_arn = platform_application['PlatformApplicationArn'] | ||||||
|  | 
 | ||||||
|  |     endpoint = conn.create_platform_endpoint( | ||||||
|  |         PlatformApplicationArn=application_arn, | ||||||
|  |         Token="some_unique_id", | ||||||
|  |         CustomUserData="some user data", | ||||||
|  |         Attributes={ | ||||||
|  |             "Enabled": 'false', | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     endpoint = conn.create_platform_endpoint.when.called_with( | ||||||
|  |         PlatformApplicationArn=application_arn, | ||||||
|  |         Token="some_unique_id", | ||||||
|  |         CustomUserData="some user data", | ||||||
|  |         Attributes={ | ||||||
|  |             "Enabled": 'false', | ||||||
|  |         }, | ||||||
|  |     ).should.throw(ClientError) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @mock_sns | @mock_sns | ||||||
| def test_get_list_endpoints_by_platform_application(): | def test_get_list_endpoints_by_platform_application(): | ||||||
|     conn = boto3.client('sns', region_name='us-east-1') |     conn = boto3.client('sns', region_name='us-east-1') | ||||||
| @ -256,7 +285,7 @@ def test_publish_to_platform_endpoint(): | |||||||
|         Token="some_unique_id", |         Token="some_unique_id", | ||||||
|         CustomUserData="some user data", |         CustomUserData="some user data", | ||||||
|         Attributes={ |         Attributes={ | ||||||
|             "Enabled": 'false', |             "Enabled": 'true', | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| @ -264,3 +293,31 @@ def test_publish_to_platform_endpoint(): | |||||||
| 
 | 
 | ||||||
|     conn.publish(Message="some message", |     conn.publish(Message="some message", | ||||||
|                  MessageStructure="json", TargetArn=endpoint_arn) |                  MessageStructure="json", TargetArn=endpoint_arn) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @mock_sns | ||||||
|  | def test_publish_to_disabled_platform_endpoint(): | ||||||
|  |     conn = boto3.client('sns', region_name='us-east-1') | ||||||
|  |     platform_application = conn.create_platform_application( | ||||||
|  |         Name="my-application", | ||||||
|  |         Platform="APNS", | ||||||
|  |         Attributes={}, | ||||||
|  |     ) | ||||||
|  |     application_arn = platform_application['PlatformApplicationArn'] | ||||||
|  | 
 | ||||||
|  |     endpoint = conn.create_platform_endpoint( | ||||||
|  |         PlatformApplicationArn=application_arn, | ||||||
|  |         Token="some_unique_id", | ||||||
|  |         CustomUserData="some user data", | ||||||
|  |         Attributes={ | ||||||
|  |             "Enabled": 'false', | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     endpoint_arn = endpoint['EndpointArn'] | ||||||
|  | 
 | ||||||
|  |     conn.publish.when.called_with( | ||||||
|  |         Message="some message", | ||||||
|  |         MessageStructure="json", | ||||||
|  |         TargetArn=endpoint_arn, | ||||||
|  |     ).should.throw(ClientError) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user