DynamoDB: Support projection expressions in lists (#6375)

This commit is contained in:
Bert Blommers 2023-06-08 17:10:14 +00:00 committed by GitHub
parent 6fac7de646
commit 6e7edd5057
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 354 additions and 65 deletions

View File

@ -1,4 +1,7 @@
import copy
import decimal
from botocore.utils import merge_dicts
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from typing import Any, Dict, List, Union, Optional
@ -8,7 +11,7 @@ from moto.dynamodb.exceptions import (
EmptyKeyAttributeException,
ItemSizeTooLarge,
)
from moto.dynamodb.models.utilities import bytesize
from .utilities import bytesize, find_nested_key
deserializer = TypeDeserializer()
serializer = TypeSerializer()
@ -67,28 +70,6 @@ class DynamoType(object):
elif self.is_map():
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
def filter(self, projection_expressions: str) -> None:
nested_projections = [
expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr
]
if self.is_map():
expressions_to_delete = []
for attr in self.value:
if (
attr not in projection_expressions
and attr not in nested_projections
):
expressions_to_delete.append(attr)
elif attr in nested_projections:
relevant_expressions = [
expr[len(attr + ".") :]
for expr in projection_expressions
if expr.startswith(attr + ".")
]
self.value[attr].filter(relevant_expressions)
for expr in expressions_to_delete:
self.value.pop(expr)
def __hash__(self) -> int:
return hash((self.type, self.value))
@ -213,8 +194,26 @@ class DynamoType(object):
return value_size
def to_json(self) -> Dict[str, Any]:
# Returns a regular JSON object where the value can still be/contain a DynamoType
return {self.type: self.value}
def to_regular_json(self) -> Dict[str, Any]:
# Returns a regular JSON object in full
value = copy.deepcopy(self.value)
if isinstance(value, dict):
for key, nested_value in value.items():
value[key] = (
nested_value.to_regular_json()
if isinstance(nested_value, DynamoType)
else nested_value
)
if isinstance(value, list):
value = [
val.to_regular_json() if isinstance(val, DynamoType) else val
for val in value
]
return {self.type: value}
def compare(self, range_comparison: str, range_objs: List[Any]) -> bool:
"""
Compares this type against comparison filters
@ -310,7 +309,7 @@ class Item(BaseModel):
def to_regular_json(self) -> Dict[str, Any]:
attributes = {}
for key, attribute in self.attrs.items():
attributes[key] = deserializer.deserialize(attribute.to_json())
attributes[key] = deserializer.deserialize(attribute.to_regular_json())
return attributes
def describe_attrs(
@ -412,20 +411,19 @@ class Item(BaseModel):
f"{action} action not support for update_with_attribute_updates"
)
# Filter using projection_expression
# Ensure a deep copy is used to filter, otherwise actual data will be removed
def filter(self, projection_expression: str) -> None:
def project(self, projection_expression: str) -> "Item":
# Returns a new Item with only the dictionary-keys that match the provided projection_expression
# Will return an empty Item if the expression does not match anything
result: Dict[str, Any] = dict()
expressions = [x.strip() for x in projection_expression.split(",")]
top_level_expressions = [
expr[0 : expr.index(".")] for expr in expressions if "." in expr
]
for attr in list(self.attrs):
if attr not in expressions and attr not in top_level_expressions:
self.attrs.pop(attr)
if attr in top_level_expressions:
relevant_expressions = [
expr[len(attr + ".") :]
for expr in expressions
if expr.startswith(attr + ".")
]
self.attrs[attr].filter(relevant_expressions)
for expr in expressions:
x = find_nested_key(expr.split("."), self.to_regular_json())
merge_dicts(result, x)
return Item(
hash_key=self.hash_key,
range_key=self.range_key,
# 'result' is a normal Python dictionary ({'key': 'value'}
# We need to convert that into DynamoDB dictionary ({'M': {'key': {'S': 'value'}}})
attrs=serializer.serialize(result)["M"],
)

View File

@ -50,12 +50,12 @@ class SecondaryIndex(BaseModel):
]
if projection_type == "KEYS_ONLY":
item.filter(",".join(key_attributes))
item = item.project(",".join(key_attributes))
elif projection_type == "INCLUDE":
allowed_attributes = key_attributes + self.projection.get(
"NonKeyAttributes", []
)
item.filter(",".join(allowed_attributes))
item = item.project(",".join(allowed_attributes))
# ALL is handled implicitly by not filtering
return item
@ -607,11 +607,7 @@ class Table(CloudFormationModel):
result = self.items[hash_key]
if projection_expression and result:
result = copy.deepcopy(result)
result.filter(projection_expression)
if not result:
raise KeyError
result = result.project(projection_expression)
return result
except KeyError:
@ -728,9 +724,7 @@ class Table(CloudFormationModel):
results = possible_results
if index_name:
if index_range_key:
# Convert to float if necessary to ensure proper ordering
def conv(x: DynamoType) -> Any:
return float(x.value) if x.type == "N" else x.value
@ -751,8 +745,7 @@ class Table(CloudFormationModel):
results = copy.deepcopy(results)
if index_name:
index = self.get_index(index_name)
for result in results:
index.project(result)
results = [index.project(r) for r in results]
results, last_evaluated_key = self._trim_results(
results, limit, exclusive_start_key, scanned_index=index_name
@ -762,8 +755,7 @@ class Table(CloudFormationModel):
results = [item for item in results if filter_expression.expr(item)]
if projection_expression:
for result in results:
result.filter(projection_expression)
results = [r.project(projection_expression) for r in results]
return results, scanned_count, last_evaluated_key
@ -788,7 +780,6 @@ class Table(CloudFormationModel):
return indexes_by_name[index_name]
def has_idx_items(self, index_name: str) -> Iterator[Item]:
idx = self.get_index(index_name)
idx_col_set = set([i["AttributeName"] for i in idx.schema])
@ -848,8 +839,7 @@ class Table(CloudFormationModel):
results = copy.deepcopy(results)
if index_name:
index = self.get_index(index_name)
for result in results:
index.project(result)
results = [index.project(r) for r in results]
results, last_evaluated_key = self._trim_results(
results, limit, exclusive_start_key, scanned_index=index_name
@ -859,9 +849,7 @@ class Table(CloudFormationModel):
results = [item for item in results if filter_expression.expr(item)]
if projection_expression:
results = copy.deepcopy(results)
for result in results:
result.filter(projection_expression)
results = [r.project(projection_expression) for r in results]
return results, scanned_count, last_evaluated_key

View File

@ -1,5 +1,6 @@
import json
from typing import Any
import re
from typing import Any, Dict, List, Optional
class DynamoJsonEncoder(json.JSONEncoder):
@ -14,3 +15,110 @@ def dynamo_json_dump(dynamo_object: Any) -> str:
def bytesize(val: str) -> int:
return len(val.encode("utf-8"))
def find_nested_key(
keys: List[str],
dct: Dict[str, Any],
processed_keys: Optional[List[str]] = None,
result: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
keys : A list of keys that may be present in the provided dictionary
["level1", "level2"]
dct : A dictionary that we want to inspect
{"level1": {"level2": "val", "irrelevant": ..}
processed_keys:
Should not be set by the caller, only by recursive invocations.
Example value: ["level1"]
result:
Should not be set by the caller, only by recursive invocations
Example value: {"level1": {}}
returns: {"level1": {"level2": "val"}}
"""
if result is None:
result = {}
if processed_keys is None:
processed_keys = []
# A key can refer to a list-item: 'level1[1].level2'
is_list_expression = re.match(pattern=r"(.+)\[(\d+)\]$", string=keys[0])
if len(keys) == 1:
# Set 'current_key' and 'value'
# or return an empty dictionary if the key does not exist in our dictionary
if is_list_expression:
current_key = is_list_expression.group(1)
idx = int(is_list_expression.group(2))
if (
current_key in dct
and isinstance(dct[current_key], list)
and len(dct[current_key]) >= idx
):
value = [dct[current_key][idx]]
else:
return {}
elif keys[0] in dct:
current_key = keys[0]
value = dct[current_key]
else:
return {}
# We may have already processed some keys
# Dig into the result to find the appropriate key to append the value to
#
# result: {'level1': {'level2': {}}}
# processed_keys: ['level1', 'level2']
# -->
# result: {'level1': {'level2': value}}
temp_result = result
for key in processed_keys:
if isinstance(temp_result, list):
temp_result = temp_result[0][key]
else:
temp_result = temp_result[key]
if isinstance(temp_result, list):
temp_result.append({current_key: value})
else:
temp_result[current_key] = value
return result
else:
# Set 'current_key'
# or return an empty dictionary if the key does not exist in our dictionary
if is_list_expression:
current_key = is_list_expression.group(1)
idx = int(is_list_expression.group(2))
if (
current_key in dct
and isinstance(dct[current_key], list)
and len(dct[current_key]) >= idx
):
pass
else:
return {}
elif keys[0] in dct:
current_key = keys[0]
else:
return {}
# Append the 'current_key' to the dictionary that is our result (so far)
# {'level1': {}} --> {'level1': {current_key: {}}
temp_result = result
for key in processed_keys:
temp_result = temp_result[key]
if isinstance(temp_result, list):
temp_result.append({current_key: [] if is_list_expression else {}})
else:
temp_result[current_key] = [] if is_list_expression else {}
remaining_dct = (
dct[current_key][idx] if is_list_expression else dct[current_key]
)
return find_nested_key(
keys[1:],
remaining_dct,
processed_keys=processed_keys + [current_key],
result=result,
)

View File

@ -199,7 +199,6 @@ def test_update_item_range_key_set():
@mock_dynamodb
def test_batch_get_item_non_existing_table():
client = boto3.client("dynamodb", region_name="us-west-2")
with pytest.raises(client.exceptions.ResourceNotFoundException) as exc:
@ -768,7 +767,6 @@ def test_query_begins_with_without_brackets():
@mock_dynamodb
def test_transact_write_items_multiple_operations_fail():
# Setup
schema = {
"KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],

View File

@ -0,0 +1,110 @@
from moto.dynamodb.models.dynamo_type import DynamoType, Item
from moto.dynamodb.models.dynamo_type import serializer
class TestFindNestedKeys:
def setup(self):
self.dct = {
"simplestring": "val",
"nesteddict": {
"level21": {"ll31": "val", "ll32": "val"},
"level22": {"ll31": "val", "ll32": "val"},
"nestedlist": [
{"ll21": {"ll31": "val", "ll32": "val"}},
{"ll22": {"ll31": "val", "ll32": "val"}},
],
},
"rootlist": [
{"ll21": {"ll31": "val", "ll32": "val"}},
{"ll22": {"ll31": "val", "ll32": "val"}},
],
}
x = serializer.serialize(self.dct)["M"]
self.item = Item(
hash_key=DynamoType({"pk": {"S": "v"}}), range_key=None, attrs=x
)
def _project(self, expression, result):
x = self.item.project(expression)
y = Item(
hash_key=DynamoType({"pk": {"S": "v"}}),
range_key=None,
attrs=serializer.serialize(result)["M"],
)
assert x == y
def test_find_nothing(self):
self._project("", result={})
def test_find_unknown_key(self):
self._project("unknown", result={})
def test_project_single_key_string(self):
self._project("simplestring", result={"simplestring": "val"})
def test_project_single_key_dict(self):
self._project(
"nesteddict",
result={
"nesteddict": {
"level21": {"ll31": "val", "ll32": "val"},
"level22": {"ll31": "val", "ll32": "val"},
"nestedlist": [
{"ll21": {"ll31": "val", "ll32": "val"}},
{"ll22": {"ll31": "val", "ll32": "val"}},
],
}
},
)
def test_project_nested_key(self):
self._project(
"nesteddict.level21",
result={"nesteddict": {"level21": {"ll31": "val", "ll32": "val"}}},
)
def test_project_multi_level_nested_key(self):
self._project(
"nesteddict.level21.ll32",
result={"nesteddict": {"level21": {"ll32": "val"}}},
)
def test_project_nested_key__partial_fix(self):
self._project("nesteddict.levelunknown", result={})
def test_project_nested_key__partial_fix2(self):
self._project("nesteddict.unknown.unknown2", result={})
def test_list_index(self):
self._project(
"rootlist[0]",
result={"rootlist": [{"ll21": {"ll31": "val", "ll32": "val"}}]},
)
def test_nested_list_index(self):
self._project(
"nesteddict.nestedlist[1]",
result={
"nesteddict": {"nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}]}
},
)
def test_nested_obj_in_list(self):
self._project(
"nesteddict.nestedlist[1].ll22.ll31",
result={"nesteddict": {"nestedlist": [{"ll22": {"ll31": "val"}}]}},
)
def test_list_unknown_indexes(self):
self._project("nesteddict.nestedlist[25]", result={})
def test_multiple_projections(self):
self._project(
"nesteddict.nestedlist[1].ll22,rootlist[0]",
result={
"nesteddict": {
"nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}]
},
"rootlist": [{"ll21": {"ll31": "val", "ll32": "val"}}],
},
)

View File

@ -0,0 +1,75 @@
from moto.dynamodb.models.utilities import find_nested_key
class TestFindDictionaryKeys:
def setup(self):
self.item = {
"simplestring": "val",
"nesteddict": {
"level21": {"level3.1": "val", "level3.2": "val"},
"level22": {"level3.1": "val", "level3.2": "val"},
"nestedlist": [
{"ll21": {"ll3.1": "val", "ll3.2": "val"}},
{"ll22": {"ll3.1": "val", "ll3.2": "val"}},
],
},
"rootlist": [
{"ll21": {"ll3.1": "val", "ll3.2": "val"}},
{"ll22": {"ll3.1": "val", "ll3.2": "val"}},
],
}
def test_find_nothing(self):
assert find_nested_key([""], self.item) == {}
def test_find_unknown_key(self):
assert find_nested_key(["unknown"], self.item) == {}
def test_project_single_key_string(self):
assert find_nested_key(["simplestring"], self.item) == {"simplestring": "val"}
def test_project_single_key_dict(self):
assert find_nested_key(["nesteddict"], self.item) == {
"nesteddict": {
"level21": {"level3.1": "val", "level3.2": "val"},
"level22": {"level3.1": "val", "level3.2": "val"},
"nestedlist": [
{"ll21": {"ll3.1": "val", "ll3.2": "val"}},
{"ll22": {"ll3.1": "val", "ll3.2": "val"}},
],
}
}
def test_project_nested_key(self):
assert find_nested_key(["nesteddict", "level21"], self.item) == {
"nesteddict": {"level21": {"level3.1": "val", "level3.2": "val"}}
}
def test_project_multi_level_nested_key(self):
assert find_nested_key(["nesteddict", "level21", "level3.2"], self.item) == {
"nesteddict": {"level21": {"level3.2": "val"}}
}
def test_project_nested_key__partial_fix(self):
assert find_nested_key(["nesteddict", "levelunknown"], self.item) == {}
def test_project_nested_key__partial_fix2(self):
assert find_nested_key(["nesteddict", "unknown", "unknown2"], self.item) == {}
def test_list_index(self):
assert find_nested_key(["rootlist[0]"], self.item) == {
"rootlist": [{"ll21": {"ll3.1": "val", "ll3.2": "val"}}]
}
def test_nested_list_index(self):
assert find_nested_key(["nesteddict", "nestedlist[1]"], self.item) == {
"nesteddict": {"nestedlist": [{"ll22": {"ll3.1": "val", "ll3.2": "val"}}]}
}
def test_nested_obj_in_list(self):
assert find_nested_key(
["nesteddict", "nestedlist[1]", "ll22", "ll3.1"], self.item
) == {"nesteddict": {"nestedlist": [{"ll22": {"ll3.1": "val"}}]}}
def test_list_unknown_indexes(self):
assert find_nested_key(["nesteddict", "nestedlist[25]"], self.item) == {}

View File

@ -895,7 +895,10 @@ def test_nested_projection_expression_using_get_item_with_attr_expression():
"nested": {
"level1": {"id": "id1", "att": "irrelevant"},
"level2": {"id": "id2", "include": "all"},
"level3": {"id": "irrelevant"},
"level3": {
"id": "irrelevant",
"children": [{"Name": "child_a"}, {"Name": "child_b"}],
},
},
"foo": "bar",
}
@ -926,11 +929,21 @@ def test_nested_projection_expression_using_get_item_with_attr_expression():
"nested": {
"level1": {"id": "id1", "att": "irrelevant"},
"level2": {"id": "id2", "include": "all"},
"level3": {"id": "irrelevant"},
"level3": {
"id": "irrelevant",
"children": [{"Name": "child_a"}, {"Name": "child_b"}],
},
},
}
)
# Test a get_item retrieving children
result = table.get_item(
Key={"forum_name": "key1"},
ProjectionExpression="nested.level3.children[0].Name",
)["Item"]
result.should.equal({"nested": {"level3": {"children": [{"Name": "child_a"}]}}})
@mock_dynamodb
def test_nested_projection_expression_using_query_with_attr_expression_names():
@ -3400,7 +3413,6 @@ def test_query_catches_when_no_filters():
@mock_dynamodb
def test_invalid_transact_get_items():
dynamodb = boto3.resource("dynamodb", region_name="us-east-1")
dynamodb.create_table(
TableName="test1",