Merge pull request #1180 from terrycain/sqs_improvement

SQS Cleanup, and fix #1105
This commit is contained in:
Jack Danger 2017-09-22 13:23:20 -07:00 committed by GitHub
commit d692219927
4 changed files with 107 additions and 48 deletions

View File

@ -310,7 +310,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
param_index += 1
return results
def _get_map_prefix(self, param_prefix):
def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'):
results = {}
param_index = 1
while 1:
@ -319,9 +319,9 @@ class BaseResponse(_TemplateEnvironmentMixin):
k, v = None, None
for key, value in self.querystring.items():
if key.startswith(index_prefix):
if key.endswith('.key'):
if key.endswith(key_end):
k = value[0]
elif key.endswith('.value'):
elif key.endswith(value_end):
v = value[0]
if not (k and v):

View File

@ -12,10 +12,7 @@ import boto.sqs
from moto.core import BaseBackend, BaseModel
from moto.core.utils import camelcase_to_underscores, get_random_message_id, unix_time, unix_time_millis
from .utils import generate_receipt_handle
from .exceptions import (
ReceiptHandleIsInvalid,
MessageNotInflight
)
from .exceptions import ReceiptHandleIsInvalid, MessageNotInflight, MessageAttributesInvalid
DEFAULT_ACCOUNT_ID = 123456789012
DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU"
@ -151,8 +148,12 @@ class Queue(BaseModel):
camelcase_attributes = ['ApproximateNumberOfMessages',
'ApproximateNumberOfMessagesDelayed',
'ApproximateNumberOfMessagesNotVisible',
'ContentBasedDeduplication',
'CreatedTimestamp',
'DelaySeconds',
'FifoQueue',
'KmsDataKeyReusePeriodSeconds',
'KmsMasterKeyId',
'LastModifiedTimestamp',
'MaximumMessageSize',
'MessageRetentionPeriod',
@ -161,25 +162,35 @@ class Queue(BaseModel):
'VisibilityTimeout',
'WaitTimeSeconds']
def __init__(self, name, visibility_timeout, wait_time_seconds, region):
def __init__(self, name, region, **kwargs):
self.name = name
self.visibility_timeout = visibility_timeout or 30
self.visibility_timeout = int(kwargs.get('VisibilityTimeout', 30))
self.region = region
# wait_time_seconds will be set to immediate return messages
self.wait_time_seconds = int(wait_time_seconds) if wait_time_seconds else 0
self._messages = []
now = unix_time()
# kwargs can also have:
# [Policy, RedrivePolicy]
self.fifo_queue = kwargs.get('FifoQueue', 'false') == 'true'
self.content_based_deduplication = kwargs.get('ContentBasedDeduplication', 'false') == 'true'
self.kms_master_key_id = kwargs.get('KmsMasterKeyId', 'alias/aws/sqs')
self.kms_data_key_reuse_period_seconds = int(kwargs.get('KmsDataKeyReusePeriodSeconds', 300))
self.created_timestamp = now
self.delay_seconds = 0
self.delay_seconds = int(kwargs.get('DelaySeconds', 0))
self.last_modified_timestamp = now
self.maximum_message_size = 64 << 10
self.message_retention_period = 86400 * 4 # four days
self.queue_arn = 'arn:aws:sqs:{0}:123456789012:{1}'.format(
self.region, self.name)
self.receive_message_wait_time_seconds = 0
self.maximum_message_size = int(kwargs.get('MaximumMessageSize', 64 << 10))
self.message_retention_period = int(kwargs.get('MessageRetentionPeriod', 86400 * 4)) # four days
self.queue_arn = 'arn:aws:sqs:{0}:123456789012:{1}'.format(self.region, self.name)
self.receive_message_wait_time_seconds = int(kwargs.get('ReceiveMessageWaitTimeSeconds', 0))
# wait_time_seconds will be set to immediate return messages
self.wait_time_seconds = int(kwargs.get('WaitTimeSeconds', 0))
# Check some conditions
if self.fifo_queue and not self.name.endswith('.fifo'):
raise MessageAttributesInvalid('Queue name must end in .fifo for FIFO queues')
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -188,8 +199,8 @@ class Queue(BaseModel):
sqs_backend = sqs_backends[region_name]
return sqs_backend.create_queue(
name=properties['QueueName'],
visibility_timeout=properties.get('VisibilityTimeout'),
wait_time_seconds=properties.get('WaitTimeSeconds')
region=region_name,
**properties
)
@classmethod
@ -233,8 +244,10 @@ class Queue(BaseModel):
def attributes(self):
result = {}
for attribute in self.camelcase_attributes:
result[attribute] = getattr(
self, camelcase_to_underscores(attribute))
attr = getattr(self, camelcase_to_underscores(attribute))
if isinstance(attr, bool):
attr = str(attr).lower()
result[attribute] = attr
return result
def url(self, request_url):
@ -268,11 +281,14 @@ class SQSBackend(BaseBackend):
self.__dict__ = {}
self.__init__(region_name)
def create_queue(self, name, visibility_timeout, wait_time_seconds):
def create_queue(self, name, **kwargs):
queue = self.queues.get(name)
if queue is None:
queue = Queue(name, visibility_timeout,
wait_time_seconds, self.region_name)
try:
kwargs.pop('region')
except KeyError:
pass
queue = Queue(name, region=self.region_name, **kwargs)
self.queues[name] = queue
return queue

View File

@ -28,8 +28,7 @@ class SQSResponse(BaseResponse):
@property
def attribute(self):
if not hasattr(self, '_attribute'):
self._attribute = dict([(a['name'], a['value'])
for a in self._get_list_prefix('Attribute')])
self._attribute = self._get_map_prefix('Attribute', key_end='Name', value_end='Value')
return self._attribute
def _get_queue_name(self):
@ -58,17 +57,25 @@ class SQSResponse(BaseResponse):
return 404, headers, ERROR_INEXISTENT_QUEUE
return status_code, headers, body
def _error(self, code, message, status=400):
template = self.response_template(ERROR_TEMPLATE)
return template.render(code=code, message=message), dict(status=status)
def create_queue(self):
request_url = urlparse(self.uri)
queue_name = self.querystring.get("QueueName")[0]
queue = self.sqs_backend.create_queue(queue_name, visibility_timeout=self.attribute.get('VisibilityTimeout'),
wait_time_seconds=self.attribute.get('WaitTimeSeconds'))
queue_name = self._get_param("QueueName")
try:
queue = self.sqs_backend.create_queue(queue_name, **self.attribute)
except MessageAttributesInvalid as e:
return self._error('InvalidParameterValue', e.description)
template = self.response_template(CREATE_QUEUE_RESPONSE)
return template.render(queue=queue, request_url=request_url)
def get_queue_url(self):
request_url = urlparse(self.uri)
queue_name = self.querystring.get("QueueName")[0]
queue_name = self._get_param("QueueName")
queue = self.sqs_backend.get_queue(queue_name)
if queue:
template = self.response_template(GET_QUEUE_URL_RESPONSE)
@ -78,14 +85,14 @@ class SQSResponse(BaseResponse):
def list_queues(self):
request_url = urlparse(self.uri)
queue_name_prefix = self.querystring.get("QueueNamePrefix", [None])[0]
queue_name_prefix = self._get_param('QueueNamePrefix')
queues = self.sqs_backend.list_queues(queue_name_prefix)
template = self.response_template(LIST_QUEUES_RESPONSE)
return template.render(queues=queues, request_url=request_url)
def change_message_visibility(self):
queue_name = self._get_queue_name()
receipt_handle = self.querystring.get("ReceiptHandle")[0]
receipt_handle = self._get_param('ReceiptHandle')
try:
visibility_timeout = self._get_validated_visibility_timeout()
@ -111,19 +118,15 @@ class SQSResponse(BaseResponse):
return template.render(queue=queue)
def set_queue_attributes(self):
# TODO validate self.get_param('QueueUrl')
queue_name = self._get_queue_name()
if "Attribute.Name" in self.querystring:
key = camelcase_to_underscores(
self.querystring.get("Attribute.Name")[0])
value = self.querystring.get("Attribute.Value")[0]
self.sqs_backend.set_queue_attribute(queue_name, key, value)
for a in self._get_list_prefix("Attribute"):
key = camelcase_to_underscores(a["name"])
value = a["value"]
for key, value in self.attribute.items():
key = camelcase_to_underscores(key)
self.sqs_backend.set_queue_attribute(queue_name, key, value)
return SET_QUEUE_ATTRIBUTE_RESPONSE
def delete_queue(self):
# TODO validate self.get_param('QueueUrl')
queue_name = self._get_queue_name()
queue = self.sqs_backend.delete_queue(queue_name)
if not queue:
@ -133,17 +136,12 @@ class SQSResponse(BaseResponse):
return template.render(queue=queue)
def send_message(self):
message = self.querystring.get("MessageBody")[0]
delay_seconds = self.querystring.get('DelaySeconds')
message = self._get_param('MessageBody')
delay_seconds = int(self._get_param('DelaySeconds', 0))
if len(message) > MAXIMUM_MESSAGE_LENGTH:
return ERROR_TOO_LONG_RESPONSE, dict(status=400)
if delay_seconds:
delay_seconds = int(delay_seconds[0])
else:
delay_seconds = 0
try:
message_attributes = parse_message_attributes(self.querystring)
except MessageAttributesInvalid as e:
@ -470,3 +468,13 @@ ERROR_INEXISTENT_QUEUE = """<ErrorResponse xmlns="http://queue.amazonaws.com/doc
</Error>
<RequestId>b8bc806b-fa6b-53b5-8be8-cfa2f9836bc3</RequestId>
</ErrorResponse>"""
ERROR_TEMPLATE = """<ErrorResponse xmlns="http://queue.amazonaws.com/doc/2012-11-05/">
<Error>
<Type>Sender</Type>
<Code>{{ code }}</Code>
<Message>{{ message }}</Message>
<Detail/>
</Error>
<RequestId>6fde8d1e-52cd-4581-8cd9-c512f4c64223</RequestId>
</ErrorResponse>"""

View File

@ -8,7 +8,6 @@ from boto.exception import SQSError
from boto.sqs.message import RawMessage, Message
import base64
import requests
import sure # noqa
import time
@ -18,6 +17,39 @@ import tests.backport_assert_raises # noqa
from nose.tools import assert_raises
@mock_sqs
def test_create_fifo_queue_fail():
sqs = boto3.client('sqs', region_name='us-east-1')
try:
sqs.create_queue(
QueueName='test-queue',
Attributes={
'FifoQueue': 'true',
}
)
except botocore.exceptions.ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameterValue')
else:
raise RuntimeError('Should of raised InvalidParameterValue Exception')
@mock_sqs
def test_create_fifo_queue():
sqs = boto3.client('sqs', region_name='us-east-1')
resp = sqs.create_queue(
QueueName='test-queue.fifo',
Attributes={
'FifoQueue': 'true',
}
)
queue_url = resp['QueueUrl']
response = sqs.get_queue_attributes(QueueUrl=queue_url)
response['Attributes'].should.contain('FifoQueue')
response['Attributes']['FifoQueue'].should.equal('true')
@mock_sqs
def test_create_queue():
sqs = boto3.resource('sqs', region_name='us-east-1')
@ -39,6 +71,7 @@ def test_get_inexistent_queue():
sqs.get_queue_by_name.when.called_with(
QueueName='nonexisting-queue').should.throw(botocore.exceptions.ClientError)
@mock_sqs
def test_message_send_without_attributes():
sqs = boto3.resource('sqs', region_name='us-east-1')
@ -56,6 +89,7 @@ def test_message_send_without_attributes():
messages = queue.receive_messages()
messages.should.have.length_of(1)
@mock_sqs
def test_message_send_with_attributes():
sqs = boto3.resource('sqs', region_name='us-east-1')
@ -229,6 +263,7 @@ def test_send_receive_message_without_attributes():
message1.shouldnt.have.key('MD5OfMessageAttributes')
message2.shouldnt.have.key('MD5OfMessageAttributes')
@mock_sqs
def test_send_receive_message_with_attributes():
sqs = boto3.resource('sqs', region_name='us-east-1')