diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index f28083221..40da55ccf 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -2,6 +2,10 @@
Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project.
+## Running the tests locally
+
+Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests.
+
## Is there a missing feature?
Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services.
diff --git a/Makefile b/Makefile
index de08c6f74..2a7249760 100644
--- a/Makefile
+++ b/Makefile
@@ -10,7 +10,7 @@ endif
init:
@python setup.py develop
- @pip install -r requirements.txt
+ @pip install -r requirements-dev.txt
lint:
flake8 moto
diff --git a/moto/__init__.py b/moto/__init__.py
index 9c974f00d..a6f35069e 100644
--- a/moto/__init__.py
+++ b/moto/__init__.py
@@ -3,7 +3,7 @@ import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto'
-__version__ = '1.3.9'
+__version__ = '1.3.11'
from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa
diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py
index 8dfa4724a..784d86b0b 100644
--- a/moto/awslambda/models.py
+++ b/moto/awslambda/models.py
@@ -231,6 +231,10 @@ class LambdaFunction(BaseModel):
config.update({"VpcId": "vpc-123abc"})
return config
+ @property
+ def physical_resource_id(self):
+ return self.function_name
+
def __repr__(self):
return json.dumps(self.get_configuration())
diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py
index 6d37345fe..1a4633e64 100644
--- a/moto/dynamodb2/comparisons.py
+++ b/moto/dynamodb2/comparisons.py
@@ -1,6 +1,94 @@
from __future__ import unicode_literals
import re
import six
+import re
+from collections import deque
+from collections import namedtuple
+
+
+def get_filter_expression(expr, names, values):
+ """
+ Parse a filter expression into an Op.
+
+ Examples
+ expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)'
+ expr = 'Id > 5 AND Subs < 7'
+ """
+ parser = ConditionExpressionParser(expr, names, values)
+ return parser.parse()
+
+
+def get_expected(expected):
+ """
+ Parse a filter expression into an Op.
+
+ Examples
+ expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)'
+ expr = 'Id > 5 AND Subs < 7'
+ """
+ ops = {
+ 'EQ': OpEqual,
+ 'NE': OpNotEqual,
+ 'LE': OpLessThanOrEqual,
+ 'LT': OpLessThan,
+ 'GE': OpGreaterThanOrEqual,
+ 'GT': OpGreaterThan,
+ 'NOT_NULL': FuncAttrExists,
+ 'NULL': FuncAttrNotExists,
+ 'CONTAINS': FuncContains,
+ 'NOT_CONTAINS': FuncNotContains,
+ 'BEGINS_WITH': FuncBeginsWith,
+ 'IN': FuncIn,
+ 'BETWEEN': FuncBetween,
+ }
+
+ # NOTE: Always uses ConditionalOperator=AND
+ conditions = []
+ for key, cond in expected.items():
+ path = AttributePath([key])
+ if 'Exists' in cond:
+ if cond['Exists']:
+ conditions.append(FuncAttrExists(path))
+ else:
+ conditions.append(FuncAttrNotExists(path))
+ elif 'Value' in cond:
+ conditions.append(OpEqual(path, AttributeValue(cond['Value'])))
+ elif 'ComparisonOperator' in cond:
+ operator_name = cond['ComparisonOperator']
+ values = [
+ AttributeValue(v)
+ for v in cond.get("AttributeValueList", [])]
+ OpClass = ops[operator_name]
+ conditions.append(OpClass(path, *values))
+
+ # NOTE: Ignore ConditionalOperator
+ ConditionalOp = OpAnd
+ if conditions:
+ output = conditions[0]
+ for condition in conditions[1:]:
+ output = ConditionalOp(output, condition)
+ else:
+ return OpDefault(None, None)
+
+ return output
+
+
+class Op(object):
+ """
+ Base class for a FilterExpression operator
+ """
+ OP = ''
+
+ def __init__(self, lhs, rhs):
+ self.lhs = lhs
+ self.rhs = rhs
+
+ def expr(self, item):
+ raise NotImplementedError("Expr not defined for {0}".format(type(self)))
+
+ def __repr__(self):
+ return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs)
+
# TODO add tests for all of these
EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa
@@ -49,292 +137,799 @@ class RecursionStopIteration(StopIteration):
pass
-def get_filter_expression(expr, names, values):
- # Examples
- # expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)'
- # expr = 'Id > 5 AND Subs < 7'
- if names is None:
- names = {}
- if values is None:
- values = {}
+class ConditionExpressionParser:
+ def __init__(self, condition_expression, expression_attribute_names,
+ expression_attribute_values):
+ self.condition_expression = condition_expression
+ self.expression_attribute_names = expression_attribute_names
+ self.expression_attribute_values = expression_attribute_values
- # Do substitutions
- for key, value in names.items():
- expr = expr.replace(key, value)
+ def parse(self):
+ """Returns a syntax tree for the expression.
- # Store correct types of values for use later
- values_map = {}
- for key, value in values.items():
- if 'N' in value:
- values_map[key] = float(value['N'])
- elif 'BOOL' in value:
- values_map[key] = value['BOOL']
- elif 'S' in value:
- values_map[key] = value['S']
- elif 'NS' in value:
- values_map[key] = tuple(value['NS'])
- elif 'SS' in value:
- values_map[key] = tuple(value['SS'])
- elif 'L' in value:
- values_map[key] = tuple(value['L'])
+ The tree, and all of the nodes in the tree are a tuple of
+ - kind: str
+ - children/value:
+ list of nodes for parent nodes
+ value for leaf nodes
+
+ Raises ValueError if the condition expression is invalid
+ Raises KeyError if expression attribute names/values are invalid
+
+ Here are the types of nodes that can be returned.
+ The types of child nodes are denoted with a colon (:).
+ An arbitrary number of children is denoted with ...
+
+ Condition:
+ ('OR', [lhs : Condition, rhs : Condition])
+ ('AND', [lhs: Condition, rhs: Condition])
+ ('NOT', [argument: Condition])
+ ('PARENTHESES', [argument: Condition])
+ ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...])
+ ('BETWEEN', [query: Operand, low: Operand, high: Operand])
+ ('IN', [query: Operand, possible_value: Operand, ...])
+ ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand])
+
+ Operand:
+ ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'})
+ ('PATH', [('LITERAL', path_element: str), ...])
+ NOTE: Expression attribute names will be expanded
+ ('FUNCTION', [('LITERAL', 'size'), argument: Operand])
+
+ Literal:
+ ('LITERAL', value: str)
+
+ """
+ if not self.condition_expression:
+ return OpDefault(None, None)
+ nodes = self._lex_condition_expression()
+ nodes = self._parse_paths(nodes)
+ # NOTE: The docs say that functions should be parsed after
+ # IN, BETWEEN, and comparisons like <=.
+ # However, these expressions are invalid as function arguments,
+ # so it is okay to parse functions first. This needs to be done
+ # to interpret size() correctly as an operand.
+ nodes = self._apply_functions(nodes)
+ nodes = self._apply_comparator(nodes)
+ nodes = self._apply_in(nodes)
+ nodes = self._apply_between(nodes)
+ nodes = self._apply_parens_and_booleans(nodes)
+ node = nodes[0]
+ op = self._make_op_condition(node)
+ return op
+
+ class Kind:
+ """Enum defining types of nodes in the syntax tree."""
+
+ # Condition nodes
+ # ---------------
+ OR = 'OR'
+ AND = 'AND'
+ NOT = 'NOT'
+ PARENTHESES = 'PARENTHESES'
+ FUNCTION = 'FUNCTION'
+ BETWEEN = 'BETWEEN'
+ IN = 'IN'
+ COMPARISON = 'COMPARISON'
+
+ # Operand nodes
+ # -------------
+ EXPRESSION_ATTRIBUTE_VALUE = 'EXPRESSION_ATTRIBUTE_VALUE'
+ PATH = 'PATH'
+
+ # Literal nodes
+ # --------------
+ LITERAL = 'LITERAL'
+
+
+ class Nonterminal:
+ """Enum defining nonterminals for productions."""
+
+ CONDITION = 'CONDITION'
+ OPERAND = 'OPERAND'
+ COMPARATOR = 'COMPARATOR'
+ FUNCTION_NAME = 'FUNCTION_NAME'
+ IDENTIFIER = 'IDENTIFIER'
+ AND = 'AND'
+ OR = 'OR'
+ NOT = 'NOT'
+ BETWEEN = 'BETWEEN'
+ IN = 'IN'
+ COMMA = 'COMMA'
+ LEFT_PAREN = 'LEFT_PAREN'
+ RIGHT_PAREN = 'RIGHT_PAREN'
+ WHITESPACE = 'WHITESPACE'
+
+
+ Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children'])
+
+ def _lex_condition_expression(self):
+ nodes = deque()
+ remaining_expression = self.condition_expression
+ while remaining_expression:
+ node, remaining_expression = \
+ self._lex_one_node(remaining_expression)
+ if node.nonterminal == self.Nonterminal.WHITESPACE:
+ continue
+ nodes.append(node)
+ return nodes
+
+ def _lex_one_node(self, remaining_expression):
+ # TODO: Handle indexing like [1]
+ attribute_regex = '(:|#)?[A-z0-9\-_]+'
+ patterns = [(
+ self.Nonterminal.WHITESPACE, re.compile('^ +')
+ ), (
+ self.Nonterminal.COMPARATOR, re.compile(
+ '^('
+ # Put long expressions first for greedy matching
+ '<>|'
+ '<=|'
+ '>=|'
+ '=|'
+ '<|'
+ '>)'),
+ ), (
+ self.Nonterminal.OPERAND, re.compile(
+ '^' +
+ attribute_regex + '(\.' + attribute_regex + '|\[[0-9]\])*')
+ ), (
+ self.Nonterminal.COMMA, re.compile('^,')
+ ), (
+ self.Nonterminal.LEFT_PAREN, re.compile('^\(')
+ ), (
+ self.Nonterminal.RIGHT_PAREN, re.compile('^\)')
+ )]
+
+ for nonterminal, pattern in patterns:
+ match = pattern.match(remaining_expression)
+ if match:
+ match_text = match.group()
+ break
+ else: # pragma: no cover
+ raise ValueError("Cannot parse condition starting at: " +
+ remaining_expression)
+
+ value = match_text
+ node = self.Node(
+ nonterminal=nonterminal,
+ kind=self.Kind.LITERAL,
+ text=match_text,
+ value=match_text,
+ children=[])
+
+ remaining_expression = remaining_expression[len(match_text):]
+
+ return node, remaining_expression
+
+ def _parse_paths(self, nodes):
+ output = deque()
+
+ while nodes:
+ node = nodes.popleft()
+
+ if node.nonterminal == self.Nonterminal.OPERAND:
+ path = node.value.replace('[', '.[').split('.')
+ children = [
+ self._parse_path_element(name)
+ for name in path]
+ if len(children) == 1:
+ child = children[0]
+ if child.nonterminal != self.Nonterminal.IDENTIFIER:
+ output.append(child)
+ continue
+ else:
+ for child in children:
+ self._assert(
+ child.nonterminal == self.Nonterminal.IDENTIFIER,
+ "Cannot use %s in path" % child.text, [node])
+ output.append(self.Node(
+ nonterminal=self.Nonterminal.OPERAND,
+ kind=self.Kind.PATH,
+ text=node.text,
+ value=None,
+ children=children))
+ else:
+ output.append(node)
+ return output
+
+ def _parse_path_element(self, name):
+ reserved = {
+ 'and': self.Nonterminal.AND,
+ 'or': self.Nonterminal.OR,
+ 'in': self.Nonterminal.IN,
+ 'between': self.Nonterminal.BETWEEN,
+ 'not': self.Nonterminal.NOT,
+ }
+
+ functions = {
+ 'attribute_exists',
+ 'attribute_not_exists',
+ 'attribute_type',
+ 'begins_with',
+ 'contains',
+ 'size',
+ }
+
+
+ if name.lower() in reserved:
+ # e.g. AND
+ nonterminal = reserved[name.lower()]
+ return self.Node(
+ nonterminal=nonterminal,
+ kind=self.Kind.LITERAL,
+ text=name,
+ value=name,
+ children=[])
+ elif name in functions:
+ # e.g. attribute_exists
+ return self.Node(
+ nonterminal=self.Nonterminal.FUNCTION_NAME,
+ kind=self.Kind.LITERAL,
+ text=name,
+ value=name,
+ children=[])
+ elif name.startswith(':'):
+ # e.g. :value0
+ return self.Node(
+ nonterminal=self.Nonterminal.OPERAND,
+ kind=self.Kind.EXPRESSION_ATTRIBUTE_VALUE,
+ text=name,
+ value=self._lookup_expression_attribute_value(name),
+ children=[])
+ elif name.startswith('#'):
+ # e.g. #name0
+ return self.Node(
+ nonterminal=self.Nonterminal.IDENTIFIER,
+ kind=self.Kind.LITERAL,
+ text=name,
+ value=self._lookup_expression_attribute_name(name),
+ children=[])
+ elif name.startswith('['):
+ # e.g. [123]
+ if not name.endswith(']'): # pragma: no cover
+ raise ValueError("Bad path element %s" % name)
+ return self.Node(
+ nonterminal=self.Nonterminal.IDENTIFIER,
+ kind=self.Kind.LITERAL,
+ text=name,
+ value=int(name[1:-1]),
+ children=[])
else:
- raise NotImplementedError()
+ # e.g. ItemId
+ return self.Node(
+ nonterminal=self.Nonterminal.IDENTIFIER,
+ kind=self.Kind.LITERAL,
+ text=name,
+ value=name,
+ children=[])
- # Remove all spaces, tbf we could just skip them in the next step.
- # The number of known options is really small so we can do a fair bit of cheating
- expr = list(expr.strip())
+ def _lookup_expression_attribute_value(self, name):
+ return self.expression_attribute_values[name]
- # DodgyTokenisation stage 1
- def is_value(val):
- return val not in ('<', '>', '=', '(', ')')
+ def _lookup_expression_attribute_name(self, name):
+ return self.expression_attribute_names[name]
- def contains_keyword(val):
- for kw in ('BETWEEN', 'IN', 'AND', 'OR', 'NOT'):
- if kw in val:
- return kw
- return None
+ # NOTE: The following constructions are ordered from high precedence to low precedence
+ # according to
+ # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence
+ #
+ # = <> < <= > >=
+ # IN
+ # BETWEEN
+ # attribute_exists attribute_not_exists begins_with contains
+ # Parentheses
+ # NOT
+ # AND
+ # OR
+ #
+ # The grammar is taken from
+ # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax
+ #
+ # condition-expression ::=
+ # operand comparator operand
+ # operand BETWEEN operand AND operand
+ # operand IN ( operand (',' operand (, ...) ))
+ # function
+ # condition AND condition
+ # condition OR condition
+ # NOT condition
+ # ( condition )
+ #
+ # comparator ::=
+ # =
+ # <>
+ # <
+ # <=
+ # >
+ # >=
+ #
+ # function ::=
+ # attribute_exists (path)
+ # attribute_not_exists (path)
+ # attribute_type (path, type)
+ # begins_with (path, substr)
+ # contains (path, operand)
+ # size (path)
- def is_function(val):
- return val in ('attribute_exists', 'attribute_not_exists', 'attribute_type', 'begins_with', 'contains', 'size')
+ def _matches(self, nodes, production):
+ """Check if the nodes start with the given production.
- # Does the main part of splitting between sections of characters
- tokens = []
- stack = ''
- while len(expr) > 0:
- current_char = expr.pop(0)
+ Parameters
+ ----------
+ nodes: list of Node
+ production: list of str
+ The name of a Nonterminal, or '*' for anything
- if current_char == ' ':
- if len(stack) > 0:
- tokens.append(stack)
- stack = ''
- elif current_char == ',': # Split params ,
- if len(stack) > 0:
- tokens.append(stack)
- stack = ''
- elif is_value(current_char):
- stack += current_char
+ """
+ if len(nodes) < len(production):
+ return False
+ for i in range(len(production)):
+ if production[i] == '*':
+ continue
+ expected = getattr(self.Nonterminal, production[i])
+ if nodes[i].nonterminal != expected:
+ return False
+ return True
- kw = contains_keyword(stack)
- if kw is not None:
- # We have a kw in the stack, could be AND or something like 5AND
- tmp = stack.replace(kw, '')
- if len(tmp) > 0:
- tokens.append(tmp)
- tokens.append(kw)
- stack = ''
- else:
- if len(stack) > 0:
- tokens.append(stack)
- tokens.append(current_char)
- stack = ''
- if len(stack) > 0:
- tokens.append(stack)
+ def _apply_comparator(self, nodes):
+ """Apply condition := operand comparator operand."""
+ output = deque()
- def is_op(val):
- return val in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT')
+ while nodes:
+ if self._matches(nodes, ['*', 'COMPARATOR']):
+ self._assert(
+ self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']),
+ "Bad comparison", list(nodes)[:3])
+ lhs = nodes.popleft()
+ comparator = nodes.popleft()
+ rhs = nodes.popleft()
+ nodes.appendleft(self.Node(
+ nonterminal=self.Nonterminal.CONDITION,
+ kind=self.Kind.COMPARISON,
+ text=" ".join([
+ lhs.text,
+ comparator.text,
+ rhs.text]),
+ value=None,
+ children=[lhs, comparator, rhs]))
+ else:
+ output.append(nodes.popleft())
+ return output
- # DodgyTokenisation stage 2, it groups together some elements to make RPN'ing it later easier.
- def handle_token(token, tokens2, token_iterator):
- # ok so this essentially groups up some tokens to make later parsing easier,
- # when it encounters brackets it will recurse and then unrecurse when RecursionStopIteration is raised.
- if token == ')':
- raise RecursionStopIteration() # Should be recursive so this should work
- elif token == '(':
- temp_list = []
-
- try:
+ def _apply_in(self, nodes):
+ """Apply condition := operand IN ( operand , ... )."""
+ output = deque()
+ while nodes:
+ if self._matches(nodes, ['*', 'IN']):
+ self._assert(
+ self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']),
+ "Bad IN expression", list(nodes)[:3])
+ lhs = nodes.popleft()
+ in_node = nodes.popleft()
+ left_paren = nodes.popleft()
+ all_children = [lhs, in_node, left_paren]
+ rhs = []
while True:
- next_token = six.next(token_iterator)
- handle_token(next_token, temp_list, token_iterator)
- except RecursionStopIteration:
- pass # Continue
- except StopIteration:
- ValueError('Malformed filter expression, type1')
-
- # Sigh, we only want to group a tuple if it doesnt contain operators
- if any([is_op(item) for item in temp_list]):
- # Its an expression
- tokens2.append('(')
- tokens2.extend(temp_list)
- tokens2.append(')')
+ if self._matches(nodes, ['OPERAND', 'COMMA']):
+ operand = nodes.popleft()
+ separator = nodes.popleft()
+ all_children += [operand, separator]
+ rhs.append(operand)
+ elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']):
+ operand = nodes.popleft()
+ separator = nodes.popleft()
+ all_children += [operand, separator]
+ rhs.append(operand)
+ break # Close
+ else:
+ self._assert(
+ False,
+ "Bad IN expression starting at", nodes)
+ nodes.appendleft(self.Node(
+ nonterminal=self.Nonterminal.CONDITION,
+ kind=self.Kind.IN,
+ text=" ".join([t.text for t in all_children]),
+ value=None,
+ children=[lhs] + rhs))
else:
- tokens2.append(tuple(temp_list))
- elif token == 'BETWEEN':
- field = tokens2.pop()
- # if values map contains a number, it would be a float
- # so we need to int() it anyway
- op1 = six.next(token_iterator)
- op1 = int(values_map.get(op1, op1))
- and_op = six.next(token_iterator)
- assert and_op == 'AND'
- op2 = six.next(token_iterator)
- op2 = int(values_map.get(op2, op2))
- tokens2.append(['between', field, op1, op2])
- elif is_function(token):
- function_list = [token]
+ output.append(nodes.popleft())
+ return output
- lbracket = six.next(token_iterator)
- assert lbracket == '('
-
- next_token = six.next(token_iterator)
- while next_token != ')':
- if next_token in values_map:
- next_token = values_map[next_token]
- function_list.append(next_token)
- next_token = six.next(token_iterator)
-
- tokens2.append(function_list)
- else:
- # Convert tokens back to real types
- if token in values_map:
- token = values_map[token]
-
- # Need to join >= <= <>
- if len(tokens2) > 0 and ((tokens2[-1] == '>' and token == '=') or (tokens2[-1] == '<' and token == '=') or (tokens2[-1] == '<' and token == '>')):
- tokens2.append(tokens2.pop() + token)
+ def _apply_between(self, nodes):
+ """Apply condition := operand BETWEEN operand AND operand."""
+ output = deque()
+ while nodes:
+ if self._matches(nodes, ['*', 'BETWEEN']):
+ self._assert(
+ self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND',
+ 'AND', 'OPERAND']),
+ "Bad BETWEEN expression", list(nodes)[:5])
+ lhs = nodes.popleft()
+ between_node = nodes.popleft()
+ low = nodes.popleft()
+ and_node = nodes.popleft()
+ high = nodes.popleft()
+ all_children = [lhs, between_node, low, and_node, high]
+ nodes.appendleft(self.Node(
+ nonterminal=self.Nonterminal.CONDITION,
+ kind=self.Kind.BETWEEN,
+ text=" ".join([t.text for t in all_children]),
+ value=None,
+ children=[lhs, low, high]))
else:
- tokens2.append(token)
+ output.append(nodes.popleft())
+ return output
- tokens2 = []
- token_iterator = iter(tokens)
- for token in token_iterator:
- handle_token(token, tokens2, token_iterator)
-
- # Start of the Shunting-Yard algorithm. <-- Proper beast algorithm!
- def is_number(val):
- return val not in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT')
-
- OPS = {'<': 5, '>': 5, '=': 5, '>=': 5, '<=': 5, '<>': 5, 'IN': 8, 'AND': 11, 'OR': 12, 'NOT': 10, 'BETWEEN': 9, '(': 100, ')': 100}
-
- def shunting_yard(token_list):
- output = []
- op_stack = []
-
- # Basically takes in an infix notation calculation, converts it to a reverse polish notation where there is no
- # ambiguity on which order operators are applied.
- while len(token_list) > 0:
- token = token_list.pop(0)
-
- if token == '(':
- op_stack.append(token)
- elif token == ')':
- while len(op_stack) > 0 and op_stack[-1] != '(':
- output.append(op_stack.pop())
- lbracket = op_stack.pop()
- assert lbracket == '('
-
- elif is_number(token):
- output.append(token)
+ def _apply_functions(self, nodes):
+ """Apply condition := function_name (operand , ...)."""
+ output = deque()
+ either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE}
+ expected_argument_kind_map = {
+ 'attribute_exists': [{self.Kind.PATH}],
+ 'attribute_not_exists': [{self.Kind.PATH}],
+ 'attribute_type': [either_kind, {self.Kind.EXPRESSION_ATTRIBUTE_VALUE}],
+ 'begins_with': [either_kind, either_kind],
+ 'contains': [either_kind, either_kind],
+ 'size': [{self.Kind.PATH}],
+ }
+ while nodes:
+ if self._matches(nodes, ['FUNCTION_NAME']):
+ self._assert(
+ self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN',
+ 'OPERAND', '*']),
+ "Bad function expression at", list(nodes)[:4])
+ function_name = nodes.popleft()
+ left_paren = nodes.popleft()
+ all_children = [function_name, left_paren]
+ arguments = []
+ while True:
+ if self._matches(nodes, ['OPERAND', 'COMMA']):
+ operand = nodes.popleft()
+ separator = nodes.popleft()
+ all_children += [operand, separator]
+ arguments.append(operand)
+ elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']):
+ operand = nodes.popleft()
+ separator = nodes.popleft()
+ all_children += [operand, separator]
+ arguments.append(operand)
+ break # Close paren
+ else:
+ self._assert(
+ False,
+ "Bad function expression", all_children + list(nodes)[:2])
+ expected_kinds = expected_argument_kind_map[function_name.value]
+ self._assert(
+ len(arguments) == len(expected_kinds),
+ "Wrong number of arguments in", all_children)
+ for i in range(len(expected_kinds)):
+ self._assert(
+ arguments[i].kind in expected_kinds[i],
+ "Wrong type for argument %d in" % i, all_children)
+ if function_name.value == 'size':
+ nonterminal = self.Nonterminal.OPERAND
+ else:
+ nonterminal = self.Nonterminal.CONDITION
+ nodes.appendleft(self.Node(
+ nonterminal=nonterminal,
+ kind=self.Kind.FUNCTION,
+ text=" ".join([t.text for t in all_children]),
+ value=None,
+ children=[function_name] + arguments))
else:
- # Must be operator kw
+ output.append(nodes.popleft())
+ return output
- # Cheat, NOT is our only RIGHT associative operator, should really have dict of operator associativity
- while len(op_stack) > 0 and OPS[op_stack[-1]] <= OPS[token] and op_stack[-1] != 'NOT':
- output.append(op_stack.pop())
- op_stack.append(token)
- while len(op_stack) > 0:
- output.append(op_stack.pop())
+ def _apply_parens_and_booleans(self, nodes, left_paren=None):
+ """Apply condition := ( condition ) and booleans."""
+ output = deque()
+ while nodes:
+ if self._matches(nodes, ['LEFT_PAREN']):
+ parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft())
+ self._assert(
+ len(parsed) >= 1,
+ "Failed to close parentheses at", nodes)
+ parens = parsed.popleft()
+ self._assert(
+ parens.kind == self.Kind.PARENTHESES,
+ "Failed to close parentheses at", nodes)
+ output.append(parens)
+ nodes = parsed
+ elif self._matches(nodes, ['RIGHT_PAREN']):
+ self._assert(
+ left_paren is not None,
+ "Unmatched ) at", nodes)
+ close_paren = nodes.popleft()
+ children = self._apply_booleans(output)
+ all_children = [left_paren] + list(children) + [close_paren]
+ return deque([
+ self.Node(
+ nonterminal=self.Nonterminal.CONDITION,
+ kind=self.Kind.PARENTHESES,
+ text=" ".join([t.text for t in all_children]),
+ value=None,
+ children=list(children),
+ )] + list(nodes))
+ else:
+ output.append(nodes.popleft())
+
+ self._assert(
+ left_paren is None,
+ "Unmatched ( at", list(output))
+ return self._apply_booleans(output)
+
+ def _apply_booleans(self, nodes):
+ """Apply and, or, and not constructions."""
+ nodes = self._apply_not(nodes)
+ nodes = self._apply_and(nodes)
+ nodes = self._apply_or(nodes)
+ # The expression should reduce to a single condition
+ self._assert(
+ len(nodes) == 1,
+ "Unexpected expression at", list(nodes)[1:])
+ self._assert(
+ nodes[0].nonterminal == self.Nonterminal.CONDITION,
+ "Incomplete condition", nodes)
+ return nodes
+
+ def _apply_not(self, nodes):
+ """Apply condition := NOT condition."""
+ output = deque()
+ while nodes:
+ if self._matches(nodes, ['NOT']):
+ self._assert(
+ self._matches(nodes, ['NOT', 'CONDITION']),
+ "Bad NOT expression", list(nodes)[:2])
+ not_node = nodes.popleft()
+ child = nodes.popleft()
+ nodes.appendleft(self.Node(
+ nonterminal=self.Nonterminal.CONDITION,
+ kind=self.Kind.NOT,
+ text=" ".join([not_node.text, child.text]),
+ value=None,
+ children=[child]))
+ else:
+ output.append(nodes.popleft())
return output
- output = shunting_yard(tokens2)
-
- # Hacky function to convert dynamo functions (which are represented as lists) to their Class equivalent
- def to_func(val):
- if isinstance(val, list):
- func_name = val.pop(0)
- # Expand rest of the list to arguments
- val = FUNC_CLASS[func_name](*val)
-
- return val
-
- # Simple reverse polish notation execution. Builts up a nested filter object.
- # The filter object then takes a dynamo item and returns true/false
- stack = []
- for token in output:
- if is_op(token):
- op_cls = OP_CLASS[token]
-
- if token == 'NOT':
- op1 = stack.pop()
- op2 = True
+ def _apply_and(self, nodes):
+ """Apply condition := condition AND condition."""
+ output = deque()
+ while nodes:
+ if self._matches(nodes, ['*', 'AND']):
+ self._assert(
+ self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']),
+ "Bad AND expression", list(nodes)[:3])
+ lhs = nodes.popleft()
+ and_node = nodes.popleft()
+ rhs = nodes.popleft()
+ all_children = [lhs, and_node, rhs]
+ nodes.appendleft(self.Node(
+ nonterminal=self.Nonterminal.CONDITION,
+ kind=self.Kind.AND,
+ text=" ".join([t.text for t in all_children]),
+ value=None,
+ children=[lhs, rhs]))
else:
- op2 = stack.pop()
- op1 = stack.pop()
+ output.append(nodes.popleft())
- stack.append(op_cls(op1, op2))
+ return output
+
+ def _apply_or(self, nodes):
+ """Apply condition := condition OR condition."""
+ output = deque()
+ while nodes:
+ if self._matches(nodes, ['*', 'OR']):
+ self._assert(
+ self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']),
+ "Bad OR expression", list(nodes)[:3])
+ lhs = nodes.popleft()
+ or_node = nodes.popleft()
+ rhs = nodes.popleft()
+ all_children = [lhs, or_node, rhs]
+ nodes.appendleft(self.Node(
+ nonterminal=self.Nonterminal.CONDITION,
+ kind=self.Kind.OR,
+ text=" ".join([t.text for t in all_children]),
+ value=None,
+ children=[lhs, rhs]))
+ else:
+ output.append(nodes.popleft())
+
+ return output
+
+ def _make_operand(self, node):
+ if node.kind == self.Kind.PATH:
+ return AttributePath([child.value for child in node.children])
+ elif node.kind == self.Kind.EXPRESSION_ATTRIBUTE_VALUE:
+ return AttributeValue(node.value)
+ elif node.kind == self.Kind.FUNCTION:
+ # size()
+ function_node = node.children[0]
+ arguments = node.children[1:]
+ function_name = function_node.value
+ arguments = [self._make_operand(arg) for arg in arguments]
+ return FUNC_CLASS[function_name](*arguments)
+ else: # pragma: no cover
+ raise ValueError("Unknown operand: %r" % node)
+
+
+ def _make_op_condition(self, node):
+ if node.kind == self.Kind.OR:
+ lhs, rhs = node.children
+ return OpOr(
+ self._make_op_condition(lhs),
+ self._make_op_condition(rhs))
+ elif node.kind == self.Kind.AND:
+ lhs, rhs = node.children
+ return OpAnd(
+ self._make_op_condition(lhs),
+ self._make_op_condition(rhs))
+ elif node.kind == self.Kind.NOT:
+ child, = node.children
+ return OpNot(self._make_op_condition(child))
+ elif node.kind == self.Kind.PARENTHESES:
+ child, = node.children
+ return self._make_op_condition(child)
+ elif node.kind == self.Kind.FUNCTION:
+ function_node = node.children[0]
+ arguments = node.children[1:]
+ function_name = function_node.value
+ arguments = [self._make_operand(arg) for arg in arguments]
+ return FUNC_CLASS[function_name](*arguments)
+ elif node.kind == self.Kind.BETWEEN:
+ query, low, high = node.children
+ return FuncBetween(
+ self._make_operand(query),
+ self._make_operand(low),
+ self._make_operand(high))
+ elif node.kind == self.Kind.IN:
+ query = node.children[0]
+ possible_values = node.children[1:]
+ query = self._make_operand(query)
+ possible_values = [self._make_operand(v) for v in possible_values]
+ return FuncIn(query, *possible_values)
+ elif node.kind == self.Kind.COMPARISON:
+ lhs, comparator, rhs = node.children
+ return COMPARATOR_CLASS[comparator.value](
+ self._make_operand(lhs),
+ self._make_operand(rhs))
+ else: # pragma: no cover
+ raise ValueError("Unknown expression node kind %r" % node.kind)
+
+ def _print_debug(self, nodes): # pragma: no cover
+ print('ROOT')
+ for node in nodes:
+ self._print_node_recursive(node, depth=1)
+
+ def _print_node_recursive(self, node, depth=0): # pragma: no cover
+ if len(node.children) > 0:
+ print(' ' * depth, node.nonterminal, node.kind)
+ for child in node.children:
+ self._print_node_recursive(child, depth=depth + 1)
else:
- stack.append(to_func(token))
-
- result = stack.pop(0)
- if len(stack) > 0:
- raise ValueError('Malformed filter expression, type2')
-
- return result
+ print(' ' * depth, node.nonterminal, node.kind, node.value)
-class Op(object):
- """
- Base class for a FilterExpression operator
- """
- OP = ''
- def __init__(self, lhs, rhs):
- self.lhs = lhs
- self.rhs = rhs
+ def _assert(self, condition, message, nodes):
+ if not condition:
+ raise ValueError(message + " " + " ".join([t.text for t in nodes]))
+
+
+class Operand(object):
+ def expr(self, item):
+ raise NotImplementedError
+
+ def get_type(self, item):
+ raise NotImplementedError
+
+
+class AttributePath(Operand):
+ def __init__(self, path):
+ """Initialize the AttributePath.
+
+ Parameters
+ ----------
+ path: list of int/str
- def _lhs(self, item):
"""
- :type item: moto.dynamodb2.models.Item
- """
- lhs = self.lhs
- if isinstance(self.lhs, (Op, Func)):
- lhs = self.lhs.expr(item)
- elif isinstance(self.lhs, six.string_types):
- try:
- lhs = item.attrs[self.lhs].cast_value
- except Exception:
- pass
+ assert len(path) >= 1
+ self.path = path
- return lhs
+ def _get_attr(self, item):
+ if item is None:
+ return None
- def _rhs(self, item):
- rhs = self.rhs
- if isinstance(self.rhs, (Op, Func)):
- rhs = self.rhs.expr(item)
- elif isinstance(self.rhs, six.string_types):
- try:
- rhs = item.attrs[self.rhs].cast_value
- except Exception:
- pass
- return rhs
+ base = self.path[0]
+ if base not in item.attrs:
+ return None
+ attr = item.attrs[base]
+
+ for name in self.path[1:]:
+ attr = attr.child_attr(name)
+ if attr is None:
+ return None
+
+ return attr
def expr(self, item):
- return True
+ attr = self._get_attr(item)
+ if attr is None:
+ return None
+ else:
+ return attr.cast_value
+
+ def get_type(self, item):
+ attr = self._get_attr(item)
+ if attr is None:
+ return None
+ else:
+ return attr.type
def __repr__(self):
- return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs)
+ return ".".join(self.path)
-class Func(object):
- """
- Base class for a FilterExpression function
- """
- FUNC = 'Unknown'
+class AttributeValue(Operand):
+ def __init__(self, value):
+ """Initialize the AttributePath.
+
+ Parameters
+ ----------
+ value: dict
+ e.g. {'N': '1.234'}
+
+ """
+ self.type = list(value.keys())[0]
+ self.value = value[self.type]
def expr(self, item):
- return True
+ # TODO: Reuse DynamoType code
+ if self.type == 'N':
+ try:
+ return int(self.value)
+ except ValueError:
+ return float(self.value)
+ elif self.type in ['SS', 'NS', 'BS']:
+ sub_type = self.type[0]
+ return set([AttributeValue({sub_type: v}).expr(item)
+ for v in self.value])
+ elif self.type == 'L':
+ return [AttributeValue(v).expr(item) for v in self.value]
+ elif self.type == 'M':
+ return dict([
+ (k, AttributeValue(v).expr(item))
+ for k, v in self.value.items()])
+ else:
+ return self.value
+ return self.value
+
+ def get_type(self, item):
+ return self.type
def __repr__(self):
- return 'Func(...)'.format(self.FUNC)
+ return repr(self.value)
+
+
+class OpDefault(Op):
+ OP = 'NONE'
+
+ def expr(self, item):
+ """If no condition is specified, always True."""
+ return True
class OpNot(Op):
OP = 'NOT'
- def expr(self, item):
- lhs = self._lhs(item)
+ def __init__(self, lhs):
+ super(OpNot, self).__init__(lhs, None)
+ def expr(self, item):
+ lhs = self.lhs.expr(item)
return not lhs
def __str__(self):
@@ -345,8 +940,8 @@ class OpAnd(Op):
OP = 'AND'
def expr(self, item):
- lhs = self._lhs(item)
- rhs = self._rhs(item)
+ lhs = self.lhs.expr(item)
+ rhs = self.rhs.expr(item)
return lhs and rhs
@@ -354,8 +949,8 @@ class OpLessThan(Op):
OP = '<'
def expr(self, item):
- lhs = self._lhs(item)
- rhs = self._rhs(item)
+ lhs = self.lhs.expr(item)
+ rhs = self.rhs.expr(item)
return lhs < rhs
@@ -363,8 +958,8 @@ class OpGreaterThan(Op):
OP = '>'
def expr(self, item):
- lhs = self._lhs(item)
- rhs = self._rhs(item)
+ lhs = self.lhs.expr(item)
+ rhs = self.rhs.expr(item)
return lhs > rhs
@@ -372,8 +967,8 @@ class OpEqual(Op):
OP = '='
def expr(self, item):
- lhs = self._lhs(item)
- rhs = self._rhs(item)
+ lhs = self.lhs.expr(item)
+ rhs = self.rhs.expr(item)
return lhs == rhs
@@ -381,8 +976,8 @@ class OpNotEqual(Op):
OP = '<>'
def expr(self, item):
- lhs = self._lhs(item)
- rhs = self._rhs(item)
+ lhs = self.lhs.expr(item)
+ rhs = self.rhs.expr(item)
return lhs != rhs
@@ -390,8 +985,8 @@ class OpLessThanOrEqual(Op):
OP = '<='
def expr(self, item):
- lhs = self._lhs(item)
- rhs = self._rhs(item)
+ lhs = self.lhs.expr(item)
+ rhs = self.rhs.expr(item)
return lhs <= rhs
@@ -399,8 +994,8 @@ class OpGreaterThanOrEqual(Op):
OP = '>='
def expr(self, item):
- lhs = self._lhs(item)
- rhs = self._rhs(item)
+ lhs = self.lhs.expr(item)
+ rhs = self.rhs.expr(item)
return lhs >= rhs
@@ -408,18 +1003,27 @@ class OpOr(Op):
OP = 'OR'
def expr(self, item):
- lhs = self._lhs(item)
- rhs = self._rhs(item)
+ lhs = self.lhs.expr(item)
+ rhs = self.rhs.expr(item)
return lhs or rhs
-class OpIn(Op):
- OP = 'IN'
+class Func(object):
+ """
+ Base class for a FilterExpression function
+ """
+ FUNC = 'Unknown'
+
+ def __init__(self, *arguments):
+ self.arguments = arguments
def expr(self, item):
- lhs = self._lhs(item)
- rhs = self._rhs(item)
- return lhs in rhs
+ raise NotImplementedError
+
+ def __repr__(self):
+ return '{0}({1})'.format(
+ self.FUNC,
+ " ".join([repr(arg) for arg in self.arguments]))
class FuncAttrExists(Func):
@@ -427,19 +1031,14 @@ class FuncAttrExists(Func):
def __init__(self, attribute):
self.attr = attribute
+ super(FuncAttrExists, self).__init__(attribute)
def expr(self, item):
- return self.attr in item.attrs
+ return self.attr.get_type(item) is not None
-class FuncAttrNotExists(Func):
- FUNC = 'attribute_not_exists'
-
- def __init__(self, attribute):
- self.attr = attribute
-
- def expr(self, item):
- return self.attr not in item.attrs
+def FuncAttrNotExists(attribute):
+ return OpNot(FuncAttrExists(attribute))
class FuncAttrType(Func):
@@ -448,9 +1047,10 @@ class FuncAttrType(Func):
def __init__(self, attribute, _type):
self.attr = attribute
self.type = _type
+ super(FuncAttrType, self).__init__(attribute, _type)
def expr(self, item):
- return self.attr in item.attrs and item.attrs[self.attr].type == self.type
+ return self.attr.get_type(item) == self.type.expr(item)
class FuncBeginsWith(Func):
@@ -459,9 +1059,14 @@ class FuncBeginsWith(Func):
def __init__(self, attribute, substr):
self.attr = attribute
self.substr = substr
+ super(FuncBeginsWith, self).__init__(attribute, substr)
def expr(self, item):
- return self.attr in item.attrs and item.attrs[self.attr].type == 'S' and item.attrs[self.attr].value.startswith(self.substr)
+ if self.attr.get_type(item) != 'S':
+ return False
+ if self.substr.get_type(item) != 'S':
+ return False
+ return self.attr.expr(item).startswith(self.substr.expr(item))
class FuncContains(Func):
@@ -470,51 +1075,67 @@ class FuncContains(Func):
def __init__(self, attribute, operand):
self.attr = attribute
self.operand = operand
+ super(FuncContains, self).__init__(attribute, operand)
def expr(self, item):
- if self.attr not in item.attrs:
- return False
-
- if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'BS', 'L', 'M'):
- return self.operand in item.attrs[self.attr].value
+ if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L'):
+ try:
+ return self.operand.expr(item) in self.attr.expr(item)
+ except TypeError:
+ return False
return False
+def FuncNotContains(attribute, operand):
+ return OpNot(FuncContains(attribute, operand))
+
+
class FuncSize(Func):
- FUNC = 'contains'
+ FUNC = 'size'
def __init__(self, attribute):
self.attr = attribute
+ super(FuncSize, self).__init__(attribute)
def expr(self, item):
- if self.attr not in item.attrs:
+ if self.attr.get_type(item) is None:
raise ValueError('Invalid attribute name {0}'.format(self.attr))
- if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'):
- return len(item.attrs[self.attr].value)
+ if self.attr.get_type(item) in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'):
+ return len(self.attr.expr(item))
raise ValueError('Invalid filter expression')
class FuncBetween(Func):
- FUNC = 'between'
+ FUNC = 'BETWEEN'
def __init__(self, attribute, start, end):
self.attr = attribute
self.start = start
self.end = end
+ super(FuncBetween, self).__init__(attribute, start, end)
def expr(self, item):
- if self.attr not in item.attrs:
- raise ValueError('Invalid attribute name {0}'.format(self.attr))
-
- return self.start <= item.attrs[self.attr].cast_value <= self.end
+ return self.start.expr(item) <= self.attr.expr(item) <= self.end.expr(item)
-OP_CLASS = {
- 'NOT': OpNot,
- 'AND': OpAnd,
- 'OR': OpOr,
- 'IN': OpIn,
+class FuncIn(Func):
+ FUNC = 'IN'
+
+ def __init__(self, attribute, *possible_values):
+ self.attr = attribute
+ self.possible_values = possible_values
+ super(FuncIn, self).__init__(attribute, *possible_values)
+
+ def expr(self, item):
+ for possible_value in self.possible_values:
+ if self.attr.expr(item) == possible_value.expr(item):
+ return True
+
+ return False
+
+
+COMPARATOR_CLASS = {
'<': OpLessThan,
'>': OpGreaterThan,
'<=': OpLessThanOrEqual,
diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py
index bfbb654b4..29e90e7dc 100644
--- a/moto/dynamodb2/models.py
+++ b/moto/dynamodb2/models.py
@@ -6,13 +6,16 @@ import decimal
import json
import re
import uuid
+import six
import boto3
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time
from moto.core.exceptions import JsonRESTError
-from .comparisons import get_comparison_func, get_filter_expression, Op
+from .comparisons import get_comparison_func
+from .comparisons import get_filter_expression
+from .comparisons import get_expected
from .exceptions import InvalidIndexNameError
@@ -68,10 +71,34 @@ class DynamoType(object):
except ValueError:
return float(self.value)
elif self.is_set():
- return set(self.value)
+ sub_type = self.type[0]
+ return set([DynamoType({sub_type: v}).cast_value
+ for v in self.value])
+ elif self.is_list():
+ return [DynamoType(v).cast_value for v in self.value]
+ elif self.is_map():
+ return dict([
+ (k, DynamoType(v).cast_value)
+ for k, v in self.value.items()])
else:
return self.value
+ def child_attr(self, key):
+ """
+ Get Map or List children by key. str for Map, int for List.
+
+ Returns DynamoType or None.
+ """
+ if isinstance(key, six.string_types) and self.is_map() and key in self.value:
+ return DynamoType(self.value[key])
+
+ if isinstance(key, int) and self.is_list():
+ idx = key
+ if idx >= 0 and idx < len(self.value):
+ return DynamoType(self.value[idx])
+
+ return None
+
def to_json(self):
return {self.type: self.value}
@@ -89,6 +116,12 @@ class DynamoType(object):
def is_set(self):
return self.type == 'SS' or self.type == 'NS' or self.type == 'BS'
+ def is_list(self):
+ return self.type == 'L'
+
+ def is_map(self):
+ return self.type == 'M'
+
def same_type(self, other):
return self.type == other.type
@@ -504,7 +537,9 @@ class Table(BaseModel):
keys.append(range_key)
return keys
- def put_item(self, item_attrs, expected=None, overwrite=False):
+ def put_item(self, item_attrs, expected=None, condition_expression=None,
+ expression_attribute_names=None,
+ expression_attribute_values=None, overwrite=False):
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
if self.has_range_key:
range_value = DynamoType(item_attrs.get(self.range_key_attr))
@@ -527,29 +562,15 @@ class Table(BaseModel):
self.range_key_type, item_attrs)
if not overwrite:
- if current is None:
- current_attr = {}
- elif hasattr(current, 'attrs'):
- current_attr = current.attrs
- else:
- current_attr = current
+ if not get_expected(expected).expr(current):
+ raise ValueError('The conditional request failed')
+ condition_op = get_filter_expression(
+ condition_expression,
+ expression_attribute_names,
+ expression_attribute_values)
+ if not condition_op.expr(current):
+ raise ValueError('The conditional request failed')
- for key, val in expected.items():
- if 'Exists' in val and val['Exists'] is False \
- or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
- if key in current_attr:
- raise ValueError("The conditional request failed")
- elif key not in current_attr:
- raise ValueError("The conditional request failed")
- elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value:
- raise ValueError("The conditional request failed")
- elif 'ComparisonOperator' in val:
- dynamo_types = [
- DynamoType(ele) for ele in
- val.get("AttributeValueList", [])
- ]
- if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types):
- raise ValueError('The conditional request failed')
if range_value:
self.items[hash_value][range_value] = item
else:
@@ -902,11 +923,15 @@ class DynamoDBBackend(BaseBackend):
table.global_indexes = list(gsis_by_name.values())
return table
- def put_item(self, table_name, item_attrs, expected=None, overwrite=False):
+ def put_item(self, table_name, item_attrs, expected=None,
+ condition_expression=None, expression_attribute_names=None,
+ expression_attribute_values=None, overwrite=False):
table = self.tables.get(table_name)
if not table:
return None
- return table.put_item(item_attrs, expected, overwrite)
+ return table.put_item(item_attrs, expected, condition_expression,
+ expression_attribute_names,
+ expression_attribute_values, overwrite)
def get_table_keys_name(self, table_name, keys):
"""
@@ -962,10 +987,7 @@ class DynamoDBBackend(BaseBackend):
range_values = [DynamoType(range_value)
for range_value in range_value_dicts]
- if filter_expression is not None:
- filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
- else:
- filter_expression = Op(None, None) # Will always eval to true
+ filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs)
@@ -980,17 +1002,14 @@ class DynamoDBBackend(BaseBackend):
dynamo_types = [DynamoType(value) for value in comparison_values]
scan_filters[key] = (comparison_operator, dynamo_types)
- if filter_expression is not None:
- filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
- else:
- filter_expression = Op(None, None) # Will always eval to true
+ filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
projection_expression = ','.join([expr_names.get(attr, attr) for attr in projection_expression.replace(' ', '').split(',')])
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name, projection_expression)
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
- expression_attribute_values, expected=None):
+ expression_attribute_values, expected=None, condition_expression=None):
table = self.get_table(table_name)
if all([table.hash_key_attr in key, table.range_key_attr in key]):
@@ -1009,32 +1028,17 @@ class DynamoDBBackend(BaseBackend):
item = table.get_item(hash_value, range_value)
- if item is None:
- item_attr = {}
- elif hasattr(item, 'attrs'):
- item_attr = item.attrs
- else:
- item_attr = item
-
if not expected:
expected = {}
- for key, val in expected.items():
- if 'Exists' in val and val['Exists'] is False \
- or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
- if key in item_attr:
- raise ValueError("The conditional request failed")
- elif key not in item_attr:
- raise ValueError("The conditional request failed")
- elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value:
- raise ValueError("The conditional request failed")
- elif 'ComparisonOperator' in val:
- dynamo_types = [
- DynamoType(ele) for ele in
- val.get("AttributeValueList", [])
- ]
- if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types):
- raise ValueError('The conditional request failed')
+ if not get_expected(expected).expr(item):
+ raise ValueError('The conditional request failed')
+ condition_op = get_filter_expression(
+ condition_expression,
+ expression_attribute_names,
+ expression_attribute_values)
+ if not condition_op.expr(item):
+ raise ValueError('The conditional request failed')
# Update does not fail on new items, so create one
if item is None:
diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py
index 5dde432d5..d34b176a7 100644
--- a/moto/dynamodb2/responses.py
+++ b/moto/dynamodb2/responses.py
@@ -32,67 +32,6 @@ def get_empty_str_error():
))
-def condition_expression_to_expected(condition_expression, expression_attribute_names, expression_attribute_values):
- """
- Limited condition expression syntax parsing.
- Supports Global Negation ex: NOT(inner expressions).
- Supports simple AND conditions ex: cond_a AND cond_b and cond_c.
- Atomic expressions supported are attribute_exists(key), attribute_not_exists(key) and #key = :value.
- """
- expected = {}
- if condition_expression and 'OR' not in condition_expression:
- reverse_re = re.compile('^NOT\s*\((.*)\)$')
- reverse_m = reverse_re.match(condition_expression.strip())
-
- reverse = False
- if reverse_m:
- reverse = True
- condition_expression = reverse_m.group(1)
-
- cond_items = [c.strip() for c in condition_expression.split('AND')]
- if cond_items:
- exists_re = re.compile('^attribute_exists\s*\((.*)\)$')
- not_exists_re = re.compile(
- '^attribute_not_exists\s*\((.*)\)$')
- equals_re = re.compile('^(#?\w+)\s*=\s*(\:?\w+)')
-
- for cond in cond_items:
- exists_m = exists_re.match(cond)
- not_exists_m = not_exists_re.match(cond)
- equals_m = equals_re.match(cond)
-
- if exists_m:
- attribute_name = expression_attribute_names_lookup(exists_m.group(1), expression_attribute_names)
- expected[attribute_name] = {'Exists': True if not reverse else False}
- elif not_exists_m:
- attribute_name = expression_attribute_names_lookup(not_exists_m.group(1), expression_attribute_names)
- expected[attribute_name] = {'Exists': False if not reverse else True}
- elif equals_m:
- attribute_name = expression_attribute_names_lookup(equals_m.group(1), expression_attribute_names)
- attribute_value = expression_attribute_values_lookup(equals_m.group(2), expression_attribute_values)
- expected[attribute_name] = {
- 'AttributeValueList': [attribute_value],
- 'ComparisonOperator': 'EQ' if not reverse else 'NEQ'}
-
- return expected
-
-
-def expression_attribute_names_lookup(attribute_name, expression_attribute_names):
- if attribute_name.startswith('#') and attribute_name in expression_attribute_names:
- return expression_attribute_names[attribute_name]
- else:
- return attribute_name
-
-
-def expression_attribute_values_lookup(attribute_value, expression_attribute_values):
- if isinstance(attribute_value, six.string_types) and \
- attribute_value.startswith(':') and\
- attribute_value in expression_attribute_values:
- return expression_attribute_values[attribute_value]
- else:
- return attribute_value
-
-
class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers):
@@ -288,18 +227,18 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected
# expression
- if not expected:
- condition_expression = self.body.get('ConditionExpression')
- expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
- expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
- expected = condition_expression_to_expected(condition_expression,
- expression_attribute_names,
- expression_attribute_values)
- if expected:
- overwrite = False
+ condition_expression = self.body.get('ConditionExpression')
+ expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
+ expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
+
+ if condition_expression:
+ overwrite = False
try:
- result = self.dynamodb_backend.put_item(name, item, expected, overwrite)
+ result = self.dynamodb_backend.put_item(
+ name, item, expected, condition_expression,
+ expression_attribute_names, expression_attribute_values,
+ overwrite)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er, 'A condition specified in the operation could not be evaluated.')
@@ -626,7 +565,7 @@ class DynamoHandler(BaseResponse):
name = self.body['TableName']
key = self.body['Key']
return_values = self.body.get('ReturnValues', 'NONE')
- update_expression = self.body.get('UpdateExpression')
+ update_expression = self.body.get('UpdateExpression', '').strip()
attribute_updates = self.body.get('AttributeUpdates')
expression_attribute_names = self.body.get(
'ExpressionAttributeNames', {})
@@ -653,13 +592,9 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected
# expression
- if not expected:
- condition_expression = self.body.get('ConditionExpression')
- expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
- expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
- expected = condition_expression_to_expected(condition_expression,
- expression_attribute_names,
- expression_attribute_values)
+ condition_expression = self.body.get('ConditionExpression')
+ expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
+ expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
# Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c`
@@ -670,7 +605,7 @@ class DynamoHandler(BaseResponse):
try:
item = self.dynamodb_backend.update_item(
name, key, update_expression, attribute_updates, expression_attribute_names,
- expression_attribute_values, expected
+ expression_attribute_values, expected, condition_expression
)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
diff --git a/moto/rds2/exceptions.py b/moto/rds2/exceptions.py
index 0e716310e..e82ae7077 100644
--- a/moto/rds2/exceptions.py
+++ b/moto/rds2/exceptions.py
@@ -60,6 +60,15 @@ class DBParameterGroupNotFoundError(RDSClientError):
'DB Parameter Group {0} not found.'.format(db_parameter_group_name))
+class OptionGroupNotFoundFaultError(RDSClientError):
+
+ def __init__(self, option_group_name):
+ super(OptionGroupNotFoundFaultError, self).__init__(
+ 'OptionGroupNotFoundFault',
+ 'Specified OptionGroupName: {0} not found.'.format(option_group_name)
+ )
+
+
class InvalidDBClusterStateFaultError(RDSClientError):
def __init__(self, database_identifier):
diff --git a/moto/rds2/models.py b/moto/rds2/models.py
index fee004f76..81b346fdb 100644
--- a/moto/rds2/models.py
+++ b/moto/rds2/models.py
@@ -20,6 +20,7 @@ from .exceptions import (RDSClientError,
DBSecurityGroupNotFoundError,
DBSubnetGroupNotFoundError,
DBParameterGroupNotFoundError,
+ OptionGroupNotFoundFaultError,
InvalidDBClusterStateFaultError,
InvalidDBInstanceStateError,
SnapshotQuotaExceededError,
@@ -70,6 +71,7 @@ class Database(BaseModel):
self.port = Database.default_port(self.engine)
self.db_instance_identifier = kwargs.get('db_instance_identifier')
self.db_name = kwargs.get("db_name")
+ self.instance_create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
self.publicly_accessible = kwargs.get("publicly_accessible")
if self.publicly_accessible is None:
self.publicly_accessible = True
@@ -99,6 +101,8 @@ class Database(BaseModel):
'preferred_backup_window', '13:14-13:44')
self.license_model = kwargs.get('license_model', 'general-public-license')
self.option_group_name = kwargs.get('option_group_name', None)
+ if self.option_group_name and self.option_group_name not in rds2_backends[self.region].option_groups:
+ raise OptionGroupNotFoundFaultError(self.option_group_name)
self.default_option_groups = {"MySQL": "default.mysql5.6",
"mysql": "default.mysql5.6",
"postgres": "default.postgres9.3"
@@ -148,6 +152,7 @@ class Database(BaseModel):
{{ database.db_instance_identifier }}
{{ database.dbi_resource_id }}
+ {{ database.instance_create_time }}
03:50-04:20
wed:06:38-wed:07:08
@@ -173,6 +178,10 @@ class Database(BaseModel):
{{ database.license_model }}
{{ database.engine_version }}
+
+ {{ database.option_group_name }}
+ in-sync
+
{% for db_parameter_group in database.db_parameter_groups() %}
@@ -373,7 +382,7 @@ class Database(BaseModel):
"Address": "{{ database.address }}",
"Port": "{{ database.port }}"
},
- "InstanceCreateTime": null,
+ "InstanceCreateTime": "{{ database.instance_create_time }}",
"Iops": null,
"ReadReplicaDBInstanceIdentifiers": [{%- for replica in database.replicas -%}
{%- if not loop.first -%},{%- endif -%}
@@ -873,13 +882,16 @@ class RDS2Backend(BaseBackend):
def create_option_group(self, option_group_kwargs):
option_group_id = option_group_kwargs['name']
- valid_option_group_engines = {'mysql': ['5.6'],
- 'oracle-se1': ['11.2'],
- 'oracle-se': ['11.2'],
- 'oracle-ee': ['11.2'],
+ valid_option_group_engines = {'mariadb': ['10.0', '10.1', '10.2', '10.3'],
+ 'mysql': ['5.5', '5.6', '5.7', '8.0'],
+ 'oracle-se2': ['11.2', '12.1', '12.2'],
+ 'oracle-se1': ['11.2', '12.1', '12.2'],
+ 'oracle-se': ['11.2', '12.1', '12.2'],
+ 'oracle-ee': ['11.2', '12.1', '12.2'],
'sqlserver-se': ['10.50', '11.00'],
- 'sqlserver-ee': ['10.50', '11.00']
- }
+ 'sqlserver-ee': ['10.50', '11.00'],
+ 'sqlserver-ex': ['10.50', '11.00'],
+ 'sqlserver-web': ['10.50', '11.00']}
if option_group_kwargs['name'] in self.option_groups:
raise RDSClientError('OptionGroupAlreadyExistsFault',
'An option group named {0} already exists.'.format(option_group_kwargs['name']))
@@ -905,8 +917,7 @@ class RDS2Backend(BaseBackend):
if option_group_name in self.option_groups:
return self.option_groups.pop(option_group_name)
else:
- raise RDSClientError(
- 'OptionGroupNotFoundFault', 'Specified OptionGroupName: {0} not found.'.format(option_group_name))
+ raise OptionGroupNotFoundFaultError(option_group_name)
def describe_option_groups(self, option_group_kwargs):
option_group_list = []
@@ -935,8 +946,7 @@ class RDS2Backend(BaseBackend):
else:
option_group_list.append(option_group)
if not len(option_group_list):
- raise RDSClientError('OptionGroupNotFoundFault',
- 'Specified OptionGroupName: {0} not found.'.format(option_group_kwargs['name']))
+ raise OptionGroupNotFoundFaultError(option_group_kwargs['name'])
return option_group_list[marker:max_records + marker]
@staticmethod
@@ -965,8 +975,7 @@ class RDS2Backend(BaseBackend):
def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None):
if option_group_name not in self.option_groups:
- raise RDSClientError('OptionGroupNotFoundFault',
- 'Specified OptionGroupName: {0} not found.'.format(option_group_name))
+ raise OptionGroupNotFoundFaultError(option_group_name)
if not options_to_include and not options_to_remove:
raise RDSClientError('InvalidParameterValue',
'At least one option must be added, modified, or removed.')
diff --git a/moto/rds2/responses.py b/moto/rds2/responses.py
index 66d4e0c52..e92625635 100644
--- a/moto/rds2/responses.py
+++ b/moto/rds2/responses.py
@@ -34,7 +34,7 @@ class RDS2Response(BaseResponse):
"master_user_password": self._get_param('MasterUserPassword'),
"master_username": self._get_param('MasterUsername'),
"multi_az": self._get_bool_param("MultiAZ"),
- # OptionGroupName
+ "option_group_name": self._get_param("OptionGroupName"),
"port": self._get_param('Port'),
# PreferredBackupWindow
# PreferredMaintenanceWindow
diff --git a/moto/route53/models.py b/moto/route53/models.py
index 5ed1c1476..681a9d6ff 100644
--- a/moto/route53/models.py
+++ b/moto/route53/models.py
@@ -85,6 +85,7 @@ class RecordSet(BaseModel):
self.health_check = kwargs.get('HealthCheckId')
self.hosted_zone_name = kwargs.get('HostedZoneName')
self.hosted_zone_id = kwargs.get('HostedZoneId')
+ self.alias_target = kwargs.get('AliasTarget')
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@@ -143,6 +144,13 @@ class RecordSet(BaseModel):
{% if record_set.ttl %}
{{ record_set.ttl }}
{% endif %}
+ {% if record_set.alias_target %}
+
+ {{ record_set.alias_target['HostedZoneId'] }}
+ {{ record_set.alias_target['DNSName'] }}
+ {{ record_set.alias_target['EvaluateTargetHealth'] }}
+
+ {% else %}
{% for record in record_set.records %}
@@ -150,6 +158,7 @@ class RecordSet(BaseModel):
{% endfor %}
+ {% endif %}
{% if record_set.health_check %}
{{ record_set.health_check }}
{% endif %}
diff --git a/moto/route53/responses.py b/moto/route53/responses.py
index bf705c87f..f933c575a 100644
--- a/moto/route53/responses.py
+++ b/moto/route53/responses.py
@@ -134,10 +134,7 @@ class Route53(BaseResponse):
# Depending on how many records there are, this may
# or may not be a list
resource_records = [resource_records]
- record_values = [x['Value'] for x in resource_records]
- elif 'AliasTarget' in record_set:
- record_values = [record_set['AliasTarget']['DNSName']]
- record_set['ResourceRecords'] = record_values
+ record_set['ResourceRecords'] = [x['Value'] for x in resource_records]
if action == 'CREATE':
the_zone.add_rrset(record_set)
else:
diff --git a/moto/s3/responses.py b/moto/s3/responses.py
index 46d811f81..a052e4cfb 100644
--- a/moto/s3/responses.py
+++ b/moto/s3/responses.py
@@ -807,7 +807,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
body = b''
if method == 'GET':
- return self._key_response_get(bucket_name, query, key_name, headers)
+ return self._key_response_get(bucket_name, query, key_name, headers=request.headers)
elif method == 'PUT':
return self._key_response_put(request, body, bucket_name, query, key_name, headers)
elif method == 'HEAD':
@@ -842,10 +842,15 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
parts=parts
)
version_id = query.get('versionId', [None])[0]
+ if_modified_since = headers.get('If-Modified-Since', None)
key = self.backend.get_key(
bucket_name, key_name, version_id=version_id)
if key is None:
raise MissingKey(key_name)
+ if if_modified_since:
+ if_modified_since = str_to_rfc_1123_datetime(if_modified_since)
+ if if_modified_since and key.last_modified < if_modified_since:
+ return 304, response_headers, 'Not Modified'
if 'acl' in query:
template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return 200, response_headers, template.render(obj=key)
diff --git a/moto/ses/feedback.py b/moto/ses/feedback.py
new file mode 100644
index 000000000..2d32f9ce0
--- /dev/null
+++ b/moto/ses/feedback.py
@@ -0,0 +1,81 @@
+"""
+SES Feedback messages
+Extracted from https://docs.aws.amazon.com/ses/latest/DeveloperGuide/notification-contents.html
+"""
+COMMON_MAIL = {
+ "notificationType": "Bounce, Complaint, or Delivery.",
+ "mail": {
+ "timestamp": "2018-10-08T14:05:45 +0000",
+ "messageId": "000001378603177f-7a5433e7-8edb-42ae-af10-f0181f34d6ee-000000",
+ "source": "sender@example.com",
+ "sourceArn": "arn:aws:ses:us-west-2:888888888888:identity/example.com",
+ "sourceIp": "127.0.3.0",
+ "sendingAccountId": "123456789012",
+ "destination": [
+ "recipient@example.com"
+ ],
+ "headersTruncated": False,
+ "headers": [
+ {
+ "name": "From",
+ "value": "\"Sender Name\" "
+ },
+ {
+ "name": "To",
+ "value": "\"Recipient Name\" "
+ }
+ ],
+ "commonHeaders": {
+ "from": [
+ "Sender Name "
+ ],
+ "date": "Mon, 08 Oct 2018 14:05:45 +0000",
+ "to": [
+ "Recipient Name "
+ ],
+ "messageId": " custom-message-ID",
+ "subject": "Message sent using Amazon SES"
+ }
+ }
+}
+BOUNCE = {
+ "bounceType": "Permanent",
+ "bounceSubType": "General",
+ "bouncedRecipients": [
+ {
+ "status": "5.0.0",
+ "action": "failed",
+ "diagnosticCode": "smtp; 550 user unknown",
+ "emailAddress": "recipient1@example.com"
+ },
+ {
+ "status": "4.0.0",
+ "action": "delayed",
+ "emailAddress": "recipient2@example.com"
+ }
+ ],
+ "reportingMTA": "example.com",
+ "timestamp": "2012-05-25T14:59:38.605Z",
+ "feedbackId": "000001378603176d-5a4b5ad9-6f30-4198-a8c3-b1eb0c270a1d-000000",
+ "remoteMtaIp": "127.0.2.0"
+}
+COMPLAINT = {
+ "userAgent": "AnyCompany Feedback Loop (V0.01)",
+ "complainedRecipients": [
+ {
+ "emailAddress": "recipient1@example.com"
+ }
+ ],
+ "complaintFeedbackType": "abuse",
+ "arrivalDate": "2009-12-03T04:24:21.000-05:00",
+ "timestamp": "2012-05-25T14:59:38.623Z",
+ "feedbackId": "000001378603177f-18c07c78-fa81-4a58-9dd1-fedc3cb8f49a-000000"
+}
+DELIVERY = {
+ "timestamp": "2014-05-28T22:41:01.184Z",
+ "processingTimeMillis": 546,
+ "recipients": ["success@simulator.amazonses.com"],
+ "smtpResponse": "250 ok: Message 64111812 accepted",
+ "reportingMTA": "a8-70.smtp-out.amazonses.com",
+ "remoteMtaIp": "127.0.2.0"
+}
diff --git a/moto/ses/models.py b/moto/ses/models.py
index 71fe9d9a1..0544ac278 100644
--- a/moto/ses/models.py
+++ b/moto/ses/models.py
@@ -4,13 +4,41 @@ import email
from email.utils import parseaddr
from moto.core import BaseBackend, BaseModel
+from moto.sns.models import sns_backends
from .exceptions import MessageRejectedError
from .utils import get_random_message_id
-
+from .feedback import COMMON_MAIL, BOUNCE, COMPLAINT, DELIVERY
RECIPIENT_LIMIT = 50
+class SESFeedback(BaseModel):
+
+ BOUNCE = "Bounce"
+ COMPLAINT = "Complaint"
+ DELIVERY = "Delivery"
+
+ SUCCESS_ADDR = "success"
+ BOUNCE_ADDR = "bounce"
+ COMPLAINT_ADDR = "complaint"
+
+ FEEDBACK_SUCCESS_MSG = {"test": "success"}
+ FEEDBACK_BOUNCE_MSG = {"test": "bounce"}
+ FEEDBACK_COMPLAINT_MSG = {"test": "complaint"}
+
+ @staticmethod
+ def generate_message(msg_type):
+ msg = dict(COMMON_MAIL)
+ if msg_type == SESFeedback.BOUNCE:
+ msg["bounce"] = BOUNCE
+ elif msg_type == SESFeedback.COMPLAINT:
+ msg["complaint"] = COMPLAINT
+ elif msg_type == SESFeedback.DELIVERY:
+ msg["delivery"] = DELIVERY
+
+ return msg
+
+
class Message(BaseModel):
def __init__(self, message_id, source, subject, body, destinations):
@@ -48,6 +76,7 @@ class SESBackend(BaseBackend):
self.domains = []
self.sent_messages = []
self.sent_message_count = 0
+ self.sns_topics = {}
def _is_verified_address(self, source):
_, address = parseaddr(source)
@@ -77,7 +106,7 @@ class SESBackend(BaseBackend):
else:
self.domains.remove(identity)
- def send_email(self, source, subject, body, destinations):
+ def send_email(self, source, subject, body, destinations, region):
recipient_count = sum(map(len, destinations.values()))
if recipient_count > RECIPIENT_LIMIT:
raise MessageRejectedError('Too many recipients.')
@@ -86,13 +115,46 @@ class SESBackend(BaseBackend):
"Email address not verified %s" % source
)
+ self.__process_sns_feedback__(source, destinations, region)
+
message_id = get_random_message_id()
message = Message(message_id, source, subject, body, destinations)
self.sent_messages.append(message)
self.sent_message_count += recipient_count
return message
- def send_raw_email(self, source, destinations, raw_data):
+ def __type_of_message__(self, destinations):
+ """Checks the destination for any special address that could indicate delivery, complaint or bounce
+ like in SES simualtor"""
+ alladdress = destinations.get("ToAddresses", []) + destinations.get("CcAddresses", []) + destinations.get("BccAddresses", [])
+ for addr in alladdress:
+ if SESFeedback.SUCCESS_ADDR in addr:
+ return SESFeedback.DELIVERY
+ elif SESFeedback.COMPLAINT_ADDR in addr:
+ return SESFeedback.COMPLAINT
+ elif SESFeedback.BOUNCE_ADDR in addr:
+ return SESFeedback.BOUNCE
+
+ return None
+
+ def __generate_feedback__(self, msg_type):
+ """Generates the SNS message for the feedback"""
+ return SESFeedback.generate_message(msg_type)
+
+ def __process_sns_feedback__(self, source, destinations, region):
+ domain = str(source)
+ if "@" in domain:
+ domain = domain.split("@")[1]
+ if domain in self.sns_topics:
+ msg_type = self.__type_of_message__(destinations)
+ if msg_type is not None:
+ sns_topic = self.sns_topics[domain].get(msg_type, None)
+ if sns_topic is not None:
+ message = self.__generate_feedback__(msg_type)
+ if message:
+ sns_backends[region].publish(sns_topic, message)
+
+ def send_raw_email(self, source, destinations, raw_data, region):
if source is not None:
_, source_email_address = parseaddr(source)
if source_email_address not in self.addresses:
@@ -122,6 +184,8 @@ class SESBackend(BaseBackend):
if recipient_count > RECIPIENT_LIMIT:
raise MessageRejectedError('Too many recipients.')
+ self.__process_sns_feedback__(source, destinations, region)
+
self.sent_message_count += recipient_count
message_id = get_random_message_id()
message = RawMessage(message_id, source, destinations, raw_data)
@@ -131,5 +195,16 @@ class SESBackend(BaseBackend):
def get_send_quota(self):
return SESQuota(self.sent_message_count)
+ def set_identity_notification_topic(self, identity, notification_type, sns_topic):
+ identity_sns_topics = self.sns_topics.get(identity, {})
+ if sns_topic is None:
+ del identity_sns_topics[notification_type]
+ else:
+ identity_sns_topics[notification_type] = sns_topic
+
+ self.sns_topics[identity] = identity_sns_topics
+
+ return {}
+
ses_backend = SESBackend()
diff --git a/moto/ses/responses.py b/moto/ses/responses.py
index bdf873836..d2dda55f1 100644
--- a/moto/ses/responses.py
+++ b/moto/ses/responses.py
@@ -70,7 +70,7 @@ class EmailResponse(BaseResponse):
break
destinations[dest_type].append(address[0])
- message = ses_backend.send_email(source, subject, body, destinations)
+ message = ses_backend.send_email(source, subject, body, destinations, self.region)
template = self.response_template(SEND_EMAIL_RESPONSE)
return template.render(message=message)
@@ -92,7 +92,7 @@ class EmailResponse(BaseResponse):
break
destinations.append(address[0])
- message = ses_backend.send_raw_email(source, destinations, raw_data)
+ message = ses_backend.send_raw_email(source, destinations, raw_data, self.region)
template = self.response_template(SEND_RAW_EMAIL_RESPONSE)
return template.render(message=message)
@@ -101,6 +101,18 @@ class EmailResponse(BaseResponse):
template = self.response_template(GET_SEND_QUOTA_RESPONSE)
return template.render(quota=quota)
+ def set_identity_notification_topic(self):
+
+ identity = self.querystring.get("Identity")[0]
+ not_type = self.querystring.get("NotificationType")[0]
+ sns_topic = self.querystring.get("SnsTopic")
+ if sns_topic:
+ sns_topic = sns_topic[0]
+
+ ses_backend.set_identity_notification_topic(identity, not_type, sns_topic)
+ template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE)
+ return template.render()
+
VERIFY_EMAIL_IDENTITY = """
@@ -200,3 +212,10 @@ GET_SEND_QUOTA_RESPONSE = """
+
+
+ 47e0ef1a-9bf2-11e1-9279-0100e8cf109a
+
+"""
diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py
index faa467aab..f5afc1e7e 100644
--- a/tests/test_dynamodb2/test_dynamodb.py
+++ b/tests/test_dynamodb2/test_dynamodb.py
@@ -838,44 +838,47 @@ def test_filter_expression():
filter_expr.expr(row1).should.be(True)
# NOT test 2
- filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': 8}})
+ filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': '8'}})
filter_expr.expr(row1).should.be(False) # Id = 8 so should be false
# AND test
- filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 7}})
+ filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '7'}})
filter_expr.expr(row1).should.be(True)
filter_expr.expr(row2).should.be(False)
# OR test
- filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': 5}, ':v1': {'N': 8}})
+ filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '8'}})
filter_expr.expr(row1).should.be(True)
# BETWEEN test
- filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 10}})
+ filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '10'}})
filter_expr.expr(row1).should.be(True)
# PAREN test
- filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': 8}, ':v1': {'N': 5}})
+ filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': '8'}, ':v1': {'N': '5'}})
filter_expr.expr(row1).should.be(True)
# IN test
- filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN :v0', {}, {':v0': {'NS': [7, 8, 9]}})
+ filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN (:v0, :v1, :v2)', {}, {
+ ':v0': {'N': '7'},
+ ':v1': {'N': '8'},
+ ':v2': {'N': '9'}})
filter_expr.expr(row1).should.be(True)
# attribute function tests (with extra spaces)
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {})
filter_expr.expr(row1).should.be(True)
- filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, N)', {}, {})
+ filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, :v0)', {}, {':v0': {'S': 'N'}})
filter_expr.expr(row1).should.be(True)
# beginswith function test
- filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, Some)', {}, {})
+ filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, :v0)', {}, {':v0': {'S': 'Some'}})
filter_expr.expr(row1).should.be(True)
filter_expr.expr(row2).should.be(False)
# contains function test
- filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, test1)', {}, {})
+ filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, :v0)', {}, {':v0': {'S': 'test1'}})
filter_expr.expr(row1).should.be(True)
filter_expr.expr(row2).should.be(False)
@@ -916,14 +919,26 @@ def test_query_filter():
TableName='test1',
Item={
'client': {'S': 'client1'},
- 'app': {'S': 'app1'}
+ 'app': {'S': 'app1'},
+ 'nested': {'M': {
+ 'version': {'S': 'version1'},
+ 'contents': {'L': [
+ {'S': 'value1'}, {'S': 'value2'},
+ ]},
+ }},
}
)
client.put_item(
TableName='test1',
Item={
'client': {'S': 'client1'},
- 'app': {'S': 'app2'}
+ 'app': {'S': 'app2'},
+ 'nested': {'M': {
+ 'version': {'S': 'version2'},
+ 'contents': {'L': [
+ {'S': 'value1'}, {'S': 'value2'},
+ ]},
+ }},
}
)
@@ -945,6 +960,18 @@ def test_query_filter():
)
assert response['Count'] == 2
+ response = table.query(
+ KeyConditionExpression=Key('client').eq('client1'),
+ FilterExpression=Attr('nested.version').contains('version')
+ )
+ assert response['Count'] == 2
+
+ response = table.query(
+ KeyConditionExpression=Key('client').eq('client1'),
+ FilterExpression=Attr('nested.contents[0]').eq('value1')
+ )
+ assert response['Count'] == 2
+
@mock_dynamodb2
def test_scan_filter():
@@ -1223,7 +1250,7 @@ def test_delete_item():
with assert_raises(ClientError) as ex:
table.delete_item(Key={'client': 'client1', 'app': 'app1'},
ReturnValues='ALL_NEW')
-
+
# Test deletion and returning old value
response = table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_OLD')
response['Attributes'].should.contain('client')
@@ -1526,7 +1553,7 @@ def test_put_return_attributes():
ReturnValues='NONE'
)
assert 'Attributes' not in r
-
+
r = dynamodb.put_item(
TableName='moto-test',
Item={'id': {'S': 'foo'}, 'col1': {'S': 'val2'}},
@@ -1543,7 +1570,7 @@ def test_put_return_attributes():
ex.exception.response['Error']['Code'].should.equal('ValidationException')
ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400)
ex.exception.response['Error']['Message'].should.equal('Return values set to invalid value')
-
+
@mock_dynamodb2
def test_query_global_secondary_index_when_created_via_update_table_resource():
@@ -1651,7 +1678,7 @@ def test_dynamodb_streams_1():
'StreamViewType': 'NEW_AND_OLD_IMAGES'
}
)
-
+
assert 'StreamSpecification' in resp['TableDescription']
assert resp['TableDescription']['StreamSpecification'] == {
'StreamEnabled': True,
@@ -1659,11 +1686,11 @@ def test_dynamodb_streams_1():
}
assert 'LatestStreamLabel' in resp['TableDescription']
assert 'LatestStreamArn' in resp['TableDescription']
-
+
resp = conn.delete_table(TableName='test-streams')
assert 'StreamSpecification' in resp['TableDescription']
-
+
@mock_dynamodb2
def test_dynamodb_streams_2():
@@ -1694,11 +1721,10 @@ def test_dynamodb_streams_2():
assert 'LatestStreamLabel' in resp['TableDescription']
assert 'LatestStreamArn' in resp['TableDescription']
-
+
@mock_dynamodb2
def test_condition_expressions():
client = boto3.client('dynamodb', region_name='us-east-1')
- dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
# Create the DynamoDB table.
client.create_table(
@@ -1751,6 +1777,57 @@ def test_condition_expressions():
}
)
+ client.put_item(
+ TableName='test1',
+ Item={
+ 'client': {'S': 'client1'},
+ 'app': {'S': 'app1'},
+ 'match': {'S': 'match'},
+ 'existing': {'S': 'existing'},
+ },
+ ConditionExpression='attribute_exists(#nonexistent) OR attribute_exists(#existing)',
+ ExpressionAttributeNames={
+ '#nonexistent': 'nope',
+ '#existing': 'existing'
+ }
+ )
+
+ client.put_item(
+ TableName='test1',
+ Item={
+ 'client': {'S': 'client1'},
+ 'app': {'S': 'app1'},
+ 'match': {'S': 'match'},
+ 'existing': {'S': 'existing'},
+ },
+ ConditionExpression='#client BETWEEN :a AND :z',
+ ExpressionAttributeNames={
+ '#client': 'client',
+ },
+ ExpressionAttributeValues={
+ ':a': {'S': 'a'},
+ ':z': {'S': 'z'},
+ }
+ )
+
+ client.put_item(
+ TableName='test1',
+ Item={
+ 'client': {'S': 'client1'},
+ 'app': {'S': 'app1'},
+ 'match': {'S': 'match'},
+ 'existing': {'S': 'existing'},
+ },
+ ConditionExpression='#client IN (:client1, :client2)',
+ ExpressionAttributeNames={
+ '#client': 'client',
+ },
+ ExpressionAttributeValues={
+ ':client1': {'S': 'client1'},
+ ':client2': {'S': 'client2'},
+ }
+ )
+
with assert_raises(client.exceptions.ConditionalCheckFailedException):
client.put_item(
TableName='test1',
@@ -1803,6 +1880,89 @@ def test_condition_expressions():
}
)
+ # Make sure update_item honors ConditionExpression as well
+ client.update_item(
+ TableName='test1',
+ Key={
+ 'client': {'S': 'client1'},
+ 'app': {'S': 'app1'},
+ },
+ UpdateExpression='set #match=:match',
+ ConditionExpression='attribute_exists(#existing)',
+ ExpressionAttributeNames={
+ '#existing': 'existing',
+ '#match': 'match',
+ },
+ ExpressionAttributeValues={
+ ':match': {'S': 'match'}
+ }
+ )
+
+ with assert_raises(client.exceptions.ConditionalCheckFailedException):
+ client.update_item(
+ TableName='test1',
+ Key={
+ 'client': { 'S': 'client1'},
+ 'app': { 'S': 'app1'},
+ },
+ UpdateExpression='set #match=:match',
+ ConditionExpression='attribute_not_exists(#existing)',
+ ExpressionAttributeValues={
+ ':match': {'S': 'match'}
+ },
+ ExpressionAttributeNames={
+ '#existing': 'existing',
+ '#match': 'match',
+ },
+ )
+
+
+@mock_dynamodb2
+def test_condition_expression__attr_doesnt_exist():
+ client = boto3.client('dynamodb', region_name='us-east-1')
+
+ client.create_table(
+ TableName='test',
+ KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}],
+ AttributeDefinitions=[
+ {'AttributeName': 'forum_name', 'AttributeType': 'S'},
+ ],
+ ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1},
+ )
+
+ client.put_item(
+ TableName='test',
+ Item={
+ 'forum_name': {'S': 'foo'},
+ 'ttl': {'N': 'bar'},
+ }
+ )
+
+
+ def update_if_attr_doesnt_exist():
+ # Test nonexistent top-level attribute.
+ client.update_item(
+ TableName='test',
+ Key={
+ 'forum_name': {'S': 'the-key'},
+ 'subject': {'S': 'the-subject'},
+ },
+ UpdateExpression='set #new_state=:new_state, #ttl=:ttl',
+ ConditionExpression='attribute_not_exists(#new_state)',
+ ExpressionAttributeNames={'#new_state': 'foobar', '#ttl': 'ttl'},
+ ExpressionAttributeValues={
+ ':new_state': {'S': 'some-value'},
+ ':ttl': {'N': '12345.67'},
+ },
+ ReturnValues='ALL_NEW',
+ )
+
+ update_if_attr_doesnt_exist()
+
+ # Second time should fail
+ with assert_raises(client.exceptions.ConditionalCheckFailedException):
+ update_if_attr_doesnt_exist()
+
@mock_dynamodb2
def test_query_gsi_with_range_key():
diff --git a/tests/test_rds2/test_rds2.py b/tests/test_rds2/test_rds2.py
index a25b53196..8ea296c2c 100644
--- a/tests/test_rds2/test_rds2.py
+++ b/tests/test_rds2/test_rds2.py
@@ -34,6 +34,39 @@ def test_create_database():
db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False)
db_instance['DbiResourceId'].should.contain("db-")
db_instance['CopyTagsToSnapshot'].should.equal(False)
+ db_instance['InstanceCreateTime'].should.be.a("datetime.datetime")
+
+
+@mock_rds2
+def test_create_database_non_existing_option_group():
+ conn = boto3.client('rds', region_name='us-west-2')
+ database = conn.create_db_instance.when.called_with(
+ DBInstanceIdentifier='db-master-1',
+ AllocatedStorage=10,
+ Engine='postgres',
+ DBName='staging-postgres',
+ DBInstanceClass='db.m1.small',
+ OptionGroupName='non-existing').should.throw(ClientError)
+
+
+@mock_rds2
+def test_create_database_with_option_group():
+ conn = boto3.client('rds', region_name='us-west-2')
+ conn.create_option_group(OptionGroupName='my-og',
+ EngineName='mysql',
+ MajorEngineVersion='5.6',
+ OptionGroupDescription='test option group')
+ database = conn.create_db_instance(DBInstanceIdentifier='db-master-1',
+ AllocatedStorage=10,
+ Engine='postgres',
+ DBName='staging-postgres',
+ DBInstanceClass='db.m1.small',
+ OptionGroupName='my-og')
+ db_instance = database['DBInstance']
+ db_instance['AllocatedStorage'].should.equal(10)
+ db_instance['DBInstanceClass'].should.equal('db.m1.small')
+ db_instance['DBName'].should.equal('staging-postgres')
+ db_instance['OptionGroupMemberships'][0]['OptionGroupName'].should.equal('my-og')
@mock_rds2
@@ -204,6 +237,7 @@ def test_get_databases_paginated():
resp3 = conn.describe_db_instances(MaxRecords=100)
resp3["DBInstances"].should.have.length_of(51)
+
@mock_rds2
def test_describe_non_existant_database():
conn = boto3.client('rds', region_name='us-west-2')
diff --git a/tests/test_route53/test_route53.py b/tests/test_route53/test_route53.py
index f43657dad..ca652af88 100644
--- a/tests/test_route53/test_route53.py
+++ b/tests/test_route53/test_route53.py
@@ -173,14 +173,16 @@ def test_alias_rrset():
changes.commit()
rrsets = conn.get_all_rrsets(zoneid, type="A")
- rrset_records = [(rr_set.name, rr) for rr_set in rrsets for rr in rr_set.resource_records]
- rrset_records.should.have.length_of(2)
- rrset_records.should.contain(('foo.alias.testdns.aws.com.', 'foo.testdns.aws.com'))
- rrset_records.should.contain(('bar.alias.testdns.aws.com.', 'bar.testdns.aws.com'))
- rrsets[0].resource_records[0].should.equal('foo.testdns.aws.com')
+ alias_targets = [rr_set.alias_dns_name for rr_set in rrsets]
+ alias_targets.should.have.length_of(2)
+ alias_targets.should.contain('foo.testdns.aws.com')
+ alias_targets.should.contain('bar.testdns.aws.com')
+ rrsets[0].alias_dns_name.should.equal('foo.testdns.aws.com')
+ rrsets[0].resource_records.should.have.length_of(0)
rrsets = conn.get_all_rrsets(zoneid, type="CNAME")
rrsets.should.have.length_of(1)
- rrsets[0].resource_records[0].should.equal('bar.testdns.aws.com')
+ rrsets[0].alias_dns_name.should.equal('bar.testdns.aws.com')
+ rrsets[0].resource_records.should.have.length_of(0)
@mock_route53_deprecated
@@ -583,6 +585,39 @@ def test_change_resource_record_sets_crud_valid():
cname_record_detail['TTL'].should.equal(60)
cname_record_detail['ResourceRecords'].should.equal([{'Value': '192.168.1.1'}])
+ # Update to add Alias.
+ cname_alias_record_endpoint_payload = {
+ 'Comment': 'Update to Alias prod.redis.db',
+ 'Changes': [
+ {
+ 'Action': 'UPSERT',
+ 'ResourceRecordSet': {
+ 'Name': 'prod.redis.db.',
+ 'Type': 'A',
+ 'TTL': 60,
+ 'AliasTarget': {
+ 'HostedZoneId': hosted_zone_id,
+ 'DNSName': 'prod.redis.alias.',
+ 'EvaluateTargetHealth': False,
+ }
+ }
+ }
+ ]
+ }
+ conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=cname_alias_record_endpoint_payload)
+
+ response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id)
+ cname_alias_record_detail = response['ResourceRecordSets'][0]
+ cname_alias_record_detail['Name'].should.equal('prod.redis.db.')
+ cname_alias_record_detail['Type'].should.equal('A')
+ cname_alias_record_detail['TTL'].should.equal(60)
+ cname_alias_record_detail['AliasTarget'].should.equal({
+ 'HostedZoneId': hosted_zone_id,
+ 'DNSName': 'prod.redis.alias.',
+ 'EvaluateTargetHealth': False,
+ })
+ cname_alias_record_detail.should_not.contain('ResourceRecords')
+
# Delete record with wrong type.
delete_payload = {
'Comment': 'delete prod.redis.db',
diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py
index f26964ab7..697c47865 100644
--- a/tests/test_s3/test_s3.py
+++ b/tests/test_s3/test_s3.py
@@ -1596,6 +1596,28 @@ def test_boto3_delete_versioned_bucket():
client.delete_bucket(Bucket='blah')
+@mock_s3
+def test_boto3_get_object_if_modified_since():
+ s3 = boto3.client('s3', region_name='us-east-1')
+ bucket_name = "blah"
+ s3.create_bucket(Bucket=bucket_name)
+
+ key = 'hello.txt'
+
+ s3.put_object(
+ Bucket=bucket_name,
+ Key=key,
+ Body='test'
+ )
+
+ with assert_raises(botocore.exceptions.ClientError) as err:
+ s3.get_object(
+ Bucket=bucket_name,
+ Key=key,
+ IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1)
+ )
+ e = err.exception
+ e.response['Error'].should.equal({'Code': '304', 'Message': 'Not Modified'})
@mock_s3
def test_boto3_head_object_if_modified_since():
diff --git a/tests/test_ses/test_ses_sns_boto3.py b/tests/test_ses/test_ses_sns_boto3.py
new file mode 100644
index 000000000..37f79a8b0
--- /dev/null
+++ b/tests/test_ses/test_ses_sns_boto3.py
@@ -0,0 +1,114 @@
+from __future__ import unicode_literals
+
+import boto3
+import json
+from botocore.exceptions import ClientError
+from six.moves.email_mime_multipart import MIMEMultipart
+from six.moves.email_mime_text import MIMEText
+
+import sure # noqa
+from nose import tools
+from moto import mock_ses, mock_sns, mock_sqs
+from moto.ses.models import SESFeedback
+
+
+@mock_ses
+def test_enable_disable_ses_sns_communication():
+ conn = boto3.client('ses', region_name='us-east-1')
+ conn.set_identity_notification_topic(
+ Identity='test.com',
+ NotificationType='Bounce',
+ SnsTopic='the-arn'
+ )
+ conn.set_identity_notification_topic(
+ Identity='test.com',
+ NotificationType='Bounce'
+ )
+
+
+def __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region, expected_msg):
+ """Setup the AWS environment to test the SES SNS Feedback"""
+ # Environment setup
+ # Create SQS queue
+ sqs_conn.create_queue(QueueName=queue)
+ # Create SNS topic
+ create_topic_response = sns_conn.create_topic(Name=topic)
+ topic_arn = create_topic_response["TopicArn"]
+ # Subscribe the SNS topic to the SQS queue
+ sns_conn.subscribe(TopicArn=topic_arn,
+ Protocol="sqs",
+ Endpoint="arn:aws:sqs:%s:123456789012:%s" % (region, queue))
+ # Verify SES domain
+ ses_conn.verify_domain_identity(Domain=domain)
+ # Setup SES notification topic
+ if expected_msg is not None:
+ ses_conn.set_identity_notification_topic(
+ Identity=domain,
+ NotificationType=expected_msg,
+ SnsTopic=topic_arn
+ )
+
+
+def __test_sns_feedback__(addr, expected_msg):
+ region_name = "us-east-1"
+ ses_conn = boto3.client('ses', region_name=region_name)
+ sns_conn = boto3.client('sns', region_name=region_name)
+ sqs_conn = boto3.resource('sqs', region_name=region_name)
+ domain = "example.com"
+ topic = "bounce-arn-feedback"
+ queue = "feedback-test-queue"
+
+ __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region_name, expected_msg)
+
+ # Send the message
+ kwargs = dict(
+ Source="test@" + domain,
+ Destination={
+ "ToAddresses": [addr + "@" + domain],
+ "CcAddresses": ["test_cc@" + domain],
+ "BccAddresses": ["test_bcc@" + domain],
+ },
+ Message={
+ "Subject": {"Data": "test subject"},
+ "Body": {"Text": {"Data": "test body"}}
+ }
+ )
+ ses_conn.send_email(**kwargs)
+
+ # Wait for messages in the queues
+ queue = sqs_conn.get_queue_by_name(QueueName=queue)
+ messages = queue.receive_messages(MaxNumberOfMessages=1)
+ if expected_msg is not None:
+ msg = messages[0].body
+ msg = json.loads(msg)
+ assert msg["Message"] == SESFeedback.generate_message(expected_msg)
+ else:
+ assert len(messages) == 0
+
+
+@mock_sqs
+@mock_sns
+@mock_ses
+def test_no_sns_feedback():
+ __test_sns_feedback__("test", None)
+
+
+@mock_sqs
+@mock_sns
+@mock_ses
+def test_sns_feedback_bounce():
+ __test_sns_feedback__(SESFeedback.BOUNCE_ADDR, SESFeedback.BOUNCE)
+
+
+@mock_sqs
+@mock_sns
+@mock_ses
+def test_sns_feedback_complaint():
+ __test_sns_feedback__(SESFeedback.COMPLAINT_ADDR, SESFeedback.COMPLAINT)
+
+
+@mock_sqs
+@mock_sns
+@mock_ses
+def test_sns_feedback_delivery():
+ __test_sns_feedback__(SESFeedback.SUCCESS_ADDR, SESFeedback.DELIVERY)