SNS: do not duplicate subscriptions
This commit is contained in:
		
							parent
							
								
									b855fee2e4
								
							
						
					
					
						commit
						633decc6c0
					
				| @ -247,11 +247,21 @@ class SNSBackend(BaseBackend): | |||||||
|         setattr(topic, attribute_name, attribute_value) |         setattr(topic, attribute_name, attribute_value) | ||||||
| 
 | 
 | ||||||
|     def subscribe(self, topic_arn, endpoint, protocol): |     def subscribe(self, topic_arn, endpoint, protocol): | ||||||
|  |         # AWS doesn't create duplicates | ||||||
|  |         old_subscription = self._find_subscription(topic_arn, endpoint, protocol) | ||||||
|  |         if old_subscription: | ||||||
|  |             return old_subscription | ||||||
|         topic = self.get_topic(topic_arn) |         topic = self.get_topic(topic_arn) | ||||||
|         subscription = Subscription(topic, endpoint, protocol) |         subscription = Subscription(topic, endpoint, protocol) | ||||||
|         self.subscriptions[subscription.arn] = subscription |         self.subscriptions[subscription.arn] = subscription | ||||||
|         return subscription |         return subscription | ||||||
| 
 | 
 | ||||||
|  |     def _find_subscription(self, topic_arn, endpoint, protocol): | ||||||
|  |         for subscription in self.subscriptions.values(): | ||||||
|  |             if subscription.topic.arn == topic_arn and subscription.endpoint == endpoint and subscription.protocol == protocol: | ||||||
|  |                 return subscription | ||||||
|  |         return None | ||||||
|  | 
 | ||||||
|     def unsubscribe(self, subscription_arn): |     def unsubscribe(self, subscription_arn): | ||||||
|         self.subscriptions.pop(subscription_arn) |         self.subscriptions.pop(subscription_arn) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -25,6 +25,23 @@ def test_subscribe_sms(): | |||||||
|     ) |     ) | ||||||
|     resp.should.contain('SubscriptionArn') |     resp.should.contain('SubscriptionArn') | ||||||
| 
 | 
 | ||||||
|  | @mock_sns | ||||||
|  | def test_double_subscription(): | ||||||
|  |     client = boto3.client('sns', region_name='us-east-1') | ||||||
|  |     client.create_topic(Name="some-topic") | ||||||
|  |     resp = client.create_topic(Name="some-topic") | ||||||
|  |     arn = resp['TopicArn'] | ||||||
|  | 
 | ||||||
|  |     do_subscribe_sqs = lambda sqs_arn: client.subscribe( | ||||||
|  |         TopicArn=arn, | ||||||
|  |         Protocol='sqs', | ||||||
|  |         Endpoint=sqs_arn | ||||||
|  |     ) | ||||||
|  |     resp1 = do_subscribe_sqs('arn:aws:sqs:elasticmq:000000000000:foo') | ||||||
|  |     resp2 = do_subscribe_sqs('arn:aws:sqs:elasticmq:000000000000:foo') | ||||||
|  | 
 | ||||||
|  |     resp1['SubscriptionArn'].should.equal(resp2['SubscriptionArn']) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @mock_sns | @mock_sns | ||||||
| def test_subscribe_bad_sms(): | def test_subscribe_bad_sms(): | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user