diff --git a/moto/sns/models.py b/moto/sns/models.py index 5289c8bcd..6d0833476 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -193,10 +193,17 @@ class SNSBackend(BaseBackend): next_token = None return values, next_token + def _get_topic_subscriptions(self, topic): + return [sub for sub in self.subscriptions.values() if sub.topic == topic] + def list_topics(self, next_token=None): return self._get_values_nexttoken(self.topics, next_token) def delete_topic(self, arn): + topic = self.get_topic(arn) + subscriptions = self._get_topic_subscriptions(topic) + for sub in subscriptions: + self.unsubscribe(sub.arn) self.topics.pop(arn) def get_topic(self, arn): @@ -222,7 +229,7 @@ class SNSBackend(BaseBackend): if topic_arn: topic = self.get_topic(topic_arn) filtered = OrderedDict( - [(k, sub) for k, sub in self.subscriptions.items() if sub.topic == topic]) + [(sub.arn, sub) for sub in self._get_topic_subscriptions(topic)]) return self._get_values_nexttoken(filtered, next_token) else: return self._get_values_nexttoken(self.subscriptions, next_token) diff --git a/tests/test_sns/test_subscriptions.py b/tests/test_sns/test_subscriptions.py index c521bb428..292fd83c0 100644 --- a/tests/test_sns/test_subscriptions.py +++ b/tests/test_sns/test_subscriptions.py @@ -34,6 +34,37 @@ def test_creating_subscription(): "ListSubscriptionsResult"]["Subscriptions"] subscriptions.should.have.length_of(0) +@mock_sns_deprecated +def test_deleting_subscriptions_by_deleting_topic(): + conn = boto.connect_sns() + conn.create_topic("some-topic") + topics_json = conn.get_all_topics() + topic_arn = topics_json["ListTopicsResponse"][ + "ListTopicsResult"]["Topics"][0]['TopicArn'] + + conn.subscribe(topic_arn, "http", "http://example.com/") + + subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ + "ListSubscriptionsResult"]["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + # Now delete the topic + conn.delete_topic(topic_arn) + + # And there should now be 0 topics + topics_json = conn.get_all_topics() + topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + topics.should.have.length_of(0) + + # And there should be zero subscriptions left + subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ + "ListSubscriptionsResult"]["Subscriptions"] + subscriptions.should.have.length_of(0) @mock_sns_deprecated def test_getting_subscriptions_by_topic(): diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index 906c483f7..ac325ed20 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -33,6 +33,36 @@ def test_creating_subscription(): subscriptions = conn.list_subscriptions()["Subscriptions"] subscriptions.should.have.length_of(0) +@mock_sns +def test_deleting_subscriptions_by_deleting_topic(): + conn = boto3.client('sns', region_name='us-east-1') + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]['TopicArn'] + + conn.subscribe(TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/") + + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + # Now delete the topic + conn.delete_topic(TopicArn=topic_arn) + + # And there should now be 0 topics + topics_json = conn.list_topics() + topics = topics_json["Topics"] + topics.should.have.length_of(0) + + # And there should be zero subscriptions left + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(0) @mock_sns def test_getting_subscriptions_by_topic():