diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 855728ec1..e5bce2f67 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -419,7 +419,7 @@ class Table(BaseModel): def query(self, hash_key, range_comparison, range_objs, limit, exclusive_start_key, scan_index_forward, projection_expression, - index_name=None, **filter_kwargs): + index_name=None, filter_expression=None, **filter_kwargs): results = [] if index_name: all_indexes = (self.global_indexes or []) + (self.indexes or []) @@ -502,6 +502,9 @@ class Table(BaseModel): scanned_count = len(list(self.all_items())) + if filter_expression is not None: + results = [item for item in results if filter_expression.expr(item)] + results, last_evaluated_key = self._trim_results(results, limit, exclusive_start_key) return results, scanned_count, last_evaluated_key @@ -707,7 +710,9 @@ class DynamoDBBackend(BaseBackend): return table.get_item(hash_key, range_key) def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts, - limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None, **filter_kwargs): + limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None, + expr_names=None, expr_values=None, filter_expression=None, + **filter_kwargs): table = self.tables.get(table_name) if not table: return None, None @@ -716,8 +721,13 @@ class DynamoDBBackend(BaseBackend): range_values = [DynamoType(range_value) for range_value in range_value_dicts] + if filter_expression is not None: + filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) + else: + filter_expression = Op(None, None) # Will always eval to true + return table.query(hash_key, range_comparison, range_values, limit, - exclusive_start_key, scan_index_forward, projection_expression, index_name, **filter_kwargs) + exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values): table = self.tables.get(table_name) diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index b9154b6e1..c0420c2a4 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -298,7 +298,9 @@ class DynamoHandler(BaseResponse): # {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}} key_condition_expression = self.body.get('KeyConditionExpression') projection_expression = self.body.get('ProjectionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames') + expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) + filter_expression = self.body.get('FilterExpression') + expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) if projection_expression and expression_attribute_names: expressions = [x.strip() for x in projection_expression.split(',')] @@ -307,6 +309,7 @@ class DynamoHandler(BaseResponse): projection_expression = projection_expression.replace(expression, expression_attribute_names[expression]) filter_kwargs = {} + if key_condition_expression: value_alias_map = self.body['ExpressionAttributeValues'] @@ -413,7 +416,9 @@ class DynamoHandler(BaseResponse): scan_index_forward = self.body.get("ScanIndexForward") items, scanned_count, last_evaluated_key = self.dynamodb_backend.query( name, hash_key, range_comparison, range_values, limit, - exclusive_start_key, scan_index_forward, projection_expression, index_name=index_name, **filter_kwargs + exclusive_start_key, scan_index_forward, projection_expression, index_name=index_name, + expr_names=expression_attribute_names, expr_values=expression_attribute_values, + filter_expression=filter_expression, **filter_kwargs ) if items is None: er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index c645a0c4e..a0bfc9833 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -649,6 +649,47 @@ def test_filter_expression(): filter_expr.expr(row1).should.be(True) +@mock_dynamodb2 +def test_query_filter(): + client = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + + # Create the DynamoDB table. + client.create_table( + TableName='test1', + AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], + KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], + ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + ) + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'} + } + ) + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app2'} + } + ) + + table = dynamodb.Table('test1') + response = table.query( + KeyConditionExpression=Key('client').eq('client1') + ) + assert response['Count'] == 2 + + response = table.query( + KeyConditionExpression=Key('client').eq('client1'), + FilterExpression=Attr('app').eq('app2') + ) + assert response['Count'] == 1 + assert response['Items'][0]['app'] == 'app2' + + @mock_dynamodb2 def test_scan_filter(): client = boto3.client('dynamodb', region_name='us-east-1')