Refactor sqs.get_queue_attributes & add AttributeNames handling

This commit is contained in:
gruebel 2019-10-27 12:13:33 +01:00
parent 6573f69087
commit c3cb411c07
6 changed files with 149 additions and 16 deletions

View File

@ -6065,7 +6065,7 @@
- [X] untag_resource - [X] untag_resource
## sqs ## sqs
80% implemented 85% implemented
- [X] add_permission - [X] add_permission
- [X] change_message_visibility - [X] change_message_visibility
- [ ] change_message_visibility_batch - [ ] change_message_visibility_batch
@ -6073,7 +6073,7 @@
- [X] delete_message - [X] delete_message
- [ ] delete_message_batch - [ ] delete_message_batch
- [X] delete_queue - [X] delete_queue
- [ ] get_queue_attributes - [X] get_queue_attributes
- [X] get_queue_url - [X] get_queue_url
- [X] list_dead_letter_source_queues - [X] list_dead_letter_source_queues
- [X] list_queue_tags - [X] list_queue_tags

View File

@ -454,7 +454,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
index = 1 index = 1
while True: while True:
value_dict = self._get_multi_param_helper(prefix + str(index)) value_dict = self._get_multi_param_helper(prefix + str(index))
if not value_dict: if not value_dict and value_dict != '':
break break
values.append(value_dict) values.append(value_dict)

View File

@ -86,3 +86,13 @@ class TooManyEntriesInBatchRequest(RESTError):
'Maximum number of entries per request are 10. ' 'Maximum number of entries per request are 10. '
'You have sent {}.'.format(number) 'You have sent {}.'.format(number)
) )
class InvalidAttributeName(RESTError):
code = 400
def __init__(self, attribute_name):
super(InvalidAttributeName, self).__init__(
'InvalidAttributeName',
'Unknown Attribute {}.'.format(attribute_name)
)

View File

@ -23,7 +23,8 @@ from .exceptions import (
InvalidBatchEntryId, InvalidBatchEntryId,
BatchRequestTooLong, BatchRequestTooLong,
BatchEntryIdsNotDistinct, BatchEntryIdsNotDistinct,
TooManyEntriesInBatchRequest TooManyEntriesInBatchRequest,
InvalidAttributeName
) )
DEFAULT_ACCOUNT_ID = 123456789012 DEFAULT_ACCOUNT_ID = 123456789012
@ -161,7 +162,7 @@ class Message(BaseModel):
class Queue(BaseModel): class Queue(BaseModel):
base_attributes = ['ApproximateNumberOfMessages', BASE_ATTRIBUTES = ['ApproximateNumberOfMessages',
'ApproximateNumberOfMessagesDelayed', 'ApproximateNumberOfMessagesDelayed',
'ApproximateNumberOfMessagesNotVisible', 'ApproximateNumberOfMessagesNotVisible',
'CreatedTimestamp', 'CreatedTimestamp',
@ -172,9 +173,9 @@ class Queue(BaseModel):
'QueueArn', 'QueueArn',
'ReceiveMessageWaitTimeSeconds', 'ReceiveMessageWaitTimeSeconds',
'VisibilityTimeout'] 'VisibilityTimeout']
fifo_attributes = ['FifoQueue', FIFO_ATTRIBUTES = ['FifoQueue',
'ContentBasedDeduplication'] 'ContentBasedDeduplication']
kms_attributes = ['KmsDataKeyReusePeriodSeconds', KMS_ATTRIBUTES = ['KmsDataKeyReusePeriodSeconds',
'KmsMasterKeyId'] 'KmsMasterKeyId']
ALLOWED_PERMISSIONS = ('*', 'ChangeMessageVisibility', 'DeleteMessage', ALLOWED_PERMISSIONS = ('*', 'ChangeMessageVisibility', 'DeleteMessage',
'GetQueueAttributes', 'GetQueueUrl', 'GetQueueAttributes', 'GetQueueUrl',
@ -191,8 +192,9 @@ class Queue(BaseModel):
now = unix_time() now = unix_time()
self.created_timestamp = now self.created_timestamp = now
self.queue_arn = 'arn:aws:sqs:{0}:123456789012:{1}'.format(self.region, self.queue_arn = 'arn:aws:sqs:{0}:{1}:{2}'.format(self.region,
self.name) DEFAULT_ACCOUNT_ID,
self.name)
self.dead_letter_queue = None self.dead_letter_queue = None
self.lambda_event_source_mappings = {} self.lambda_event_source_mappings = {}
@ -336,17 +338,17 @@ class Queue(BaseModel):
def attributes(self): def attributes(self):
result = {} result = {}
for attribute in self.base_attributes: for attribute in self.BASE_ATTRIBUTES:
attr = getattr(self, camelcase_to_underscores(attribute)) attr = getattr(self, camelcase_to_underscores(attribute))
result[attribute] = attr result[attribute] = attr
if self.fifo_queue: if self.fifo_queue:
for attribute in self.fifo_attributes: for attribute in self.FIFO_ATTRIBUTES:
attr = getattr(self, camelcase_to_underscores(attribute)) attr = getattr(self, camelcase_to_underscores(attribute))
result[attribute] = attr result[attribute] = attr
if self.kms_master_key_id: if self.kms_master_key_id:
for attribute in self.kms_attributes: for attribute in self.KMS_ATTRIBUTES:
attr = getattr(self, camelcase_to_underscores(attribute)) attr = getattr(self, camelcase_to_underscores(attribute))
result[attribute] = attr result[attribute] = attr
@ -491,6 +493,28 @@ class SQSBackend(BaseBackend):
return self.queues.pop(queue_name) return self.queues.pop(queue_name)
return False return False
def get_queue_attributes(self, queue_name, attribute_names):
queue = self.get_queue(queue_name)
if not len(attribute_names):
attribute_names.append('All')
valid_names = ['All'] + queue.BASE_ATTRIBUTES + queue.FIFO_ATTRIBUTES + queue.KMS_ATTRIBUTES
invalid_name = next((name for name in attribute_names if name not in valid_names), None)
if invalid_name or invalid_name == '':
raise InvalidAttributeName(invalid_name)
attributes = {}
if 'All' in attribute_names:
attributes = queue.attributes
else:
for name in (name for name in attribute_names if name in queue.attributes):
attributes[name] = queue.attributes.get(name)
return attributes
def set_queue_attributes(self, queue_name, attributes): def set_queue_attributes(self, queue_name, attributes):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
queue._set_attributes(attributes) queue._set_attributes(attributes)

View File

@ -11,7 +11,8 @@ from .exceptions import (
MessageAttributesInvalid, MessageAttributesInvalid,
MessageNotInflight, MessageNotInflight,
ReceiptHandleIsInvalid, ReceiptHandleIsInvalid,
EmptyBatchRequest EmptyBatchRequest,
InvalidAttributeName
) )
MAXIMUM_VISIBILTY_TIMEOUT = 43200 MAXIMUM_VISIBILTY_TIMEOUT = 43200
@ -169,10 +170,15 @@ class SQSResponse(BaseResponse):
def get_queue_attributes(self): def get_queue_attributes(self):
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
queue = self.sqs_backend.get_queue(queue_name) if self.querystring.get('AttributeNames'):
raise InvalidAttributeName('')
attribute_names = self._get_multi_param('AttributeName')
attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names)
template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE) template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE)
return template.render(queue=queue) return template.render(attributes=attributes)
def set_queue_attributes(self): def set_queue_attributes(self):
# TODO validate self.get_param('QueueUrl') # TODO validate self.get_param('QueueUrl')
@ -443,7 +449,7 @@ DELETE_QUEUE_RESPONSE = """<DeleteQueueResponse>
GET_QUEUE_ATTRIBUTES_RESPONSE = """<GetQueueAttributesResponse> GET_QUEUE_ATTRIBUTES_RESPONSE = """<GetQueueAttributesResponse>
<GetQueueAttributesResult> <GetQueueAttributesResult>
{% for key, value in queue.attributes.items() %} {% for key, value in attributes.items() %}
<Attribute> <Attribute>
<Name>{{ key }}</Name> <Name>{{ key }}</Name>
<Value>{{ value }}</Value> <Value>{{ value }}</Value>

View File

@ -5,6 +5,7 @@ import os
import boto import boto
import boto3 import boto3
import botocore.exceptions import botocore.exceptions
import six
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from boto.exception import SQSError from boto.exception import SQSError
from boto.sqs.message import RawMessage, Message from boto.sqs.message import RawMessage, Message
@ -365,6 +366,98 @@ def test_delete_queue():
queue.delete() queue.delete()
@mock_sqs
def test_get_queue_attributes():
client = boto3.client('sqs', region_name='us-east-1')
response = client.create_queue(QueueName='test-queue')
queue_url = response['QueueUrl']
response = client.get_queue_attributes(QueueUrl=queue_url)
response['Attributes']['ApproximateNumberOfMessages'].should.equal('0')
response['Attributes']['ApproximateNumberOfMessagesDelayed'].should.equal('0')
response['Attributes']['ApproximateNumberOfMessagesNotVisible'].should.equal('0')
response['Attributes']['CreatedTimestamp'].should.be.a(six.string_types)
response['Attributes']['DelaySeconds'].should.equal('0')
response['Attributes']['LastModifiedTimestamp'].should.be.a(six.string_types)
response['Attributes']['MaximumMessageSize'].should.equal('65536')
response['Attributes']['MessageRetentionPeriod'].should.equal('345600')
response['Attributes']['QueueArn'].should.equal('arn:aws:sqs:us-east-1:123456789012:test-queue')
response['Attributes']['ReceiveMessageWaitTimeSeconds'].should.equal('0')
response['Attributes']['VisibilityTimeout'].should.equal('30')
response = client.get_queue_attributes(
QueueUrl=queue_url,
AttributeNames=[
'ApproximateNumberOfMessages',
'MaximumMessageSize',
'QueueArn',
'VisibilityTimeout'
]
)
response['Attributes'].should.equal({
'ApproximateNumberOfMessages': '0',
'MaximumMessageSize': '65536',
'QueueArn': 'arn:aws:sqs:us-east-1:123456789012:test-queue',
'VisibilityTimeout': '30'
})
# should not return any attributes, if it was not set before
response = client.get_queue_attributes(
QueueUrl=queue_url,
AttributeNames=[
'KmsMasterKeyId'
]
)
response.should_not.have.key('Attributes')
@mock_sqs
def test_get_queue_attributes_errors():
client = boto3.client('sqs', region_name='us-east-1')
response = client.create_queue(QueueName='test-queue')
queue_url = response['QueueUrl']
client.get_queue_attributes.when.called_with(
QueueUrl=queue_url + '-non-existing'
).should.throw(
ClientError,
'The specified queue does not exist for this wsdl version.'
)
client.get_queue_attributes.when.called_with(
QueueUrl=queue_url,
AttributeNames=[
'QueueArn',
'not-existing',
'VisibilityTimeout'
]
).should.throw(
ClientError,
'Unknown Attribute not-existing.'
)
client.get_queue_attributes.when.called_with(
QueueUrl=queue_url,
AttributeNames=[
''
]
).should.throw(
ClientError,
'Unknown Attribute .'
)
client.get_queue_attributes.when.called_with(
QueueUrl = queue_url,
AttributeNames = []
).should.throw(
ClientError,
'Unknown Attribute .'
)
@mock_sqs @mock_sqs
def test_set_queue_attribute(): def test_set_queue_attribute():
sqs = boto3.resource('sqs', region_name='us-east-1') sqs = boto3.resource('sqs', region_name='us-east-1')