diff --git a/moto/sqs/models.py b/moto/sqs/models.py index 33abe3684..31d919d87 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -103,11 +103,15 @@ class Queue(object): 'MessageRetentionPeriod', 'QueueArn', 'ReceiveMessageWaitTimeSeconds', - 'VisibilityTimeout'] + 'VisibilityTimeout', + 'WaitTimeSeconds'] - def __init__(self, name, visibility_timeout): + def __init__(self, name, visibility_timeout, wait_time_seconds): self.name = name self.visibility_timeout = visibility_timeout or 30 + + # wait_time_seconds will be set to immediate return messages + self.wait_time_seconds = wait_time_seconds or 0 self._messages = [] now = time.time() @@ -128,6 +132,7 @@ class Queue(object): return sqs_backend.create_queue( name=properties['QueueName'], visibility_timeout=properties.get('VisibilityTimeout'), + wait_time_seconds=properties.get('WaitTimeSeconds') ) @classmethod @@ -139,6 +144,9 @@ class Queue(object): queue = sqs_backend.get_queue(queue_name) if 'VisibilityTimeout' in properties: queue.visibility_timeout = int(properties['VisibilityTimeout']) + + if 'WaitTimeSeconds' in properties: + queue.wait_time_seconds = int(properties['WaitTimeSeconds']) return queue @classmethod @@ -192,10 +200,10 @@ class SQSBackend(BaseBackend): self.queues = {} super(SQSBackend, self).__init__() - def create_queue(self, name, visibility_timeout): + def create_queue(self, name, visibility_timeout, wait_time_seconds): queue = self.queues.get(name) if queue is None: - queue = Queue(name, visibility_timeout) + queue = Queue(name, visibility_timeout, wait_time_seconds) self.queues[name] = queue return queue @@ -246,7 +254,7 @@ class SQSBackend(BaseBackend): return message - def receive_messages(self, queue_name, count): + def receive_messages(self, queue_name, count, wait_seconds_timeout): """ Attempt to retrieve visible messages from a queue. @@ -260,13 +268,20 @@ class SQSBackend(BaseBackend): """ queue = self.get_queue(queue_name) result = [] + + polling_end = time.time() + wait_seconds_timeout + # queue.messages only contains visible messages - for message in queue.messages: - message.mark_received( - visibility_timeout=queue.visibility_timeout - ) - result.append(message) - if len(result) >= count: + while True: + for message in queue.messages: + message.mark_received( + visibility_timeout=queue.visibility_timeout + ) + result.append(message) + if len(result) >= count: + break + + if time.time() > polling_end: break return result diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index d20372762..648e939d2 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -23,6 +23,12 @@ class SQSResponse(BaseResponse): def sqs_backend(self): return sqs_backends[self.region] + @property + def attribute(self): + if not hasattr(self, '_attribute'): + self._attribute = dict([(a['name'], a['value']) for a in self._get_list_prefix('Attribute')]) + return self._attribute + def _get_queue_name(self): try: queue_name = self.querystring.get('QueueUrl')[0].split("/")[-1] @@ -32,12 +38,9 @@ class SQSResponse(BaseResponse): return queue_name def create_queue(self): - visibility_timeout = None - if 'Attribute.1.Name' in self.querystring and self.querystring.get('Attribute.1.Name')[0] == 'VisibilityTimeout': - visibility_timeout = self.querystring.get("Attribute.1.Value")[0] - queue_name = self.querystring.get("QueueName")[0] - queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=visibility_timeout) + queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=self.attribute.get('VisibilityTimeout'), + wait_time_seconds=self.attribute.get('WaitTimeSeconds')) template = self.response_template(CREATE_QUEUE_RESPONSE) return template.render(queue=queue) @@ -209,11 +212,19 @@ class SQSResponse(BaseResponse): def receive_message(self): queue_name = self._get_queue_name() + queue = self.sqs_backend.get_queue(queue_name) + try: message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) except TypeError: message_count = DEFAULT_RECEIVED_MESSAGES - messages = self.sqs_backend.receive_messages(queue_name, message_count) + + try: + wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) + except TypeError: + wait_time = queue.wait_time_seconds + + messages = self.sqs_backend.receive_messages(queue_name, message_count, wait_time) template = self.response_template(RECEIVE_MESSAGE_RESPONSE) output = template.render(messages=messages) return output diff --git a/tests/test_sqs/test_server.py b/tests/test_sqs/test_server.py index 757d4de99..56888ca5c 100644 --- a/tests/test_sqs/test_server.py +++ b/tests/test_sqs/test_server.py @@ -1,6 +1,9 @@ from __future__ import unicode_literals + import re import sure # noqa +import threading +import time import moto.server as server @@ -30,3 +33,38 @@ def test_sqs_list_identities(): message = re.search("(.*?)", res.data.decode('utf-8')).groups()[0] message.should.equal('test-message') + + +def test_messages_polling(): + backend = server.create_backend_app("sqs") + test_client = backend.test_client() + messages = [] + + test_client.put('/?Action=CreateQueue&QueueName=testqueue') + + def insert_messages(): + messages_count = 5 + while messages_count > 0: + test_client.put( + '/123/testqueue?MessageBody=test-message&Action=SendMessage' + '&Attribute.1.Name=WaitTimeSeconds&Attribute.1.Value=10' + ) + messages_count -= 1 + time.sleep(.5) + + def get_messages(): + msg_res = test_client.get( + '/123/testqueue?Action=ReceiveMessage&MaxNumberOfMessages=1&WaitTimeSeconds=5' + ) + [messages.append(m) for m in re.findall("(.*?)", msg_res.data.decode('utf-8'))] + + get_messages_thread = threading.Thread(target=get_messages) + insert_messages_thread = threading.Thread(target=insert_messages) + + get_messages_thread.start() + insert_messages_thread.start() + + get_messages_thread.join() + insert_messages_thread.join() + + assert len(messages) == 5