11883a1fda
* Add missing dependencies for EFS
1054 lines
35 KiB
Python
1054 lines
35 KiB
Python
from __future__ import unicode_literals
|
|
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import random
|
|
import re
|
|
import string
|
|
|
|
import struct
|
|
from copy import deepcopy
|
|
from xml.sax.saxutils import escape
|
|
|
|
from boto3 import Session
|
|
|
|
from moto.core.exceptions import RESTError
|
|
from moto.core import BaseBackend, BaseModel, CloudFormationModel
|
|
from moto.core.utils import (
|
|
camelcase_to_underscores,
|
|
get_random_message_id,
|
|
unix_time,
|
|
unix_time_millis,
|
|
tags_from_cloudformation_tags_list,
|
|
)
|
|
from .utils import generate_receipt_handle
|
|
from .exceptions import (
|
|
MessageAttributesInvalid,
|
|
MessageNotInflight,
|
|
QueueDoesNotExist,
|
|
QueueAlreadyExists,
|
|
ReceiptHandleIsInvalid,
|
|
InvalidBatchEntryId,
|
|
BatchRequestTooLong,
|
|
BatchEntryIdsNotDistinct,
|
|
TooManyEntriesInBatchRequest,
|
|
InvalidAttributeName,
|
|
InvalidParameterValue,
|
|
MissingParameter,
|
|
OverLimit,
|
|
InvalidAttributeValue,
|
|
)
|
|
|
|
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
|
|
|
|
DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU"
|
|
|
|
MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB
|
|
|
|
MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND = 1024
|
|
MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND = MAXIMUM_MESSAGE_LENGTH
|
|
|
|
TRANSPORT_TYPE_ENCODINGS = {
|
|
"String": b"\x01",
|
|
"Binary": b"\x02",
|
|
"Number": b"\x01",
|
|
"String.custom": b"\x01",
|
|
}
|
|
|
|
STRING_TYPE_FIELD_INDEX = 1
|
|
BINARY_TYPE_FIELD_INDEX = 2
|
|
STRING_LIST_TYPE_FIELD_INDEX = 3
|
|
BINARY_LIST_TYPE_FIELD_INDEX = 4
|
|
|
|
# Valid attribute name rules can found at
|
|
# https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html
|
|
ATTRIBUTE_NAME_PATTERN = re.compile("^([a-z]|[A-Z]|[0-9]|[_.\\-])+$")
|
|
|
|
DEDUPLICATION_TIME_IN_SECONDS = 300
|
|
|
|
|
|
class Message(BaseModel):
|
|
def __init__(self, message_id, body, system_attributes={}):
|
|
self.id = message_id
|
|
self._body = body
|
|
self.message_attributes = {}
|
|
self.receipt_handle = None
|
|
self.sender_id = DEFAULT_SENDER_ID
|
|
self.sent_timestamp = None
|
|
self.approximate_first_receive_timestamp = None
|
|
self.approximate_receive_count = 0
|
|
self.deduplication_id = None
|
|
self.group_id = None
|
|
self.sequence_number = None
|
|
self.visible_at = 0
|
|
self.delayed_until = 0
|
|
self.system_attributes = system_attributes
|
|
|
|
@property
|
|
def body_md5(self):
|
|
md5 = hashlib.md5()
|
|
md5.update(self._body.encode("utf-8"))
|
|
return md5.hexdigest()
|
|
|
|
@property
|
|
def attribute_md5(self):
|
|
|
|
md5 = hashlib.md5()
|
|
|
|
for attrName in sorted(self.message_attributes.keys()):
|
|
self.validate_attribute_name(attrName)
|
|
attrValue = self.message_attributes[attrName]
|
|
# Encode name
|
|
self.update_binary_length_and_value(md5, self.utf8(attrName))
|
|
# Encode type
|
|
self.update_binary_length_and_value(md5, self.utf8(attrValue["data_type"]))
|
|
|
|
if attrValue.get("string_value"):
|
|
md5.update(bytearray([STRING_TYPE_FIELD_INDEX]))
|
|
self.update_binary_length_and_value(
|
|
md5, self.utf8(attrValue.get("string_value"))
|
|
)
|
|
elif attrValue.get("binary_value"):
|
|
md5.update(bytearray([BINARY_TYPE_FIELD_INDEX]))
|
|
decoded_binary_value = base64.b64decode(attrValue.get("binary_value"))
|
|
self.update_binary_length_and_value(md5, decoded_binary_value)
|
|
# string_list_value type is not implemented, reserved for the future use.
|
|
# See https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_MessageAttributeValue.html
|
|
elif len(attrValue["string_list_value"]) > 0:
|
|
md5.update(bytearray([STRING_LIST_TYPE_FIELD_INDEX]))
|
|
for strListMember in attrValue["string_list_value"]:
|
|
self.update_binary_length_and_value(md5, self.utf8(strListMember))
|
|
# binary_list_value type is not implemented, reserved for the future use.
|
|
# See https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_MessageAttributeValue.html
|
|
elif len(attrValue["binary_list_value"]) > 0:
|
|
md5.update(bytearray([BINARY_LIST_TYPE_FIELD_INDEX]))
|
|
for strListMember in attrValue["binary_list_value"]:
|
|
decoded_binary_value = base64.b64decode(strListMember)
|
|
self.update_binary_length_and_value(md5, decoded_binary_value)
|
|
|
|
return md5.hexdigest()
|
|
|
|
@staticmethod
|
|
def update_binary_length_and_value(md5, value):
|
|
length_bytes = struct.pack("!I".encode("ascii"), len(value))
|
|
md5.update(length_bytes)
|
|
md5.update(value)
|
|
|
|
@staticmethod
|
|
def validate_attribute_name(name):
|
|
if not ATTRIBUTE_NAME_PATTERN.match(name):
|
|
raise MessageAttributesInvalid(
|
|
"The message attribute name '{0}' is invalid. "
|
|
"Attribute name can contain A-Z, a-z, 0-9, "
|
|
"underscore (_), hyphen (-), and period (.) characters.".format(name)
|
|
)
|
|
|
|
@staticmethod
|
|
def utf8(string):
|
|
if isinstance(string, str):
|
|
return string.encode("utf-8")
|
|
return string
|
|
|
|
@property
|
|
def body(self):
|
|
return escape(self._body)
|
|
|
|
def mark_sent(self, delay_seconds=None):
|
|
self.sent_timestamp = int(unix_time_millis())
|
|
if delay_seconds:
|
|
self.delay(delay_seconds=delay_seconds)
|
|
|
|
def mark_received(self, visibility_timeout=None):
|
|
"""
|
|
When a message is received we will set the first receive timestamp,
|
|
tap the ``approximate_receive_count`` and the ``visible_at`` time.
|
|
"""
|
|
if visibility_timeout:
|
|
visibility_timeout = int(visibility_timeout)
|
|
else:
|
|
visibility_timeout = 0
|
|
|
|
if not self.approximate_first_receive_timestamp:
|
|
self.approximate_first_receive_timestamp = int(unix_time_millis())
|
|
|
|
self.approximate_receive_count += 1
|
|
|
|
# Make message visible again in the future unless its
|
|
# destroyed.
|
|
if visibility_timeout:
|
|
self.change_visibility(visibility_timeout)
|
|
|
|
self.receipt_handle = generate_receipt_handle()
|
|
|
|
def change_visibility(self, visibility_timeout):
|
|
# We're dealing with milliseconds internally
|
|
visibility_timeout_msec = int(visibility_timeout) * 1000
|
|
self.visible_at = unix_time_millis() + visibility_timeout_msec
|
|
|
|
def delay(self, delay_seconds):
|
|
delay_msec = int(delay_seconds) * 1000
|
|
self.delayed_until = unix_time_millis() + delay_msec
|
|
|
|
@property
|
|
def visible(self):
|
|
current_time = unix_time_millis()
|
|
if current_time > self.visible_at:
|
|
return True
|
|
return False
|
|
|
|
@property
|
|
def delayed(self):
|
|
current_time = unix_time_millis()
|
|
if current_time < self.delayed_until:
|
|
return True
|
|
return False
|
|
|
|
|
|
class Queue(CloudFormationModel):
|
|
BASE_ATTRIBUTES = [
|
|
"ApproximateNumberOfMessages",
|
|
"ApproximateNumberOfMessagesDelayed",
|
|
"ApproximateNumberOfMessagesNotVisible",
|
|
"CreatedTimestamp",
|
|
"DelaySeconds",
|
|
"LastModifiedTimestamp",
|
|
"MaximumMessageSize",
|
|
"MessageRetentionPeriod",
|
|
"QueueArn",
|
|
"Policy",
|
|
"RedrivePolicy",
|
|
"ReceiveMessageWaitTimeSeconds",
|
|
"VisibilityTimeout",
|
|
]
|
|
FIFO_ATTRIBUTES = ["FifoQueue", "ContentBasedDeduplication"]
|
|
KMS_ATTRIBUTES = ["KmsDataKeyReusePeriodSeconds", "KmsMasterKeyId"]
|
|
ALLOWED_PERMISSIONS = (
|
|
"*",
|
|
"ChangeMessageVisibility",
|
|
"DeleteMessage",
|
|
"GetQueueAttributes",
|
|
"GetQueueUrl",
|
|
"ListDeadLetterSourceQueues",
|
|
"PurgeQueue",
|
|
"ReceiveMessage",
|
|
"SendMessage",
|
|
)
|
|
|
|
def __init__(self, name, region, **kwargs):
|
|
self.name = name
|
|
self.region = region
|
|
self.tags = {}
|
|
self.permissions = {}
|
|
|
|
self._messages = []
|
|
self._pending_messages = set()
|
|
|
|
now = unix_time()
|
|
self.created_timestamp = now
|
|
self.queue_arn = "arn:aws:sqs:{0}:{1}:{2}".format(
|
|
self.region, DEFAULT_ACCOUNT_ID, self.name
|
|
)
|
|
self.dead_letter_queue = None
|
|
|
|
self.lambda_event_source_mappings = {}
|
|
|
|
# default settings for a non fifo queue
|
|
defaults = {
|
|
"ContentBasedDeduplication": "false",
|
|
"DelaySeconds": 0,
|
|
"FifoQueue": "false",
|
|
"KmsDataKeyReusePeriodSeconds": 300, # five minutes
|
|
"KmsMasterKeyId": None,
|
|
"MaximumMessageSize": MAXIMUM_MESSAGE_LENGTH,
|
|
"MessageRetentionPeriod": 86400 * 4, # four days
|
|
"Policy": None,
|
|
"ReceiveMessageWaitTimeSeconds": 0,
|
|
"RedrivePolicy": None,
|
|
"VisibilityTimeout": 30,
|
|
}
|
|
|
|
defaults.update(kwargs)
|
|
self._set_attributes(defaults, now)
|
|
|
|
# Check some conditions
|
|
if self.fifo_queue and not self.name.endswith(".fifo"):
|
|
raise InvalidParameterValue("Queue name must end in .fifo for FIFO queues")
|
|
if (
|
|
self.maximum_message_size < MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND
|
|
or self.maximum_message_size > MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND
|
|
):
|
|
raise InvalidAttributeValue("MaximumMessageSize")
|
|
|
|
@property
|
|
def pending_messages(self):
|
|
return self._pending_messages
|
|
|
|
@property
|
|
def pending_message_groups(self):
|
|
return set(
|
|
message.group_id
|
|
for message in self._pending_messages
|
|
if message.group_id is not None
|
|
)
|
|
|
|
def _set_attributes(self, attributes, now=None):
|
|
if not now:
|
|
now = unix_time()
|
|
|
|
integer_fields = (
|
|
"DelaySeconds",
|
|
"KmsDataKeyreusePeriodSeconds",
|
|
"MaximumMessageSize",
|
|
"MessageRetentionPeriod",
|
|
"ReceiveMessageWaitTime",
|
|
"VisibilityTimeout",
|
|
)
|
|
bool_fields = ("ContentBasedDeduplication", "FifoQueue")
|
|
|
|
for key, value in attributes.items():
|
|
if key in integer_fields:
|
|
value = int(value)
|
|
if key in bool_fields:
|
|
value = value == "true"
|
|
|
|
if key in ["Policy", "RedrivePolicy"] and value is not None:
|
|
continue
|
|
|
|
setattr(self, camelcase_to_underscores(key), value)
|
|
|
|
if attributes.get("RedrivePolicy", None):
|
|
self._setup_dlq(attributes["RedrivePolicy"])
|
|
|
|
if attributes.get("Policy"):
|
|
self.policy = attributes["Policy"]
|
|
|
|
self.last_modified_timestamp = now
|
|
|
|
def _setup_dlq(self, policy):
|
|
|
|
if isinstance(policy, str):
|
|
try:
|
|
self.redrive_policy = json.loads(policy)
|
|
except ValueError:
|
|
raise RESTError(
|
|
"InvalidParameterValue",
|
|
"Redrive policy is not a dict or valid json",
|
|
)
|
|
elif isinstance(policy, dict):
|
|
self.redrive_policy = policy
|
|
else:
|
|
raise RESTError(
|
|
"InvalidParameterValue", "Redrive policy is not a dict or valid json"
|
|
)
|
|
|
|
if "deadLetterTargetArn" not in self.redrive_policy:
|
|
raise RESTError(
|
|
"InvalidParameterValue",
|
|
"Redrive policy does not contain deadLetterTargetArn",
|
|
)
|
|
if "maxReceiveCount" not in self.redrive_policy:
|
|
raise RESTError(
|
|
"InvalidParameterValue",
|
|
"Redrive policy does not contain maxReceiveCount",
|
|
)
|
|
|
|
# 'maxReceiveCount' is stored as int
|
|
self.redrive_policy["maxReceiveCount"] = int(
|
|
self.redrive_policy["maxReceiveCount"]
|
|
)
|
|
|
|
for queue in sqs_backends[self.region].queues.values():
|
|
if queue.queue_arn == self.redrive_policy["deadLetterTargetArn"]:
|
|
self.dead_letter_queue = queue
|
|
|
|
if self.fifo_queue and not queue.fifo_queue:
|
|
raise RESTError(
|
|
"InvalidParameterCombination",
|
|
"Fifo queues cannot use non fifo dead letter queues",
|
|
)
|
|
break
|
|
else:
|
|
raise RESTError(
|
|
"AWS.SimpleQueueService.NonExistentQueue",
|
|
"Could not find DLQ for {0}".format(
|
|
self.redrive_policy["deadLetterTargetArn"]
|
|
),
|
|
)
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type():
|
|
return "QueueName"
|
|
|
|
@staticmethod
|
|
def cloudformation_type():
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sqs-queue.html
|
|
return "AWS::SQS::Queue"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, region_name
|
|
):
|
|
properties = deepcopy(cloudformation_json["Properties"])
|
|
# remove Tags from properties and convert tags list to dict
|
|
tags = properties.pop("Tags", [])
|
|
tags_dict = tags_from_cloudformation_tags_list(tags)
|
|
|
|
sqs_backend = sqs_backends[region_name]
|
|
return sqs_backend.create_queue(
|
|
name=resource_name, tags=tags_dict, region=region_name, **properties
|
|
)
|
|
|
|
@classmethod
|
|
def update_from_cloudformation_json(
|
|
cls, original_resource, new_resource_name, cloudformation_json, region_name
|
|
):
|
|
properties = cloudformation_json["Properties"]
|
|
queue_name = original_resource.name
|
|
|
|
sqs_backend = sqs_backends[region_name]
|
|
queue = sqs_backend.get_queue(queue_name)
|
|
if "VisibilityTimeout" in properties:
|
|
queue.visibility_timeout = int(properties["VisibilityTimeout"])
|
|
|
|
if "ReceiveMessageWaitTimeSeconds" in properties:
|
|
queue.receive_message_wait_time_seconds = int(
|
|
properties["ReceiveMessageWaitTimeSeconds"]
|
|
)
|
|
return queue
|
|
|
|
@classmethod
|
|
def delete_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, region_name
|
|
):
|
|
sqs_backend = sqs_backends[region_name]
|
|
sqs_backend.delete_queue(resource_name)
|
|
|
|
@property
|
|
def approximate_number_of_messages_delayed(self):
|
|
return len([m for m in self._messages if m.delayed])
|
|
|
|
@property
|
|
def approximate_number_of_messages_not_visible(self):
|
|
return len([m for m in self._messages if not m.visible])
|
|
|
|
@property
|
|
def approximate_number_of_messages(self):
|
|
return len(self.messages)
|
|
|
|
@property
|
|
def physical_resource_id(self):
|
|
return self.name
|
|
|
|
@property
|
|
def attributes(self):
|
|
result = {}
|
|
|
|
for attribute in self.BASE_ATTRIBUTES:
|
|
attr = getattr(self, camelcase_to_underscores(attribute))
|
|
result[attribute] = attr
|
|
|
|
if self.fifo_queue:
|
|
for attribute in self.FIFO_ATTRIBUTES:
|
|
attr = getattr(self, camelcase_to_underscores(attribute))
|
|
result[attribute] = attr
|
|
|
|
if self.kms_master_key_id:
|
|
for attribute in self.KMS_ATTRIBUTES:
|
|
attr = getattr(self, camelcase_to_underscores(attribute))
|
|
result[attribute] = attr
|
|
|
|
if self.policy:
|
|
result["Policy"] = self.policy
|
|
|
|
if self.redrive_policy:
|
|
result["RedrivePolicy"] = json.dumps(self.redrive_policy)
|
|
|
|
for key in result:
|
|
if isinstance(result[key], bool):
|
|
result[key] = str(result[key]).lower()
|
|
|
|
return result
|
|
|
|
def url(self, request_url):
|
|
return "{0}://{1}/{2}/{3}".format(
|
|
request_url.scheme, request_url.netloc, DEFAULT_ACCOUNT_ID, self.name
|
|
)
|
|
|
|
@property
|
|
def messages(self):
|
|
# TODO: This can become very inefficient if a large number of messages are in-flight
|
|
return [
|
|
message
|
|
for message in self._messages
|
|
if message.visible and not message.delayed
|
|
]
|
|
|
|
def add_message(self, message):
|
|
if (
|
|
self.fifo_queue
|
|
and self.attributes.get("ContentBasedDeduplication") == "true"
|
|
):
|
|
for m in self._messages:
|
|
if m.deduplication_id == message.deduplication_id:
|
|
diff = message.sent_timestamp - m.sent_timestamp
|
|
# if a duplicate message is received within the deduplication time then it should
|
|
# not be added to the queue
|
|
if diff / 1000 < DEDUPLICATION_TIME_IN_SECONDS:
|
|
return
|
|
|
|
self._messages.append(message)
|
|
|
|
for arn, esm in self.lambda_event_source_mappings.items():
|
|
backend = sqs_backends[self.region]
|
|
|
|
"""
|
|
Lambda polls the queue and invokes your function synchronously with an event
|
|
that contains queue messages. Lambda reads messages in batches and invokes
|
|
your function once for each batch. When your function successfully processes
|
|
a batch, Lambda deletes its messages from the queue.
|
|
"""
|
|
messages = backend.receive_messages(
|
|
self.name,
|
|
esm.batch_size,
|
|
self.receive_message_wait_time_seconds,
|
|
self.visibility_timeout,
|
|
)
|
|
|
|
from moto.awslambda import lambda_backends
|
|
|
|
result = lambda_backends[self.region].send_sqs_batch(
|
|
arn, messages, self.queue_arn
|
|
)
|
|
|
|
if result:
|
|
[backend.delete_message(self.name, m.receipt_handle) for m in messages]
|
|
else:
|
|
# Make messages visible again
|
|
[m.change_visibility(visibility_timeout=0) for m in messages]
|
|
|
|
def get_cfn_attribute(self, attribute_name):
|
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
|
|
|
if attribute_name == "Arn":
|
|
return self.queue_arn
|
|
elif attribute_name == "QueueName":
|
|
return self.name
|
|
raise UnformattedGetAttTemplateException()
|
|
|
|
@property
|
|
def policy(self):
|
|
if self._policy_json.get("Statement"):
|
|
return json.dumps(self._policy_json)
|
|
else:
|
|
return None
|
|
|
|
@policy.setter
|
|
def policy(self, policy):
|
|
if policy:
|
|
self._policy_json = json.loads(policy)
|
|
else:
|
|
self._policy_json = {
|
|
"Version": "2012-10-17",
|
|
"Id": "{}/SQSDefaultPolicy".format(self.queue_arn),
|
|
"Statement": [],
|
|
}
|
|
|
|
|
|
def _filter_message_attributes(message, input_message_attributes):
|
|
filtered_message_attributes = {}
|
|
return_all = "All" in input_message_attributes
|
|
for key, value in message.message_attributes.items():
|
|
if return_all or key in input_message_attributes:
|
|
filtered_message_attributes[key] = value
|
|
message.message_attributes = filtered_message_attributes
|
|
|
|
|
|
class SQSBackend(BaseBackend):
|
|
def __init__(self, region_name):
|
|
self.region_name = region_name
|
|
self.queues = {}
|
|
super(SQSBackend, self).__init__()
|
|
|
|
def reset(self):
|
|
region_name = self.region_name
|
|
self._reset_model_refs()
|
|
self.__dict__ = {}
|
|
self.__init__(region_name)
|
|
|
|
def create_queue(self, name, tags=None, **kwargs):
|
|
queue = self.queues.get(name)
|
|
if queue:
|
|
try:
|
|
kwargs.pop("region")
|
|
except KeyError:
|
|
pass
|
|
|
|
new_queue = Queue(name, region=self.region_name, **kwargs)
|
|
|
|
queue_attributes = queue.attributes
|
|
new_queue_attributes = new_queue.attributes
|
|
|
|
# only the attributes which are being sent for the queue
|
|
# creation have to be compared if the queue is existing.
|
|
for key in kwargs:
|
|
if queue_attributes.get(key) != new_queue_attributes.get(key):
|
|
raise QueueAlreadyExists("The specified queue already exists.")
|
|
else:
|
|
try:
|
|
kwargs.pop("region")
|
|
except KeyError:
|
|
pass
|
|
queue = Queue(name, region=self.region_name, **kwargs)
|
|
self.queues[name] = queue
|
|
|
|
if tags:
|
|
queue.tags = tags
|
|
|
|
return queue
|
|
|
|
def get_queue_url(self, queue_name):
|
|
return self.get_queue(queue_name)
|
|
|
|
def list_queues(self, queue_name_prefix):
|
|
re_str = ".*"
|
|
if queue_name_prefix:
|
|
re_str = "^{0}.*".format(queue_name_prefix)
|
|
prefix_re = re.compile(re_str)
|
|
qs = []
|
|
for name, q in self.queues.items():
|
|
if prefix_re.search(name):
|
|
qs.append(q)
|
|
return qs[:1000]
|
|
|
|
def get_queue(self, queue_name):
|
|
queue = self.queues.get(queue_name)
|
|
if queue is None:
|
|
raise QueueDoesNotExist()
|
|
return queue
|
|
|
|
def delete_queue(self, queue_name):
|
|
if queue_name in self.queues:
|
|
return self.queues.pop(queue_name)
|
|
return False
|
|
|
|
def get_queue_attributes(self, queue_name, attribute_names):
|
|
queue = self.get_queue(queue_name)
|
|
|
|
if not len(attribute_names):
|
|
attribute_names.append("All")
|
|
|
|
valid_names = (
|
|
["All"]
|
|
+ queue.BASE_ATTRIBUTES
|
|
+ queue.FIFO_ATTRIBUTES
|
|
+ queue.KMS_ATTRIBUTES
|
|
)
|
|
invalid_name = next(
|
|
(name for name in attribute_names if name not in valid_names), None
|
|
)
|
|
|
|
if invalid_name or invalid_name == "":
|
|
raise InvalidAttributeName(invalid_name)
|
|
|
|
attributes = {}
|
|
|
|
if "All" in attribute_names:
|
|
attributes = queue.attributes
|
|
else:
|
|
for name in (name for name in attribute_names if name in queue.attributes):
|
|
if queue.attributes.get(name) is not None:
|
|
attributes[name] = queue.attributes.get(name)
|
|
|
|
return attributes
|
|
|
|
def set_queue_attributes(self, queue_name, attributes):
|
|
queue = self.get_queue(queue_name)
|
|
queue._set_attributes(attributes)
|
|
return queue
|
|
|
|
def send_message(
|
|
self,
|
|
queue_name,
|
|
message_body,
|
|
message_attributes=None,
|
|
delay_seconds=None,
|
|
deduplication_id=None,
|
|
group_id=None,
|
|
system_attributes=None,
|
|
):
|
|
|
|
queue = self.get_queue(queue_name)
|
|
|
|
if len(message_body) > queue.maximum_message_size:
|
|
msg = "One or more parameters are invalid. Reason: Message must be shorter than {} bytes.".format(
|
|
queue.maximum_message_size
|
|
)
|
|
raise InvalidParameterValue(msg)
|
|
|
|
if delay_seconds:
|
|
delay_seconds = int(delay_seconds)
|
|
else:
|
|
delay_seconds = queue.delay_seconds
|
|
|
|
message_id = get_random_message_id()
|
|
message = Message(message_id, message_body, system_attributes)
|
|
|
|
# if content based deduplication is set then set sha256 hash of the message
|
|
# as the deduplication_id
|
|
if queue.attributes.get("ContentBasedDeduplication") == "true":
|
|
sha256 = hashlib.sha256()
|
|
sha256.update(message_body.encode("utf-8"))
|
|
message.deduplication_id = sha256.hexdigest()
|
|
|
|
# Attributes, but not *message* attributes
|
|
if deduplication_id is not None:
|
|
message.deduplication_id = deduplication_id
|
|
message.sequence_number = "".join(
|
|
random.choice(string.digits) for _ in range(20)
|
|
)
|
|
|
|
if group_id is None:
|
|
# MessageGroupId is a mandatory parameter for all
|
|
# messages in a fifo queue
|
|
if queue.fifo_queue:
|
|
raise MissingParameter("MessageGroupId")
|
|
else:
|
|
message.group_id = group_id
|
|
|
|
if message_attributes:
|
|
message.message_attributes = message_attributes
|
|
|
|
message.mark_sent(delay_seconds=delay_seconds)
|
|
|
|
queue.add_message(message)
|
|
|
|
return message
|
|
|
|
def send_message_batch(self, queue_name, entries):
|
|
self.get_queue(queue_name)
|
|
|
|
if any(
|
|
not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values()
|
|
):
|
|
raise InvalidBatchEntryId()
|
|
|
|
body_length = next(
|
|
(
|
|
len(entry["MessageBody"])
|
|
for entry in entries.values()
|
|
if len(entry["MessageBody"]) > MAXIMUM_MESSAGE_LENGTH
|
|
),
|
|
False,
|
|
)
|
|
if body_length:
|
|
raise BatchRequestTooLong(body_length)
|
|
|
|
duplicate_id = self._get_first_duplicate_id(
|
|
[entry["Id"] for entry in entries.values()]
|
|
)
|
|
if duplicate_id:
|
|
raise BatchEntryIdsNotDistinct(duplicate_id)
|
|
|
|
if len(entries) > 10:
|
|
raise TooManyEntriesInBatchRequest(len(entries))
|
|
|
|
messages = []
|
|
for index, entry in entries.items():
|
|
# Loop through looking for messages
|
|
message = self.send_message(
|
|
queue_name,
|
|
entry["MessageBody"],
|
|
message_attributes=entry["MessageAttributes"],
|
|
delay_seconds=entry["DelaySeconds"],
|
|
group_id=entry.get("MessageGroupId"),
|
|
deduplication_id=entry.get("MessageDeduplicationId"),
|
|
)
|
|
message.user_id = entry["Id"]
|
|
|
|
messages.append(message)
|
|
|
|
return messages
|
|
|
|
def _get_first_duplicate_id(self, ids):
|
|
unique_ids = set()
|
|
for id in ids:
|
|
if id in unique_ids:
|
|
return id
|
|
unique_ids.add(id)
|
|
return None
|
|
|
|
def receive_messages(
|
|
self,
|
|
queue_name,
|
|
count,
|
|
wait_seconds_timeout,
|
|
visibility_timeout,
|
|
message_attribute_names=None,
|
|
):
|
|
"""
|
|
Attempt to retrieve visible messages from a queue.
|
|
|
|
If a message was read by client and not deleted it is considered to be
|
|
"inflight" and cannot be read. We make attempts to obtain ``count``
|
|
messages but we may return less if messages are in-flight or there
|
|
are simple not enough messages in the queue.
|
|
|
|
:param string queue_name: The name of the queue to read from.
|
|
:param int count: The maximum amount of messages to retrieve.
|
|
:param int visibility_timeout: The number of seconds the message should remain invisible to other queue readers.
|
|
:param int wait_seconds_timeout: The duration (in seconds) for which the call waits for a message to arrive in
|
|
the queue before returning. If a message is available, the call returns sooner than WaitTimeSeconds
|
|
"""
|
|
if message_attribute_names is None:
|
|
message_attribute_names = []
|
|
queue = self.get_queue(queue_name)
|
|
result = []
|
|
previous_result_count = len(result)
|
|
|
|
polling_end = unix_time() + wait_seconds_timeout
|
|
currently_pending_groups = deepcopy(queue.pending_message_groups)
|
|
|
|
# queue.messages only contains visible messages
|
|
while True:
|
|
|
|
if result or (wait_seconds_timeout and unix_time() > polling_end):
|
|
break
|
|
|
|
messages_to_dlq = []
|
|
|
|
for message in queue.messages:
|
|
if not message.visible:
|
|
continue
|
|
|
|
if message in queue.pending_messages:
|
|
# The message is pending but is visible again, so the
|
|
# consumer must have timed out.
|
|
queue.pending_messages.remove(message)
|
|
currently_pending_groups = deepcopy(queue.pending_message_groups)
|
|
|
|
if message.group_id and queue.fifo_queue:
|
|
if message.group_id in currently_pending_groups:
|
|
# A previous call is still processing messages in this group, so we cannot deliver this one.
|
|
continue
|
|
|
|
if (
|
|
queue.dead_letter_queue is not None
|
|
and queue.redrive_policy
|
|
and message.approximate_receive_count
|
|
>= queue.redrive_policy["maxReceiveCount"]
|
|
):
|
|
messages_to_dlq.append(message)
|
|
continue
|
|
|
|
queue.pending_messages.add(message)
|
|
message.mark_received(visibility_timeout=visibility_timeout)
|
|
_filter_message_attributes(message, message_attribute_names)
|
|
if not self.is_message_valid_based_on_retention_period(
|
|
queue_name, message
|
|
):
|
|
break
|
|
result.append(message)
|
|
if len(result) >= count:
|
|
break
|
|
|
|
for message in messages_to_dlq:
|
|
queue._messages.remove(message)
|
|
queue.dead_letter_queue.add_message(message)
|
|
|
|
if previous_result_count == len(result):
|
|
if wait_seconds_timeout == 0:
|
|
# There is timeout and we have added no additional results,
|
|
# so break to avoid an infinite loop.
|
|
break
|
|
|
|
import time
|
|
|
|
time.sleep(0.01)
|
|
continue
|
|
|
|
previous_result_count = len(result)
|
|
|
|
return result
|
|
|
|
def delete_message(self, queue_name, receipt_handle):
|
|
queue = self.get_queue(queue_name)
|
|
|
|
if not any(
|
|
message.receipt_handle == receipt_handle for message in queue._messages
|
|
):
|
|
raise ReceiptHandleIsInvalid()
|
|
|
|
new_messages = []
|
|
for message in queue._messages:
|
|
# Only delete message if it is not visible and the receipt_handle
|
|
# matches.
|
|
if message.receipt_handle == receipt_handle:
|
|
queue.pending_messages.remove(message)
|
|
continue
|
|
new_messages.append(message)
|
|
queue._messages = new_messages
|
|
|
|
def change_message_visibility(self, queue_name, receipt_handle, visibility_timeout):
|
|
queue = self.get_queue(queue_name)
|
|
for message in queue._messages:
|
|
if message.receipt_handle == receipt_handle:
|
|
if message.visible:
|
|
raise MessageNotInflight
|
|
|
|
visibility_timeout_msec = int(visibility_timeout) * 1000
|
|
given_visibility_timeout = unix_time_millis() + visibility_timeout_msec
|
|
if given_visibility_timeout - message.sent_timestamp > 43200 * 1000:
|
|
raise InvalidParameterValue(
|
|
"Value {0} for parameter VisibilityTimeout is invalid. Reason: Total "
|
|
"VisibilityTimeout for the message is beyond the limit [43200 seconds]".format(
|
|
visibility_timeout
|
|
)
|
|
)
|
|
|
|
message.change_visibility(visibility_timeout)
|
|
if message.visible:
|
|
# If the message is visible again, remove it from pending
|
|
# messages.
|
|
queue.pending_messages.remove(message)
|
|
return
|
|
raise ReceiptHandleIsInvalid
|
|
|
|
def purge_queue(self, queue_name):
|
|
queue = self.get_queue(queue_name)
|
|
queue._messages = []
|
|
queue._pending_messages = set()
|
|
|
|
def list_dead_letter_source_queues(self, queue_name):
|
|
dlq = self.get_queue(queue_name)
|
|
|
|
queues = []
|
|
for queue in self.queues.values():
|
|
if queue.dead_letter_queue is dlq:
|
|
queues.append(queue)
|
|
|
|
return queues
|
|
|
|
def add_permission(self, queue_name, actions, account_ids, label):
|
|
queue = self.get_queue(queue_name)
|
|
|
|
if not actions:
|
|
raise MissingParameter("Actions")
|
|
|
|
if not account_ids:
|
|
raise InvalidParameterValue(
|
|
"Value [] for parameter PrincipalId is invalid. Reason: Unable to verify."
|
|
)
|
|
|
|
count = len(actions)
|
|
if count > 7:
|
|
raise OverLimit(count)
|
|
|
|
invalid_action = next(
|
|
(action for action in actions if action not in Queue.ALLOWED_PERMISSIONS),
|
|
None,
|
|
)
|
|
if invalid_action:
|
|
raise InvalidParameterValue(
|
|
"Value SQS:{} for parameter ActionName is invalid. "
|
|
"Reason: Only the queue owner is allowed to invoke this action.".format(
|
|
invalid_action
|
|
)
|
|
)
|
|
|
|
policy = queue._policy_json
|
|
statement = next(
|
|
(
|
|
statement
|
|
for statement in policy["Statement"]
|
|
if statement["Sid"] == label
|
|
),
|
|
None,
|
|
)
|
|
if statement:
|
|
raise InvalidParameterValue(
|
|
"Value {} for parameter Label is invalid. "
|
|
"Reason: Already exists.".format(label)
|
|
)
|
|
|
|
principals = [
|
|
"arn:aws:iam::{}:root".format(account_id) for account_id in account_ids
|
|
]
|
|
actions = ["SQS:{}".format(action) for action in actions]
|
|
|
|
statement = {
|
|
"Sid": label,
|
|
"Effect": "Allow",
|
|
"Principal": {"AWS": principals[0] if len(principals) == 1 else principals},
|
|
"Action": actions[0] if len(actions) == 1 else actions,
|
|
"Resource": queue.queue_arn,
|
|
}
|
|
|
|
queue._policy_json["Statement"].append(statement)
|
|
|
|
def remove_permission(self, queue_name, label):
|
|
queue = self.get_queue(queue_name)
|
|
|
|
statements = queue._policy_json["Statement"]
|
|
statements_new = [
|
|
statement for statement in statements if statement["Sid"] != label
|
|
]
|
|
|
|
if len(statements) == len(statements_new):
|
|
raise InvalidParameterValue(
|
|
"Value {} for parameter Label is invalid. "
|
|
"Reason: can't find label on existing policy.".format(label)
|
|
)
|
|
|
|
queue._policy_json["Statement"] = statements_new
|
|
|
|
def tag_queue(self, queue_name, tags):
|
|
queue = self.get_queue(queue_name)
|
|
|
|
if not len(tags):
|
|
raise MissingParameter("Tags")
|
|
|
|
if len(tags) > 50:
|
|
raise InvalidParameterValue(
|
|
"Too many tags added for queue {}.".format(queue_name)
|
|
)
|
|
|
|
queue.tags.update(tags)
|
|
|
|
def untag_queue(self, queue_name, tag_keys):
|
|
queue = self.get_queue(queue_name)
|
|
|
|
if not len(tag_keys):
|
|
raise RESTError(
|
|
"InvalidParameterValue",
|
|
"Tag keys must be between 1 and 128 characters in length.",
|
|
)
|
|
|
|
for key in tag_keys:
|
|
try:
|
|
del queue.tags[key]
|
|
except KeyError:
|
|
pass
|
|
|
|
def list_queue_tags(self, queue_name):
|
|
return self.get_queue(queue_name)
|
|
|
|
def is_message_valid_based_on_retention_period(self, queue_name, message):
|
|
message_attributes = self.get_queue_attributes(queue_name, [])
|
|
retain_until = (
|
|
message_attributes.get("MessageRetentionPeriod")
|
|
+ message.sent_timestamp / 1000
|
|
)
|
|
if retain_until <= unix_time():
|
|
return False
|
|
return True
|
|
|
|
|
|
sqs_backends = {}
|
|
for region in Session().get_available_regions("sqs"):
|
|
sqs_backends[region] = SQSBackend(region)
|
|
for region in Session().get_available_regions("sqs", partition_name="aws-us-gov"):
|
|
sqs_backends[region] = SQSBackend(region)
|
|
for region in Session().get_available_regions("sqs", partition_name="aws-cn"):
|
|
sqs_backends[region] = SQSBackend(region)
|