diff --git a/moto/sqs/exceptions.py b/moto/sqs/exceptions.py index faf49255d..d2f483c4e 100644 --- a/moto/sqs/exceptions.py +++ b/moto/sqs/exceptions.py @@ -4,7 +4,7 @@ from moto.core.exceptions import RESTError class ReceiptHandleIsInvalid(RESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "ReceiptHandleIsInvalid", "The input receipt handle is invalid." ) @@ -13,14 +13,14 @@ class ReceiptHandleIsInvalid(RESTError): class MessageAttributesInvalid(RESTError): code = 400 - def __init__(self, description): + def __init__(self, description: str): super().__init__("MessageAttributesInvalid", description) class QueueDoesNotExist(RESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "AWS.SimpleQueueService.NonExistentQueue", "The specified queue does not exist for this wsdl version.", @@ -31,14 +31,14 @@ class QueueDoesNotExist(RESTError): class QueueAlreadyExists(RESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("QueueAlreadyExists", message) class EmptyBatchRequest(RESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "EmptyBatchRequest", "There should be at least one SendMessageBatchRequestEntry in the request.", @@ -48,7 +48,7 @@ class EmptyBatchRequest(RESTError): class InvalidBatchEntryId(RESTError): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "InvalidBatchEntryId", "A batch entry id can only contain alphanumeric characters, " @@ -59,7 +59,7 @@ class InvalidBatchEntryId(RESTError): class BatchRequestTooLong(RESTError): code = 400 - def __init__(self, length): + def __init__(self, length: int): super().__init__( "BatchRequestTooLong", "Batch requests cannot be longer than 262144 bytes. " @@ -70,14 +70,14 @@ class BatchRequestTooLong(RESTError): class BatchEntryIdsNotDistinct(RESTError): code = 400 - def __init__(self, entry_id): + def __init__(self, entry_id: str): super().__init__("BatchEntryIdsNotDistinct", f"Id {entry_id} repeated.") class TooManyEntriesInBatchRequest(RESTError): code = 400 - def __init__(self, number): + def __init__(self, number: int): super().__init__( "TooManyEntriesInBatchRequest", "Maximum number of entries per request are 10. " f"You have sent {number}.", @@ -87,14 +87,14 @@ class TooManyEntriesInBatchRequest(RESTError): class InvalidAttributeName(RESTError): code = 400 - def __init__(self, attribute_name): + def __init__(self, attribute_name: str): super().__init__("InvalidAttributeName", f"Unknown Attribute {attribute_name}.") class InvalidAttributeValue(RESTError): code = 400 - def __init__(self, attribute_name): + def __init__(self, attribute_name: str): super().__init__( "InvalidAttributeValue", f"Invalid value for the parameter {attribute_name}.", @@ -104,14 +104,14 @@ class InvalidAttributeValue(RESTError): class InvalidParameterValue(RESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidParameterValue", message) class MissingParameter(RESTError): code = 400 - def __init__(self, parameter): + def __init__(self, parameter: str): super().__init__( "MissingParameter", f"The request must contain the parameter {parameter}." ) @@ -120,7 +120,7 @@ class MissingParameter(RESTError): class OverLimit(RESTError): code = 403 - def __init__(self, count): + def __init__(self, count: int): super().__init__( "OverLimit", f"{count} Actions were found, maximum allowed is 7." ) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index 004aa86fb..7d775d274 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -6,8 +6,9 @@ import string import struct from copy import deepcopy -from typing import Dict +from typing import Any, Dict, List, Optional, Tuple, Set, TYPE_CHECKING from threading import Condition +from urllib.parse import ParseResult from xml.sax.saxutils import escape from moto.core.exceptions import RESTError @@ -38,6 +39,9 @@ from .exceptions import ( InvalidAttributeValue, ) +if TYPE_CHECKING: + from moto.awslambda.models import EventSourceMapping + DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU" MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB @@ -67,31 +71,36 @@ DEDUPLICATION_TIME_IN_SECONDS = 300 class Message(BaseModel): - def __init__(self, message_id, body, system_attributes=None): + def __init__( + self, + message_id: str, + body: str, + system_attributes: Optional[Dict[str, Any]] = None, + ): self.id = message_id self._body = body - self.message_attributes = {} - self.receipt_handle = None - self._old_receipt_handles = [] + self.message_attributes: Dict[str, Any] = {} + self.receipt_handle: Optional[str] = None + self._old_receipt_handles: List[str] = [] self.sender_id = DEFAULT_SENDER_ID self.sent_timestamp = None - self.approximate_first_receive_timestamp = None + self.approximate_first_receive_timestamp: Optional[int] = None self.approximate_receive_count = 0 - self.deduplication_id = None - self.group_id = None - self.sequence_number = None - self.visible_at = 0 - self.delayed_until = 0 + self.deduplication_id: Optional[str] = None + self.group_id: Optional[str] = None + self.sequence_number: Optional[str] = None + self.visible_at = 0.0 + self.delayed_until = 0.0 self.system_attributes = system_attributes or {} @property - def body_md5(self): + def body_md5(self) -> str: md5 = md5_hash() md5.update(self._body.encode("utf-8")) return md5.hexdigest() @property - def attribute_md5(self): + def attribute_md5(self) -> str: md5 = md5_hash() @@ -129,13 +138,13 @@ class Message(BaseModel): return md5.hexdigest() @staticmethod - def update_binary_length_and_value(md5, value): + def update_binary_length_and_value(md5: Any, value: bytes) -> None: # type: ignore[misc] length_bytes = struct.pack("!I".encode("ascii"), len(value)) md5.update(length_bytes) md5.update(value) @staticmethod - def validate_attribute_name(name): + def validate_attribute_name(name: str) -> None: if not ATTRIBUTE_NAME_PATTERN.match(name): raise MessageAttributesInvalid( f"The message attribute name '{name}' is invalid. " @@ -144,21 +153,21 @@ class Message(BaseModel): ) @staticmethod - def utf8(value): + def utf8(value: Any) -> bytes: # type: ignore[misc] if isinstance(value, str): return value.encode("utf-8") return value @property - def body(self): + def body(self) -> str: return escape(self._body).replace('"', """).replace("\r", " ") - def mark_sent(self, delay_seconds=None): - self.sent_timestamp = int(unix_time_millis()) + def mark_sent(self, delay_seconds: Optional[int] = None) -> None: + self.sent_timestamp = int(unix_time_millis()) # type: ignore if delay_seconds: self.delay(delay_seconds=delay_seconds) - def mark_received(self, visibility_timeout=None): + def mark_received(self, visibility_timeout: Optional[int] = None) -> None: """ When a message is received we will set the first receive timestamp, tap the ``approximate_receive_count`` and the ``visible_at`` time. @@ -178,37 +187,37 @@ class Message(BaseModel): if visibility_timeout: self.change_visibility(visibility_timeout) - self._old_receipt_handles.append(self.receipt_handle) + self._old_receipt_handles.append(self.receipt_handle) # type: ignore self.receipt_handle = generate_receipt_handle() - def change_visibility(self, visibility_timeout): + def change_visibility(self, visibility_timeout: int) -> None: # We're dealing with milliseconds internally visibility_timeout_msec = int(visibility_timeout) * 1000 self.visible_at = unix_time_millis() + visibility_timeout_msec - def delay(self, delay_seconds): + def delay(self, delay_seconds: int) -> None: delay_msec = int(delay_seconds) * 1000 self.delayed_until = unix_time_millis() + delay_msec @property - def visible(self): + def visible(self) -> bool: current_time = unix_time_millis() if current_time > self.visible_at: return True return False @property - def delayed(self): + def delayed(self) -> bool: current_time = unix_time_millis() if current_time < self.delayed_until: return True return False @property - def all_receipt_handles(self): - return [self.receipt_handle] + self._old_receipt_handles + def all_receipt_handles(self) -> List[Optional[str]]: + return [self.receipt_handle] + self._old_receipt_handles # type: ignore - def had_receipt_handle(self, receipt_handle): + def had_receipt_handle(self, receipt_handle: str) -> bool: """ Check if this message ever had this receipt_handle in the past """ @@ -250,24 +259,25 @@ class Queue(CloudFormationModel): "SendMessage", ) - def __init__(self, name, region, account_id, **kwargs): + def __init__(self, name: str, region: str, account_id: str, **kwargs: Any): self.name = name self.region = region self.account_id = account_id - self.tags = {} - self.permissions = {} + self.tags: Dict[str, str] = {} + self.permissions: Dict[str, Any] = {} - self._messages = [] - self._pending_messages = set() - self.deleted_messages = set() + self._messages: List[Message] = [] + self._pending_messages: Set[Message] = set() + self.deleted_messages: Set[str] = set() self._messages_lock = Condition() now = unix_time() self.created_timestamp = now self.queue_arn = f"arn:aws:sqs:{region}:{account_id}:{name}" - self.dead_letter_queue = None + self.dead_letter_queue: Optional["Queue"] = None + self.fifo_queue = False - self.lambda_event_source_mappings = {} + self.lambda_event_source_mappings: Dict[str, "EventSourceMapping"] = {} # default settings for a non fifo queue defaults = { @@ -293,24 +303,26 @@ class Queue(CloudFormationModel): if self.fifo_queue and not self.name.endswith(".fifo"): raise InvalidParameterValue("Queue name must end in .fifo for FIFO queues") if ( - self.maximum_message_size < MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND - or self.maximum_message_size > MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND + self.maximum_message_size < MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND # type: ignore + or self.maximum_message_size > MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND # type: ignore ): raise InvalidAttributeValue("MaximumMessageSize") @property - def pending_messages(self): + def pending_messages(self) -> Set[Message]: return self._pending_messages @property - def pending_message_groups(self): + def pending_message_groups(self) -> Set[str]: return set( message.group_id for message in self._pending_messages if message.group_id is not None ) - def _set_attributes(self, attributes, now=None): + def _set_attributes( + self, attributes: Dict[str, Any], now: Optional[float] = None + ) -> None: if not now: now = unix_time() @@ -343,7 +355,7 @@ class Queue(CloudFormationModel): self.last_modified_timestamp = now @staticmethod - def _is_empty_redrive_policy(policy): + def _is_empty_redrive_policy(policy: Any) -> bool: # type: ignore[misc] if isinstance(policy, str): if policy == "" or len(json.loads(policy)) == 0: return True @@ -352,7 +364,7 @@ class Queue(CloudFormationModel): return False - def _setup_dlq(self, policy): + def _setup_dlq(self, policy: Any) -> None: if Queue._is_empty_redrive_policy(policy): self.redrive_policy = None self.dead_letter_queue = None @@ -407,18 +419,23 @@ class Queue(CloudFormationModel): ) @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "QueueName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sqs-queue.html return "AWS::SQS::Queue" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Queue": properties = deepcopy(cloudformation_json["Properties"]) # remove Tags from properties and convert tags list to dict tags = properties.pop("Tags", []) @@ -433,14 +450,14 @@ class Queue(CloudFormationModel): ) @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "Queue": properties = cloudformation_json["Properties"] queue_name = original_resource.name @@ -456,9 +473,13 @@ class Queue(CloudFormationModel): return queue @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: # ResourceName will be the full queue URL - we only need the name # https://sqs.us-west-1.amazonaws.com/123456789012/queue_name queue_name = resource_name.split("/")[-1] @@ -466,24 +487,24 @@ class Queue(CloudFormationModel): sqs_backend.delete_queue(queue_name) @property - def approximate_number_of_messages_delayed(self): + def approximate_number_of_messages_delayed(self) -> int: return len([m for m in self._messages if m.delayed]) @property - def approximate_number_of_messages_not_visible(self): + def approximate_number_of_messages_not_visible(self) -> int: return len([m for m in self._messages if not m.visible]) @property - def approximate_number_of_messages(self): + def approximate_number_of_messages(self) -> int: return len(self.messages) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return f"https://sqs.{self.region}.amazonaws.com/{self.account_id}/{self.name}" @property - def attributes(self): - result = {} + def attributes(self) -> Dict[str, Any]: # type: ignore[misc] + result: Dict[str, Any] = {} for attribute in self.BASE_ATTRIBUTES: attr = getattr(self, camelcase_to_underscores(attribute)) @@ -494,7 +515,7 @@ class Queue(CloudFormationModel): attr = getattr(self, camelcase_to_underscores(attribute)) result[attribute] = attr - if self.kms_master_key_id: + if self.kms_master_key_id: # type: ignore for attribute in self.KMS_ATTRIBUTES: attr = getattr(self, camelcase_to_underscores(attribute)) result[attribute] = attr @@ -511,13 +532,11 @@ class Queue(CloudFormationModel): return result - def url(self, request_url): - return ( - f"{request_url.scheme}://{request_url.netloc}/{self.account_id}/{self.name}" - ) + def url(self, request_url: ParseResult) -> str: + return f"{request_url.scheme}://{request_url.netloc}/{self.account_id}/{self.name}" # type: ignore @property - def messages(self): + def messages(self) -> List[Message]: # TODO: This can become very inefficient if a large number of messages are in-flight return [ message @@ -525,7 +544,7 @@ class Queue(CloudFormationModel): if message.visible and not message.delayed ] - def add_message(self, message): + def add_message(self, message: Message) -> None: if self.fifo_queue: # the cases in which we dedupe fifo messages @@ -537,7 +556,7 @@ class Queue(CloudFormationModel): ): for m in self._messages: if m.deduplication_id == message.deduplication_id: - diff = message.sent_timestamp - m.sent_timestamp + diff = message.sent_timestamp - m.sent_timestamp # type: ignore # if a duplicate message is received within the deduplication time then it should # not be added to the queue if diff / 1000 < DEDUPLICATION_TIME_IN_SECONDS: @@ -559,8 +578,8 @@ class Queue(CloudFormationModel): messages = backend.receive_message( self.name, esm.batch_size, - self.receive_message_wait_time_seconds, - self.visibility_timeout, + self.receive_message_wait_time_seconds, # type: ignore + self.visibility_timeout, # type: ignore ) from moto.awslambda import lambda_backends @@ -580,7 +599,7 @@ class Queue(CloudFormationModel): for m in messages ] - def delete_message(self, receipt_handle): + def delete_message(self, receipt_handle: str) -> None: if receipt_handle in self.deleted_messages: # Already deleted - gracefully handle deleting it again return @@ -595,20 +614,20 @@ class Queue(CloudFormationModel): for message in self._messages: if message.had_receipt_handle(receipt_handle): self.pending_messages.discard(message) - self.deleted_messages.update(message.all_receipt_handles) + self.deleted_messages.update(message.all_receipt_handles) # type: ignore continue new_messages.append(message) self._messages = new_messages - def wait_for_messages(self, timeout): + def wait_for_messages(self, timeout: int) -> None: with self._messages_lock: self._messages_lock.wait_for(lambda: self.messages, timeout=timeout) @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn", "QueueName"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": @@ -618,14 +637,14 @@ class Queue(CloudFormationModel): raise UnformattedGetAttTemplateException() @property - def policy(self): + def policy(self) -> Any: # type: ignore[misc] if self._policy_json.get("Statement"): return json.dumps(self._policy_json) else: return None @policy.setter - def policy(self, policy): + def policy(self, policy: Any) -> None: if policy: self._policy_json = json.loads(policy) else: @@ -636,7 +655,9 @@ class Queue(CloudFormationModel): } -def _filter_message_attributes(message, input_message_attributes): +def _filter_message_attributes( + message: Message, input_message_attributes: List[str] +) -> None: filtered_message_attributes = {} return_all = "All" in input_message_attributes for key, value in message.message_attributes.items(): @@ -646,18 +667,22 @@ def _filter_message_attributes(message, input_message_attributes): class SQSBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.queues: Dict[str, Queue] = {} @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "sqs" ) - def create_queue(self, name, tags=None, **kwargs): + def create_queue( + self, name: str, tags: Optional[Dict[str, str]] = None, **kwargs: Any + ) -> Queue: queue = self.queues.get(name) if queue: try: @@ -692,10 +717,10 @@ class SQSBackend(BaseBackend): return queue - def get_queue_url(self, queue_name): + def get_queue_url(self, queue_name: str) -> Queue: return self.get_queue(queue_name) - def list_queues(self, queue_name_prefix): + def list_queues(self, queue_name_prefix: str) -> List[Queue]: re_str = ".*" if queue_name_prefix: re_str = f"^{queue_name_prefix}.*" @@ -706,18 +731,20 @@ class SQSBackend(BaseBackend): qs.append(q) return qs[:1000] - def get_queue(self, queue_name): + def get_queue(self, queue_name: str) -> Queue: queue = self.queues.get(queue_name) if queue is None: raise QueueDoesNotExist() return queue - def delete_queue(self, queue_name): + def delete_queue(self, queue_name: str) -> None: self.get_queue(queue_name) del self.queues[queue_name] - def get_queue_attributes(self, queue_name, attribute_names): + def get_queue_attributes( + self, queue_name: str, attribute_names: List[str] + ) -> Dict[str, Any]: queue = self.get_queue(queue_name) if not attribute_names: return {} @@ -746,21 +773,23 @@ class SQSBackend(BaseBackend): return attributes - def set_queue_attributes(self, queue_name, attributes): + def set_queue_attributes( + self, queue_name: str, attributes: Dict[str, Any] + ) -> Queue: queue = self.get_queue(queue_name) queue._set_attributes(attributes) return queue def send_message( self, - queue_name, - message_body, - message_attributes=None, - delay_seconds=None, - deduplication_id=None, - group_id=None, - system_attributes=None, - ): + queue_name: str, + message_body: str, + message_attributes: Optional[Dict[str, Any]] = None, + delay_seconds: Optional[int] = None, + deduplication_id: Optional[str] = None, + group_id: Optional[str] = None, + system_attributes: Optional[Dict[str, Any]] = None, + ) -> Message: queue = self.get_queue(queue_name) @@ -783,14 +812,14 @@ class SQSBackend(BaseBackend): ) raise InvalidParameterValue(msg) - if len(message_body) > queue.maximum_message_size: - msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." + if len(message_body) > queue.maximum_message_size: # type: ignore + msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." # type: ignore raise InvalidParameterValue(msg) if delay_seconds: delay_seconds = int(delay_seconds) else: - delay_seconds = queue.delay_seconds + delay_seconds = queue.delay_seconds # type: ignore message_id = str(random.uuid4()) message = Message(message_id, message_body, system_attributes) @@ -826,7 +855,7 @@ class SQSBackend(BaseBackend): if message_attributes: message.message_attributes = message_attributes - if delay_seconds > MAXIMUM_MESSAGE_DELAY: + if delay_seconds > MAXIMUM_MESSAGE_DELAY: # type: ignore msg = ( f"Value {delay_seconds} for parameter DelaySeconds is invalid. " "Reason: DelaySeconds must be >= 0 and <= 900." @@ -839,7 +868,9 @@ class SQSBackend(BaseBackend): return message - def send_message_batch(self, queue_name, entries): + def send_message_batch( + self, queue_name: str, entries: Dict[str, Dict[str, Any]] + ) -> Tuple[List[Message], List[Dict[str, Any]]]: self.get_queue(queue_name) if any( @@ -880,14 +911,14 @@ class SQSBackend(BaseBackend): group_id=entry.get("MessageGroupId"), deduplication_id=entry.get("MessageDeduplicationId"), ) - message.user_id = entry["Id"] + message.user_id = entry["Id"] # type: ignore[attr-defined] messages.append(message) except InvalidParameterValue: failedInvalidDelay.append(entry) return messages, failedInvalidDelay - def _get_first_duplicate_id(self, ids): + def _get_first_duplicate_id(self, ids: List[str]) -> Optional[str]: unique_ids = set() for _id in ids: if _id in unique_ids: @@ -897,12 +928,12 @@ class SQSBackend(BaseBackend): def receive_message( self, - queue_name, - count, - wait_seconds_timeout, - visibility_timeout, - message_attribute_names=None, - ): + queue_name: str, + count: int, + wait_seconds_timeout: int, + visibility_timeout: int, + message_attribute_names: Optional[List[str]] = None, + ) -> List[Message]: # Attempt to retrieve visible messages from a queue. # If a message was read by client and not deleted it is considered to be @@ -913,7 +944,7 @@ class SQSBackend(BaseBackend): if message_attribute_names is None: message_attribute_names = [] queue = self.get_queue(queue_name) - result = [] + result: List[Message] = [] previous_result_count = len(result) polling_end = unix_time() + wait_seconds_timeout @@ -925,7 +956,7 @@ class SQSBackend(BaseBackend): if result or (wait_seconds_timeout and unix_time() > polling_end): break - messages_to_dlq = [] + messages_to_dlq: List[Message] = [] for message in queue.messages: if not message.visible: @@ -966,7 +997,7 @@ class SQSBackend(BaseBackend): for message in messages_to_dlq: queue._messages.remove(message) - queue.dead_letter_queue.add_message(message) + queue.dead_letter_queue.add_message(message) # type: ignore if previous_result_count == len(result): if wait_seconds_timeout == 0: @@ -981,12 +1012,14 @@ class SQSBackend(BaseBackend): return result - def delete_message(self, queue_name, receipt_handle): + def delete_message(self, queue_name: str, receipt_handle: str) -> None: queue = self.get_queue(queue_name) queue.delete_message(receipt_handle) - def delete_message_batch(self, queue_name, receipts): + def delete_message_batch( + self, queue_name: str, receipts: List[Dict[str, Any]] + ) -> Tuple[List[str], List[Dict[str, str]]]: success = [] errors = [] for receipt_and_id in receipts: @@ -1004,14 +1037,16 @@ class SQSBackend(BaseBackend): ) return success, errors - def change_message_visibility(self, queue_name, receipt_handle, visibility_timeout): + def change_message_visibility( + self, queue_name: str, receipt_handle: str, visibility_timeout: int + ) -> None: queue = self.get_queue(queue_name) for message in queue._messages: if message.had_receipt_handle(receipt_handle): visibility_timeout_msec = int(visibility_timeout) * 1000 given_visibility_timeout = unix_time_millis() + visibility_timeout_msec - if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: + if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: # type: ignore raise InvalidParameterValue( f"Value {visibility_timeout} for parameter VisibilityTimeout is invalid. Reason: Total " "VisibilityTimeout for the message is beyond the limit [43200 seconds]" @@ -1025,7 +1060,9 @@ class SQSBackend(BaseBackend): return raise ReceiptHandleIsInvalid - def change_message_visibility_batch(self, queue_name: str, entries): + def change_message_visibility_batch( + self, queue_name: str, entries: List[Dict[str, Any]] + ) -> Tuple[List[str], List[Dict[str, str]]]: success = [] error = [] for entry in entries: @@ -1061,22 +1098,24 @@ class SQSBackend(BaseBackend): ) return success, error - def purge_queue(self, queue_name): + def purge_queue(self, queue_name: str) -> None: queue = self.get_queue(queue_name) queue._messages = [] queue._pending_messages = set() - def list_dead_letter_source_queues(self, queue_name): + def list_dead_letter_source_queues(self, queue_name: str) -> List[Queue]: dlq = self.get_queue(queue_name) - queues = [] + queues: List[Queue] = [] for queue in self.queues.values(): if queue.dead_letter_queue is dlq: queues.append(queue) return queues - def add_permission(self, queue_name, actions, account_ids, label): + def add_permission( + self, queue_name: str, actions: List[str], account_ids: List[str], label: str + ) -> None: queue = self.get_queue(queue_name) if not actions: @@ -1128,7 +1167,7 @@ class SQSBackend(BaseBackend): queue._policy_json["Statement"].append(statement) - def remove_permission(self, queue_name, label): + def remove_permission(self, queue_name: str, label: str) -> None: queue = self.get_queue(queue_name) statements = queue._policy_json["Statement"] @@ -1144,7 +1183,7 @@ class SQSBackend(BaseBackend): queue._policy_json["Statement"] = statements_new - def tag_queue(self, queue_name, tags): + def tag_queue(self, queue_name: str, tags: Dict[str, str]) -> None: queue = self.get_queue(queue_name) if not len(tags): @@ -1155,7 +1194,7 @@ class SQSBackend(BaseBackend): queue.tags.update(tags) - def untag_queue(self, queue_name, tag_keys): + def untag_queue(self, queue_name: str, tag_keys: List[str]) -> None: queue = self.get_queue(queue_name) if not len(tag_keys): @@ -1170,17 +1209,16 @@ class SQSBackend(BaseBackend): except KeyError: pass - def list_queue_tags(self, queue_name): + def list_queue_tags(self, queue_name: str) -> Queue: return self.get_queue(queue_name) - def is_message_valid_based_on_retention_period(self, queue_name, message): - message_attributes = self.get_queue_attributes( + def is_message_valid_based_on_retention_period( + self, queue_name: str, message: Message + ) -> bool: + retention_period = self.get_queue_attributes( queue_name, ["MessageRetentionPeriod"] - ) - retain_until = ( - message_attributes.get("MessageRetentionPeriod") - + message.sent_timestamp / 1000 - ) + )["MessageRetentionPeriod"] + retain_until = retention_period + message.sent_timestamp / 1000 # type: ignore if retain_until <= unix_time(): return False return True diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index 90165c0ee..56d1d93e1 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -1,5 +1,7 @@ import re +from typing import Any, Dict, Optional, Tuple, Union +from moto.core.common_types import TYPE_RESPONSE from moto.core.exceptions import RESTError from moto.core.responses import BaseResponse from moto.core.utils import underscores_to_camelcase, camelcase_to_pascal @@ -24,7 +26,7 @@ class SQSResponse(BaseResponse): region_regex = re.compile(r"://(.+?)\.queue\.amazonaws\.com") - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="sqs") @property @@ -32,7 +34,7 @@ class SQSResponse(BaseResponse): return sqs_backends[self.current_account][self.region] @property - def attribute(self): + def attribute(self) -> Any: # type: ignore[misc] if not hasattr(self, "_attribute"): self._attribute = self._get_map_prefix( "Attribute", key_end=".Name", value_end=".Value" @@ -40,14 +42,14 @@ class SQSResponse(BaseResponse): return self._attribute @property - def tags(self): + def tags(self) -> Dict[str, str]: if not hasattr(self, "_tags"): self._tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") return self._tags - def _get_queue_name(self): + def _get_queue_name(self) -> str: try: - queue_url = self.querystring.get("QueueUrl")[0] + queue_url = self.querystring.get("QueueUrl")[0] # type: ignore if queue_url.startswith("http://") or queue_url.startswith("https://"): return queue_url.split("/")[-1] else: @@ -57,7 +59,7 @@ class SQSResponse(BaseResponse): # Fallback to reading from the URL for botocore return self.path.split("/")[-1] - def _get_validated_visibility_timeout(self, timeout=None): + def _get_validated_visibility_timeout(self, timeout: Optional[str] = None) -> int: """ :raises ValueError: If specified visibility timeout exceeds MAXIMUM_VISIBILITY_TIMEOUT :raises TypeError: If visibility timeout was not specified @@ -65,7 +67,7 @@ class SQSResponse(BaseResponse): if timeout is not None: visibility_timeout = int(timeout) else: - visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0]) + visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0]) # type: ignore if visibility_timeout > MAXIMUM_VISIBILITY_TIMEOUT: raise ValueError @@ -74,7 +76,7 @@ class SQSResponse(BaseResponse): @amz_crc32 # crc last as request_id can edit XML @amzn_request_id - def call_action(self): + def call_action(self) -> TYPE_RESPONSE: status_code, headers, body = super().call_action() if status_code == 404: queue_name = self.querystring.get("QueueName", [""])[0] @@ -83,11 +85,13 @@ class SQSResponse(BaseResponse): return 404, headers, response return status_code, headers, body - def _error(self, code, message, status=400): + def _error( + self, code: str, message: str, status: int = 400 + ) -> Tuple[str, Dict[str, int]]: template = self.response_template(ERROR_TEMPLATE) return template.render(code=code, message=message), dict(status=status) - def create_queue(self): + def create_queue(self) -> str: request_url = urlparse(self.uri) queue_name = self._get_param("QueueName") @@ -96,7 +100,7 @@ class SQSResponse(BaseResponse): template = self.response_template(CREATE_QUEUE_RESPONSE) return template.render(queue_url=queue.url(request_url)) - def get_queue_url(self): + def get_queue_url(self) -> str: request_url = urlparse(self.uri) queue_name = self._get_param("QueueName") @@ -105,14 +109,14 @@ class SQSResponse(BaseResponse): template = self.response_template(GET_QUEUE_URL_RESPONSE) return template.render(queue_url=queue.url(request_url)) - def list_queues(self): + def list_queues(self) -> str: request_url = urlparse(self.uri) queue_name_prefix = self._get_param("QueueNamePrefix") queues = self.sqs_backend.list_queues(queue_name_prefix) template = self.response_template(LIST_QUEUES_RESPONSE) return template.render(queues=queues, request_url=request_url) - def change_message_visibility(self): + def change_message_visibility(self) -> Union[str, Tuple[str, Dict[str, int]]]: queue_name = self._get_queue_name() receipt_handle = self._get_param("ReceiptHandle") @@ -130,7 +134,7 @@ class SQSResponse(BaseResponse): template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE) return template.render() - def change_message_visibility_batch(self): + def change_message_visibility_batch(self) -> str: queue_name = self._get_queue_name() entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry") @@ -141,24 +145,23 @@ class SQSResponse(BaseResponse): template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE) return template.render(success=success, errors=error) - def get_queue_attributes(self): + def get_queue_attributes(self) -> str: queue_name = self._get_queue_name() if self.querystring.get("AttributeNames"): raise InvalidAttributeName("") - attribute_names = self._get_multi_param("AttributeName") - # if connecting to AWS via boto, then 'AttributeName' is just a normal parameter - if not attribute_names: - attribute_names = self.querystring.get("AttributeName") + attribute_names = self._get_multi_param( + "AttributeName" + ) or self.querystring.get("AttributeName") - attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) + attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) # type: ignore template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE) return template.render(attributes=attributes) - def set_queue_attributes(self): + def set_queue_attributes(self) -> str: # TODO validate self.get_param('QueueUrl') attribute = self.attribute @@ -174,7 +177,7 @@ class SQSResponse(BaseResponse): return SET_QUEUE_ATTRIBUTE_RESPONSE - def delete_queue(self): + def delete_queue(self) -> str: # TODO validate self.get_param('QueueUrl') queue_name = self._get_queue_name() @@ -183,7 +186,7 @@ class SQSResponse(BaseResponse): template = self.response_template(DELETE_QUEUE_RESPONSE) return template.render() - def send_message(self): + def send_message(self) -> Union[str, Tuple[str, Dict[str, int]]]: message = self._get_param("MessageBody") delay_seconds = int(self._get_param("DelaySeconds", 0)) message_group_id = self._get_param("MessageGroupId") @@ -215,7 +218,7 @@ class SQSResponse(BaseResponse): template = self.response_template(SEND_MESSAGE_RESPONSE) return template.render(message=message, message_attributes=message_attributes) - def send_message_batch(self): + def send_message_batch(self) -> str: """ The querystring comes like this @@ -247,7 +250,7 @@ class SQSResponse(BaseResponse): entries[index] = { "Id": value[0], - "MessageBody": self.querystring.get( + "MessageBody": self.querystring.get( # type: ignore f"SendMessageBatchRequestEntry.{index}.MessageBody" )[0], "DelaySeconds": self.querystring.get( @@ -286,14 +289,14 @@ class SQSResponse(BaseResponse): template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE) return template.render(messages=messages, errors=errors) - def delete_message(self): + def delete_message(self) -> str: queue_name = self._get_queue_name() - receipt_handle = self.querystring.get("ReceiptHandle")[0] + receipt_handle = self.querystring.get("ReceiptHandle")[0] # type: ignore self.sqs_backend.delete_message(queue_name, receipt_handle) template = self.response_template(DELETE_MESSAGE_RESPONSE) return template.render() - def delete_message_batch(self): + def delete_message_batch(self) -> str: """ The querystring comes like this @@ -316,7 +319,7 @@ class SQSResponse(BaseResponse): break message_user_id_key = f"DeleteMessageBatchRequestEntry.{index}.Id" - message_user_id = self.querystring.get(message_user_id_key)[0] + message_user_id = self.querystring.get(message_user_id_key)[0] # type: ignore receipts.append( {"receipt_handle": receipt_handle[0], "msg_user_id": message_user_id} ) @@ -333,13 +336,13 @@ class SQSResponse(BaseResponse): template = self.response_template(DELETE_MESSAGE_BATCH_RESPONSE) return template.render(success=success, errors=errors) - def purge_queue(self): + def purge_queue(self) -> str: queue_name = self._get_queue_name() self.sqs_backend.purge_queue(queue_name) template = self.response_template(PURGE_QUEUE_RESPONSE) return template.render() - def receive_message(self): + def receive_message(self) -> Union[str, Tuple[str, Dict[str, int]]]: queue_name = self._get_queue_name() message_attributes = self._get_multi_param("message_attributes") if not message_attributes: @@ -350,7 +353,7 @@ class SQSResponse(BaseResponse): queue = self.sqs_backend.get_queue(queue_name) try: - message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) + message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) # type: ignore except TypeError: message_count = DEFAULT_RECEIVED_MESSAGES @@ -364,9 +367,9 @@ class SQSResponse(BaseResponse): ) try: - wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) + wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) # type: ignore except TypeError: - wait_time = int(queue.receive_message_wait_time_seconds) + wait_time = int(queue.receive_message_wait_time_seconds) # type: ignore if wait_time < 0 or wait_time > 20: return self._error( @@ -380,7 +383,7 @@ class SQSResponse(BaseResponse): try: visibility_timeout = self._get_validated_visibility_timeout() except TypeError: - visibility_timeout = queue.visibility_timeout + visibility_timeout = queue.visibility_timeout # type: ignore except ValueError: return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400) @@ -406,7 +409,7 @@ class SQSResponse(BaseResponse): template = self.response_template(RECEIVE_MESSAGE_RESPONSE) return template.render(messages=messages, attributes=attributes) - def list_dead_letter_source_queues(self): + def list_dead_letter_source_queues(self) -> str: request_url = urlparse(self.uri) queue_name = self._get_queue_name() @@ -415,7 +418,7 @@ class SQSResponse(BaseResponse): template = self.response_template(LIST_DEAD_LETTER_SOURCE_QUEUES_RESPONSE) return template.render(queues=source_queue_urls, request_url=request_url) - def add_permission(self): + def add_permission(self) -> str: queue_name = self._get_queue_name() actions = self._get_multi_param("ActionName") account_ids = self._get_multi_param("AWSAccountId") @@ -426,7 +429,7 @@ class SQSResponse(BaseResponse): template = self.response_template(ADD_PERMISSION_RESPONSE) return template.render() - def remove_permission(self): + def remove_permission(self) -> str: queue_name = self._get_queue_name() label = self._get_param("Label") @@ -435,7 +438,7 @@ class SQSResponse(BaseResponse): template = self.response_template(REMOVE_PERMISSION_RESPONSE) return template.render() - def tag_queue(self): + def tag_queue(self) -> str: queue_name = self._get_queue_name() tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") @@ -444,7 +447,7 @@ class SQSResponse(BaseResponse): template = self.response_template(TAG_QUEUE_RESPONSE) return template.render() - def untag_queue(self): + def untag_queue(self) -> str: queue_name = self._get_queue_name() tag_keys = self._get_multi_param("TagKey") @@ -453,7 +456,7 @@ class SQSResponse(BaseResponse): template = self.response_template(UNTAG_QUEUE_RESPONSE) return template.render() - def list_queue_tags(self): + def list_queue_tags(self) -> str: queue_name = self._get_queue_name() queue = self.sqs_backend.list_queue_tags(queue_name) diff --git a/moto/sqs/utils.py b/moto/sqs/utils.py index 8fd6ab107..33ba0ec9b 100644 --- a/moto/sqs/utils.py +++ b/moto/sqs/utils.py @@ -1,15 +1,17 @@ import string +from typing import Any, Dict, List + from moto.moto_api._internal import mock_random as random from .exceptions import MessageAttributesInvalid -def generate_receipt_handle(): +def generate_receipt_handle() -> str: # http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ImportantIdentifiers.html#ImportantIdentifiers-receipt-handles length = 185 return "".join(random.choice(string.ascii_lowercase) for x in range(length)) -def extract_input_message_attributes(querystring): +def extract_input_message_attributes(querystring: Dict[str, Any]) -> List[str]: message_attributes = [] index = 1 while True: @@ -25,8 +27,11 @@ def extract_input_message_attributes(querystring): def parse_message_attributes( - querystring, key="MessageAttribute", base="", value_namespace="Value." -): + querystring: Dict[str, Any], + key: str = "MessageAttribute", + base: str = "", + value_namespace: str = "Value.", +) -> Dict[str, Any]: message_attributes = {} index = 1 while True: diff --git a/setup.cfg b/setup.cfg index 1679a8068..b98e4f7ec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -239,7 +239,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ssm,moto/scheduler +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/sqs,moto/ssm,moto/scheduler show_column_numbers=True show_error_codes = True disable_error_code=abstract