diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index ac78d45ba..06d992602 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -19,6 +19,63 @@ def get_filter_expression(expr, names, values): return parser.parse() +def get_expected(expected): + """ + Parse a filter expression into an Op. + + Examples + expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' + expr = 'Id > 5 AND Subs < 7' + """ + ops = { + 'EQ': OpEqual, + 'NE': OpNotEqual, + 'LE': OpLessThanOrEqual, + 'LT': OpLessThan, + 'GE': OpGreaterThanOrEqual, + 'GT': OpGreaterThan, + 'NOT_NULL': FuncAttrExists, + 'NULL': FuncAttrNotExists, + 'CONTAINS': FuncContains, + 'NOT_CONTAINS': FuncNotContains, + 'BEGINS_WITH': FuncBeginsWith, + 'IN': FuncIn, + 'BETWEEN': FuncBetween, + } + + # NOTE: Always uses ConditionalOperator=AND + conditions = [] + for key, cond in expected.items(): + path = AttributePath([key]) + if 'Exists' in cond: + if cond['Exists']: + conditions.append(FuncAttrExists(path)) + else: + conditions.append(FuncAttrNotExists(path)) + elif 'Value' in cond: + conditions.append(OpEqual(path, AttributeValue(cond['Value']))) + elif 'ComparisonOperator' in cond: + operator_name = cond['ComparisonOperator'] + values = [ + AttributeValue(v) + for v in cond.get("AttributeValueList", [])] + print(path, values) + OpClass = ops[operator_name] + conditions.append(OpClass(path, *values)) + + # NOTE: Ignore ConditionalOperator + ConditionalOp = OpAnd + if conditions: + output = conditions[0] + for condition in conditions[1:]: + output = ConditionalOp(output, condition) + else: + return OpDefault(None, None) + + print("EXPECTED:", expected, output) + return output + + class Op(object): """ Base class for a FilterExpression operator @@ -782,14 +839,19 @@ class AttributePath(Operand): self.path = path def _get_attr(self, item): + if item is None: + return None + base = self.path[0] if base not in item.attrs: return None attr = item.attrs[base] + for name in self.path[1:]: attr = attr.child_attr(name) if attr is None: return None + return attr def expr(self, item): @@ -807,7 +869,7 @@ class AttributePath(Operand): return attr.type def __repr__(self): - return self.path + return ".".join(self.path) class AttributeValue(Operand): @@ -821,23 +883,27 @@ class AttributeValue(Operand): """ self.type = list(value.keys())[0] - if 'N' in value: - self.value = float(value['N']) - elif 'BOOL' in value: - self.value = value['BOOL'] - elif 'S' in value: - self.value = value['S'] - elif 'NS' in value: - self.value = tuple(value['NS']) - elif 'SS' in value: - self.value = tuple(value['SS']) - elif 'L' in value: - self.value = tuple(value['L']) - else: - # TODO: Handle all attribute types - raise NotImplementedError() + self.value = value[self.type] def expr(self, item): + # TODO: Reuse DynamoType code + if self.type == 'N': + try: + return int(self.value) + except ValueError: + return float(self.value) + elif self.type in ['SS', 'NS', 'BS']: + sub_type = self.type[0] + return set([AttributeValue({sub_type: v}).expr(item) + for v in self.value]) + elif self.type == 'L': + return [AttributeValue(v).expr(item) for v in self.value] + elif self.type == 'M': + return dict([ + (k, AttributeValue(v).expr(item)) + for k, v in self.value.items()]) + else: + return self.value return self.value def get_type(self, item): @@ -976,15 +1042,8 @@ class FuncAttrExists(Func): return self.attr.get_type(item) is not None -class FuncAttrNotExists(Func): - FUNC = 'attribute_not_exists' - - def __init__(self, attribute): - self.attr = attribute - super().__init__(attribute) - - def expr(self, item): - return self.attr.get_type(item) is None +def FuncAttrNotExists(attribute): + return OpNot(FuncAttrExists(attribute), None) class FuncAttrType(Func): @@ -1024,13 +1083,20 @@ class FuncContains(Func): super().__init__(attribute, operand) def expr(self, item): - if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L', 'M'): - return self.operand.expr(item) in self.attr.expr(item) + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L'): + try: + return self.operand.expr(item) in self.attr.expr(item) + except TypeError: + return False return False +def FuncNotContains(attribute, operand): + return OpNot(FuncContains(attribute, operand), None) + + class FuncSize(Func): - FUNC = 'contains' + FUNC = 'size' def __init__(self, attribute): self.attr = attribute diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 300479e9a..bdf59df1f 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -13,6 +13,9 @@ from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time from moto.core.exceptions import JsonRESTError from .comparisons import get_comparison_func, get_filter_expression, Op +from .comparisons import get_comparison_func +from .comparisons import get_filter_expression +from .comparisons import get_expected from .exceptions import InvalidIndexNameError @@ -557,29 +560,9 @@ class Table(BaseModel): self.range_key_type, item_attrs) if not overwrite: - if current is None: - current_attr = {} - elif hasattr(current, 'attrs'): - current_attr = current.attrs - else: - current_attr = current + if not get_expected(expected).expr(current): + raise ValueError('The conditional request failed') - for key, val in expected.items(): - if 'Exists' in val and val['Exists'] is False \ - or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': - if key in current_attr: - raise ValueError("The conditional request failed") - elif key not in current_attr: - raise ValueError("The conditional request failed") - elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value: - raise ValueError("The conditional request failed") - elif 'ComparisonOperator' in val: - dynamo_types = [ - DynamoType(ele) for ele in - val.get("AttributeValueList", []) - ] - if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types): - raise ValueError('The conditional request failed') if range_value: self.items[hash_value][range_value] = item else: @@ -1024,32 +1007,11 @@ class DynamoDBBackend(BaseBackend): item = table.get_item(hash_value, range_value) - if item is None: - item_attr = {} - elif hasattr(item, 'attrs'): - item_attr = item.attrs - else: - item_attr = item - if not expected: expected = {} - for key, val in expected.items(): - if 'Exists' in val and val['Exists'] is False \ - or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': - if key in item_attr: - raise ValueError("The conditional request failed") - elif key not in item_attr: - raise ValueError("The conditional request failed") - elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value: - raise ValueError("The conditional request failed") - elif 'ComparisonOperator' in val: - dynamo_types = [ - DynamoType(ele) for ele in - val.get("AttributeValueList", []) - ] - if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types): - raise ValueError('The conditional request failed') + if not get_expected(expected).expr(item): + raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: