Techdebt: MyPy SQS (#6251)

This commit is contained in:
Bert Blommers 2023-04-24 16:46:01 +00:00 committed by GitHub
parent a32b721c79
commit 1f56b75ccf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 241 additions and 195 deletions

View File

@ -4,7 +4,7 @@ from moto.core.exceptions import RESTError
class ReceiptHandleIsInvalid(RESTError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"ReceiptHandleIsInvalid", "The input receipt handle is invalid."
)
@ -13,14 +13,14 @@ class ReceiptHandleIsInvalid(RESTError):
class MessageAttributesInvalid(RESTError):
code = 400
def __init__(self, description):
def __init__(self, description: str):
super().__init__("MessageAttributesInvalid", description)
class QueueDoesNotExist(RESTError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"AWS.SimpleQueueService.NonExistentQueue",
"The specified queue does not exist for this wsdl version.",
@ -31,14 +31,14 @@ class QueueDoesNotExist(RESTError):
class QueueAlreadyExists(RESTError):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("QueueAlreadyExists", message)
class EmptyBatchRequest(RESTError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"EmptyBatchRequest",
"There should be at least one SendMessageBatchRequestEntry in the request.",
@ -48,7 +48,7 @@ class EmptyBatchRequest(RESTError):
class InvalidBatchEntryId(RESTError):
code = 400
def __init__(self):
def __init__(self) -> None:
super().__init__(
"InvalidBatchEntryId",
"A batch entry id can only contain alphanumeric characters, "
@ -59,7 +59,7 @@ class InvalidBatchEntryId(RESTError):
class BatchRequestTooLong(RESTError):
code = 400
def __init__(self, length):
def __init__(self, length: int):
super().__init__(
"BatchRequestTooLong",
"Batch requests cannot be longer than 262144 bytes. "
@ -70,14 +70,14 @@ class BatchRequestTooLong(RESTError):
class BatchEntryIdsNotDistinct(RESTError):
code = 400
def __init__(self, entry_id):
def __init__(self, entry_id: str):
super().__init__("BatchEntryIdsNotDistinct", f"Id {entry_id} repeated.")
class TooManyEntriesInBatchRequest(RESTError):
code = 400
def __init__(self, number):
def __init__(self, number: int):
super().__init__(
"TooManyEntriesInBatchRequest",
"Maximum number of entries per request are 10. " f"You have sent {number}.",
@ -87,14 +87,14 @@ class TooManyEntriesInBatchRequest(RESTError):
class InvalidAttributeName(RESTError):
code = 400
def __init__(self, attribute_name):
def __init__(self, attribute_name: str):
super().__init__("InvalidAttributeName", f"Unknown Attribute {attribute_name}.")
class InvalidAttributeValue(RESTError):
code = 400
def __init__(self, attribute_name):
def __init__(self, attribute_name: str):
super().__init__(
"InvalidAttributeValue",
f"Invalid value for the parameter {attribute_name}.",
@ -104,14 +104,14 @@ class InvalidAttributeValue(RESTError):
class InvalidParameterValue(RESTError):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("InvalidParameterValue", message)
class MissingParameter(RESTError):
code = 400
def __init__(self, parameter):
def __init__(self, parameter: str):
super().__init__(
"MissingParameter", f"The request must contain the parameter {parameter}."
)
@ -120,7 +120,7 @@ class MissingParameter(RESTError):
class OverLimit(RESTError):
code = 403
def __init__(self, count):
def __init__(self, count: int):
super().__init__(
"OverLimit", f"{count} Actions were found, maximum allowed is 7."
)

View File

@ -6,8 +6,9 @@ import string
import struct
from copy import deepcopy
from typing import Dict
from typing import Any, Dict, List, Optional, Tuple, Set, TYPE_CHECKING
from threading import Condition
from urllib.parse import ParseResult
from xml.sax.saxutils import escape
from moto.core.exceptions import RESTError
@ -38,6 +39,9 @@ from .exceptions import (
InvalidAttributeValue,
)
if TYPE_CHECKING:
from moto.awslambda.models import EventSourceMapping
DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU"
MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB
@ -67,31 +71,36 @@ DEDUPLICATION_TIME_IN_SECONDS = 300
class Message(BaseModel):
def __init__(self, message_id, body, system_attributes=None):
def __init__(
self,
message_id: str,
body: str,
system_attributes: Optional[Dict[str, Any]] = None,
):
self.id = message_id
self._body = body
self.message_attributes = {}
self.receipt_handle = None
self._old_receipt_handles = []
self.message_attributes: Dict[str, Any] = {}
self.receipt_handle: Optional[str] = None
self._old_receipt_handles: List[str] = []
self.sender_id = DEFAULT_SENDER_ID
self.sent_timestamp = None
self.approximate_first_receive_timestamp = None
self.approximate_first_receive_timestamp: Optional[int] = None
self.approximate_receive_count = 0
self.deduplication_id = None
self.group_id = None
self.sequence_number = None
self.visible_at = 0
self.delayed_until = 0
self.deduplication_id: Optional[str] = None
self.group_id: Optional[str] = None
self.sequence_number: Optional[str] = None
self.visible_at = 0.0
self.delayed_until = 0.0
self.system_attributes = system_attributes or {}
@property
def body_md5(self):
def body_md5(self) -> str:
md5 = md5_hash()
md5.update(self._body.encode("utf-8"))
return md5.hexdigest()
@property
def attribute_md5(self):
def attribute_md5(self) -> str:
md5 = md5_hash()
@ -129,13 +138,13 @@ class Message(BaseModel):
return md5.hexdigest()
@staticmethod
def update_binary_length_and_value(md5, value):
def update_binary_length_and_value(md5: Any, value: bytes) -> None: # type: ignore[misc]
length_bytes = struct.pack("!I".encode("ascii"), len(value))
md5.update(length_bytes)
md5.update(value)
@staticmethod
def validate_attribute_name(name):
def validate_attribute_name(name: str) -> None:
if not ATTRIBUTE_NAME_PATTERN.match(name):
raise MessageAttributesInvalid(
f"The message attribute name '{name}' is invalid. "
@ -144,21 +153,21 @@ class Message(BaseModel):
)
@staticmethod
def utf8(value):
def utf8(value: Any) -> bytes: # type: ignore[misc]
if isinstance(value, str):
return value.encode("utf-8")
return value
@property
def body(self):
def body(self) -> str:
return escape(self._body).replace('"', """).replace("\r", "
")
def mark_sent(self, delay_seconds=None):
self.sent_timestamp = int(unix_time_millis())
def mark_sent(self, delay_seconds: Optional[int] = None) -> None:
self.sent_timestamp = int(unix_time_millis()) # type: ignore
if delay_seconds:
self.delay(delay_seconds=delay_seconds)
def mark_received(self, visibility_timeout=None):
def mark_received(self, visibility_timeout: Optional[int] = None) -> None:
"""
When a message is received we will set the first receive timestamp,
tap the ``approximate_receive_count`` and the ``visible_at`` time.
@ -178,37 +187,37 @@ class Message(BaseModel):
if visibility_timeout:
self.change_visibility(visibility_timeout)
self._old_receipt_handles.append(self.receipt_handle)
self._old_receipt_handles.append(self.receipt_handle) # type: ignore
self.receipt_handle = generate_receipt_handle()
def change_visibility(self, visibility_timeout):
def change_visibility(self, visibility_timeout: int) -> None:
# We're dealing with milliseconds internally
visibility_timeout_msec = int(visibility_timeout) * 1000
self.visible_at = unix_time_millis() + visibility_timeout_msec
def delay(self, delay_seconds):
def delay(self, delay_seconds: int) -> None:
delay_msec = int(delay_seconds) * 1000
self.delayed_until = unix_time_millis() + delay_msec
@property
def visible(self):
def visible(self) -> bool:
current_time = unix_time_millis()
if current_time > self.visible_at:
return True
return False
@property
def delayed(self):
def delayed(self) -> bool:
current_time = unix_time_millis()
if current_time < self.delayed_until:
return True
return False
@property
def all_receipt_handles(self):
return [self.receipt_handle] + self._old_receipt_handles
def all_receipt_handles(self) -> List[Optional[str]]:
return [self.receipt_handle] + self._old_receipt_handles # type: ignore
def had_receipt_handle(self, receipt_handle):
def had_receipt_handle(self, receipt_handle: str) -> bool:
"""
Check if this message ever had this receipt_handle in the past
"""
@ -250,24 +259,25 @@ class Queue(CloudFormationModel):
"SendMessage",
)
def __init__(self, name, region, account_id, **kwargs):
def __init__(self, name: str, region: str, account_id: str, **kwargs: Any):
self.name = name
self.region = region
self.account_id = account_id
self.tags = {}
self.permissions = {}
self.tags: Dict[str, str] = {}
self.permissions: Dict[str, Any] = {}
self._messages = []
self._pending_messages = set()
self.deleted_messages = set()
self._messages: List[Message] = []
self._pending_messages: Set[Message] = set()
self.deleted_messages: Set[str] = set()
self._messages_lock = Condition()
now = unix_time()
self.created_timestamp = now
self.queue_arn = f"arn:aws:sqs:{region}:{account_id}:{name}"
self.dead_letter_queue = None
self.dead_letter_queue: Optional["Queue"] = None
self.fifo_queue = False
self.lambda_event_source_mappings = {}
self.lambda_event_source_mappings: Dict[str, "EventSourceMapping"] = {}
# default settings for a non fifo queue
defaults = {
@ -293,24 +303,26 @@ class Queue(CloudFormationModel):
if self.fifo_queue and not self.name.endswith(".fifo"):
raise InvalidParameterValue("Queue name must end in .fifo for FIFO queues")
if (
self.maximum_message_size < MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND
or self.maximum_message_size > MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND
self.maximum_message_size < MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND # type: ignore
or self.maximum_message_size > MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND # type: ignore
):
raise InvalidAttributeValue("MaximumMessageSize")
@property
def pending_messages(self):
def pending_messages(self) -> Set[Message]:
return self._pending_messages
@property
def pending_message_groups(self):
def pending_message_groups(self) -> Set[str]:
return set(
message.group_id
for message in self._pending_messages
if message.group_id is not None
)
def _set_attributes(self, attributes, now=None):
def _set_attributes(
self, attributes: Dict[str, Any], now: Optional[float] = None
) -> None:
if not now:
now = unix_time()
@ -343,7 +355,7 @@ class Queue(CloudFormationModel):
self.last_modified_timestamp = now
@staticmethod
def _is_empty_redrive_policy(policy):
def _is_empty_redrive_policy(policy: Any) -> bool: # type: ignore[misc]
if isinstance(policy, str):
if policy == "" or len(json.loads(policy)) == 0:
return True
@ -352,7 +364,7 @@ class Queue(CloudFormationModel):
return False
def _setup_dlq(self, policy):
def _setup_dlq(self, policy: Any) -> None:
if Queue._is_empty_redrive_policy(policy):
self.redrive_policy = None
self.dead_letter_queue = None
@ -407,18 +419,23 @@ class Queue(CloudFormationModel):
)
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "QueueName"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sqs-queue.html
return "AWS::SQS::Queue"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Queue":
properties = deepcopy(cloudformation_json["Properties"])
# remove Tags from properties and convert tags list to dict
tags = properties.pop("Tags", [])
@ -433,14 +450,14 @@ class Queue(CloudFormationModel):
)
@classmethod
def update_from_cloudformation_json(
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource,
new_resource_name,
cloudformation_json,
account_id,
region_name,
):
original_resource: Any,
new_resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> "Queue":
properties = cloudformation_json["Properties"]
queue_name = original_resource.name
@ -456,9 +473,13 @@ class Queue(CloudFormationModel):
return queue
@classmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name
):
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
# ResourceName will be the full queue URL - we only need the name
# https://sqs.us-west-1.amazonaws.com/123456789012/queue_name
queue_name = resource_name.split("/")[-1]
@ -466,24 +487,24 @@ class Queue(CloudFormationModel):
sqs_backend.delete_queue(queue_name)
@property
def approximate_number_of_messages_delayed(self):
def approximate_number_of_messages_delayed(self) -> int:
return len([m for m in self._messages if m.delayed])
@property
def approximate_number_of_messages_not_visible(self):
def approximate_number_of_messages_not_visible(self) -> int:
return len([m for m in self._messages if not m.visible])
@property
def approximate_number_of_messages(self):
def approximate_number_of_messages(self) -> int:
return len(self.messages)
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return f"https://sqs.{self.region}.amazonaws.com/{self.account_id}/{self.name}"
@property
def attributes(self):
result = {}
def attributes(self) -> Dict[str, Any]: # type: ignore[misc]
result: Dict[str, Any] = {}
for attribute in self.BASE_ATTRIBUTES:
attr = getattr(self, camelcase_to_underscores(attribute))
@ -494,7 +515,7 @@ class Queue(CloudFormationModel):
attr = getattr(self, camelcase_to_underscores(attribute))
result[attribute] = attr
if self.kms_master_key_id:
if self.kms_master_key_id: # type: ignore
for attribute in self.KMS_ATTRIBUTES:
attr = getattr(self, camelcase_to_underscores(attribute))
result[attribute] = attr
@ -511,13 +532,11 @@ class Queue(CloudFormationModel):
return result
def url(self, request_url):
return (
f"{request_url.scheme}://{request_url.netloc}/{self.account_id}/{self.name}"
)
def url(self, request_url: ParseResult) -> str:
return f"{request_url.scheme}://{request_url.netloc}/{self.account_id}/{self.name}" # type: ignore
@property
def messages(self):
def messages(self) -> List[Message]:
# TODO: This can become very inefficient if a large number of messages are in-flight
return [
message
@ -525,7 +544,7 @@ class Queue(CloudFormationModel):
if message.visible and not message.delayed
]
def add_message(self, message):
def add_message(self, message: Message) -> None:
if self.fifo_queue:
# the cases in which we dedupe fifo messages
@ -537,7 +556,7 @@ class Queue(CloudFormationModel):
):
for m in self._messages:
if m.deduplication_id == message.deduplication_id:
diff = message.sent_timestamp - m.sent_timestamp
diff = message.sent_timestamp - m.sent_timestamp # type: ignore
# if a duplicate message is received within the deduplication time then it should
# not be added to the queue
if diff / 1000 < DEDUPLICATION_TIME_IN_SECONDS:
@ -559,8 +578,8 @@ class Queue(CloudFormationModel):
messages = backend.receive_message(
self.name,
esm.batch_size,
self.receive_message_wait_time_seconds,
self.visibility_timeout,
self.receive_message_wait_time_seconds, # type: ignore
self.visibility_timeout, # type: ignore
)
from moto.awslambda import lambda_backends
@ -580,7 +599,7 @@ class Queue(CloudFormationModel):
for m in messages
]
def delete_message(self, receipt_handle):
def delete_message(self, receipt_handle: str) -> None:
if receipt_handle in self.deleted_messages:
# Already deleted - gracefully handle deleting it again
return
@ -595,20 +614,20 @@ class Queue(CloudFormationModel):
for message in self._messages:
if message.had_receipt_handle(receipt_handle):
self.pending_messages.discard(message)
self.deleted_messages.update(message.all_receipt_handles)
self.deleted_messages.update(message.all_receipt_handles) # type: ignore
continue
new_messages.append(message)
self._messages = new_messages
def wait_for_messages(self, timeout):
def wait_for_messages(self, timeout: int) -> None:
with self._messages_lock:
self._messages_lock.wait_for(lambda: self.messages, timeout=timeout)
@classmethod
def has_cfn_attr(cls, attr):
def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["Arn", "QueueName"]
def get_cfn_attribute(self, attribute_name):
def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
@ -618,14 +637,14 @@ class Queue(CloudFormationModel):
raise UnformattedGetAttTemplateException()
@property
def policy(self):
def policy(self) -> Any: # type: ignore[misc]
if self._policy_json.get("Statement"):
return json.dumps(self._policy_json)
else:
return None
@policy.setter
def policy(self, policy):
def policy(self, policy: Any) -> None:
if policy:
self._policy_json = json.loads(policy)
else:
@ -636,7 +655,9 @@ class Queue(CloudFormationModel):
}
def _filter_message_attributes(message, input_message_attributes):
def _filter_message_attributes(
message: Message, input_message_attributes: List[str]
) -> None:
filtered_message_attributes = {}
return_all = "All" in input_message_attributes
for key, value in message.message_attributes.items():
@ -646,18 +667,22 @@ def _filter_message_attributes(message, input_message_attributes):
class SQSBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.queues: Dict[str, Queue] = {}
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "sqs"
)
def create_queue(self, name, tags=None, **kwargs):
def create_queue(
self, name: str, tags: Optional[Dict[str, str]] = None, **kwargs: Any
) -> Queue:
queue = self.queues.get(name)
if queue:
try:
@ -692,10 +717,10 @@ class SQSBackend(BaseBackend):
return queue
def get_queue_url(self, queue_name):
def get_queue_url(self, queue_name: str) -> Queue:
return self.get_queue(queue_name)
def list_queues(self, queue_name_prefix):
def list_queues(self, queue_name_prefix: str) -> List[Queue]:
re_str = ".*"
if queue_name_prefix:
re_str = f"^{queue_name_prefix}.*"
@ -706,18 +731,20 @@ class SQSBackend(BaseBackend):
qs.append(q)
return qs[:1000]
def get_queue(self, queue_name):
def get_queue(self, queue_name: str) -> Queue:
queue = self.queues.get(queue_name)
if queue is None:
raise QueueDoesNotExist()
return queue
def delete_queue(self, queue_name):
def delete_queue(self, queue_name: str) -> None:
self.get_queue(queue_name)
del self.queues[queue_name]
def get_queue_attributes(self, queue_name, attribute_names):
def get_queue_attributes(
self, queue_name: str, attribute_names: List[str]
) -> Dict[str, Any]:
queue = self.get_queue(queue_name)
if not attribute_names:
return {}
@ -746,21 +773,23 @@ class SQSBackend(BaseBackend):
return attributes
def set_queue_attributes(self, queue_name, attributes):
def set_queue_attributes(
self, queue_name: str, attributes: Dict[str, Any]
) -> Queue:
queue = self.get_queue(queue_name)
queue._set_attributes(attributes)
return queue
def send_message(
self,
queue_name,
message_body,
message_attributes=None,
delay_seconds=None,
deduplication_id=None,
group_id=None,
system_attributes=None,
):
queue_name: str,
message_body: str,
message_attributes: Optional[Dict[str, Any]] = None,
delay_seconds: Optional[int] = None,
deduplication_id: Optional[str] = None,
group_id: Optional[str] = None,
system_attributes: Optional[Dict[str, Any]] = None,
) -> Message:
queue = self.get_queue(queue_name)
@ -783,14 +812,14 @@ class SQSBackend(BaseBackend):
)
raise InvalidParameterValue(msg)
if len(message_body) > queue.maximum_message_size:
msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes."
if len(message_body) > queue.maximum_message_size: # type: ignore
msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." # type: ignore
raise InvalidParameterValue(msg)
if delay_seconds:
delay_seconds = int(delay_seconds)
else:
delay_seconds = queue.delay_seconds
delay_seconds = queue.delay_seconds # type: ignore
message_id = str(random.uuid4())
message = Message(message_id, message_body, system_attributes)
@ -826,7 +855,7 @@ class SQSBackend(BaseBackend):
if message_attributes:
message.message_attributes = message_attributes
if delay_seconds > MAXIMUM_MESSAGE_DELAY:
if delay_seconds > MAXIMUM_MESSAGE_DELAY: # type: ignore
msg = (
f"Value {delay_seconds} for parameter DelaySeconds is invalid. "
"Reason: DelaySeconds must be >= 0 and <= 900."
@ -839,7 +868,9 @@ class SQSBackend(BaseBackend):
return message
def send_message_batch(self, queue_name, entries):
def send_message_batch(
self, queue_name: str, entries: Dict[str, Dict[str, Any]]
) -> Tuple[List[Message], List[Dict[str, Any]]]:
self.get_queue(queue_name)
if any(
@ -880,14 +911,14 @@ class SQSBackend(BaseBackend):
group_id=entry.get("MessageGroupId"),
deduplication_id=entry.get("MessageDeduplicationId"),
)
message.user_id = entry["Id"]
message.user_id = entry["Id"] # type: ignore[attr-defined]
messages.append(message)
except InvalidParameterValue:
failedInvalidDelay.append(entry)
return messages, failedInvalidDelay
def _get_first_duplicate_id(self, ids):
def _get_first_duplicate_id(self, ids: List[str]) -> Optional[str]:
unique_ids = set()
for _id in ids:
if _id in unique_ids:
@ -897,12 +928,12 @@ class SQSBackend(BaseBackend):
def receive_message(
self,
queue_name,
count,
wait_seconds_timeout,
visibility_timeout,
message_attribute_names=None,
):
queue_name: str,
count: int,
wait_seconds_timeout: int,
visibility_timeout: int,
message_attribute_names: Optional[List[str]] = None,
) -> List[Message]:
# Attempt to retrieve visible messages from a queue.
# If a message was read by client and not deleted it is considered to be
@ -913,7 +944,7 @@ class SQSBackend(BaseBackend):
if message_attribute_names is None:
message_attribute_names = []
queue = self.get_queue(queue_name)
result = []
result: List[Message] = []
previous_result_count = len(result)
polling_end = unix_time() + wait_seconds_timeout
@ -925,7 +956,7 @@ class SQSBackend(BaseBackend):
if result or (wait_seconds_timeout and unix_time() > polling_end):
break
messages_to_dlq = []
messages_to_dlq: List[Message] = []
for message in queue.messages:
if not message.visible:
@ -966,7 +997,7 @@ class SQSBackend(BaseBackend):
for message in messages_to_dlq:
queue._messages.remove(message)
queue.dead_letter_queue.add_message(message)
queue.dead_letter_queue.add_message(message) # type: ignore
if previous_result_count == len(result):
if wait_seconds_timeout == 0:
@ -981,12 +1012,14 @@ class SQSBackend(BaseBackend):
return result
def delete_message(self, queue_name, receipt_handle):
def delete_message(self, queue_name: str, receipt_handle: str) -> None:
queue = self.get_queue(queue_name)
queue.delete_message(receipt_handle)
def delete_message_batch(self, queue_name, receipts):
def delete_message_batch(
self, queue_name: str, receipts: List[Dict[str, Any]]
) -> Tuple[List[str], List[Dict[str, str]]]:
success = []
errors = []
for receipt_and_id in receipts:
@ -1004,14 +1037,16 @@ class SQSBackend(BaseBackend):
)
return success, errors
def change_message_visibility(self, queue_name, receipt_handle, visibility_timeout):
def change_message_visibility(
self, queue_name: str, receipt_handle: str, visibility_timeout: int
) -> None:
queue = self.get_queue(queue_name)
for message in queue._messages:
if message.had_receipt_handle(receipt_handle):
visibility_timeout_msec = int(visibility_timeout) * 1000
given_visibility_timeout = unix_time_millis() + visibility_timeout_msec
if given_visibility_timeout - message.sent_timestamp > 43200 * 1000:
if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: # type: ignore
raise InvalidParameterValue(
f"Value {visibility_timeout} for parameter VisibilityTimeout is invalid. Reason: Total "
"VisibilityTimeout for the message is beyond the limit [43200 seconds]"
@ -1025,7 +1060,9 @@ class SQSBackend(BaseBackend):
return
raise ReceiptHandleIsInvalid
def change_message_visibility_batch(self, queue_name: str, entries):
def change_message_visibility_batch(
self, queue_name: str, entries: List[Dict[str, Any]]
) -> Tuple[List[str], List[Dict[str, str]]]:
success = []
error = []
for entry in entries:
@ -1061,22 +1098,24 @@ class SQSBackend(BaseBackend):
)
return success, error
def purge_queue(self, queue_name):
def purge_queue(self, queue_name: str) -> None:
queue = self.get_queue(queue_name)
queue._messages = []
queue._pending_messages = set()
def list_dead_letter_source_queues(self, queue_name):
def list_dead_letter_source_queues(self, queue_name: str) -> List[Queue]:
dlq = self.get_queue(queue_name)
queues = []
queues: List[Queue] = []
for queue in self.queues.values():
if queue.dead_letter_queue is dlq:
queues.append(queue)
return queues
def add_permission(self, queue_name, actions, account_ids, label):
def add_permission(
self, queue_name: str, actions: List[str], account_ids: List[str], label: str
) -> None:
queue = self.get_queue(queue_name)
if not actions:
@ -1128,7 +1167,7 @@ class SQSBackend(BaseBackend):
queue._policy_json["Statement"].append(statement)
def remove_permission(self, queue_name, label):
def remove_permission(self, queue_name: str, label: str) -> None:
queue = self.get_queue(queue_name)
statements = queue._policy_json["Statement"]
@ -1144,7 +1183,7 @@ class SQSBackend(BaseBackend):
queue._policy_json["Statement"] = statements_new
def tag_queue(self, queue_name, tags):
def tag_queue(self, queue_name: str, tags: Dict[str, str]) -> None:
queue = self.get_queue(queue_name)
if not len(tags):
@ -1155,7 +1194,7 @@ class SQSBackend(BaseBackend):
queue.tags.update(tags)
def untag_queue(self, queue_name, tag_keys):
def untag_queue(self, queue_name: str, tag_keys: List[str]) -> None:
queue = self.get_queue(queue_name)
if not len(tag_keys):
@ -1170,17 +1209,16 @@ class SQSBackend(BaseBackend):
except KeyError:
pass
def list_queue_tags(self, queue_name):
def list_queue_tags(self, queue_name: str) -> Queue:
return self.get_queue(queue_name)
def is_message_valid_based_on_retention_period(self, queue_name, message):
message_attributes = self.get_queue_attributes(
def is_message_valid_based_on_retention_period(
self, queue_name: str, message: Message
) -> bool:
retention_period = self.get_queue_attributes(
queue_name, ["MessageRetentionPeriod"]
)
retain_until = (
message_attributes.get("MessageRetentionPeriod")
+ message.sent_timestamp / 1000
)
)["MessageRetentionPeriod"]
retain_until = retention_period + message.sent_timestamp / 1000 # type: ignore
if retain_until <= unix_time():
return False
return True

