SQS - Align logic around receipt_handles with AWS (#4655)

This commit is contained in:
Bert Blommers 2021-12-04 21:51:51 -01:00 committed by GitHub
parent 695a3ca3d3
commit cbfe962b70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 150 additions and 109 deletions

View File

@ -1,11 +1,6 @@
from moto.core.exceptions import RESTError
class MessageNotInflight(Exception):
description = "The message referred to is not in flight."
status_code = 400
class ReceiptHandleIsInvalid(RESTError):
code = 400

View File

@ -24,7 +24,6 @@ from moto.core.utils import (
from .utils import generate_receipt_handle
from .exceptions import (
MessageAttributesInvalid,
MessageNotInflight,
QueueDoesNotExist,
QueueAlreadyExists,
ReceiptHandleIsInvalid,
@ -73,6 +72,7 @@ class Message(BaseModel):
self._body = body
self.message_attributes = {}
self.receipt_handle = None
self._old_receipt_handles = []
self.sender_id = DEFAULT_SENDER_ID
self.sent_timestamp = None
self.approximate_first_receive_timestamp = None
@ -178,6 +178,7 @@ class Message(BaseModel):
if visibility_timeout:
self.change_visibility(visibility_timeout)
self._old_receipt_handles.append(self.receipt_handle)
self.receipt_handle = generate_receipt_handle()
def change_visibility(self, visibility_timeout):
@ -203,6 +204,16 @@ class Message(BaseModel):
return True
return False
@property
def all_receipt_handles(self):
return [self.receipt_handle] + self._old_receipt_handles
def had_receipt_handle(self, receipt_handle):
"""
Check if this message ever had this receipt_handle in the past
"""
return receipt_handle in self.all_receipt_handles
class Queue(CloudFormationModel):
BASE_ATTRIBUTES = [
@ -247,6 +258,7 @@ class Queue(CloudFormationModel):
self._messages = []
self._pending_messages = set()
self.deleted_messages = set()
now = unix_time()
self.created_timestamp = now
@ -541,6 +553,26 @@ class Queue(CloudFormationModel):
for m in messages
]
def delete_message(self, receipt_handle):
if receipt_handle in self.deleted_messages:
# Already deleted - gracefully handle deleting it again
return
if not any(
message.had_receipt_handle(receipt_handle) for message in self._messages
):
raise ReceiptHandleIsInvalid()
# Delete message from queue regardless of pending state
new_messages = []
for message in self._messages:
if message.had_receipt_handle(receipt_handle):
self.pending_messages.discard(message)
self.deleted_messages.update(message.all_receipt_handles)
continue
new_messages.append(message)
self._messages = new_messages
@classmethod
def has_cfn_attr(cls, attribute_name):
return attribute_name in ["Arn", "QueueName"]
@ -906,26 +938,12 @@ class SQSBackend(BaseBackend):
def delete_message(self, queue_name, receipt_handle):
queue = self.get_queue(queue_name)
if not any(
message.receipt_handle == receipt_handle for message in queue._messages
):
raise ReceiptHandleIsInvalid()
# Delete message from queue regardless of pending state
new_messages = []
for message in queue._messages:
if message.receipt_handle == receipt_handle:
queue.pending_messages.discard(message)
continue
new_messages.append(message)
queue._messages = new_messages
queue.delete_message(receipt_handle)
def change_message_visibility(self, queue_name, receipt_handle, visibility_timeout):
queue = self.get_queue(queue_name)
for message in queue._messages:
if message.receipt_handle == receipt_handle:
if message.visible:
raise MessageNotInflight
if message.had_receipt_handle(receipt_handle):
visibility_timeout_msec = int(visibility_timeout) * 1000
given_visibility_timeout = unix_time_millis() + visibility_timeout_msec
@ -938,7 +956,7 @@ class SQSBackend(BaseBackend):
)
message.change_visibility(visibility_timeout)
if message.visible:
if message.visible and message in queue.pending_messages:
# If the message is visible again, remove it from pending
# messages.
queue.pending_messages.remove(message)

View File

@ -14,7 +14,6 @@ from .exceptions import (
EmptyBatchRequest,
InvalidAddress,
InvalidAttributeName,
MessageNotInflight,
ReceiptHandleIsInvalid,
BatchEntryIdsNotDistinct,
)
@ -123,17 +122,11 @@ class SQSResponse(BaseResponse):
except ValueError:
return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400)
try:
self.sqs_backend.change_message_visibility(
queue_name=queue_name,
receipt_handle=receipt_handle,
visibility_timeout=visibility_timeout,
)
except MessageNotInflight as e:
return (
"Invalid request: {0}".format(e.description),
dict(status=e.status_code),
)
self.sqs_backend.change_message_visibility(
queue_name=queue_name,
receipt_handle=receipt_handle,
visibility_timeout=visibility_timeout,
)
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE)
return template.render()
@ -176,15 +169,6 @@ class SQSResponse(BaseResponse):
"Message": e.description,
}
)
except MessageNotInflight as e:
error.append(
{
"Id": entry["id"],
"SenderFault": "false",
"Code": "AWS.SimpleQueueService.MessageNotInflight",
"Message": e.description,
}
)
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE)
return template.render(success=success, errors=error)

View File

@ -1677,6 +1677,23 @@ def test_change_message_visibility_boto3():
queue.attributes["ApproximateNumberOfMessages"].should.equal("0")
@mock_sqs
def test_change_message_visibility_on_unknown_receipt_handle():
sqs = boto3.resource("sqs", region_name="us-east-1")
conn = boto3.client("sqs", region_name="us-east-1")
queue = sqs.create_queue(
QueueName=str(uuid4())[0:6], Attributes={"VisibilityTimeout": "2"}
)
with pytest.raises(ClientError) as exc:
conn.change_message_visibility(
QueueUrl=queue.url, ReceiptHandle="unknown-stuff", VisibilityTimeout=432,
)
err = exc.value.response["Error"]
err["Code"].should.equal("ReceiptHandleIsInvalid")
err["Message"].should.equal("The input receipt handle is invalid.")
# Has boto3 equivalent
@mock_sqs_deprecated
def test_message_attributes():
@ -1887,39 +1904,11 @@ def test_queue_attributes():
attribute_names.should.contain("QueueArn")
# Has boto3 equivalent
@mock_sqs_deprecated
def test_change_message_visibility_on_invalid_receipt():
conn = boto.connect_sqs("the_key", "the_secret")
queue = conn.create_queue("test-queue", visibility_timeout=1)
queue.set_message_class(RawMessage)
queue.write(queue.new_message("this is another test message"))
queue.count().should.equal(1)
messages = conn.receive_message(queue, number_messages=1)
assert len(messages) == 1
original_message = messages[0]
queue.count().should.equal(0)
time.sleep(2)
queue.count().should.equal(1)
messages = conn.receive_message(queue, number_messages=1)
assert len(messages) == 1
original_message.change_visibility.when.called_with(100).should.throw(SQSError)
@mock_sqs
def test_change_message_visibility_on_invalid_receipt_boto3():
def test_change_message_visibility_on_old_message_boto3():
sqs = boto3.resource("sqs", region_name="us-east-1")
queue = sqs.create_queue(
QueueName=str(uuid4())[0:6], Attributes={"VisibilityTimeout ": "1"}
QueueName=str(uuid4())[0:6], Attributes={"VisibilityTimeout": "1"}
)
queue.send_message(MessageBody="test message 1")
@ -1942,62 +1931,43 @@ def test_change_message_visibility_on_invalid_receipt_boto3():
messages.should.have.length_of(1)
with pytest.raises(ClientError) as ex:
original_message.change_visibility(VisibilityTimeout=100)
err = ex.value.response["Error"]
err["Code"].should.equal("ReceiptHandleIsInvalid")
err["Message"].should.equal("The input receipt handle is invalid.")
# Has boto3 equivalent
@mock_sqs_deprecated
def test_change_message_visibility_on_visible_message():
conn = boto.connect_sqs("the_key", "the_secret")
queue = conn.create_queue("test-queue", visibility_timeout=1)
queue.set_message_class(RawMessage)
queue.write(queue.new_message("this is another test message"))
queue.count().should.equal(1)
messages = conn.receive_message(queue, number_messages=1)
assert len(messages) == 1
original_message = messages[0]
queue.count().should.equal(0)
# Docs indicate this should throw an ReceiptHandleIsInvalid, but this is allowed in AWS
original_message.change_visibility(VisibilityTimeout=100)
# Docs indicate this should throw a MessageNotInflight, but this is allowed in AWS
original_message.change_visibility(VisibilityTimeout=100)
time.sleep(2)
queue.count().should.equal(1)
original_message.change_visibility.when.called_with(100).should.throw(SQSError)
# Message is not yet available, because of the visibility-timeout
messages = queue.receive_messages(MaxNumberOfMessages=1)
messages.should.have.length_of(0)
@mock_sqs
def test_change_message_visibility_on_visible_message_boto3():
sqs = boto3.resource("sqs", region_name="us-east-1")
queue = sqs.create_queue(
QueueName=str(uuid4())[0:6], Attributes={"VisibilityTimeout ": "1"}
QueueName=str(uuid4())[0:6], Attributes={"VisibilityTimeout": "1"}
)
queue.send_message(MessageBody="test message")
messages = queue.receive_messages(MaxNumberOfMessages=1)
messages.should.have.length_of(1)
original_message = messages[0]
queue.reload()
queue.attributes["ApproximateNumberOfMessages"].should.equal("0")
time.sleep(2)
queue.reload()
queue.attributes["ApproximateNumberOfMessages"].should.equal("1")
messages = queue.receive_messages(MaxNumberOfMessages=1)
messages.should.have.length_of(1)
# TODO: We should catch a ClientError here, but Moto throws an error in the wrong format
with pytest.raises(Exception) as ex:
original_message.change_visibility(VisibilityTimeout=100)
str(ex).should.match("Invalid request: The message referred to is not in flight.")
messages[0].change_visibility(VisibilityTimeout=100)
time.sleep(2)
queue.reload()
queue.attributes["ApproximateNumberOfMessages"].should.equal("0")
# Has boto3 equivalent
@ -2112,6 +2082,48 @@ def test_delete_message_errors():
).should.throw(ClientError, "The input receipt handle is invalid.")
@mock_sqs
def test_delete_message_twice_using_same_receipt_handle():
client = boto3.client("sqs", region_name="us-east-1")
response = client.create_queue(QueueName=str(uuid4())[0:6])
queue_url = response["QueueUrl"]
client.send_message(QueueUrl=queue_url, MessageBody="body")
response = client.receive_message(QueueUrl=queue_url)
receipt_handle = response["Messages"][0]["ReceiptHandle"]
client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
@mock_sqs
def test_delete_message_using_old_receipt_handle():
client = boto3.client("sqs", region_name="us-east-1")
response = client.create_queue(
QueueName=str(uuid4())[0:6], Attributes={"VisibilityTimeout": "0"}
)
queue_url = response["QueueUrl"]
client.send_message(QueueUrl=queue_url, MessageBody="body")
response = client.receive_message(QueueUrl=queue_url)
receipt_1 = response["Messages"][0]["ReceiptHandle"]
response = client.receive_message(QueueUrl=queue_url)
receipt_2 = response["Messages"][0]["ReceiptHandle"]
receipt_1.shouldnt.equal(receipt_2)
# Can use an old receipt_handle to delete a message
client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_1)
# Sanity check the message really is gone
client.receive_message(QueueUrl=queue_url).shouldnt.have.key("Messages")
# We can delete it again
client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_1)
# Can use the second receipt handle to delete it 'again' - succeeds, as it is idempotent against the message
client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_2)
@mock_sqs
def test_send_message_batch():
client = boto3.client("sqs", region_name="us-east-1")
@ -2421,6 +2433,38 @@ def test_batch_change_message_visibility():
len(resp["Messages"]).should.equal(3)
@mock_sqs
def test_batch_change_message_visibility_on_old_message():
sqs = boto3.resource("sqs", region_name="us-east-1")
queue = sqs.create_queue(
QueueName=str(uuid4())[0:6], Attributes={"VisibilityTimeout": "1"}
)
queue.send_message(MessageBody="test message 1")
messages = queue.receive_messages(MaxNumberOfMessages=1)
messages.should.have.length_of(1)
original_message = messages[0]
time.sleep(2)
messages = queue.receive_messages(MaxNumberOfMessages=1)
messages[0].receipt_handle.shouldnt.equal(original_message.receipt_handle)
entries = [
{
"Id": str(uuid.uuid4()),
"ReceiptHandle": original_message.receipt_handle,
"VisibilityTimeout": 4,
}
]
resp = queue.change_message_visibility_batch(Entries=entries)
resp["Successful"].should.have.length_of(1)
@mock_sqs
def test_permissions():
client = boto3.client("sqs", region_name="us-east-1")