DynamoDB: Allow ProjectionExpressions on attributes that contain a . (#6709)

This commit is contained in:
Bert Blommers 2023-08-21 21:52:58 +00:00 committed by GitHub
parent bc29ae2fc3
commit 9e3e5e947b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 60 additions and 58 deletions

View File

@ -301,11 +301,11 @@ class DynamoDBBackend(BaseBackend):
self, self,
table_name: str, table_name: str,
keys: Dict[str, Any], keys: Dict[str, Any],
projection_expression: Optional[str] = None, projection_expressions: Optional[List[List[str]]] = None,
) -> Optional[Item]: ) -> Optional[Item]:
table = self.get_table(table_name) table = self.get_table(table_name)
hash_key, range_key = self.get_keys_value(table, keys) hash_key, range_key = self.get_keys_value(table, keys)
return table.get_item(hash_key, range_key, projection_expression) return table.get_item(hash_key, range_key, projection_expressions)
def query( def query(
self, self,
@ -316,7 +316,7 @@ class DynamoDBBackend(BaseBackend):
limit: int, limit: int,
exclusive_start_key: Dict[str, Any], exclusive_start_key: Dict[str, Any],
scan_index_forward: bool, scan_index_forward: bool,
projection_expression: Optional[str], projection_expressions: Optional[List[List[str]]],
index_name: Optional[str] = None, index_name: Optional[str] = None,
expr_names: Optional[Dict[str, str]] = None, expr_names: Optional[Dict[str, str]] = None,
expr_values: Optional[Dict[str, str]] = None, expr_values: Optional[Dict[str, str]] = None,
@ -339,7 +339,7 @@ class DynamoDBBackend(BaseBackend):
limit, limit,
exclusive_start_key, exclusive_start_key,
scan_index_forward, scan_index_forward,
projection_expression, projection_expressions,
index_name, index_name,
filter_expression_op, filter_expression_op,
**filter_kwargs, **filter_kwargs,
@ -355,7 +355,7 @@ class DynamoDBBackend(BaseBackend):
expr_names: Dict[str, Any], expr_names: Dict[str, Any],
expr_values: Dict[str, Any], expr_values: Dict[str, Any],
index_name: str, index_name: str,
projection_expression: Optional[str], projection_expression: Optional[List[List[str]]],
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
table = self.get_table(table_name) table = self.get_table(table_name)

View File

@ -418,13 +418,12 @@ class Item(BaseModel):
f"{action} action not support for update_with_attribute_updates" f"{action} action not support for update_with_attribute_updates"
) )
def project(self, projection_expression: str) -> "Item": def project(self, projection_expressions: List[List[str]]) -> "Item":
# Returns a new Item with only the dictionary-keys that match the provided projection_expression # Returns a new Item with only the dictionary-keys that match the provided projection_expression
# Will return an empty Item if the expression does not match anything # Will return an empty Item if the expression does not match anything
result: Dict[str, Any] = dict() result: Dict[str, Any] = dict()
expressions = [x.strip() for x in projection_expression.split(",")] for expr in projection_expressions:
for expr in expressions: x = find_nested_key(expr, self.to_regular_json())
x = find_nested_key(expr.split("."), self.to_regular_json())
merge_dicts(result, x) merge_dicts(result, x)
return Item( return Item(

View File

@ -50,12 +50,18 @@ class SecondaryIndex(BaseModel):
] ]
if projection_type == "KEYS_ONLY": if projection_type == "KEYS_ONLY":
item = item.project(",".join(key_attributes)) # 'project' expects lists of lists of strings
# project([["attr1"], ["nested", "attr2"]]
#
# In our case, we need to convert
# ["key1", "key2"]
# into
# [["key1"], ["key2"]]
item = item.project([[attr] for attr in key_attributes])
elif projection_type == "INCLUDE": elif projection_type == "INCLUDE":
allowed_attributes = key_attributes + self.projection.get( allowed_attributes = key_attributes
"NonKeyAttributes", [] allowed_attributes.extend(self.projection.get("NonKeyAttributes", []))
) item = item.project([[attr] for attr in allowed_attributes])
item = item.project(",".join(allowed_attributes))
# ALL is handled implicitly by not filtering # ALL is handled implicitly by not filtering
return item return item
@ -592,7 +598,7 @@ class Table(CloudFormationModel):
self, self,
hash_key: DynamoType, hash_key: DynamoType,
range_key: Optional[DynamoType] = None, range_key: Optional[DynamoType] = None,
projection_expression: Optional[str] = None, projection_expression: Optional[List[List[str]]] = None,
) -> Optional[Item]: ) -> Optional[Item]:
if self.has_range_key and not range_key: if self.has_range_key and not range_key:
raise MockValidationException( raise MockValidationException(
@ -637,7 +643,7 @@ class Table(CloudFormationModel):
limit: int, limit: int,
exclusive_start_key: Dict[str, Any], exclusive_start_key: Dict[str, Any],
scan_index_forward: bool, scan_index_forward: bool,
projection_expression: Optional[str], projection_expressions: Optional[List[List[str]]],
index_name: Optional[str] = None, index_name: Optional[str] = None,
filter_expression: Any = None, filter_expression: Any = None,
**filter_kwargs: Any, **filter_kwargs: Any,
@ -754,8 +760,8 @@ class Table(CloudFormationModel):
if filter_expression is not None: if filter_expression is not None:
results = [item for item in results if filter_expression.expr(item)] results = [item for item in results if filter_expression.expr(item)]
if projection_expression: if projection_expressions:
results = [r.project(projection_expression) for r in results] results = [r.project(projection_expressions) for r in results]
return results, scanned_count, last_evaluated_key return results, scanned_count, last_evaluated_key
@ -799,7 +805,7 @@ class Table(CloudFormationModel):
exclusive_start_key: Dict[str, Any], exclusive_start_key: Dict[str, Any],
filter_expression: Any = None, filter_expression: Any = None,
index_name: Optional[str] = None, index_name: Optional[str] = None,
projection_expression: Optional[str] = None, projection_expression: Optional[List[List[str]]] = None,
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
results = [] results = []
scanned_count = 0 scanned_count = 0

View File

@ -556,11 +556,11 @@ class DynamoHandler(BaseResponse):
) )
expression_attribute_names = expression_attribute_names or {} expression_attribute_names = expression_attribute_names or {}
projection_expression = self._adjust_projection_expression( projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names projection_expression, expression_attribute_names
) )
item = self.dynamodb_backend.get_item(name, key, projection_expression) item = self.dynamodb_backend.get_item(name, key, projection_expressions)
if item: if item:
item_dict = item.describe_attrs(attributes=None) item_dict = item.describe_attrs(attributes=None)
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
@ -608,14 +608,14 @@ class DynamoHandler(BaseResponse):
"ExpressionAttributeNames", {} "ExpressionAttributeNames", {}
) )
projection_expression = self._adjust_projection_expression( projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names projection_expression, expression_attribute_names
) )
results["Responses"][table_name] = [] results["Responses"][table_name] = []
for key in keys: for key in keys:
item = self.dynamodb_backend.get_item( item = self.dynamodb_backend.get_item(
table_name, key, projection_expression table_name, key, projection_expressions
) )
if item: if item:
# A single operation can retrieve up to 16 MB of data [and] returns a partial result if the response size limit is exceeded # A single operation can retrieve up to 16 MB of data [and] returns a partial result if the response size limit is exceeded
@ -652,7 +652,7 @@ class DynamoHandler(BaseResponse):
filter_expression = self._get_filter_expression() filter_expression = self._get_filter_expression()
expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
projection_expression = self._adjust_projection_expression( projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names projection_expression, expression_attribute_names
) )
@ -720,7 +720,7 @@ class DynamoHandler(BaseResponse):
limit, limit,
exclusive_start_key, exclusive_start_key,
scan_index_forward, scan_index_forward,
projection_expression, projection_expressions,
index_name=index_name, index_name=index_name,
expr_names=expression_attribute_names, expr_names=expression_attribute_names,
expr_values=expression_attribute_values, expr_values=expression_attribute_values,
@ -743,27 +743,24 @@ class DynamoHandler(BaseResponse):
def _adjust_projection_expression( def _adjust_projection_expression(
self, projection_expression: Optional[str], expr_attr_names: Dict[str, str] self, projection_expression: Optional[str], expr_attr_names: Dict[str, str]
) -> Optional[str]: ) -> List[List[str]]:
"""
lvl1.lvl2.attr1,lvl1.attr2 --> [["lvl1", "lvl2", "attr1"], ["lvl1", "attr2]]
"""
def _adjust(expression: str) -> str: def _adjust(expression: str) -> str:
return ( return (expr_attr_names or {}).get(expression, expression)
expr_attr_names[expression]
if expression in expr_attr_names
else expression
)
if projection_expression: if projection_expression:
expressions = [x.strip() for x in projection_expression.split(",")] expressions = [x.strip() for x in projection_expression.split(",")]
for expression in expressions: for expression in expressions:
check_projection_expression(expression) check_projection_expression(expression)
if expr_attr_names: return [
return ",".join( [_adjust(expr) for expr in nested_expr.split(".")]
[ for nested_expr in expressions
".".join([_adjust(expr) for expr in nested_expr.split(".")]) ]
for nested_expr in expressions
]
)
return projection_expression return []
@include_consumed_capacity() @include_consumed_capacity()
def scan(self) -> str: def scan(self) -> str:
@ -786,7 +783,7 @@ class DynamoHandler(BaseResponse):
limit = self.body.get("Limit") limit = self.body.get("Limit")
index_name = self.body.get("IndexName") index_name = self.body.get("IndexName")
projection_expression = self._adjust_projection_expression( projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names projection_expression, expression_attribute_names
) )
@ -800,7 +797,7 @@ class DynamoHandler(BaseResponse):
expression_attribute_names, expression_attribute_names,
expression_attribute_values, expression_attribute_values,
index_name, index_name,
projection_expression, projection_expressions,
) )
except ValueError as err: except ValueError as err:
raise MockValidationException(f"Bad Filter Expression: {err}") raise MockValidationException(f"Bad Filter Expression: {err}")

