diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f28083221..40da55ccf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,6 +2,10 @@ Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project. +## Running the tests locally + +Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests. + ## Is there a missing feature? Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services. diff --git a/Makefile b/Makefile index de08c6f74..2a7249760 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ endif init: @python setup.py develop - @pip install -r requirements.txt + @pip install -r requirements-dev.txt lint: flake8 moto diff --git a/moto/__init__.py b/moto/__init__.py index 9c974f00d..a6f35069e 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -3,7 +3,7 @@ import logging # logging.getLogger('boto').setLevel(logging.CRITICAL) __title__ = 'moto' -__version__ = '1.3.9' +__version__ = '1.3.11' from .acm import mock_acm # flake8: noqa from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py index 8dfa4724a..784d86b0b 100644 --- a/moto/awslambda/models.py +++ b/moto/awslambda/models.py @@ -231,6 +231,10 @@ class LambdaFunction(BaseModel): config.update({"VpcId": "vpc-123abc"}) return config + @property + def physical_resource_id(self): + return self.function_name + def __repr__(self): return json.dumps(self.get_configuration()) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 6d37345fe..1a4633e64 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -1,6 +1,94 @@ from __future__ import unicode_literals import re import six +import re +from collections import deque +from collections import namedtuple + + +def get_filter_expression(expr, names, values): + """ + Parse a filter expression into an Op. + + Examples + expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' + expr = 'Id > 5 AND Subs < 7' + """ + parser = ConditionExpressionParser(expr, names, values) + return parser.parse() + + +def get_expected(expected): + """ + Parse a filter expression into an Op. + + Examples + expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' + expr = 'Id > 5 AND Subs < 7' + """ + ops = { + 'EQ': OpEqual, + 'NE': OpNotEqual, + 'LE': OpLessThanOrEqual, + 'LT': OpLessThan, + 'GE': OpGreaterThanOrEqual, + 'GT': OpGreaterThan, + 'NOT_NULL': FuncAttrExists, + 'NULL': FuncAttrNotExists, + 'CONTAINS': FuncContains, + 'NOT_CONTAINS': FuncNotContains, + 'BEGINS_WITH': FuncBeginsWith, + 'IN': FuncIn, + 'BETWEEN': FuncBetween, + } + + # NOTE: Always uses ConditionalOperator=AND + conditions = [] + for key, cond in expected.items(): + path = AttributePath([key]) + if 'Exists' in cond: + if cond['Exists']: + conditions.append(FuncAttrExists(path)) + else: + conditions.append(FuncAttrNotExists(path)) + elif 'Value' in cond: + conditions.append(OpEqual(path, AttributeValue(cond['Value']))) + elif 'ComparisonOperator' in cond: + operator_name = cond['ComparisonOperator'] + values = [ + AttributeValue(v) + for v in cond.get("AttributeValueList", [])] + OpClass = ops[operator_name] + conditions.append(OpClass(path, *values)) + + # NOTE: Ignore ConditionalOperator + ConditionalOp = OpAnd + if conditions: + output = conditions[0] + for condition in conditions[1:]: + output = ConditionalOp(output, condition) + else: + return OpDefault(None, None) + + return output + + +class Op(object): + """ + Base class for a FilterExpression operator + """ + OP = '' + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + def expr(self, item): + raise NotImplementedError("Expr not defined for {0}".format(type(self))) + + def __repr__(self): + return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + # TODO add tests for all of these EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa @@ -49,292 +137,799 @@ class RecursionStopIteration(StopIteration): pass -def get_filter_expression(expr, names, values): - # Examples - # expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' - # expr = 'Id > 5 AND Subs < 7' - if names is None: - names = {} - if values is None: - values = {} +class ConditionExpressionParser: + def __init__(self, condition_expression, expression_attribute_names, + expression_attribute_values): + self.condition_expression = condition_expression + self.expression_attribute_names = expression_attribute_names + self.expression_attribute_values = expression_attribute_values - # Do substitutions - for key, value in names.items(): - expr = expr.replace(key, value) + def parse(self): + """Returns a syntax tree for the expression. - # Store correct types of values for use later - values_map = {} - for key, value in values.items(): - if 'N' in value: - values_map[key] = float(value['N']) - elif 'BOOL' in value: - values_map[key] = value['BOOL'] - elif 'S' in value: - values_map[key] = value['S'] - elif 'NS' in value: - values_map[key] = tuple(value['NS']) - elif 'SS' in value: - values_map[key] = tuple(value['SS']) - elif 'L' in value: - values_map[key] = tuple(value['L']) + The tree, and all of the nodes in the tree are a tuple of + - kind: str + - children/value: + list of nodes for parent nodes + value for leaf nodes + + Raises ValueError if the condition expression is invalid + Raises KeyError if expression attribute names/values are invalid + + Here are the types of nodes that can be returned. + The types of child nodes are denoted with a colon (:). + An arbitrary number of children is denoted with ... + + Condition: + ('OR', [lhs : Condition, rhs : Condition]) + ('AND', [lhs: Condition, rhs: Condition]) + ('NOT', [argument: Condition]) + ('PARENTHESES', [argument: Condition]) + ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...]) + ('BETWEEN', [query: Operand, low: Operand, high: Operand]) + ('IN', [query: Operand, possible_value: Operand, ...]) + ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand]) + + Operand: + ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'}) + ('PATH', [('LITERAL', path_element: str), ...]) + NOTE: Expression attribute names will be expanded + ('FUNCTION', [('LITERAL', 'size'), argument: Operand]) + + Literal: + ('LITERAL', value: str) + + """ + if not self.condition_expression: + return OpDefault(None, None) + nodes = self._lex_condition_expression() + nodes = self._parse_paths(nodes) + # NOTE: The docs say that functions should be parsed after + # IN, BETWEEN, and comparisons like <=. + # However, these expressions are invalid as function arguments, + # so it is okay to parse functions first. This needs to be done + # to interpret size() correctly as an operand. + nodes = self._apply_functions(nodes) + nodes = self._apply_comparator(nodes) + nodes = self._apply_in(nodes) + nodes = self._apply_between(nodes) + nodes = self._apply_parens_and_booleans(nodes) + node = nodes[0] + op = self._make_op_condition(node) + return op + + class Kind: + """Enum defining types of nodes in the syntax tree.""" + + # Condition nodes + # --------------- + OR = 'OR' + AND = 'AND' + NOT = 'NOT' + PARENTHESES = 'PARENTHESES' + FUNCTION = 'FUNCTION' + BETWEEN = 'BETWEEN' + IN = 'IN' + COMPARISON = 'COMPARISON' + + # Operand nodes + # ------------- + EXPRESSION_ATTRIBUTE_VALUE = 'EXPRESSION_ATTRIBUTE_VALUE' + PATH = 'PATH' + + # Literal nodes + # -------------- + LITERAL = 'LITERAL' + + + class Nonterminal: + """Enum defining nonterminals for productions.""" + + CONDITION = 'CONDITION' + OPERAND = 'OPERAND' + COMPARATOR = 'COMPARATOR' + FUNCTION_NAME = 'FUNCTION_NAME' + IDENTIFIER = 'IDENTIFIER' + AND = 'AND' + OR = 'OR' + NOT = 'NOT' + BETWEEN = 'BETWEEN' + IN = 'IN' + COMMA = 'COMMA' + LEFT_PAREN = 'LEFT_PAREN' + RIGHT_PAREN = 'RIGHT_PAREN' + WHITESPACE = 'WHITESPACE' + + + Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) + + def _lex_condition_expression(self): + nodes = deque() + remaining_expression = self.condition_expression + while remaining_expression: + node, remaining_expression = \ + self._lex_one_node(remaining_expression) + if node.nonterminal == self.Nonterminal.WHITESPACE: + continue + nodes.append(node) + return nodes + + def _lex_one_node(self, remaining_expression): + # TODO: Handle indexing like [1] + attribute_regex = '(:|#)?[A-z0-9\-_]+' + patterns = [( + self.Nonterminal.WHITESPACE, re.compile('^ +') + ), ( + self.Nonterminal.COMPARATOR, re.compile( + '^(' + # Put long expressions first for greedy matching + '<>|' + '<=|' + '>=|' + '=|' + '<|' + '>)'), + ), ( + self.Nonterminal.OPERAND, re.compile( + '^' + + attribute_regex + '(\.' + attribute_regex + '|\[[0-9]\])*') + ), ( + self.Nonterminal.COMMA, re.compile('^,') + ), ( + self.Nonterminal.LEFT_PAREN, re.compile('^\(') + ), ( + self.Nonterminal.RIGHT_PAREN, re.compile('^\)') + )] + + for nonterminal, pattern in patterns: + match = pattern.match(remaining_expression) + if match: + match_text = match.group() + break + else: # pragma: no cover + raise ValueError("Cannot parse condition starting at: " + + remaining_expression) + + value = match_text + node = self.Node( + nonterminal=nonterminal, + kind=self.Kind.LITERAL, + text=match_text, + value=match_text, + children=[]) + + remaining_expression = remaining_expression[len(match_text):] + + return node, remaining_expression + + def _parse_paths(self, nodes): + output = deque() + + while nodes: + node = nodes.popleft() + + if node.nonterminal == self.Nonterminal.OPERAND: + path = node.value.replace('[', '.[').split('.') + children = [ + self._parse_path_element(name) + for name in path] + if len(children) == 1: + child = children[0] + if child.nonterminal != self.Nonterminal.IDENTIFIER: + output.append(child) + continue + else: + for child in children: + self._assert( + child.nonterminal == self.Nonterminal.IDENTIFIER, + "Cannot use %s in path" % child.text, [node]) + output.append(self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.PATH, + text=node.text, + value=None, + children=children)) + else: + output.append(node) + return output + + def _parse_path_element(self, name): + reserved = { + 'and': self.Nonterminal.AND, + 'or': self.Nonterminal.OR, + 'in': self.Nonterminal.IN, + 'between': self.Nonterminal.BETWEEN, + 'not': self.Nonterminal.NOT, + } + + functions = { + 'attribute_exists', + 'attribute_not_exists', + 'attribute_type', + 'begins_with', + 'contains', + 'size', + } + + + if name.lower() in reserved: + # e.g. AND + nonterminal = reserved[name.lower()] + return self.Node( + nonterminal=nonterminal, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name in functions: + # e.g. attribute_exists + return self.Node( + nonterminal=self.Nonterminal.FUNCTION_NAME, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name.startswith(':'): + # e.g. :value0 + return self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.EXPRESSION_ATTRIBUTE_VALUE, + text=name, + value=self._lookup_expression_attribute_value(name), + children=[]) + elif name.startswith('#'): + # e.g. #name0 + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=self._lookup_expression_attribute_name(name), + children=[]) + elif name.startswith('['): + # e.g. [123] + if not name.endswith(']'): # pragma: no cover + raise ValueError("Bad path element %s" % name) + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=int(name[1:-1]), + children=[]) else: - raise NotImplementedError() + # e.g. ItemId + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) - # Remove all spaces, tbf we could just skip them in the next step. - # The number of known options is really small so we can do a fair bit of cheating - expr = list(expr.strip()) + def _lookup_expression_attribute_value(self, name): + return self.expression_attribute_values[name] - # DodgyTokenisation stage 1 - def is_value(val): - return val not in ('<', '>', '=', '(', ')') + def _lookup_expression_attribute_name(self, name): + return self.expression_attribute_names[name] - def contains_keyword(val): - for kw in ('BETWEEN', 'IN', 'AND', 'OR', 'NOT'): - if kw in val: - return kw - return None + # NOTE: The following constructions are ordered from high precedence to low precedence + # according to + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence + # + # = <> < <= > >= + # IN + # BETWEEN + # attribute_exists attribute_not_exists begins_with contains + # Parentheses + # NOT + # AND + # OR + # + # The grammar is taken from + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax + # + # condition-expression ::= + # operand comparator operand + # operand BETWEEN operand AND operand + # operand IN ( operand (',' operand (, ...) )) + # function + # condition AND condition + # condition OR condition + # NOT condition + # ( condition ) + # + # comparator ::= + # = + # <> + # < + # <= + # > + # >= + # + # function ::= + # attribute_exists (path) + # attribute_not_exists (path) + # attribute_type (path, type) + # begins_with (path, substr) + # contains (path, operand) + # size (path) - def is_function(val): - return val in ('attribute_exists', 'attribute_not_exists', 'attribute_type', 'begins_with', 'contains', 'size') + def _matches(self, nodes, production): + """Check if the nodes start with the given production. - # Does the main part of splitting between sections of characters - tokens = [] - stack = '' - while len(expr) > 0: - current_char = expr.pop(0) + Parameters + ---------- + nodes: list of Node + production: list of str + The name of a Nonterminal, or '*' for anything - if current_char == ' ': - if len(stack) > 0: - tokens.append(stack) - stack = '' - elif current_char == ',': # Split params , - if len(stack) > 0: - tokens.append(stack) - stack = '' - elif is_value(current_char): - stack += current_char + """ + if len(nodes) < len(production): + return False + for i in range(len(production)): + if production[i] == '*': + continue + expected = getattr(self.Nonterminal, production[i]) + if nodes[i].nonterminal != expected: + return False + return True - kw = contains_keyword(stack) - if kw is not None: - # We have a kw in the stack, could be AND or something like 5AND - tmp = stack.replace(kw, '') - if len(tmp) > 0: - tokens.append(tmp) - tokens.append(kw) - stack = '' - else: - if len(stack) > 0: - tokens.append(stack) - tokens.append(current_char) - stack = '' - if len(stack) > 0: - tokens.append(stack) + def _apply_comparator(self, nodes): + """Apply condition := operand comparator operand.""" + output = deque() - def is_op(val): - return val in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT') + while nodes: + if self._matches(nodes, ['*', 'COMPARATOR']): + self._assert( + self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), + "Bad comparison", list(nodes)[:3]) + lhs = nodes.popleft() + comparator = nodes.popleft() + rhs = nodes.popleft() + nodes.appendleft(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.COMPARISON, + text=" ".join([ + lhs.text, + comparator.text, + rhs.text]), + value=None, + children=[lhs, comparator, rhs])) + else: + output.append(nodes.popleft()) + return output - # DodgyTokenisation stage 2, it groups together some elements to make RPN'ing it later easier. - def handle_token(token, tokens2, token_iterator): - # ok so this essentially groups up some tokens to make later parsing easier, - # when it encounters brackets it will recurse and then unrecurse when RecursionStopIteration is raised. - if token == ')': - raise RecursionStopIteration() # Should be recursive so this should work - elif token == '(': - temp_list = [] - - try: + def _apply_in(self, nodes): + """Apply condition := operand IN ( operand , ... ).""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'IN']): + self._assert( + self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), + "Bad IN expression", list(nodes)[:3]) + lhs = nodes.popleft() + in_node = nodes.popleft() + left_paren = nodes.popleft() + all_children = [lhs, in_node, left_paren] + rhs = [] while True: - next_token = six.next(token_iterator) - handle_token(next_token, temp_list, token_iterator) - except RecursionStopIteration: - pass # Continue - except StopIteration: - ValueError('Malformed filter expression, type1') - - # Sigh, we only want to group a tuple if it doesnt contain operators - if any([is_op(item) for item in temp_list]): - # Its an expression - tokens2.append('(') - tokens2.extend(temp_list) - tokens2.append(')') + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + break # Close + else: + self._assert( + False, + "Bad IN expression starting at", nodes) + nodes.appendleft(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.IN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs] + rhs)) else: - tokens2.append(tuple(temp_list)) - elif token == 'BETWEEN': - field = tokens2.pop() - # if values map contains a number, it would be a float - # so we need to int() it anyway - op1 = six.next(token_iterator) - op1 = int(values_map.get(op1, op1)) - and_op = six.next(token_iterator) - assert and_op == 'AND' - op2 = six.next(token_iterator) - op2 = int(values_map.get(op2, op2)) - tokens2.append(['between', field, op1, op2]) - elif is_function(token): - function_list = [token] + output.append(nodes.popleft()) + return output - lbracket = six.next(token_iterator) - assert lbracket == '(' - - next_token = six.next(token_iterator) - while next_token != ')': - if next_token in values_map: - next_token = values_map[next_token] - function_list.append(next_token) - next_token = six.next(token_iterator) - - tokens2.append(function_list) - else: - # Convert tokens back to real types - if token in values_map: - token = values_map[token] - - # Need to join >= <= <> - if len(tokens2) > 0 and ((tokens2[-1] == '>' and token == '=') or (tokens2[-1] == '<' and token == '=') or (tokens2[-1] == '<' and token == '>')): - tokens2.append(tokens2.pop() + token) + def _apply_between(self, nodes): + """Apply condition := operand BETWEEN operand AND operand.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'BETWEEN']): + self._assert( + self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', + 'AND', 'OPERAND']), + "Bad BETWEEN expression", list(nodes)[:5]) + lhs = nodes.popleft() + between_node = nodes.popleft() + low = nodes.popleft() + and_node = nodes.popleft() + high = nodes.popleft() + all_children = [lhs, between_node, low, and_node, high] + nodes.appendleft(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.BETWEEN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, low, high])) else: - tokens2.append(token) + output.append(nodes.popleft()) + return output - tokens2 = [] - token_iterator = iter(tokens) - for token in token_iterator: - handle_token(token, tokens2, token_iterator) - - # Start of the Shunting-Yard algorithm. <-- Proper beast algorithm! - def is_number(val): - return val not in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT') - - OPS = {'<': 5, '>': 5, '=': 5, '>=': 5, '<=': 5, '<>': 5, 'IN': 8, 'AND': 11, 'OR': 12, 'NOT': 10, 'BETWEEN': 9, '(': 100, ')': 100} - - def shunting_yard(token_list): - output = [] - op_stack = [] - - # Basically takes in an infix notation calculation, converts it to a reverse polish notation where there is no - # ambiguity on which order operators are applied. - while len(token_list) > 0: - token = token_list.pop(0) - - if token == '(': - op_stack.append(token) - elif token == ')': - while len(op_stack) > 0 and op_stack[-1] != '(': - output.append(op_stack.pop()) - lbracket = op_stack.pop() - assert lbracket == '(' - - elif is_number(token): - output.append(token) + def _apply_functions(self, nodes): + """Apply condition := function_name (operand , ...).""" + output = deque() + either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE} + expected_argument_kind_map = { + 'attribute_exists': [{self.Kind.PATH}], + 'attribute_not_exists': [{self.Kind.PATH}], + 'attribute_type': [either_kind, {self.Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'begins_with': [either_kind, either_kind], + 'contains': [either_kind, either_kind], + 'size': [{self.Kind.PATH}], + } + while nodes: + if self._matches(nodes, ['FUNCTION_NAME']): + self._assert( + self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', + 'OPERAND', '*']), + "Bad function expression at", list(nodes)[:4]) + function_name = nodes.popleft() + left_paren = nodes.popleft() + all_children = [function_name, left_paren] + arguments = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + break # Close paren + else: + self._assert( + False, + "Bad function expression", all_children + list(nodes)[:2]) + expected_kinds = expected_argument_kind_map[function_name.value] + self._assert( + len(arguments) == len(expected_kinds), + "Wrong number of arguments in", all_children) + for i in range(len(expected_kinds)): + self._assert( + arguments[i].kind in expected_kinds[i], + "Wrong type for argument %d in" % i, all_children) + if function_name.value == 'size': + nonterminal = self.Nonterminal.OPERAND + else: + nonterminal = self.Nonterminal.CONDITION + nodes.appendleft(self.Node( + nonterminal=nonterminal, + kind=self.Kind.FUNCTION, + text=" ".join([t.text for t in all_children]), + value=None, + children=[function_name] + arguments)) else: - # Must be operator kw + output.append(nodes.popleft()) + return output - # Cheat, NOT is our only RIGHT associative operator, should really have dict of operator associativity - while len(op_stack) > 0 and OPS[op_stack[-1]] <= OPS[token] and op_stack[-1] != 'NOT': - output.append(op_stack.pop()) - op_stack.append(token) - while len(op_stack) > 0: - output.append(op_stack.pop()) + def _apply_parens_and_booleans(self, nodes, left_paren=None): + """Apply condition := ( condition ) and booleans.""" + output = deque() + while nodes: + if self._matches(nodes, ['LEFT_PAREN']): + parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) + self._assert( + len(parsed) >= 1, + "Failed to close parentheses at", nodes) + parens = parsed.popleft() + self._assert( + parens.kind == self.Kind.PARENTHESES, + "Failed to close parentheses at", nodes) + output.append(parens) + nodes = parsed + elif self._matches(nodes, ['RIGHT_PAREN']): + self._assert( + left_paren is not None, + "Unmatched ) at", nodes) + close_paren = nodes.popleft() + children = self._apply_booleans(output) + all_children = [left_paren] + list(children) + [close_paren] + return deque([ + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.PARENTHESES, + text=" ".join([t.text for t in all_children]), + value=None, + children=list(children), + )] + list(nodes)) + else: + output.append(nodes.popleft()) + + self._assert( + left_paren is None, + "Unmatched ( at", list(output)) + return self._apply_booleans(output) + + def _apply_booleans(self, nodes): + """Apply and, or, and not constructions.""" + nodes = self._apply_not(nodes) + nodes = self._apply_and(nodes) + nodes = self._apply_or(nodes) + # The expression should reduce to a single condition + self._assert( + len(nodes) == 1, + "Unexpected expression at", list(nodes)[1:]) + self._assert( + nodes[0].nonterminal == self.Nonterminal.CONDITION, + "Incomplete condition", nodes) + return nodes + + def _apply_not(self, nodes): + """Apply condition := NOT condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['NOT']): + self._assert( + self._matches(nodes, ['NOT', 'CONDITION']), + "Bad NOT expression", list(nodes)[:2]) + not_node = nodes.popleft() + child = nodes.popleft() + nodes.appendleft(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.NOT, + text=" ".join([not_node.text, child.text]), + value=None, + children=[child])) + else: + output.append(nodes.popleft()) return output - output = shunting_yard(tokens2) - - # Hacky function to convert dynamo functions (which are represented as lists) to their Class equivalent - def to_func(val): - if isinstance(val, list): - func_name = val.pop(0) - # Expand rest of the list to arguments - val = FUNC_CLASS[func_name](*val) - - return val - - # Simple reverse polish notation execution. Builts up a nested filter object. - # The filter object then takes a dynamo item and returns true/false - stack = [] - for token in output: - if is_op(token): - op_cls = OP_CLASS[token] - - if token == 'NOT': - op1 = stack.pop() - op2 = True + def _apply_and(self, nodes): + """Apply condition := condition AND condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'AND']): + self._assert( + self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), + "Bad AND expression", list(nodes)[:3]) + lhs = nodes.popleft() + and_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, and_node, rhs] + nodes.appendleft(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.AND, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) else: - op2 = stack.pop() - op1 = stack.pop() + output.append(nodes.popleft()) - stack.append(op_cls(op1, op2)) + return output + + def _apply_or(self, nodes): + """Apply condition := condition OR condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'OR']): + self._assert( + self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), + "Bad OR expression", list(nodes)[:3]) + lhs = nodes.popleft() + or_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, or_node, rhs] + nodes.appendleft(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.OR, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _make_operand(self, node): + if node.kind == self.Kind.PATH: + return AttributePath([child.value for child in node.children]) + elif node.kind == self.Kind.EXPRESSION_ATTRIBUTE_VALUE: + return AttributeValue(node.value) + elif node.kind == self.Kind.FUNCTION: + # size() + function_node = node.children[0] + arguments = node.children[1:] + function_name = function_node.value + arguments = [self._make_operand(arg) for arg in arguments] + return FUNC_CLASS[function_name](*arguments) + else: # pragma: no cover + raise ValueError("Unknown operand: %r" % node) + + + def _make_op_condition(self, node): + if node.kind == self.Kind.OR: + lhs, rhs = node.children + return OpOr( + self._make_op_condition(lhs), + self._make_op_condition(rhs)) + elif node.kind == self.Kind.AND: + lhs, rhs = node.children + return OpAnd( + self._make_op_condition(lhs), + self._make_op_condition(rhs)) + elif node.kind == self.Kind.NOT: + child, = node.children + return OpNot(self._make_op_condition(child)) + elif node.kind == self.Kind.PARENTHESES: + child, = node.children + return self._make_op_condition(child) + elif node.kind == self.Kind.FUNCTION: + function_node = node.children[0] + arguments = node.children[1:] + function_name = function_node.value + arguments = [self._make_operand(arg) for arg in arguments] + return FUNC_CLASS[function_name](*arguments) + elif node.kind == self.Kind.BETWEEN: + query, low, high = node.children + return FuncBetween( + self._make_operand(query), + self._make_operand(low), + self._make_operand(high)) + elif node.kind == self.Kind.IN: + query = node.children[0] + possible_values = node.children[1:] + query = self._make_operand(query) + possible_values = [self._make_operand(v) for v in possible_values] + return FuncIn(query, *possible_values) + elif node.kind == self.Kind.COMPARISON: + lhs, comparator, rhs = node.children + return COMPARATOR_CLASS[comparator.value]( + self._make_operand(lhs), + self._make_operand(rhs)) + else: # pragma: no cover + raise ValueError("Unknown expression node kind %r" % node.kind) + + def _print_debug(self, nodes): # pragma: no cover + print('ROOT') + for node in nodes: + self._print_node_recursive(node, depth=1) + + def _print_node_recursive(self, node, depth=0): # pragma: no cover + if len(node.children) > 0: + print(' ' * depth, node.nonterminal, node.kind) + for child in node.children: + self._print_node_recursive(child, depth=depth + 1) else: - stack.append(to_func(token)) - - result = stack.pop(0) - if len(stack) > 0: - raise ValueError('Malformed filter expression, type2') - - return result + print(' ' * depth, node.nonterminal, node.kind, node.value) -class Op(object): - """ - Base class for a FilterExpression operator - """ - OP = '' - def __init__(self, lhs, rhs): - self.lhs = lhs - self.rhs = rhs + def _assert(self, condition, message, nodes): + if not condition: + raise ValueError(message + " " + " ".join([t.text for t in nodes])) + + +class Operand(object): + def expr(self, item): + raise NotImplementedError + + def get_type(self, item): + raise NotImplementedError + + +class AttributePath(Operand): + def __init__(self, path): + """Initialize the AttributePath. + + Parameters + ---------- + path: list of int/str - def _lhs(self, item): """ - :type item: moto.dynamodb2.models.Item - """ - lhs = self.lhs - if isinstance(self.lhs, (Op, Func)): - lhs = self.lhs.expr(item) - elif isinstance(self.lhs, six.string_types): - try: - lhs = item.attrs[self.lhs].cast_value - except Exception: - pass + assert len(path) >= 1 + self.path = path - return lhs + def _get_attr(self, item): + if item is None: + return None - def _rhs(self, item): - rhs = self.rhs - if isinstance(self.rhs, (Op, Func)): - rhs = self.rhs.expr(item) - elif isinstance(self.rhs, six.string_types): - try: - rhs = item.attrs[self.rhs].cast_value - except Exception: - pass - return rhs + base = self.path[0] + if base not in item.attrs: + return None + attr = item.attrs[base] + + for name in self.path[1:]: + attr = attr.child_attr(name) + if attr is None: + return None + + return attr def expr(self, item): - return True + attr = self._get_attr(item) + if attr is None: + return None + else: + return attr.cast_value + + def get_type(self, item): + attr = self._get_attr(item) + if attr is None: + return None + else: + return attr.type def __repr__(self): - return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + return ".".join(self.path) -class Func(object): - """ - Base class for a FilterExpression function - """ - FUNC = 'Unknown' +class AttributeValue(Operand): + def __init__(self, value): + """Initialize the AttributePath. + + Parameters + ---------- + value: dict + e.g. {'N': '1.234'} + + """ + self.type = list(value.keys())[0] + self.value = value[self.type] def expr(self, item): - return True + # TODO: Reuse DynamoType code + if self.type == 'N': + try: + return int(self.value) + except ValueError: + return float(self.value) + elif self.type in ['SS', 'NS', 'BS']: + sub_type = self.type[0] + return set([AttributeValue({sub_type: v}).expr(item) + for v in self.value]) + elif self.type == 'L': + return [AttributeValue(v).expr(item) for v in self.value] + elif self.type == 'M': + return dict([ + (k, AttributeValue(v).expr(item)) + for k, v in self.value.items()]) + else: + return self.value + return self.value + + def get_type(self, item): + return self.type def __repr__(self): - return 'Func(...)'.format(self.FUNC) + return repr(self.value) + + +class OpDefault(Op): + OP = 'NONE' + + def expr(self, item): + """If no condition is specified, always True.""" + return True class OpNot(Op): OP = 'NOT' - def expr(self, item): - lhs = self._lhs(item) + def __init__(self, lhs): + super(OpNot, self).__init__(lhs, None) + def expr(self, item): + lhs = self.lhs.expr(item) return not lhs def __str__(self): @@ -345,8 +940,8 @@ class OpAnd(Op): OP = 'AND' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs and rhs @@ -354,8 +949,8 @@ class OpLessThan(Op): OP = '<' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs < rhs @@ -363,8 +958,8 @@ class OpGreaterThan(Op): OP = '>' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs > rhs @@ -372,8 +967,8 @@ class OpEqual(Op): OP = '=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs == rhs @@ -381,8 +976,8 @@ class OpNotEqual(Op): OP = '<>' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs != rhs @@ -390,8 +985,8 @@ class OpLessThanOrEqual(Op): OP = '<=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs <= rhs @@ -399,8 +994,8 @@ class OpGreaterThanOrEqual(Op): OP = '>=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs >= rhs @@ -408,18 +1003,27 @@ class OpOr(Op): OP = 'OR' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs or rhs -class OpIn(Op): - OP = 'IN' +class Func(object): + """ + Base class for a FilterExpression function + """ + FUNC = 'Unknown' + + def __init__(self, *arguments): + self.arguments = arguments def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) - return lhs in rhs + raise NotImplementedError + + def __repr__(self): + return '{0}({1})'.format( + self.FUNC, + " ".join([repr(arg) for arg in self.arguments])) class FuncAttrExists(Func): @@ -427,19 +1031,14 @@ class FuncAttrExists(Func): def __init__(self, attribute): self.attr = attribute + super(FuncAttrExists, self).__init__(attribute) def expr(self, item): - return self.attr in item.attrs + return self.attr.get_type(item) is not None -class FuncAttrNotExists(Func): - FUNC = 'attribute_not_exists' - - def __init__(self, attribute): - self.attr = attribute - - def expr(self, item): - return self.attr not in item.attrs +def FuncAttrNotExists(attribute): + return OpNot(FuncAttrExists(attribute)) class FuncAttrType(Func): @@ -448,9 +1047,10 @@ class FuncAttrType(Func): def __init__(self, attribute, _type): self.attr = attribute self.type = _type + super(FuncAttrType, self).__init__(attribute, _type) def expr(self, item): - return self.attr in item.attrs and item.attrs[self.attr].type == self.type + return self.attr.get_type(item) == self.type.expr(item) class FuncBeginsWith(Func): @@ -459,9 +1059,14 @@ class FuncBeginsWith(Func): def __init__(self, attribute, substr): self.attr = attribute self.substr = substr + super(FuncBeginsWith, self).__init__(attribute, substr) def expr(self, item): - return self.attr in item.attrs and item.attrs[self.attr].type == 'S' and item.attrs[self.attr].value.startswith(self.substr) + if self.attr.get_type(item) != 'S': + return False + if self.substr.get_type(item) != 'S': + return False + return self.attr.expr(item).startswith(self.substr.expr(item)) class FuncContains(Func): @@ -470,51 +1075,67 @@ class FuncContains(Func): def __init__(self, attribute, operand): self.attr = attribute self.operand = operand + super(FuncContains, self).__init__(attribute, operand) def expr(self, item): - if self.attr not in item.attrs: - return False - - if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'BS', 'L', 'M'): - return self.operand in item.attrs[self.attr].value + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L'): + try: + return self.operand.expr(item) in self.attr.expr(item) + except TypeError: + return False return False +def FuncNotContains(attribute, operand): + return OpNot(FuncContains(attribute, operand)) + + class FuncSize(Func): - FUNC = 'contains' + FUNC = 'size' def __init__(self, attribute): self.attr = attribute + super(FuncSize, self).__init__(attribute) def expr(self, item): - if self.attr not in item.attrs: + if self.attr.get_type(item) is None: raise ValueError('Invalid attribute name {0}'.format(self.attr)) - if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): - return len(item.attrs[self.attr].value) + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): + return len(self.attr.expr(item)) raise ValueError('Invalid filter expression') class FuncBetween(Func): - FUNC = 'between' + FUNC = 'BETWEEN' def __init__(self, attribute, start, end): self.attr = attribute self.start = start self.end = end + super(FuncBetween, self).__init__(attribute, start, end) def expr(self, item): - if self.attr not in item.attrs: - raise ValueError('Invalid attribute name {0}'.format(self.attr)) - - return self.start <= item.attrs[self.attr].cast_value <= self.end + return self.start.expr(item) <= self.attr.expr(item) <= self.end.expr(item) -OP_CLASS = { - 'NOT': OpNot, - 'AND': OpAnd, - 'OR': OpOr, - 'IN': OpIn, +class FuncIn(Func): + FUNC = 'IN' + + def __init__(self, attribute, *possible_values): + self.attr = attribute + self.possible_values = possible_values + super(FuncIn, self).__init__(attribute, *possible_values) + + def expr(self, item): + for possible_value in self.possible_values: + if self.attr.expr(item) == possible_value.expr(item): + return True + + return False + + +COMPARATOR_CLASS = { '<': OpLessThan, '>': OpGreaterThan, '<=': OpLessThanOrEqual, diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index bfbb654b4..29e90e7dc 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -6,13 +6,16 @@ import decimal import json import re import uuid +import six import boto3 from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time from moto.core.exceptions import JsonRESTError -from .comparisons import get_comparison_func, get_filter_expression, Op +from .comparisons import get_comparison_func +from .comparisons import get_filter_expression +from .comparisons import get_expected from .exceptions import InvalidIndexNameError @@ -68,10 +71,34 @@ class DynamoType(object): except ValueError: return float(self.value) elif self.is_set(): - return set(self.value) + sub_type = self.type[0] + return set([DynamoType({sub_type: v}).cast_value + for v in self.value]) + elif self.is_list(): + return [DynamoType(v).cast_value for v in self.value] + elif self.is_map(): + return dict([ + (k, DynamoType(v).cast_value) + for k, v in self.value.items()]) else: return self.value + def child_attr(self, key): + """ + Get Map or List children by key. str for Map, int for List. + + Returns DynamoType or None. + """ + if isinstance(key, six.string_types) and self.is_map() and key in self.value: + return DynamoType(self.value[key]) + + if isinstance(key, int) and self.is_list(): + idx = key + if idx >= 0 and idx < len(self.value): + return DynamoType(self.value[idx]) + + return None + def to_json(self): return {self.type: self.value} @@ -89,6 +116,12 @@ class DynamoType(object): def is_set(self): return self.type == 'SS' or self.type == 'NS' or self.type == 'BS' + def is_list(self): + return self.type == 'L' + + def is_map(self): + return self.type == 'M' + def same_type(self, other): return self.type == other.type @@ -504,7 +537,9 @@ class Table(BaseModel): keys.append(range_key) return keys - def put_item(self, item_attrs, expected=None, overwrite=False): + def put_item(self, item_attrs, expected=None, condition_expression=None, + expression_attribute_names=None, + expression_attribute_values=None, overwrite=False): hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) if self.has_range_key: range_value = DynamoType(item_attrs.get(self.range_key_attr)) @@ -527,29 +562,15 @@ class Table(BaseModel): self.range_key_type, item_attrs) if not overwrite: - if current is None: - current_attr = {} - elif hasattr(current, 'attrs'): - current_attr = current.attrs - else: - current_attr = current + if not get_expected(expected).expr(current): + raise ValueError('The conditional request failed') + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values) + if not condition_op.expr(current): + raise ValueError('The conditional request failed') - for key, val in expected.items(): - if 'Exists' in val and val['Exists'] is False \ - or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': - if key in current_attr: - raise ValueError("The conditional request failed") - elif key not in current_attr: - raise ValueError("The conditional request failed") - elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value: - raise ValueError("The conditional request failed") - elif 'ComparisonOperator' in val: - dynamo_types = [ - DynamoType(ele) for ele in - val.get("AttributeValueList", []) - ] - if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types): - raise ValueError('The conditional request failed') if range_value: self.items[hash_value][range_value] = item else: @@ -902,11 +923,15 @@ class DynamoDBBackend(BaseBackend): table.global_indexes = list(gsis_by_name.values()) return table - def put_item(self, table_name, item_attrs, expected=None, overwrite=False): + def put_item(self, table_name, item_attrs, expected=None, + condition_expression=None, expression_attribute_names=None, + expression_attribute_values=None, overwrite=False): table = self.tables.get(table_name) if not table: return None - return table.put_item(item_attrs, expected, overwrite) + return table.put_item(item_attrs, expected, condition_expression, + expression_attribute_names, + expression_attribute_values, overwrite) def get_table_keys_name(self, table_name, keys): """ @@ -962,10 +987,7 @@ class DynamoDBBackend(BaseBackend): range_values = [DynamoType(range_value) for range_value in range_value_dicts] - if filter_expression is not None: - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) - else: - filter_expression = Op(None, None) # Will always eval to true + filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) return table.query(hash_key, range_comparison, range_values, limit, exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) @@ -980,17 +1002,14 @@ class DynamoDBBackend(BaseBackend): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) - if filter_expression is not None: - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) - else: - filter_expression = Op(None, None) # Will always eval to true + filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) projection_expression = ','.join([expr_names.get(attr, attr) for attr in projection_expression.replace(' ', '').split(',')]) return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name, projection_expression) def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected=None): + expression_attribute_values, expected=None, condition_expression=None): table = self.get_table(table_name) if all([table.hash_key_attr in key, table.range_key_attr in key]): @@ -1009,32 +1028,17 @@ class DynamoDBBackend(BaseBackend): item = table.get_item(hash_value, range_value) - if item is None: - item_attr = {} - elif hasattr(item, 'attrs'): - item_attr = item.attrs - else: - item_attr = item - if not expected: expected = {} - for key, val in expected.items(): - if 'Exists' in val and val['Exists'] is False \ - or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': - if key in item_attr: - raise ValueError("The conditional request failed") - elif key not in item_attr: - raise ValueError("The conditional request failed") - elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value: - raise ValueError("The conditional request failed") - elif 'ComparisonOperator' in val: - dynamo_types = [ - DynamoType(ele) for ele in - val.get("AttributeValueList", []) - ] - if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types): - raise ValueError('The conditional request failed') + if not get_expected(expected).expr(item): + raise ValueError('The conditional request failed') + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values) + if not condition_op.expr(item): + raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 5dde432d5..d34b176a7 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -32,67 +32,6 @@ def get_empty_str_error(): )) -def condition_expression_to_expected(condition_expression, expression_attribute_names, expression_attribute_values): - """ - Limited condition expression syntax parsing. - Supports Global Negation ex: NOT(inner expressions). - Supports simple AND conditions ex: cond_a AND cond_b and cond_c. - Atomic expressions supported are attribute_exists(key), attribute_not_exists(key) and #key = :value. - """ - expected = {} - if condition_expression and 'OR' not in condition_expression: - reverse_re = re.compile('^NOT\s*\((.*)\)$') - reverse_m = reverse_re.match(condition_expression.strip()) - - reverse = False - if reverse_m: - reverse = True - condition_expression = reverse_m.group(1) - - cond_items = [c.strip() for c in condition_expression.split('AND')] - if cond_items: - exists_re = re.compile('^attribute_exists\s*\((.*)\)$') - not_exists_re = re.compile( - '^attribute_not_exists\s*\((.*)\)$') - equals_re = re.compile('^(#?\w+)\s*=\s*(\:?\w+)') - - for cond in cond_items: - exists_m = exists_re.match(cond) - not_exists_m = not_exists_re.match(cond) - equals_m = equals_re.match(cond) - - if exists_m: - attribute_name = expression_attribute_names_lookup(exists_m.group(1), expression_attribute_names) - expected[attribute_name] = {'Exists': True if not reverse else False} - elif not_exists_m: - attribute_name = expression_attribute_names_lookup(not_exists_m.group(1), expression_attribute_names) - expected[attribute_name] = {'Exists': False if not reverse else True} - elif equals_m: - attribute_name = expression_attribute_names_lookup(equals_m.group(1), expression_attribute_names) - attribute_value = expression_attribute_values_lookup(equals_m.group(2), expression_attribute_values) - expected[attribute_name] = { - 'AttributeValueList': [attribute_value], - 'ComparisonOperator': 'EQ' if not reverse else 'NEQ'} - - return expected - - -def expression_attribute_names_lookup(attribute_name, expression_attribute_names): - if attribute_name.startswith('#') and attribute_name in expression_attribute_names: - return expression_attribute_names[attribute_name] - else: - return attribute_name - - -def expression_attribute_values_lookup(attribute_value, expression_attribute_values): - if isinstance(attribute_value, six.string_types) and \ - attribute_value.startswith(':') and\ - attribute_value in expression_attribute_values: - return expression_attribute_values[attribute_value] - else: - return attribute_value - - class DynamoHandler(BaseResponse): def get_endpoint_name(self, headers): @@ -288,18 +227,18 @@ class DynamoHandler(BaseResponse): # Attempt to parse simple ConditionExpressions into an Expected # expression - if not expected: - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) - expected = condition_expression_to_expected(condition_expression, - expression_attribute_names, - expression_attribute_values) - if expected: - overwrite = False + condition_expression = self.body.get('ConditionExpression') + expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) + expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + + if condition_expression: + overwrite = False try: - result = self.dynamodb_backend.put_item(name, item, expected, overwrite) + result = self.dynamodb_backend.put_item( + name, item, expected, condition_expression, + expression_attribute_names, expression_attribute_values, + overwrite) except ValueError: er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' return self.error(er, 'A condition specified in the operation could not be evaluated.') @@ -626,7 +565,7 @@ class DynamoHandler(BaseResponse): name = self.body['TableName'] key = self.body['Key'] return_values = self.body.get('ReturnValues', 'NONE') - update_expression = self.body.get('UpdateExpression') + update_expression = self.body.get('UpdateExpression', '').strip() attribute_updates = self.body.get('AttributeUpdates') expression_attribute_names = self.body.get( 'ExpressionAttributeNames', {}) @@ -653,13 +592,9 @@ class DynamoHandler(BaseResponse): # Attempt to parse simple ConditionExpressions into an Expected # expression - if not expected: - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) - expected = condition_expression_to_expected(condition_expression, - expression_attribute_names, - expression_attribute_values) + condition_expression = self.body.get('ConditionExpression') + expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) + expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) # Support spaces between operators in an update expression # E.g. `a = b + c` -> `a=b+c` @@ -670,7 +605,7 @@ class DynamoHandler(BaseResponse): try: item = self.dynamodb_backend.update_item( name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected + expression_attribute_values, expected, condition_expression ) except ValueError: er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' diff --git a/moto/rds2/exceptions.py b/moto/rds2/exceptions.py index 0e716310e..e82ae7077 100644 --- a/moto/rds2/exceptions.py +++ b/moto/rds2/exceptions.py @@ -60,6 +60,15 @@ class DBParameterGroupNotFoundError(RDSClientError): 'DB Parameter Group {0} not found.'.format(db_parameter_group_name)) +class OptionGroupNotFoundFaultError(RDSClientError): + + def __init__(self, option_group_name): + super(OptionGroupNotFoundFaultError, self).__init__( + 'OptionGroupNotFoundFault', + 'Specified OptionGroupName: {0} not found.'.format(option_group_name) + ) + + class InvalidDBClusterStateFaultError(RDSClientError): def __init__(self, database_identifier): diff --git a/moto/rds2/models.py b/moto/rds2/models.py index fee004f76..81b346fdb 100644 --- a/moto/rds2/models.py +++ b/moto/rds2/models.py @@ -20,6 +20,7 @@ from .exceptions import (RDSClientError, DBSecurityGroupNotFoundError, DBSubnetGroupNotFoundError, DBParameterGroupNotFoundError, + OptionGroupNotFoundFaultError, InvalidDBClusterStateFaultError, InvalidDBInstanceStateError, SnapshotQuotaExceededError, @@ -70,6 +71,7 @@ class Database(BaseModel): self.port = Database.default_port(self.engine) self.db_instance_identifier = kwargs.get('db_instance_identifier') self.db_name = kwargs.get("db_name") + self.instance_create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) self.publicly_accessible = kwargs.get("publicly_accessible") if self.publicly_accessible is None: self.publicly_accessible = True @@ -99,6 +101,8 @@ class Database(BaseModel): 'preferred_backup_window', '13:14-13:44') self.license_model = kwargs.get('license_model', 'general-public-license') self.option_group_name = kwargs.get('option_group_name', None) + if self.option_group_name and self.option_group_name not in rds2_backends[self.region].option_groups: + raise OptionGroupNotFoundFaultError(self.option_group_name) self.default_option_groups = {"MySQL": "default.mysql5.6", "mysql": "default.mysql5.6", "postgres": "default.postgres9.3" @@ -148,6 +152,7 @@ class Database(BaseModel): {{ database.db_instance_identifier }} {{ database.dbi_resource_id }} + {{ database.instance_create_time }} 03:50-04:20 wed:06:38-wed:07:08 @@ -173,6 +178,10 @@ class Database(BaseModel): {{ database.license_model }} {{ database.engine_version }} + + {{ database.option_group_name }} + in-sync + {% for db_parameter_group in database.db_parameter_groups() %} @@ -373,7 +382,7 @@ class Database(BaseModel): "Address": "{{ database.address }}", "Port": "{{ database.port }}" }, - "InstanceCreateTime": null, + "InstanceCreateTime": "{{ database.instance_create_time }}", "Iops": null, "ReadReplicaDBInstanceIdentifiers": [{%- for replica in database.replicas -%} {%- if not loop.first -%},{%- endif -%} @@ -873,13 +882,16 @@ class RDS2Backend(BaseBackend): def create_option_group(self, option_group_kwargs): option_group_id = option_group_kwargs['name'] - valid_option_group_engines = {'mysql': ['5.6'], - 'oracle-se1': ['11.2'], - 'oracle-se': ['11.2'], - 'oracle-ee': ['11.2'], + valid_option_group_engines = {'mariadb': ['10.0', '10.1', '10.2', '10.3'], + 'mysql': ['5.5', '5.6', '5.7', '8.0'], + 'oracle-se2': ['11.2', '12.1', '12.2'], + 'oracle-se1': ['11.2', '12.1', '12.2'], + 'oracle-se': ['11.2', '12.1', '12.2'], + 'oracle-ee': ['11.2', '12.1', '12.2'], 'sqlserver-se': ['10.50', '11.00'], - 'sqlserver-ee': ['10.50', '11.00'] - } + 'sqlserver-ee': ['10.50', '11.00'], + 'sqlserver-ex': ['10.50', '11.00'], + 'sqlserver-web': ['10.50', '11.00']} if option_group_kwargs['name'] in self.option_groups: raise RDSClientError('OptionGroupAlreadyExistsFault', 'An option group named {0} already exists.'.format(option_group_kwargs['name'])) @@ -905,8 +917,7 @@ class RDS2Backend(BaseBackend): if option_group_name in self.option_groups: return self.option_groups.pop(option_group_name) else: - raise RDSClientError( - 'OptionGroupNotFoundFault', 'Specified OptionGroupName: {0} not found.'.format(option_group_name)) + raise OptionGroupNotFoundFaultError(option_group_name) def describe_option_groups(self, option_group_kwargs): option_group_list = [] @@ -935,8 +946,7 @@ class RDS2Backend(BaseBackend): else: option_group_list.append(option_group) if not len(option_group_list): - raise RDSClientError('OptionGroupNotFoundFault', - 'Specified OptionGroupName: {0} not found.'.format(option_group_kwargs['name'])) + raise OptionGroupNotFoundFaultError(option_group_kwargs['name']) return option_group_list[marker:max_records + marker] @staticmethod @@ -965,8 +975,7 @@ class RDS2Backend(BaseBackend): def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None): if option_group_name not in self.option_groups: - raise RDSClientError('OptionGroupNotFoundFault', - 'Specified OptionGroupName: {0} not found.'.format(option_group_name)) + raise OptionGroupNotFoundFaultError(option_group_name) if not options_to_include and not options_to_remove: raise RDSClientError('InvalidParameterValue', 'At least one option must be added, modified, or removed.') diff --git a/moto/rds2/responses.py b/moto/rds2/responses.py index 66d4e0c52..e92625635 100644 --- a/moto/rds2/responses.py +++ b/moto/rds2/responses.py @@ -34,7 +34,7 @@ class RDS2Response(BaseResponse): "master_user_password": self._get_param('MasterUserPassword'), "master_username": self._get_param('MasterUsername'), "multi_az": self._get_bool_param("MultiAZ"), - # OptionGroupName + "option_group_name": self._get_param("OptionGroupName"), "port": self._get_param('Port'), # PreferredBackupWindow # PreferredMaintenanceWindow diff --git a/moto/route53/models.py b/moto/route53/models.py index 5ed1c1476..681a9d6ff 100644 --- a/moto/route53/models.py +++ b/moto/route53/models.py @@ -85,6 +85,7 @@ class RecordSet(BaseModel): self.health_check = kwargs.get('HealthCheckId') self.hosted_zone_name = kwargs.get('HostedZoneName') self.hosted_zone_id = kwargs.get('HostedZoneId') + self.alias_target = kwargs.get('AliasTarget') @classmethod def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): @@ -143,6 +144,13 @@ class RecordSet(BaseModel): {% if record_set.ttl %} {{ record_set.ttl }} {% endif %} + {% if record_set.alias_target %} + + {{ record_set.alias_target['HostedZoneId'] }} + {{ record_set.alias_target['DNSName'] }} + {{ record_set.alias_target['EvaluateTargetHealth'] }} + + {% else %} {% for record in record_set.records %} @@ -150,6 +158,7 @@ class RecordSet(BaseModel): {% endfor %} + {% endif %} {% if record_set.health_check %} {{ record_set.health_check }} {% endif %} diff --git a/moto/route53/responses.py b/moto/route53/responses.py index bf705c87f..f933c575a 100644 --- a/moto/route53/responses.py +++ b/moto/route53/responses.py @@ -134,10 +134,7 @@ class Route53(BaseResponse): # Depending on how many records there are, this may # or may not be a list resource_records = [resource_records] - record_values = [x['Value'] for x in resource_records] - elif 'AliasTarget' in record_set: - record_values = [record_set['AliasTarget']['DNSName']] - record_set['ResourceRecords'] = record_values + record_set['ResourceRecords'] = [x['Value'] for x in resource_records] if action == 'CREATE': the_zone.add_rrset(record_set) else: diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 46d811f81..a052e4cfb 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -807,7 +807,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): body = b'' if method == 'GET': - return self._key_response_get(bucket_name, query, key_name, headers) + return self._key_response_get(bucket_name, query, key_name, headers=request.headers) elif method == 'PUT': return self._key_response_put(request, body, bucket_name, query, key_name, headers) elif method == 'HEAD': @@ -842,10 +842,15 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): parts=parts ) version_id = query.get('versionId', [None])[0] + if_modified_since = headers.get('If-Modified-Since', None) key = self.backend.get_key( bucket_name, key_name, version_id=version_id) if key is None: raise MissingKey(key_name) + if if_modified_since: + if_modified_since = str_to_rfc_1123_datetime(if_modified_since) + if if_modified_since and key.last_modified < if_modified_since: + return 304, response_headers, 'Not Modified' if 'acl' in query: template = self.response_template(S3_OBJECT_ACL_RESPONSE) return 200, response_headers, template.render(obj=key) diff --git a/moto/ses/feedback.py b/moto/ses/feedback.py new file mode 100644 index 000000000..2d32f9ce0 --- /dev/null +++ b/moto/ses/feedback.py @@ -0,0 +1,81 @@ +""" +SES Feedback messages +Extracted from https://docs.aws.amazon.com/ses/latest/DeveloperGuide/notification-contents.html +""" +COMMON_MAIL = { + "notificationType": "Bounce, Complaint, or Delivery.", + "mail": { + "timestamp": "2018-10-08T14:05:45 +0000", + "messageId": "000001378603177f-7a5433e7-8edb-42ae-af10-f0181f34d6ee-000000", + "source": "sender@example.com", + "sourceArn": "arn:aws:ses:us-west-2:888888888888:identity/example.com", + "sourceIp": "127.0.3.0", + "sendingAccountId": "123456789012", + "destination": [ + "recipient@example.com" + ], + "headersTruncated": False, + "headers": [ + { + "name": "From", + "value": "\"Sender Name\" " + }, + { + "name": "To", + "value": "\"Recipient Name\" " + } + ], + "commonHeaders": { + "from": [ + "Sender Name " + ], + "date": "Mon, 08 Oct 2018 14:05:45 +0000", + "to": [ + "Recipient Name " + ], + "messageId": " custom-message-ID", + "subject": "Message sent using Amazon SES" + } + } +} +BOUNCE = { + "bounceType": "Permanent", + "bounceSubType": "General", + "bouncedRecipients": [ + { + "status": "5.0.0", + "action": "failed", + "diagnosticCode": "smtp; 550 user unknown", + "emailAddress": "recipient1@example.com" + }, + { + "status": "4.0.0", + "action": "delayed", + "emailAddress": "recipient2@example.com" + } + ], + "reportingMTA": "example.com", + "timestamp": "2012-05-25T14:59:38.605Z", + "feedbackId": "000001378603176d-5a4b5ad9-6f30-4198-a8c3-b1eb0c270a1d-000000", + "remoteMtaIp": "127.0.2.0" +} +COMPLAINT = { + "userAgent": "AnyCompany Feedback Loop (V0.01)", + "complainedRecipients": [ + { + "emailAddress": "recipient1@example.com" + } + ], + "complaintFeedbackType": "abuse", + "arrivalDate": "2009-12-03T04:24:21.000-05:00", + "timestamp": "2012-05-25T14:59:38.623Z", + "feedbackId": "000001378603177f-18c07c78-fa81-4a58-9dd1-fedc3cb8f49a-000000" +} +DELIVERY = { + "timestamp": "2014-05-28T22:41:01.184Z", + "processingTimeMillis": 546, + "recipients": ["success@simulator.amazonses.com"], + "smtpResponse": "250 ok: Message 64111812 accepted", + "reportingMTA": "a8-70.smtp-out.amazonses.com", + "remoteMtaIp": "127.0.2.0" +} diff --git a/moto/ses/models.py b/moto/ses/models.py index 71fe9d9a1..0544ac278 100644 --- a/moto/ses/models.py +++ b/moto/ses/models.py @@ -4,13 +4,41 @@ import email from email.utils import parseaddr from moto.core import BaseBackend, BaseModel +from moto.sns.models import sns_backends from .exceptions import MessageRejectedError from .utils import get_random_message_id - +from .feedback import COMMON_MAIL, BOUNCE, COMPLAINT, DELIVERY RECIPIENT_LIMIT = 50 +class SESFeedback(BaseModel): + + BOUNCE = "Bounce" + COMPLAINT = "Complaint" + DELIVERY = "Delivery" + + SUCCESS_ADDR = "success" + BOUNCE_ADDR = "bounce" + COMPLAINT_ADDR = "complaint" + + FEEDBACK_SUCCESS_MSG = {"test": "success"} + FEEDBACK_BOUNCE_MSG = {"test": "bounce"} + FEEDBACK_COMPLAINT_MSG = {"test": "complaint"} + + @staticmethod + def generate_message(msg_type): + msg = dict(COMMON_MAIL) + if msg_type == SESFeedback.BOUNCE: + msg["bounce"] = BOUNCE + elif msg_type == SESFeedback.COMPLAINT: + msg["complaint"] = COMPLAINT + elif msg_type == SESFeedback.DELIVERY: + msg["delivery"] = DELIVERY + + return msg + + class Message(BaseModel): def __init__(self, message_id, source, subject, body, destinations): @@ -48,6 +76,7 @@ class SESBackend(BaseBackend): self.domains = [] self.sent_messages = [] self.sent_message_count = 0 + self.sns_topics = {} def _is_verified_address(self, source): _, address = parseaddr(source) @@ -77,7 +106,7 @@ class SESBackend(BaseBackend): else: self.domains.remove(identity) - def send_email(self, source, subject, body, destinations): + def send_email(self, source, subject, body, destinations, region): recipient_count = sum(map(len, destinations.values())) if recipient_count > RECIPIENT_LIMIT: raise MessageRejectedError('Too many recipients.') @@ -86,13 +115,46 @@ class SESBackend(BaseBackend): "Email address not verified %s" % source ) + self.__process_sns_feedback__(source, destinations, region) + message_id = get_random_message_id() message = Message(message_id, source, subject, body, destinations) self.sent_messages.append(message) self.sent_message_count += recipient_count return message - def send_raw_email(self, source, destinations, raw_data): + def __type_of_message__(self, destinations): + """Checks the destination for any special address that could indicate delivery, complaint or bounce + like in SES simualtor""" + alladdress = destinations.get("ToAddresses", []) + destinations.get("CcAddresses", []) + destinations.get("BccAddresses", []) + for addr in alladdress: + if SESFeedback.SUCCESS_ADDR in addr: + return SESFeedback.DELIVERY + elif SESFeedback.COMPLAINT_ADDR in addr: + return SESFeedback.COMPLAINT + elif SESFeedback.BOUNCE_ADDR in addr: + return SESFeedback.BOUNCE + + return None + + def __generate_feedback__(self, msg_type): + """Generates the SNS message for the feedback""" + return SESFeedback.generate_message(msg_type) + + def __process_sns_feedback__(self, source, destinations, region): + domain = str(source) + if "@" in domain: + domain = domain.split("@")[1] + if domain in self.sns_topics: + msg_type = self.__type_of_message__(destinations) + if msg_type is not None: + sns_topic = self.sns_topics[domain].get(msg_type, None) + if sns_topic is not None: + message = self.__generate_feedback__(msg_type) + if message: + sns_backends[region].publish(sns_topic, message) + + def send_raw_email(self, source, destinations, raw_data, region): if source is not None: _, source_email_address = parseaddr(source) if source_email_address not in self.addresses: @@ -122,6 +184,8 @@ class SESBackend(BaseBackend): if recipient_count > RECIPIENT_LIMIT: raise MessageRejectedError('Too many recipients.') + self.__process_sns_feedback__(source, destinations, region) + self.sent_message_count += recipient_count message_id = get_random_message_id() message = RawMessage(message_id, source, destinations, raw_data) @@ -131,5 +195,16 @@ class SESBackend(BaseBackend): def get_send_quota(self): return SESQuota(self.sent_message_count) + def set_identity_notification_topic(self, identity, notification_type, sns_topic): + identity_sns_topics = self.sns_topics.get(identity, {}) + if sns_topic is None: + del identity_sns_topics[notification_type] + else: + identity_sns_topics[notification_type] = sns_topic + + self.sns_topics[identity] = identity_sns_topics + + return {} + ses_backend = SESBackend() diff --git a/moto/ses/responses.py b/moto/ses/responses.py index bdf873836..d2dda55f1 100644 --- a/moto/ses/responses.py +++ b/moto/ses/responses.py @@ -70,7 +70,7 @@ class EmailResponse(BaseResponse): break destinations[dest_type].append(address[0]) - message = ses_backend.send_email(source, subject, body, destinations) + message = ses_backend.send_email(source, subject, body, destinations, self.region) template = self.response_template(SEND_EMAIL_RESPONSE) return template.render(message=message) @@ -92,7 +92,7 @@ class EmailResponse(BaseResponse): break destinations.append(address[0]) - message = ses_backend.send_raw_email(source, destinations, raw_data) + message = ses_backend.send_raw_email(source, destinations, raw_data, self.region) template = self.response_template(SEND_RAW_EMAIL_RESPONSE) return template.render(message=message) @@ -101,6 +101,18 @@ class EmailResponse(BaseResponse): template = self.response_template(GET_SEND_QUOTA_RESPONSE) return template.render(quota=quota) + def set_identity_notification_topic(self): + + identity = self.querystring.get("Identity")[0] + not_type = self.querystring.get("NotificationType")[0] + sns_topic = self.querystring.get("SnsTopic") + if sns_topic: + sns_topic = sns_topic[0] + + ses_backend.set_identity_notification_topic(identity, not_type, sns_topic) + template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE) + return template.render() + VERIFY_EMAIL_IDENTITY = """ @@ -200,3 +212,10 @@ GET_SEND_QUOTA_RESPONSE = """ + + + 47e0ef1a-9bf2-11e1-9279-0100e8cf109a + +""" diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index faa467aab..f5afc1e7e 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -838,44 +838,47 @@ def test_filter_expression(): filter_expr.expr(row1).should.be(True) # NOT test 2 - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': 8}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': '8'}}) filter_expr.expr(row1).should.be(False) # Id = 8 so should be false # AND test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 7}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '7'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # OR test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': 5}, ':v1': {'N': 8}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '8'}}) filter_expr.expr(row1).should.be(True) # BETWEEN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 10}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '10'}}) filter_expr.expr(row1).should.be(True) # PAREN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': 8}, ':v1': {'N': 5}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': '8'}, ':v1': {'N': '5'}}) filter_expr.expr(row1).should.be(True) # IN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN :v0', {}, {':v0': {'NS': [7, 8, 9]}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN (:v0, :v1, :v2)', {}, { + ':v0': {'N': '7'}, + ':v1': {'N': '8'}, + ':v2': {'N': '9'}}) filter_expr.expr(row1).should.be(True) # attribute function tests (with extra spaces) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {}) filter_expr.expr(row1).should.be(True) - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, N)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, :v0)', {}, {':v0': {'S': 'N'}}) filter_expr.expr(row1).should.be(True) # beginswith function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, Some)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, :v0)', {}, {':v0': {'S': 'Some'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # contains function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, test1)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, :v0)', {}, {':v0': {'S': 'test1'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) @@ -916,14 +919,26 @@ def test_query_filter(): TableName='test1', Item={ 'client': {'S': 'client1'}, - 'app': {'S': 'app1'} + 'app': {'S': 'app1'}, + 'nested': {'M': { + 'version': {'S': 'version1'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, } ) client.put_item( TableName='test1', Item={ 'client': {'S': 'client1'}, - 'app': {'S': 'app2'} + 'app': {'S': 'app2'}, + 'nested': {'M': { + 'version': {'S': 'version2'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, } ) @@ -945,6 +960,18 @@ def test_query_filter(): ) assert response['Count'] == 2 + response = table.query( + KeyConditionExpression=Key('client').eq('client1'), + FilterExpression=Attr('nested.version').contains('version') + ) + assert response['Count'] == 2 + + response = table.query( + KeyConditionExpression=Key('client').eq('client1'), + FilterExpression=Attr('nested.contents[0]').eq('value1') + ) + assert response['Count'] == 2 + @mock_dynamodb2 def test_scan_filter(): @@ -1223,7 +1250,7 @@ def test_delete_item(): with assert_raises(ClientError) as ex: table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_NEW') - + # Test deletion and returning old value response = table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_OLD') response['Attributes'].should.contain('client') @@ -1526,7 +1553,7 @@ def test_put_return_attributes(): ReturnValues='NONE' ) assert 'Attributes' not in r - + r = dynamodb.put_item( TableName='moto-test', Item={'id': {'S': 'foo'}, 'col1': {'S': 'val2'}}, @@ -1543,7 +1570,7 @@ def test_put_return_attributes(): ex.exception.response['Error']['Code'].should.equal('ValidationException') ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) ex.exception.response['Error']['Message'].should.equal('Return values set to invalid value') - + @mock_dynamodb2 def test_query_global_secondary_index_when_created_via_update_table_resource(): @@ -1651,7 +1678,7 @@ def test_dynamodb_streams_1(): 'StreamViewType': 'NEW_AND_OLD_IMAGES' } ) - + assert 'StreamSpecification' in resp['TableDescription'] assert resp['TableDescription']['StreamSpecification'] == { 'StreamEnabled': True, @@ -1659,11 +1686,11 @@ def test_dynamodb_streams_1(): } assert 'LatestStreamLabel' in resp['TableDescription'] assert 'LatestStreamArn' in resp['TableDescription'] - + resp = conn.delete_table(TableName='test-streams') assert 'StreamSpecification' in resp['TableDescription'] - + @mock_dynamodb2 def test_dynamodb_streams_2(): @@ -1694,11 +1721,10 @@ def test_dynamodb_streams_2(): assert 'LatestStreamLabel' in resp['TableDescription'] assert 'LatestStreamArn' in resp['TableDescription'] - + @mock_dynamodb2 def test_condition_expressions(): client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') # Create the DynamoDB table. client.create_table( @@ -1751,6 +1777,57 @@ def test_condition_expressions(): } ) + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'match': {'S': 'match'}, + 'existing': {'S': 'existing'}, + }, + ConditionExpression='attribute_exists(#nonexistent) OR attribute_exists(#existing)', + ExpressionAttributeNames={ + '#nonexistent': 'nope', + '#existing': 'existing' + } + ) + + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'match': {'S': 'match'}, + 'existing': {'S': 'existing'}, + }, + ConditionExpression='#client BETWEEN :a AND :z', + ExpressionAttributeNames={ + '#client': 'client', + }, + ExpressionAttributeValues={ + ':a': {'S': 'a'}, + ':z': {'S': 'z'}, + } + ) + + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'match': {'S': 'match'}, + 'existing': {'S': 'existing'}, + }, + ConditionExpression='#client IN (:client1, :client2)', + ExpressionAttributeNames={ + '#client': 'client', + }, + ExpressionAttributeValues={ + ':client1': {'S': 'client1'}, + ':client2': {'S': 'client2'}, + } + ) + with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( TableName='test1', @@ -1803,6 +1880,89 @@ def test_condition_expressions(): } ) + # Make sure update_item honors ConditionExpression as well + client.update_item( + TableName='test1', + Key={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + }, + UpdateExpression='set #match=:match', + ConditionExpression='attribute_exists(#existing)', + ExpressionAttributeNames={ + '#existing': 'existing', + '#match': 'match', + }, + ExpressionAttributeValues={ + ':match': {'S': 'match'} + } + ) + + with assert_raises(client.exceptions.ConditionalCheckFailedException): + client.update_item( + TableName='test1', + Key={ + 'client': { 'S': 'client1'}, + 'app': { 'S': 'app1'}, + }, + UpdateExpression='set #match=:match', + ConditionExpression='attribute_not_exists(#existing)', + ExpressionAttributeValues={ + ':match': {'S': 'match'} + }, + ExpressionAttributeNames={ + '#existing': 'existing', + '#match': 'match', + }, + ) + + +@mock_dynamodb2 +def test_condition_expression__attr_doesnt_exist(): + client = boto3.client('dynamodb', region_name='us-east-1') + + client.create_table( + TableName='test', + KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], + AttributeDefinitions=[ + {'AttributeName': 'forum_name', 'AttributeType': 'S'}, + ], + ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ) + + client.put_item( + TableName='test', + Item={ + 'forum_name': {'S': 'foo'}, + 'ttl': {'N': 'bar'}, + } + ) + + + def update_if_attr_doesnt_exist(): + # Test nonexistent top-level attribute. + client.update_item( + TableName='test', + Key={ + 'forum_name': {'S': 'the-key'}, + 'subject': {'S': 'the-subject'}, + }, + UpdateExpression='set #new_state=:new_state, #ttl=:ttl', + ConditionExpression='attribute_not_exists(#new_state)', + ExpressionAttributeNames={'#new_state': 'foobar', '#ttl': 'ttl'}, + ExpressionAttributeValues={ + ':new_state': {'S': 'some-value'}, + ':ttl': {'N': '12345.67'}, + }, + ReturnValues='ALL_NEW', + ) + + update_if_attr_doesnt_exist() + + # Second time should fail + with assert_raises(client.exceptions.ConditionalCheckFailedException): + update_if_attr_doesnt_exist() + @mock_dynamodb2 def test_query_gsi_with_range_key(): diff --git a/tests/test_rds2/test_rds2.py b/tests/test_rds2/test_rds2.py index a25b53196..8ea296c2c 100644 --- a/tests/test_rds2/test_rds2.py +++ b/tests/test_rds2/test_rds2.py @@ -34,6 +34,39 @@ def test_create_database(): db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False) db_instance['DbiResourceId'].should.contain("db-") db_instance['CopyTagsToSnapshot'].should.equal(False) + db_instance['InstanceCreateTime'].should.be.a("datetime.datetime") + + +@mock_rds2 +def test_create_database_non_existing_option_group(): + conn = boto3.client('rds', region_name='us-west-2') + database = conn.create_db_instance.when.called_with( + DBInstanceIdentifier='db-master-1', + AllocatedStorage=10, + Engine='postgres', + DBName='staging-postgres', + DBInstanceClass='db.m1.small', + OptionGroupName='non-existing').should.throw(ClientError) + + +@mock_rds2 +def test_create_database_with_option_group(): + conn = boto3.client('rds', region_name='us-west-2') + conn.create_option_group(OptionGroupName='my-og', + EngineName='mysql', + MajorEngineVersion='5.6', + OptionGroupDescription='test option group') + database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', + AllocatedStorage=10, + Engine='postgres', + DBName='staging-postgres', + DBInstanceClass='db.m1.small', + OptionGroupName='my-og') + db_instance = database['DBInstance'] + db_instance['AllocatedStorage'].should.equal(10) + db_instance['DBInstanceClass'].should.equal('db.m1.small') + db_instance['DBName'].should.equal('staging-postgres') + db_instance['OptionGroupMemberships'][0]['OptionGroupName'].should.equal('my-og') @mock_rds2 @@ -204,6 +237,7 @@ def test_get_databases_paginated(): resp3 = conn.describe_db_instances(MaxRecords=100) resp3["DBInstances"].should.have.length_of(51) + @mock_rds2 def test_describe_non_existant_database(): conn = boto3.client('rds', region_name='us-west-2') diff --git a/tests/test_route53/test_route53.py b/tests/test_route53/test_route53.py index f43657dad..ca652af88 100644 --- a/tests/test_route53/test_route53.py +++ b/tests/test_route53/test_route53.py @@ -173,14 +173,16 @@ def test_alias_rrset(): changes.commit() rrsets = conn.get_all_rrsets(zoneid, type="A") - rrset_records = [(rr_set.name, rr) for rr_set in rrsets for rr in rr_set.resource_records] - rrset_records.should.have.length_of(2) - rrset_records.should.contain(('foo.alias.testdns.aws.com.', 'foo.testdns.aws.com')) - rrset_records.should.contain(('bar.alias.testdns.aws.com.', 'bar.testdns.aws.com')) - rrsets[0].resource_records[0].should.equal('foo.testdns.aws.com') + alias_targets = [rr_set.alias_dns_name for rr_set in rrsets] + alias_targets.should.have.length_of(2) + alias_targets.should.contain('foo.testdns.aws.com') + alias_targets.should.contain('bar.testdns.aws.com') + rrsets[0].alias_dns_name.should.equal('foo.testdns.aws.com') + rrsets[0].resource_records.should.have.length_of(0) rrsets = conn.get_all_rrsets(zoneid, type="CNAME") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('bar.testdns.aws.com') + rrsets[0].alias_dns_name.should.equal('bar.testdns.aws.com') + rrsets[0].resource_records.should.have.length_of(0) @mock_route53_deprecated @@ -583,6 +585,39 @@ def test_change_resource_record_sets_crud_valid(): cname_record_detail['TTL'].should.equal(60) cname_record_detail['ResourceRecords'].should.equal([{'Value': '192.168.1.1'}]) + # Update to add Alias. + cname_alias_record_endpoint_payload = { + 'Comment': 'Update to Alias prod.redis.db', + 'Changes': [ + { + 'Action': 'UPSERT', + 'ResourceRecordSet': { + 'Name': 'prod.redis.db.', + 'Type': 'A', + 'TTL': 60, + 'AliasTarget': { + 'HostedZoneId': hosted_zone_id, + 'DNSName': 'prod.redis.alias.', + 'EvaluateTargetHealth': False, + } + } + } + ] + } + conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=cname_alias_record_endpoint_payload) + + response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) + cname_alias_record_detail = response['ResourceRecordSets'][0] + cname_alias_record_detail['Name'].should.equal('prod.redis.db.') + cname_alias_record_detail['Type'].should.equal('A') + cname_alias_record_detail['TTL'].should.equal(60) + cname_alias_record_detail['AliasTarget'].should.equal({ + 'HostedZoneId': hosted_zone_id, + 'DNSName': 'prod.redis.alias.', + 'EvaluateTargetHealth': False, + }) + cname_alias_record_detail.should_not.contain('ResourceRecords') + # Delete record with wrong type. delete_payload = { 'Comment': 'delete prod.redis.db', diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index f26964ab7..697c47865 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1596,6 +1596,28 @@ def test_boto3_delete_versioned_bucket(): client.delete_bucket(Bucket='blah') +@mock_s3 +def test_boto3_get_object_if_modified_since(): + s3 = boto3.client('s3', region_name='us-east-1') + bucket_name = "blah" + s3.create_bucket(Bucket=bucket_name) + + key = 'hello.txt' + + s3.put_object( + Bucket=bucket_name, + Key=key, + Body='test' + ) + + with assert_raises(botocore.exceptions.ClientError) as err: + s3.get_object( + Bucket=bucket_name, + Key=key, + IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1) + ) + e = err.exception + e.response['Error'].should.equal({'Code': '304', 'Message': 'Not Modified'}) @mock_s3 def test_boto3_head_object_if_modified_since(): diff --git a/tests/test_ses/test_ses_sns_boto3.py b/tests/test_ses/test_ses_sns_boto3.py new file mode 100644 index 000000000..37f79a8b0 --- /dev/null +++ b/tests/test_ses/test_ses_sns_boto3.py @@ -0,0 +1,114 @@ +from __future__ import unicode_literals + +import boto3 +import json +from botocore.exceptions import ClientError +from six.moves.email_mime_multipart import MIMEMultipart +from six.moves.email_mime_text import MIMEText + +import sure # noqa +from nose import tools +from moto import mock_ses, mock_sns, mock_sqs +from moto.ses.models import SESFeedback + + +@mock_ses +def test_enable_disable_ses_sns_communication(): + conn = boto3.client('ses', region_name='us-east-1') + conn.set_identity_notification_topic( + Identity='test.com', + NotificationType='Bounce', + SnsTopic='the-arn' + ) + conn.set_identity_notification_topic( + Identity='test.com', + NotificationType='Bounce' + ) + + +def __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region, expected_msg): + """Setup the AWS environment to test the SES SNS Feedback""" + # Environment setup + # Create SQS queue + sqs_conn.create_queue(QueueName=queue) + # Create SNS topic + create_topic_response = sns_conn.create_topic(Name=topic) + topic_arn = create_topic_response["TopicArn"] + # Subscribe the SNS topic to the SQS queue + sns_conn.subscribe(TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:%s:123456789012:%s" % (region, queue)) + # Verify SES domain + ses_conn.verify_domain_identity(Domain=domain) + # Setup SES notification topic + if expected_msg is not None: + ses_conn.set_identity_notification_topic( + Identity=domain, + NotificationType=expected_msg, + SnsTopic=topic_arn + ) + + +def __test_sns_feedback__(addr, expected_msg): + region_name = "us-east-1" + ses_conn = boto3.client('ses', region_name=region_name) + sns_conn = boto3.client('sns', region_name=region_name) + sqs_conn = boto3.resource('sqs', region_name=region_name) + domain = "example.com" + topic = "bounce-arn-feedback" + queue = "feedback-test-queue" + + __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region_name, expected_msg) + + # Send the message + kwargs = dict( + Source="test@" + domain, + Destination={ + "ToAddresses": [addr + "@" + domain], + "CcAddresses": ["test_cc@" + domain], + "BccAddresses": ["test_bcc@" + domain], + }, + Message={ + "Subject": {"Data": "test subject"}, + "Body": {"Text": {"Data": "test body"}} + } + ) + ses_conn.send_email(**kwargs) + + # Wait for messages in the queues + queue = sqs_conn.get_queue_by_name(QueueName=queue) + messages = queue.receive_messages(MaxNumberOfMessages=1) + if expected_msg is not None: + msg = messages[0].body + msg = json.loads(msg) + assert msg["Message"] == SESFeedback.generate_message(expected_msg) + else: + assert len(messages) == 0 + + +@mock_sqs +@mock_sns +@mock_ses +def test_no_sns_feedback(): + __test_sns_feedback__("test", None) + + +@mock_sqs +@mock_sns +@mock_ses +def test_sns_feedback_bounce(): + __test_sns_feedback__(SESFeedback.BOUNCE_ADDR, SESFeedback.BOUNCE) + + +@mock_sqs +@mock_sns +@mock_ses +def test_sns_feedback_complaint(): + __test_sns_feedback__(SESFeedback.COMPLAINT_ADDR, SESFeedback.COMPLAINT) + + +@mock_sqs +@mock_sns +@mock_ses +def test_sns_feedback_delivery(): + __test_sns_feedback__(SESFeedback.SUCCESS_ADDR, SESFeedback.DELIVERY)