diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py index 954d28d5a..705d5539f 100644 --- a/moto/sns/exceptions.py +++ b/moto/sns/exceptions.py @@ -1,8 +1,9 @@ +from typing import Any, Optional from moto.core.exceptions import RESTError class SNSException(RESTError): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): kwargs["template"] = "wrapped_single_error" super().__init__(*args, **kwargs) @@ -10,54 +11,54 @@ class SNSException(RESTError): class SNSNotFoundError(SNSException): code = 404 - def __init__(self, message, **kwargs): - super().__init__("NotFound", message, **kwargs) + def __init__(self, message: str, template: Optional[str] = None): + super().__init__("NotFound", message, template=template) class TopicNotFound(SNSNotFoundError): - def __init__(self): + def __init__(self) -> None: super().__init__(message="Topic does not exist") class ResourceNotFoundError(SNSException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("ResourceNotFound", "Resource does not exist") class DuplicateSnsEndpointError(SNSException): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("DuplicateEndpoint", message) class SnsEndpointDisabled(SNSException): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("EndpointDisabled", message) class SNSInvalidParameter(SNSException): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidParameter", message) class InvalidParameterValue(SNSException): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidParameterValue", message) class TagLimitExceededError(SNSException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "TagLimitExceeded", "Could not complete request: tag quota of per resource exceeded", @@ -67,14 +68,14 @@ class TagLimitExceededError(SNSException): class InternalError(SNSException): code = 500 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InternalFailure", message) class TooManyEntriesInBatchRequest(SNSException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "TooManyEntriesInBatchRequest", "The batch request contains more entries than permissible.", @@ -84,7 +85,7 @@ class TooManyEntriesInBatchRequest(SNSException): class BatchEntryIdsNotDistinct(SNSException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "BatchEntryIdsNotDistinct", "Two or more batch entries in the request have the same Id.", diff --git a/moto/sns/models.py b/moto/sns/models.py index 7cc3f461b..a3cf2afd0 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -1,10 +1,11 @@ import datetime import json - import requests import re from collections import OrderedDict +from typing import Any, Dict, List, Iterable, Optional, Tuple + from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core.utils import ( iso_8601_datetime_with_milliseconds, @@ -36,7 +37,7 @@ MAXIMUM_SMS_MESSAGE_BYTES = 1600 # Amazon limit for a single publish SMS action class Topic(CloudFormationModel): - def __init__(self, name, sns_backend): + def __init__(self, name: str, sns_backend: "SNSBackend"): self.name = name self.sns_backend = sns_backend self.account_id = sns_backend.account_id @@ -49,23 +50,25 @@ class Topic(CloudFormationModel): self.subscriptions_pending = 0 self.subscriptions_confimed = 0 self.subscriptions_deleted = 0 - self.sent_notifications = [] + self.sent_notifications: List[ + Tuple[str, str, Optional[str], Optional[Dict[str, Any]], Optional[str]] + ] = [] self._policy_json = self._create_default_topic_policy( sns_backend.region_name, self.account_id, name ) - self._tags = {} + self._tags: Dict[str, str] = {} self.fifo_topic = "false" self.content_based_deduplication = "false" def publish( self, - message, - subject=None, - message_attributes=None, - group_id=None, - deduplication_id=None, - ): + message: str, + subject: Optional[str] = None, + message_attributes: Optional[Dict[str, Any]] = None, + group_id: Optional[str] = None, + deduplication_id: Optional[str] = None, + ) -> str: message_id = str(mock_random.uuid4()) subscriptions, _ = self.sns_backend.list_subscriptions(self.arn) for subscription in subscriptions: @@ -83,10 +86,10 @@ class Topic(CloudFormationModel): return message_id @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["TopicName"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "TopicName": @@ -94,30 +97,35 @@ class Topic(CloudFormationModel): raise UnformattedGetAttTemplateException() @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.arn @property - def policy(self): + def policy(self) -> str: return json.dumps(self._policy_json, separators=(",", ":")) @policy.setter - def policy(self, policy): + def policy(self, policy: Any) -> None: # type: ignore[misc] self._policy_json = json.loads(policy) @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "TopicName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sns-topic.html return "AWS::SNS::Topic" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Topic": sns_backend = sns_backends[account_id][region_name] properties = cloudformation_json["Properties"] @@ -129,14 +137,14 @@ class Topic(CloudFormationModel): return topic @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "Topic": cls.delete_from_cloudformation_json( original_resource.name, cloudformation_json, account_id, region_name ) @@ -145,9 +153,13 @@ class Topic(CloudFormationModel): ) @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: sns_backend = sns_backends[account_id][region_name] properties = cloudformation_json["Properties"] @@ -158,7 +170,9 @@ class Topic(CloudFormationModel): sns_backend.unsubscribe(subscription.arn) sns_backend.delete_topic(topic_arn) - def _create_default_topic_policy(self, region_name, account_id, name): + def _create_default_topic_policy( + self, region_name: str, account_id: str, name: str + ) -> Dict[str, Any]: return { "Version": "2008-10-17", "Id": "__default_policy_ID", @@ -186,25 +200,25 @@ class Topic(CloudFormationModel): class Subscription(BaseModel): - def __init__(self, account_id, topic, endpoint, protocol): + def __init__(self, account_id: str, topic: Topic, endpoint: str, protocol: str): self.account_id = account_id self.topic = topic self.endpoint = endpoint self.protocol = protocol self.arn = make_arn_for_subscription(self.topic.arn) - self.attributes = {} + self.attributes: Dict[str, Any] = {} self._filter_policy = None # filter policy as a dict, not json. self.confirmed = False def publish( self, - message, - message_id, - subject=None, - message_attributes=None, - group_id=None, - deduplication_id=None, - ): + message: str, + message_id: str, + subject: Optional[str] = None, + message_attributes: Optional[Dict[str, Any]] = None, + group_id: Optional[str] = None, + deduplication_id: Optional[str] = None, + ) -> None: if not self._matches_filter_policy(message_attributes): return @@ -230,7 +244,7 @@ class Subscription(BaseModel): ) else: raw_message_attributes = {} - for key, value in message_attributes.items(): + for key, value in message_attributes.items(): # type: ignore attr_type = "string_value" type_value = value["Value"] if value["Type"].startswith("Binary"): @@ -279,14 +293,18 @@ class Subscription(BaseModel): function_name, message, subject=subject, qualifier=qualifier ) - def _matches_filter_policy(self, message_attributes): + def _matches_filter_policy( + self, message_attributes: Optional[Dict[str, Any]] + ) -> bool: if not self._filter_policy: return True if message_attributes is None: message_attributes = {} - def _field_match(field, rules, message_attributes): + def _field_match( + field: str, rules: List[Any], message_attributes: Dict[str, Any] + ) -> bool: for rule in rules: # TODO: boolean value matching is not supported, SNS behavior unknown if isinstance(rule, str): @@ -400,8 +418,14 @@ class Subscription(BaseModel): for field, rules in self._filter_policy.items() ) - def get_post_data(self, message, message_id, subject, message_attributes=None): - post_data = { + def get_post_data( + self, + message: str, + message_id: str, + subject: Optional[str], + message_attributes: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + post_data: Dict[str, Any] = { "Type": "Notification", "MessageId": message_id, "TopicArn": self.topic.arn, @@ -422,7 +446,14 @@ class Subscription(BaseModel): class PlatformApplication(BaseModel): - def __init__(self, account_id, region, name, platform, attributes): + def __init__( + self, + account_id: str, + region: str, + name: str, + platform: str, + attributes: Dict[str, str], + ): self.region = region self.name = name self.platform = platform @@ -432,7 +463,13 @@ class PlatformApplication(BaseModel): class PlatformEndpoint(BaseModel): def __init__( - self, account_id, region, application, custom_user_data, token, attributes + self, + account_id: str, + region: str, + application: PlatformApplication, + custom_user_data: str, + token: str, + attributes: Dict[str, str], ): self.region = region self.application = application @@ -441,10 +478,10 @@ class PlatformEndpoint(BaseModel): self.attributes = attributes self.id = mock_random.uuid4() self.arn = f"arn:aws:sns:{region}:{account_id}:endpoint/{self.application.platform}/{self.application.name}/{self.id}" - self.messages = OrderedDict() + self.messages: Dict[str, str] = OrderedDict() self.__fixup_attributes() - def __fixup_attributes(self): + def __fixup_attributes(self) -> None: # When AWS returns the attributes dict, it always contains these two elements, so we need to # automatically ensure they exist as well. if "Token" not in self.attributes: @@ -456,10 +493,10 @@ class PlatformEndpoint(BaseModel): self.attributes["Enabled"] = "true" @property - def enabled(self): + def enabled(self) -> bool: return json.loads(self.attributes.get("Enabled", "true").lower()) - def publish(self, message): + def publish(self, message: str) -> str: if not self.enabled: raise SnsEndpointDisabled(f"Endpoint {self.id} disabled") @@ -485,15 +522,15 @@ class SNSBackend(BaseBackend): Note that, as this is an internal API, the exact format may differ per versions. """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.topics = OrderedDict() + self.topics: Dict[str, Topic] = OrderedDict() self.subscriptions: OrderedDict[str, Subscription] = OrderedDict() - self.applications = {} - self.platform_endpoints = {} + self.applications: Dict[str, PlatformApplication] = {} + self.platform_endpoints: Dict[str, PlatformEndpoint] = {} self.region_name = region_name - self.sms_attributes = {} - self.sms_messages = OrderedDict() + self.sms_attributes: Dict[str, str] = {} + self.sms_messages: Dict[str, Tuple[str, str]] = OrderedDict() self.opt_out_numbers = [ "+447420500600", "+447420505401", @@ -506,23 +543,27 @@ class SNSBackend(BaseBackend): ] @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """List of dicts representing default VPC endpoints for this service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "sns" ) - def update_sms_attributes(self, attrs): + def update_sms_attributes(self, attrs: Dict[str, str]) -> None: self.sms_attributes.update(attrs) - def create_topic(self, name, attributes=None, tags=None): + def create_topic( + self, + name: str, + attributes: Optional[Dict[str, str]] = None, + tags: Optional[Dict[str, str]] = None, + ) -> Topic: if attributes is None: attributes = {} - if ( - attributes.get("FifoTopic") - and attributes.get("FifoTopic").lower() == "true" - ): + if attributes.get("FifoTopic") and attributes["FifoTopic"].lower() == "true": fails_constraints = not re.match(r"^[a-zA-Z0-9_-]{1,256}\.fifo$", name) msg = "Fifo Topic names must end with .fifo and must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long." @@ -549,29 +590,33 @@ class SNSBackend(BaseBackend): self.topics[candidate_topic.arn] = candidate_topic return candidate_topic - def _get_values_nexttoken(self, values_map, next_token=None): - if next_token is None or not next_token: - next_token = 0 - next_token = int(next_token) - values = list(values_map.values())[next_token : next_token + DEFAULT_PAGE_SIZE] + def _get_values_nexttoken( + self, values_map: Dict[str, Any], next_token: Optional[str] = None + ) -> Tuple[List[Any], Optional[int]]: + i_next_token = int(next_token or "0") + values = list(values_map.values())[ + i_next_token : i_next_token + DEFAULT_PAGE_SIZE + ] if len(values) == DEFAULT_PAGE_SIZE: - next_token = next_token + DEFAULT_PAGE_SIZE + i_next_token = i_next_token + DEFAULT_PAGE_SIZE else: - next_token = None - return values, next_token + i_next_token = None # type: ignore + return values, i_next_token - def _get_topic_subscriptions(self, topic): + def _get_topic_subscriptions(self, topic: Topic) -> List[Subscription]: return [sub for sub in self.subscriptions.values() if sub.topic == topic] - def list_topics(self, next_token=None): + def list_topics( + self, next_token: Optional[str] = None + ) -> Tuple[List[Topic], Optional[int]]: return self._get_values_nexttoken(self.topics, next_token) - def delete_topic_subscriptions(self, topic): + def delete_topic_subscriptions(self, topic: Topic) -> None: for key, value in dict(self.subscriptions).items(): if value.topic == topic: self.subscriptions.pop(key) - def delete_topic(self, arn): + def delete_topic(self, arn: str) -> None: try: topic = self.get_topic(arn) self.delete_topic_subscriptions(topic) @@ -579,17 +624,19 @@ class SNSBackend(BaseBackend): except KeyError: raise SNSNotFoundError(f"Topic with arn {arn} not found") - def get_topic(self, arn): + def get_topic(self, arn: str) -> Topic: try: return self.topics[arn] except KeyError: raise SNSNotFoundError(f"Topic with arn {arn} not found") - def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): + def set_topic_attribute( + self, topic_arn: str, attribute_name: str, attribute_value: str + ) -> None: topic = self.get_topic(topic_arn) setattr(topic, attribute_name, attribute_value) - def subscribe(self, topic_arn, endpoint, protocol): + def subscribe(self, topic_arn: str, endpoint: str, protocol: str) -> Subscription: if protocol == "sms": if re.search(r"[./-]{2,}", endpoint) or re.search( r"(^[./-]|[./-]$)", endpoint @@ -625,7 +672,9 @@ class SNSBackend(BaseBackend): self.subscriptions[subscription.arn] = subscription return subscription - def _find_subscription(self, topic_arn, endpoint, protocol): + def _find_subscription( + self, topic_arn: str, endpoint: str, protocol: str + ) -> Optional[Subscription]: for subscription in self.subscriptions.values(): if ( subscription.topic.arn == topic_arn @@ -635,10 +684,12 @@ class SNSBackend(BaseBackend): return subscription return None - def unsubscribe(self, subscription_arn): + def unsubscribe(self, subscription_arn: str) -> None: self.subscriptions.pop(subscription_arn, None) - def list_subscriptions(self, topic_arn=None, next_token=None): + def list_subscriptions( + self, topic_arn: Optional[str] = None, next_token: Optional[str] = None + ) -> Tuple[List[Subscription], Optional[int]]: if topic_arn: topic = self.get_topic(topic_arn) filtered = OrderedDict( @@ -650,14 +701,14 @@ class SNSBackend(BaseBackend): def publish( self, - message, - arn=None, - phone_number=None, - subject=None, - message_attributes=None, - group_id=None, - deduplication_id=None, - ): + message: str, + arn: Optional[str], + phone_number: Optional[str] = None, + subject: Optional[str] = None, + message_attributes: Optional[Dict[str, Any]] = None, + group_id: Optional[str] = None, + deduplication_id: Optional[str] = None, + ) -> str: if subject is not None and len(subject) > 100: # Note that the AWS docs around length are wrong: https://github.com/getmoto/moto/issues/1503 raise ValueError("Subject must be less than 100 characters") @@ -677,7 +728,7 @@ class SNSBackend(BaseBackend): ) try: - topic = self.get_topic(arn) + topic = self.get_topic(arn) # type: ignore fifo_topic = topic.fifo_topic == "true" if fifo_topic: @@ -705,40 +756,48 @@ class SNSBackend(BaseBackend): deduplication_id=deduplication_id, ) except SNSNotFoundError: - endpoint = self.get_endpoint(arn) + endpoint = self.get_endpoint(arn) # type: ignore message_id = endpoint.publish(message) return message_id - def create_platform_application(self, name, platform, attributes): + def create_platform_application( + self, name: str, platform: str, attributes: Dict[str, str] + ) -> PlatformApplication: application = PlatformApplication( self.account_id, self.region_name, name, platform, attributes ) self.applications[application.arn] = application return application - def get_application(self, arn): + def get_application(self, arn: str) -> PlatformApplication: try: return self.applications[arn] except KeyError: raise SNSNotFoundError(f"Application with arn {arn} not found") - def set_application_attributes(self, arn, attributes): + def set_application_attributes( + self, arn: str, attributes: Dict[str, Any] + ) -> PlatformApplication: application = self.get_application(arn) application.attributes.update(attributes) return application - def list_platform_applications(self): + def list_platform_applications(self) -> Iterable[PlatformApplication]: return self.applications.values() - def delete_platform_application(self, platform_arn): + def delete_platform_application(self, platform_arn: str) -> None: self.applications.pop(platform_arn) endpoints = self.list_endpoints_by_platform_application(platform_arn) for endpoint in endpoints: self.platform_endpoints.pop(endpoint.arn) def create_platform_endpoint( - self, application, custom_user_data, token, attributes - ): + self, + application: PlatformApplication, + custom_user_data: str, + token: str, + attributes: Dict[str, str], + ) -> PlatformEndpoint: for endpoint in self.platform_endpoints.values(): if token == endpoint.token: if ( @@ -760,33 +819,37 @@ class SNSBackend(BaseBackend): self.platform_endpoints[platform_endpoint.arn] = platform_endpoint return platform_endpoint - def list_endpoints_by_platform_application(self, application_arn): + def list_endpoints_by_platform_application( + self, application_arn: str + ) -> List[PlatformEndpoint]: return [ endpoint for endpoint in self.platform_endpoints.values() if endpoint.application.arn == application_arn ] - def get_endpoint(self, arn): + def get_endpoint(self, arn: str) -> PlatformEndpoint: try: return self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError("Endpoint does not exist") - def set_endpoint_attributes(self, arn, attributes): + def set_endpoint_attributes( + self, arn: str, attributes: Dict[str, Any] + ) -> PlatformEndpoint: endpoint = self.get_endpoint(arn) if "Enabled" in attributes: attributes["Enabled"] = attributes["Enabled"].lower() endpoint.attributes.update(attributes) return endpoint - def delete_endpoint(self, arn): + def delete_endpoint(self, arn: str) -> None: try: del self.platform_endpoints[arn] except KeyError: raise SNSNotFoundError(f"Endpoint with arn {arn} not found") - def get_subscription_attributes(self, arn): + def get_subscription_attributes(self, arn: str) -> Dict[str, Any]: subscription = self.subscriptions.get(arn) if not subscription: @@ -796,7 +859,7 @@ class SNSBackend(BaseBackend): return subscription.attributes - def set_subscription_attributes(self, arn, name, value): + def set_subscription_attributes(self, arn: str, name: str, value: Any) -> None: if name not in [ "RawMessageDelivery", "DeliveryPolicy", @@ -819,7 +882,7 @@ class SNSBackend(BaseBackend): self._validate_filter_policy(filter_policy) subscription._filter_policy = filter_policy - def _validate_filter_policy(self, value): + def _validate_filter_policy(self, value: Any) -> None: # TODO: extend validation checks combinations = 1 for rules in value.values(): @@ -931,7 +994,13 @@ class SNSBackend(BaseBackend): "Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null" ) - def add_permission(self, topic_arn, label, aws_account_ids, action_names): + def add_permission( + self, + topic_arn: str, + label: str, + aws_account_ids: List[str], + action_names: List[str], + ) -> None: if topic_arn not in self.topics: raise SNSNotFoundError("Topic does not exist") @@ -966,7 +1035,7 @@ class SNSBackend(BaseBackend): self.topics[topic_arn]._policy_json["Statement"].append(statement) - def remove_permission(self, topic_arn, label): + def remove_permission(self, topic_arn: str, label: str) -> None: if topic_arn not in self.topics: raise SNSNotFoundError("Topic does not exist") @@ -977,13 +1046,13 @@ class SNSBackend(BaseBackend): self.topics[topic_arn]._policy_json["Statement"] = statements - def list_tags_for_resource(self, resource_arn): + def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]: if resource_arn not in self.topics: raise ResourceNotFoundError return self.topics[resource_arn]._tags - def tag_resource(self, resource_arn, tags): + def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None: if resource_arn not in self.topics: raise ResourceNotFoundError @@ -995,14 +1064,16 @@ class SNSBackend(BaseBackend): self.topics[resource_arn]._tags = updated_tags - def untag_resource(self, resource_arn, tag_keys): + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: if resource_arn not in self.topics: raise ResourceNotFoundError for key in tag_keys: self.topics[resource_arn]._tags.pop(key, None) - def publish_batch(self, topic_arn, publish_batch_request_entries): + def publish_batch( + self, topic_arn: str, publish_batch_request_entries: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, str]], List[Dict[str, Any]]]: """ The MessageStructure and MessageDeduplicationId-parameters have not yet been implemented. """ @@ -1027,8 +1098,8 @@ class SNSBackend(BaseBackend): "Invalid parameter: The MessageGroupId parameter is required for FIFO topics" ) - successful = [] - failed = [] + successful: List[Dict[str, str]] = [] + failed: List[Dict[str, Any]] = [] for entry in publish_batch_request_entries: try: diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 21372218e..adad2dc96 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -1,10 +1,11 @@ import json import re from collections import defaultdict +from typing import Any, Dict, Tuple, Union from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores -from .models import sns_backends +from .models import sns_backends, SNSBackend from .exceptions import InvalidParameterValue, SNSNotFoundError from .utils import is_e164 @@ -15,32 +16,34 @@ class SNSResponse(BaseResponse): ) OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r"^\+?\d+$") - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="sns") @property - def backend(self): + def backend(self) -> SNSBackend: return sns_backends[self.current_account][self.region] - def _error(self, code, message, sender="Sender"): + def _error(self, code: str, message: str, sender: str = "Sender") -> str: template = self.response_template(ERROR_RESPONSE) return template.render(code=code, message=message, sender=sender) - def _get_attributes(self): + def _get_attributes(self) -> Dict[str, str]: attributes = self._get_list_prefix("Attributes.entry") return dict((attribute["key"], attribute["value"]) for attribute in attributes) - def _get_tags(self): + def _get_tags(self) -> Dict[str, str]: tags = self._get_list_prefix("Tags.member") return {tag["key"]: tag["value"] for tag in tags} - def _parse_message_attributes(self): + def _parse_message_attributes(self) -> Dict[str, Any]: message_attributes = self._get_object_map( "MessageAttributes.entry", name="Name", value="Value" ) return self._transform_message_attributes(message_attributes) - def _transform_message_attributes(self, message_attributes): + def _transform_message_attributes( + self, message_attributes: Dict[str, Any] + ) -> Dict[str, Any]: # SNS converts some key names before forwarding messages # DataType -> Type, StringValue -> Value, BinaryValue -> Value transformed_message_attributes = {} @@ -96,7 +99,7 @@ class SNSResponse(BaseResponse): return transformed_message_attributes - def create_topic(self): + def create_topic(self) -> str: name = self._get_param("Name") attributes = self._get_attributes() tags = self._get_tags() @@ -117,7 +120,7 @@ class SNSResponse(BaseResponse): template = self.response_template(CREATE_TOPIC_TEMPLATE) return template.render(topic=topic) - def list_topics(self): + def list_topics(self) -> str: next_token = self._get_param("NextToken") topics, next_token = self.backend.list_topics(next_token=next_token) @@ -139,7 +142,7 @@ class SNSResponse(BaseResponse): template = self.response_template(LIST_TOPICS_TEMPLATE) return template.render(topics=topics, next_token=next_token) - def delete_topic(self): + def delete_topic(self) -> str: topic_arn = self._get_param("TopicArn") self.backend.delete_topic(topic_arn) @@ -157,7 +160,7 @@ class SNSResponse(BaseResponse): template = self.response_template(DELETE_TOPIC_TEMPLATE) return template.render() - def get_topic_attributes(self): + def get_topic_attributes(self) -> str: topic_arn = self._get_param("TopicArn") topic = self.backend.get_topic(topic_arn) @@ -193,7 +196,7 @@ class SNSResponse(BaseResponse): template = self.response_template(GET_TOPIC_ATTRIBUTES_TEMPLATE) return template.render(topic=topic) - def set_topic_attributes(self): + def set_topic_attributes(self) -> str: topic_arn = self._get_param("TopicArn") attribute_name = self._get_param("AttributeName") attribute_name = camelcase_to_underscores(attribute_name) @@ -214,7 +217,7 @@ class SNSResponse(BaseResponse): template = self.response_template(SET_TOPIC_ATTRIBUTES_TEMPLATE) return template.render() - def subscribe(self): + def subscribe(self) -> str: topic_arn = self._get_param("TopicArn") endpoint = self._get_param("Endpoint") protocol = self._get_param("Protocol") @@ -243,7 +246,7 @@ class SNSResponse(BaseResponse): template = self.response_template(SUBSCRIBE_TEMPLATE) return template.render(subscription=subscription) - def unsubscribe(self): + def unsubscribe(self) -> str: subscription_arn = self._get_param("SubscriptionArn") self.backend.unsubscribe(subscription_arn) @@ -261,7 +264,7 @@ class SNSResponse(BaseResponse): template = self.response_template(UNSUBSCRIBE_TEMPLATE) return template.render() - def list_subscriptions(self): + def list_subscriptions(self) -> str: next_token = self._get_param("NextToken") subscriptions, next_token = self.backend.list_subscriptions( next_token=next_token @@ -294,7 +297,7 @@ class SNSResponse(BaseResponse): template = self.response_template(LIST_SUBSCRIPTIONS_TEMPLATE) return template.render(subscriptions=subscriptions, next_token=next_token) - def list_subscriptions_by_topic(self): + def list_subscriptions_by_topic(self) -> str: topic_arn = self._get_param("TopicArn") next_token = self._get_param("NextToken") subscriptions, next_token = self.backend.list_subscriptions( @@ -328,7 +331,7 @@ class SNSResponse(BaseResponse): template = self.response_template(LIST_SUBSCRIPTIONS_BY_TOPIC_TEMPLATE) return template.render(subscriptions=subscriptions, next_token=next_token) - def publish(self): + def publish(self) -> Union[str, Tuple[str, Dict[str, int]]]: target_arn = self._get_param("TargetArn") topic_arn = self._get_param("TopicArn") phone_number = self._get_param("PhoneNumber") @@ -384,7 +387,7 @@ class SNSResponse(BaseResponse): template = self.response_template(PUBLISH_TEMPLATE) return template.render(message_id=message_id) - def publish_batch(self): + def publish_batch(self) -> str: topic_arn = self._get_param("TopicArn") publish_batch_request_entries = self._get_multi_param( "PublishBatchRequestEntries.member" @@ -406,7 +409,7 @@ class SNSResponse(BaseResponse): template = self.response_template(PUBLISH_BATCH_TEMPLATE) return template.render(successful=successful, failed=failed) - def create_platform_application(self): + def create_platform_application(self) -> str: name = self._get_param("Name") platform = self._get_param("Platform") attributes = self._get_attributes() @@ -431,7 +434,7 @@ class SNSResponse(BaseResponse): template = self.response_template(CREATE_PLATFORM_APPLICATION_TEMPLATE) return template.render(platform_application=platform_application) - def get_platform_application_attributes(self): + def get_platform_application_attributes(self) -> str: arn = self._get_param("PlatformApplicationArn") application = self.backend.get_application(arn) @@ -452,7 +455,7 @@ class SNSResponse(BaseResponse): template = self.response_template(GET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) return template.render(application=application) - def set_platform_application_attributes(self): + def set_platform_application_attributes(self) -> str: arn = self._get_param("PlatformApplicationArn") attributes = self._get_attributes() @@ -472,7 +475,7 @@ class SNSResponse(BaseResponse): template = self.response_template(SET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) return template.render() - def list_platform_applications(self): + def list_platform_applications(self) -> str: applications = self.backend.list_platform_applications() if self.request_json: @@ -499,7 +502,7 @@ class SNSResponse(BaseResponse): template = self.response_template(LIST_PLATFORM_APPLICATIONS_TEMPLATE) return template.render(applications=applications) - def delete_platform_application(self): + def delete_platform_application(self) -> str: platform_arn = self._get_param("PlatformApplicationArn") self.backend.delete_platform_application(platform_arn) @@ -517,7 +520,7 @@ class SNSResponse(BaseResponse): template = self.response_template(DELETE_PLATFORM_APPLICATION_TEMPLATE) return template.render() - def create_platform_endpoint(self): + def create_platform_endpoint(self) -> str: application_arn = self._get_param("PlatformApplicationArn") application = self.backend.get_application(application_arn) @@ -546,7 +549,7 @@ class SNSResponse(BaseResponse): template = self.response_template(CREATE_PLATFORM_ENDPOINT_TEMPLATE) return template.render(platform_endpoint=platform_endpoint) - def list_endpoints_by_platform_application(self): + def list_endpoints_by_platform_application(self) -> str: application_arn = self._get_param("PlatformApplicationArn") endpoints = self.backend.list_endpoints_by_platform_application(application_arn) @@ -576,7 +579,7 @@ class SNSResponse(BaseResponse): ) return template.render(endpoints=endpoints) - def get_endpoint_attributes(self): + def get_endpoint_attributes(self) -> Union[str, Tuple[str, Dict[str, int]]]: arn = self._get_param("EndpointArn") try: endpoint = self.backend.get_endpoint(arn) @@ -601,7 +604,7 @@ class SNSResponse(BaseResponse): error_response = self._error("NotFound", "Endpoint does not exist") return error_response, dict(status=404) - def set_endpoint_attributes(self): + def set_endpoint_attributes(self) -> Union[str, Tuple[str, Dict[str, int]]]: arn = self._get_param("EndpointArn") attributes = self._get_attributes() @@ -621,7 +624,7 @@ class SNSResponse(BaseResponse): template = self.response_template(SET_ENDPOINT_ATTRIBUTES_TEMPLATE) return template.render() - def delete_endpoint(self): + def delete_endpoint(self) -> str: arn = self._get_param("EndpointArn") self.backend.delete_endpoint(arn) @@ -639,13 +642,13 @@ class SNSResponse(BaseResponse): template = self.response_template(DELETE_ENDPOINT_TEMPLATE) return template.render() - def get_subscription_attributes(self): + def get_subscription_attributes(self) -> str: arn = self._get_param("SubscriptionArn") attributes = self.backend.get_subscription_attributes(arn) template = self.response_template(GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) return template.render(attributes=attributes) - def set_subscription_attributes(self): + def set_subscription_attributes(self) -> str: arn = self._get_param("SubscriptionArn") attr_name = self._get_param("AttributeName") attr_value = self._get_param("AttributeValue") @@ -653,12 +656,12 @@ class SNSResponse(BaseResponse): template = self.response_template(SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) return template.render() - def set_sms_attributes(self): + def set_sms_attributes(self) -> str: # attributes.entry.1.key # attributes.entry.1.value # to # 1: {key:X, value:Y} - temp_dict = defaultdict(dict) + temp_dict: Dict[str, Any] = defaultdict(dict) for key, value in self.querystring.items(): match = self.SMS_ATTR_REGEX.match(key) if match is not None: @@ -678,7 +681,7 @@ class SNSResponse(BaseResponse): template = self.response_template(SET_SMS_ATTRIBUTES_TEMPLATE) return template.render() - def get_sms_attributes(self): + def get_sms_attributes(self) -> str: filter_list = set() for key, value in self.querystring.items(): if key.startswith("attributes.member.1"): @@ -694,7 +697,9 @@ class SNSResponse(BaseResponse): template = self.response_template(GET_SMS_ATTRIBUTES_TEMPLATE) return template.render(attributes=result) - def check_if_phone_number_is_opted_out(self): + def check_if_phone_number_is_opted_out( + self, + ) -> Union[str, Tuple[str, Dict[str, int]]]: number = self._get_param("phoneNumber") if self.OPT_OUT_PHONE_NUMBER_REGEX.match(number) is None: error_response = self._error( @@ -707,11 +712,11 @@ class SNSResponse(BaseResponse): template = self.response_template(CHECK_IF_OPTED_OUT_TEMPLATE) return template.render(opt_out=str(number.endswith("99")).lower()) - def list_phone_numbers_opted_out(self): + def list_phone_numbers_opted_out(self) -> str: template = self.response_template(LIST_OPTOUT_TEMPLATE) return template.render(opt_outs=self.backend.opt_out_numbers) - def opt_in_phone_number(self): + def opt_in_phone_number(self) -> str: number = self._get_param("phoneNumber") try: @@ -722,7 +727,7 @@ class SNSResponse(BaseResponse): template = self.response_template(OPT_IN_NUMBER_TEMPLATE) return template.render() - def add_permission(self): + def add_permission(self) -> str: topic_arn = self._get_param("TopicArn") label = self._get_param("Label") aws_account_ids = self._get_multi_param("AWSAccountId.member.") @@ -733,7 +738,7 @@ class SNSResponse(BaseResponse): template = self.response_template(ADD_PERMISSION_TEMPLATE) return template.render() - def remove_permission(self): + def remove_permission(self) -> str: topic_arn = self._get_param("TopicArn") label = self._get_param("Label") @@ -742,7 +747,7 @@ class SNSResponse(BaseResponse): template = self.response_template(DEL_PERMISSION_TEMPLATE) return template.render() - def confirm_subscription(self): + def confirm_subscription(self) -> Union[str, Tuple[str, Dict[str, int]]]: arn = self._get_param("TopicArn") if arn not in self.backend.topics: @@ -767,7 +772,7 @@ class SNSResponse(BaseResponse): template = self.response_template(CONFIRM_SUBSCRIPTION_TEMPLATE) return template.render(sub_arn=f"{arn}:68762e72-e9b1-410a-8b3b-903da69ee1d5") - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> str: arn = self._get_param("ResourceArn") result = self.backend.list_tags_for_resource(arn) @@ -775,7 +780,7 @@ class SNSResponse(BaseResponse): template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE) return template.render(tags=result) - def tag_resource(self): + def tag_resource(self) -> str: arn = self._get_param("ResourceArn") tags = self._get_tags() @@ -783,7 +788,7 @@ class SNSResponse(BaseResponse): return self.response_template(TAG_RESOURCE_TEMPLATE).render() - def untag_resource(self): + def untag_resource(self) -> str: arn = self._get_param("ResourceArn") tag_keys = self._get_multi_param("TagKeys.member") diff --git a/moto/sns/utils.py b/moto/sns/utils.py index 561097546..8849aa153 100644 --- a/moto/sns/utils.py +++ b/moto/sns/utils.py @@ -4,14 +4,14 @@ from moto.moto_api._internal import mock_random E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$") -def make_arn_for_topic(account_id, name, region_name): +def make_arn_for_topic(account_id: str, name: str, region_name: str) -> str: return f"arn:aws:sns:{region_name}:{account_id}:{name}" -def make_arn_for_subscription(topic_arn): +def make_arn_for_subscription(topic_arn: str) -> str: subscription_id = mock_random.uuid4() return f"{topic_arn}:{subscription_id}" -def is_e164(number): +def is_e164(number: str) -> bool: return E164_REGEX.match(number) is not None diff --git a/setup.cfg b/setup.cfg index cfaf0380e..7d7b3bef9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -239,7 +239,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ses,moto/sqs,moto/ssm,moto/scheduler,moto/swf,moto/sns show_column_numbers=True show_error_codes = True disable_error_code=abstract