Adds FilterExpression to dynamodb.query (#1326)

* Added FilterExpression for dynamodb.query

* flake8

* Fixes using mutable default argument values
This commit is contained in:
Terry Cain 2017-11-08 22:53:31 +00:00 committed by GitHub
parent 6e199d35b3
commit 884fc6f260
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 5 deletions

View File

@ -419,7 +419,7 @@ class Table(BaseModel):
def query(self, hash_key, range_comparison, range_objs, limit,
exclusive_start_key, scan_index_forward, projection_expression,
index_name=None, **filter_kwargs):
index_name=None, filter_expression=None, **filter_kwargs):
results = []
if index_name:
all_indexes = (self.global_indexes or []) + (self.indexes or [])
@ -502,6 +502,9 @@ class Table(BaseModel):
scanned_count = len(list(self.all_items()))
if filter_expression is not None:
results = [item for item in results if filter_expression.expr(item)]
results, last_evaluated_key = self._trim_results(results, limit,
exclusive_start_key)
return results, scanned_count, last_evaluated_key
@ -707,7 +710,9 @@ class DynamoDBBackend(BaseBackend):
return table.get_item(hash_key, range_key)
def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts,
limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None, **filter_kwargs):
limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None,
expr_names=None, expr_values=None, filter_expression=None,
**filter_kwargs):
table = self.tables.get(table_name)
if not table:
return None, None
@ -716,8 +721,13 @@ class DynamoDBBackend(BaseBackend):
range_values = [DynamoType(range_value)
for range_value in range_value_dicts]
if filter_expression is not None:
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, projection_expression, index_name, **filter_kwargs)
exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs)
def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values):
table = self.tables.get(table_name)

View File

@ -298,7 +298,9 @@ class DynamoHandler(BaseResponse):
# {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}}
key_condition_expression = self.body.get('KeyConditionExpression')
projection_expression = self.body.get('ProjectionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
filter_expression = self.body.get('FilterExpression')
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
if projection_expression and expression_attribute_names:
expressions = [x.strip() for x in projection_expression.split(',')]
@ -307,6 +309,7 @@ class DynamoHandler(BaseResponse):
projection_expression = projection_expression.replace(expression, expression_attribute_names[expression])
filter_kwargs = {}
if key_condition_expression:
value_alias_map = self.body['ExpressionAttributeValues']
@ -413,7 +416,9 @@ class DynamoHandler(BaseResponse):
scan_index_forward = self.body.get("ScanIndexForward")
items, scanned_count, last_evaluated_key = self.dynamodb_backend.query(
name, hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, projection_expression, index_name=index_name, **filter_kwargs
exclusive_start_key, scan_index_forward, projection_expression, index_name=index_name,
expr_names=expression_attribute_names, expr_values=expression_attribute_values,
filter_expression=filter_expression, **filter_kwargs
)
if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'

View File

@ -649,6 +649,47 @@ def test_filter_expression():
filter_expr.expr(row1).should.be(True)
@mock_dynamodb2
def test_query_filter():
client = boto3.client('dynamodb', region_name='us-east-1')
dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
# Create the DynamoDB table.
client.create_table(
TableName='test1',
AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}],
KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}],
ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123}
)
client.put_item(
TableName='test1',
Item={
'client': {'S': 'client1'},
'app': {'S': 'app1'}
}
)
client.put_item(
TableName='test1',
Item={
'client': {'S': 'client1'},
'app': {'S': 'app2'}
}
)
table = dynamodb.Table('test1')
response = table.query(
KeyConditionExpression=Key('client').eq('client1')
)
assert response['Count'] == 2
response = table.query(
KeyConditionExpression=Key('client').eq('client1'),
FilterExpression=Attr('app').eq('app2')
)
assert response['Count'] == 1
assert response['Items'][0]['app'] == 'app2'
@mock_dynamodb2
def test_scan_filter():
client = boto3.client('dynamodb', region_name='us-east-1')