DynamoDB: Allow ProjectionExpressions on attributes that contain a . (#6709)

This commit is contained in:
Bert Blommers 2023-08-21 21:52:58 +00:00 committed by GitHub
parent bc29ae2fc3
commit 9e3e5e947b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 60 additions and 58 deletions

View File

@ -301,11 +301,11 @@ class DynamoDBBackend(BaseBackend):
self,
table_name: str,
keys: Dict[str, Any],
projection_expression: Optional[str] = None,
projection_expressions: Optional[List[List[str]]] = None,
) -> Optional[Item]:
table = self.get_table(table_name)
hash_key, range_key = self.get_keys_value(table, keys)
return table.get_item(hash_key, range_key, projection_expression)
return table.get_item(hash_key, range_key, projection_expressions)
def query(
self,
@ -316,7 +316,7 @@ class DynamoDBBackend(BaseBackend):
limit: int,
exclusive_start_key: Dict[str, Any],
scan_index_forward: bool,
projection_expression: Optional[str],
projection_expressions: Optional[List[List[str]]],
index_name: Optional[str] = None,
expr_names: Optional[Dict[str, str]] = None,
expr_values: Optional[Dict[str, str]] = None,
@ -339,7 +339,7 @@ class DynamoDBBackend(BaseBackend):
limit,
exclusive_start_key,
scan_index_forward,
projection_expression,
projection_expressions,
index_name,
filter_expression_op,
**filter_kwargs,
@ -355,7 +355,7 @@ class DynamoDBBackend(BaseBackend):
expr_names: Dict[str, Any],
expr_values: Dict[str, Any],
index_name: str,
projection_expression: Optional[str],
projection_expression: Optional[List[List[str]]],
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
table = self.get_table(table_name)

View File

@ -418,13 +418,12 @@ class Item(BaseModel):
f"{action} action not support for update_with_attribute_updates"
)
def project(self, projection_expression: str) -> "Item":
def project(self, projection_expressions: List[List[str]]) -> "Item":
# Returns a new Item with only the dictionary-keys that match the provided projection_expression
# Will return an empty Item if the expression does not match anything
result: Dict[str, Any] = dict()
expressions = [x.strip() for x in projection_expression.split(",")]
for expr in expressions:
x = find_nested_key(expr.split("."), self.to_regular_json())
for expr in projection_expressions:
x = find_nested_key(expr, self.to_regular_json())
merge_dicts(result, x)
return Item(

View File

@ -50,12 +50,18 @@ class SecondaryIndex(BaseModel):
]
if projection_type == "KEYS_ONLY":
item = item.project(",".join(key_attributes))
# 'project' expects lists of lists of strings
# project([["attr1"], ["nested", "attr2"]]
#
# In our case, we need to convert
# ["key1", "key2"]
# into
# [["key1"], ["key2"]]
item = item.project([[attr] for attr in key_attributes])
elif projection_type == "INCLUDE":
allowed_attributes = key_attributes + self.projection.get(
"NonKeyAttributes", []
)
item = item.project(",".join(allowed_attributes))
allowed_attributes = key_attributes
allowed_attributes.extend(self.projection.get("NonKeyAttributes", []))
item = item.project([[attr] for attr in allowed_attributes])
# ALL is handled implicitly by not filtering
return item
@ -592,7 +598,7 @@ class Table(CloudFormationModel):
self,
hash_key: DynamoType,
range_key: Optional[DynamoType] = None,
projection_expression: Optional[str] = None,
projection_expression: Optional[List[List[str]]] = None,
) -> Optional[Item]:
if self.has_range_key and not range_key:
raise MockValidationException(
@ -637,7 +643,7 @@ class Table(CloudFormationModel):
limit: int,
exclusive_start_key: Dict[str, Any],
scan_index_forward: bool,
projection_expression: Optional[str],
projection_expressions: Optional[List[List[str]]],
index_name: Optional[str] = None,
filter_expression: Any = None,
**filter_kwargs: Any,
@ -754,8 +760,8 @@ class Table(CloudFormationModel):
if filter_expression is not None:
results = [item for item in results if filter_expression.expr(item)]
if projection_expression:
results = [r.project(projection_expression) for r in results]
if projection_expressions:
results = [r.project(projection_expressions) for r in results]
return results, scanned_count, last_evaluated_key
@ -799,7 +805,7 @@ class Table(CloudFormationModel):
exclusive_start_key: Dict[str, Any],
filter_expression: Any = None,
index_name: Optional[str] = None,
projection_expression: Optional[str] = None,
projection_expression: Optional[List[List[str]]] = None,
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
results = []
scanned_count = 0

View File

@ -556,11 +556,11 @@ class DynamoHandler(BaseResponse):
)
expression_attribute_names = expression_attribute_names or {}
projection_expression = self._adjust_projection_expression(
projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names
)
item = self.dynamodb_backend.get_item(name, key, projection_expression)
item = self.dynamodb_backend.get_item(name, key, projection_expressions)
if item:
item_dict = item.describe_attrs(attributes=None)
return dynamo_json_dump(item_dict)
@ -608,14 +608,14 @@ class DynamoHandler(BaseResponse):
"ExpressionAttributeNames", {}
)
projection_expression = self._adjust_projection_expression(
projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names
)
results["Responses"][table_name] = []
for key in keys:
item = self.dynamodb_backend.get_item(
table_name, key, projection_expression
table_name, key, projection_expressions
)
if item:
# A single operation can retrieve up to 16 MB of data [and] returns a partial result if the response size limit is exceeded
@ -652,7 +652,7 @@ class DynamoHandler(BaseResponse):
filter_expression = self._get_filter_expression()
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
projection_expression = self._adjust_projection_expression(
projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names
)
@ -720,7 +720,7 @@ class DynamoHandler(BaseResponse):
limit,
exclusive_start_key,
scan_index_forward,
projection_expression,
projection_expressions,
index_name=index_name,
expr_names=expression_attribute_names,
expr_values=expression_attribute_values,
@ -743,27 +743,24 @@ class DynamoHandler(BaseResponse):
def _adjust_projection_expression(
self, projection_expression: Optional[str], expr_attr_names: Dict[str, str]
) -> Optional[str]:
) -> List[List[str]]:
"""
lvl1.lvl2.attr1,lvl1.attr2 --> [["lvl1", "lvl2", "attr1"], ["lvl1", "attr2]]
"""
def _adjust(expression: str) -> str:
return (
expr_attr_names[expression]
if expression in expr_attr_names
else expression
)
return (expr_attr_names or {}).get(expression, expression)
if projection_expression:
expressions = [x.strip() for x in projection_expression.split(",")]
for expression in expressions:
check_projection_expression(expression)
if expr_attr_names:
return ",".join(
[
".".join([_adjust(expr) for expr in nested_expr.split(".")])
for nested_expr in expressions
]
)
return [
[_adjust(expr) for expr in nested_expr.split(".")]
for nested_expr in expressions
]
return projection_expression
return []
@include_consumed_capacity()
def scan(self) -> str:
@ -786,7 +783,7 @@ class DynamoHandler(BaseResponse):
limit = self.body.get("Limit")
index_name = self.body.get("IndexName")
projection_expression = self._adjust_projection_expression(
projection_expressions = self._adjust_projection_expression(
projection_expression, expression_attribute_names
)
@ -800,7 +797,7 @@ class DynamoHandler(BaseResponse):
expression_attribute_names,
expression_attribute_values,
index_name,
projection_expression,
projection_expressions,
)
except ValueError as err:
raise MockValidationException(f"Bad Filter Expression: {err}")

