SQS: perform validation on all messages before attempting a send (#6722)
This commit is contained in:
parent
a29f556358
commit
4d4cae08d2
@ -101,7 +101,6 @@ class Message(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def attribute_md5(self) -> str:
|
def attribute_md5(self) -> str:
|
||||||
|
|
||||||
md5 = md5_hash()
|
md5 = md5_hash()
|
||||||
|
|
||||||
for attrName in sorted(self.message_attributes.keys()):
|
for attrName in sorted(self.message_attributes.keys()):
|
||||||
@ -552,7 +551,6 @@ class Queue(CloudFormationModel):
|
|||||||
|
|
||||||
def add_message(self, message: Message) -> None:
|
def add_message(self, message: Message) -> None:
|
||||||
if self.fifo_queue:
|
if self.fifo_queue:
|
||||||
|
|
||||||
# the cases in which we dedupe fifo messages
|
# the cases in which we dedupe fifo messages
|
||||||
# from https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagededuplicationid-property.html
|
# from https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagededuplicationid-property.html
|
||||||
# https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessage.html
|
# https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessage.html
|
||||||
@ -786,19 +784,13 @@ class SQSBackend(BaseBackend):
|
|||||||
queue._set_attributes(attributes)
|
queue._set_attributes(attributes)
|
||||||
return queue
|
return queue
|
||||||
|
|
||||||
def send_message(
|
def _validate_message(
|
||||||
self,
|
self,
|
||||||
queue_name: str,
|
queue: Queue,
|
||||||
message_body: str,
|
message_body: str,
|
||||||
message_attributes: Optional[Dict[str, Any]] = None,
|
|
||||||
delay_seconds: Optional[int] = None,
|
|
||||||
deduplication_id: Optional[str] = None,
|
deduplication_id: Optional[str] = None,
|
||||||
group_id: Optional[str] = None,
|
group_id: Optional[str] = None,
|
||||||
system_attributes: Optional[Dict[str, Any]] = None,
|
) -> None:
|
||||||
) -> Message:
|
|
||||||
|
|
||||||
queue = self.get_queue(queue_name)
|
|
||||||
|
|
||||||
if queue.fifo_queue:
|
if queue.fifo_queue:
|
||||||
if (
|
if (
|
||||||
queue.attributes.get("ContentBasedDeduplication") == "false"
|
queue.attributes.get("ContentBasedDeduplication") == "false"
|
||||||
@ -822,6 +814,33 @@ class SQSBackend(BaseBackend):
|
|||||||
msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." # type: ignore
|
msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." # type: ignore
|
||||||
raise InvalidParameterValue(msg)
|
raise InvalidParameterValue(msg)
|
||||||
|
|
||||||
|
if group_id is None:
|
||||||
|
if queue.fifo_queue:
|
||||||
|
# MessageGroupId is a mandatory parameter for all
|
||||||
|
# messages in a fifo queue
|
||||||
|
raise MissingParameter("MessageGroupId")
|
||||||
|
else:
|
||||||
|
if not queue.fifo_queue:
|
||||||
|
msg = (
|
||||||
|
f"Value {group_id} for parameter MessageGroupId is invalid. "
|
||||||
|
"Reason: The request include parameter that is not valid for this queue type."
|
||||||
|
)
|
||||||
|
raise InvalidParameterValue(msg)
|
||||||
|
|
||||||
|
def send_message(
|
||||||
|
self,
|
||||||
|
queue_name: str,
|
||||||
|
message_body: str,
|
||||||
|
message_attributes: Optional[Dict[str, Any]] = None,
|
||||||
|
delay_seconds: Optional[int] = None,
|
||||||
|
deduplication_id: Optional[str] = None,
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
system_attributes: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Message:
|
||||||
|
queue = self.get_queue(queue_name)
|
||||||
|
|
||||||
|
self._validate_message(queue, message_body, deduplication_id, group_id)
|
||||||
|
|
||||||
if delay_seconds:
|
if delay_seconds:
|
||||||
delay_seconds = int(delay_seconds)
|
delay_seconds = int(delay_seconds)
|
||||||
else:
|
else:
|
||||||
@ -844,18 +863,7 @@ class SQSBackend(BaseBackend):
|
|||||||
random.choice(string.digits) for _ in range(20)
|
random.choice(string.digits) for _ in range(20)
|
||||||
)
|
)
|
||||||
|
|
||||||
if group_id is None:
|
if group_id is not None:
|
||||||
# MessageGroupId is a mandatory parameter for all
|
|
||||||
# messages in a fifo queue
|
|
||||||
if queue.fifo_queue:
|
|
||||||
raise MissingParameter("MessageGroupId")
|
|
||||||
else:
|
|
||||||
if not queue.fifo_queue:
|
|
||||||
msg = (
|
|
||||||
f"Value {group_id} for parameter MessageGroupId is invalid. "
|
|
||||||
"Reason: The request include parameter that is not valid for this queue type."
|
|
||||||
)
|
|
||||||
raise InvalidParameterValue(msg)
|
|
||||||
message.group_id = group_id
|
message.group_id = group_id
|
||||||
|
|
||||||
if message_attributes:
|
if message_attributes:
|
||||||
@ -877,7 +885,7 @@ class SQSBackend(BaseBackend):
|
|||||||
def send_message_batch(
|
def send_message_batch(
|
||||||
self, queue_name: str, entries: Dict[str, Dict[str, Any]]
|
self, queue_name: str, entries: Dict[str, Dict[str, Any]]
|
||||||
) -> Tuple[List[Message], List[Dict[str, Any]]]:
|
) -> Tuple[List[Message], List[Dict[str, Any]]]:
|
||||||
self.get_queue(queue_name)
|
queue = self.get_queue(queue_name)
|
||||||
|
|
||||||
if any(
|
if any(
|
||||||
not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values()
|
not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values()
|
||||||
@ -899,6 +907,16 @@ class SQSBackend(BaseBackend):
|
|||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
failedInvalidDelay = []
|
failedInvalidDelay = []
|
||||||
|
|
||||||
|
for entry in entries.values():
|
||||||
|
# validate ALL messages before trying to send any
|
||||||
|
self._validate_message(
|
||||||
|
queue,
|
||||||
|
entry["MessageBody"],
|
||||||
|
entry.get("MessageDeduplicationId"),
|
||||||
|
entry.get("MessageGroupId"),
|
||||||
|
)
|
||||||
|
|
||||||
for entry in entries.values():
|
for entry in entries.values():
|
||||||
try:
|
try:
|
||||||
# Loop through looking for messages
|
# Loop through looking for messages
|
||||||
@ -954,7 +972,6 @@ class SQSBackend(BaseBackend):
|
|||||||
|
|
||||||
# queue.messages only contains visible messages
|
# queue.messages only contains visible messages
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
if result or (wait_seconds_timeout and unix_time() > polling_end):
|
if result or (wait_seconds_timeout and unix_time() > polling_end):
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -1045,7 +1062,6 @@ class SQSBackend(BaseBackend):
|
|||||||
queue = self.get_queue(queue_name)
|
queue = self.get_queue(queue_name)
|
||||||
for message in queue._messages:
|
for message in queue._messages:
|
||||||
if message.had_receipt_handle(receipt_handle):
|
if message.had_receipt_handle(receipt_handle):
|
||||||
|
|
||||||
visibility_timeout_msec = int(visibility_timeout) * 1000
|
visibility_timeout_msec = int(visibility_timeout) * 1000
|
||||||
given_visibility_timeout = unix_time_millis() + visibility_timeout_msec
|
given_visibility_timeout = unix_time_millis() + visibility_timeout_msec
|
||||||
if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: # type: ignore
|
if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: # type: ignore
|
||||||
|
Loading…
x
Reference in New Issue
Block a user