SQS: perform validation on all messages before attempting a send (#6722)
This commit is contained in:
parent
a29f556358
commit
4d4cae08d2
@ -101,7 +101,6 @@ class Message(BaseModel):
|
||||
|
||||
@property
|
||||
def attribute_md5(self) -> str:
|
||||
|
||||
md5 = md5_hash()
|
||||
|
||||
for attrName in sorted(self.message_attributes.keys()):
|
||||
@ -552,7 +551,6 @@ class Queue(CloudFormationModel):
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
if self.fifo_queue:
|
||||
|
||||
# the cases in which we dedupe fifo messages
|
||||
# from https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/using-messagededuplicationid-property.html
|
||||
# https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessage.html
|
||||
@ -786,19 +784,13 @@ class SQSBackend(BaseBackend):
|
||||
queue._set_attributes(attributes)
|
||||
return queue
|
||||
|
||||
def send_message(
|
||||
def _validate_message(
|
||||
self,
|
||||
queue_name: str,
|
||||
queue: Queue,
|
||||
message_body: str,
|
||||
message_attributes: Optional[Dict[str, Any]] = None,
|
||||
delay_seconds: Optional[int] = None,
|
||||
deduplication_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
system_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> Message:
|
||||
|
||||
queue = self.get_queue(queue_name)
|
||||
|
||||
) -> None:
|
||||
if queue.fifo_queue:
|
||||
if (
|
||||
queue.attributes.get("ContentBasedDeduplication") == "false"
|
||||
@ -822,6 +814,33 @@ class SQSBackend(BaseBackend):
|
||||
msg = f"One or more parameters are invalid. Reason: Message must be shorter than {queue.maximum_message_size} bytes." # type: ignore
|
||||
raise InvalidParameterValue(msg)
|
||||
|
||||
if group_id is None:
|
||||
if queue.fifo_queue:
|
||||
# MessageGroupId is a mandatory parameter for all
|
||||
# messages in a fifo queue
|
||||
raise MissingParameter("MessageGroupId")
|
||||
else:
|
||||
if not queue.fifo_queue:
|
||||
msg = (
|
||||
f"Value {group_id} for parameter MessageGroupId is invalid. "
|
||||
"Reason: The request include parameter that is not valid for this queue type."
|
||||
)
|
||||
raise InvalidParameterValue(msg)
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
queue_name: str,
|
||||
message_body: str,
|
||||
message_attributes: Optional[Dict[str, Any]] = None,
|
||||
delay_seconds: Optional[int] = None,
|
||||
deduplication_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
system_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> Message:
|
||||
queue = self.get_queue(queue_name)
|
||||
|
||||
self._validate_message(queue, message_body, deduplication_id, group_id)
|
||||
|
||||
if delay_seconds:
|
||||
delay_seconds = int(delay_seconds)
|
||||
else:
|
||||
@ -844,18 +863,7 @@ class SQSBackend(BaseBackend):
|
||||
random.choice(string.digits) for _ in range(20)
|
||||
)
|
||||
|
||||
if group_id is None:
|
||||
# MessageGroupId is a mandatory parameter for all
|
||||
# messages in a fifo queue
|
||||
if queue.fifo_queue:
|
||||
raise MissingParameter("MessageGroupId")
|
||||
else:
|
||||
if not queue.fifo_queue:
|
||||
msg = (
|
||||
f"Value {group_id} for parameter MessageGroupId is invalid. "
|
||||
"Reason: The request include parameter that is not valid for this queue type."
|
||||
)
|
||||
raise InvalidParameterValue(msg)
|
||||
if group_id is not None:
|
||||
message.group_id = group_id
|
||||
|
||||
if message_attributes:
|
||||
@ -877,7 +885,7 @@ class SQSBackend(BaseBackend):
|
||||
def send_message_batch(
|
||||
self, queue_name: str, entries: Dict[str, Dict[str, Any]]
|
||||
) -> Tuple[List[Message], List[Dict[str, Any]]]:
|
||||
self.get_queue(queue_name)
|
||||
queue = self.get_queue(queue_name)
|
||||
|
||||
if any(
|
||||
not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values()
|
||||
@ -899,6 +907,16 @@ class SQSBackend(BaseBackend):
|
||||
|
||||
messages = []
|
||||
failedInvalidDelay = []
|
||||
|
||||
for entry in entries.values():
|
||||
# validate ALL messages before trying to send any
|
||||
self._validate_message(
|
||||
queue,
|
||||
entry["MessageBody"],
|
||||
entry.get("MessageDeduplicationId"),
|
||||
entry.get("MessageGroupId"),
|
||||
)
|
||||
|
||||
for entry in entries.values():
|
||||
try:
|
||||
# Loop through looking for messages
|
||||
@ -954,7 +972,6 @@ class SQSBackend(BaseBackend):
|
||||
|
||||
# queue.messages only contains visible messages
|
||||
while True:
|
||||
|
||||
if result or (wait_seconds_timeout and unix_time() > polling_end):
|
||||
break
|
||||
|
||||
@ -1045,7 +1062,6 @@ class SQSBackend(BaseBackend):
|
||||
queue = self.get_queue(queue_name)
|
||||
for message in queue._messages:
|
||||
if message.had_receipt_handle(receipt_handle):
|
||||
|
||||
visibility_timeout_msec = int(visibility_timeout) * 1000
|
||||
given_visibility_timeout = unix_time_millis() + visibility_timeout_msec
|
||||
if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: # type: ignore
|
||||
|
Loading…
Reference in New Issue
Block a user