Techdebt: MyPy SNS (#6258)

This commit is contained in:
Bert Blommers 2023-04-26 14:25:00 +00:00 committed by GitHub
parent 9b969f7e3f
commit 37f1456747
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 251 additions and 174 deletions

View File

@ -1,8 +1,9 @@
from typing import Any, Optional
from moto.core.exceptions import RESTError
class SNSException(RESTError):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
kwargs["template"] = "wrapped_single_error"
super().__init__(*args, **kwargs)
@ -10,54 +11,54 @@ class SNSException(RESTError):
class SNSNotFoundError(SNSException):
code = 404
def __init__(self, message, **kwargs):
super().__init__("NotFound", message, **kwargs)
def __init__(self, message: str, template: Optional[str] = None):
super().__init__("NotFound", message, template=template)
class TopicNotFound(SNSNotFoundError):
def __init__(self):
def __init__(self) -> None:
super().__init__(message="Topic does not exist")
class ResourceNotFoundError(SNSException):
code = 404
def __init__(self):
def __init__(self) -> None:
super().__init__("ResourceNotFound", "Resource does not exist")
class DuplicateSnsEndpointError(SNSException):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("DuplicateEndpoint", message)
class SnsEndpointDisabled(SNSException):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("EndpointDisabled", message)
class SNSInvalidParameter(SNSException):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("InvalidParameter", message)
class InvalidParameterValue(SNSException):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("InvalidParameterValue", message)
class TagLimitExceededError(SNSException):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"TagLimitExceeded",
"Could not complete request: tag quota of per resource exceeded",
@ -67,14 +68,14 @@ class TagLimitExceededError(SNSException):
class InternalError(SNSException):
code = 500
def __init__(self, message):
def __init__(self, message: str):
super().__init__("InternalFailure", message)
class TooManyEntriesInBatchRequest(SNSException):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"TooManyEntriesInBatchRequest",
"The batch request contains more entries than permissible.",
@ -84,7 +85,7 @@ class TooManyEntriesInBatchRequest(SNSException):
class BatchEntryIdsNotDistinct(SNSException):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"BatchEntryIdsNotDistinct",
"Two or more batch entries in the request have the same Id.",

View File

@ -1,10 +1,11 @@
import datetime
import json
import requests
import re
from collections import OrderedDict
from typing import Any, Dict, List, Iterable, Optional, Tuple
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import (
iso_8601_datetime_with_milliseconds,
@ -36,7 +37,7 @@ MAXIMUM_SMS_MESSAGE_BYTES = 1600 # Amazon limit for a single publish SMS action
class Topic(CloudFormationModel):
def __init__(self, name, sns_backend):
def __init__(self, name: str, sns_backend: "SNSBackend"):
self.name = name
self.sns_backend = sns_backend
self.account_id = sns_backend.account_id
@ -49,23 +50,25 @@ class Topic(CloudFormationModel):
self.subscriptions_pending = 0
self.subscriptions_confimed = 0
self.subscriptions_deleted = 0
self.sent_notifications = []
self.sent_notifications: List[
Tuple[str, str, Optional[str], Optional[Dict[str, Any]], Optional[str]]
] = []
self._policy_json = self._create_default_topic_policy(
sns_backend.region_name, self.account_id, name
)
self._tags = {}
self._tags: Dict[str, str] = {}
self.fifo_topic = "false"
self.content_based_deduplication = "false"
def publish(
self,
message,
subject=None,
message_attributes=None,
group_id=None,
deduplication_id=None,
):
message: str,
subject: Optional[str] = None,
message_attributes: Optional[Dict[str, Any]] = None,
group_id: Optional[str] = None,
deduplication_id: Optional[str] = None,
) -> str:
message_id = str(mock_random.uuid4())
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
for subscription in subscriptions:
@ -83,10 +86,10 @@ class Topic(CloudFormationModel):
return message_id
@classmethod
def has_cfn_attr(cls, attr):
def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["TopicName"]
def get_cfn_attribute(self, attribute_name):
def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "TopicName":
@ -94,30 +97,35 @@ class Topic(CloudFormationModel):
raise UnformattedGetAttTemplateException()
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.arn
@property
def policy(self):
def policy(self) -> str:
return json.dumps(self._policy_json, separators=(",", ":"))
@policy.setter
def policy(self, policy):
def policy(self, policy: Any) -> None: # type: ignore[misc]
self._policy_json = json.loads(policy)
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "TopicName"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sns-topic.html
return "AWS::SNS::Topic"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Topic":
sns_backend = sns_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
@ -129,14 +137,14 @@ class Topic(CloudFormationModel):
return topic
@classmethod
def update_from_cloudformation_json(
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource,
new_resource_name,
cloudformation_json,
account_id,
region_name,
):
original_resource: Any,
new_resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> "Topic":
cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, account_id, region_name
)
@ -145,9 +153,13 @@ class Topic(CloudFormationModel):
)
@classmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name
):
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
sns_backend = sns_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
@ -158,7 +170,9 @@ class Topic(CloudFormationModel):
sns_backend.unsubscribe(subscription.arn)
sns_backend.delete_topic(topic_arn)
def _create_default_topic_policy(self, region_name, account_id, name):
def _create_default_topic_policy(
self, region_name: str, account_id: str, name: str
) -> Dict[str, Any]:
return {
"Version": "2008-10-17",
"Id": "__default_policy_ID",
@ -186,25 +200,25 @@ class Topic(CloudFormationModel):
class Subscription(BaseModel):
def __init__(self, account_id, topic, endpoint, protocol):
def __init__(self, account_id: str, topic: Topic, endpoint: str, protocol: str):
self.account_id = account_id
self.topic = topic
self.endpoint = endpoint
self.protocol = protocol
self.arn = make_arn_for_subscription(self.topic.arn)
self.attributes = {}
self.attributes: Dict[str, Any] = {}
self._filter_policy = None # filter policy as a dict, not json.
self.confirmed = False
def publish(
self,
message,
message_id,
subject=None,
message_attributes=None,
group_id=None,
deduplication_id=None,
):
message: str,
message_id: str,
subject: Optional[str] = None,
message_attributes: Optional[Dict[str, Any]] = None,
group_id: Optional[str] = None,
deduplication_id: Optional[str] = None,
) -> None:
if not self._matches_filter_policy(message_attributes):
return
@ -230,7 +244,7 @@ class Subscription(BaseModel):
)
else:
raw_message_attributes = {}
for key, value in message_attributes.items():
for key, value in message_attributes.items(): # type: ignore
attr_type = "string_value"
type_value = value["Value"]
if value["Type"].startswith("Binary"):
@ -279,14 +293,18 @@ class Subscription(BaseModel):
function_name, message, subject=subject, qualifier=qualifier
)
def _matches_filter_policy(self, message_attributes):
def _matches_filter_policy(
self, message_attributes: Optional[Dict[str, Any]]
) -> bool:
if not self._filter_policy:
return True
if message_attributes is None:
message_attributes = {}
def _field_match(field, rules, message_attributes):
def _field_match(
field: str, rules: List[Any], message_attributes: Dict[str, Any]
) -> bool:
for rule in rules:
# TODO: boolean value matching is not supported, SNS behavior unknown
if isinstance(rule, str):
@ -400,8 +418,14 @@ class Subscription(BaseModel):
for field, rules in self._filter_policy.items()
)
def get_post_data(self, message, message_id, subject, message_attributes=None):
post_data = {
def get_post_data(
self,
message: str,
message_id: str,
subject: Optional[str],
message_attributes: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
post_data: Dict[str, Any] = {
"Type": "Notification",
"MessageId": message_id,
"TopicArn": self.topic.arn,
@ -422,7 +446,14 @@ class Subscription(BaseModel):
class PlatformApplication(BaseModel):
def __init__(self, account_id, region, name, platform, attributes):
def __init__(
self,
account_id: str,
region: str,
name: str,
platform: str,
attributes: Dict[str, str],
):
self.region = region
self.name = name
self.platform = platform
@ -432,7 +463,13 @@ class PlatformApplication(BaseModel):
class PlatformEndpoint(BaseModel):
def __init__(
self, account_id, region, application, custom_user_data, token, attributes
self,
account_id: str,
region: str,
application: PlatformApplication,
custom_user_data: str,
token: str,
attributes: Dict[str, str],
):
self.region = region
self.application = application
@ -441,10 +478,10 @@ class PlatformEndpoint(BaseModel):
self.attributes = attributes
self.id = mock_random.uuid4()
self.arn = f"arn:aws:sns:{region}:{account_id}:endpoint/{self.application.platform}/{self.application.name}/{self.id}"
self.messages = OrderedDict()
self.messages: Dict[str, str] = OrderedDict()
self.__fixup_attributes()
def __fixup_attributes(self):
def __fixup_attributes(self) -> None:
# When AWS returns the attributes dict, it always contains these two elements, so we need to
# automatically ensure they exist as well.
if "Token" not in self.attributes:
@ -456,10 +493,10 @@ class PlatformEndpoint(BaseModel):
self.attributes["Enabled"] = "true"
@property
def enabled(self):
def enabled(self) -> bool:
return json.loads(self.attributes.get("Enabled", "true").lower())
def publish(self, message):
def publish(self, message: str) -> str:
if not self.enabled:
raise SnsEndpointDisabled(f"Endpoint {self.id} disabled")
@ -485,15 +522,15 @@ class SNSBackend(BaseBackend):
Note that, as this is an internal API, the exact format may differ per versions.
"""
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.topics = OrderedDict()
self.topics: Dict[str, Topic] = OrderedDict()
self.subscriptions: OrderedDict[str, Subscription] = OrderedDict()
self.applications = {}
self.platform_endpoints = {}
self.applications: Dict[str, PlatformApplication] = {}
self.platform_endpoints: Dict[str, PlatformEndpoint] = {}
self.region_name = region_name
self.sms_attributes = {}
self.sms_messages = OrderedDict()
self.sms_attributes: Dict[str, str] = {}
self.sms_messages: Dict[str, Tuple[str, str]] = OrderedDict()
self.opt_out_numbers = [
"+447420500600",
"+447420505401",
@ -506,23 +543,27 @@ class SNSBackend(BaseBackend):
]
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""List of dicts representing default VPC endpoints for this service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "sns"
)
def update_sms_attributes(self, attrs):
def update_sms_attributes(self, attrs: Dict[str, str]) -> None:
self.sms_attributes.update(attrs)
def create_topic(self, name, attributes=None, tags=None):
def create_topic(
self,
name: str,
attributes: Optional[Dict[str, str]] = None,
tags: Optional[Dict[str, str]] = None,
) -> Topic:
if attributes is None:
attributes = {}
if (
attributes.get("FifoTopic")
and attributes.get("FifoTopic").lower() == "true"
):
if attributes.get("FifoTopic") and attributes["FifoTopic"].lower() == "true":
fails_constraints = not re.match(r"^[a-zA-Z0-9_-]{1,256}\.fifo$", name)
msg = "Fifo Topic names must end with .fifo and must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long."
@ -549,29 +590,33 @@ class SNSBackend(BaseBackend):
self.topics[candidate_topic.arn] = candidate_topic
return candidate_topic
def _get_values_nexttoken(self, values_map, next_token=None):
if next_token is None or not next_token:
next_token = 0
next_token = int(next_token)
values = list(values_map.values())[next_token : next_token + DEFAULT_PAGE_SIZE]
def _get_values_nexttoken(
self, values_map: Dict[str, Any], next_token: Optional[str] = None
) -> Tuple[List[Any], Optional[int]]:
i_next_token = int(next_token or "0")
values = list(values_map.values())[
i_next_token : i_next_token + DEFAULT_PAGE_SIZE
]
if len(values) == DEFAULT_PAGE_SIZE:
next_token = next_token + DEFAULT_PAGE_SIZE
i_next_token = i_next_token + DEFAULT_PAGE_SIZE
else:
next_token = None
return values, next_token
i_next_token = None # type: ignore
return values, i_next_token
def _get_topic_subscriptions(self, topic):
def _get_topic_subscriptions(self, topic: Topic) -> List[Subscription]:
return [sub for sub in self.subscriptions.values() if sub.topic == topic]
def list_topics(self, next_token=None):
def list_topics(
self, next_token: Optional[str] = None
) -> Tuple[List[Topic], Optional[int]]:
return self._get_values_nexttoken(self.topics, next_token)
def delete_topic_subscriptions(self, topic):
def delete_topic_subscriptions(self, topic: Topic) -> None:
for key, value in dict(self.subscriptions).items():
if value.topic == topic:
self.subscriptions.pop(key)
def delete_topic(self, arn):
def delete_topic(self, arn: str) -> None:
try:
topic = self.get_topic(arn)
self.delete_topic_subscriptions(topic)
@ -579,17 +624,19 @@ class SNSBackend(BaseBackend):
except KeyError:
raise SNSNotFoundError(f"Topic with arn {arn} not found")
def get_topic(self, arn):
def get_topic(self, arn: str) -> Topic:
try:
return self.topics[arn]
except KeyError:
raise SNSNotFoundError(f"Topic with arn {arn} not found")
def set_topic_attribute(self, topic_arn, attribute_name, attribute_value):
def set_topic_attribute(
self, topic_arn: str, attribute_name: str, attribute_value: str
) -> None:
topic = self.get_topic(topic_arn)
setattr(topic, attribute_name, attribute_value)
def subscribe(self, topic_arn, endpoint, protocol):
def subscribe(self, topic_arn: str, endpoint: str, protocol: str) -> Subscription:
if protocol == "sms":
if re.search(r"[./-]{2,}", endpoint) or re.search(
r"(^[./-]|[./-]$)", endpoint
@ -625,7 +672,9 @@ class SNSBackend(BaseBackend):
self.subscriptions[subscription.arn] = subscription
return subscription
def _find_subscription(self, topic_arn, endpoint, protocol):
def _find_subscription(
self, topic_arn: str, endpoint: str, protocol: str
) -> Optional[Subscription]:
for subscription in self.subscriptions.values():
if (
subscription.topic.arn == topic_arn
@ -635,10 +684,12 @@ class SNSBackend(BaseBackend):
return subscription
return None
def unsubscribe(self, subscription_arn):
def unsubscribe(self, subscription_arn: str) -> None:
self.subscriptions.pop(subscription_arn, None)
def list_subscriptions(self, topic_arn=None, next_token=None):
def list_subscriptions(
self, topic_arn: Optional[str] = None, next_token: Optional[str] = None
) -> Tuple[List[Subscription], Optional[int]]:
if topic_arn:
topic = self.get_topic(topic_arn)
filtered = OrderedDict(
@ -650,14 +701,14 @@ class SNSBackend(BaseBackend):
def publish(
self,
message,
arn=None,
phone_number=None,
subject=None,
message_attributes=None,
group_id=None,
deduplication_id=None,
):
message: str,
arn: Optional[str],
phone_number: Optional[str] = None,
subject: Optional[str] = None,
message_attributes: Optional[Dict[str, Any]] = None,
group_id: Optional[str] = None,
deduplication_id: Optional[str] = None,
) -> str:
if subject is not None and len(subject) > 100:
# Note that the AWS docs around length are wrong: https://github.com/getmoto/moto/issues/1503
raise ValueError("Subject must be less than 100 characters")
@ -677,7 +728,7 @@ class SNSBackend(BaseBackend):
)
try:
topic = self.get_topic(arn)
topic = self.get_topic(arn) # type: ignore
fifo_topic = topic.fifo_topic == "true"
if fifo_topic:
@ -705,40 +756,48 @@ class SNSBackend(BaseBackend):
deduplication_id=deduplication_id,
)
except SNSNotFoundError:
endpoint = self.get_endpoint(arn)
endpoint = self.get_endpoint(arn) # type: ignore
message_id = endpoint.publish(message)
return message_id
def create_platform_application(self, name, platform, attributes):
def create_platform_application(
self, name: str, platform: str, attributes: Dict[str, str]
) -> PlatformApplication:
application = PlatformApplication(
self.account_id, self.region_name, name, platform, attributes
)
self.applications[application.arn] = application
return application
def get_application(self, arn):
def get_application(self, arn: str) -> PlatformApplication:
try:
return self.applications[arn]
except KeyError:
raise SNSNotFoundError(f"Application with arn {arn} not found")
def set_application_attributes(self, arn, attributes):
def set_application_attributes(
self, arn: str, attributes: Dict[str, Any]
) -> PlatformApplication:
application = self.get_application(arn)
application.attributes.update(attributes)
return application
def list_platform_applications(self):
def list_platform_applications(self) -> Iterable[PlatformApplication]:
return self.applications.values()
def delete_platform_application(self, platform_arn):
def delete_platform_application(self, platform_arn: str) -> None:
self.applications.pop(platform_arn)
endpoints = self.list_endpoints_by_platform_application(platform_arn)
for endpoint in endpoints:
self.platform_endpoints.pop(endpoint.arn)
def create_platform_endpoint(
self, application, custom_user_data, token, attributes
):
self,
application: PlatformApplication,
custom_user_data: str,
token: str,
attributes: Dict[str, str],
) -> PlatformEndpoint:
for endpoint in self.platform_endpoints.values():
if token == endpoint.token:
if (
@ -760,33 +819,37 @@ class SNSBackend(BaseBackend):
self.platform_endpoints[platform_endpoint.arn] = platform_endpoint
return platform_endpoint
def list_endpoints_by_platform_application(self, application_arn):
def list_endpoints_by_platform_application(
self, application_arn: str
) -> List[PlatformEndpoint]:
return [
endpoint
for endpoint in self.platform_endpoints.values()
if endpoint.application.arn == application_arn
]
def get_endpoint(self, arn):
def get_endpoint(self, arn: str) -> PlatformEndpoint:
try:
return self.platform_endpoints[arn]
except KeyError:
raise SNSNotFoundError("Endpoint does not exist")
def set_endpoint_attributes(self, arn, attributes):
def set_endpoint_attributes(
self, arn: str, attributes: Dict[str, Any]
) -> PlatformEndpoint:
endpoint = self.get_endpoint(arn)
if "Enabled" in attributes:
attributes["Enabled"] = attributes["Enabled"].lower()
endpoint.attributes.update(attributes)
return endpoint
def delete_endpoint(self, arn):
def delete_endpoint(self, arn: str) -> None:
try:
del self.platform_endpoints[arn]
except KeyError:
raise SNSNotFoundError(f"Endpoint with arn {arn} not found")
def get_subscription_attributes(self, arn):
def get_subscription_attributes(self, arn: str) -> Dict[str, Any]:
subscription = self.subscriptions.get(arn)
if not subscription:
@ -796,7 +859,7 @@ class SNSBackend(BaseBackend):
return subscription.attributes
def set_subscription_attributes(self, arn, name, value):
def set_subscription_attributes(self, arn: str, name: str, value: Any) -> None:
if name not in [
"RawMessageDelivery",
"DeliveryPolicy",
@ -819,7 +882,7 @@ class SNSBackend(BaseBackend):
self._validate_filter_policy(filter_policy)
subscription._filter_policy = filter_policy
def _validate_filter_policy(self, value):
def _validate_filter_policy(self, value: Any) -> None:
# TODO: extend validation checks
combinations = 1
for rules in value.values():
@ -931,7 +994,13 @@ class SNSBackend(BaseBackend):
"Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null"
)
def add_permission(self, topic_arn, label, aws_account_ids, action_names):
def add_permission(
self,
topic_arn: str,
label: str,
aws_account_ids: List[str],
action_names: List[str],
) -> None:
if topic_arn not in self.topics:
raise SNSNotFoundError("Topic does not exist")
@ -966,7 +1035,7 @@ class SNSBackend(BaseBackend):
self.topics[topic_arn]._policy_json["Statement"].append(statement)
def remove_permission(self, topic_arn, label):
def remove_permission(self, topic_arn: str, label: str) -> None:
if topic_arn not in self.topics:
raise SNSNotFoundError("Topic does not exist")
@ -977,13 +1046,13 @@ class SNSBackend(BaseBackend):
self.topics[topic_arn]._policy_json["Statement"] = statements
def list_tags_for_resource(self, resource_arn):
def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]:
if resource_arn not in self.topics:
raise ResourceNotFoundError
return self.topics[resource_arn]._tags
def tag_resource(self, resource_arn, tags):
def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None:
if resource_arn not in self.topics:
raise ResourceNotFoundError
@ -995,14 +1064,16 @@ class SNSBackend(BaseBackend):
self.topics[resource_arn]._tags = updated_tags
def untag_resource(self, resource_arn, tag_keys):
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
if resource_arn not in self.topics:
raise ResourceNotFoundError
for key in tag_keys:
self.topics[resource_arn]._tags.pop(key, None)
def publish_batch(self, topic_arn, publish_batch_request_entries):
def publish_batch(
self, topic_arn: str, publish_batch_request_entries: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, str]], List[Dict[str, Any]]]:
"""
The MessageStructure and MessageDeduplicationId-parameters have not yet been implemented.
"""
@ -1027,8 +1098,8 @@ class SNSBackend(BaseBackend):
"Invalid parameter: The MessageGroupId parameter is required for FIFO topics"
)
successful = []
failed = []
successful: List[Dict[str, str]] = []
failed: List[Dict[str, Any]] = []
for entry in publish_batch_request_entries:
try:

View File

@ -1,10 +1,11 @@
import json
import re
from collections import defaultdict
from typing import Any, Dict, Tuple, Union
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from .models import sns_backends
from .models import sns_backends, SNSBackend
from .exceptions import InvalidParameterValue, SNSNotFoundError
from .utils import is_e164
@ -15,32 +16,34 @@ class SNSResponse(BaseResponse):
)
OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r"^\+?\d+$")
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="sns")
@property
def backend(self):
def backend(self) -> SNSBackend:
return sns_backends[self.current_account][self.region]
def _error(self, code, message, sender="Sender"):
def _error(self, code: str, message: str, sender: str = "Sender") -> str:
template = self.response_template(ERROR_RESPONSE)
return template.render(code=code, message=message, sender=sender)
def _get_attributes(self):
def _get_attributes(self) -> Dict[str, str]:
attributes = self._get_list_prefix("Attributes.entry")
return dict((attribute["key"], attribute["value"]) for attribute in attributes)
def _get_tags(self):
def _get_tags(self) -> Dict[str, str]:
tags = self._get_list_prefix("Tags.member")
return {tag["key"]: tag["value"] for tag in tags}
def _parse_message_attributes(self):
def _parse_message_attributes(self) -> Dict[str, Any]:
message_attributes = self._get_object_map(
"MessageAttributes.entry", name="Name", value="Value"
)
return self._transform_message_attributes(message_attributes)
def _transform_message_attributes(self, message_attributes):
def _transform_message_attributes(
self, message_attributes: Dict[str, Any]
) -> Dict[str, Any]:
# SNS converts some key names before forwarding messages
# DataType -> Type, StringValue -> Value, BinaryValue -> Value
transformed_message_attributes = {}
@ -96,7 +99,7 @@ class SNSResponse(BaseResponse):
return transformed_message_attributes
def create_topic(self):
def create_topic(self) -> str:
name = self._get_param("Name")
attributes = self._get_attributes()
tags = self._get_tags()
@ -117,7 +120,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(CREATE_TOPIC_TEMPLATE)
return template.render(topic=topic)
def list_topics(self):
def list_topics(self) -> str:
next_token = self._get_param("NextToken")
topics, next_token = self.backend.list_topics(next_token=next_token)
@ -139,7 +142,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_TOPICS_TEMPLATE)
return template.render(topics=topics, next_token=next_token)
def delete_topic(self):
def delete_topic(self) -> str:
topic_arn = self._get_param("TopicArn")
self.backend.delete_topic(topic_arn)
@ -157,7 +160,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(DELETE_TOPIC_TEMPLATE)
return template.render()
def get_topic_attributes(self):
def get_topic_attributes(self) -> str:
topic_arn = self._get_param("TopicArn")
topic = self.backend.get_topic(topic_arn)
@ -193,7 +196,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(GET_TOPIC_ATTRIBUTES_TEMPLATE)
return template.render(topic=topic)
def set_topic_attributes(self):
def set_topic_attributes(self) -> str:
topic_arn = self._get_param("TopicArn")
attribute_name = self._get_param("AttributeName")
attribute_name = camelcase_to_underscores(attribute_name)
@ -214,7 +217,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_TOPIC_ATTRIBUTES_TEMPLATE)
return template.render()
def subscribe(self):
def subscribe(self) -> str:
topic_arn = self._get_param("TopicArn")
endpoint = self._get_param("Endpoint")
protocol = self._get_param("Protocol")
@ -243,7 +246,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SUBSCRIBE_TEMPLATE)
return template.render(subscription=subscription)
def unsubscribe(self):
def unsubscribe(self) -> str:
subscription_arn = self._get_param("SubscriptionArn")
self.backend.unsubscribe(subscription_arn)
@ -261,7 +264,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(UNSUBSCRIBE_TEMPLATE)
return template.render()
def list_subscriptions(self):
def list_subscriptions(self) -> str:
next_token = self._get_param("NextToken")
subscriptions, next_token = self.backend.list_subscriptions(
next_token=next_token
@ -294,7 +297,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_SUBSCRIPTIONS_TEMPLATE)
return template.render(subscriptions=subscriptions, next_token=next_token)
def list_subscriptions_by_topic(self):
def list_subscriptions_by_topic(self) -> str:
topic_arn = self._get_param("TopicArn")
next_token = self._get_param("NextToken")
subscriptions, next_token = self.backend.list_subscriptions(
@ -328,7 +331,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_SUBSCRIPTIONS_BY_TOPIC_TEMPLATE)
return template.render(subscriptions=subscriptions, next_token=next_token)
def publish(self):
def publish(self) -> Union[str, Tuple[str, Dict[str, int]]]:
target_arn = self._get_param("TargetArn")
topic_arn = self._get_param("TopicArn")
phone_number = self._get_param("PhoneNumber")
@ -384,7 +387,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(PUBLISH_TEMPLATE)
return template.render(message_id=message_id)
def publish_batch(self):
def publish_batch(self) -> str:
topic_arn = self._get_param("TopicArn")
publish_batch_request_entries = self._get_multi_param(
"PublishBatchRequestEntries.member"
@ -406,7 +409,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(PUBLISH_BATCH_TEMPLATE)
return template.render(successful=successful, failed=failed)
def create_platform_application(self):
def create_platform_application(self) -> str:
name = self._get_param("Name")
platform = self._get_param("Platform")
attributes = self._get_attributes()
@ -431,7 +434,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(CREATE_PLATFORM_APPLICATION_TEMPLATE)
return template.render(platform_application=platform_application)
def get_platform_application_attributes(self):
def get_platform_application_attributes(self) -> str:
arn = self._get_param("PlatformApplicationArn")
application = self.backend.get_application(arn)
@ -452,7 +455,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(GET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE)
return template.render(application=application)
def set_platform_application_attributes(self):
def set_platform_application_attributes(self) -> str:
arn = self._get_param("PlatformApplicationArn")
attributes = self._get_attributes()
@ -472,7 +475,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE)
return template.render()
def list_platform_applications(self):
def list_platform_applications(self) -> str:
applications = self.backend.list_platform_applications()
if self.request_json:
@ -499,7 +502,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_PLATFORM_APPLICATIONS_TEMPLATE)
return template.render(applications=applications)
def delete_platform_application(self):
def delete_platform_application(self) -> str:
platform_arn = self._get_param("PlatformApplicationArn")
self.backend.delete_platform_application(platform_arn)
@ -517,7 +520,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(DELETE_PLATFORM_APPLICATION_TEMPLATE)
return template.render()
def create_platform_endpoint(self):
def create_platform_endpoint(self) -> str:
application_arn = self._get_param("PlatformApplicationArn")
application = self.backend.get_application(application_arn)
@ -546,7 +549,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(CREATE_PLATFORM_ENDPOINT_TEMPLATE)
return template.render(platform_endpoint=platform_endpoint)
def list_endpoints_by_platform_application(self):
def list_endpoints_by_platform_application(self) -> str:
application_arn = self._get_param("PlatformApplicationArn")
endpoints = self.backend.list_endpoints_by_platform_application(application_arn)
@ -576,7 +579,7 @@ class SNSResponse(BaseResponse):
)
return template.render(endpoints=endpoints)
def get_endpoint_attributes(self):
def get_endpoint_attributes(self) -> Union[str, Tuple[str, Dict[str, int]]]:
arn = self._get_param("EndpointArn")
try:
endpoint = self.backend.get_endpoint(arn)
@ -601,7 +604,7 @@ class SNSResponse(BaseResponse):
error_response = self._error("NotFound", "Endpoint does not exist")
return error_response, dict(status=404)
def set_endpoint_attributes(self):
def set_endpoint_attributes(self) -> Union[str, Tuple[str, Dict[str, int]]]:
arn = self._get_param("EndpointArn")
attributes = self._get_attributes()
@ -621,7 +624,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_ENDPOINT_ATTRIBUTES_TEMPLATE)
return template.render()
def delete_endpoint(self):
def delete_endpoint(self) -> str:
arn = self._get_param("EndpointArn")
self.backend.delete_endpoint(arn)
@ -639,13 +642,13 @@ class SNSResponse(BaseResponse):
template = self.response_template(DELETE_ENDPOINT_TEMPLATE)
return template.render()
def get_subscription_attributes(self):
def get_subscription_attributes(self) -> str:
arn = self._get_param("SubscriptionArn")
attributes = self.backend.get_subscription_attributes(arn)
template = self.response_template(GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE)
return template.render(attributes=attributes)
def set_subscription_attributes(self):
def set_subscription_attributes(self) -> str:
arn = self._get_param("SubscriptionArn")
attr_name = self._get_param("AttributeName")
attr_value = self._get_param("AttributeValue")
@ -653,12 +656,12 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE)
return template.render()
def set_sms_attributes(self):
def set_sms_attributes(self) -> str:
# attributes.entry.1.key
# attributes.entry.1.value
# to
# 1: {key:X, value:Y}
temp_dict = defaultdict(dict)
temp_dict: Dict[str, Any] = defaultdict(dict)
for key, value in self.querystring.items():
match = self.SMS_ATTR_REGEX.match(key)
if match is not None:
@ -678,7 +681,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(SET_SMS_ATTRIBUTES_TEMPLATE)
return template.render()
def get_sms_attributes(self):
def get_sms_attributes(self) -> str:
filter_list = set()
for key, value in self.querystring.items():
if key.startswith("attributes.member.1"):
@ -694,7 +697,9 @@ class SNSResponse(BaseResponse):
template = self.response_template(GET_SMS_ATTRIBUTES_TEMPLATE)
return template.render(attributes=result)
def check_if_phone_number_is_opted_out(self):
def check_if_phone_number_is_opted_out(
self,
) -> Union[str, Tuple[str, Dict[str, int]]]:
number = self._get_param("phoneNumber")
if self.OPT_OUT_PHONE_NUMBER_REGEX.match(number) is None:
error_response = self._error(
@ -707,11 +712,11 @@ class SNSResponse(BaseResponse):
template = self.response_template(CHECK_IF_OPTED_OUT_TEMPLATE)
return template.render(opt_out=str(number.endswith("99")).lower())
def list_phone_numbers_opted_out(self):
def list_phone_numbers_opted_out(self) -> str:
template = self.response_template(LIST_OPTOUT_TEMPLATE)
return template.render(opt_outs=self.backend.opt_out_numbers)
def opt_in_phone_number(self):
def opt_in_phone_number(self) -> str:
number = self._get_param("phoneNumber")
try:
@ -722,7 +727,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(OPT_IN_NUMBER_TEMPLATE)
return template.render()
def add_permission(self):
def add_permission(self) -> str:
topic_arn = self._get_param("TopicArn")
label = self._get_param("Label")
aws_account_ids = self._get_multi_param("AWSAccountId.member.")
@ -733,7 +738,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(ADD_PERMISSION_TEMPLATE)
return template.render()
def remove_permission(self):
def remove_permission(self) -> str:
topic_arn = self._get_param("TopicArn")
label = self._get_param("Label")
@ -742,7 +747,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(DEL_PERMISSION_TEMPLATE)
return template.render()
def confirm_subscription(self):
def confirm_subscription(self) -> Union[str, Tuple[str, Dict[str, int]]]:
arn = self._get_param("TopicArn")
if arn not in self.backend.topics:
@ -767,7 +772,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(CONFIRM_SUBSCRIPTION_TEMPLATE)
return template.render(sub_arn=f"{arn}:68762e72-e9b1-410a-8b3b-903da69ee1d5")
def list_tags_for_resource(self):
def list_tags_for_resource(self) -> str:
arn = self._get_param("ResourceArn")
result = self.backend.list_tags_for_resource(arn)
@ -775,7 +780,7 @@ class SNSResponse(BaseResponse):
template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE)
return template.render(tags=result)
def tag_resource(self):
def tag_resource(self) -> str:
arn = self._get_param("ResourceArn")
tags = self._get_tags()
@ -783,7 +788,7 @@ class SNSResponse(BaseResponse):
return self.response_template(TAG_RESOURCE_TEMPLATE).render()
def untag_resource(self):
def untag_resource(self) -> str:
arn = self._get_param("ResourceArn")
tag_keys = self._get_multi_param("TagKeys.member")

View File

@ -4,14 +4,14 @@ from moto.moto_api._internal import mock_random
E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$")
def make_arn_for_topic(account_id, name, region_name):
def make_arn_for_topic(account_id: str, name: str, region_name: str) -> str:
return f"arn:aws:sns:{region_name}:{account_id}:{name}"
def make_arn_for_subscription(topic_arn):
def make_arn_for_subscription(topic_arn: str) -> str:
subscription_id = mock_random.uuid4()
return f"{topic_arn}:{subscription_id}"
def is_e164(number):
def is_e164(number: str) -> bool:
return E164_REGEX.match(number) is not None

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf,moto/sns
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract