Using Ops for dynamodb condition expressions
This commit is contained in:
parent
2712654518
commit
57b668c832
@ -59,7 +59,6 @@ def get_expected(expected):
|
|||||||
values = [
|
values = [
|
||||||
AttributeValue(v)
|
AttributeValue(v)
|
||||||
for v in cond.get("AttributeValueList", [])]
|
for v in cond.get("AttributeValueList", [])]
|
||||||
print(path, values)
|
|
||||||
OpClass = ops[operator_name]
|
OpClass = ops[operator_name]
|
||||||
conditions.append(OpClass(path, *values))
|
conditions.append(OpClass(path, *values))
|
||||||
|
|
||||||
@ -72,7 +71,6 @@ def get_expected(expected):
|
|||||||
else:
|
else:
|
||||||
return OpDefault(None, None)
|
return OpDefault(None, None)
|
||||||
|
|
||||||
print("EXPECTED:", expected, output)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -486,7 +484,7 @@ class ConditionExpressionParser:
|
|||||||
lhs = nodes.popleft()
|
lhs = nodes.popleft()
|
||||||
comparator = nodes.popleft()
|
comparator = nodes.popleft()
|
||||||
rhs = nodes.popleft()
|
rhs = nodes.popleft()
|
||||||
output.append(self.Node(
|
nodes.appendleft(self.Node(
|
||||||
nonterminal=self.Nonterminal.CONDITION,
|
nonterminal=self.Nonterminal.CONDITION,
|
||||||
kind=self.Kind.COMPARISON,
|
kind=self.Kind.COMPARISON,
|
||||||
text=" ".join([
|
text=" ".join([
|
||||||
@ -528,7 +526,7 @@ class ConditionExpressionParser:
|
|||||||
self._assert(
|
self._assert(
|
||||||
False,
|
False,
|
||||||
"Bad IN expression starting at", nodes)
|
"Bad IN expression starting at", nodes)
|
||||||
output.append(self.Node(
|
nodes.appendleft(self.Node(
|
||||||
nonterminal=self.Nonterminal.CONDITION,
|
nonterminal=self.Nonterminal.CONDITION,
|
||||||
kind=self.Kind.IN,
|
kind=self.Kind.IN,
|
||||||
text=" ".join([t.text for t in all_children]),
|
text=" ".join([t.text for t in all_children]),
|
||||||
@ -553,7 +551,7 @@ class ConditionExpressionParser:
|
|||||||
and_node = nodes.popleft()
|
and_node = nodes.popleft()
|
||||||
high = nodes.popleft()
|
high = nodes.popleft()
|
||||||
all_children = [lhs, between_node, low, and_node, high]
|
all_children = [lhs, between_node, low, and_node, high]
|
||||||
output.append(self.Node(
|
nodes.appendleft(self.Node(
|
||||||
nonterminal=self.Nonterminal.CONDITION,
|
nonterminal=self.Nonterminal.CONDITION,
|
||||||
kind=self.Kind.BETWEEN,
|
kind=self.Kind.BETWEEN,
|
||||||
text=" ".join([t.text for t in all_children]),
|
text=" ".join([t.text for t in all_children]),
|
||||||
@ -613,7 +611,7 @@ class ConditionExpressionParser:
|
|||||||
nonterminal = self.Nonterminal.OPERAND
|
nonterminal = self.Nonterminal.OPERAND
|
||||||
else:
|
else:
|
||||||
nonterminal = self.Nonterminal.CONDITION
|
nonterminal = self.Nonterminal.CONDITION
|
||||||
output.append(self.Node(
|
nodes.appendleft(self.Node(
|
||||||
nonterminal=nonterminal,
|
nonterminal=nonterminal,
|
||||||
kind=self.Kind.FUNCTION,
|
kind=self.Kind.FUNCTION,
|
||||||
text=" ".join([t.text for t in all_children]),
|
text=" ".join([t.text for t in all_children]),
|
||||||
@ -685,7 +683,7 @@ class ConditionExpressionParser:
|
|||||||
"Bad NOT expression", list(nodes)[:2])
|
"Bad NOT expression", list(nodes)[:2])
|
||||||
not_node = nodes.popleft()
|
not_node = nodes.popleft()
|
||||||
child = nodes.popleft()
|
child = nodes.popleft()
|
||||||
output.append(self.Node(
|
nodes.appendleft(self.Node(
|
||||||
nonterminal=self.Nonterminal.CONDITION,
|
nonterminal=self.Nonterminal.CONDITION,
|
||||||
kind=self.Kind.NOT,
|
kind=self.Kind.NOT,
|
||||||
text=" ".join([not_node.text, child.text]),
|
text=" ".join([not_node.text, child.text]),
|
||||||
@ -708,7 +706,7 @@ class ConditionExpressionParser:
|
|||||||
and_node = nodes.popleft()
|
and_node = nodes.popleft()
|
||||||
rhs = nodes.popleft()
|
rhs = nodes.popleft()
|
||||||
all_children = [lhs, and_node, rhs]
|
all_children = [lhs, and_node, rhs]
|
||||||
output.append(self.Node(
|
nodes.appendleft(self.Node(
|
||||||
nonterminal=self.Nonterminal.CONDITION,
|
nonterminal=self.Nonterminal.CONDITION,
|
||||||
kind=self.Kind.AND,
|
kind=self.Kind.AND,
|
||||||
text=" ".join([t.text for t in all_children]),
|
text=" ".join([t.text for t in all_children]),
|
||||||
@ -731,7 +729,7 @@ class ConditionExpressionParser:
|
|||||||
or_node = nodes.popleft()
|
or_node = nodes.popleft()
|
||||||
rhs = nodes.popleft()
|
rhs = nodes.popleft()
|
||||||
all_children = [lhs, or_node, rhs]
|
all_children = [lhs, or_node, rhs]
|
||||||
output.append(self.Node(
|
nodes.appendleft(self.Node(
|
||||||
nonterminal=self.Nonterminal.CONDITION,
|
nonterminal=self.Nonterminal.CONDITION,
|
||||||
kind=self.Kind.OR,
|
kind=self.Kind.OR,
|
||||||
text=" ".join([t.text for t in all_children]),
|
text=" ".join([t.text for t in all_children]),
|
||||||
|
@ -537,7 +537,9 @@ class Table(BaseModel):
|
|||||||
keys.append(range_key)
|
keys.append(range_key)
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
def put_item(self, item_attrs, expected=None, overwrite=False):
|
def put_item(self, item_attrs, expected=None, condition_expression=None,
|
||||||
|
expression_attribute_names=None,
|
||||||
|
expression_attribute_values=None, overwrite=False):
|
||||||
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
|
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
|
||||||
if self.has_range_key:
|
if self.has_range_key:
|
||||||
range_value = DynamoType(item_attrs.get(self.range_key_attr))
|
range_value = DynamoType(item_attrs.get(self.range_key_attr))
|
||||||
@ -562,6 +564,12 @@ class Table(BaseModel):
|
|||||||
if not overwrite:
|
if not overwrite:
|
||||||
if not get_expected(expected).expr(current):
|
if not get_expected(expected).expr(current):
|
||||||
raise ValueError('The conditional request failed')
|
raise ValueError('The conditional request failed')
|
||||||
|
condition_op = get_filter_expression(
|
||||||
|
condition_expression,
|
||||||
|
expression_attribute_names,
|
||||||
|
expression_attribute_values)
|
||||||
|
if not condition_op.expr(current):
|
||||||
|
raise ValueError('The conditional request failed')
|
||||||
|
|
||||||
if range_value:
|
if range_value:
|
||||||
self.items[hash_value][range_value] = item
|
self.items[hash_value][range_value] = item
|
||||||
@ -907,11 +915,15 @@ class DynamoDBBackend(BaseBackend):
|
|||||||
table.global_indexes = list(gsis_by_name.values())
|
table.global_indexes = list(gsis_by_name.values())
|
||||||
return table
|
return table
|
||||||
|
|
||||||
def put_item(self, table_name, item_attrs, expected=None, overwrite=False):
|
def put_item(self, table_name, item_attrs, expected=None,
|
||||||
|
condition_expression=None, expression_attribute_names=None,
|
||||||
|
expression_attribute_values=None, overwrite=False):
|
||||||
table = self.tables.get(table_name)
|
table = self.tables.get(table_name)
|
||||||
if not table:
|
if not table:
|
||||||
return None
|
return None
|
||||||
return table.put_item(item_attrs, expected, overwrite)
|
return table.put_item(item_attrs, expected, condition_expression,
|
||||||
|
expression_attribute_names,
|
||||||
|
expression_attribute_values, overwrite)
|
||||||
|
|
||||||
def get_table_keys_name(self, table_name, keys):
|
def get_table_keys_name(self, table_name, keys):
|
||||||
"""
|
"""
|
||||||
@ -988,7 +1000,7 @@ class DynamoDBBackend(BaseBackend):
|
|||||||
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name)
|
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name)
|
||||||
|
|
||||||
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
|
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
|
||||||
expression_attribute_values, expected=None):
|
expression_attribute_values, expected=None, condition_expression=None):
|
||||||
table = self.get_table(table_name)
|
table = self.get_table(table_name)
|
||||||
|
|
||||||
if all([table.hash_key_attr in key, table.range_key_attr in key]):
|
if all([table.hash_key_attr in key, table.range_key_attr in key]):
|
||||||
@ -1012,6 +1024,12 @@ class DynamoDBBackend(BaseBackend):
|
|||||||
|
|
||||||
if not get_expected(expected).expr(item):
|
if not get_expected(expected).expr(item):
|
||||||
raise ValueError('The conditional request failed')
|
raise ValueError('The conditional request failed')
|
||||||
|
condition_op = get_filter_expression(
|
||||||
|
condition_expression,
|
||||||
|
expression_attribute_names,
|
||||||
|
expression_attribute_values)
|
||||||
|
if not condition_op.expr(current):
|
||||||
|
raise ValueError('The conditional request failed')
|
||||||
|
|
||||||
# Update does not fail on new items, so create one
|
# Update does not fail on new items, so create one
|
||||||
if item is None:
|
if item is None:
|
||||||
|
@ -288,18 +288,18 @@ class DynamoHandler(BaseResponse):
|
|||||||
|
|
||||||
# Attempt to parse simple ConditionExpressions into an Expected
|
# Attempt to parse simple ConditionExpressions into an Expected
|
||||||
# expression
|
# expression
|
||||||
if not expected:
|
condition_expression = self.body.get('ConditionExpression')
|
||||||
condition_expression = self.body.get('ConditionExpression')
|
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
|
||||||
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
|
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
|
||||||
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
|
|
||||||
expected = condition_expression_to_expected(condition_expression,
|
if condition_expression:
|
||||||
expression_attribute_names,
|
overwrite = False
|
||||||
expression_attribute_values)
|
|
||||||
if expected:
|
|
||||||
overwrite = False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.dynamodb_backend.put_item(name, item, expected, overwrite)
|
result = self.dynamodb_backend.put_item(
|
||||||
|
name, item, expected, condition_expression,
|
||||||
|
expression_attribute_names, expression_attribute_values,
|
||||||
|
overwrite)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
|
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
|
||||||
return self.error(er, 'A condition specified in the operation could not be evaluated.')
|
return self.error(er, 'A condition specified in the operation could not be evaluated.')
|
||||||
@ -652,13 +652,9 @@ class DynamoHandler(BaseResponse):
|
|||||||
|
|
||||||
# Attempt to parse simple ConditionExpressions into an Expected
|
# Attempt to parse simple ConditionExpressions into an Expected
|
||||||
# expression
|
# expression
|
||||||
if not expected:
|
condition_expression = self.body.get('ConditionExpression')
|
||||||
condition_expression = self.body.get('ConditionExpression')
|
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
|
||||||
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
|
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
|
||||||
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
|
|
||||||
expected = condition_expression_to_expected(condition_expression,
|
|
||||||
expression_attribute_names,
|
|
||||||
expression_attribute_values)
|
|
||||||
|
|
||||||
# Support spaces between operators in an update expression
|
# Support spaces between operators in an update expression
|
||||||
# E.g. `a = b + c` -> `a=b+c`
|
# E.g. `a = b + c` -> `a=b+c`
|
||||||
@ -669,7 +665,7 @@ class DynamoHandler(BaseResponse):
|
|||||||
try:
|
try:
|
||||||
item = self.dynamodb_backend.update_item(
|
item = self.dynamodb_backend.update_item(
|
||||||
name, key, update_expression, attribute_updates, expression_attribute_names,
|
name, key, update_expression, attribute_updates, expression_attribute_names,
|
||||||
expression_attribute_values, expected
|
expression_attribute_values, expected, condition_expression
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
|
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
|
||||||
|
@ -1616,6 +1616,21 @@ def test_condition_expressions():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
client.put_item(
|
||||||
|
TableName='test1',
|
||||||
|
Item={
|
||||||
|
'client': {'S': 'client1'},
|
||||||
|
'app': {'S': 'app1'},
|
||||||
|
'match': {'S': 'match'},
|
||||||
|
'existing': {'S': 'existing'},
|
||||||
|
},
|
||||||
|
ConditionExpression='attribute_exists(#nonexistent) OR attribute_exists(#existing)',
|
||||||
|
ExpressionAttributeNames={
|
||||||
|
'#nonexistent': 'nope',
|
||||||
|
'#existing': 'existing'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
with assert_raises(client.exceptions.ConditionalCheckFailedException):
|
with assert_raises(client.exceptions.ConditionalCheckFailedException):
|
||||||
client.put_item(
|
client.put_item(
|
||||||
TableName='test1',
|
TableName='test1',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user