DynamoDB: query() now returns the correct ScannedCount (#7208)

This commit is contained in:
Bert Blommers 2024-01-13 18:36:19 +00:00 committed by GitHub
parent 624de34d82
commit 455fbd5eaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 58 deletions

View File

@ -656,8 +656,8 @@ class Table(CloudFormationModel):
filter_expression: Any = None, filter_expression: Any = None,
**filter_kwargs: Any, **filter_kwargs: Any,
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
results = []
# FIND POSSIBLE RESULTS
if index_name: if index_name:
all_indexes = self.all_indexes() all_indexes = self.all_indexes()
indexes_by_name = dict((i.name, i) for i in all_indexes) indexes_by_name = dict((i.name, i) for i in all_indexes)
@ -683,6 +683,10 @@ class Table(CloudFormationModel):
][0] ][0]
except IndexError: except IndexError:
index_range_key = None index_range_key = None
if range_comparison:
raise ValueError(
f"Range Key comparison but no range key found for index: {index_name}"
)
possible_results = [] possible_results = []
for item in self.all_items(): for item in self.all_items():
@ -703,26 +707,51 @@ class Table(CloudFormationModel):
if isinstance(item, Item) and item.hash_key == hash_key if isinstance(item, Item) and item.hash_key == hash_key
] ]
if range_comparison: # FILTER
if index_name and not index_range_key: results: List[Item] = []
raise ValueError( result_size = 0
"Range Key comparison but no range key found for index: %s" scanned_count = 0
% index_name last_evaluated_key = None
) processing_previous_page = exclusive_start_key is not None
elif index_name:
for result in possible_results: for result in possible_results:
# Cycle through the previous page of results
# When we encounter our start key, we know we've reached the end of the previous page
if processing_previous_page:
if self._item_equals_dct(result, exclusive_start_key):
processing_previous_page = False
continue
# Check wether we've reached the limit of our result set
# That can be either in number, or in size
reached_length_limit = len(results) == limit
reached_size_limit = (result_size + result.size()) > RESULT_SIZE_LIMIT
if reached_length_limit or reached_size_limit:
last_evaluated_key = self._get_last_evaluated_key(
results[-1], index_name
)
break
if not range_comparison and not filter_kwargs:
# If we're not filtering on range key or on an index
results.append(result)
result_size += result.size()
scanned_count += 1
if range_comparison:
if index_name:
if result.attrs.get(index_range_key["AttributeName"]).compare( # type: ignore if result.attrs.get(index_range_key["AttributeName"]).compare( # type: ignore
range_comparison, range_objs range_comparison, range_objs
): ):
results.append(result) results.append(result)
result_size += result.size()
scanned_count += 1
else: else:
for result in possible_results:
if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr] if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr]
results.append(result) results.append(result)
result_size += result.size()
scanned_count += 1
if filter_kwargs: if filter_kwargs:
for result in possible_results:
for field, value in filter_kwargs.items(): for field, value in filter_kwargs.items():
dynamo_types = [ dynamo_types = [
DynamoType(ele) for ele in value["AttributeValueList"] DynamoType(ele) for ele in value["AttributeValueList"]
@ -731,12 +760,10 @@ class Table(CloudFormationModel):
value["ComparisonOperator"], dynamo_types value["ComparisonOperator"], dynamo_types
): ):
results.append(result) results.append(result)
result_size += result.size()
scanned_count += 1
if not range_comparison and not filter_kwargs: # SORT
# If we're not filtering on range key or on an index return all
# values
results = possible_results
if index_name: if index_name:
if index_range_key: if index_range_key:
# Convert to float if necessary to ensure proper ordering # Convert to float if necessary to ensure proper ordering
@ -754,17 +781,11 @@ class Table(CloudFormationModel):
if scan_index_forward is False: if scan_index_forward is False:
results.reverse() results.reverse()
scanned_count = len(list(self.all_items()))
results = copy.deepcopy(results) results = copy.deepcopy(results)
if index_name: if index_name:
index = self.get_index(index_name) index = self.get_index(index_name)
results = [index.project(r) for r in results] results = [index.project(r) for r in results]
results, last_evaluated_key = self._trim_results(
results, limit, exclusive_start_key, scanned_index=index_name
)
if filter_expression is not None: if filter_expression is not None:
results = [item for item in results if filter_expression.expr(item)] results = [item for item in results if filter_expression.expr(item)]
@ -891,35 +912,6 @@ class Table(CloudFormationModel):
range_key = DynamoType(range_key) range_key = DynamoType(range_key)
return item.hash_key == hash_key and item.range_key == range_key return item.hash_key == hash_key and item.range_key == range_key
def _trim_results(
self,
results: List[Item],
limit: int,
exclusive_start_key: Optional[Dict[str, Any]],
scanned_index: Optional[str] = None,
) -> Tuple[List[Item], Optional[Dict[str, Any]]]:
if exclusive_start_key is not None:
for i in range(len(results)):
if self._item_equals_dct(results[i], exclusive_start_key):
results = results[i + 1 :]
break
last_evaluated_key = None
item_size = sum(res.size() for res in results)
if item_size > RESULT_SIZE_LIMIT:
item_size = idx = 0
while item_size + results[idx].size() < RESULT_SIZE_LIMIT:
item_size += results[idx].size()
idx += 1
limit = min(limit, idx) if limit else idx
if limit and len(results) > limit:
results = results[:limit]
last_evaluated_key = self._get_last_evaluated_key(
last_result=results[-1], index_name=scanned_index
)
return results, last_evaluated_key
def _get_last_evaluated_key( def _get_last_evaluated_key(
self, last_result: Item, index_name: Optional[str] self, last_result: Item, index_name: Optional[str]
) -> Dict[str, Any]: ) -> Dict[str, Any]:

View File

@ -689,18 +689,21 @@ def test_nested_projection_expression_using_query():
} }
) )
# Test a query returning all items # Test a query returning nested attributes
result = table.query( result = table.query(
KeyConditionExpression=Key("name").eq("key1"), KeyConditionExpression=Key("name").eq("key1"),
ProjectionExpression="nested.level1.id, nested.level2", ProjectionExpression="nested.level1.id, nested.level2",
)["Items"][0] )
assert result["ScannedCount"] == 1
item = result["Items"][0]
assert "nested" in result assert "nested" in item
assert result["nested"] == { assert item["nested"] == {
"level1": {"id": "id1"}, "level1": {"id": "id1"},
"level2": {"id": "id2", "include": "all"}, "level2": {"id": "id2", "include": "all"},
} }
assert "foo" not in result assert "foo" not in item
# Assert actual data has not been deleted # Assert actual data has not been deleted
result = table.query(KeyConditionExpression=Key("name").eq("key1"))["Items"][0] result = table.query(KeyConditionExpression=Key("name").eq("key1"))["Items"][0]
assert result == { assert result == {
@ -1356,12 +1359,14 @@ def test_query_filter():
table = dynamodb.Table("test1") table = dynamodb.Table("test1")
response = table.query(KeyConditionExpression=Key("client").eq("client1")) response = table.query(KeyConditionExpression=Key("client").eq("client1"))
assert response["Count"] == 2 assert response["Count"] == 2
assert response["ScannedCount"] == 2
response = table.query( response = table.query(
KeyConditionExpression=Key("client").eq("client1"), KeyConditionExpression=Key("client").eq("client1"),
FilterExpression=Attr("app").eq("app2"), FilterExpression=Attr("app").eq("app2"),
) )
assert response["Count"] == 1 assert response["Count"] == 1
assert response["ScannedCount"] == 2
assert response["Items"][0]["app"] == "app2" assert response["Items"][0]["app"] == "app2"
response = table.query( response = table.query(
KeyConditionExpression=Key("client").eq("client1"), KeyConditionExpression=Key("client").eq("client1"),

View File

@ -794,6 +794,7 @@ def test_boto3_query_gsi_range_comparison():
ScanIndexForward=True, ScanIndexForward=True,
IndexName="TestGSI", IndexName="TestGSI",
) )
assert results["ScannedCount"] == 3
expected = ["456", "789", "123"] expected = ["456", "789", "123"]
for index, item in enumerate(results["Items"]): for index, item in enumerate(results["Items"]):
assert item["subject"] == expected[index] assert item["subject"] == expected[index]
@ -1077,6 +1078,7 @@ def test_query_pagination():
page1 = table.query(KeyConditionExpression=Key("forum_name").eq("the-key"), Limit=6) page1 = table.query(KeyConditionExpression=Key("forum_name").eq("the-key"), Limit=6)
assert page1["Count"] == 6 assert page1["Count"] == 6
assert page1["ScannedCount"] == 6
assert len(page1["Items"]) == 6 assert len(page1["Items"]) == 6
page2 = table.query( page2 = table.query(
@ -1085,6 +1087,7 @@ def test_query_pagination():
ExclusiveStartKey=page1["LastEvaluatedKey"], ExclusiveStartKey=page1["LastEvaluatedKey"],
) )
assert page2["Count"] == 4 assert page2["Count"] == 4
assert page2["ScannedCount"] == 4
assert len(page2["Items"]) == 4 assert len(page2["Items"]) == 4
assert "LastEvaluatedKey" not in page2 assert "LastEvaluatedKey" not in page2