Merge pull request #398 from jotes/sqs_polling_messages

Added support for WaitTimeSeconds in SQS #392
This commit is contained in:
Steve Pulec 2015-08-13 17:24:27 -04:00
commit ebfe7bb7b8
3 changed files with 81 additions and 17 deletions

View File

@ -103,11 +103,15 @@ class Queue(object):
'MessageRetentionPeriod',
'QueueArn',
'ReceiveMessageWaitTimeSeconds',
'VisibilityTimeout']
'VisibilityTimeout',
'WaitTimeSeconds']
def __init__(self, name, visibility_timeout):
def __init__(self, name, visibility_timeout, wait_time_seconds):
self.name = name
self.visibility_timeout = visibility_timeout or 30
# wait_time_seconds will be set to immediate return messages
self.wait_time_seconds = wait_time_seconds or 0
self._messages = []
now = time.time()
@ -128,6 +132,7 @@ class Queue(object):
return sqs_backend.create_queue(
name=properties['QueueName'],
visibility_timeout=properties.get('VisibilityTimeout'),
wait_time_seconds=properties.get('WaitTimeSeconds')
)
@classmethod
@ -139,6 +144,9 @@ class Queue(object):
queue = sqs_backend.get_queue(queue_name)
if 'VisibilityTimeout' in properties:
queue.visibility_timeout = int(properties['VisibilityTimeout'])
if 'WaitTimeSeconds' in properties:
queue.wait_time_seconds = int(properties['WaitTimeSeconds'])
return queue
@classmethod
@ -192,10 +200,10 @@ class SQSBackend(BaseBackend):
self.queues = {}
super(SQSBackend, self).__init__()
def create_queue(self, name, visibility_timeout):
def create_queue(self, name, visibility_timeout, wait_time_seconds):
queue = self.queues.get(name)
if queue is None:
queue = Queue(name, visibility_timeout)
queue = Queue(name, visibility_timeout, wait_time_seconds)
self.queues[name] = queue
return queue
@ -246,7 +254,7 @@ class SQSBackend(BaseBackend):
return message
def receive_messages(self, queue_name, count):
def receive_messages(self, queue_name, count, wait_seconds_timeout):
"""
Attempt to retrieve visible messages from a queue.
@ -260,13 +268,20 @@ class SQSBackend(BaseBackend):
"""
queue = self.get_queue(queue_name)
result = []
polling_end = time.time() + wait_seconds_timeout
# queue.messages only contains visible messages
for message in queue.messages:
message.mark_received(
visibility_timeout=queue.visibility_timeout
)
result.append(message)
if len(result) >= count:
while True:
for message in queue.messages:
message.mark_received(
visibility_timeout=queue.visibility_timeout
)
result.append(message)
if len(result) >= count:
break
if time.time() > polling_end:
break
return result

View File

@ -23,6 +23,12 @@ class SQSResponse(BaseResponse):
def sqs_backend(self):
return sqs_backends[self.region]
@property
def attribute(self):
if not hasattr(self, '_attribute'):
self._attribute = dict([(a['name'], a['value']) for a in self._get_list_prefix('Attribute')])
return self._attribute
def _get_queue_name(self):
try:
queue_name = self.querystring.get('QueueUrl')[0].split("/")[-1]
@ -32,12 +38,9 @@ class SQSResponse(BaseResponse):
return queue_name
def create_queue(self):
visibility_timeout = None
if 'Attribute.1.Name' in self.querystring and self.querystring.get('Attribute.1.Name')[0] == 'VisibilityTimeout':
visibility_timeout = self.querystring.get("Attribute.1.Value")[0]
queue_name = self.querystring.get("QueueName")[0]
queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=visibility_timeout)
queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=self.attribute.get('VisibilityTimeout'),
wait_time_seconds=self.attribute.get('WaitTimeSeconds'))
template = self.response_template(CREATE_QUEUE_RESPONSE)
return template.render(queue=queue)
@ -209,11 +212,19 @@ class SQSResponse(BaseResponse):
def receive_message(self):
queue_name = self._get_queue_name()
queue = self.sqs_backend.get_queue(queue_name)
try:
message_count = int(self.querystring.get("MaxNumberOfMessages")[0])
except TypeError:
message_count = DEFAULT_RECEIVED_MESSAGES
messages = self.sqs_backend.receive_messages(queue_name, message_count)
try:
wait_time = int(self.querystring.get("WaitTimeSeconds")[0])
except TypeError:
wait_time = queue.wait_time_seconds
messages = self.sqs_backend.receive_messages(queue_name, message_count, wait_time)
template = self.response_template(RECEIVE_MESSAGE_RESPONSE)
output = template.render(messages=messages)
return output

View File

@ -1,6 +1,9 @@
from __future__ import unicode_literals
import re
import sure # noqa
import threading
import time
import moto.server as server
@ -30,3 +33,38 @@ def test_sqs_list_identities():
message = re.search("<Body>(.*?)</Body>", res.data.decode('utf-8')).groups()[0]
message.should.equal('test-message')
def test_messages_polling():
backend = server.create_backend_app("sqs")
test_client = backend.test_client()
messages = []
test_client.put('/?Action=CreateQueue&QueueName=testqueue')
def insert_messages():
messages_count = 5
while messages_count > 0:
test_client.put(
'/123/testqueue?MessageBody=test-message&Action=SendMessage'
'&Attribute.1.Name=WaitTimeSeconds&Attribute.1.Value=10'
)
messages_count -= 1
time.sleep(.5)
def get_messages():
msg_res = test_client.get(
'/123/testqueue?Action=ReceiveMessage&MaxNumberOfMessages=1&WaitTimeSeconds=5'
)
[messages.append(m) for m in re.findall("<Body>(.*?)</Body>", msg_res.data.decode('utf-8'))]
get_messages_thread = threading.Thread(target=get_messages)
insert_messages_thread = threading.Thread(target=insert_messages)
get_messages_thread.start()
insert_messages_thread.start()
get_messages_thread.join()
insert_messages_thread.join()
assert len(messages) == 5