From 23978e644fae10756d97c446d6db4c7f7a5b4c04 Mon Sep 17 00:00:00 2001 From: gruebel Date: Sat, 26 Oct 2019 22:08:45 +0200 Subject: [PATCH] Refactor sqs.send_message_batch --- IMPLEMENTATION_COVERAGE.md | 6 ++-- moto/sqs/exceptions.py | 53 +++++++++++++++++++++++++++++ moto/sqs/models.py | 52 ++++++++++++++++++++++++++++ moto/sqs/responses.py | 70 +++++++++----------------------------- tests/test_sqs/test_sqs.py | 66 +++++++++++++++++++++++++++++++++-- 5 files changed, 187 insertions(+), 60 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index 072173226..126d454c6 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -6065,7 +6065,7 @@ - [X] untag_resource ## sqs -65% implemented +75% implemented - [X] add_permission - [X] change_message_visibility - [ ] change_message_visibility_batch @@ -6076,13 +6076,13 @@ - [ ] get_queue_attributes - [ ] get_queue_url - [X] list_dead_letter_source_queues -- [ ] list_queue_tags +- [X] list_queue_tags - [X] list_queues - [X] purge_queue - [ ] receive_message - [X] remove_permission - [X] send_message -- [ ] send_message_batch +- [X] send_message_batch - [X] set_queue_attributes - [X] tag_queue - [X] untag_queue diff --git a/moto/sqs/exceptions.py b/moto/sqs/exceptions.py index 02f28b2d2..6eee5b843 100644 --- a/moto/sqs/exceptions.py +++ b/moto/sqs/exceptions.py @@ -33,3 +33,56 @@ class QueueAlreadyExists(RESTError): def __init__(self, message): super(QueueAlreadyExists, self).__init__( "QueueAlreadyExists", message) + + +class EmptyBatchRequest(RESTError): + code = 400 + + def __init__(self): + super(EmptyBatchRequest, self).__init__( + 'EmptyBatchRequest', + 'There should be at least one SendMessageBatchRequestEntry in the request.' + ) + + +class InvalidBatchEntryId(RESTError): + code = 400 + + def __init__(self): + super(InvalidBatchEntryId, self).__init__( + 'InvalidBatchEntryId', + 'A batch entry id can only contain alphanumeric characters, ' + 'hyphens and underscores. It can be at most 80 letters long.' + ) + + +class BatchRequestTooLong(RESTError): + code = 400 + + def __init__(self, length): + super(BatchRequestTooLong, self).__init__( + 'BatchRequestTooLong', + 'Batch requests cannot be longer than 262144 bytes. ' + 'You have sent {} bytes.'.format(length) + ) + + +class BatchEntryIdsNotDistinct(RESTError): + code = 400 + + def __init__(self, entry_id): + super(BatchEntryIdsNotDistinct, self).__init__( + 'BatchEntryIdsNotDistinct', + 'Id {} repeated.'.format(entry_id) + ) + + +class TooManyEntriesInBatchRequest(RESTError): + code = 400 + + def __init__(self, number): + super(TooManyEntriesInBatchRequest, self).__init__( + 'TooManyEntriesInBatchRequest', + 'Maximum number of entries per request are 10. ' + 'You have sent {}.'.format(number) + ) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index eb237e437..9900846ac 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -20,11 +20,17 @@ from .exceptions import ( QueueDoesNotExist, QueueAlreadyExists, ReceiptHandleIsInvalid, + InvalidBatchEntryId, + BatchRequestTooLong, + BatchEntryIdsNotDistinct, + TooManyEntriesInBatchRequest ) DEFAULT_ACCOUNT_ID = 123456789012 DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU" +MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB + TRANSPORT_TYPE_ENCODINGS = {'String': b'\x01', 'Binary': b'\x02', 'Number': b'\x01'} @@ -516,6 +522,49 @@ class SQSBackend(BaseBackend): return message + def send_message_batch(self, queue_name, entries): + self.get_queue(queue_name) + + if any(not re.match(r'^[\w-]{1,80}$', entry['Id']) for entry in entries.values()): + raise InvalidBatchEntryId() + + body_length = next( + (len(entry['MessageBody']) for entry in entries.values() if len(entry['MessageBody']) > MAXIMUM_MESSAGE_LENGTH), + False + ) + if body_length: + raise BatchRequestTooLong(body_length) + + duplicate_id = self._get_first_duplicate_id([entry['Id'] for entry in entries.values()]) + if duplicate_id: + raise BatchEntryIdsNotDistinct(duplicate_id) + + if len(entries) > 10: + raise TooManyEntriesInBatchRequest(len(entries)) + + messages = [] + for index, entry in entries.items(): + # Loop through looking for messages + message = self.send_message( + queue_name, + entry['MessageBody'], + message_attributes=entry['MessageAttributes'], + delay_seconds=entry['DelaySeconds'] + ) + message.user_id = entry['Id'] + + messages.append(message) + + return messages + + def _get_first_duplicate_id(self, ids): + unique_ids = set() + for id in ids: + if id in unique_ids: + return id + unique_ids.add(id) + return None + def receive_messages(self, queue_name, count, wait_seconds_timeout, visibility_timeout): """ Attempt to retrieve visible messages from a queue. @@ -677,6 +726,9 @@ class SQSBackend(BaseBackend): except KeyError: pass + def list_queue_tags(self, queue_name): + return self.get_queue(queue_name) + sqs_backends = {} for region in boto.sqs.regions(): diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index b6f717f3b..52b237235 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -11,6 +11,7 @@ from .exceptions import ( MessageAttributesInvalid, MessageNotInflight, ReceiptHandleIsInvalid, + EmptyBatchRequest ) MAXIMUM_VISIBILTY_TIMEOUT = 43200 @@ -237,72 +238,33 @@ class SQSResponse(BaseResponse): self.sqs_backend.get_queue(queue_name) if self.querystring.get('Entries'): - return self._error('AWS.SimpleQueueService.EmptyBatchRequest', - 'There should be at least one SendMessageBatchRequestEntry in the request.') + raise EmptyBatchRequest() entries = {} for key, value in self.querystring.items(): match = re.match(r'^SendMessageBatchRequestEntry\.(\d+)\.Id', key) if match: - entries[match.group(1)] = { + index = match.group(1) + + message_attributes = parse_message_attributes( + self.querystring, base='SendMessageBatchRequestEntry.{}.'.format(index)) + if type(message_attributes) == tuple: + return message_attributes[0], message_attributes[1] + + entries[index] = { 'Id': value[0], 'MessageBody': self.querystring.get( - 'SendMessageBatchRequestEntry.{}.MessageBody'.format(match.group(1)))[0] + 'SendMessageBatchRequestEntry.{}.MessageBody'.format(index))[0], + 'DelaySeconds': self.querystring.get( + 'SendMessageBatchRequestEntry.{}.DelaySeconds'.format(index), [None])[0], + 'MessageAttributes': message_attributes } - if any(not re.match(r'^[\w-]{1,80}$', entry['Id']) for entry in entries.values()): - return self._error('AWS.SimpleQueueService.InvalidBatchEntryId', - 'A batch entry id can only contain alphanumeric characters, ' - 'hyphens and underscores. It can be at most 80 letters long.') - - body_length = next( - (len(entry['MessageBody']) for entry in entries.values() if len(entry['MessageBody']) > MAXIMUM_MESSAGE_LENGTH), - False - ) - if body_length: - return self._error('AWS.SimpleQueueService.BatchRequestTooLong', - 'Batch requests cannot be longer than 262144 bytes. ' - 'You have sent {} bytes.'.format(body_length)) - - duplicate_id = self._get_first_duplicate_id([entry['Id'] for entry in entries.values()]) - if duplicate_id: - return self._error('AWS.SimpleQueueService.BatchEntryIdsNotDistinct', - 'Id {} repeated.'.format(duplicate_id)) - - if len(entries) > 10: - return self._error('AWS.SimpleQueueService.TooManyEntriesInBatchRequest', - 'Maximum number of entries per request are 10. ' - 'You have sent 11.') - - messages = [] - for index, entry in entries.items(): - # Loop through looking for messages - delay_key = 'SendMessageBatchRequestEntry.{0}.DelaySeconds'.format( - index) - delay_seconds = self.querystring.get(delay_key, [None])[0] - message = self.sqs_backend.send_message( - queue_name, entry['MessageBody'], delay_seconds=delay_seconds) - message.user_id = entry['Id'] - - message_attributes = parse_message_attributes( - self.querystring, base='SendMessageBatchRequestEntry.{0}.'.format(index)) - if type(message_attributes) == tuple: - return message_attributes[0], message_attributes[1] - message.message_attributes = message_attributes - - messages.append(message) + messages = self.sqs_backend.send_message_batch(queue_name, entries) template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE) return template.render(messages=messages) - def _get_first_duplicate_id(self, ids): - unique_ids = set() - for id in ids: - if id in unique_ids: - return id - unique_ids.add(id) - return None - def delete_message(self): queue_name = self._get_queue_name() receipt_handle = self.querystring.get("ReceiptHandle")[0] @@ -441,7 +403,7 @@ class SQSResponse(BaseResponse): def list_queue_tags(self): queue_name = self._get_queue_name() - queue = self.sqs_backend.get_queue(queue_name) + queue = self.sqs_backend.list_queue_tags(queue_name) template = self.response_template(LIST_QUEUE_TAGS_RESPONSE) return template.render(tags=queue.tags) diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index 1ad2e1a80..1a9038aa5 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -883,10 +883,70 @@ def test_delete_message_after_visibility_timeout(): @mock_sqs -def test_send_message_batch_errors(): - client = boto3.client('sqs', region_name = 'us-east-1') +def test_send_message_batch(): + client = boto3.client('sqs', region_name='us-east-1') + response = client.create_queue(QueueName='test-queue') + queue_url = response['QueueUrl'] - response = client.create_queue(QueueName='test-queue-with-tags') + response = client.send_message_batch( + QueueUrl=queue_url, + Entries=[ + { + 'Id': 'id_1', + 'MessageBody': 'body_1', + 'DelaySeconds': 0, + 'MessageAttributes': { + 'attribute_name_1': { + 'StringValue': 'attribute_value_1', + 'DataType': 'String' + } + } + }, + { + 'Id': 'id_2', + 'MessageBody': 'body_2', + 'DelaySeconds': 0, + 'MessageAttributes': { + 'attribute_name_2': { + 'StringValue': '123', + 'DataType': 'Number' + } + } + } + ] + ) + + sorted([entry['Id'] for entry in response['Successful']]).should.equal([ + 'id_1', + 'id_2' + ]) + + response = client.receive_message( + QueueUrl=queue_url, + MaxNumberOfMessages=10 + ) + + response['Messages'][0]['Body'].should.equal('body_1') + response['Messages'][0]['MessageAttributes'].should.equal({ + 'attribute_name_1': { + 'StringValue': 'attribute_value_1', + 'DataType': 'String' + } + }) + response['Messages'][1]['Body'].should.equal('body_2') + response['Messages'][1]['MessageAttributes'].should.equal({ + 'attribute_name_2': { + 'StringValue': '123', + 'DataType': 'Number' + } + }) + + +@mock_sqs +def test_send_message_batch_errors(): + client = boto3.client('sqs', region_name='us-east-1') + + response = client.create_queue(QueueName='test-queue') queue_url = response['QueueUrl'] client.send_message_batch.when.called_with(