Techdebt: MyPy SNS (#6258)

This commit is contained in:
Bert Blommers 2023-04-26 14:25:00 +00:00 committed by GitHub
parent 9b969f7e3f
commit 37f1456747
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 251 additions and 174 deletions

View File

@ -1,8 +1,9 @@
from typing import Any, Optional
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
class SNSException(RESTError): class SNSException(RESTError):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
kwargs["template"] = "wrapped_single_error" kwargs["template"] = "wrapped_single_error"
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -10,54 +11,54 @@ class SNSException(RESTError):
class SNSNotFoundError(SNSException): class SNSNotFoundError(SNSException):
code = 404 code = 404
def __init__(self, message, **kwargs): def __init__(self, message: str, template: Optional[str] = None):
super().__init__("NotFound", message, **kwargs) super().__init__("NotFound", message, template=template)
class TopicNotFound(SNSNotFoundError): class TopicNotFound(SNSNotFoundError):
def __init__(self): def __init__(self) -> None:
super().__init__(message="Topic does not exist") super().__init__(message="Topic does not exist")
class ResourceNotFoundError(SNSException): class ResourceNotFoundError(SNSException):
code = 404 code = 404
def __init__(self): def __init__(self) -> None:
super().__init__("ResourceNotFound", "Resource does not exist") super().__init__("ResourceNotFound", "Resource does not exist")
class DuplicateSnsEndpointError(SNSException): class DuplicateSnsEndpointError(SNSException):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("DuplicateEndpoint", message) super().__init__("DuplicateEndpoint", message)
class SnsEndpointDisabled(SNSException): class SnsEndpointDisabled(SNSException):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("EndpointDisabled", message) super().__init__("EndpointDisabled", message)
class SNSInvalidParameter(SNSException): class SNSInvalidParameter(SNSException):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidParameter", message) super().__init__("InvalidParameter", message)
class InvalidParameterValue(SNSException): class InvalidParameterValue(SNSException):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidParameterValue", message) super().__init__("InvalidParameterValue", message)
class TagLimitExceededError(SNSException): class TagLimitExceededError(SNSException):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"TagLimitExceeded", "TagLimitExceeded",
"Could not complete request: tag quota of per resource exceeded", "Could not complete request: tag quota of per resource exceeded",
@ -67,14 +68,14 @@ class TagLimitExceededError(SNSException):
class InternalError(SNSException): class InternalError(SNSException):
code = 500 code = 500
def __init__(self, message): def __init__(self, message: str):
super().__init__("InternalFailure", message) super().__init__("InternalFailure", message)
class TooManyEntriesInBatchRequest(SNSException): class TooManyEntriesInBatchRequest(SNSException):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"TooManyEntriesInBatchRequest", "TooManyEntriesInBatchRequest",
"The batch request contains more entries than permissible.", "The batch request contains more entries than permissible.",
@ -84,7 +85,7 @@ class TooManyEntriesInBatchRequest(SNSException):
class BatchEntryIdsNotDistinct(SNSException): class BatchEntryIdsNotDistinct(SNSException):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"BatchEntryIdsNotDistinct", "BatchEntryIdsNotDistinct",
"Two or more batch entries in the request have the same Id.", "Two or more batch entries in the request have the same Id.",

View File

@ -1,10 +1,11 @@
import datetime import datetime
import json import json
import requests import requests
import re import re
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Iterable, Optional, Tuple
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import ( from moto.core.utils import (
iso_8601_datetime_with_milliseconds, iso_8601_datetime_with_milliseconds,
@ -36,7 +37,7 @@ MAXIMUM_SMS_MESSAGE_BYTES = 1600 # Amazon limit for a single publish SMS action
class Topic(CloudFormationModel): class Topic(CloudFormationModel):
def __init__(self, name, sns_backend): def __init__(self, name: str, sns_backend: "SNSBackend"):
self.name = name self.name = name
self.sns_backend = sns_backend self.sns_backend = sns_backend
self.account_id = sns_backend.account_id self.account_id = sns_backend.account_id
@ -49,23 +50,25 @@ class Topic(CloudFormationModel):
self.subscriptions_pending = 0 self.subscriptions_pending = 0
self.subscriptions_confimed = 0 self.subscriptions_confimed = 0
self.subscriptions_deleted = 0 self.subscriptions_deleted = 0
self.sent_notifications = [] self.sent_notifications: List[
Tuple[str, str, Optional[str], Optional[Dict[str, Any]], Optional[str]]
] = []
self._policy_json = self._create_default_topic_policy( self._policy_json = self._create_default_topic_policy(
sns_backend.region_name, self.account_id, name sns_backend.region_name, self.account_id, name
) )
self._tags = {} self._tags: Dict[str, str] = {}
self.fifo_topic = "false" self.fifo_topic = "false"
self.content_based_deduplication = "false" self.content_based_deduplication = "false"
def publish( def publish(
self, self,
message, message: str,
subject=None, subject: Optional[str] = None,
message_attributes=None, message_attributes: Optional[Dict[str, Any]] = None,
group_id=None, group_id: Optional[str] = None,
deduplication_id=None, deduplication_id: Optional[str] = None,
): ) -> str:
message_id = str(mock_random.uuid4()) message_id = str(mock_random.uuid4())
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn) subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
for subscription in subscriptions: for subscription in subscriptions:
@ -83,10 +86,10 @@ class Topic(CloudFormationModel):
return message_id return message_id
@classmethod @classmethod
def has_cfn_attr(cls, attr): def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["TopicName"] return attr in ["TopicName"]
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "TopicName": if attribute_name == "TopicName":
@ -94,30 +97,35 @@ class Topic(CloudFormationModel):
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.arn return self.arn
@property @property
def policy(self): def policy(self) -> str:
return json.dumps(self._policy_json, separators=(",", ":")) return json.dumps(self._policy_json, separators=(",", ":"))
@policy.setter @policy.setter
def policy(self, policy): def policy(self, policy: Any) -> None: # type: ignore[misc]
self._policy_json = json.loads(policy) self._policy_json = json.loads(policy)
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return "TopicName" return "TopicName"
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sns-topic.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sns-topic.html
return "AWS::SNS::Topic" return "AWS::SNS::Topic"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Topic":
sns_backend = sns_backends[account_id][region_name] sns_backend = sns_backends[account_id][region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -129,14 +137,14 @@ class Topic(CloudFormationModel):
return topic return topic
@classmethod @classmethod
def update_from_cloudformation_json( def update_from_cloudformation_json( # type: ignore[misc]
cls, cls,
original_resource, original_resource: Any,
new_resource_name, new_resource_name: str,
cloudformation_json, cloudformation_json: Any,
account_id, account_id: str,
region_name, region_name: str,
): ) -> "Topic":
cls.delete_from_cloudformation_json( cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, account_id, region_name original_resource.name, cloudformation_json, account_id, region_name
) )
@ -145,9 +153,13 @@ class Topic(CloudFormationModel):
) )
@classmethod @classmethod
def delete_from_cloudformation_json( def delete_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
sns_backend = sns_backends[account_id][region_name] sns_backend = sns_backends[account_id][region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -158,7 +170,9 @@ class Topic(CloudFormationModel):
sns_backend.unsubscribe(subscription.arn) sns_backend.unsubscribe(subscription.arn)
sns_backend.delete_topic(topic_arn) sns_backend.delete_topic(topic_arn)
def _create_default_topic_policy(self, region_name, account_id, name): def _create_default_topic_policy(
self, region_name: str, account_id: str, name: str
) -> Dict[str, Any]:
return { return {
"Version": "2008-10-17", "Version": "2008-10-17",
"Id": "__default_policy_ID", "Id": "__default_policy_ID",
@ -186,25 +200,25 @@ class Topic(CloudFormationModel):
class Subscription(BaseModel): class Subscription(BaseModel):
def __init__(self, account_id, topic, endpoint, protocol): def __init__(self, account_id: str, topic: Topic, endpoint: str, protocol: str):
self.account_id = account_id self.account_id = account_id
self.topic = topic self.topic = topic
self.endpoint = endpoint self.endpoint = endpoint
self.protocol = protocol self.protocol = protocol
self.arn = make_arn_for_subscription(self.topic.arn) self.arn = make_arn_for_subscription(self.topic.arn)
self.attributes = {} self.attributes: Dict[str, Any] = {}
self._filter_policy = None # filter policy as a dict, not json. self._filter_policy = None # filter policy as a dict, not json.
self.confirmed = False self.confirmed = False
def publish( def publish(
self, self,
message, message: str,
message_id, message_id: str,
subject=None, subject: Optional[str] = None,
message_attributes=None, message_attributes: Optional[Dict[str, Any]] = None,
group_id=None, group_id: Optional[str] = None,
deduplication_id=None, deduplication_id: Optional[str] = None,
): ) -> None:
if not self._matches_filter_policy(message_attributes): if not self._matches_filter_policy(message_attributes):
return return
@ -230,7 +244,7 @@ class Subscription(BaseModel):
) )
else: else:
raw_message_attributes = {} raw_message_attributes = {}
for key, value in message_attributes.items(): for key, value in message_attributes.items(): # type: ignore
attr_type = "string_value" attr_type = "string_value"
type_value = value["Value"] type_value = value["Value"]
if value["Type"].startswith("Binary"): if value["Type"].startswith("Binary"):
@ -279,14 +293,18 @@ class Subscription(BaseModel):
function_name, message, subject=subject, qualifier=qualifier function_name, message, subject=subject, qualifier=qualifier
) )
def _matches_filter_policy(self, message_attributes): def _matches_filter_policy(
self, message_attributes: Optional[Dict[str, Any]]
) -> bool:
if not self._filter_policy: if not self._filter_policy:
return True return True
if message_attributes is None: if message_attributes is None:
message_attributes = {} message_attributes = {}
def _field_match(field, rules, message_attributes): def _field_match(
field: str, rules: List[Any], message_attributes: Dict[str, Any]
) -> bool:
for rule in rules: for rule in rules:
# TODO: boolean value matching is not supported, SNS behavior unknown # TODO: boolean value matching is not supported, SNS behavior unknown
if isinstance(rule, str): if isinstance(rule, str):
@ -400,8 +418,14 @@ class Subscription(BaseModel):
for field, rules in self._filter_policy.items() for field, rules in self._filter_policy.items()
) )
def get_post_data(self, message, message_id, subject, message_attributes=None): def get_post_data(
post_data = { self,
message: str,
message_id: str,
subject: Optional[str],
message_attributes: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
post_data: Dict[str, Any] = {
"Type": "Notification", "Type": "Notification",
"MessageId": message_id, "MessageId": message_id,
"TopicArn": self.topic.arn, "TopicArn": self.topic.arn,
@ -422,7 +446,14 @@ class Subscription(BaseModel):
class PlatformApplication(BaseModel): class PlatformApplication(BaseModel):
def __init__(self, account_id, region, name, platform, attributes): def __init__(
self,
account_id: str,
region: str,
name: str,
platform: str,
attributes: Dict[str, str],
):
self.region = region self.region = region
self.name = name self.name = name
self.platform = platform self.platform = platform
@ -432,7 +463,13 @@ class PlatformApplication(BaseModel):
class PlatformEndpoint(BaseModel): class PlatformEndpoint(BaseModel):
def __init__( def __init__(
self, account_id, region, application, custom_user_data, token, attributes self,
account_id: str,
region: str,
application: PlatformApplication,
custom_user_data: str,
token: str,
attributes: Dict[str, str],
): ):
self.region = region self.region = region
self.application = application self.application = application
@ -441,10 +478,10 @@ class PlatformEndpoint(BaseModel):
self.attributes = attributes self.attributes = attributes
self.id = mock_random.uuid4() self.id = mock_random.uuid4()
self.arn = f"arn:aws:sns:{region}:{account_id}:endpoint/{self.application.platform}/{self.application.name}/{self.id}" self.arn = f"arn:aws:sns:{region}:{account_id}:endpoint/{self.application.platform}/{self.application.name}/{self.id}"
self.messages = OrderedDict() self.messages: Dict[str, str] = OrderedDict()
self.__fixup_attributes() self.__fixup_attributes()
def __fixup_attributes(self): def __fixup_attributes(self) -> None:
# When AWS returns the attributes dict, it always contains these two elements, so we need to # When AWS returns the attributes dict, it always contains these two elements, so we need to
# automatically ensure they exist as well. # automatically ensure they exist as well.
if "Token" not in self.attributes: if "Token" not in self.attributes:
@ -456,10 +493,10 @@ class PlatformEndpoint(BaseModel):
self.attributes["Enabled"] = "true" self.attributes["Enabled"] = "true"
@property @property
def enabled(self): def enabled(self) -> bool:
return json.loads(self.attributes.get("Enabled", "true").lower()) return json.loads(self.attributes.get("Enabled", "true").lower())
def publish(self, message): def publish(self, message: str) -> str:
if not self.enabled: if not self.enabled:
raise SnsEndpointDisabled(f"Endpoint {self.id} disabled") raise SnsEndpointDisabled(f"Endpoint {self.id} disabled")
@ -485,15 +522,15 @@ class SNSBackend(BaseBackend):
Note that, as this is an internal API, the exact format may differ per versions. Note that, as this is an internal API, the exact format may differ per versions.
""" """
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.topics = OrderedDict() self.topics: Dict[str, Topic] = OrderedDict()
self.subscriptions: OrderedDict[str, Subscription] = OrderedDict() self.subscriptions: OrderedDict[str, Subscription] = OrderedDict()
self.applications = {} self.applications: Dict[str, PlatformApplication] = {}
self.platform_endpoints = {} self.platform_endpoints: Dict[str, PlatformEndpoint] = {}
self.region_name = region_name self.region_name = region_name
self.sms_attributes = {} self.sms_attributes: Dict[str, str] = {}
self.sms_messages = OrderedDict() self.sms_messages: Dict[str, Tuple[str, str]] = OrderedDict()
self.opt_out_numbers = [ self.opt_out_numbers = [
"+447420500600", "+447420500600",
"+447420505401", "+447420505401",
@ -506,23 +543,27 @@ class SNSBackend(BaseBackend):
] ]
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""List of dicts representing default VPC endpoints for this service.""" """List of dicts representing default VPC endpoints for this service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "sns" service_region, zones, "sns"
) )
def update_sms_attributes(self, attrs): def update_sms_attributes(self, attrs: Dict[str, str]) -> None:
self.sms_attributes.update(attrs) self.sms_attributes.update(attrs)
def create_topic(self, name, attributes=None, tags=None): def create_topic(
self,
name: str,
attributes: Optional[Dict[str, str]] = None,
tags: Optional[Dict[str, str]] = None,
) -> Topic:
if attributes is None: if attributes is None:
attributes = {} attributes = {}
if ( if attributes.get("FifoTopic") and attributes["FifoTopic"].lower() == "true":
attributes.get("FifoTopic")
and attributes.get("FifoTopic").lower() == "true"
):
fails_constraints = not re.match(r"^[a-zA-Z0-9_-]{1,256}\.fifo$", name) fails_constraints = not re.match(r"^[a-zA-Z0-9_-]{1,256}\.fifo$", name)
msg = "Fifo Topic names must end with .fifo and must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long." msg = "Fifo Topic names must end with .fifo and must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long."
@ -549,29 +590,33 @@ class SNSBackend(BaseBackend):
self.topics[candidate_topic.arn] = candidate_topic self.topics[candidate_topic.arn] = candidate_topic
return candidate_topic return candidate_topic
def _get_values_nexttoken(self, values_map, next_token=None): def _get_values_nexttoken(
if next_token is None or not next_token: self, values_map: Dict[str, Any], next_token: Optional[str] = None
next_token = 0 ) -> Tuple[List[Any], Optional[int]]:
next_token = int(next_token) i_next_token = int(next_token or "0")
values = list(values_map.values())[next_token : next_token + DEFAULT_PAGE_SIZE] values = list(values_map.values())[
i_next_token : i_next_token + DEFAULT_PAGE_SIZE
]
if len(values) == DEFAULT_PAGE_SIZE: if len(values) == DEFAULT_PAGE_SIZE:
next_token = next_token + DEFAULT_PAGE_SIZE i_next_token = i_next_token + DEFAULT_PAGE_SIZE
else: else:
next_token = None i_next_token = None # type: ignore
return values, next_token return values, i_next_token
def _get_topic_subscriptions(self, topic): def _get_topic_subscriptions(self, topic: Topic) -> List[Subscription]:
return [sub for sub in self.subscriptions.values() if sub.topic == topic] return [sub for sub in self.subscriptions.values() if sub.topic == topic]
def list_topics(self, next_token=None): def list_topics(
self, next_token: Optional[str] = None
) -> Tuple[List[Topic], Optional[int]]:
return self._get_values_nexttoken(self.topics, next_token) return self._get_values_nexttoken(self.topics, next_token)
def delete_topic_subscriptions(self, topic): def delete_topic_subscriptions(self, topic: Topic) -> None:
for key, value in dict(self.subscriptions).items(): for key, value in dict(self.subscriptions).items():
if value.topic == topic: if value.topic == topic:
self.subscriptions.pop(key) self.subscriptions.pop(key)
def delete_topic(self, arn): def delete_topic(self, arn: str) -> None:
try: try:
topic = self.get_topic(arn) topic = self.get_topic(arn)
self.delete_topic_subscriptions(topic) self.delete_topic_subscriptions(topic)
@ -579,17 +624,19 @@ class SNSBackend(BaseBackend):
except KeyError: except KeyError:
raise SNSNotFoundError(f"Topic with arn {arn} not found") raise SNSNotFoundError(f"Topic with arn {arn} not found")
def get_topic(self, arn): def get_topic(self, arn: str) -> Topic:
try: try:
return self.topics[arn] return self.topics[arn]
except KeyError: except KeyError:
raise SNSNotFoundError(f"Topic with arn {arn} not found") raise SNSNotFoundError(f"Topic with arn {arn} not found")
def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): def set_topic_attribute(
self, topic_arn: str, attribute_name: str, attribute_value: str
) -> None:
topic = self.get_topic(topic_arn) topic = self.get_topic(topic_arn)
setattr(topic, attribute_name, attribute_value) setattr(topic, attribute_name, attribute_value)
def subscribe(self, topic_arn, endpoint, protocol): def subscribe(self, topic_arn: str, endpoint: str, protocol: str) -> Subscription:
if protocol == "sms": if protocol == "sms":
if re.search(r"[./-]{2,}", endpoint) or re.search( if re.search(r"[./-]{2,}", endpoint) or re.search(
r"(^[./-]|[./-]$)", endpoint r"(^[./-]|[./-]$)", endpoint
@ -625,7 +672,9 @@ class SNSBackend(BaseBackend):
self.subscriptions[subscription.arn] = subscription self.subscriptions[subscription.arn] = subscription
return subscription return subscription
def _find_subscription(self, topic_arn, endpoint, protocol): def _find_subscription(
self, topic_arn: str, endpoint: str, protocol: str
) -> Optional[Subscription]:
for subscription in self.subscriptions.values(): for subscription in self.subscriptions.values():
if ( if (
subscription.topic.arn == topic_arn subscription.topic.arn == topic_arn
@ -635,10 +684,12 @@ class SNSBackend(BaseBackend):
return subscription return subscription
return None return None
def unsubscribe(self, subscription_arn): def unsubscribe(self, subscription_arn: str) -> None:
self.subscriptions.pop(subscription_arn, None) self.subscriptions.pop(subscription_arn, None)
def list_subscriptions(self, topic_arn=None, next_token=None): def list_subscriptions(
self, topic_arn: Optional[str] = None, next_token: Optional[str] = None
) -> Tuple[List[Subscription], Optional[int]]:
if topic_arn: if topic_arn:
topic = self.get_topic(topic_arn) topic = self.get_topic(topic_arn)
filtered = OrderedDict( filtered = OrderedDict(
@ -650,14 +701,14 @@ class SNSBackend(BaseBackend):
def publish( def publish(
self, self,
message, message: str,
arn=None, arn: Optional[str],
phone_number=None, phone_number: Optional[str] = None,
subject=None, subject: Optional[str] = None,
message_attributes=None, message_attributes: Optional[Dict[str, Any]] = None,
group_id=None, group_id: Optional[str] = None,
deduplication_id=None, deduplication_id: Optional[str] = None,
): ) -> str:
if subject is not None and len(subject) > 100: if subject is not None and len(subject) > 100:
# Note that the AWS docs around length are wrong: https://github.com/getmoto/moto/issues/1503 # Note that the AWS docs around length are wrong: https://github.com/getmoto/moto/issues/1503
raise ValueError("Subject must be less than 100 characters") raise ValueError("Subject must be less than 100 characters")
@ -677,7 +728,7 @@ class SNSBackend(BaseBackend):
) )
try: try:
topic = self.get_topic(arn) topic = self.get_topic(arn) # type: ignore
fifo_topic = topic.fifo_topic == "true" fifo_topic = topic.fifo_topic == "true"
if fifo_topic: if fifo_topic:
@ -705,40 +756,48 @@ class SNSBackend(BaseBackend):
deduplication_id=deduplication_id, deduplication_id=deduplication_id,
) )
except SNSNotFoundError: except SNSNotFoundError:
endpoint = self.get_endpoint(arn) endpoint = self.get_endpoint(arn) # type: ignore
message_id = endpoint.publish(message) message_id = endpoint.publish(message)
return message_id return message_id
def create_platform_application(self, name, platform, attributes): def create_platform_application(
self, name: str, platform: str, attributes: Dict[str, str]
) -> PlatformApplication:
application = PlatformApplication( application = PlatformApplication(
self.account_id, self.region_name, name, platform, attributes self.account_id, self.region_name, name, platform, attributes
) )
self.applications[application.arn] = application self.applications[application.arn] = application
return application return application
def get_application(self, arn): def get_application(self, arn: str) -> PlatformApplication:
try: try:
return self.applications[arn] return self.applications[arn]
except KeyError: except KeyError:
raise SNSNotFoundError(f"Application with arn {arn} not found") raise SNSNotFoundError(f"Application with arn {arn} not found")
def set_application_attributes(self, arn, attributes): def set_application_attributes(
self, arn: str, attributes: Dict[str, Any]
) -> PlatformApplication:
application = self.get_application(arn) application = self.get_application(arn)
application.attributes.update(attributes) application.attributes.update(attributes)
return application return application
def list_platform_applications(self): def list_platform_applications(self) -> Iterable[PlatformApplication]:
return self.applications.values() return self.applications.values()
def delete_platform_application(self, platform_arn): def delete_platform_application(self, platform_arn: str) -> None:
self.applications.pop(platform_arn) self.applications.pop(platform_arn)
endpoints = self.list_endpoints_by_platform_application(platform_arn) endpoints = self.list_endpoints_by_platform_application(platform_arn)
for endpoint in endpoints: for endpoint in endpoints:
self.platform_endpoints.pop(endpoint.arn) self.platform_endpoints.pop(endpoint.arn)
def create_platform_endpoint( def create_platform_endpoint(
self, application, custom_user_data, token, attributes self,
): application: PlatformApplication,
custom_user_data: str,
token: str,
attributes: Dict[str, str],
) -> PlatformEndpoint:
for endpoint in self.platform_endpoints.values(): for endpoint in self.platform_endpoints.values():
if token == endpoint.token: if token == endpoint.token:
if ( if (
@ -760,33 +819,37 @@ class SNSBackend(BaseBackend):
self.platform_endpoints[platform_endpoint.arn] = platform_endpoint self.platform_endpoints[platform_endpoint.arn] = platform_endpoint
return platform_endpoint return platform_endpoint
def list_endpoints_by_platform_application(self, application_arn): def list_endpoints_by_platform_application(
self, application_arn: str
) -> List[PlatformEndpoint]:
return [ return [
endpoint endpoint
for endpoint in self.platform_endpoints.values() for endpoint in self.platform_endpoints.values()
if endpoint.application.arn == application_arn if endpoint.application.arn == application_arn
] ]
def get_endpoint(self, arn): def get_endpoint(self, arn: str) -> PlatformEndpoint:
try: try:
return self.platform_endpoints[arn] return self.platform_endpoints[arn]
except KeyError: except KeyError:
raise SNSNotFoundError("Endpoint does not exist") raise SNSNotFoundError("Endpoint does not exist")
def set_endpoint_attributes(self, arn, attributes): def set_endpoint_attributes(
self, arn: str, attributes: Dict[str, Any]
) -> PlatformEndpoint:
endpoint = self.get_endpoint(arn) endpoint = self.get_endpoint(arn)
if "Enabled" in attributes: if "Enabled" in attributes:
attributes["Enabled"] = attributes["Enabled"].lower() attributes["Enabled"] = attributes["Enabled"].lower()
endpoint.attributes.update(attributes) endpoint.attributes.update(attributes)
return endpoint return endpoint
def delete_endpoint(self, arn): def delete_endpoint(self, arn: str) -> None:
try: try:
del self.platform_endpoints[arn] del self.platform_endpoints[arn]
except KeyError: except KeyError:
raise SNSNotFoundError(f"Endpoint with arn {arn} not found") raise SNSNotFoundError(f"Endpoint with arn {arn} not found")
def get_subscription_attributes(self, arn): def get_subscription_attributes(self, arn: str) -> Dict[str, Any]:
subscription = self.subscriptions.get(arn) subscription = self.subscriptions.get(arn)
if not subscription: if not subscription:
@ -796,7 +859,7 @@ class SNSBackend(BaseBackend):
return subscription.attributes return subscription.attributes
def set_subscription_attributes(self, arn, name, value): def set_subscription_attributes(self, arn: str, name: str, value: Any) -> None:
if name not in [ if name not in [
"RawMessageDelivery", "RawMessageDelivery",
"DeliveryPolicy", "DeliveryPolicy",
@ -819,7 +882,7 @@ class SNSBackend(BaseBackend):
self._validate_filter_policy(filter_policy) self._validate_filter_policy(filter_policy)
subscription._filter_policy = filter_policy subscription._filter_policy = filter_policy
def _validate_filter_policy(self, value): def _validate_filter_policy(self, value: Any) -> None:
# TODO: extend validation checks # TODO: extend validation checks
combinations = 1 combinations = 1
for rules in value.values(): for rules in value.values():
@ -931,7 +994,13 @@ class SNSBackend(BaseBackend):
"Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null" "Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null"
) )
def add_permission(self, topic_arn, label, aws_account_ids, action_names): def add_permission(
self,
topic_arn: str,
label: str,
aws_account_ids: List[str],
action_names: List[str],
) -> None:
if topic_arn not in self.topics: if topic_arn not in self.topics:
raise SNSNotFoundError("Topic does not exist") raise SNSNotFoundError("Topic does not exist")
@ -966,7 +1035,7 @@ class SNSBackend(BaseBackend):
self.topics[topic_arn]._policy_json["Statement"].append(statement) self.topics[topic_arn]._policy_json["Statement"].append(statement)
def remove_permission(self, topic_arn, label): def remove_permission(self, topic_arn: str, label: str) -> None:
if topic_arn not in self.topics: if topic_arn not in self.topics:
raise SNSNotFoundError("Topic does not exist") raise SNSNotFoundError("Topic does not exist")
@ -977,13 +1046,13 @@ class SNSBackend(BaseBackend):
self.topics[topic_arn]._policy_json["Statement"] = statements self.topics[topic_arn]._policy_json["Statement"] = statements
def list_tags_for_resource(self, resource_arn): def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]:
if resource_arn not in self.topics: if resource_arn not in self.topics:
raise ResourceNotFoundError raise ResourceNotFoundError
return self.topics[resource_arn]._tags return self.topics[resource_arn]._tags
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None:
if resource_arn not in self.topics: if resource_arn not in self.topics:
raise ResourceNotFoundError raise ResourceNotFoundError
@ -995,14 +1064,16 @@ class SNSBackend(BaseBackend):
self.topics[resource_arn]._tags = updated_tags self.topics[resource_arn]._tags = updated_tags
def untag_resource(self, resource_arn, tag_keys): def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
if resource_arn not in self.topics: if resource_arn not in self.topics:
raise ResourceNotFoundError raise ResourceNotFoundError
for key in tag_keys: for key in tag_keys:
self.topics[resource_arn]._tags.pop(key, None) self.topics[resource_arn]._tags.pop(key, None)
def publish_batch(self, topic_arn, publish_batch_request_entries): def publish_batch(
self, topic_arn: str, publish_batch_request_entries: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, str]], List[Dict[str, Any]]]:
""" """
The MessageStructure and MessageDeduplicationId-parameters have not yet been implemented. The MessageStructure and MessageDeduplicationId-parameters have not yet been implemented.
""" """
@ -1027,8 +1098,8 @@ class SNSBackend(BaseBackend):
"Invalid parameter: The MessageGroupId parameter is required for FIFO topics" "Invalid parameter: The MessageGroupId parameter is required for FIFO topics"
) )
successful = [] successful: List[Dict[str, str]] = []
failed = [] failed: List[Dict[str, Any]] = []
for entry in publish_batch_request_entries: for entry in publish_batch_request_entries:
try: try:

View File

@ -1,10 +1,11 @@
import json import json
import re import re
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Tuple, Union
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from .models import sns_backends from .models import sns_backends, SNSBackend
from .exceptions import InvalidParameterValue, SNSNotFoundError from .exceptions import InvalidParameterValue, SNSNotFoundError
from .utils import is_e164 from .utils import is_e164
@ -15,32 +16,34 @@ class SNSResponse(BaseResponse):
) )
OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r"^\+?\d+$") OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r"^\+?\d+$")
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="sns") super().__init__(service_name="sns")
@property @property
def backend(self): def backend(self) -> SNSBackend:
return sns_backends[self.current_account][self.region] return sns_backends[self.current_account][self.region]
def _error(self, code, message, sender="Sender"): def _error(self, code: str, message: str, sender: str = "Sender") -> str:
template = self.response_template(ERROR_RESPONSE) template = self.response_template(ERROR_RESPONSE)
return template.render(code=code, message=message, sender=sender) return template.render(code=code, message=message, sender=sender)
def _get_attributes(self): def _get_attributes(self) -> Dict[str, str]:
attributes = self._get_list_prefix("Attributes.entry") attributes = self._get_list_prefix("Attributes.entry")
return dict((attribute["key"], attribute["value"]) for attribute in attributes) return dict((attribute["key"], attribute["value"]) for attribute in attributes)
def _get_tags(self): def _get_tags(self) -> Dict[str, str]:
tags = self._get_list_prefix("Tags.member") tags = self._get_list_prefix("Tags.member")
return {tag["key"]: tag["value"] for tag in tags} return {tag["key"]: tag["value"] for tag in tags}
def _parse_message_attributes(self): def _parse_message_attributes(self) -> Dict[str, Any]:
message_attributes = self._get_object_map( message_attributes = self._get_object_map(
"MessageAttributes.entry", name="Name", value="Value" "MessageAttributes.entry", name="Name", value="Value"
) )
return self._transform_message_attributes(message_attributes) return self._transform_message_attributes(message_attributes)
def _transform_message_attributes(self, message_attributes): def _transform_message_attributes(
self, message_attributes: Dict[str, Any]
) -> Dict[str, Any]:
# SNS converts some key names before forwarding messages # SNS converts some key names before forwarding messages
# DataType -> Type, StringValue -> Value, BinaryValue -> Value # DataType -> Type, StringValue -> Value, BinaryValue -> Value
transformed_message_attributes = {} transformed_message_attributes = {}
@ -96,7 +99,7 @@ class SNSResponse(BaseResponse):
return transformed_message_attributes return transformed_message_attributes
def create_topic(self): def create_topic(self) -> str:
name = self._get_param("Name") name = self._get_param("Name")
attributes = self._get_attributes() attributes = self._get_attributes()
tags = self._get_tags() tags = self._get_tags()
@ -117,7 +120,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(CREATE_TOPIC_TEMPLATE) template = self.response_template(CREATE_TOPIC_TEMPLATE)
return template.render(topic=topic) return template.render(topic=topic)
def list_topics(self): def list_topics(self) -> str:
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
topics, next_token = self.backend.list_topics(next_token=next_token) topics, next_token = self.backend.list_topics(next_token=next_token)
@ -139,7 +142,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_TOPICS_TEMPLATE) template = self.response_template(LIST_TOPICS_TEMPLATE)
return template.render(topics=topics, next_token=next_token) return template.render(topics=topics, next_token=next_token)
def delete_topic(self): def delete_topic(self) -> str:
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
self.backend.delete_topic(topic_arn) self.backend.delete_topic(topic_arn)
@ -157,7 +160,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(DELETE_TOPIC_TEMPLATE) template = self.response_template(DELETE_TOPIC_TEMPLATE)
return template.render() return template.render()
def get_topic_attributes(self): def get_topic_attributes(self) -> str:
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
topic = self.backend.get_topic(topic_arn) topic = self.backend.get_topic(topic_arn)
@ -193,7 +196,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(GET_TOPIC_ATTRIBUTES_TEMPLATE) template = self.response_template(GET_TOPIC_ATTRIBUTES_TEMPLATE)
return template.render(topic=topic) return template.render(topic=topic)
def set_topic_attributes(self): def set_topic_attributes(self) -> str:
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
attribute_name = self._get_param("AttributeName") attribute_name = self._get_param("AttributeName")
attribute_name = camelcase_to_underscores(attribute_name) attribute_name = camelcase_to_underscores(attribute_name)
@ -214,7 +217,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_TOPIC_ATTRIBUTES_TEMPLATE) template = self.response_template(SET_TOPIC_ATTRIBUTES_TEMPLATE)
return template.render() return template.render()
def subscribe(self): def subscribe(self) -> str:
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
endpoint = self._get_param("Endpoint") endpoint = self._get_param("Endpoint")
protocol = self._get_param("Protocol") protocol = self._get_param("Protocol")
@ -243,7 +246,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SUBSCRIBE_TEMPLATE) template = self.response_template(SUBSCRIBE_TEMPLATE)
return template.render(subscription=subscription) return template.render(subscription=subscription)
def unsubscribe(self): def unsubscribe(self) -> str:
subscription_arn = self._get_param("SubscriptionArn") subscription_arn = self._get_param("SubscriptionArn")
self.backend.unsubscribe(subscription_arn) self.backend.unsubscribe(subscription_arn)
@ -261,7 +264,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(UNSUBSCRIBE_TEMPLATE) template = self.response_template(UNSUBSCRIBE_TEMPLATE)
return template.render() return template.render()
def list_subscriptions(self): def list_subscriptions(self) -> str:
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
subscriptions, next_token = self.backend.list_subscriptions( subscriptions, next_token = self.backend.list_subscriptions(
next_token=next_token next_token=next_token
@ -294,7 +297,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_SUBSCRIPTIONS_TEMPLATE) template = self.response_template(LIST_SUBSCRIPTIONS_TEMPLATE)
return template.render(subscriptions=subscriptions, next_token=next_token) return template.render(subscriptions=subscriptions, next_token=next_token)
def list_subscriptions_by_topic(self): def list_subscriptions_by_topic(self) -> str:
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
next_token = self._get_param("NextToken") next_token = self._get_param("NextToken")
subscriptions, next_token = self.backend.list_subscriptions( subscriptions, next_token = self.backend.list_subscriptions(
@ -328,7 +331,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_SUBSCRIPTIONS_BY_TOPIC_TEMPLATE) template = self.response_template(LIST_SUBSCRIPTIONS_BY_TOPIC_TEMPLATE)
return template.render(subscriptions=subscriptions, next_token=next_token) return template.render(subscriptions=subscriptions, next_token=next_token)
def publish(self): def publish(self) -> Union[str, Tuple[str, Dict[str, int]]]:
target_arn = self._get_param("TargetArn") target_arn = self._get_param("TargetArn")
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
phone_number = self._get_param("PhoneNumber") phone_number = self._get_param("PhoneNumber")
@ -384,7 +387,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(PUBLISH_TEMPLATE) template = self.response_template(PUBLISH_TEMPLATE)
return template.render(message_id=message_id) return template.render(message_id=message_id)
def publish_batch(self): def publish_batch(self) -> str:
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
publish_batch_request_entries = self._get_multi_param( publish_batch_request_entries = self._get_multi_param(
"PublishBatchRequestEntries.member" "PublishBatchRequestEntries.member"
@ -406,7 +409,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(PUBLISH_BATCH_TEMPLATE) template = self.response_template(PUBLISH_BATCH_TEMPLATE)
return template.render(successful=successful, failed=failed) return template.render(successful=successful, failed=failed)
def create_platform_application(self): def create_platform_application(self) -> str:
name = self._get_param("Name") name = self._get_param("Name")
platform = self._get_param("Platform") platform = self._get_param("Platform")
attributes = self._get_attributes() attributes = self._get_attributes()
@ -431,7 +434,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(CREATE_PLATFORM_APPLICATION_TEMPLATE) template = self.response_template(CREATE_PLATFORM_APPLICATION_TEMPLATE)
return template.render(platform_application=platform_application) return template.render(platform_application=platform_application)
def get_platform_application_attributes(self): def get_platform_application_attributes(self) -> str:
arn = self._get_param("PlatformApplicationArn") arn = self._get_param("PlatformApplicationArn")
application = self.backend.get_application(arn) application = self.backend.get_application(arn)
@ -452,7 +455,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(GET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) template = self.response_template(GET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE)
return template.render(application=application) return template.render(application=application)
def set_platform_application_attributes(self): def set_platform_application_attributes(self) -> str:
arn = self._get_param("PlatformApplicationArn") arn = self._get_param("PlatformApplicationArn")
attributes = self._get_attributes() attributes = self._get_attributes()
@ -472,7 +475,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) template = self.response_template(SET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE)
return template.render() return template.render()
def list_platform_applications(self): def list_platform_applications(self) -> str:
applications = self.backend.list_platform_applications() applications = self.backend.list_platform_applications()
if self.request_json: if self.request_json:
@ -499,7 +502,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_PLATFORM_APPLICATIONS_TEMPLATE) template = self.response_template(LIST_PLATFORM_APPLICATIONS_TEMPLATE)
return template.render(applications=applications) return template.render(applications=applications)
def delete_platform_application(self): def delete_platform_application(self) -> str:
platform_arn = self._get_param("PlatformApplicationArn") platform_arn = self._get_param("PlatformApplicationArn")
self.backend.delete_platform_application(platform_arn) self.backend.delete_platform_application(platform_arn)
@ -517,7 +520,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(DELETE_PLATFORM_APPLICATION_TEMPLATE) template = self.response_template(DELETE_PLATFORM_APPLICATION_TEMPLATE)
return template.render() return template.render()
def create_platform_endpoint(self): def create_platform_endpoint(self) -> str:
application_arn = self._get_param("PlatformApplicationArn") application_arn = self._get_param("PlatformApplicationArn")
application = self.backend.get_application(application_arn) application = self.backend.get_application(application_arn)
@ -546,7 +549,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(CREATE_PLATFORM_ENDPOINT_TEMPLATE) template = self.response_template(CREATE_PLATFORM_ENDPOINT_TEMPLATE)
return template.render(platform_endpoint=platform_endpoint) return template.render(platform_endpoint=platform_endpoint)
def list_endpoints_by_platform_application(self): def list_endpoints_by_platform_application(self) -> str:
application_arn = self._get_param("PlatformApplicationArn") application_arn = self._get_param("PlatformApplicationArn")
endpoints = self.backend.list_endpoints_by_platform_application(application_arn) endpoints = self.backend.list_endpoints_by_platform_application(application_arn)
@ -576,7 +579,7 @@ class SNSResponse(BaseResponse):
) )
return template.render(endpoints=endpoints) return template.render(endpoints=endpoints)
def get_endpoint_attributes(self): def get_endpoint_attributes(self) -> Union[str, Tuple[str, Dict[str, int]]]:
arn = self._get_param("EndpointArn") arn = self._get_param("EndpointArn")
try: try:
endpoint = self.backend.get_endpoint(arn) endpoint = self.backend.get_endpoint(arn)
@ -601,7 +604,7 @@ class SNSResponse(BaseResponse):
error_response = self._error("NotFound", "Endpoint does not exist") error_response = self._error("NotFound", "Endpoint does not exist")
return error_response, dict(status=404) return error_response, dict(status=404)
def set_endpoint_attributes(self): def set_endpoint_attributes(self) -> Union[str, Tuple[str, Dict[str, int]]]:
arn = self._get_param("EndpointArn") arn = self._get_param("EndpointArn")
attributes = self._get_attributes() attributes = self._get_attributes()
@ -621,7 +624,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_ENDPOINT_ATTRIBUTES_TEMPLATE) template = self.response_template(SET_ENDPOINT_ATTRIBUTES_TEMPLATE)
return template.render() return template.render()
def delete_endpoint(self): def delete_endpoint(self) -> str:
arn = self._get_param("EndpointArn") arn = self._get_param("EndpointArn")
self.backend.delete_endpoint(arn) self.backend.delete_endpoint(arn)
@ -639,13 +642,13 @@ class SNSResponse(BaseResponse):
template = self.response_template(DELETE_ENDPOINT_TEMPLATE) template = self.response_template(DELETE_ENDPOINT_TEMPLATE)
return template.render() return template.render()
def get_subscription_attributes(self): def get_subscription_attributes(self) -> str:
arn = self._get_param("SubscriptionArn") arn = self._get_param("SubscriptionArn")
attributes = self.backend.get_subscription_attributes(arn) attributes = self.backend.get_subscription_attributes(arn)
template = self.response_template(GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) template = self.response_template(GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE)
return template.render(attributes=attributes) return template.render(attributes=attributes)
def set_subscription_attributes(self): def set_subscription_attributes(self) -> str:
arn = self._get_param("SubscriptionArn") arn = self._get_param("SubscriptionArn")
attr_name = self._get_param("AttributeName") attr_name = self._get_param("AttributeName")
attr_value = self._get_param("AttributeValue") attr_value = self._get_param("AttributeValue")
@ -653,12 +656,12 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) template = self.response_template(SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE)
return template.render() return template.render()
def set_sms_attributes(self): def set_sms_attributes(self) -> str:
# attributes.entry.1.key # attributes.entry.1.key
# attributes.entry.1.value # attributes.entry.1.value
# to # to
# 1: {key:X, value:Y} # 1: {key:X, value:Y}
temp_dict = defaultdict(dict) temp_dict: Dict[str, Any] = defaultdict(dict)
for key, value in self.querystring.items(): for key, value in self.querystring.items():
match = self.SMS_ATTR_REGEX.match(key) match = self.SMS_ATTR_REGEX.match(key)
if match is not None: if match is not None:
@ -678,7 +681,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_SMS_ATTRIBUTES_TEMPLATE) template = self.response_template(SET_SMS_ATTRIBUTES_TEMPLATE)
return template.render() return template.render()
def get_sms_attributes(self): def get_sms_attributes(self) -> str:
filter_list = set() filter_list = set()
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if key.startswith("attributes.member.1"): if key.startswith("attributes.member.1"):
@ -694,7 +697,9 @@ class SNSResponse(BaseResponse):
template = self.response_template(GET_SMS_ATTRIBUTES_TEMPLATE) template = self.response_template(GET_SMS_ATTRIBUTES_TEMPLATE)
return template.render(attributes=result) return template.render(attributes=result)
def check_if_phone_number_is_opted_out(self): def check_if_phone_number_is_opted_out(
self,
) -> Union[str, Tuple[str, Dict[str, int]]]:
number = self._get_param("phoneNumber") number = self._get_param("phoneNumber")
if self.OPT_OUT_PHONE_NUMBER_REGEX.match(number) is None: if self.OPT_OUT_PHONE_NUMBER_REGEX.match(number) is None:
error_response = self._error( error_response = self._error(
@ -707,11 +712,11 @@ class SNSResponse(BaseResponse):
template = self.response_template(CHECK_IF_OPTED_OUT_TEMPLATE) template = self.response_template(CHECK_IF_OPTED_OUT_TEMPLATE)
return template.render(opt_out=str(number.endswith("99")).lower()) return template.render(opt_out=str(number.endswith("99")).lower())
def list_phone_numbers_opted_out(self): def list_phone_numbers_opted_out(self) -> str:
template = self.response_template(LIST_OPTOUT_TEMPLATE) template = self.response_template(LIST_OPTOUT_TEMPLATE)
return template.render(opt_outs=self.backend.opt_out_numbers) return template.render(opt_outs=self.backend.opt_out_numbers)
def opt_in_phone_number(self): def opt_in_phone_number(self) -> str:
number = self._get_param("phoneNumber") number = self._get_param("phoneNumber")
try: try:
@ -722,7 +727,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(OPT_IN_NUMBER_TEMPLATE) template = self.response_template(OPT_IN_NUMBER_TEMPLATE)
return template.render() return template.render()
def add_permission(self): def add_permission(self) -> str:
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
label = self._get_param("Label") label = self._get_param("Label")
aws_account_ids = self._get_multi_param("AWSAccountId.member.") aws_account_ids = self._get_multi_param("AWSAccountId.member.")
@ -733,7 +738,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(ADD_PERMISSION_TEMPLATE) template = self.response_template(ADD_PERMISSION_TEMPLATE)
return template.render() return template.render()
def remove_permission(self): def remove_permission(self) -> str:
topic_arn = self._get_param("TopicArn") topic_arn = self._get_param("TopicArn")
label = self._get_param("Label") label = self._get_param("Label")
@ -742,7 +747,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(DEL_PERMISSION_TEMPLATE) template = self.response_template(DEL_PERMISSION_TEMPLATE)
return template.render() return template.render()
def confirm_subscription(self): def confirm_subscription(self) -> Union[str, Tuple[str, Dict[str, int]]]:
arn = self._get_param("TopicArn") arn = self._get_param("TopicArn")
if arn not in self.backend.topics: if arn not in self.backend.topics:
@ -767,7 +772,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(CONFIRM_SUBSCRIPTION_TEMPLATE) template = self.response_template(CONFIRM_SUBSCRIPTION_TEMPLATE)
return template.render(sub_arn=f"{arn}:68762e72-e9b1-410a-8b3b-903da69ee1d5") return template.render(sub_arn=f"{arn}:68762e72-e9b1-410a-8b3b-903da69ee1d5")
def list_tags_for_resource(self): def list_tags_for_resource(self) -> str:
arn = self._get_param("ResourceArn") arn = self._get_param("ResourceArn")
result = self.backend.list_tags_for_resource(arn) result = self.backend.list_tags_for_resource(arn)
@ -775,7 +780,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE) template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE)
return template.render(tags=result) return template.render(tags=result)
def tag_resource(self): def tag_resource(self) -> str:
arn = self._get_param("ResourceArn") arn = self._get_param("ResourceArn")
tags = self._get_tags() tags = self._get_tags()
@ -783,7 +788,7 @@ class SNSResponse(BaseResponse):
return self.response_template(TAG_RESOURCE_TEMPLATE).render() return self.response_template(TAG_RESOURCE_TEMPLATE).render()
def untag_resource(self): def untag_resource(self) -> str:
arn = self._get_param("ResourceArn") arn = self._get_param("ResourceArn")
tag_keys = self._get_multi_param("TagKeys.member") tag_keys = self._get_multi_param("TagKeys.member")

View File

@ -4,14 +4,14 @@ from moto.moto_api._internal import mock_random
E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$") E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$")
def make_arn_for_topic(account_id, name, region_name): def make_arn_for_topic(account_id: str, name: str, region_name: str) -> str:
return f"arn:aws:sns:{region_name}:{account_id}:{name}" return f"arn:aws:sns:{region_name}:{account_id}:{name}"
def make_arn_for_subscription(topic_arn): def make_arn_for_subscription(topic_arn: str) -> str:
subscription_id = mock_random.uuid4() subscription_id = mock_random.uuid4()
return f"{topic_arn}:{subscription_id}" return f"{topic_arn}:{subscription_id}"
def is_e164(number): def is_e164(number: str) -> bool:
return E164_REGEX.match(number) is not None return E164_REGEX.match(number) is not None

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf,moto/sns
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract