diff --git a/moto/sns/models.py b/moto/sns/models.py index 009398407..bc80f9e41 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -172,12 +172,16 @@ class SNSBackend(BaseBackend): self.applications = {} self.platform_endpoints = {} self.region_name = region_name + self.sms_attributes = {} def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) + def update_sms_attributes(self, attrs): + self.sms_attributes.update(attrs) + def create_topic(self, name): topic = Topic(name, self) self.topics[topic.arn] = topic diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 9c079b006..9ffe298af 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals import json +import re +from collections import defaultdict from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores @@ -459,6 +461,47 @@ class SNSResponse(BaseResponse): template = self.response_template(SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) return template.render() + def set_sms_attributes(self): + attr_regex = re.compile(r'^attributes\.entry\.(?P\d+)\.(?Pkey|value)$') + + # attributes.entry.1.key + # attributes.entry.1.value + # to + # 1: {key:X, value:Y} + temp_dict = defaultdict(dict) + for key, value in self.querystring.items(): + match = attr_regex.match(key) + if match is not None: + temp_dict[match.group('index')][match.group('type')] = value[0] + + # 1: {key:X, value:Y} + # to + # X: Y + # All of this, just to take into account when people provide invalid stuff. + result = {} + for item in temp_dict.values(): + if 'key' in item and 'value' in item: + result[item['key']] = item['value'] + + self.backend.update_sms_attributes(result) + + template = self.response_template(SET_SMS_ATTRIBUTES_TEMPLATE) + return template.render() + + def get_sms_attributes(self): + filter_list = set() + for key, value in self.querystring.items(): + if key.startswith('attributes.member.1'): + filter_list.add(value[0]) + + if len(filter_list) > 0: + result = {k: v for k, v in self.backend.sms_attributes.items() if k in filter_list} + else: + result = self.backend.sms_attributes + + template = self.response_template(GET_SMS_ATTRIBUTES_TEMPLATE) + return template.render(attributes=result) + CREATE_TOPIC_TEMPLATE = """ @@ -758,3 +801,26 @@ SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE = """a8763b99-33a7-11df-a9b7-05d48da6f042 """ + +SET_SMS_ATTRIBUTES_TEMPLATE = """ + + + 26332069-c04a-5428-b829-72524b56a364 + +""" + +GET_SMS_ATTRIBUTES_TEMPLATE = """ + + + {% for name, value in attributes.items() %} + + {{ name }} + {{ value }} + + {% endfor %} + + + + 287f9554-8db3-5e66-8abc-c76f0186db7e + +""" diff --git a/tests/test_sns/test_sms_boto3.py b/tests/test_sns/test_sms_boto3.py new file mode 100644 index 000000000..220a7530a --- /dev/null +++ b/tests/test_sns/test_sms_boto3.py @@ -0,0 +1,32 @@ +from __future__ import unicode_literals +import boto3 +import sure # noqa + +from moto import mock_sns + + +@mock_sns +def test_set_sms_attributes(): + conn = boto3.client('sns', region_name='us-east-1') + + conn.set_sms_attributes(attributes={'DefaultSMSType': 'Transactional', 'test': 'test'}) + + response = conn.get_sms_attributes() + response.should.contain('attributes') + response['attributes'].should.contain('DefaultSMSType') + response['attributes'].should.contain('test') + response['attributes']['DefaultSMSType'].should.equal('Transactional') + response['attributes']['test'].should.equal('test') + + +@mock_sns +def test_get_sms_attributes_filtered(): + conn = boto3.client('sns', region_name='us-east-1') + + conn.set_sms_attributes(attributes={'DefaultSMSType': 'Transactional', 'test': 'test'}) + + response = conn.get_sms_attributes(attributes=['DefaultSMSType']) + response.should.contain('attributes') + response['attributes'].should.contain('DefaultSMSType') + response['attributes'].should_not.contain('test') + response['attributes']['DefaultSMSType'].should.equal('Transactional')