From 4d4cae08d2a6914bca826347a7616aa457bd2c79 Mon Sep 17 00:00:00 2001 From: amlodzianowski Date: Fri, 25 Aug 2023 01:03:18 -0700 Subject: [PATCH] SQS: perform validation on all messages before attempting a send (#6722) --- moto/sqs/models.py | 68 ++++++++++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index 06c27f7fd..05d03e583 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -101,7 +101,6 @@ class Message(BaseModel): @property def attribute_md5(self) -> str: - md5 = md5_hash() for attrName in sorted(self.message_attributes.keys()): @@ -552,7 +551,6 @@ class Queue(CloudFormationModel): def add_message(self, message: Message) -> None: if self.fifo_queue: - # the cases in which we dedupe fifo messages # from https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagededuplicationid-property.html # https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessage.html @@ -786,19 +784,13 @@ class SQSBackend(BaseBackend): queue._set_attributes(attributes) return queue - def send_message( + def _validate_message( self, - queue_name: str, + queue: Queue, message_body: str, - message_attributes: Optional[Dict[str, Any]] = None, - delay_seconds: Optional[int] = None, deduplication_id: Optional[str] = None, group_id: Optional[str] = None, - system_attributes: Optional[Dict[str, Any]] = None, - ) -> Message: - - queue = self.get_queue(queue_name) - + ) -> None: if queue.fifo_queue: if ( queue.attributes.get("ContentBasedDeduplication") == "false" @@ -822,6 +814,33 @@ class SQSBackend(BaseBackend): msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." # type: ignore raise InvalidParameterValue(msg) + if group_id is None: + if queue.fifo_queue: + # MessageGroupId is a mandatory parameter for all + # messages in a fifo queue + raise MissingParameter("MessageGroupId") + else: + if not queue.fifo_queue: + msg = ( + f"Value {group_id} for parameter MessageGroupId is invalid. " + "Reason: The request include parameter that is not valid for this queue type." + ) + raise InvalidParameterValue(msg) + + def send_message( + self, + queue_name: str, + message_body: str, + message_attributes: Optional[Dict[str, Any]] = None, + delay_seconds: Optional[int] = None, + deduplication_id: Optional[str] = None, + group_id: Optional[str] = None, + system_attributes: Optional[Dict[str, Any]] = None, + ) -> Message: + queue = self.get_queue(queue_name) + + self._validate_message(queue, message_body, deduplication_id, group_id) + if delay_seconds: delay_seconds = int(delay_seconds) else: @@ -844,18 +863,7 @@ class SQSBackend(BaseBackend): random.choice(string.digits) for _ in range(20) ) - if group_id is None: - # MessageGroupId is a mandatory parameter for all - # messages in a fifo queue - if queue.fifo_queue: - raise MissingParameter("MessageGroupId") - else: - if not queue.fifo_queue: - msg = ( - f"Value {group_id} for parameter MessageGroupId is invalid. " - "Reason: The request include parameter that is not valid for this queue type." - ) - raise InvalidParameterValue(msg) + if group_id is not None: message.group_id = group_id if message_attributes: @@ -877,7 +885,7 @@ class SQSBackend(BaseBackend): def send_message_batch( self, queue_name: str, entries: Dict[str, Dict[str, Any]] ) -> Tuple[List[Message], List[Dict[str, Any]]]: - self.get_queue(queue_name) + queue = self.get_queue(queue_name) if any( not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values() @@ -899,6 +907,16 @@ class SQSBackend(BaseBackend): messages = [] failedInvalidDelay = [] + + for entry in entries.values(): + # validate ALL messages before trying to send any + self._validate_message( + queue, + entry["MessageBody"], + entry.get("MessageDeduplicationId"), + entry.get("MessageGroupId"), + ) + for entry in entries.values(): try: # Loop through looking for messages @@ -954,7 +972,6 @@ class SQSBackend(BaseBackend): # queue.messages only contains visible messages while True: - if result or (wait_seconds_timeout and unix_time() > polling_end): break @@ -1045,7 +1062,6 @@ class SQSBackend(BaseBackend): queue = self.get_queue(queue_name) for message in queue._messages: if message.had_receipt_handle(receipt_handle): - visibility_timeout_msec = int(visibility_timeout) * 1000 given_visibility_timeout = unix_time_millis() + visibility_timeout_msec if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: # type: ignore