SQS: perform validation on all messages before attempting a send (#6722)

This commit is contained in:
amlodzianowski 2023-08-25 01:03:18 -07:00 committed by GitHub
parent a29f556358
commit 4d4cae08d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -101,7 +101,6 @@ class Message(BaseModel):
@property
def attribute_md5(self) -> str:
md5 = md5_hash()
for attrName in sorted(self.message_attributes.keys()):
@ -552,7 +551,6 @@ class Queue(CloudFormationModel):
def add_message(self, message: Message) -> None:
if self.fifo_queue:
# the cases in which we dedupe fifo messages
# from https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagededuplicationid-property.html
# https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessage.html
@ -786,19 +784,13 @@ class SQSBackend(BaseBackend):
queue._set_attributes(attributes)
return queue
def send_message(
def _validate_message(
self,
queue_name: str,
queue: Queue,
message_body: str,
message_attributes: Optional[Dict[str, Any]] = None,
delay_seconds: Optional[int] = None,
deduplication_id: Optional[str] = None,
group_id: Optional[str] = None,
system_attributes: Optional[Dict[str, Any]] = None,
) -> Message:
queue = self.get_queue(queue_name)
) -> None:
if queue.fifo_queue:
if (
queue.attributes.get("ContentBasedDeduplication") == "false"
@ -822,6 +814,33 @@ class SQSBackend(BaseBackend):
msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." # type: ignore
raise InvalidParameterValue(msg)
if group_id is None:
if queue.fifo_queue:
# MessageGroupId is a mandatory parameter for all
# messages in a fifo queue
raise MissingParameter("MessageGroupId")
else:
if not queue.fifo_queue:
msg = (
f"Value {group_id} for parameter MessageGroupId is invalid. "
"Reason: The request include parameter that is not valid for this queue type."
)
raise InvalidParameterValue(msg)
def send_message(
self,
queue_name: str,
message_body: str,
message_attributes: Optional[Dict[str, Any]] = None,
delay_seconds: Optional[int] = None,
deduplication_id: Optional[str] = None,
group_id: Optional[str] = None,
system_attributes: Optional[Dict[str, Any]] = None,
) -> Message:
queue = self.get_queue(queue_name)
self._validate_message(queue, message_body, deduplication_id, group_id)
if delay_seconds:
delay_seconds = int(delay_seconds)
else:
@ -844,18 +863,7 @@ class SQSBackend(BaseBackend):
random.choice(string.digits) for _ in range(20)
)
if group_id is None:
# MessageGroupId is a mandatory parameter for all
# messages in a fifo queue
if queue.fifo_queue:
raise MissingParameter("MessageGroupId")
else:
if not queue.fifo_queue:
msg = (
f"Value {group_id} for parameter MessageGroupId is invalid. "
"Reason: The request include parameter that is not valid for this queue type."
)
raise InvalidParameterValue(msg)
if group_id is not None:
message.group_id = group_id
if message_attributes:
@ -877,7 +885,7 @@ class SQSBackend(BaseBackend):
def send_message_batch(
self, queue_name: str, entries: Dict[str, Dict[str, Any]]
) -> Tuple[List[Message], List[Dict[str, Any]]]:
self.get_queue(queue_name)
queue = self.get_queue(queue_name)
if any(
not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values()
@ -899,6 +907,16 @@ class SQSBackend(BaseBackend):
messages = []
failedInvalidDelay = []
for entry in entries.values():
# validate ALL messages before trying to send any
self._validate_message(
queue,
entry["MessageBody"],
entry.get("MessageDeduplicationId"),
entry.get("MessageGroupId"),
)
for entry in entries.values():
try:
# Loop through looking for messages
@ -954,7 +972,6 @@ class SQSBackend(BaseBackend):
# queue.messages only contains visible messages
while True:
if result or (wait_seconds_timeout and unix_time() > polling_end):
break
@ -1045,7 +1062,6 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
for message in queue._messages:
if message.had_receipt_handle(receipt_handle):
visibility_timeout_msec = int(visibility_timeout) * 1000
given_visibility_timeout = unix_time_millis() + visibility_timeout_msec
if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: # type: ignore