Merge pull request #398 from jotes/sqs_polling_messages
Added support for WaitTimeSeconds in SQS #392
This commit is contained in:
commit
ebfe7bb7b8
@ -103,11 +103,15 @@ class Queue(object):
|
|||||||
'MessageRetentionPeriod',
|
'MessageRetentionPeriod',
|
||||||
'QueueArn',
|
'QueueArn',
|
||||||
'ReceiveMessageWaitTimeSeconds',
|
'ReceiveMessageWaitTimeSeconds',
|
||||||
'VisibilityTimeout']
|
'VisibilityTimeout',
|
||||||
|
'WaitTimeSeconds']
|
||||||
|
|
||||||
def __init__(self, name, visibility_timeout):
|
def __init__(self, name, visibility_timeout, wait_time_seconds):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.visibility_timeout = visibility_timeout or 30
|
self.visibility_timeout = visibility_timeout or 30
|
||||||
|
|
||||||
|
# wait_time_seconds will be set to immediate return messages
|
||||||
|
self.wait_time_seconds = wait_time_seconds or 0
|
||||||
self._messages = []
|
self._messages = []
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@ -128,6 +132,7 @@ class Queue(object):
|
|||||||
return sqs_backend.create_queue(
|
return sqs_backend.create_queue(
|
||||||
name=properties['QueueName'],
|
name=properties['QueueName'],
|
||||||
visibility_timeout=properties.get('VisibilityTimeout'),
|
visibility_timeout=properties.get('VisibilityTimeout'),
|
||||||
|
wait_time_seconds=properties.get('WaitTimeSeconds')
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -139,6 +144,9 @@ class Queue(object):
|
|||||||
queue = sqs_backend.get_queue(queue_name)
|
queue = sqs_backend.get_queue(queue_name)
|
||||||
if 'VisibilityTimeout' in properties:
|
if 'VisibilityTimeout' in properties:
|
||||||
queue.visibility_timeout = int(properties['VisibilityTimeout'])
|
queue.visibility_timeout = int(properties['VisibilityTimeout'])
|
||||||
|
|
||||||
|
if 'WaitTimeSeconds' in properties:
|
||||||
|
queue.wait_time_seconds = int(properties['WaitTimeSeconds'])
|
||||||
return queue
|
return queue
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -192,10 +200,10 @@ class SQSBackend(BaseBackend):
|
|||||||
self.queues = {}
|
self.queues = {}
|
||||||
super(SQSBackend, self).__init__()
|
super(SQSBackend, self).__init__()
|
||||||
|
|
||||||
def create_queue(self, name, visibility_timeout):
|
def create_queue(self, name, visibility_timeout, wait_time_seconds):
|
||||||
queue = self.queues.get(name)
|
queue = self.queues.get(name)
|
||||||
if queue is None:
|
if queue is None:
|
||||||
queue = Queue(name, visibility_timeout)
|
queue = Queue(name, visibility_timeout, wait_time_seconds)
|
||||||
self.queues[name] = queue
|
self.queues[name] = queue
|
||||||
return queue
|
return queue
|
||||||
|
|
||||||
@ -246,7 +254,7 @@ class SQSBackend(BaseBackend):
|
|||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def receive_messages(self, queue_name, count):
|
def receive_messages(self, queue_name, count, wait_seconds_timeout):
|
||||||
"""
|
"""
|
||||||
Attempt to retrieve visible messages from a queue.
|
Attempt to retrieve visible messages from a queue.
|
||||||
|
|
||||||
@ -260,13 +268,20 @@ class SQSBackend(BaseBackend):
|
|||||||
"""
|
"""
|
||||||
queue = self.get_queue(queue_name)
|
queue = self.get_queue(queue_name)
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
|
polling_end = time.time() + wait_seconds_timeout
|
||||||
|
|
||||||
# queue.messages only contains visible messages
|
# queue.messages only contains visible messages
|
||||||
for message in queue.messages:
|
while True:
|
||||||
message.mark_received(
|
for message in queue.messages:
|
||||||
visibility_timeout=queue.visibility_timeout
|
message.mark_received(
|
||||||
)
|
visibility_timeout=queue.visibility_timeout
|
||||||
result.append(message)
|
)
|
||||||
if len(result) >= count:
|
result.append(message)
|
||||||
|
if len(result) >= count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if time.time() > polling_end:
|
||||||
break
|
break
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
@ -23,6 +23,12 @@ class SQSResponse(BaseResponse):
|
|||||||
def sqs_backend(self):
|
def sqs_backend(self):
|
||||||
return sqs_backends[self.region]
|
return sqs_backends[self.region]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attribute(self):
|
||||||
|
if not hasattr(self, '_attribute'):
|
||||||
|
self._attribute = dict([(a['name'], a['value']) for a in self._get_list_prefix('Attribute')])
|
||||||
|
return self._attribute
|
||||||
|
|
||||||
def _get_queue_name(self):
|
def _get_queue_name(self):
|
||||||
try:
|
try:
|
||||||
queue_name = self.querystring.get('QueueUrl')[0].split("/")[-1]
|
queue_name = self.querystring.get('QueueUrl')[0].split("/")[-1]
|
||||||
@ -32,12 +38,9 @@ class SQSResponse(BaseResponse):
|
|||||||
return queue_name
|
return queue_name
|
||||||
|
|
||||||
def create_queue(self):
|
def create_queue(self):
|
||||||
visibility_timeout = None
|
|
||||||
if 'Attribute.1.Name' in self.querystring and self.querystring.get('Attribute.1.Name')[0] == 'VisibilityTimeout':
|
|
||||||
visibility_timeout = self.querystring.get("Attribute.1.Value")[0]
|
|
||||||
|
|
||||||
queue_name = self.querystring.get("QueueName")[0]
|
queue_name = self.querystring.get("QueueName")[0]
|
||||||
queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=visibility_timeout)
|
queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=self.attribute.get('VisibilityTimeout'),
|
||||||
|
wait_time_seconds=self.attribute.get('WaitTimeSeconds'))
|
||||||
template = self.response_template(CREATE_QUEUE_RESPONSE)
|
template = self.response_template(CREATE_QUEUE_RESPONSE)
|
||||||
return template.render(queue=queue)
|
return template.render(queue=queue)
|
||||||
|
|
||||||
@ -209,11 +212,19 @@ class SQSResponse(BaseResponse):
|
|||||||
|
|
||||||
def receive_message(self):
|
def receive_message(self):
|
||||||
queue_name = self._get_queue_name()
|
queue_name = self._get_queue_name()
|
||||||
|
queue = self.sqs_backend.get_queue(queue_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message_count = int(self.querystring.get("MaxNumberOfMessages")[0])
|
message_count = int(self.querystring.get("MaxNumberOfMessages")[0])
|
||||||
except TypeError:
|
except TypeError:
|
||||||
message_count = DEFAULT_RECEIVED_MESSAGES
|
message_count = DEFAULT_RECEIVED_MESSAGES
|
||||||
messages = self.sqs_backend.receive_messages(queue_name, message_count)
|
|
||||||
|
try:
|
||||||
|
wait_time = int(self.querystring.get("WaitTimeSeconds")[0])
|
||||||
|
except TypeError:
|
||||||
|
wait_time = queue.wait_time_seconds
|
||||||
|
|
||||||
|
messages = self.sqs_backend.receive_messages(queue_name, message_count, wait_time)
|
||||||
template = self.response_template(RECEIVE_MESSAGE_RESPONSE)
|
template = self.response_template(RECEIVE_MESSAGE_RESPONSE)
|
||||||
output = template.render(messages=messages)
|
output = template.render(messages=messages)
|
||||||
return output
|
return output
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import sure # noqa
|
import sure # noqa
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
import moto.server as server
|
import moto.server as server
|
||||||
|
|
||||||
@ -30,3 +33,38 @@ def test_sqs_list_identities():
|
|||||||
|
|
||||||
message = re.search("<Body>(.*?)</Body>", res.data.decode('utf-8')).groups()[0]
|
message = re.search("<Body>(.*?)</Body>", res.data.decode('utf-8')).groups()[0]
|
||||||
message.should.equal('test-message')
|
message.should.equal('test-message')
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_polling():
|
||||||
|
backend = server.create_backend_app("sqs")
|
||||||
|
test_client = backend.test_client()
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
test_client.put('/?Action=CreateQueue&QueueName=testqueue')
|
||||||
|
|
||||||
|
def insert_messages():
|
||||||
|
messages_count = 5
|
||||||
|
while messages_count > 0:
|
||||||
|
test_client.put(
|
||||||
|
'/123/testqueue?MessageBody=test-message&Action=SendMessage'
|
||||||
|
'&Attribute.1.Name=WaitTimeSeconds&Attribute.1.Value=10'
|
||||||
|
)
|
||||||
|
messages_count -= 1
|
||||||
|
time.sleep(.5)
|
||||||
|
|
||||||
|
def get_messages():
|
||||||
|
msg_res = test_client.get(
|
||||||
|
'/123/testqueue?Action=ReceiveMessage&MaxNumberOfMessages=1&WaitTimeSeconds=5'
|
||||||
|
)
|
||||||
|
[messages.append(m) for m in re.findall("<Body>(.*?)</Body>", msg_res.data.decode('utf-8'))]
|
||||||
|
|
||||||
|
get_messages_thread = threading.Thread(target=get_messages)
|
||||||
|
insert_messages_thread = threading.Thread(target=insert_messages)
|
||||||
|
|
||||||
|
get_messages_thread.start()
|
||||||
|
insert_messages_thread.start()
|
||||||
|
|
||||||
|
get_messages_thread.join()
|
||||||
|
insert_messages_thread.join()
|
||||||
|
|
||||||
|
assert len(messages) == 5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user