Techdebt: MyPy SQS (#6251)

This commit is contained in:
Bert Blommers 2023-04-24 16:46:01 +00:00 committed by GitHub
parent a32b721c79
commit 1f56b75ccf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 241 additions and 195 deletions

View File

@ -4,7 +4,7 @@ from moto.core.exceptions import RESTError
class ReceiptHandleIsInvalid(RESTError): class ReceiptHandleIsInvalid(RESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"ReceiptHandleIsInvalid", "The input receipt handle is invalid." "ReceiptHandleIsInvalid", "The input receipt handle is invalid."
) )
@ -13,14 +13,14 @@ class ReceiptHandleIsInvalid(RESTError):
class MessageAttributesInvalid(RESTError): class MessageAttributesInvalid(RESTError):
code = 400 code = 400
def __init__(self, description): def __init__(self, description: str):
super().__init__("MessageAttributesInvalid", description) super().__init__("MessageAttributesInvalid", description)
class QueueDoesNotExist(RESTError): class QueueDoesNotExist(RESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"AWS.SimpleQueueService.NonExistentQueue", "AWS.SimpleQueueService.NonExistentQueue",
"The specified queue does not exist for this wsdl version.", "The specified queue does not exist for this wsdl version.",
@ -31,14 +31,14 @@ class QueueDoesNotExist(RESTError):
class QueueAlreadyExists(RESTError): class QueueAlreadyExists(RESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("QueueAlreadyExists", message) super().__init__("QueueAlreadyExists", message)
class EmptyBatchRequest(RESTError): class EmptyBatchRequest(RESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"EmptyBatchRequest", "EmptyBatchRequest",
"There should be at least one SendMessageBatchRequestEntry in the request.", "There should be at least one SendMessageBatchRequestEntry in the request.",
@ -48,7 +48,7 @@ class EmptyBatchRequest(RESTError):
class InvalidBatchEntryId(RESTError): class InvalidBatchEntryId(RESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"InvalidBatchEntryId", "InvalidBatchEntryId",
"A batch entry id can only contain alphanumeric characters, " "A batch entry id can only contain alphanumeric characters, "
@ -59,7 +59,7 @@ class InvalidBatchEntryId(RESTError):
class BatchRequestTooLong(RESTError): class BatchRequestTooLong(RESTError):
code = 400 code = 400
def __init__(self, length): def __init__(self, length: int):
super().__init__( super().__init__(
"BatchRequestTooLong", "BatchRequestTooLong",
"Batch requests cannot be longer than 262144 bytes. " "Batch requests cannot be longer than 262144 bytes. "
@ -70,14 +70,14 @@ class BatchRequestTooLong(RESTError):
class BatchEntryIdsNotDistinct(RESTError): class BatchEntryIdsNotDistinct(RESTError):
code = 400 code = 400
def __init__(self, entry_id): def __init__(self, entry_id: str):
super().__init__("BatchEntryIdsNotDistinct", f"Id {entry_id} repeated.") super().__init__("BatchEntryIdsNotDistinct", f"Id {entry_id} repeated.")
class TooManyEntriesInBatchRequest(RESTError): class TooManyEntriesInBatchRequest(RESTError):
code = 400 code = 400
def __init__(self, number): def __init__(self, number: int):
super().__init__( super().__init__(
"TooManyEntriesInBatchRequest", "TooManyEntriesInBatchRequest",
"Maximum number of entries per request are 10. " f"You have sent {number}.", "Maximum number of entries per request are 10. " f"You have sent {number}.",
@ -87,14 +87,14 @@ class TooManyEntriesInBatchRequest(RESTError):
class InvalidAttributeName(RESTError): class InvalidAttributeName(RESTError):
code = 400 code = 400
def __init__(self, attribute_name): def __init__(self, attribute_name: str):
super().__init__("InvalidAttributeName", f"Unknown Attribute {attribute_name}.") super().__init__("InvalidAttributeName", f"Unknown Attribute {attribute_name}.")
class InvalidAttributeValue(RESTError): class InvalidAttributeValue(RESTError):
code = 400 code = 400
def __init__(self, attribute_name): def __init__(self, attribute_name: str):
super().__init__( super().__init__(
"InvalidAttributeValue", "InvalidAttributeValue",
f"Invalid value for the parameter {attribute_name}.", f"Invalid value for the parameter {attribute_name}.",
@ -104,14 +104,14 @@ class InvalidAttributeValue(RESTError):
class InvalidParameterValue(RESTError): class InvalidParameterValue(RESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidParameterValue", message) super().__init__("InvalidParameterValue", message)
class MissingParameter(RESTError): class MissingParameter(RESTError):
code = 400 code = 400
def __init__(self, parameter): def __init__(self, parameter: str):
super().__init__( super().__init__(
"MissingParameter", f"The request must contain the parameter {parameter}." "MissingParameter", f"The request must contain the parameter {parameter}."
) )
@ -120,7 +120,7 @@ class MissingParameter(RESTError):
class OverLimit(RESTError): class OverLimit(RESTError):
code = 403 code = 403
def __init__(self, count): def __init__(self, count: int):
super().__init__( super().__init__(
"OverLimit", f"{count} Actions were found, maximum allowed is 7." "OverLimit", f"{count} Actions were found, maximum allowed is 7."
) )

View File

@ -6,8 +6,9 @@ import string
import struct import struct
from copy import deepcopy from copy import deepcopy
from typing import Dict from typing import Any, Dict, List, Optional, Tuple, Set, TYPE_CHECKING
from threading import Condition from threading import Condition
from urllib.parse import ParseResult
from xml.sax.saxutils import escape from xml.sax.saxutils import escape
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
@ -38,6 +39,9 @@ from .exceptions import (
InvalidAttributeValue, InvalidAttributeValue,
) )
if TYPE_CHECKING:
from moto.awslambda.models import EventSourceMapping
DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU" DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU"
MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB
@ -67,31 +71,36 @@ DEDUPLICATION_TIME_IN_SECONDS = 300
class Message(BaseModel): class Message(BaseModel):
def __init__(self, message_id, body, system_attributes=None): def __init__(
self,
message_id: str,
body: str,
system_attributes: Optional[Dict[str, Any]] = None,
):
self.id = message_id self.id = message_id
self._body = body self._body = body
self.message_attributes = {} self.message_attributes: Dict[str, Any] = {}
self.receipt_handle = None self.receipt_handle: Optional[str] = None
self._old_receipt_handles = [] self._old_receipt_handles: List[str] = []
self.sender_id = DEFAULT_SENDER_ID self.sender_id = DEFAULT_SENDER_ID
self.sent_timestamp = None self.sent_timestamp = None
self.approximate_first_receive_timestamp = None self.approximate_first_receive_timestamp: Optional[int] = None
self.approximate_receive_count = 0 self.approximate_receive_count = 0
self.deduplication_id = None self.deduplication_id: Optional[str] = None
self.group_id = None self.group_id: Optional[str] = None
self.sequence_number = None self.sequence_number: Optional[str] = None
self.visible_at = 0 self.visible_at = 0.0
self.delayed_until = 0 self.delayed_until = 0.0
self.system_attributes = system_attributes or {} self.system_attributes = system_attributes or {}
@property @property
def body_md5(self): def body_md5(self) -> str:
md5 = md5_hash() md5 = md5_hash()
md5.update(self._body.encode("utf-8")) md5.update(self._body.encode("utf-8"))
return md5.hexdigest() return md5.hexdigest()
@property @property
def attribute_md5(self): def attribute_md5(self) -> str:
md5 = md5_hash() md5 = md5_hash()
@ -129,13 +138,13 @@ class Message(BaseModel):
return md5.hexdigest() return md5.hexdigest()
@staticmethod @staticmethod
def update_binary_length_and_value(md5, value): def update_binary_length_and_value(md5: Any, value: bytes) -> None: # type: ignore[misc]
length_bytes = struct.pack("!I".encode("ascii"), len(value)) length_bytes = struct.pack("!I".encode("ascii"), len(value))
md5.update(length_bytes) md5.update(length_bytes)
md5.update(value) md5.update(value)
@staticmethod @staticmethod
def validate_attribute_name(name): def validate_attribute_name(name: str) -> None:
if not ATTRIBUTE_NAME_PATTERN.match(name): if not ATTRIBUTE_NAME_PATTERN.match(name):
raise MessageAttributesInvalid( raise MessageAttributesInvalid(
f"The message attribute name '{name}' is invalid. " f"The message attribute name '{name}' is invalid. "
@ -144,21 +153,21 @@ class Message(BaseModel):
) )
@staticmethod @staticmethod
def utf8(value): def utf8(value: Any) -> bytes: # type: ignore[misc]
if isinstance(value, str): if isinstance(value, str):
return value.encode("utf-8") return value.encode("utf-8")
return value return value
@property @property
def body(self): def body(self) -> str:
return escape(self._body).replace('"', """).replace("\r", "
") return escape(self._body).replace('"', """).replace("\r", "
")
def mark_sent(self, delay_seconds=None): def mark_sent(self, delay_seconds: Optional[int] = None) -> None:
self.sent_timestamp = int(unix_time_millis()) self.sent_timestamp = int(unix_time_millis()) # type: ignore
if delay_seconds: if delay_seconds:
self.delay(delay_seconds=delay_seconds) self.delay(delay_seconds=delay_seconds)
def mark_received(self, visibility_timeout=None): def mark_received(self, visibility_timeout: Optional[int] = None) -> None:
""" """
When a message is received we will set the first receive timestamp, When a message is received we will set the first receive timestamp,
tap the ``approximate_receive_count`` and the ``visible_at`` time. tap the ``approximate_receive_count`` and the ``visible_at`` time.
@ -178,37 +187,37 @@ class Message(BaseModel):
if visibility_timeout: if visibility_timeout:
self.change_visibility(visibility_timeout) self.change_visibility(visibility_timeout)
self._old_receipt_handles.append(self.receipt_handle) self._old_receipt_handles.append(self.receipt_handle) # type: ignore
self.receipt_handle = generate_receipt_handle() self.receipt_handle = generate_receipt_handle()
def change_visibility(self, visibility_timeout): def change_visibility(self, visibility_timeout: int) -> None:
# We're dealing with milliseconds internally # We're dealing with milliseconds internally
visibility_timeout_msec = int(visibility_timeout) * 1000 visibility_timeout_msec = int(visibility_timeout) * 1000
self.visible_at = unix_time_millis() + visibility_timeout_msec self.visible_at = unix_time_millis() + visibility_timeout_msec
def delay(self, delay_seconds): def delay(self, delay_seconds: int) -> None:
delay_msec = int(delay_seconds) * 1000 delay_msec = int(delay_seconds) * 1000
self.delayed_until = unix_time_millis() + delay_msec self.delayed_until = unix_time_millis() + delay_msec
@property @property
def visible(self): def visible(self) -> bool:
current_time = unix_time_millis() current_time = unix_time_millis()
if current_time > self.visible_at: if current_time > self.visible_at:
return True return True
return False return False
@property @property
def delayed(self): def delayed(self) -> bool:
current_time = unix_time_millis() current_time = unix_time_millis()
if current_time < self.delayed_until: if current_time < self.delayed_until:
return True return True
return False return False
@property @property
def all_receipt_handles(self): def all_receipt_handles(self) -> List[Optional[str]]:
return [self.receipt_handle] + self._old_receipt_handles return [self.receipt_handle] + self._old_receipt_handles # type: ignore
def had_receipt_handle(self, receipt_handle): def had_receipt_handle(self, receipt_handle: str) -> bool:
""" """
Check if this message ever had this receipt_handle in the past Check if this message ever had this receipt_handle in the past
""" """
@ -250,24 +259,25 @@ class Queue(CloudFormationModel):
"SendMessage", "SendMessage",
) )
def __init__(self, name, region, account_id, **kwargs): def __init__(self, name: str, region: str, account_id: str, **kwargs: Any):
self.name = name self.name = name
self.region = region self.region = region
self.account_id = account_id self.account_id = account_id
self.tags = {} self.tags: Dict[str, str] = {}
self.permissions = {} self.permissions: Dict[str, Any] = {}
self._messages = [] self._messages: List[Message] = []
self._pending_messages = set() self._pending_messages: Set[Message] = set()
self.deleted_messages = set() self.deleted_messages: Set[str] = set()
self._messages_lock = Condition() self._messages_lock = Condition()
now = unix_time() now = unix_time()
self.created_timestamp = now self.created_timestamp = now
self.queue_arn = f"arn:aws:sqs:{region}:{account_id}:{name}" self.queue_arn = f"arn:aws:sqs:{region}:{account_id}:{name}"
self.dead_letter_queue = None self.dead_letter_queue: Optional["Queue"] = None
self.fifo_queue = False
self.lambda_event_source_mappings = {} self.lambda_event_source_mappings: Dict[str, "EventSourceMapping"] = {}
# default settings for a non fifo queue # default settings for a non fifo queue
defaults = { defaults = {
@ -293,24 +303,26 @@ class Queue(CloudFormationModel):
if self.fifo_queue and not self.name.endswith(".fifo"): if self.fifo_queue and not self.name.endswith(".fifo"):
raise InvalidParameterValue("Queue name must end in .fifo for FIFO queues") raise InvalidParameterValue("Queue name must end in .fifo for FIFO queues")
if ( if (
self.maximum_message_size < MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND self.maximum_message_size < MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND # type: ignore
or self.maximum_message_size > MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND or self.maximum_message_size > MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND # type: ignore
): ):
raise InvalidAttributeValue("MaximumMessageSize") raise InvalidAttributeValue("MaximumMessageSize")
@property @property
def pending_messages(self): def pending_messages(self) -> Set[Message]:
return self._pending_messages return self._pending_messages
@property @property
def pending_message_groups(self): def pending_message_groups(self) -> Set[str]:
return set( return set(
message.group_id message.group_id
for message in self._pending_messages for message in self._pending_messages
if message.group_id is not None if message.group_id is not None
) )
def _set_attributes(self, attributes, now=None): def _set_attributes(
self, attributes: Dict[str, Any], now: Optional[float] = None
) -> None:
if not now: if not now:
now = unix_time() now = unix_time()
@ -343,7 +355,7 @@ class Queue(CloudFormationModel):
self.last_modified_timestamp = now self.last_modified_timestamp = now
@staticmethod @staticmethod
def _is_empty_redrive_policy(policy): def _is_empty_redrive_policy(policy: Any) -> bool: # type: ignore[misc]
if isinstance(policy, str): if isinstance(policy, str):
if policy == "" or len(json.loads(policy)) == 0: if policy == "" or len(json.loads(policy)) == 0:
return True return True
@ -352,7 +364,7 @@ class Queue(CloudFormationModel):
return False return False
def _setup_dlq(self, policy): def _setup_dlq(self, policy: Any) -> None:
if Queue._is_empty_redrive_policy(policy): if Queue._is_empty_redrive_policy(policy):
self.redrive_policy = None self.redrive_policy = None
self.dead_letter_queue = None self.dead_letter_queue = None
@ -407,18 +419,23 @@ class Queue(CloudFormationModel):
) )
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return "QueueName" return "QueueName"
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sqs-queue.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sqs-queue.html
return "AWS::SQS::Queue" return "AWS::SQS::Queue"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Queue":
properties = deepcopy(cloudformation_json["Properties"]) properties = deepcopy(cloudformation_json["Properties"])
# remove Tags from properties and convert tags list to dict # remove Tags from properties and convert tags list to dict
tags = properties.pop("Tags", []) tags = properties.pop("Tags", [])
@ -433,14 +450,14 @@ class Queue(CloudFormationModel):
) )
@classmethod @classmethod
def update_from_cloudformation_json( def update_from_cloudformation_json( # type: ignore[misc]
cls, cls,
original_resource, original_resource: Any,
new_resource_name, new_resource_name: str,
cloudformation_json, cloudformation_json: Any,
account_id, account_id: str,
region_name, region_name: str,
): ) -> "Queue":
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
queue_name = original_resource.name queue_name = original_resource.name
@ -456,9 +473,13 @@ class Queue(CloudFormationModel):
return queue return queue
@classmethod @classmethod
def delete_from_cloudformation_json( def delete_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
# ResourceName will be the full queue URL - we only need the name # ResourceName will be the full queue URL - we only need the name
# https://sqs.us-west-1.amazonaws.com/123456789012/queue_name # https://sqs.us-west-1.amazonaws.com/123456789012/queue_name
queue_name = resource_name.split("/")[-1] queue_name = resource_name.split("/")[-1]
@ -466,24 +487,24 @@ class Queue(CloudFormationModel):
sqs_backend.delete_queue(queue_name) sqs_backend.delete_queue(queue_name)
@property @property
def approximate_number_of_messages_delayed(self): def approximate_number_of_messages_delayed(self) -> int:
return len([m for m in self._messages if m.delayed]) return len([m for m in self._messages if m.delayed])
@property @property
def approximate_number_of_messages_not_visible(self): def approximate_number_of_messages_not_visible(self) -> int:
return len([m for m in self._messages if not m.visible]) return len([m for m in self._messages if not m.visible])
@property @property
def approximate_number_of_messages(self): def approximate_number_of_messages(self) -> int:
return len(self.messages) return len(self.messages)
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return f"https://sqs.{self.region}.amazonaws.com/{self.account_id}/{self.name}" return f"https://sqs.{self.region}.amazonaws.com/{self.account_id}/{self.name}"
@property @property
def attributes(self): def attributes(self) -> Dict[str, Any]: # type: ignore[misc]
result = {} result: Dict[str, Any] = {}
for attribute in self.BASE_ATTRIBUTES: for attribute in self.BASE_ATTRIBUTES:
attr = getattr(self, camelcase_to_underscores(attribute)) attr = getattr(self, camelcase_to_underscores(attribute))
@ -494,7 +515,7 @@ class Queue(CloudFormationModel):
attr = getattr(self, camelcase_to_underscores(attribute)) attr = getattr(self, camelcase_to_underscores(attribute))
result[attribute] = attr result[attribute] = attr
if self.kms_master_key_id: if self.kms_master_key_id: # type: ignore
for attribute in self.KMS_ATTRIBUTES: for attribute in self.KMS_ATTRIBUTES:
attr = getattr(self, camelcase_to_underscores(attribute)) attr = getattr(self, camelcase_to_underscores(attribute))
result[attribute] = attr result[attribute] = attr
@ -511,13 +532,11 @@ class Queue(CloudFormationModel):
return result return result
def url(self, request_url): def url(self, request_url: ParseResult) -> str:
return ( return f"{request_url.scheme}://{request_url.netloc}/{self.account_id}/{self.name}" # type: ignore
f"{request_url.scheme}://{request_url.netloc}/{self.account_id}/{self.name}"
)
@property @property
def messages(self): def messages(self) -> List[Message]:
# TODO: This can become very inefficient if a large number of messages are in-flight # TODO: This can become very inefficient if a large number of messages are in-flight
return [ return [
message message
@ -525,7 +544,7 @@ class Queue(CloudFormationModel):
if message.visible and not message.delayed if message.visible and not message.delayed
] ]
def add_message(self, message): def add_message(self, message: Message) -> None:
if self.fifo_queue: if self.fifo_queue:
# the cases in which we dedupe fifo messages # the cases in which we dedupe fifo messages
@ -537,7 +556,7 @@ class Queue(CloudFormationModel):
): ):
for m in self._messages: for m in self._messages:
if m.deduplication_id == message.deduplication_id: if m.deduplication_id == message.deduplication_id:
diff = message.sent_timestamp - m.sent_timestamp diff = message.sent_timestamp - m.sent_timestamp # type: ignore
# if a duplicate message is received within the deduplication time then it should # if a duplicate message is received within the deduplication time then it should
# not be added to the queue # not be added to the queue
if diff / 1000 < DEDUPLICATION_TIME_IN_SECONDS: if diff / 1000 < DEDUPLICATION_TIME_IN_SECONDS:
@ -559,8 +578,8 @@ class Queue(CloudFormationModel):
messages = backend.receive_message( messages = backend.receive_message(
self.name, self.name,
esm.batch_size, esm.batch_size,
self.receive_message_wait_time_seconds, self.receive_message_wait_time_seconds, # type: ignore
self.visibility_timeout, self.visibility_timeout, # type: ignore
) )
from moto.awslambda import lambda_backends from moto.awslambda import lambda_backends
@ -580,7 +599,7 @@ class Queue(CloudFormationModel):
for m in messages for m in messages
] ]
def delete_message(self, receipt_handle): def delete_message(self, receipt_handle: str) -> None:
if receipt_handle in self.deleted_messages: if receipt_handle in self.deleted_messages:
# Already deleted - gracefully handle deleting it again # Already deleted - gracefully handle deleting it again
return return
@ -595,20 +614,20 @@ class Queue(CloudFormationModel):
for message in self._messages: for message in self._messages:
if message.had_receipt_handle(receipt_handle): if message.had_receipt_handle(receipt_handle):
self.pending_messages.discard(message) self.pending_messages.discard(message)
self.deleted_messages.update(message.all_receipt_handles) self.deleted_messages.update(message.all_receipt_handles) # type: ignore
continue continue
new_messages.append(message) new_messages.append(message)
self._messages = new_messages self._messages = new_messages
def wait_for_messages(self, timeout): def wait_for_messages(self, timeout: int) -> None:
with self._messages_lock: with self._messages_lock:
self._messages_lock.wait_for(lambda: self.messages, timeout=timeout) self._messages_lock.wait_for(lambda: self.messages, timeout=timeout)
@classmethod @classmethod
def has_cfn_attr(cls, attr): def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["Arn", "QueueName"] return attr in ["Arn", "QueueName"]
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn": if attribute_name == "Arn":
@ -618,14 +637,14 @@ class Queue(CloudFormationModel):
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@property @property
def policy(self): def policy(self) -> Any: # type: ignore[misc]
if self._policy_json.get("Statement"): if self._policy_json.get("Statement"):
return json.dumps(self._policy_json) return json.dumps(self._policy_json)
else: else:
return None return None
@policy.setter @policy.setter
def policy(self, policy): def policy(self, policy: Any) -> None:
if policy: if policy:
self._policy_json = json.loads(policy) self._policy_json = json.loads(policy)
else: else:
@ -636,7 +655,9 @@ class Queue(CloudFormationModel):
} }
def _filter_message_attributes(message, input_message_attributes): def _filter_message_attributes(
message: Message, input_message_attributes: List[str]
) -> None:
filtered_message_attributes = {} filtered_message_attributes = {}
return_all = "All" in input_message_attributes return_all = "All" in input_message_attributes
for key, value in message.message_attributes.items(): for key, value in message.message_attributes.items():
@ -646,18 +667,22 @@ def _filter_message_attributes(message, input_message_attributes):
class SQSBackend(BaseBackend): class SQSBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.queues: Dict[str, Queue] = {} self.queues: Dict[str, Queue] = {}
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "sqs" service_region, zones, "sqs"
) )
def create_queue(self, name, tags=None, **kwargs): def create_queue(
self, name: str, tags: Optional[Dict[str, str]] = None, **kwargs: Any
) -> Queue:
queue = self.queues.get(name) queue = self.queues.get(name)
if queue: if queue:
try: try:
@ -692,10 +717,10 @@ class SQSBackend(BaseBackend):
return queue return queue
def get_queue_url(self, queue_name): def get_queue_url(self, queue_name: str) -> Queue:
return self.get_queue(queue_name) return self.get_queue(queue_name)
def list_queues(self, queue_name_prefix): def list_queues(self, queue_name_prefix: str) -> List[Queue]:
re_str = ".*" re_str = ".*"
if queue_name_prefix: if queue_name_prefix:
re_str = f"^{queue_name_prefix}.*" re_str = f"^{queue_name_prefix}.*"
@ -706,18 +731,20 @@ class SQSBackend(BaseBackend):
qs.append(q) qs.append(q)
return qs[:1000] return qs[:1000]
def get_queue(self, queue_name): def get_queue(self, queue_name: str) -> Queue:
queue = self.queues.get(queue_name) queue = self.queues.get(queue_name)
if queue is None: if queue is None:
raise QueueDoesNotExist() raise QueueDoesNotExist()
return queue return queue
def delete_queue(self, queue_name): def delete_queue(self, queue_name: str) -> None:
self.get_queue(queue_name) self.get_queue(queue_name)
del self.queues[queue_name] del self.queues[queue_name]
def get_queue_attributes(self, queue_name, attribute_names): def get_queue_attributes(
self, queue_name: str, attribute_names: List[str]
) -> Dict[str, Any]:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if not attribute_names: if not attribute_names:
return {} return {}
@ -746,21 +773,23 @@ class SQSBackend(BaseBackend):
return attributes return attributes
def set_queue_attributes(self, queue_name, attributes): def set_queue_attributes(
self, queue_name: str, attributes: Dict[str, Any]
) -> Queue:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
queue._set_attributes(attributes) queue._set_attributes(attributes)
return queue return queue
def send_message( def send_message(
self, self,
queue_name, queue_name: str,
message_body, message_body: str,
message_attributes=None, message_attributes: Optional[Dict[str, Any]] = None,
delay_seconds=None, delay_seconds: Optional[int] = None,
deduplication_id=None, deduplication_id: Optional[str] = None,
group_id=None, group_id: Optional[str] = None,
system_attributes=None, system_attributes: Optional[Dict[str, Any]] = None,
): ) -> Message:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
@ -783,14 +812,14 @@ class SQSBackend(BaseBackend):
) )
raise InvalidParameterValue(msg) raise InvalidParameterValue(msg)
if len(message_body) > queue.maximum_message_size: if len(message_body) > queue.maximum_message_size: # type: ignore
msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." # type: ignore
raise InvalidParameterValue(msg) raise InvalidParameterValue(msg)
if delay_seconds: if delay_seconds:
delay_seconds = int(delay_seconds) delay_seconds = int(delay_seconds)
else: else:
delay_seconds = queue.delay_seconds delay_seconds = queue.delay_seconds # type: ignore
message_id = str(random.uuid4()) message_id = str(random.uuid4())
message = Message(message_id, message_body, system_attributes) message = Message(message_id, message_body, system_attributes)
@ -826,7 +855,7 @@ class SQSBackend(BaseBackend):
if message_attributes: if message_attributes:
message.message_attributes = message_attributes message.message_attributes = message_attributes
if delay_seconds > MAXIMUM_MESSAGE_DELAY: if delay_seconds > MAXIMUM_MESSAGE_DELAY: # type: ignore
msg = ( msg = (
f"Value {delay_seconds} for parameter DelaySeconds is invalid. " f"Value {delay_seconds} for parameter DelaySeconds is invalid. "
"Reason: DelaySeconds must be >= 0 and <= 900." "Reason: DelaySeconds must be >= 0 and <= 900."
@ -839,7 +868,9 @@ class SQSBackend(BaseBackend):
return message return message
def send_message_batch(self, queue_name, entries): def send_message_batch(
self, queue_name: str, entries: Dict[str, Dict[str, Any]]
) -> Tuple[List[Message], List[Dict[str, Any]]]:
self.get_queue(queue_name) self.get_queue(queue_name)
if any( if any(
@ -880,14 +911,14 @@ class SQSBackend(BaseBackend):
group_id=entry.get("MessageGroupId"), group_id=entry.get("MessageGroupId"),
deduplication_id=entry.get("MessageDeduplicationId"), deduplication_id=entry.get("MessageDeduplicationId"),
) )
message.user_id = entry["Id"] message.user_id = entry["Id"] # type: ignore[attr-defined]
messages.append(message) messages.append(message)
except InvalidParameterValue: except InvalidParameterValue:
failedInvalidDelay.append(entry) failedInvalidDelay.append(entry)
return messages, failedInvalidDelay return messages, failedInvalidDelay
def _get_first_duplicate_id(self, ids): def _get_first_duplicate_id(self, ids: List[str]) -> Optional[str]:
unique_ids = set() unique_ids = set()
for _id in ids: for _id in ids:
if _id in unique_ids: if _id in unique_ids:
@ -897,12 +928,12 @@ class SQSBackend(BaseBackend):
def receive_message( def receive_message(
self, self,
queue_name, queue_name: str,
count, count: int,
wait_seconds_timeout, wait_seconds_timeout: int,
visibility_timeout, visibility_timeout: int,
message_attribute_names=None, message_attribute_names: Optional[List[str]] = None,
): ) -> List[Message]:
# Attempt to retrieve visible messages from a queue. # Attempt to retrieve visible messages from a queue.
# If a message was read by client and not deleted it is considered to be # If a message was read by client and not deleted it is considered to be
@ -913,7 +944,7 @@ class SQSBackend(BaseBackend):
if message_attribute_names is None: if message_attribute_names is None:
message_attribute_names = [] message_attribute_names = []
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
result = [] result: List[Message] = []
previous_result_count = len(result) previous_result_count = len(result)
polling_end = unix_time() + wait_seconds_timeout polling_end = unix_time() + wait_seconds_timeout
@ -925,7 +956,7 @@ class SQSBackend(BaseBackend):
if result or (wait_seconds_timeout and unix_time() > polling_end): if result or (wait_seconds_timeout and unix_time() > polling_end):
break break
messages_to_dlq = [] messages_to_dlq: List[Message] = []
for message in queue.messages: for message in queue.messages:
if not message.visible: if not message.visible:
@ -966,7 +997,7 @@ class SQSBackend(BaseBackend):
for message in messages_to_dlq: for message in messages_to_dlq:
queue._messages.remove(message) queue._messages.remove(message)
queue.dead_letter_queue.add_message(message) queue.dead_letter_queue.add_message(message) # type: ignore
if previous_result_count == len(result): if previous_result_count == len(result):
if wait_seconds_timeout == 0: if wait_seconds_timeout == 0:
@ -981,12 +1012,14 @@ class SQSBackend(BaseBackend):
return result return result
def delete_message(self, queue_name, receipt_handle): def delete_message(self, queue_name: str, receipt_handle: str) -> None:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
queue.delete_message(receipt_handle) queue.delete_message(receipt_handle)
def delete_message_batch(self, queue_name, receipts): def delete_message_batch(
self, queue_name: str, receipts: List[Dict[str, Any]]
) -> Tuple[List[str], List[Dict[str, str]]]:
success = [] success = []
errors = [] errors = []
for receipt_and_id in receipts: for receipt_and_id in receipts:
@ -1004,14 +1037,16 @@ class SQSBackend(BaseBackend):
) )
return success, errors return success, errors
def change_message_visibility(self, queue_name, receipt_handle, visibility_timeout): def change_message_visibility(
self, queue_name: str, receipt_handle: str, visibility_timeout: int
) -> None:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
for message in queue._messages: for message in queue._messages:
if message.had_receipt_handle(receipt_handle): if message.had_receipt_handle(receipt_handle):
visibility_timeout_msec = int(visibility_timeout) * 1000 visibility_timeout_msec = int(visibility_timeout) * 1000
given_visibility_timeout = unix_time_millis() + visibility_timeout_msec given_visibility_timeout = unix_time_millis() + visibility_timeout_msec
if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: # type: ignore
raise InvalidParameterValue( raise InvalidParameterValue(
f"Value {visibility_timeout} for parameter VisibilityTimeout is invalid. Reason: Total " f"Value {visibility_timeout} for parameter VisibilityTimeout is invalid. Reason: Total "
"VisibilityTimeout for the message is beyond the limit [43200 seconds]" "VisibilityTimeout for the message is beyond the limit [43200 seconds]"
@ -1025,7 +1060,9 @@ class SQSBackend(BaseBackend):
return return
raise ReceiptHandleIsInvalid raise ReceiptHandleIsInvalid
def change_message_visibility_batch(self, queue_name: str, entries): def change_message_visibility_batch(
self, queue_name: str, entries: List[Dict[str, Any]]
) -> Tuple[List[str], List[Dict[str, str]]]:
success = [] success = []
error = [] error = []
for entry in entries: for entry in entries:
@ -1061,22 +1098,24 @@ class SQSBackend(BaseBackend):
) )
return success, error return success, error
def purge_queue(self, queue_name): def purge_queue(self, queue_name: str) -> None:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
queue._messages = [] queue._messages = []
queue._pending_messages = set() queue._pending_messages = set()
def list_dead_letter_source_queues(self, queue_name): def list_dead_letter_source_queues(self, queue_name: str) -> List[Queue]:
dlq = self.get_queue(queue_name) dlq = self.get_queue(queue_name)
queues = [] queues: List[Queue] = []
for queue in self.queues.values(): for queue in self.queues.values():
if queue.dead_letter_queue is dlq: if queue.dead_letter_queue is dlq:
queues.append(queue) queues.append(queue)
return queues return queues
def add_permission(self, queue_name, actions, account_ids, label): def add_permission(
self, queue_name: str, actions: List[str], account_ids: List[str], label: str
) -> None:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if not actions: if not actions:
@ -1128,7 +1167,7 @@ class SQSBackend(BaseBackend):
queue._policy_json["Statement"].append(statement) queue._policy_json["Statement"].append(statement)
def remove_permission(self, queue_name, label): def remove_permission(self, queue_name: str, label: str) -> None:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
statements = queue._policy_json["Statement"] statements = queue._policy_json["Statement"]
@ -1144,7 +1183,7 @@ class SQSBackend(BaseBackend):
queue._policy_json["Statement"] = statements_new queue._policy_json["Statement"] = statements_new
def tag_queue(self, queue_name, tags): def tag_queue(self, queue_name: str, tags: Dict[str, str]) -> None:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if not len(tags): if not len(tags):
@ -1155,7 +1194,7 @@ class SQSBackend(BaseBackend):
queue.tags.update(tags) queue.tags.update(tags)
def untag_queue(self, queue_name, tag_keys): def untag_queue(self, queue_name: str, tag_keys: List[str]) -> None:
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if not len(tag_keys): if not len(tag_keys):
@ -1170,17 +1209,16 @@ class SQSBackend(BaseBackend):
except KeyError: except KeyError:
pass pass
def list_queue_tags(self, queue_name): def list_queue_tags(self, queue_name: str) -> Queue:
return self.get_queue(queue_name) return self.get_queue(queue_name)
def is_message_valid_based_on_retention_period(self, queue_name, message): def is_message_valid_based_on_retention_period(
message_attributes = self.get_queue_attributes( self, queue_name: str, message: Message
) -> bool:
retention_period = self.get_queue_attributes(
queue_name, ["MessageRetentionPeriod"] queue_name, ["MessageRetentionPeriod"]
) )["MessageRetentionPeriod"]
retain_until = ( retain_until = retention_period + message.sent_timestamp / 1000 # type: ignore
message_attributes.get("MessageRetentionPeriod")
+ message.sent_timestamp / 1000
)
if retain_until <= unix_time(): if retain_until <= unix_time():
return False return False
return True return True

