moto/moto/dynamodb2/parsing/validators.py
pvbouwel ec731ac901 Improve DDB expressions support4: Execution using AST
Part of structured approach for UpdateExpressions:
 1) Expression gets parsed into a tokenlist (tokenized)
 2) Tokenlist get transformed to expression tree (AST)
 3) The AST gets validated (full semantic correctness)
 4) AST gets processed to perform the update -> this commit

This commit uses the AST to execute the UpdateExpression.
All the existing tests pass. The only tests that have been
updated are in test_dynamodb_table_with_range_key.py because
they wrongly allow adding a set to a path that doesn't exist.
This has been alligend to correspond to the behavior of AWS
DynamoDB.

This commit will resolve https://github.com/spulec/moto/issues/2806
Multiple tests have been implemented that verify this.
2020-04-26 15:59:12 +01:00

369 lines
14 KiB
Python

"""
See docstring class Validator below for more details on validation
"""
from abc import abstractmethod
from copy import deepcopy
from moto.dynamodb2.exceptions import (
AttributeIsReservedKeyword,
ExpressionAttributeValueNotDefined,
AttributeDoesNotExist,
ExpressionAttributeNameNotDefined,
IncorrectOperandType,
InvalidUpdateExpressionInvalidDocumentPath,
ProvidedKeyDoesNotExist,
)
from moto.dynamodb2.models import DynamoType
from moto.dynamodb2.parsing.ast_nodes import (
ExpressionAttribute,
UpdateExpressionPath,
UpdateExpressionSetAction,
UpdateExpressionAddAction,
UpdateExpressionDeleteAction,
UpdateExpressionRemoveAction,
DDBTypedValue,
ExpressionAttributeValue,
ExpressionAttributeName,
DepthFirstTraverser,
NoneExistingPath,
UpdateExpressionFunction,
ExpressionPathDescender,
UpdateExpressionValue,
ExpressionValueOperator,
ExpressionSelector,
)
from moto.dynamodb2.parsing.reserved_keywords import ReservedKeywords
class ExpressionAttributeValueProcessor(DepthFirstTraverser):
def __init__(self, expression_attribute_values):
self.expression_attribute_values = expression_attribute_values
def _processing_map(self):
return {
ExpressionAttributeValue: self.replace_expression_attribute_value_with_value
}
def replace_expression_attribute_value_with_value(self, node):
"""A node representing an Expression Attribute Value. Resolve and replace value"""
assert isinstance(node, ExpressionAttributeValue)
attribute_value_name = node.get_value_name()
try:
target = self.expression_attribute_values[attribute_value_name]
except KeyError:
raise ExpressionAttributeValueNotDefined(
attribute_value=attribute_value_name
)
return DDBTypedValue(DynamoType(target))
class ExpressionPathResolver(object):
def __init__(self, expression_attribute_names):
self.expression_attribute_names = expression_attribute_names
@classmethod
def raise_exception_if_keyword(cls, attribute):
if attribute.upper() in ReservedKeywords.get_reserved_keywords():
raise AttributeIsReservedKeyword(attribute)
def resolve_expression_path(self, item, update_expression_path):
assert isinstance(update_expression_path, UpdateExpressionPath)
return self.resolve_expression_path_nodes(item, update_expression_path.children)
def resolve_expression_path_nodes(self, item, update_expression_path_nodes):
target = item.attrs
for child in update_expression_path_nodes:
# First replace placeholder with attribute_name
attr_name = None
if isinstance(child, ExpressionAttributeName):
attr_placeholder = child.get_attribute_name_placeholder()
try:
attr_name = self.expression_attribute_names[attr_placeholder]
except KeyError:
raise ExpressionAttributeNameNotDefined(attr_placeholder)
elif isinstance(child, ExpressionAttribute):
attr_name = child.get_attribute_name()
self.raise_exception_if_keyword(attr_name)
if attr_name is not None:
# Resolv attribute_name
try:
target = target[attr_name]
except (KeyError, TypeError):
if child == update_expression_path_nodes[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
if isinstance(child, ExpressionPathDescender):
continue
elif isinstance(child, ExpressionSelector):
index = child.get_index()
if target.is_list():
try:
target = target[index]
except IndexError:
# When a list goes out of bounds when assigning that is no problem when at the assignment
# side. It will just append to the list.
if child == update_expression_path_nodes[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
raise InvalidUpdateExpressionInvalidDocumentPath
else:
raise NotImplementedError(
"Path resolution for {t}".format(t=type(child))
)
if not isinstance(target, DynamoType):
print(target)
return DDBTypedValue(target)
def resolve_expression_path_nodes_to_dynamo_type(
self, item, update_expression_path_nodes
):
node = self.resolve_expression_path_nodes(item, update_expression_path_nodes)
if isinstance(node, NoneExistingPath):
raise ProvidedKeyDoesNotExist()
assert isinstance(node, DDBTypedValue)
return node.get_value()
class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
def _processing_map(self):
return {
UpdateExpressionSetAction: self.disable_resolving,
UpdateExpressionPath: self.process_expression_path_node,
}
def __init__(self, expression_attribute_names, item):
self.expression_attribute_names = expression_attribute_names
self.item = item
self.resolving = False
def pre_processing_of_child(self, parent_node, child_id):
"""
We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first.
Because first argument is path to be set, 2nd argument would be the value.
"""
if isinstance(
parent_node,
(
UpdateExpressionSetAction,
UpdateExpressionRemoveAction,
UpdateExpressionDeleteAction,
UpdateExpressionAddAction,
),
):
if child_id == 0:
self.resolving = False
else:
self.resolving = True
def disable_resolving(self, node=None):
self.resolving = False
return node
def process_expression_path_node(self, node):
"""Resolve ExpressionAttribute if not part of a path and resolving is enabled."""
if self.resolving:
return self.resolve_expression_path(node)
else:
# Still resolve but return original note to make sure path is correct Just make sure nodes are creatable.
result_node = self.resolve_expression_path(node)
if (
isinstance(result_node, NoneExistingPath)
and not result_node.is_creatable()
):
raise InvalidUpdateExpressionInvalidDocumentPath()
return node
def resolve_expression_path(self, node):
return ExpressionPathResolver(
self.expression_attribute_names
).resolve_expression_path(self.item, node)
class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
"""
At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET
expression as per the official AWS docs:
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/
Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET
"""
def _processing_map(self):
return {UpdateExpressionFunction: self.process_function}
def process_function(self, node):
assert isinstance(node, UpdateExpressionFunction)
function_name = node.get_function_name()
first_arg = node.get_nth_argument(1)
second_arg = node.get_nth_argument(2)
if function_name == "if_not_exists":
if isinstance(first_arg, NoneExistingPath):
result = second_arg
else:
result = first_arg
assert isinstance(result, (DDBTypedValue, NoneExistingPath))
return result
elif function_name == "list_append":
first_arg = deepcopy(
self.get_list_from_ddb_typed_value(first_arg, function_name)
)
second_arg = self.get_list_from_ddb_typed_value(second_arg, function_name)
for list_element in second_arg.value:
first_arg.value.append(list_element)
return DDBTypedValue(first_arg)
else:
raise NotImplementedError(
"Unsupported function for moto {name}".format(name=function_name)
)
@classmethod
def get_list_from_ddb_typed_value(cls, node, function_name):
assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType)
if not dynamo_value.is_list():
raise IncorrectOperandType(function_name, dynamo_value.type)
return dynamo_value
class NoneExistingPathChecker(DepthFirstTraverser):
"""
Pass through the AST and make sure there are no none-existing paths.
"""
def _processing_map(self):
return {NoneExistingPath: self.raise_none_existing_path}
def raise_none_existing_path(self, node):
raise AttributeDoesNotExist
class ExecuteOperations(DepthFirstTraverser):
def _processing_map(self):
return {UpdateExpressionValue: self.process_update_expression_value}
def process_update_expression_value(self, node):
"""
If an UpdateExpressionValue only has a single child the node will be replaced with the childe.
Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them
Args:
node(Node):
Returns:
Node: The resulting node of the operation if present or the child.
"""
assert isinstance(node, UpdateExpressionValue)
if len(node.children) == 1:
return node.children[0]
elif len(node.children) == 3:
operator_node = node.children[1]
assert isinstance(operator_node, ExpressionValueOperator)
operator = operator_node.get_operator()
left_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[0])
right_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[2])
if operator == "+":
return self.get_sum(left_operand, right_operand)
elif operator == "-":
return self.get_subtraction(left_operand, right_operand)
else:
raise NotImplementedError(
"Moto does not support operator {operator}".format(
operator=operator
)
)
else:
raise NotImplementedError(
"UpdateExpressionValue only has implementations for 1 or 3 children."
)
@classmethod
def get_dynamo_value_from_ddb_typed_value(cls, node):
assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType)
return dynamo_value
@classmethod
def get_sum(cls, left_operand, right_operand):
"""
Args:
left_operand(DynamoType):
right_operand(DynamoType):
Returns:
DDBTypedValue:
"""
try:
return DDBTypedValue(left_operand + right_operand)
except TypeError:
raise IncorrectOperandType("+", left_operand.type)
@classmethod
def get_subtraction(cls, left_operand, right_operand):
"""
Args:
left_operand(DynamoType):
right_operand(DynamoType):
Returns:
DDBTypedValue:
"""
try:
return DDBTypedValue(left_operand - right_operand)
except TypeError:
raise IncorrectOperandType("-", left_operand.type)
class Validator(object):
"""
A validator is used to validate expressions which are passed in as an AST.
"""
def __init__(
self, expression, expression_attribute_names, expression_attribute_values, item
):
"""
Besides validation the Validator should also replace referenced parts of an item which is cheapest upon
validation.
Args:
expression(Node): The root node of the AST representing the expression to be validated
expression_attribute_names(ExpressionAttributeNames):
expression_attribute_values(ExpressionAttributeValues):
item(Item): The item which will be updated (pointed to by Key of update_item)
"""
self.expression_attribute_names = expression_attribute_names
self.expression_attribute_values = expression_attribute_values
self.item = item
self.processors = self.get_ast_processors()
self.node_to_validate = deepcopy(expression)
@abstractmethod
def get_ast_processors(self):
"""Get the different processors that go through the AST tree and processes the nodes."""
def validate(self):
n = self.node_to_validate
for processor in self.processors:
n = processor.traverse(n)
return n
class UpdateExpressionValidator(Validator):
def get_ast_processors(self):
"""Get the different processors that go through the AST tree and processes the nodes."""
processors = [
ExpressionAttributeValueProcessor(self.expression_attribute_values),
ExpressionAttributeResolvingProcessor(
self.expression_attribute_names, self.item
),
UpdateExpressionFunctionEvaluator(),
NoneExistingPathChecker(),
ExecuteOperations(),
]
return processors