View File

@ -34,17 +34,17 @@ class TestFindNestedKeys:
assert x == y assert x == y
def test_find_nothing(self): def test_find_nothing(self):
self._project("", result={}) self._project([[""]], result={})
def test_find_unknown_key(self): def test_find_unknown_key(self):
self._project("unknown", result={}) self._project([["unknown"]], result={})
def test_project_single_key_string(self): def test_project_single_key_string(self):
self._project("simplestring", result={"simplestring": "val"}) self._project([["simplestring"]], result={"simplestring": "val"})
def test_project_single_key_dict(self): def test_project_single_key_dict(self):
self._project( self._project(
"nesteddict", [["nesteddict"]],
result={ result={
"nesteddict": { "nesteddict": {
"level21": {"ll31": "val", "ll32": "val"}, "level21": {"ll31": "val", "ll32": "val"},
@ -59,31 +59,31 @@ class TestFindNestedKeys:
def test_project_nested_key(self): def test_project_nested_key(self):
self._project( self._project(
"nesteddict.level21", [["nesteddict", "level21"]],
result={"nesteddict": {"level21": {"ll31": "val", "ll32": "val"}}}, result={"nesteddict": {"level21": {"ll31": "val", "ll32": "val"}}},
) )
def test_project_multi_level_nested_key(self): def test_project_multi_level_nested_key(self):
self._project( self._project(
"nesteddict.level21.ll32", [["nesteddict", "level21", "ll32"]],
result={"nesteddict": {"level21": {"ll32": "val"}}}, result={"nesteddict": {"level21": {"ll32": "val"}}},
) )
def test_project_nested_key__partial_fix(self): def test_project_nested_key__partial_fix(self):
self._project("nesteddict.levelunknown", result={}) self._project([["nesteddict", "levelunknown"]], result={})
def test_project_nested_key__partial_fix2(self): def test_project_nested_key__partial_fix2(self):
self._project("nesteddict.unknown.unknown2", result={}) self._project([["nesteddict", "unknown", "unknown2"]], result={})
def test_list_index(self): def test_list_index(self):
self._project( self._project(
"rootlist[0]", [["rootlist[0]"]],
result={"rootlist": [{"ll21": {"ll31": "val", "ll32": "val"}}]}, result={"rootlist": [{"ll21": {"ll31": "val", "ll32": "val"}}]},
) )
def test_nested_list_index(self): def test_nested_list_index(self):
self._project( self._project(
"nesteddict.nestedlist[1]", [["nesteddict", "nestedlist[1]"]],
result={ result={
"nesteddict": {"nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}]} "nesteddict": {"nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}]}
}, },
@ -91,16 +91,16 @@ class TestFindNestedKeys:
def test_nested_obj_in_list(self): def test_nested_obj_in_list(self):
self._project( self._project(
"nesteddict.nestedlist[1].ll22.ll31", [["nesteddict", "nestedlist[1]", "ll22", "ll31"]],
result={"nesteddict": {"nestedlist": [{"ll22": {"ll31": "val"}}]}}, result={"nesteddict": {"nestedlist": [{"ll22": {"ll31": "val"}}]}},
) )
def test_list_unknown_indexes(self): def test_list_unknown_indexes(self):
self._project("nesteddict.nestedlist[25]", result={}) self._project([["nesteddict", "nestedlist[25]"]], result={})
def test_multiple_projections(self): def test_multiple_projections(self):
self._project( self._project(
"nesteddict.nestedlist[1].ll22,rootlist[0]", [["nesteddict", "nestedlist[1]", "ll22"], ["rootlist[0]"]],
result={ result={
"nesteddict": { "nesteddict": {
"nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}] "nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}]

View File

@ -886,7 +886,7 @@ def test_nested_projection_expression_using_get_item_with_attr_expression():
"forum_name": "key1", "forum_name": "key1",
"nested": { "nested": {
"level1": {"id": "id1", "att": "irrelevant"}, "level1": {"id": "id1", "att": "irrelevant"},
"level2": {"id": "id2", "include": "all"}, "level.2": {"id": "id2", "include": "all"},
"level3": { "level3": {
"id": "irrelevant", "id": "irrelevant",
"children": [{"Name": "child_a"}, {"Name": "child_b"}], "children": [{"Name": "child_a"}, {"Name": "child_b"}],
@ -907,10 +907,10 @@ def test_nested_projection_expression_using_get_item_with_attr_expression():
result = table.get_item( result = table.get_item(
Key={"forum_name": "key1"}, Key={"forum_name": "key1"},
ProjectionExpression="#nst.level1.id, #nst.#lvl2", ProjectionExpression="#nst.level1.id, #nst.#lvl2",
ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level2"}, ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level.2"},
)["Item"] )["Item"]
assert result == { assert result == {
"nested": {"level1": {"id": "id1"}, "level2": {"id": "id2", "include": "all"}} "nested": {"level1": {"id": "id1"}, "level.2": {"id": "id2", "include": "all"}}
} }
# Assert actual data has not been deleted # Assert actual data has not been deleted
result = table.get_item(Key={"forum_name": "key1"})["Item"] result = table.get_item(Key={"forum_name": "key1"})["Item"]
@ -919,7 +919,7 @@ def test_nested_projection_expression_using_get_item_with_attr_expression():
"forum_name": "key1", "forum_name": "key1",
"nested": { "nested": {
"level1": {"id": "id1", "att": "irrelevant"}, "level1": {"id": "id1", "att": "irrelevant"},
"level2": {"id": "id2", "include": "all"}, "level.2": {"id": "id2", "include": "all"},
"level3": { "level3": {
"id": "irrelevant", "id": "irrelevant",
"children": [{"Name": "child_a"}, {"Name": "child_b"}], "children": [{"Name": "child_a"}, {"Name": "child_b"}],