View File

@ -34,17 +34,17 @@ class TestFindNestedKeys:
assert x == y
def test_find_nothing(self):
self._project("", result={})
self._project([[""]], result={})
def test_find_unknown_key(self):
self._project("unknown", result={})
self._project([["unknown"]], result={})
def test_project_single_key_string(self):
self._project("simplestring", result={"simplestring": "val"})
self._project([["simplestring"]], result={"simplestring": "val"})
def test_project_single_key_dict(self):
self._project(
"nesteddict",
[["nesteddict"]],
result={
"nesteddict": {
"level21": {"ll31": "val", "ll32": "val"},
@ -59,31 +59,31 @@ class TestFindNestedKeys:
def test_project_nested_key(self):
self._project(
"nesteddict.level21",
[["nesteddict", "level21"]],
result={"nesteddict": {"level21": {"ll31": "val", "ll32": "val"}}},
)
def test_project_multi_level_nested_key(self):
self._project(
"nesteddict.level21.ll32",
[["nesteddict", "level21", "ll32"]],
result={"nesteddict": {"level21": {"ll32": "val"}}},
)
def test_project_nested_key__partial_fix(self):
self._project("nesteddict.levelunknown", result={})
self._project([["nesteddict", "levelunknown"]], result={})
def test_project_nested_key__partial_fix2(self):
self._project("nesteddict.unknown.unknown2", result={})
self._project([["nesteddict", "unknown", "unknown2"]], result={})
def test_list_index(self):
self._project(
"rootlist[0]",
[["rootlist[0]"]],
result={"rootlist": [{"ll21": {"ll31": "val", "ll32": "val"}}]},
)
def test_nested_list_index(self):
self._project(
"nesteddict.nestedlist[1]",
[["nesteddict", "nestedlist[1]"]],
result={
"nesteddict": {"nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}]}
},
@ -91,16 +91,16 @@ class TestFindNestedKeys:
def test_nested_obj_in_list(self):
self._project(
"nesteddict.nestedlist[1].ll22.ll31",
[["nesteddict", "nestedlist[1]", "ll22", "ll31"]],
result={"nesteddict": {"nestedlist": [{"ll22": {"ll31": "val"}}]}},
)
def test_list_unknown_indexes(self):
self._project("nesteddict.nestedlist[25]", result={})
self._project([["nesteddict", "nestedlist[25]"]], result={})
def test_multiple_projections(self):
self._project(
"nesteddict.nestedlist[1].ll22,rootlist[0]",
[["nesteddict", "nestedlist[1]", "ll22"], ["rootlist[0]"]],
result={
"nesteddict": {
"nestedlist": [{"ll22": {"ll31": "val", "ll32": "val"}}]

View File

@ -886,7 +886,7 @@ def test_nested_projection_expression_using_get_item_with_attr_expression():
"forum_name": "key1",
"nested": {
"level1": {"id": "id1", "att": "irrelevant"},
"level2": {"id": "id2", "include": "all"},
"level.2": {"id": "id2", "include": "all"},
"level3": {
"id": "irrelevant",
"children": [{"Name": "child_a"}, {"Name": "child_b"}],
@ -907,10 +907,10 @@ def test_nested_projection_expression_using_get_item_with_attr_expression():
result = table.get_item(
Key={"forum_name": "key1"},
ProjectionExpression="#nst.level1.id, #nst.#lvl2",
ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level2"},
ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level.2"},
)["Item"]
assert result == {
"nested": {"level1": {"id": "id1"}, "level2": {"id": "id2", "include": "all"}}
"nested": {"level1": {"id": "id1"}, "level.2": {"id": "id2", "include": "all"}}
}
# Assert actual data has not been deleted
result = table.get_item(Key={"forum_name": "key1"})["Item"]
@ -919,7 +919,7 @@ def test_nested_projection_expression_using_get_item_with_attr_expression():
"forum_name": "key1",
"nested": {
"level1": {"id": "id1", "att": "irrelevant"},
"level2": {"id": "id2", "include": "all"},
"level.2": {"id": "id2", "include": "all"},
"level3": {
"id": "irrelevant",
"children": [{"Name": "child_a"}, {"Name": "child_b"}],