Merge pull request #398 from jotes/sqs_polling_messages
Added support for WaitTimeSeconds in SQS #392
This commit is contained in:
commit
ebfe7bb7b8
@ -103,11 +103,15 @@ class Queue(object):
|
||||
'MessageRetentionPeriod',
|
||||
'QueueArn',
|
||||
'ReceiveMessageWaitTimeSeconds',
|
||||
'VisibilityTimeout']
|
||||
'VisibilityTimeout',
|
||||
'WaitTimeSeconds']
|
||||
|
||||
def __init__(self, name, visibility_timeout):
|
||||
def __init__(self, name, visibility_timeout, wait_time_seconds):
|
||||
self.name = name
|
||||
self.visibility_timeout = visibility_timeout or 30
|
||||
|
||||
# wait_time_seconds will be set to immediate return messages
|
||||
self.wait_time_seconds = wait_time_seconds or 0
|
||||
self._messages = []
|
||||
|
||||
now = time.time()
|
||||
@ -128,6 +132,7 @@ class Queue(object):
|
||||
return sqs_backend.create_queue(
|
||||
name=properties['QueueName'],
|
||||
visibility_timeout=properties.get('VisibilityTimeout'),
|
||||
wait_time_seconds=properties.get('WaitTimeSeconds')
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -139,6 +144,9 @@ class Queue(object):
|
||||
queue = sqs_backend.get_queue(queue_name)
|
||||
if 'VisibilityTimeout' in properties:
|
||||
queue.visibility_timeout = int(properties['VisibilityTimeout'])
|
||||
|
||||
if 'WaitTimeSeconds' in properties:
|
||||
queue.wait_time_seconds = int(properties['WaitTimeSeconds'])
|
||||
return queue
|
||||
|
||||
@classmethod
|
||||
@ -192,10 +200,10 @@ class SQSBackend(BaseBackend):
|
||||
self.queues = {}
|
||||
super(SQSBackend, self).__init__()
|
||||
|
||||
def create_queue(self, name, visibility_timeout):
|
||||
def create_queue(self, name, visibility_timeout, wait_time_seconds):
|
||||
queue = self.queues.get(name)
|
||||
if queue is None:
|
||||
queue = Queue(name, visibility_timeout)
|
||||
queue = Queue(name, visibility_timeout, wait_time_seconds)
|
||||
self.queues[name] = queue
|
||||
return queue
|
||||
|
||||
@ -246,7 +254,7 @@ class SQSBackend(BaseBackend):
|
||||
|
||||
return message
|
||||
|
||||
def receive_messages(self, queue_name, count):
|
||||
def receive_messages(self, queue_name, count, wait_seconds_timeout):
|
||||
"""
|
||||
Attempt to retrieve visible messages from a queue.
|
||||
|
||||
@ -260,13 +268,20 @@ class SQSBackend(BaseBackend):
|
||||
"""
|
||||
queue = self.get_queue(queue_name)
|
||||
result = []
|
||||
|
||||
polling_end = time.time() + wait_seconds_timeout
|
||||
|
||||
# queue.messages only contains visible messages
|
||||
for message in queue.messages:
|
||||
message.mark_received(
|
||||
visibility_timeout=queue.visibility_timeout
|
||||
)
|
||||
result.append(message)
|
||||
if len(result) >= count:
|
||||
while True:
|
||||
for message in queue.messages:
|
||||
message.mark_received(
|
||||
visibility_timeout=queue.visibility_timeout
|
||||
)
|
||||
result.append(message)
|
||||
if len(result) >= count:
|
||||
break
|
||||
|
||||
if time.time() > polling_end:
|
||||
break
|
||||
|
||||
return result
|
||||
|
@ -23,6 +23,12 @@ class SQSResponse(BaseResponse):
|
||||
def sqs_backend(self):
|
||||
return sqs_backends[self.region]
|
||||
|
||||
@property
|
||||
def attribute(self):
|
||||
if not hasattr(self, '_attribute'):
|
||||
self._attribute = dict([(a['name'], a['value']) for a in self._get_list_prefix('Attribute')])
|
||||
return self._attribute
|
||||
|
||||
def _get_queue_name(self):
|
||||
try:
|
||||
queue_name = self.querystring.get('QueueUrl')[0].split("/")[-1]
|
||||
@ -32,12 +38,9 @@ class SQSResponse(BaseResponse):
|
||||
return queue_name
|
||||
|
||||
def create_queue(self):
|
||||
visibility_timeout = None
|
||||
if 'Attribute.1.Name' in self.querystring and self.querystring.get('Attribute.1.Name')[0] == 'VisibilityTimeout':
|
||||
visibility_timeout = self.querystring.get("Attribute.1.Value")[0]
|
||||
|
||||
queue_name = self.querystring.get("QueueName")[0]
|
||||
queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=visibility_timeout)
|
||||
queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=self.attribute.get('VisibilityTimeout'),
|
||||
wait_time_seconds=self.attribute.get('WaitTimeSeconds'))
|
||||
template = self.response_template(CREATE_QUEUE_RESPONSE)
|
||||
return template.render(queue=queue)
|
||||
|
||||
@ -209,11 +212,19 @@ class SQSResponse(BaseResponse):
|
||||
|
||||
def receive_message(self):
|
||||
queue_name = self._get_queue_name()
|
||||
queue = self.sqs_backend.get_queue(queue_name)
|
||||
|
||||
try:
|
||||
message_count = int(self.querystring.get("MaxNumberOfMessages")[0])
|
||||
except TypeError:
|
||||
message_count = DEFAULT_RECEIVED_MESSAGES
|
||||
messages = self.sqs_backend.receive_messages(queue_name, message_count)
|
||||
|
||||
try:
|
||||
wait_time = int(self.querystring.get("WaitTimeSeconds")[0])
|
||||
except TypeError:
|
||||
wait_time = queue.wait_time_seconds
|
||||
|
||||
messages = self.sqs_backend.receive_messages(queue_name, message_count, wait_time)
|
||||
template = self.response_template(RECEIVE_MESSAGE_RESPONSE)
|
||||
output = template.render(messages=messages)
|
||||
return output
|
||||
|
@ -1,6 +1,9 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import re
|
||||
import sure # noqa
|
||||
import threading
|
||||
import time
|
||||
|
||||
import moto.server as server
|
||||
|
||||
@ -30,3 +33,38 @@ def test_sqs_list_identities():
|
||||
|
||||
message = re.search("<Body>(.*?)</Body>", res.data.decode('utf-8')).groups()[0]
|
||||
message.should.equal('test-message')
|
||||
|
||||
|
||||
def test_messages_polling():
|
||||
backend = server.create_backend_app("sqs")
|
||||
test_client = backend.test_client()
|
||||
messages = []
|
||||
|
||||
test_client.put('/?Action=CreateQueue&QueueName=testqueue')
|
||||
|
||||
def insert_messages():
|
||||
messages_count = 5
|
||||
while messages_count > 0:
|
||||
test_client.put(
|
||||
'/123/testqueue?MessageBody=test-message&Action=SendMessage'
|
||||
'&Attribute.1.Name=WaitTimeSeconds&Attribute.1.Value=10'
|
||||
)
|
||||
messages_count -= 1
|
||||
time.sleep(.5)
|
||||
|
||||
def get_messages():
|
||||
msg_res = test_client.get(
|
||||
'/123/testqueue?Action=ReceiveMessage&MaxNumberOfMessages=1&WaitTimeSeconds=5'
|
||||
)
|
||||
[messages.append(m) for m in re.findall("<Body>(.*?)</Body>", msg_res.data.decode('utf-8'))]
|
||||
|
||||
get_messages_thread = threading.Thread(target=get_messages)
|
||||
insert_messages_thread = threading.Thread(target=insert_messages)
|
||||
|
||||
get_messages_thread.start()
|
||||
insert_messages_thread.start()
|
||||
|
||||
get_messages_thread.join()
|
||||
insert_messages_thread.join()
|
||||
|
||||
assert len(messages) == 5
|
||||
|
Loading…
Reference in New Issue
Block a user