SNS: do not duplicate subscriptions
This commit is contained in:
		
							parent
							
								
									b855fee2e4
								
							
						
					
					
						commit
						633decc6c0
					
				| @ -247,11 +247,21 @@ class SNSBackend(BaseBackend): | ||||
|         setattr(topic, attribute_name, attribute_value) | ||||
| 
 | ||||
|     def subscribe(self, topic_arn, endpoint, protocol): | ||||
|         # AWS doesn't create duplicates | ||||
|         old_subscription = self._find_subscription(topic_arn, endpoint, protocol) | ||||
|         if old_subscription: | ||||
|             return old_subscription | ||||
|         topic = self.get_topic(topic_arn) | ||||
|         subscription = Subscription(topic, endpoint, protocol) | ||||
|         self.subscriptions[subscription.arn] = subscription | ||||
|         return subscription | ||||
| 
 | ||||
|     def _find_subscription(self, topic_arn, endpoint, protocol): | ||||
|         for subscription in self.subscriptions.values(): | ||||
|             if subscription.topic.arn == topic_arn and subscription.endpoint == endpoint and subscription.protocol == protocol: | ||||
|                 return subscription | ||||
|         return None | ||||
| 
 | ||||
|     def unsubscribe(self, subscription_arn): | ||||
|         self.subscriptions.pop(subscription_arn) | ||||
| 
 | ||||
|  | ||||
| @ -25,6 +25,23 @@ def test_subscribe_sms(): | ||||
|     ) | ||||
|     resp.should.contain('SubscriptionArn') | ||||
| 
 | ||||
| @mock_sns | ||||
| def test_double_subscription(): | ||||
|     client = boto3.client('sns', region_name='us-east-1') | ||||
|     client.create_topic(Name="some-topic") | ||||
|     resp = client.create_topic(Name="some-topic") | ||||
|     arn = resp['TopicArn'] | ||||
| 
 | ||||
|     do_subscribe_sqs = lambda sqs_arn: client.subscribe( | ||||
|         TopicArn=arn, | ||||
|         Protocol='sqs', | ||||
|         Endpoint=sqs_arn | ||||
|     ) | ||||
|     resp1 = do_subscribe_sqs('arn:aws:sqs:elasticmq:000000000000:foo') | ||||
|     resp2 = do_subscribe_sqs('arn:aws:sqs:elasticmq:000000000000:foo') | ||||
| 
 | ||||
|     resp1['SubscriptionArn'].should.equal(resp2['SubscriptionArn']) | ||||
| 
 | ||||
| 
 | ||||
| @mock_sns | ||||
| def test_subscribe_bad_sms(): | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user