View File

@ -1,5 +1,7 @@
import re import re
from typing import Any, Dict, Optional, Tuple, Union
from moto.core.common_types import TYPE_RESPONSE
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import underscores_to_camelcase, camelcase_to_pascal from moto.core.utils import underscores_to_camelcase, camelcase_to_pascal
@ -24,7 +26,7 @@ class SQSResponse(BaseResponse):
region_regex = re.compile(r"://(.+?)\.queue\.amazonaws\.com") region_regex = re.compile(r"://(.+?)\.queue\.amazonaws\.com")
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="sqs") super().__init__(service_name="sqs")
@property @property
@ -32,7 +34,7 @@ class SQSResponse(BaseResponse):
return sqs_backends[self.current_account][self.region] return sqs_backends[self.current_account][self.region]
@property @property
def attribute(self): def attribute(self) -> Any: # type: ignore[misc]
if not hasattr(self, "_attribute"): if not hasattr(self, "_attribute"):
self._attribute = self._get_map_prefix( self._attribute = self._get_map_prefix(
"Attribute", key_end=".Name", value_end=".Value" "Attribute", key_end=".Name", value_end=".Value"
@ -40,14 +42,14 @@ class SQSResponse(BaseResponse):
return self._attribute return self._attribute
@property @property
def tags(self): def tags(self) -> Dict[str, str]:
if not hasattr(self, "_tags"): if not hasattr(self, "_tags"):
self._tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") self._tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value")
return self._tags return self._tags
def _get_queue_name(self): def _get_queue_name(self) -> str:
try: try:
queue_url = self.querystring.get("QueueUrl")[0] queue_url = self.querystring.get("QueueUrl")[0] # type: ignore
if queue_url.startswith("http://") or queue_url.startswith("https://"): if queue_url.startswith("http://") or queue_url.startswith("https://"):
return queue_url.split("/")[-1] return queue_url.split("/")[-1]
else: else:
@ -57,7 +59,7 @@ class SQSResponse(BaseResponse):
# Fallback to reading from the URL for botocore # Fallback to reading from the URL for botocore
return self.path.split("/")[-1] return self.path.split("/")[-1]
def _get_validated_visibility_timeout(self, timeout=None): def _get_validated_visibility_timeout(self, timeout: Optional[str] = None) -> int:
""" """
:raises ValueError: If specified visibility timeout exceeds MAXIMUM_VISIBILITY_TIMEOUT :raises ValueError: If specified visibility timeout exceeds MAXIMUM_VISIBILITY_TIMEOUT
:raises TypeError: If visibility timeout was not specified :raises TypeError: If visibility timeout was not specified
@ -65,7 +67,7 @@ class SQSResponse(BaseResponse):
if timeout is not None: if timeout is not None:
visibility_timeout = int(timeout) visibility_timeout = int(timeout)
else: else:
visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0]) visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0]) # type: ignore
if visibility_timeout > MAXIMUM_VISIBILITY_TIMEOUT: if visibility_timeout > MAXIMUM_VISIBILITY_TIMEOUT:
raise ValueError raise ValueError
@ -74,7 +76,7 @@ class SQSResponse(BaseResponse):
@amz_crc32 # crc last as request_id can edit XML @amz_crc32 # crc last as request_id can edit XML
@amzn_request_id @amzn_request_id
def call_action(self): def call_action(self) -> TYPE_RESPONSE:
status_code, headers, body = super().call_action() status_code, headers, body = super().call_action()
if status_code == 404: if status_code == 404:
queue_name = self.querystring.get("QueueName", [""])[0] queue_name = self.querystring.get("QueueName", [""])[0]
@ -83,11 +85,13 @@ class SQSResponse(BaseResponse):
return 404, headers, response return 404, headers, response
return status_code, headers, body return status_code, headers, body
def _error(self, code, message, status=400): def _error(
self, code: str, message: str, status: int = 400
) -> Tuple[str, Dict[str, int]]:
template = self.response_template(ERROR_TEMPLATE) template = self.response_template(ERROR_TEMPLATE)
return template.render(code=code, message=message), dict(status=status) return template.render(code=code, message=message), dict(status=status)
def create_queue(self): def create_queue(self) -> str:
request_url = urlparse(self.uri) request_url = urlparse(self.uri)
queue_name = self._get_param("QueueName") queue_name = self._get_param("QueueName")
@ -96,7 +100,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(CREATE_QUEUE_RESPONSE) template = self.response_template(CREATE_QUEUE_RESPONSE)
return template.render(queue_url=queue.url(request_url)) return template.render(queue_url=queue.url(request_url))
def get_queue_url(self): def get_queue_url(self) -> str:
request_url = urlparse(self.uri) request_url = urlparse(self.uri)
queue_name = self._get_param("QueueName") queue_name = self._get_param("QueueName")
@ -105,14 +109,14 @@ class SQSResponse(BaseResponse):
template = self.response_template(GET_QUEUE_URL_RESPONSE) template = self.response_template(GET_QUEUE_URL_RESPONSE)
return template.render(queue_url=queue.url(request_url)) return template.render(queue_url=queue.url(request_url))
def list_queues(self): def list_queues(self) -> str:
request_url = urlparse(self.uri) request_url = urlparse(self.uri)
queue_name_prefix = self._get_param("QueueNamePrefix") queue_name_prefix = self._get_param("QueueNamePrefix")
queues = self.sqs_backend.list_queues(queue_name_prefix) queues = self.sqs_backend.list_queues(queue_name_prefix)
template = self.response_template(LIST_QUEUES_RESPONSE) template = self.response_template(LIST_QUEUES_RESPONSE)
return template.render(queues=queues, request_url=request_url) return template.render(queues=queues, request_url=request_url)
def change_message_visibility(self): def change_message_visibility(self) -> Union[str, Tuple[str, Dict[str, int]]]:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
receipt_handle = self._get_param("ReceiptHandle") receipt_handle = self._get_param("ReceiptHandle")
@ -130,7 +134,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE) template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE)
return template.render() return template.render()
def change_message_visibility_batch(self): def change_message_visibility_batch(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry") entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry")
@ -141,24 +145,23 @@ class SQSResponse(BaseResponse):
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE) template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE)
return template.render(success=success, errors=error) return template.render(success=success, errors=error)
def get_queue_attributes(self): def get_queue_attributes(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
if self.querystring.get("AttributeNames"): if self.querystring.get("AttributeNames"):
raise InvalidAttributeName("") raise InvalidAttributeName("")
attribute_names = self._get_multi_param("AttributeName")
# if connecting to AWS via boto, then 'AttributeName' is just a normal parameter # if connecting to AWS via boto, then 'AttributeName' is just a normal parameter
if not attribute_names: attribute_names = self._get_multi_param(
attribute_names = self.querystring.get("AttributeName") "AttributeName"
) or self.querystring.get("AttributeName")
attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) # type: ignore
template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE) template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE)
return template.render(attributes=attributes) return template.render(attributes=attributes)
def set_queue_attributes(self): def set_queue_attributes(self) -> str:
# TODO validate self.get_param('QueueUrl') # TODO validate self.get_param('QueueUrl')
attribute = self.attribute attribute = self.attribute
@ -174,7 +177,7 @@ class SQSResponse(BaseResponse):
return SET_QUEUE_ATTRIBUTE_RESPONSE return SET_QUEUE_ATTRIBUTE_RESPONSE
def delete_queue(self): def delete_queue(self) -> str:
# TODO validate self.get_param('QueueUrl') # TODO validate self.get_param('QueueUrl')
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
@ -183,7 +186,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(DELETE_QUEUE_RESPONSE) template = self.response_template(DELETE_QUEUE_RESPONSE)
return template.render() return template.render()
def send_message(self): def send_message(self) -> Union[str, Tuple[str, Dict[str, int]]]:
message = self._get_param("MessageBody") message = self._get_param("MessageBody")
delay_seconds = int(self._get_param("DelaySeconds", 0)) delay_seconds = int(self._get_param("DelaySeconds", 0))
message_group_id = self._get_param("MessageGroupId") message_group_id = self._get_param("MessageGroupId")
@ -215,7 +218,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(SEND_MESSAGE_RESPONSE) template = self.response_template(SEND_MESSAGE_RESPONSE)
return template.render(message=message, message_attributes=message_attributes) return template.render(message=message, message_attributes=message_attributes)
def send_message_batch(self): def send_message_batch(self) -> str:
""" """
The querystring comes like this The querystring comes like this
@ -247,7 +250,7 @@ class SQSResponse(BaseResponse):
entries[index] = { entries[index] = {
"Id": value[0], "Id": value[0],
"MessageBody": self.querystring.get( "MessageBody": self.querystring.get( # type: ignore
f"SendMessageBatchRequestEntry.{index}.MessageBody" f"SendMessageBatchRequestEntry.{index}.MessageBody"
)[0], )[0],
"DelaySeconds": self.querystring.get( "DelaySeconds": self.querystring.get(
@ -286,14 +289,14 @@ class SQSResponse(BaseResponse):
template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE) template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE)
return template.render(messages=messages, errors=errors) return template.render(messages=messages, errors=errors)
def delete_message(self): def delete_message(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
receipt_handle = self.querystring.get("ReceiptHandle")[0] receipt_handle = self.querystring.get("ReceiptHandle")[0] # type: ignore
self.sqs_backend.delete_message(queue_name, receipt_handle) self.sqs_backend.delete_message(queue_name, receipt_handle)
template = self.response_template(DELETE_MESSAGE_RESPONSE) template = self.response_template(DELETE_MESSAGE_RESPONSE)
return template.render() return template.render()
def delete_message_batch(self): def delete_message_batch(self) -> str:
""" """
The querystring comes like this The querystring comes like this
@ -316,7 +319,7 @@ class SQSResponse(BaseResponse):
break break
message_user_id_key = f"DeleteMessageBatchRequestEntry.{index}.Id" message_user_id_key = f"DeleteMessageBatchRequestEntry.{index}.Id"
message_user_id = self.querystring.get(message_user_id_key)[0] message_user_id = self.querystring.get(message_user_id_key)[0] # type: ignore
receipts.append( receipts.append(
{"receipt_handle": receipt_handle[0], "msg_user_id": message_user_id} {"receipt_handle": receipt_handle[0], "msg_user_id": message_user_id}
) )
@ -333,13 +336,13 @@ class SQSResponse(BaseResponse):
template = self.response_template(DELETE_MESSAGE_BATCH_RESPONSE) template = self.response_template(DELETE_MESSAGE_BATCH_RESPONSE)
return template.render(success=success, errors=errors) return template.render(success=success, errors=errors)
def purge_queue(self): def purge_queue(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
self.sqs_backend.purge_queue(queue_name) self.sqs_backend.purge_queue(queue_name)
template = self.response_template(PURGE_QUEUE_RESPONSE) template = self.response_template(PURGE_QUEUE_RESPONSE)
return template.render() return template.render()
def receive_message(self): def receive_message(self) -> Union[str, Tuple[str, Dict[str, int]]]:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
message_attributes = self._get_multi_param("message_attributes") message_attributes = self._get_multi_param("message_attributes")
if not message_attributes: if not message_attributes:
@ -350,7 +353,7 @@ class SQSResponse(BaseResponse):
queue = self.sqs_backend.get_queue(queue_name) queue = self.sqs_backend.get_queue(queue_name)
try: try:
message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) # type: ignore
except TypeError: except TypeError:
message_count = DEFAULT_RECEIVED_MESSAGES message_count = DEFAULT_RECEIVED_MESSAGES
@ -364,9 +367,9 @@ class SQSResponse(BaseResponse):
) )
try: try:
wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) # type: ignore
except TypeError: except TypeError:
wait_time = int(queue.receive_message_wait_time_seconds) wait_time = int(queue.receive_message_wait_time_seconds) # type: ignore
if wait_time < 0 or wait_time > 20: if wait_time < 0 or wait_time > 20:
return self._error( return self._error(
@ -380,7 +383,7 @@ class SQSResponse(BaseResponse):
try: try:
visibility_timeout = self._get_validated_visibility_timeout() visibility_timeout = self._get_validated_visibility_timeout()
except TypeError: except TypeError:
visibility_timeout = queue.visibility_timeout visibility_timeout = queue.visibility_timeout # type: ignore
except ValueError: except ValueError:
return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400) return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400)
@ -406,7 +409,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(RECEIVE_MESSAGE_RESPONSE) template = self.response_template(RECEIVE_MESSAGE_RESPONSE)
return template.render(messages=messages, attributes=attributes) return template.render(messages=messages, attributes=attributes)
def list_dead_letter_source_queues(self): def list_dead_letter_source_queues(self) -> str:
request_url = urlparse(self.uri) request_url = urlparse(self.uri)
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
@ -415,7 +418,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(LIST_DEAD_LETTER_SOURCE_QUEUES_RESPONSE) template = self.response_template(LIST_DEAD_LETTER_SOURCE_QUEUES_RESPONSE)
return template.render(queues=source_queue_urls, request_url=request_url) return template.render(queues=source_queue_urls, request_url=request_url)
def add_permission(self): def add_permission(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
actions = self._get_multi_param("ActionName") actions = self._get_multi_param("ActionName")
account_ids = self._get_multi_param("AWSAccountId") account_ids = self._get_multi_param("AWSAccountId")
@ -426,7 +429,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(ADD_PERMISSION_RESPONSE) template = self.response_template(ADD_PERMISSION_RESPONSE)
return template.render() return template.render()
def remove_permission(self): def remove_permission(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
label = self._get_param("Label") label = self._get_param("Label")
@ -435,7 +438,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(REMOVE_PERMISSION_RESPONSE) template = self.response_template(REMOVE_PERMISSION_RESPONSE)
return template.render() return template.render()
def tag_queue(self): def tag_queue(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value")
@ -444,7 +447,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(TAG_QUEUE_RESPONSE) template = self.response_template(TAG_QUEUE_RESPONSE)
return template.render() return template.render()
def untag_queue(self): def untag_queue(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
tag_keys = self._get_multi_param("TagKey") tag_keys = self._get_multi_param("TagKey")
@ -453,7 +456,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(UNTAG_QUEUE_RESPONSE) template = self.response_template(UNTAG_QUEUE_RESPONSE)
return template.render() return template.render()
def list_queue_tags(self): def list_queue_tags(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
queue = self.sqs_backend.list_queue_tags(queue_name) queue = self.sqs_backend.list_queue_tags(queue_name)

View File

@ -1,15 +1,17 @@
import string import string
from typing import Any, Dict, List
from moto.moto_api._internal import mock_random as random from moto.moto_api._internal import mock_random as random
from .exceptions import MessageAttributesInvalid from .exceptions import MessageAttributesInvalid
def generate_receipt_handle(): def generate_receipt_handle() -> str:
# http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ImportantIdentifiers.html#ImportantIdentifiers-receipt-handles # http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ImportantIdentifiers.html#ImportantIdentifiers-receipt-handles
length = 185 length = 185
return "".join(random.choice(string.ascii_lowercase) for x in range(length)) return "".join(random.choice(string.ascii_lowercase) for x in range(length))
def extract_input_message_attributes(querystring): def extract_input_message_attributes(querystring: Dict[str, Any]) -> List[str]:
message_attributes = [] message_attributes = []
index = 1 index = 1
while True: while True:
@ -25,8 +27,11 @@ def extract_input_message_attributes(querystring):
def parse_message_attributes( def parse_message_attributes(
querystring, key="MessageAttribute", base="", value_namespace="Value." querystring: Dict[str, Any],
): key: str = "MessageAttribute",
base: str = "",
value_namespace: str = "Value.",
) -> Dict[str, Any]:
message_attributes = {} message_attributes = {}
index = 1 index = 1
while True: while True:

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ssm,moto/scheduler files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/sqs,moto/ssm,moto/scheduler
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract