From 6e7edd50574334ab7d066cd389e20aa5471a3ae8 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Thu, 8 Jun 2023 17:10:14 +0000 Subject: [PATCH] DynamoDB: Support projection expressions in lists (#6375) --- moto/dynamodb/models/dynamo_type.py | 78 ++++++------- moto/dynamodb/models/table.py | 26 ++--- moto/dynamodb/models/utilities.py | 110 +++++++++++++++++- .../exceptions/test_dynamodb_exceptions.py | 2 - tests/test_dynamodb/models/test_item.py | 110 ++++++++++++++++++ tests/test_dynamodb/models/test_utilities.py | 75 ++++++++++++ tests/test_dynamodb/test_dynamodb.py | 18 ++- 7 files changed, 354 insertions(+), 65 deletions(-) create mode 100644 tests/test_dynamodb/models/test_item.py create mode 100644 tests/test_dynamodb/models/test_utilities.py diff --git a/moto/dynamodb/models/dynamo_type.py b/moto/dynamodb/models/dynamo_type.py index 38ba27e99..80798a63c 100644 --- a/moto/dynamodb/models/dynamo_type.py +++ b/moto/dynamodb/models/dynamo_type.py @@ -1,4 +1,7 @@ +import copy import decimal + +from botocore.utils import merge_dicts from boto3.dynamodb.types import TypeDeserializer, TypeSerializer from typing import Any, Dict, List, Union, Optional @@ -8,7 +11,7 @@ from moto.dynamodb.exceptions import ( EmptyKeyAttributeException, ItemSizeTooLarge, ) -from moto.dynamodb.models.utilities import bytesize +from .utilities import bytesize, find_nested_key deserializer = TypeDeserializer() serializer = TypeSerializer() @@ -67,28 +70,6 @@ class DynamoType(object): elif self.is_map(): self.value = dict((k, DynamoType(v)) for k, v in self.value.items()) - def filter(self, projection_expressions: str) -> None: - nested_projections = [ - expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr - ] - if self.is_map(): - expressions_to_delete = [] - for attr in self.value: - if ( - attr not in projection_expressions - and attr not in nested_projections - ): - expressions_to_delete.append(attr) - elif attr in nested_projections: - relevant_expressions = [ - expr[len(attr + ".") :] - for expr in projection_expressions - if expr.startswith(attr + ".") - ] - self.value[attr].filter(relevant_expressions) - for expr in expressions_to_delete: - self.value.pop(expr) - def __hash__(self) -> int: return hash((self.type, self.value)) @@ -213,8 +194,26 @@ class DynamoType(object): return value_size def to_json(self) -> Dict[str, Any]: + # Returns a regular JSON object where the value can still be/contain a DynamoType return {self.type: self.value} + def to_regular_json(self) -> Dict[str, Any]: + # Returns a regular JSON object in full + value = copy.deepcopy(self.value) + if isinstance(value, dict): + for key, nested_value in value.items(): + value[key] = ( + nested_value.to_regular_json() + if isinstance(nested_value, DynamoType) + else nested_value + ) + if isinstance(value, list): + value = [ + val.to_regular_json() if isinstance(val, DynamoType) else val + for val in value + ] + return {self.type: value} + def compare(self, range_comparison: str, range_objs: List[Any]) -> bool: """ Compares this type against comparison filters @@ -310,7 +309,7 @@ class Item(BaseModel): def to_regular_json(self) -> Dict[str, Any]: attributes = {} for key, attribute in self.attrs.items(): - attributes[key] = deserializer.deserialize(attribute.to_json()) + attributes[key] = deserializer.deserialize(attribute.to_regular_json()) return attributes def describe_attrs( @@ -412,20 +411,19 @@ class Item(BaseModel): f"{action} action not support for update_with_attribute_updates" ) - # Filter using projection_expression - # Ensure a deep copy is used to filter, otherwise actual data will be removed - def filter(self, projection_expression: str) -> None: + def project(self, projection_expression: str) -> "Item": + # Returns a new Item with only the dictionary-keys that match the provided projection_expression + # Will return an empty Item if the expression does not match anything + result: Dict[str, Any] = dict() expressions = [x.strip() for x in projection_expression.split(",")] - top_level_expressions = [ - expr[0 : expr.index(".")] for expr in expressions if "." in expr - ] - for attr in list(self.attrs): - if attr not in expressions and attr not in top_level_expressions: - self.attrs.pop(attr) - if attr in top_level_expressions: - relevant_expressions = [ - expr[len(attr + ".") :] - for expr in expressions - if expr.startswith(attr + ".") - ] - self.attrs[attr].filter(relevant_expressions) + for expr in expressions: + x = find_nested_key(expr.split("."), self.to_regular_json()) + merge_dicts(result, x) + + return Item( + hash_key=self.hash_key, + range_key=self.range_key, + # 'result' is a normal Python dictionary ({'key': 'value'} + # We need to convert that into DynamoDB dictionary ({'M': {'key': {'S': 'value'}}}) + attrs=serializer.serialize(result)["M"], + ) diff --git a/moto/dynamodb/models/table.py b/moto/dynamodb/models/table.py index 69e25d341..5b854ac96 100644 --- a/moto/dynamodb/models/table.py +++ b/moto/dynamodb/models/table.py @@ -50,12 +50,12 @@ class SecondaryIndex(BaseModel): ] if projection_type == "KEYS_ONLY": - item.filter(",".join(key_attributes)) + item = item.project(",".join(key_attributes)) elif projection_type == "INCLUDE": allowed_attributes = key_attributes + self.projection.get( "NonKeyAttributes", [] ) - item.filter(",".join(allowed_attributes)) + item = item.project(",".join(allowed_attributes)) # ALL is handled implicitly by not filtering return item @@ -607,11 +607,7 @@ class Table(CloudFormationModel): result = self.items[hash_key] if projection_expression and result: - result = copy.deepcopy(result) - result.filter(projection_expression) - - if not result: - raise KeyError + result = result.project(projection_expression) return result except KeyError: @@ -728,9 +724,7 @@ class Table(CloudFormationModel): results = possible_results if index_name: - if index_range_key: - # Convert to float if necessary to ensure proper ordering def conv(x: DynamoType) -> Any: return float(x.value) if x.type == "N" else x.value @@ -751,8 +745,7 @@ class Table(CloudFormationModel): results = copy.deepcopy(results) if index_name: index = self.get_index(index_name) - for result in results: - index.project(result) + results = [index.project(r) for r in results] results, last_evaluated_key = self._trim_results( results, limit, exclusive_start_key, scanned_index=index_name @@ -762,8 +755,7 @@ class Table(CloudFormationModel): results = [item for item in results if filter_expression.expr(item)] if projection_expression: - for result in results: - result.filter(projection_expression) + results = [r.project(projection_expression) for r in results] return results, scanned_count, last_evaluated_key @@ -788,7 +780,6 @@ class Table(CloudFormationModel): return indexes_by_name[index_name] def has_idx_items(self, index_name: str) -> Iterator[Item]: - idx = self.get_index(index_name) idx_col_set = set([i["AttributeName"] for i in idx.schema]) @@ -848,8 +839,7 @@ class Table(CloudFormationModel): results = copy.deepcopy(results) if index_name: index = self.get_index(index_name) - for result in results: - index.project(result) + results = [index.project(r) for r in results] results, last_evaluated_key = self._trim_results( results, limit, exclusive_start_key, scanned_index=index_name @@ -859,9 +849,7 @@ class Table(CloudFormationModel): results = [item for item in results if filter_expression.expr(item)] if projection_expression: - results = copy.deepcopy(results) - for result in results: - result.filter(projection_expression) + results = [r.project(projection_expression) for r in results] return results, scanned_count, last_evaluated_key diff --git a/moto/dynamodb/models/utilities.py b/moto/dynamodb/models/utilities.py index 0aa6b59bd..227f34bc0 100644 --- a/moto/dynamodb/models/utilities.py +++ b/moto/dynamodb/models/utilities.py @@ -1,5 +1,6 @@ import json -from typing import Any +import re +from typing import Any, Dict, List, Optional class DynamoJsonEncoder(json.JSONEncoder): @@ -14,3 +15,110 @@ def dynamo_json_dump(dynamo_object: Any) -> str: def bytesize(val: str) -> int: return len(val.encode("utf-8")) + + +def find_nested_key( + keys: List[str], + dct: Dict[str, Any], + processed_keys: Optional[List[str]] = None, + result: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + keys : A list of keys that may be present in the provided dictionary + ["level1", "level2"] + dct : A dictionary that we want to inspect + {"level1": {"level2": "val", "irrelevant": ..} + + processed_keys: + Should not be set by the caller, only by recursive invocations. + Example value: ["level1"] + result: + Should not be set by the caller, only by recursive invocations + Example value: {"level1": {}} + + returns: {"level1": {"level2": "val"}} + """ + if result is None: + result = {} + if processed_keys is None: + processed_keys = [] + + # A key can refer to a list-item: 'level1[1].level2' + is_list_expression = re.match(pattern=r"(.+)\[(\d+)\]$", string=keys[0]) + + if len(keys) == 1: + # Set 'current_key' and 'value' + # or return an empty dictionary if the key does not exist in our dictionary + if is_list_expression: + current_key = is_list_expression.group(1) + idx = int(is_list_expression.group(2)) + if ( + current_key in dct + and isinstance(dct[current_key], list) + and len(dct[current_key]) >= idx + ): + value = [dct[current_key][idx]] + else: + return {} + elif keys[0] in dct: + current_key = keys[0] + value = dct[current_key] + else: + return {} + + # We may have already processed some keys + # Dig into the result to find the appropriate key to append the value to + # + # result: {'level1': {'level2': {}}} + # processed_keys: ['level1', 'level2'] + # --> + # result: {'level1': {'level2': value}} + temp_result = result + for key in processed_keys: + if isinstance(temp_result, list): + temp_result = temp_result[0][key] + else: + temp_result = temp_result[key] + if isinstance(temp_result, list): + temp_result.append({current_key: value}) + else: + temp_result[current_key] = value + return result + else: + # Set 'current_key' + # or return an empty dictionary if the key does not exist in our dictionary + if is_list_expression: + current_key = is_list_expression.group(1) + idx = int(is_list_expression.group(2)) + if ( + current_key in dct + and isinstance(dct[current_key], list) + and len(dct[current_key]) >= idx + ): + pass + else: + return {} + elif keys[0] in dct: + current_key = keys[0] + else: + return {} + + # Append the 'current_key' to the dictionary that is our result (so far) + # {'level1': {}} --> {'level1': {current_key: {}} + temp_result = result + for key in processed_keys: + temp_result = temp_result[key] + if isinstance(temp_result, list): + temp_result.append({current_key: [] if is_list_expression else {}}) + else: + temp_result[current_key] = [] if is_list_expression else {} + remaining_dct = ( + dct[current_key][idx] if is_list_expression else dct[current_key] + ) + + return find_nested_key( + keys[1:], + remaining_dct, + processed_keys=processed_keys + [current_key], + result=result, + ) diff --git a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py index 47386a102..861fc034c 100644 --- a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py +++ b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py @@ -199,7 +199,6 @@ def test_update_item_range_key_set(): @mock_dynamodb def test_batch_get_item_non_existing_table(): - client = boto3.client("dynamodb", region_name="us-west-2") with pytest.raises(client.exceptions.ResourceNotFoundException) as exc: @@ -768,7 +767,6 @@ def test_query_begins_with_without_brackets(): @mock_dynamodb def test_transact_write_items_multiple_operations_fail(): - # Setup schema = { "KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}], diff --git a/tests/test_dynamodb/models/test_item.py b/tests/test_dynamodb/models/test_item.py new file mode 100644 index 000000000..cd745cb4d --- /dev/null +++ b/tests/test_dynamodb/models/test_item.py @@ -0,0 +1,110 @@ +from moto.dynamodb.models.dynamo_type import DynamoType, Item +from moto.dynamodb.models.dynamo_type import serializer + + +class TestFindNestedKeys: + def setup(self): + self.dct = { + "simplestring": "val", + "nesteddict": { + "level21": {"ll31": "val", "ll32": "val"}, + "level22": {"ll31": "val", "ll32": "val"}, + "nestedlist": [ + {"ll21": {"ll31": "val", "ll32": "val"}}, + {"ll22": {"ll31": "val", "ll32": "val"}}, + ], + }, + "rootlist": [ + {"ll21": {"ll31": "val", "ll32": "val"}}, + {"ll22": {"ll31": "val", "ll32": "val"}}, + ], + } + x = serializer.serialize(self.dct)["M"] + self.item = Item( + hash_key=DynamoType({"pk": {"S": "v"}}), range_key=None, attrs=x + ) + + def _project(self, expression, result): + x = self.item.project(expression) + y = Item( + hash_key=DynamoType({"pk": {"S": "v"}}), + range_key=None, + attrs=serializer.serialize(result)["M"], + ) + assert x == y + + def test_find_nothing(self): + self._project("", result={}) + + def test_find_unknown_key(self): + self._project("unknown", result={}) + + def test_project_single_key_string(self): + self._project("simplestring", result={"simplestring": "val"}) + + def test_project_single_key_dict(self): + self._project( + "nesteddict", + result={ + "nesteddict": { + "level21": {"ll31": "val", "ll32": "val"}, + "level22": {"ll31": "val", "ll32": "val"}, + "nestedlist": [ + {"ll21": {"ll31": "val", "ll32": "val"}}, + {"ll22": {"ll31": "val", "ll32": "val"}}, + ], + } + }, + ) + + def test_project_nested_key(self): + self._project( + "nesteddict.level21", + result={"nesteddict": {"level21": {"ll31": "val", "ll32": "val"}}}, + ) + + def test_project_multi_level_nested_key(self): + self._project( + "nesteddict.level21.ll32", + result={"nesteddict": {"level21": {"ll32": "val"}}}, + ) + + def test_project_nested_key__partial_fix(self): + self._project("nesteddict.levelunknown", result={}) + + def test_project_nested_key__partial_fix2(self): + self._project("nesteddict.unknown.unknown2", result={}) + + def test_list_index(self): + self._project( + "rootlist[0]", + result={"rootlist": [{"ll21": {"ll31": "val", "ll32": "val"}}]}, + ) + + def test_nested_list_index(self): + self._project( + "nesteddict.nestedlist[1]", + result={ + "nesteddict": {"nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}]} + }, + ) + + def test_nested_obj_in_list(self): + self._project( + "nesteddict.nestedlist[1].ll22.ll31", + result={"nesteddict": {"nestedlist": [{"ll22": {"ll31": "val"}}]}}, + ) + + def test_list_unknown_indexes(self): + self._project("nesteddict.nestedlist[25]", result={}) + + def test_multiple_projections(self): + self._project( + "nesteddict.nestedlist[1].ll22,rootlist[0]", + result={ + "nesteddict": { + "nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}] + }, + "rootlist": [{"ll21": {"ll31": "val", "ll32": "val"}}], + }, + ) diff --git a/tests/test_dynamodb/models/test_utilities.py b/tests/test_dynamodb/models/test_utilities.py new file mode 100644 index 000000000..a9b4efec9 --- /dev/null +++ b/tests/test_dynamodb/models/test_utilities.py @@ -0,0 +1,75 @@ +from moto.dynamodb.models.utilities import find_nested_key + + +class TestFindDictionaryKeys: + def setup(self): + self.item = { + "simplestring": "val", + "nesteddict": { + "level21": {"level3.1": "val", "level3.2": "val"}, + "level22": {"level3.1": "val", "level3.2": "val"}, + "nestedlist": [ + {"ll21": {"ll3.1": "val", "ll3.2": "val"}}, + {"ll22": {"ll3.1": "val", "ll3.2": "val"}}, + ], + }, + "rootlist": [ + {"ll21": {"ll3.1": "val", "ll3.2": "val"}}, + {"ll22": {"ll3.1": "val", "ll3.2": "val"}}, + ], + } + + def test_find_nothing(self): + assert find_nested_key([""], self.item) == {} + + def test_find_unknown_key(self): + assert find_nested_key(["unknown"], self.item) == {} + + def test_project_single_key_string(self): + assert find_nested_key(["simplestring"], self.item) == {"simplestring": "val"} + + def test_project_single_key_dict(self): + assert find_nested_key(["nesteddict"], self.item) == { + "nesteddict": { + "level21": {"level3.1": "val", "level3.2": "val"}, + "level22": {"level3.1": "val", "level3.2": "val"}, + "nestedlist": [ + {"ll21": {"ll3.1": "val", "ll3.2": "val"}}, + {"ll22": {"ll3.1": "val", "ll3.2": "val"}}, + ], + } + } + + def test_project_nested_key(self): + assert find_nested_key(["nesteddict", "level21"], self.item) == { + "nesteddict": {"level21": {"level3.1": "val", "level3.2": "val"}} + } + + def test_project_multi_level_nested_key(self): + assert find_nested_key(["nesteddict", "level21", "level3.2"], self.item) == { + "nesteddict": {"level21": {"level3.2": "val"}} + } + + def test_project_nested_key__partial_fix(self): + assert find_nested_key(["nesteddict", "levelunknown"], self.item) == {} + + def test_project_nested_key__partial_fix2(self): + assert find_nested_key(["nesteddict", "unknown", "unknown2"], self.item) == {} + + def test_list_index(self): + assert find_nested_key(["rootlist[0]"], self.item) == { + "rootlist": [{"ll21": {"ll3.1": "val", "ll3.2": "val"}}] + } + + def test_nested_list_index(self): + assert find_nested_key(["nesteddict", "nestedlist[1]"], self.item) == { + "nesteddict": {"nestedlist": [{"ll22": {"ll3.1": "val", "ll3.2": "val"}}]} + } + + def test_nested_obj_in_list(self): + assert find_nested_key( + ["nesteddict", "nestedlist[1]", "ll22", "ll3.1"], self.item + ) == {"nesteddict": {"nestedlist": [{"ll22": {"ll3.1": "val"}}]}} + + def test_list_unknown_indexes(self): + assert find_nested_key(["nesteddict", "nestedlist[25]"], self.item) == {} diff --git a/tests/test_dynamodb/test_dynamodb.py b/tests/test_dynamodb/test_dynamodb.py index 546c95a3b..8de5cc32f 100644 --- a/tests/test_dynamodb/test_dynamodb.py +++ b/tests/test_dynamodb/test_dynamodb.py @@ -895,7 +895,10 @@ def test_nested_projection_expression_using_get_item_with_attr_expression(): "nested": { "level1": {"id": "id1", "att": "irrelevant"}, "level2": {"id": "id2", "include": "all"}, - "level3": {"id": "irrelevant"}, + "level3": { + "id": "irrelevant", + "children": [{"Name": "child_a"}, {"Name": "child_b"}], + }, }, "foo": "bar", } @@ -926,11 +929,21 @@ def test_nested_projection_expression_using_get_item_with_attr_expression(): "nested": { "level1": {"id": "id1", "att": "irrelevant"}, "level2": {"id": "id2", "include": "all"}, - "level3": {"id": "irrelevant"}, + "level3": { + "id": "irrelevant", + "children": [{"Name": "child_a"}, {"Name": "child_b"}], + }, }, } ) + # Test a get_item retrieving children + result = table.get_item( + Key={"forum_name": "key1"}, + ProjectionExpression="nested.level3.children[0].Name", + )["Item"] + result.should.equal({"nested": {"level3": {"children": [{"Name": "child_a"}]}}}) + @mock_dynamodb def test_nested_projection_expression_using_query_with_attr_expression_names(): @@ -3400,7 +3413,6 @@ def test_query_catches_when_no_filters(): @mock_dynamodb def test_invalid_transact_get_items(): - dynamodb = boto3.resource("dynamodb", region_name="us-east-1") dynamodb.create_table( TableName="test1",