View File

@ -1,5 +1,7 @@
import re
from typing import Any, Dict, Optional, Tuple, Union
from moto.core.common_types import TYPE_RESPONSE
from moto.core.exceptions import RESTError
from moto.core.responses import BaseResponse
from moto.core.utils import underscores_to_camelcase, camelcase_to_pascal
@ -24,7 +26,7 @@ class SQSResponse(BaseResponse):
region_regex = re.compile(r"://(.+?)\.queue\.amazonaws\.com")
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="sqs")
@property
@ -32,7 +34,7 @@ class SQSResponse(BaseResponse):
return sqs_backends[self.current_account][self.region]
@property
def attribute(self):
def attribute(self) -> Any: # type: ignore[misc]
if not hasattr(self, "_attribute"):
self._attribute = self._get_map_prefix(
"Attribute", key_end=".Name", value_end=".Value"
@ -40,14 +42,14 @@ class SQSResponse(BaseResponse):
return self._attribute
@property
def tags(self):
def tags(self) -> Dict[str, str]:
if not hasattr(self, "_tags"):
self._tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value")
return self._tags
def _get_queue_name(self):
def _get_queue_name(self) -> str:
try:
queue_url = self.querystring.get("QueueUrl")[0]
queue_url = self.querystring.get("QueueUrl")[0] # type: ignore
if queue_url.startswith("http://") or queue_url.startswith("https://"):
return queue_url.split("/")[-1]
else:
@ -57,7 +59,7 @@ class SQSResponse(BaseResponse):
# Fallback to reading from the URL for botocore
return self.path.split("/")[-1]
def _get_validated_visibility_timeout(self, timeout=None):
def _get_validated_visibility_timeout(self, timeout: Optional[str] = None) -> int:
"""
:raises ValueError: If specified visibility timeout exceeds MAXIMUM_VISIBILITY_TIMEOUT
:raises TypeError: If visibility timeout was not specified
@ -65,7 +67,7 @@ class SQSResponse(BaseResponse):
if timeout is not None:
visibility_timeout = int(timeout)
else:
visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0])
visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0]) # type: ignore
if visibility_timeout > MAXIMUM_VISIBILITY_TIMEOUT:
raise ValueError
@ -74,7 +76,7 @@ class SQSResponse(BaseResponse):
@amz_crc32 # crc last as request_id can edit XML
@amzn_request_id
def call_action(self):
def call_action(self) -> TYPE_RESPONSE:
status_code, headers, body = super().call_action()
if status_code == 404:
queue_name = self.querystring.get("QueueName", [""])[0]
@ -83,11 +85,13 @@ class SQSResponse(BaseResponse):
return 404, headers, response
return status_code, headers, body
def _error(self, code, message, status=400):
def _error(
self, code: str, message: str, status: int = 400
) -> Tuple[str, Dict[str, int]]:
template = self.response_template(ERROR_TEMPLATE)
return template.render(code=code, message=message), dict(status=status)
def create_queue(self):
def create_queue(self) -> str:
request_url = urlparse(self.uri)
queue_name = self._get_param("QueueName")
@ -96,7 +100,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(CREATE_QUEUE_RESPONSE)
return template.render(queue_url=queue.url(request_url))
def get_queue_url(self):
def get_queue_url(self) -> str:
request_url = urlparse(self.uri)
queue_name = self._get_param("QueueName")
@ -105,14 +109,14 @@ class SQSResponse(BaseResponse):
template = self.response_template(GET_QUEUE_URL_RESPONSE)
return template.render(queue_url=queue.url(request_url))
def list_queues(self):
def list_queues(self) -> str:
request_url = urlparse(self.uri)
queue_name_prefix = self._get_param("QueueNamePrefix")
queues = self.sqs_backend.list_queues(queue_name_prefix)
template = self.response_template(LIST_QUEUES_RESPONSE)
return template.render(queues=queues, request_url=request_url)
def change_message_visibility(self):
def change_message_visibility(self) -> Union[str, Tuple[str, Dict[str, int]]]:
queue_name = self._get_queue_name()
receipt_handle = self._get_param("ReceiptHandle")
@ -130,7 +134,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE)
return template.render()
def change_message_visibility_batch(self):
def change_message_visibility_batch(self) -> str:
queue_name = self._get_queue_name()
entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry")
@ -141,24 +145,23 @@ class SQSResponse(BaseResponse):
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE)
return template.render(success=success, errors=error)
def get_queue_attributes(self):
def get_queue_attributes(self) -> str:
queue_name = self._get_queue_name()
if self.querystring.get("AttributeNames"):
raise InvalidAttributeName("")
attribute_names = self._get_multi_param("AttributeName")
# if connecting to AWS via boto, then 'AttributeName' is just a normal parameter
if not attribute_names:
attribute_names = self.querystring.get("AttributeName")
attribute_names = self._get_multi_param(
"AttributeName"
) or self.querystring.get("AttributeName")
attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names)
attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) # type: ignore
template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE)
return template.render(attributes=attributes)
def set_queue_attributes(self):
def set_queue_attributes(self) -> str:
# TODO validate self.get_param('QueueUrl')
attribute = self.attribute
@ -174,7 +177,7 @@ class SQSResponse(BaseResponse):
return SET_QUEUE_ATTRIBUTE_RESPONSE
def delete_queue(self):
def delete_queue(self) -> str:
# TODO validate self.get_param('QueueUrl')
queue_name = self._get_queue_name()
@ -183,7 +186,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(DELETE_QUEUE_RESPONSE)
return template.render()
def send_message(self):
def send_message(self) -> Union[str, Tuple[str, Dict[str, int]]]:
message = self._get_param("MessageBody")
delay_seconds = int(self._get_param("DelaySeconds", 0))
message_group_id = self._get_param("MessageGroupId")
@ -215,7 +218,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(SEND_MESSAGE_RESPONSE)
return template.render(message=message, message_attributes=message_attributes)
def send_message_batch(self):
def send_message_batch(self) -> str:
"""
The querystring comes like this
@ -247,7 +250,7 @@ class SQSResponse(BaseResponse):
entries[index] = {
"Id": value[0],
"MessageBody": self.querystring.get(
"MessageBody": self.querystring.get( # type: ignore
f"SendMessageBatchRequestEntry.{index}.MessageBody"
)[0],
"DelaySeconds": self.querystring.get(
@ -286,14 +289,14 @@ class SQSResponse(BaseResponse):
template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE)
return template.render(messages=messages, errors=errors)
def delete_message(self):
def delete_message(self) -> str:
queue_name = self._get_queue_name()
receipt_handle = self.querystring.get("ReceiptHandle")[0]
receipt_handle = self.querystring.get("ReceiptHandle")[0] # type: ignore
self.sqs_backend.delete_message(queue_name, receipt_handle)
template = self.response_template(DELETE_MESSAGE_RESPONSE)
return template.render()
def delete_message_batch(self):
def delete_message_batch(self) -> str:
"""
The querystring comes like this
@ -316,7 +319,7 @@ class SQSResponse(BaseResponse):
break
message_user_id_key = f"DeleteMessageBatchRequestEntry.{index}.Id"
message_user_id = self.querystring.get(message_user_id_key)[0]
message_user_id = self.querystring.get(message_user_id_key)[0] # type: ignore
receipts.append(
{"receipt_handle": receipt_handle[0], "msg_user_id": message_user_id}
)
@ -333,13 +336,13 @@ class SQSResponse(BaseResponse):
template = self.response_template(DELETE_MESSAGE_BATCH_RESPONSE)
return template.render(success=success, errors=errors)
def purge_queue(self):
def purge_queue(self) -> str:
queue_name = self._get_queue_name()
self.sqs_backend.purge_queue(queue_name)
template = self.response_template(PURGE_QUEUE_RESPONSE)
return template.render()
def receive_message(self):
def receive_message(self) -> Union[str, Tuple[str, Dict[str, int]]]:
queue_name = self._get_queue_name()
message_attributes = self._get_multi_param("message_attributes")
if not message_attributes:
@ -350,7 +353,7 @@ class SQSResponse(BaseResponse):
queue = self.sqs_backend.get_queue(queue_name)
try:
message_count = int(self.querystring.get("MaxNumberOfMessages")[0])
message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) # type: ignore
except TypeError:
message_count = DEFAULT_RECEIVED_MESSAGES
@ -364,9 +367,9 @@ class SQSResponse(BaseResponse):
)
try:
wait_time = int(self.querystring.get("WaitTimeSeconds")[0])
wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) # type: ignore
except TypeError:
wait_time = int(queue.receive_message_wait_time_seconds)
wait_time = int(queue.receive_message_wait_time_seconds) # type: ignore
if wait_time < 0 or wait_time > 20:
return self._error(
@ -380,7 +383,7 @@ class SQSResponse(BaseResponse):
try:
visibility_timeout = self._get_validated_visibility_timeout()
except TypeError:
visibility_timeout = queue.visibility_timeout
visibility_timeout = queue.visibility_timeout # type: ignore
except ValueError:
return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400)
@ -406,7 +409,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(RECEIVE_MESSAGE_RESPONSE)
return template.render(messages=messages, attributes=attributes)
def list_dead_letter_source_queues(self):
def list_dead_letter_source_queues(self) -> str:
request_url = urlparse(self.uri)
queue_name = self._get_queue_name()
@ -415,7 +418,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(LIST_DEAD_LETTER_SOURCE_QUEUES_RESPONSE)
return template.render(queues=source_queue_urls, request_url=request_url)
def add_permission(self):
def add_permission(self) -> str:
queue_name = self._get_queue_name()
actions = self._get_multi_param("ActionName")
account_ids = self._get_multi_param("AWSAccountId")
@ -426,7 +429,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(ADD_PERMISSION_RESPONSE)
return template.render()
def remove_permission(self):
def remove_permission(self) -> str:
queue_name = self._get_queue_name()
label = self._get_param("Label")
@ -435,7 +438,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(REMOVE_PERMISSION_RESPONSE)
return template.render()
def tag_queue(self):
def tag_queue(self) -> str:
queue_name = self._get_queue_name()
tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value")
@ -444,7 +447,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(TAG_QUEUE_RESPONSE)
return template.render()
def untag_queue(self):
def untag_queue(self) -> str:
queue_name = self._get_queue_name()
tag_keys = self._get_multi_param("TagKey")
@ -453,7 +456,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(UNTAG_QUEUE_RESPONSE)
return template.render()
def list_queue_tags(self):
def list_queue_tags(self) -> str:
queue_name = self._get_queue_name()
queue = self.sqs_backend.list_queue_tags(queue_name)

View File

@ -1,15 +1,17 @@
import string
from typing import Any, Dict, List
from moto.moto_api._internal import mock_random as random
from .exceptions import MessageAttributesInvalid
def generate_receipt_handle():
def generate_receipt_handle() -> str:
# http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ImportantIdentifiers.html#ImportantIdentifiers-receipt-handles
length = 185
return "".join(random.choice(string.ascii_lowercase) for x in range(length))
def extract_input_message_attributes(querystring):
def extract_input_message_attributes(querystring: Dict[str, Any]) -> List[str]:
message_attributes = []
index = 1
while True:
@ -25,8 +27,11 @@ def extract_input_message_attributes(querystring):
def parse_message_attributes(
querystring, key="MessageAttribute", base="", value_namespace="Value."
):
querystring: Dict[str, Any],
key: str = "MessageAttribute",
base: str = "",
value_namespace: str = "Value.",
) -> Dict[str, Any]:
message_attributes = {}
index = 1
while True:

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/ssm,moto/scheduler
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/secretsmanager,moto/sqs,moto/ssm,moto/scheduler
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract