DynamoDB: query() now returns the correct ScannedCount (#7208)
This commit is contained in:
parent
624de34d82
commit
455fbd5eaa
@ -656,8 +656,8 @@ class Table(CloudFormationModel):
|
|||||||
filter_expression: Any = None,
|
filter_expression: Any = None,
|
||||||
**filter_kwargs: Any,
|
**filter_kwargs: Any,
|
||||||
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
|
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
|
||||||
results = []
|
|
||||||
|
|
||||||
|
# FIND POSSIBLE RESULTS
|
||||||
if index_name:
|
if index_name:
|
||||||
all_indexes = self.all_indexes()
|
all_indexes = self.all_indexes()
|
||||||
indexes_by_name = dict((i.name, i) for i in all_indexes)
|
indexes_by_name = dict((i.name, i) for i in all_indexes)
|
||||||
@ -683,6 +683,10 @@ class Table(CloudFormationModel):
|
|||||||
][0]
|
][0]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
index_range_key = None
|
index_range_key = None
|
||||||
|
if range_comparison:
|
||||||
|
raise ValueError(
|
||||||
|
f"Range Key comparison but no range key found for index: {index_name}"
|
||||||
|
)
|
||||||
|
|
||||||
possible_results = []
|
possible_results = []
|
||||||
for item in self.all_items():
|
for item in self.all_items():
|
||||||
@ -703,26 +707,51 @@ class Table(CloudFormationModel):
|
|||||||
if isinstance(item, Item) and item.hash_key == hash_key
|
if isinstance(item, Item) and item.hash_key == hash_key
|
||||||
]
|
]
|
||||||
|
|
||||||
if range_comparison:
|
# FILTER
|
||||||
if index_name and not index_range_key:
|
results: List[Item] = []
|
||||||
raise ValueError(
|
result_size = 0
|
||||||
"Range Key comparison but no range key found for index: %s"
|
scanned_count = 0
|
||||||
% index_name
|
last_evaluated_key = None
|
||||||
)
|
processing_previous_page = exclusive_start_key is not None
|
||||||
|
for result in possible_results:
|
||||||
|
# Cycle through the previous page of results
|
||||||
|
# When we encounter our start key, we know we've reached the end of the previous page
|
||||||
|
if processing_previous_page:
|
||||||
|
if self._item_equals_dct(result, exclusive_start_key):
|
||||||
|
processing_previous_page = False
|
||||||
|
continue
|
||||||
|
|
||||||
elif index_name:
|
# Check wether we've reached the limit of our result set
|
||||||
for result in possible_results:
|
# That can be either in number, or in size
|
||||||
|
reached_length_limit = len(results) == limit
|
||||||
|
reached_size_limit = (result_size + result.size()) > RESULT_SIZE_LIMIT
|
||||||
|
if reached_length_limit or reached_size_limit:
|
||||||
|
last_evaluated_key = self._get_last_evaluated_key(
|
||||||
|
results[-1], index_name
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not range_comparison and not filter_kwargs:
|
||||||
|
# If we're not filtering on range key or on an index
|
||||||
|
results.append(result)
|
||||||
|
result_size += result.size()
|
||||||
|
scanned_count += 1
|
||||||
|
|
||||||
|
if range_comparison:
|
||||||
|
if index_name:
|
||||||
if result.attrs.get(index_range_key["AttributeName"]).compare( # type: ignore
|
if result.attrs.get(index_range_key["AttributeName"]).compare( # type: ignore
|
||||||
range_comparison, range_objs
|
range_comparison, range_objs
|
||||||
):
|
):
|
||||||
results.append(result)
|
results.append(result)
|
||||||
else:
|
result_size += result.size()
|
||||||
for result in possible_results:
|
scanned_count += 1
|
||||||
|
else:
|
||||||
if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr]
|
if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr]
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
result_size += result.size()
|
||||||
|
scanned_count += 1
|
||||||
|
|
||||||
if filter_kwargs:
|
if filter_kwargs:
|
||||||
for result in possible_results:
|
|
||||||
for field, value in filter_kwargs.items():
|
for field, value in filter_kwargs.items():
|
||||||
dynamo_types = [
|
dynamo_types = [
|
||||||
DynamoType(ele) for ele in value["AttributeValueList"]
|
DynamoType(ele) for ele in value["AttributeValueList"]
|
||||||
@ -731,12 +760,10 @@ class Table(CloudFormationModel):
|
|||||||
value["ComparisonOperator"], dynamo_types
|
value["ComparisonOperator"], dynamo_types
|
||||||
):
|
):
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
result_size += result.size()
|
||||||
|
scanned_count += 1
|
||||||
|
|
||||||
if not range_comparison and not filter_kwargs:
|
# SORT
|
||||||
# If we're not filtering on range key or on an index return all
|
|
||||||
# values
|
|
||||||
results = possible_results
|
|
||||||
|
|
||||||
if index_name:
|
if index_name:
|
||||||
if index_range_key:
|
if index_range_key:
|
||||||
# Convert to float if necessary to ensure proper ordering
|
# Convert to float if necessary to ensure proper ordering
|
||||||
@ -754,17 +781,11 @@ class Table(CloudFormationModel):
|
|||||||
if scan_index_forward is False:
|
if scan_index_forward is False:
|
||||||
results.reverse()
|
results.reverse()
|
||||||
|
|
||||||
scanned_count = len(list(self.all_items()))
|
|
||||||
|
|
||||||
results = copy.deepcopy(results)
|
results = copy.deepcopy(results)
|
||||||
if index_name:
|
if index_name:
|
||||||
index = self.get_index(index_name)
|
index = self.get_index(index_name)
|
||||||
results = [index.project(r) for r in results]
|
results = [index.project(r) for r in results]
|
||||||
|
|
||||||
results, last_evaluated_key = self._trim_results(
|
|
||||||
results, limit, exclusive_start_key, scanned_index=index_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if filter_expression is not None:
|
if filter_expression is not None:
|
||||||
results = [item for item in results if filter_expression.expr(item)]
|
results = [item for item in results if filter_expression.expr(item)]
|
||||||
|
|
||||||
@ -891,35 +912,6 @@ class Table(CloudFormationModel):
|
|||||||
range_key = DynamoType(range_key)
|
range_key = DynamoType(range_key)
|
||||||
return item.hash_key == hash_key and item.range_key == range_key
|
return item.hash_key == hash_key and item.range_key == range_key
|
||||||
|
|
||||||
def _trim_results(
|
|
||||||
self,
|
|
||||||
results: List[Item],
|
|
||||||
limit: int,
|
|
||||||
exclusive_start_key: Optional[Dict[str, Any]],
|
|
||||||
scanned_index: Optional[str] = None,
|
|
||||||
) -> Tuple[List[Item], Optional[Dict[str, Any]]]:
|
|
||||||
if exclusive_start_key is not None:
|
|
||||||
for i in range(len(results)):
|
|
||||||
if self._item_equals_dct(results[i], exclusive_start_key):
|
|
||||||
results = results[i + 1 :]
|
|
||||||
break
|
|
||||||
|
|
||||||
last_evaluated_key = None
|
|
||||||
item_size = sum(res.size() for res in results)
|
|
||||||
if item_size > RESULT_SIZE_LIMIT:
|
|
||||||
item_size = idx = 0
|
|
||||||
while item_size + results[idx].size() < RESULT_SIZE_LIMIT:
|
|
||||||
item_size += results[idx].size()
|
|
||||||
idx += 1
|
|
||||||
limit = min(limit, idx) if limit else idx
|
|
||||||
if limit and len(results) > limit:
|
|
||||||
results = results[:limit]
|
|
||||||
last_evaluated_key = self._get_last_evaluated_key(
|
|
||||||
last_result=results[-1], index_name=scanned_index
|
|
||||||
)
|
|
||||||
|
|
||||||
return results, last_evaluated_key
|
|
||||||
|
|
||||||
def _get_last_evaluated_key(
|
def _get_last_evaluated_key(
|
||||||
self, last_result: Item, index_name: Optional[str]
|
self, last_result: Item, index_name: Optional[str]
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
@ -689,18 +689,21 @@ def test_nested_projection_expression_using_query():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test a query returning all items
|
# Test a query returning nested attributes
|
||||||
result = table.query(
|
result = table.query(
|
||||||
KeyConditionExpression=Key("name").eq("key1"),
|
KeyConditionExpression=Key("name").eq("key1"),
|
||||||
ProjectionExpression="nested.level1.id, nested.level2",
|
ProjectionExpression="nested.level1.id, nested.level2",
|
||||||
)["Items"][0]
|
)
|
||||||
|
assert result["ScannedCount"] == 1
|
||||||
|
item = result["Items"][0]
|
||||||
|
|
||||||
assert "nested" in result
|
assert "nested" in item
|
||||||
assert result["nested"] == {
|
assert item["nested"] == {
|
||||||
"level1": {"id": "id1"},
|
"level1": {"id": "id1"},
|
||||||
"level2": {"id": "id2", "include": "all"},
|
"level2": {"id": "id2", "include": "all"},
|
||||||
}
|
}
|
||||||
assert "foo" not in result
|
assert "foo" not in item
|
||||||
|
|
||||||
# Assert actual data has not been deleted
|
# Assert actual data has not been deleted
|
||||||
result = table.query(KeyConditionExpression=Key("name").eq("key1"))["Items"][0]
|
result = table.query(KeyConditionExpression=Key("name").eq("key1"))["Items"][0]
|
||||||
assert result == {
|
assert result == {
|
||||||
@ -1356,12 +1359,14 @@ def test_query_filter():
|
|||||||
table = dynamodb.Table("test1")
|
table = dynamodb.Table("test1")
|
||||||
response = table.query(KeyConditionExpression=Key("client").eq("client1"))
|
response = table.query(KeyConditionExpression=Key("client").eq("client1"))
|
||||||
assert response["Count"] == 2
|
assert response["Count"] == 2
|
||||||
|
assert response["ScannedCount"] == 2
|
||||||
|
|
||||||
response = table.query(
|
response = table.query(
|
||||||
KeyConditionExpression=Key("client").eq("client1"),
|
KeyConditionExpression=Key("client").eq("client1"),
|
||||||
FilterExpression=Attr("app").eq("app2"),
|
FilterExpression=Attr("app").eq("app2"),
|
||||||
)
|
)
|
||||||
assert response["Count"] == 1
|
assert response["Count"] == 1
|
||||||
|
assert response["ScannedCount"] == 2
|
||||||
assert response["Items"][0]["app"] == "app2"
|
assert response["Items"][0]["app"] == "app2"
|
||||||
response = table.query(
|
response = table.query(
|
||||||
KeyConditionExpression=Key("client").eq("client1"),
|
KeyConditionExpression=Key("client").eq("client1"),
|
||||||
|
@ -794,6 +794,7 @@ def test_boto3_query_gsi_range_comparison():
|
|||||||
ScanIndexForward=True,
|
ScanIndexForward=True,
|
||||||
IndexName="TestGSI",
|
IndexName="TestGSI",
|
||||||
)
|
)
|
||||||
|
assert results["ScannedCount"] == 3
|
||||||
expected = ["456", "789", "123"]
|
expected = ["456", "789", "123"]
|
||||||
for index, item in enumerate(results["Items"]):
|
for index, item in enumerate(results["Items"]):
|
||||||
assert item["subject"] == expected[index]
|
assert item["subject"] == expected[index]
|
||||||
@ -1077,6 +1078,7 @@ def test_query_pagination():
|
|||||||
|
|
||||||
page1 = table.query(KeyConditionExpression=Key("forum_name").eq("the-key"), Limit=6)
|
page1 = table.query(KeyConditionExpression=Key("forum_name").eq("the-key"), Limit=6)
|
||||||
assert page1["Count"] == 6
|
assert page1["Count"] == 6
|
||||||
|
assert page1["ScannedCount"] == 6
|
||||||
assert len(page1["Items"]) == 6
|
assert len(page1["Items"]) == 6
|
||||||
|
|
||||||
page2 = table.query(
|
page2 = table.query(
|
||||||
@ -1085,6 +1087,7 @@ def test_query_pagination():
|
|||||||
ExclusiveStartKey=page1["LastEvaluatedKey"],
|
ExclusiveStartKey=page1["LastEvaluatedKey"],
|
||||||
)
|
)
|
||||||
assert page2["Count"] == 4
|
assert page2["Count"] == 4
|
||||||
|
assert page2["ScannedCount"] == 4
|
||||||
assert len(page2["Items"]) == 4
|
assert len(page2["Items"]) == 4
|
||||||
assert "LastEvaluatedKey" not in page2
|
assert "LastEvaluatedKey" not in page2
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user