diff --git a/moto/dynamodb/models/__init__.py b/moto/dynamodb/models/__init__.py index 9be813030..704f02753 100644 --- a/moto/dynamodb/models/__init__.py +++ b/moto/dynamodb/models/__init__.py @@ -301,11 +301,11 @@ class DynamoDBBackend(BaseBackend): self, table_name: str, keys: Dict[str, Any], - projection_expression: Optional[str] = None, + projection_expressions: Optional[List[List[str]]] = None, ) -> Optional[Item]: table = self.get_table(table_name) hash_key, range_key = self.get_keys_value(table, keys) - return table.get_item(hash_key, range_key, projection_expression) + return table.get_item(hash_key, range_key, projection_expressions) def query( self, @@ -316,7 +316,7 @@ class DynamoDBBackend(BaseBackend): limit: int, exclusive_start_key: Dict[str, Any], scan_index_forward: bool, - projection_expression: Optional[str], + projection_expressions: Optional[List[List[str]]], index_name: Optional[str] = None, expr_names: Optional[Dict[str, str]] = None, expr_values: Optional[Dict[str, str]] = None, @@ -339,7 +339,7 @@ class DynamoDBBackend(BaseBackend): limit, exclusive_start_key, scan_index_forward, - projection_expression, + projection_expressions, index_name, filter_expression_op, **filter_kwargs, @@ -355,7 +355,7 @@ class DynamoDBBackend(BaseBackend): expr_names: Dict[str, Any], expr_values: Dict[str, Any], index_name: str, - projection_expression: Optional[str], + projection_expression: Optional[List[List[str]]], ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: table = self.get_table(table_name) diff --git a/moto/dynamodb/models/dynamo_type.py b/moto/dynamodb/models/dynamo_type.py index 0ea58b49d..992100c06 100644 --- a/moto/dynamodb/models/dynamo_type.py +++ b/moto/dynamodb/models/dynamo_type.py @@ -418,13 +418,12 @@ class Item(BaseModel): f"{action} action not support for update_with_attribute_updates" ) - def project(self, projection_expression: str) -> "Item": + def project(self, projection_expressions: List[List[str]]) -> "Item": # Returns a new Item with only the dictionary-keys that match the provided projection_expression # Will return an empty Item if the expression does not match anything result: Dict[str, Any] = dict() - expressions = [x.strip() for x in projection_expression.split(",")] - for expr in expressions: - x = find_nested_key(expr.split("."), self.to_regular_json()) + for expr in projection_expressions: + x = find_nested_key(expr, self.to_regular_json()) merge_dicts(result, x) return Item( diff --git a/moto/dynamodb/models/table.py b/moto/dynamodb/models/table.py index 78f18e809..45037afc9 100644 --- a/moto/dynamodb/models/table.py +++ b/moto/dynamodb/models/table.py @@ -50,12 +50,18 @@ class SecondaryIndex(BaseModel): ] if projection_type == "KEYS_ONLY": - item = item.project(",".join(key_attributes)) + # 'project' expects lists of lists of strings + # project([["attr1"], ["nested", "attr2"]] + # + # In our case, we need to convert + # ["key1", "key2"] + # into + # [["key1"], ["key2"]] + item = item.project([[attr] for attr in key_attributes]) elif projection_type == "INCLUDE": - allowed_attributes = key_attributes + self.projection.get( - "NonKeyAttributes", [] - ) - item = item.project(",".join(allowed_attributes)) + allowed_attributes = key_attributes + allowed_attributes.extend(self.projection.get("NonKeyAttributes", [])) + item = item.project([[attr] for attr in allowed_attributes]) # ALL is handled implicitly by not filtering return item @@ -592,7 +598,7 @@ class Table(CloudFormationModel): self, hash_key: DynamoType, range_key: Optional[DynamoType] = None, - projection_expression: Optional[str] = None, + projection_expression: Optional[List[List[str]]] = None, ) -> Optional[Item]: if self.has_range_key and not range_key: raise MockValidationException( @@ -637,7 +643,7 @@ class Table(CloudFormationModel): limit: int, exclusive_start_key: Dict[str, Any], scan_index_forward: bool, - projection_expression: Optional[str], + projection_expressions: Optional[List[List[str]]], index_name: Optional[str] = None, filter_expression: Any = None, **filter_kwargs: Any, @@ -754,8 +760,8 @@ class Table(CloudFormationModel): if filter_expression is not None: results = [item for item in results if filter_expression.expr(item)] - if projection_expression: - results = [r.project(projection_expression) for r in results] + if projection_expressions: + results = [r.project(projection_expressions) for r in results] return results, scanned_count, last_evaluated_key @@ -799,7 +805,7 @@ class Table(CloudFormationModel): exclusive_start_key: Dict[str, Any], filter_expression: Any = None, index_name: Optional[str] = None, - projection_expression: Optional[str] = None, + projection_expression: Optional[List[List[str]]] = None, ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: results = [] scanned_count = 0 diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index 8f94fcb6f..d64a10e0b 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -556,11 +556,11 @@ class DynamoHandler(BaseResponse): ) expression_attribute_names = expression_attribute_names or {} - projection_expression = self._adjust_projection_expression( + projection_expressions = self._adjust_projection_expression( projection_expression, expression_attribute_names ) - item = self.dynamodb_backend.get_item(name, key, projection_expression) + item = self.dynamodb_backend.get_item(name, key, projection_expressions) if item: item_dict = item.describe_attrs(attributes=None) return dynamo_json_dump(item_dict) @@ -608,14 +608,14 @@ class DynamoHandler(BaseResponse): "ExpressionAttributeNames", {} ) - projection_expression = self._adjust_projection_expression( + projection_expressions = self._adjust_projection_expression( projection_expression, expression_attribute_names ) results["Responses"][table_name] = [] for key in keys: item = self.dynamodb_backend.get_item( - table_name, key, projection_expression + table_name, key, projection_expressions ) if item: # A single operation can retrieve up to 16 MB of data [and] returns a partial result if the response size limit is exceeded @@ -652,7 +652,7 @@ class DynamoHandler(BaseResponse): filter_expression = self._get_filter_expression() expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) - projection_expression = self._adjust_projection_expression( + projection_expressions = self._adjust_projection_expression( projection_expression, expression_attribute_names ) @@ -720,7 +720,7 @@ class DynamoHandler(BaseResponse): limit, exclusive_start_key, scan_index_forward, - projection_expression, + projection_expressions, index_name=index_name, expr_names=expression_attribute_names, expr_values=expression_attribute_values, @@ -743,27 +743,24 @@ class DynamoHandler(BaseResponse): def _adjust_projection_expression( self, projection_expression: Optional[str], expr_attr_names: Dict[str, str] - ) -> Optional[str]: + ) -> List[List[str]]: + """ + lvl1.lvl2.attr1,lvl1.attr2 --> [["lvl1", "lvl2", "attr1"], ["lvl1", "attr2]] + """ + def _adjust(expression: str) -> str: - return ( - expr_attr_names[expression] - if expression in expr_attr_names - else expression - ) + return (expr_attr_names or {}).get(expression, expression) if projection_expression: expressions = [x.strip() for x in projection_expression.split(",")] for expression in expressions: check_projection_expression(expression) - if expr_attr_names: - return ",".join( - [ - ".".join([_adjust(expr) for expr in nested_expr.split(".")]) - for nested_expr in expressions - ] - ) + return [ + [_adjust(expr) for expr in nested_expr.split(".")] + for nested_expr in expressions + ] - return projection_expression + return [] @include_consumed_capacity() def scan(self) -> str: @@ -786,7 +783,7 @@ class DynamoHandler(BaseResponse): limit = self.body.get("Limit") index_name = self.body.get("IndexName") - projection_expression = self._adjust_projection_expression( + projection_expressions = self._adjust_projection_expression( projection_expression, expression_attribute_names ) @@ -800,7 +797,7 @@ class DynamoHandler(BaseResponse): expression_attribute_names, expression_attribute_values, index_name, - projection_expression, + projection_expressions, ) except ValueError as err: raise MockValidationException(f"Bad Filter Expression: {err}") diff --git a/tests/test_dynamodb/models/test_item.py b/tests/test_dynamodb/models/test_item.py index f3ed31097..e26a3375a 100644 --- a/tests/test_dynamodb/models/test_item.py +++ b/tests/test_dynamodb/models/test_item.py @@ -34,17 +34,17 @@ class TestFindNestedKeys: assert x == y def test_find_nothing(self): - self._project("", result={}) + self._project([[""]], result={}) def test_find_unknown_key(self): - self._project("unknown", result={}) + self._project([["unknown"]], result={}) def test_project_single_key_string(self): - self._project("simplestring", result={"simplestring": "val"}) + self._project([["simplestring"]], result={"simplestring": "val"}) def test_project_single_key_dict(self): self._project( - "nesteddict", + [["nesteddict"]], result={ "nesteddict": { "level21": {"ll31": "val", "ll32": "val"}, @@ -59,31 +59,31 @@ class TestFindNestedKeys: def test_project_nested_key(self): self._project( - "nesteddict.level21", + [["nesteddict", "level21"]], result={"nesteddict": {"level21": {"ll31": "val", "ll32": "val"}}}, ) def test_project_multi_level_nested_key(self): self._project( - "nesteddict.level21.ll32", + [["nesteddict", "level21", "ll32"]], result={"nesteddict": {"level21": {"ll32": "val"}}}, ) def test_project_nested_key__partial_fix(self): - self._project("nesteddict.levelunknown", result={}) + self._project([["nesteddict", "levelunknown"]], result={}) def test_project_nested_key__partial_fix2(self): - self._project("nesteddict.unknown.unknown2", result={}) + self._project([["nesteddict", "unknown", "unknown2"]], result={}) def test_list_index(self): self._project( - "rootlist[0]", + [["rootlist[0]"]], result={"rootlist": [{"ll21": {"ll31": "val", "ll32": "val"}}]}, ) def test_nested_list_index(self): self._project( - "nesteddict.nestedlist[1]", + [["nesteddict", "nestedlist[1]"]], result={ "nesteddict": {"nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}]} }, @@ -91,16 +91,16 @@ class TestFindNestedKeys: def test_nested_obj_in_list(self): self._project( - "nesteddict.nestedlist[1].ll22.ll31", + [["nesteddict", "nestedlist[1]", "ll22", "ll31"]], result={"nesteddict": {"nestedlist": [{"ll22": {"ll31": "val"}}]}}, ) def test_list_unknown_indexes(self): - self._project("nesteddict.nestedlist[25]", result={}) + self._project([["nesteddict", "nestedlist[25]"]], result={}) def test_multiple_projections(self): self._project( - "nesteddict.nestedlist[1].ll22,rootlist[0]", + [["nesteddict", "nestedlist[1]", "ll22"], ["rootlist[0]"]], result={ "nesteddict": { "nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}] diff --git a/tests/test_dynamodb/test_dynamodb.py b/tests/test_dynamodb/test_dynamodb.py index e0787bf0a..81e4dc3e6 100644 --- a/tests/test_dynamodb/test_dynamodb.py +++ b/tests/test_dynamodb/test_dynamodb.py @@ -886,7 +886,7 @@ def test_nested_projection_expression_using_get_item_with_attr_expression(): "forum_name": "key1", "nested": { "level1": {"id": "id1", "att": "irrelevant"}, - "level2": {"id": "id2", "include": "all"}, + "level.2": {"id": "id2", "include": "all"}, "level3": { "id": "irrelevant", "children": [{"Name": "child_a"}, {"Name": "child_b"}], @@ -907,10 +907,10 @@ def test_nested_projection_expression_using_get_item_with_attr_expression(): result = table.get_item( Key={"forum_name": "key1"}, ProjectionExpression="#nst.level1.id, #nst.#lvl2", - ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level2"}, + ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level.2"}, )["Item"] assert result == { - "nested": {"level1": {"id": "id1"}, "level2": {"id": "id2", "include": "all"}} + "nested": {"level1": {"id": "id1"}, "level.2": {"id": "id2", "include": "all"}} } # Assert actual data has not been deleted result = table.get_item(Key={"forum_name": "key1"})["Item"] @@ -919,7 +919,7 @@ def test_nested_projection_expression_using_get_item_with_attr_expression(): "forum_name": "key1", "nested": { "level1": {"id": "id1", "att": "irrelevant"}, - "level2": {"id": "id2", "include": "all"}, + "level.2": {"id": "id2", "include": "all"}, "level3": { "id": "irrelevant", "children": [{"Name": "child_a"}, {"Name": "child_b"}],