Refactor sqs.send_message_batch
This commit is contained in:
parent
6b67002a42
commit
23978e644f
@ -6065,7 +6065,7 @@
|
||||
- [X] untag_resource
|
||||
|
||||
## sqs
|
||||
65% implemented
|
||||
75% implemented
|
||||
- [X] add_permission
|
||||
- [X] change_message_visibility
|
||||
- [ ] change_message_visibility_batch
|
||||
@ -6076,13 +6076,13 @@
|
||||
- [ ] get_queue_attributes
|
||||
- [ ] get_queue_url
|
||||
- [X] list_dead_letter_source_queues
|
||||
- [ ] list_queue_tags
|
||||
- [X] list_queue_tags
|
||||
- [X] list_queues
|
||||
- [X] purge_queue
|
||||
- [ ] receive_message
|
||||
- [X] remove_permission
|
||||
- [X] send_message
|
||||
- [ ] send_message_batch
|
||||
- [X] send_message_batch
|
||||
- [X] set_queue_attributes
|
||||
- [X] tag_queue
|
||||
- [X] untag_queue
|
||||
|
@ -33,3 +33,56 @@ class QueueAlreadyExists(RESTError):
|
||||
def __init__(self, message):
|
||||
super(QueueAlreadyExists, self).__init__(
|
||||
"QueueAlreadyExists", message)
|
||||
|
||||
|
||||
class EmptyBatchRequest(RESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self):
|
||||
super(EmptyBatchRequest, self).__init__(
|
||||
'EmptyBatchRequest',
|
||||
'There should be at least one SendMessageBatchRequestEntry in the request.'
|
||||
)
|
||||
|
||||
|
||||
class InvalidBatchEntryId(RESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self):
|
||||
super(InvalidBatchEntryId, self).__init__(
|
||||
'InvalidBatchEntryId',
|
||||
'A batch entry id can only contain alphanumeric characters, '
|
||||
'hyphens and underscores. It can be at most 80 letters long.'
|
||||
)
|
||||
|
||||
|
||||
class BatchRequestTooLong(RESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self, length):
|
||||
super(BatchRequestTooLong, self).__init__(
|
||||
'BatchRequestTooLong',
|
||||
'Batch requests cannot be longer than 262144 bytes. '
|
||||
'You have sent {} bytes.'.format(length)
|
||||
)
|
||||
|
||||
|
||||
class BatchEntryIdsNotDistinct(RESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self, entry_id):
|
||||
super(BatchEntryIdsNotDistinct, self).__init__(
|
||||
'BatchEntryIdsNotDistinct',
|
||||
'Id {} repeated.'.format(entry_id)
|
||||
)
|
||||
|
||||
|
||||
class TooManyEntriesInBatchRequest(RESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self, number):
|
||||
super(TooManyEntriesInBatchRequest, self).__init__(
|
||||
'TooManyEntriesInBatchRequest',
|
||||
'Maximum number of entries per request are 10. '
|
||||
'You have sent {}.'.format(number)
|
||||
)
|
||||
|
@ -20,11 +20,17 @@ from .exceptions import (
|
||||
QueueDoesNotExist,
|
||||
QueueAlreadyExists,
|
||||
ReceiptHandleIsInvalid,
|
||||
InvalidBatchEntryId,
|
||||
BatchRequestTooLong,
|
||||
BatchEntryIdsNotDistinct,
|
||||
TooManyEntriesInBatchRequest
|
||||
)
|
||||
|
||||
DEFAULT_ACCOUNT_ID = 123456789012
|
||||
DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU"
|
||||
|
||||
MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB
|
||||
|
||||
TRANSPORT_TYPE_ENCODINGS = {'String': b'\x01', 'Binary': b'\x02', 'Number': b'\x01'}
|
||||
|
||||
|
||||
@ -516,6 +522,49 @@ class SQSBackend(BaseBackend):
|
||||
|
||||
return message
|
||||
|
||||
def send_message_batch(self, queue_name, entries):
|
||||
self.get_queue(queue_name)
|
||||
|
||||
if any(not re.match(r'^[\w-]{1,80}$', entry['Id']) for entry in entries.values()):
|
||||
raise InvalidBatchEntryId()
|
||||
|
||||
body_length = next(
|
||||
(len(entry['MessageBody']) for entry in entries.values() if len(entry['MessageBody']) > MAXIMUM_MESSAGE_LENGTH),
|
||||
False
|
||||
)
|
||||
if body_length:
|
||||
raise BatchRequestTooLong(body_length)
|
||||
|
||||
duplicate_id = self._get_first_duplicate_id([entry['Id'] for entry in entries.values()])
|
||||
if duplicate_id:
|
||||
raise BatchEntryIdsNotDistinct(duplicate_id)
|
||||
|
||||
if len(entries) > 10:
|
||||
raise TooManyEntriesInBatchRequest(len(entries))
|
||||
|
||||
messages = []
|
||||
for index, entry in entries.items():
|
||||
# Loop through looking for messages
|
||||
message = self.send_message(
|
||||
queue_name,
|
||||
entry['MessageBody'],
|
||||
message_attributes=entry['MessageAttributes'],
|
||||
delay_seconds=entry['DelaySeconds']
|
||||
)
|
||||
message.user_id = entry['Id']
|
||||
|
||||
messages.append(message)
|
||||
|
||||
return messages
|
||||
|
||||
def _get_first_duplicate_id(self, ids):
|
||||
unique_ids = set()
|
||||
for id in ids:
|
||||
if id in unique_ids:
|
||||
return id
|
||||
unique_ids.add(id)
|
||||
return None
|
||||
|
||||
def receive_messages(self, queue_name, count, wait_seconds_timeout, visibility_timeout):
|
||||
"""
|
||||
Attempt to retrieve visible messages from a queue.
|
||||
@ -677,6 +726,9 @@ class SQSBackend(BaseBackend):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def list_queue_tags(self, queue_name):
|
||||
return self.get_queue(queue_name)
|
||||
|
||||
|
||||
sqs_backends = {}
|
||||
for region in boto.sqs.regions():
|
||||
|
@ -11,6 +11,7 @@ from .exceptions import (
|
||||
MessageAttributesInvalid,
|
||||
MessageNotInflight,
|
||||
ReceiptHandleIsInvalid,
|
||||
EmptyBatchRequest
|
||||
)
|
||||
|
||||
MAXIMUM_VISIBILTY_TIMEOUT = 43200
|
||||
@ -237,72 +238,33 @@ class SQSResponse(BaseResponse):
|
||||
self.sqs_backend.get_queue(queue_name)
|
||||
|
||||
if self.querystring.get('Entries'):
|
||||
return self._error('AWS.SimpleQueueService.EmptyBatchRequest',
|
||||
'There should be at least one SendMessageBatchRequestEntry in the request.')
|
||||
raise EmptyBatchRequest()
|
||||
|
||||
entries = {}
|
||||
for key, value in self.querystring.items():
|
||||
match = re.match(r'^SendMessageBatchRequestEntry\.(\d+)\.Id', key)
|
||||
if match:
|
||||
entries[match.group(1)] = {
|
||||
'Id': value[0],
|
||||
'MessageBody': self.querystring.get(
|
||||
'SendMessageBatchRequestEntry.{}.MessageBody'.format(match.group(1)))[0]
|
||||
}
|
||||
|
||||
if any(not re.match(r'^[\w-]{1,80}$', entry['Id']) for entry in entries.values()):
|
||||
return self._error('AWS.SimpleQueueService.InvalidBatchEntryId',
|
||||
'A batch entry id can only contain alphanumeric characters, '
|
||||
'hyphens and underscores. It can be at most 80 letters long.')
|
||||
|
||||
body_length = next(
|
||||
(len(entry['MessageBody']) for entry in entries.values() if len(entry['MessageBody']) > MAXIMUM_MESSAGE_LENGTH),
|
||||
False
|
||||
)
|
||||
if body_length:
|
||||
return self._error('AWS.SimpleQueueService.BatchRequestTooLong',
|
||||
'Batch requests cannot be longer than 262144 bytes. '
|
||||
'You have sent {} bytes.'.format(body_length))
|
||||
|
||||
duplicate_id = self._get_first_duplicate_id([entry['Id'] for entry in entries.values()])
|
||||
if duplicate_id:
|
||||
return self._error('AWS.SimpleQueueService.BatchEntryIdsNotDistinct',
|
||||
'Id {} repeated.'.format(duplicate_id))
|
||||
|
||||
if len(entries) > 10:
|
||||
return self._error('AWS.SimpleQueueService.TooManyEntriesInBatchRequest',
|
||||
'Maximum number of entries per request are 10. '
|
||||
'You have sent 11.')
|
||||
|
||||
messages = []
|
||||
for index, entry in entries.items():
|
||||
# Loop through looking for messages
|
||||
delay_key = 'SendMessageBatchRequestEntry.{0}.DelaySeconds'.format(
|
||||
index)
|
||||
delay_seconds = self.querystring.get(delay_key, [None])[0]
|
||||
message = self.sqs_backend.send_message(
|
||||
queue_name, entry['MessageBody'], delay_seconds=delay_seconds)
|
||||
message.user_id = entry['Id']
|
||||
index = match.group(1)
|
||||
|
||||
message_attributes = parse_message_attributes(
|
||||
self.querystring, base='SendMessageBatchRequestEntry.{0}.'.format(index))
|
||||
self.querystring, base='SendMessageBatchRequestEntry.{}.'.format(index))
|
||||
if type(message_attributes) == tuple:
|
||||
return message_attributes[0], message_attributes[1]
|
||||
message.message_attributes = message_attributes
|
||||
|
||||
messages.append(message)
|
||||
entries[index] = {
|
||||
'Id': value[0],
|
||||
'MessageBody': self.querystring.get(
|
||||
'SendMessageBatchRequestEntry.{}.MessageBody'.format(index))[0],
|
||||
'DelaySeconds': self.querystring.get(
|
||||
'SendMessageBatchRequestEntry.{}.DelaySeconds'.format(index), [None])[0],
|
||||
'MessageAttributes': message_attributes
|
||||
}
|
||||
|
||||
messages = self.sqs_backend.send_message_batch(queue_name, entries)
|
||||
|
||||
template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE)
|
||||
return template.render(messages=messages)
|
||||
|
||||
def _get_first_duplicate_id(self, ids):
|
||||
unique_ids = set()
|
||||
for id in ids:
|
||||
if id in unique_ids:
|
||||
return id
|
||||
unique_ids.add(id)
|
||||
return None
|
||||
|
||||
def delete_message(self):
|
||||
queue_name = self._get_queue_name()
|
||||
receipt_handle = self.querystring.get("ReceiptHandle")[0]
|
||||
@ -441,7 +403,7 @@ class SQSResponse(BaseResponse):
|
||||
def list_queue_tags(self):
|
||||
queue_name = self._get_queue_name()
|
||||
|
||||
queue = self.sqs_backend.get_queue(queue_name)
|
||||
queue = self.sqs_backend.list_queue_tags(queue_name)
|
||||
|
||||
template = self.response_template(LIST_QUEUE_TAGS_RESPONSE)
|
||||
return template.render(tags=queue.tags)
|
||||
|
@ -883,10 +883,70 @@ def test_delete_message_after_visibility_timeout():
|
||||
|
||||
|
||||
@mock_sqs
|
||||
def test_send_message_batch_errors():
|
||||
client = boto3.client('sqs', region_name = 'us-east-1')
|
||||
def test_send_message_batch():
|
||||
client = boto3.client('sqs', region_name='us-east-1')
|
||||
response = client.create_queue(QueueName='test-queue')
|
||||
queue_url = response['QueueUrl']
|
||||
|
||||
response = client.create_queue(QueueName='test-queue-with-tags')
|
||||
response = client.send_message_batch(
|
||||
QueueUrl=queue_url,
|
||||
Entries=[
|
||||
{
|
||||
'Id': 'id_1',
|
||||
'MessageBody': 'body_1',
|
||||
'DelaySeconds': 0,
|
||||
'MessageAttributes': {
|
||||
'attribute_name_1': {
|
||||
'StringValue': 'attribute_value_1',
|
||||
'DataType': 'String'
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'Id': 'id_2',
|
||||
'MessageBody': 'body_2',
|
||||
'DelaySeconds': 0,
|
||||
'MessageAttributes': {
|
||||
'attribute_name_2': {
|
||||
'StringValue': '123',
|
||||
'DataType': 'Number'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
sorted([entry['Id'] for entry in response['Successful']]).should.equal([
|
||||
'id_1',
|
||||
'id_2'
|
||||
])
|
||||
|
||||
response = client.receive_message(
|
||||
QueueUrl=queue_url,
|
||||
MaxNumberOfMessages=10
|
||||
)
|
||||
|
||||
response['Messages'][0]['Body'].should.equal('body_1')
|
||||
response['Messages'][0]['MessageAttributes'].should.equal({
|
||||
'attribute_name_1': {
|
||||
'StringValue': 'attribute_value_1',
|
||||
'DataType': 'String'
|
||||
}
|
||||
})
|
||||
response['Messages'][1]['Body'].should.equal('body_2')
|
||||
response['Messages'][1]['MessageAttributes'].should.equal({
|
||||
'attribute_name_2': {
|
||||
'StringValue': '123',
|
||||
'DataType': 'Number'
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@mock_sqs
|
||||
def test_send_message_batch_errors():
|
||||
client = boto3.client('sqs', region_name='us-east-1')
|
||||
|
||||
response = client.create_queue(QueueName='test-queue')
|
||||
queue_url = response['QueueUrl']
|
||||
|
||||
client.send_message_batch.when.called_with(
|
||||
|
Loading…
Reference in New Issue
Block a user