Using Ops for dynamodb condition expressions

This commit is contained in:
Matthew Stevens 2019-04-01 16:48:00 -04:00 committed by Garrett Heel
parent 2712654518
commit 57b668c832
4 changed files with 58 additions and 31 deletions

View File

@ -59,7 +59,6 @@ def get_expected(expected):
values = [
AttributeValue(v)
for v in cond.get("AttributeValueList", [])]
print(path, values)
OpClass = ops[operator_name]
conditions.append(OpClass(path, *values))
@ -72,7 +71,6 @@ def get_expected(expected):
else:
return OpDefault(None, None)
print("EXPECTED:", expected, output)
return output
@ -486,7 +484,7 @@ class ConditionExpressionParser:
lhs = nodes.popleft()
comparator = nodes.popleft()
rhs = nodes.popleft()
output.append(self.Node(
nodes.appendleft(self.Node(
nonterminal=self.Nonterminal.CONDITION,
kind=self.Kind.COMPARISON,
text=" ".join([
@ -528,7 +526,7 @@ class ConditionExpressionParser:
self._assert(
False,
"Bad IN expression starting at", nodes)
output.append(self.Node(
nodes.appendleft(self.Node(
nonterminal=self.Nonterminal.CONDITION,
kind=self.Kind.IN,
text=" ".join([t.text for t in all_children]),
@ -553,7 +551,7 @@ class ConditionExpressionParser:
and_node = nodes.popleft()
high = nodes.popleft()
all_children = [lhs, between_node, low, and_node, high]
output.append(self.Node(
nodes.appendleft(self.Node(
nonterminal=self.Nonterminal.CONDITION,
kind=self.Kind.BETWEEN,
text=" ".join([t.text for t in all_children]),
@ -613,7 +611,7 @@ class ConditionExpressionParser:
nonterminal = self.Nonterminal.OPERAND
else:
nonterminal = self.Nonterminal.CONDITION
output.append(self.Node(
nodes.appendleft(self.Node(
nonterminal=nonterminal,
kind=self.Kind.FUNCTION,
text=" ".join([t.text for t in all_children]),
@ -685,7 +683,7 @@ class ConditionExpressionParser:
"Bad NOT expression", list(nodes)[:2])
not_node = nodes.popleft()
child = nodes.popleft()
output.append(self.Node(
nodes.appendleft(self.Node(
nonterminal=self.Nonterminal.CONDITION,
kind=self.Kind.NOT,
text=" ".join([not_node.text, child.text]),
@ -708,7 +706,7 @@ class ConditionExpressionParser:
and_node = nodes.popleft()
rhs = nodes.popleft()
all_children = [lhs, and_node, rhs]
output.append(self.Node(
nodes.appendleft(self.Node(
nonterminal=self.Nonterminal.CONDITION,
kind=self.Kind.AND,
text=" ".join([t.text for t in all_children]),
@ -731,7 +729,7 @@ class ConditionExpressionParser:
or_node = nodes.popleft()
rhs = nodes.popleft()
all_children = [lhs, or_node, rhs]
output.append(self.Node(
nodes.appendleft(self.Node(
nonterminal=self.Nonterminal.CONDITION,
kind=self.Kind.OR,
text=" ".join([t.text for t in all_children]),

View File

@ -537,7 +537,9 @@ class Table(BaseModel):
keys.append(range_key)
return keys
def put_item(self, item_attrs, expected=None, overwrite=False):
def put_item(self, item_attrs, expected=None, condition_expression=None,
expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
if self.has_range_key:
range_value = DynamoType(item_attrs.get(self.range_key_attr))
@ -562,6 +564,12 @@ class Table(BaseModel):
if not overwrite:
if not get_expected(expected).expr(current):
raise ValueError('The conditional request failed')
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
expression_attribute_values)
if not condition_op.expr(current):
raise ValueError('The conditional request failed')
if range_value:
self.items[hash_value][range_value] = item
@ -907,11 +915,15 @@ class DynamoDBBackend(BaseBackend):
table.global_indexes = list(gsis_by_name.values())
return table
def put_item(self, table_name, item_attrs, expected=None, overwrite=False):
def put_item(self, table_name, item_attrs, expected=None,
condition_expression=None, expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
table = self.tables.get(table_name)
if not table:
return None
return table.put_item(item_attrs, expected, overwrite)
return table.put_item(item_attrs, expected, condition_expression,
expression_attribute_names,
expression_attribute_values, overwrite)
def get_table_keys_name(self, table_name, keys):
"""
@ -988,7 +1000,7 @@ class DynamoDBBackend(BaseBackend):
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name)
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected=None):
expression_attribute_values, expected=None, condition_expression=None):
table = self.get_table(table_name)
if all([table.hash_key_attr in key, table.range_key_attr in key]):
@ -1012,6 +1024,12 @@ class DynamoDBBackend(BaseBackend):
if not get_expected(expected).expr(item):
raise ValueError('The conditional request failed')
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
expression_attribute_values)
if not condition_op.expr(current):
raise ValueError('The conditional request failed')
# Update does not fail on new items, so create one
if item is None:

View File

@ -288,18 +288,18 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected
# expression
if not expected:
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression,
expression_attribute_names,
expression_attribute_values)
if expected:
if condition_expression:
overwrite = False
try:
result = self.dynamodb_backend.put_item(name, item, expected, overwrite)
result = self.dynamodb_backend.put_item(
name, item, expected, condition_expression,
expression_attribute_names, expression_attribute_values,
overwrite)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er, 'A condition specified in the operation could not be evaluated.')
@ -652,13 +652,9 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected
# expression
if not expected:
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression,
expression_attribute_names,
expression_attribute_values)
# Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c`
@ -669,7 +665,7 @@ class DynamoHandler(BaseResponse):
try:
item = self.dynamodb_backend.update_item(
name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected
expression_attribute_values, expected, condition_expression
)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'

View File

@ -1616,6 +1616,21 @@ def test_condition_expressions():
}
)
client.put_item(
TableName='test1',
Item={
'client': {'S': 'client1'},
'app': {'S': 'app1'},
'match': {'S': 'match'},
'existing': {'S': 'existing'},
},
ConditionExpression='attribute_exists(#nonexistent) OR attribute_exists(#existing)',
ExpressionAttributeNames={
'#nonexistent': 'nope',
'#existing': 'existing'
}
)
with assert_raises(client.exceptions.ConditionalCheckFailedException):
client.put_item(
TableName='test1',