Techdebt: MyPy DynamoDB (#5863)

This commit is contained in:
Bert Blommers 2023-01-22 21:06:41 -01:00 committed by GitHub
parent 90150d30c6
commit 00ad788975
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1856 additions and 1640 deletions

View File

@ -1,12 +1,17 @@
import re import re
from collections import deque from collections import deque, namedtuple
from collections import namedtuple from typing import Any, Dict, List, Tuple, Deque, Optional, Iterable, Union
from moto.dynamodb.models.dynamo_type import Item
from moto.dynamodb.exceptions import ConditionAttributeIsReservedKeyword from moto.dynamodb.exceptions import ConditionAttributeIsReservedKeyword
from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords
def get_filter_expression(expr, names, values): def get_filter_expression(
expr: Optional[str],
names: Optional[Dict[str, str]],
values: Optional[Dict[str, str]],
) -> Union["Op", "Func"]:
""" """
Parse a filter expression into an Op. Parse a filter expression into an Op.
@ -18,7 +23,7 @@ def get_filter_expression(expr, names, values):
return parser.parse() return parser.parse()
def get_expected(expected): def get_expected(expected: Dict[str, Any]) -> Union["Op", "Func"]:
""" """
Parse a filter expression into an Op. Parse a filter expression into an Op.
@ -26,7 +31,7 @@ def get_expected(expected):
expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)'
expr = 'Id > 5 AND Subs < 7' expr = 'Id > 5 AND Subs < 7'
""" """
ops = { ops: Dict[str, Any] = {
"EQ": OpEqual, "EQ": OpEqual,
"NE": OpNotEqual, "NE": OpNotEqual,
"LE": OpLessThanOrEqual, "LE": OpLessThanOrEqual,
@ -43,7 +48,7 @@ def get_expected(expected):
} }
# NOTE: Always uses ConditionalOperator=AND # NOTE: Always uses ConditionalOperator=AND
conditions = [] conditions: List[Union["Op", "Func"]] = []
for key, cond in expected.items(): for key, cond in expected.items():
path = AttributePath([key]) path = AttributePath([key])
if "Exists" in cond: if "Exists" in cond:
@ -66,26 +71,28 @@ def get_expected(expected):
for condition in conditions[1:]: for condition in conditions[1:]:
output = ConditionalOp(output, condition) output = ConditionalOp(output, condition)
else: else:
return OpDefault(None, None) return OpDefault(None, None) # type: ignore[arg-type]
return output return output
class Op(object): class Op:
""" """
Base class for a FilterExpression operator Base class for a FilterExpression operator
""" """
OP = "" OP = ""
def __init__(self, lhs, rhs): def __init__(
self, lhs: Union["Func", "Op", "Operand"], rhs: Union["Func", "Op", "Operand"]
):
self.lhs = lhs self.lhs = lhs
self.rhs = rhs self.rhs = rhs
def expr(self, item): def expr(self, item: Optional[Item]) -> bool: # type: ignore
raise NotImplementedError(f"Expr not defined for {type(self)}") raise NotImplementedError(f"Expr not defined for {type(self)}")
def __repr__(self): def __repr__(self) -> str:
return f"({self.lhs} {self.OP} {self.rhs})" return f"({self.lhs} {self.OP} {self.rhs})"
@ -125,7 +132,7 @@ COMPARISON_FUNCS = {
} }
def get_comparison_func(range_comparison): def get_comparison_func(range_comparison: str) -> Any:
return COMPARISON_FUNCS.get(range_comparison) return COMPARISON_FUNCS.get(range_comparison)
@ -136,15 +143,15 @@ class RecursionStopIteration(StopIteration):
class ConditionExpressionParser: class ConditionExpressionParser:
def __init__( def __init__(
self, self,
condition_expression, condition_expression: Optional[str],
expression_attribute_names, expression_attribute_names: Optional[Dict[str, str]],
expression_attribute_values, expression_attribute_values: Optional[Dict[str, str]],
): ):
self.condition_expression = condition_expression self.condition_expression = condition_expression
self.expression_attribute_names = expression_attribute_names self.expression_attribute_names = expression_attribute_names
self.expression_attribute_values = expression_attribute_values self.expression_attribute_values = expression_attribute_values
def parse(self): def parse(self) -> Union[Op, "Func"]:
"""Returns a syntax tree for the expression. """Returns a syntax tree for the expression.
The tree, and all of the nodes in the tree are a tuple of The tree, and all of the nodes in the tree are a tuple of
@ -181,7 +188,7 @@ class ConditionExpressionParser:
""" """
if not self.condition_expression: if not self.condition_expression:
return OpDefault(None, None) return OpDefault(None, None) # type: ignore[arg-type]
nodes = self._lex_condition_expression() nodes = self._lex_condition_expression()
nodes = self._parse_paths(nodes) nodes = self._parse_paths(nodes)
# NOTE: The docs say that functions should be parsed after # NOTE: The docs say that functions should be parsed after
@ -242,12 +249,12 @@ class ConditionExpressionParser:
Node = namedtuple("Node", ["nonterminal", "kind", "text", "value", "children"]) Node = namedtuple("Node", ["nonterminal", "kind", "text", "value", "children"])
@classmethod @classmethod
def raise_exception_if_keyword(cls, attribute): def raise_exception_if_keyword(cls, attribute: str) -> None:
if attribute.upper() in ReservedKeywords.get_reserved_keywords(): if attribute.upper() in ReservedKeywords.get_reserved_keywords():
raise ConditionAttributeIsReservedKeyword(attribute) raise ConditionAttributeIsReservedKeyword(attribute)
def _lex_condition_expression(self): def _lex_condition_expression(self) -> Deque[Node]:
nodes = deque() nodes: Deque[ConditionExpressionParser.Node] = deque()
remaining_expression = self.condition_expression remaining_expression = self.condition_expression
while remaining_expression: while remaining_expression:
node, remaining_expression = self._lex_one_node(remaining_expression) node, remaining_expression = self._lex_one_node(remaining_expression)
@ -256,7 +263,7 @@ class ConditionExpressionParser:
nodes.append(node) nodes.append(node)
return nodes return nodes
def _lex_one_node(self, remaining_expression): def _lex_one_node(self, remaining_expression: str) -> Tuple[Node, str]:
# TODO: Handle indexing like [1] # TODO: Handle indexing like [1]
attribute_regex = r"(:|#)?[A-z0-9\-_]+" attribute_regex = r"(:|#)?[A-z0-9\-_]+"
patterns = [ patterns = [
@ -305,8 +312,8 @@ class ConditionExpressionParser:
return node, remaining_expression return node, remaining_expression
def _parse_paths(self, nodes): def _parse_paths(self, nodes: Deque[Node]) -> Deque[Node]:
output = deque() output: Deque[ConditionExpressionParser.Node] = deque()
while nodes: while nodes:
node = nodes.popleft() node = nodes.popleft()
@ -339,7 +346,7 @@ class ConditionExpressionParser:
output.append(node) output.append(node)
return output return output
def _parse_path_element(self, name): def _parse_path_element(self, name: str) -> Node:
reserved = { reserved = {
"and": self.Nonterminal.AND, "and": self.Nonterminal.AND,
"or": self.Nonterminal.OR, "or": self.Nonterminal.OR,
@ -416,11 +423,11 @@ class ConditionExpressionParser:
children=[], children=[],
) )
def _lookup_expression_attribute_value(self, name): def _lookup_expression_attribute_value(self, name: str) -> str:
return self.expression_attribute_values[name] return self.expression_attribute_values[name] # type: ignore[index]
def _lookup_expression_attribute_name(self, name): def _lookup_expression_attribute_name(self, name: str) -> str:
return self.expression_attribute_names[name] return self.expression_attribute_names[name] # type: ignore[index]
# NOTE: The following constructions are ordered from high precedence to low precedence # NOTE: The following constructions are ordered from high precedence to low precedence
# according to # according to
@ -464,7 +471,7 @@ class ConditionExpressionParser:
# contains (path, operand) # contains (path, operand)
# size (path) # size (path)
def _matches(self, nodes, production): def _matches(self, nodes: Deque[Node], production: List[str]) -> bool:
"""Check if the nodes start with the given production. """Check if the nodes start with the given production.
Parameters Parameters
@ -484,9 +491,9 @@ class ConditionExpressionParser:
return False return False
return True return True
def _apply_comparator(self, nodes): def _apply_comparator(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := operand comparator operand.""" """Apply condition := operand comparator operand."""
output = deque() output: Deque[ConditionExpressionParser.Node] = deque()
while nodes: while nodes:
if self._matches(nodes, ["*", "COMPARATOR"]): if self._matches(nodes, ["*", "COMPARATOR"]):
@ -511,9 +518,9 @@ class ConditionExpressionParser:
output.append(nodes.popleft()) output.append(nodes.popleft())
return output return output
def _apply_in(self, nodes): def _apply_in(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := operand IN ( operand , ... ).""" """Apply condition := operand IN ( operand , ... )."""
output = deque() output: Deque[ConditionExpressionParser.Node] = deque()
while nodes: while nodes:
if self._matches(nodes, ["*", "IN"]): if self._matches(nodes, ["*", "IN"]):
self._assert( self._assert(
@ -553,9 +560,9 @@ class ConditionExpressionParser:
output.append(nodes.popleft()) output.append(nodes.popleft())
return output return output
def _apply_between(self, nodes): def _apply_between(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := operand BETWEEN operand AND operand.""" """Apply condition := operand BETWEEN operand AND operand."""
output = deque() output: Deque[ConditionExpressionParser.Node] = deque()
while nodes: while nodes:
if self._matches(nodes, ["*", "BETWEEN"]): if self._matches(nodes, ["*", "BETWEEN"]):
self._assert( self._assert(
@ -584,9 +591,9 @@ class ConditionExpressionParser:
output.append(nodes.popleft()) output.append(nodes.popleft())
return output return output
def _apply_functions(self, nodes): def _apply_functions(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := function_name (operand , ...).""" """Apply condition := function_name (operand , ...)."""
output = deque() output: Deque[ConditionExpressionParser.Node] = deque()
either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE} either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE}
expected_argument_kind_map = { expected_argument_kind_map = {
"attribute_exists": [{self.Kind.PATH}], "attribute_exists": [{self.Kind.PATH}],
@ -656,9 +663,11 @@ class ConditionExpressionParser:
output.append(nodes.popleft()) output.append(nodes.popleft())
return output return output
def _apply_parens_and_booleans(self, nodes, left_paren=None): def _apply_parens_and_booleans(
self, nodes: Deque[Node], left_paren: Any = None
) -> Deque[Node]:
"""Apply condition := ( condition ) and booleans.""" """Apply condition := ( condition ) and booleans."""
output = deque() output: Deque[ConditionExpressionParser.Node] = deque()
while nodes: while nodes:
if self._matches(nodes, ["LEFT_PAREN"]): if self._matches(nodes, ["LEFT_PAREN"]):
parsed = self._apply_parens_and_booleans( parsed = self._apply_parens_and_booleans(
@ -696,7 +705,7 @@ class ConditionExpressionParser:
self._assert(left_paren is None, "Unmatched ( at", list(output)) self._assert(left_paren is None, "Unmatched ( at", list(output))
return self._apply_booleans(output) return self._apply_booleans(output)
def _apply_booleans(self, nodes): def _apply_booleans(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply and, or, and not constructions.""" """Apply and, or, and not constructions."""
nodes = self._apply_not(nodes) nodes = self._apply_not(nodes)
nodes = self._apply_and(nodes) nodes = self._apply_and(nodes)
@ -710,9 +719,9 @@ class ConditionExpressionParser:
) )
return nodes return nodes
def _apply_not(self, nodes): def _apply_not(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := NOT condition.""" """Apply condition := NOT condition."""
output = deque() output: Deque[ConditionExpressionParser.Node] = deque()
while nodes: while nodes:
if self._matches(nodes, ["NOT"]): if self._matches(nodes, ["NOT"]):
self._assert( self._assert(
@ -736,9 +745,9 @@ class ConditionExpressionParser:
return output return output
def _apply_and(self, nodes): def _apply_and(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := condition AND condition.""" """Apply condition := condition AND condition."""
output = deque() output: Deque[ConditionExpressionParser.Node] = deque()
while nodes: while nodes:
if self._matches(nodes, ["*", "AND"]): if self._matches(nodes, ["*", "AND"]):
self._assert( self._assert(
@ -764,9 +773,9 @@ class ConditionExpressionParser:
return output return output
def _apply_or(self, nodes): def _apply_or(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := condition OR condition.""" """Apply condition := condition OR condition."""
output = deque() output: Deque[ConditionExpressionParser.Node] = deque()
while nodes: while nodes:
if self._matches(nodes, ["*", "OR"]): if self._matches(nodes, ["*", "OR"]):
self._assert( self._assert(
@ -792,7 +801,7 @@ class ConditionExpressionParser:
return output return output
def _make_operand(self, node): def _make_operand(self, node: Node) -> "Operand":
if node.kind == self.Kind.PATH: if node.kind == self.Kind.PATH:
return AttributePath([child.value for child in node.children]) return AttributePath([child.value for child in node.children])
elif node.kind == self.Kind.EXPRESSION_ATTRIBUTE_VALUE: elif node.kind == self.Kind.EXPRESSION_ATTRIBUTE_VALUE:
@ -807,7 +816,7 @@ class ConditionExpressionParser:
else: # pragma: no cover else: # pragma: no cover
raise ValueError(f"Unknown operand: {node}") raise ValueError(f"Unknown operand: {node}")
def _make_op_condition(self, node): def _make_op_condition(self, node: Node) -> Union["Func", Op]:
if node.kind == self.Kind.OR: if node.kind == self.Kind.OR:
lhs, rhs = node.children lhs, rhs = node.children
return OpOr(self._make_op_condition(lhs), self._make_op_condition(rhs)) return OpOr(self._make_op_condition(lhs), self._make_op_condition(rhs))
@ -847,21 +856,21 @@ class ConditionExpressionParser:
else: # pragma: no cover else: # pragma: no cover
raise ValueError(f"Unknown expression node kind {node.kind}") raise ValueError(f"Unknown expression node kind {node.kind}")
def _assert(self, condition, message, nodes): def _assert(self, condition: bool, message: str, nodes: Iterable[Node]) -> None:
if not condition: if not condition:
raise ValueError(message + " " + " ".join([t.text for t in nodes])) raise ValueError(message + " " + " ".join([t.text for t in nodes]))
class Operand(object): class Operand:
def expr(self, item): def expr(self, item: Optional[Item]) -> Any: # type: ignore
raise NotImplementedError raise NotImplementedError
def get_type(self, item): def get_type(self, item: Optional[Item]) -> Optional[str]: # type: ignore
raise NotImplementedError raise NotImplementedError
class AttributePath(Operand): class AttributePath(Operand):
def __init__(self, path): def __init__(self, path: List[Any]):
"""Initialize the AttributePath. """Initialize the AttributePath.
Parameters Parameters
@ -872,7 +881,7 @@ class AttributePath(Operand):
assert len(path) >= 1 assert len(path) >= 1
self.path = path self.path = path
def _get_attr(self, item): def _get_attr(self, item: Optional[Item]) -> Any:
if item is None: if item is None:
return None return None
@ -888,26 +897,26 @@ class AttributePath(Operand):
return attr return attr
def expr(self, item): def expr(self, item: Optional[Item]) -> Any:
attr = self._get_attr(item) attr = self._get_attr(item)
if attr is None: if attr is None:
return None return None
else: else:
return attr.cast_value return attr.cast_value
def get_type(self, item): def get_type(self, item: Optional[Item]) -> Optional[str]:
attr = self._get_attr(item) attr = self._get_attr(item)
if attr is None: if attr is None:
return None return None
else: else:
return attr.type return attr.type
def __repr__(self): def __repr__(self) -> str:
return ".".join(self.path) return ".".join(self.path)
class AttributeValue(Operand): class AttributeValue(Operand):
def __init__(self, value): def __init__(self, value: Dict[str, Any]):
"""Initialize the AttributePath. """Initialize the AttributePath.
Parameters Parameters
@ -919,7 +928,7 @@ class AttributeValue(Operand):
self.type = list(value.keys())[0] self.type = list(value.keys())[0]
self.value = value[self.type] self.value = value[self.type]
def expr(self, item): def expr(self, item: Optional[Item]) -> Any:
# TODO: Reuse DynamoType code # TODO: Reuse DynamoType code
if self.type == "N": if self.type == "N":
try: try:
@ -939,17 +948,17 @@ class AttributeValue(Operand):
return self.value return self.value
return self.value return self.value
def get_type(self, item): def get_type(self, item: Optional[Item]) -> str:
return self.type return self.type
def __repr__(self): def __repr__(self) -> str:
return repr(self.value) return repr(self.value)
class OpDefault(Op): class OpDefault(Op):
OP = "NONE" OP = "NONE"
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
"""If no condition is specified, always True.""" """If no condition is specified, always True."""
return True return True
@ -957,21 +966,21 @@ class OpDefault(Op):
class OpNot(Op): class OpNot(Op):
OP = "NOT" OP = "NOT"
def __init__(self, lhs): def __init__(self, lhs: Union["Func", Op]):
super().__init__(lhs, None) super().__init__(lhs, None) # type: ignore[arg-type]
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
return not lhs return not lhs
def __str__(self): def __str__(self) -> str:
return f"({self.OP} {self.lhs})" return f"({self.OP} {self.lhs})"
class OpAnd(Op): class OpAnd(Op):
OP = "AND" OP = "AND"
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
return lhs and self.rhs.expr(item) return lhs and self.rhs.expr(item)
@ -979,7 +988,7 @@ class OpAnd(Op):
class OpLessThan(Op): class OpLessThan(Op):
OP = "<" OP = "<"
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
@ -992,7 +1001,7 @@ class OpLessThan(Op):
class OpGreaterThan(Op): class OpGreaterThan(Op):
OP = ">" OP = ">"
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
@ -1005,7 +1014,7 @@ class OpGreaterThan(Op):
class OpEqual(Op): class OpEqual(Op):
OP = "=" OP = "="
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
return lhs == rhs return lhs == rhs
@ -1014,7 +1023,7 @@ class OpEqual(Op):
class OpNotEqual(Op): class OpNotEqual(Op):
OP = "<>" OP = "<>"
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
return lhs != rhs return lhs != rhs
@ -1023,7 +1032,7 @@ class OpNotEqual(Op):
class OpLessThanOrEqual(Op): class OpLessThanOrEqual(Op):
OP = "<=" OP = "<="
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
@ -1036,7 +1045,7 @@ class OpLessThanOrEqual(Op):
class OpGreaterThanOrEqual(Op): class OpGreaterThanOrEqual(Op):
OP = ">=" OP = ">="
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
@ -1049,64 +1058,64 @@ class OpGreaterThanOrEqual(Op):
class OpOr(Op): class OpOr(Op):
OP = "OR" OP = "OR"
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
return lhs or self.rhs.expr(item) return lhs or self.rhs.expr(item)
class Func(object): class Func:
""" """
Base class for a FilterExpression function Base class for a FilterExpression function
""" """
FUNC = "Unknown" FUNC = "Unknown"
def __init__(self, *arguments): def __init__(self, *arguments: Any):
self.arguments = arguments self.arguments = arguments
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
raise NotImplementedError raise NotImplementedError
def __repr__(self): def __repr__(self) -> str:
return f"{self.FUNC}({' '.join([repr(arg) for arg in self.arguments])})" return f"{self.FUNC}({' '.join([repr(arg) for arg in self.arguments])})"
class FuncAttrExists(Func): class FuncAttrExists(Func):
FUNC = "attribute_exists" FUNC = "attribute_exists"
def __init__(self, attribute): def __init__(self, attribute: Operand):
self.attr = attribute self.attr = attribute
super().__init__(attribute) super().__init__(attribute)
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
return self.attr.get_type(item) is not None return self.attr.get_type(item) is not None
def FuncAttrNotExists(attribute): def FuncAttrNotExists(attribute: Operand) -> Any:
return OpNot(FuncAttrExists(attribute)) return OpNot(FuncAttrExists(attribute))
class FuncAttrType(Func): class FuncAttrType(Func):
FUNC = "attribute_type" FUNC = "attribute_type"
def __init__(self, attribute, _type): def __init__(self, attribute: Operand, _type: Func):
self.attr = attribute self.attr = attribute
self.type = _type self.type = _type
super().__init__(attribute, _type) super().__init__(attribute, _type)
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
return self.attr.get_type(item) == self.type.expr(item) return self.attr.get_type(item) == self.type.expr(item) # type: ignore[comparison-overlap]
class FuncBeginsWith(Func): class FuncBeginsWith(Func):
FUNC = "begins_with" FUNC = "begins_with"
def __init__(self, attribute, substr): def __init__(self, attribute: Operand, substr: Operand):
self.attr = attribute self.attr = attribute
self.substr = substr self.substr = substr
super().__init__(attribute, substr) super().__init__(attribute, substr)
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
if self.attr.get_type(item) != "S": if self.attr.get_type(item) != "S":
return False return False
if self.substr.get_type(item) != "S": if self.substr.get_type(item) != "S":
@ -1117,12 +1126,12 @@ class FuncBeginsWith(Func):
class FuncContains(Func): class FuncContains(Func):
FUNC = "contains" FUNC = "contains"
def __init__(self, attribute, operand): def __init__(self, attribute: Operand, operand: Operand):
self.attr = attribute self.attr = attribute
self.operand = operand self.operand = operand
super().__init__(attribute, operand) super().__init__(attribute, operand)
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
if self.attr.get_type(item) in ("S", "SS", "NS", "BS", "L"): if self.attr.get_type(item) in ("S", "SS", "NS", "BS", "L"):
try: try:
return self.operand.expr(item) in self.attr.expr(item) return self.operand.expr(item) in self.attr.expr(item)
@ -1131,18 +1140,18 @@ class FuncContains(Func):
return False return False
def FuncNotContains(attribute, operand): def FuncNotContains(attribute: Operand, operand: Operand) -> OpNot:
return OpNot(FuncContains(attribute, operand)) return OpNot(FuncContains(attribute, operand))
class FuncSize(Func): class FuncSize(Func):
FUNC = "size" FUNC = "size"
def __init__(self, attribute): def __init__(self, attribute: Operand):
self.attr = attribute self.attr = attribute
super().__init__(attribute) super().__init__(attribute)
def expr(self, item): def expr(self, item: Optional[Item]) -> int: # type: ignore[override]
if self.attr.get_type(item) is None: if self.attr.get_type(item) is None:
raise ValueError(f"Invalid attribute name {self.attr}") raise ValueError(f"Invalid attribute name {self.attr}")
@ -1154,13 +1163,13 @@ class FuncSize(Func):
class FuncBetween(Func): class FuncBetween(Func):
FUNC = "BETWEEN" FUNC = "BETWEEN"
def __init__(self, attribute, start, end): def __init__(self, attribute: Operand, start: Operand, end: Operand):
self.attr = attribute self.attr = attribute
self.start = start self.start = start
self.end = end self.end = end
super().__init__(attribute, start, end) super().__init__(attribute, start, end)
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
start = self.start.expr(item) start = self.start.expr(item)
attr = self.attr.expr(item) attr = self.attr.expr(item)
@ -1183,12 +1192,12 @@ class FuncBetween(Func):
class FuncIn(Func): class FuncIn(Func):
FUNC = "IN" FUNC = "IN"
def __init__(self, attribute, *possible_values): def __init__(self, attribute: Operand, *possible_values: Any):
self.attr = attribute self.attr = attribute
self.possible_values = possible_values self.possible_values = possible_values
super().__init__(attribute, *possible_values) super().__init__(attribute, *possible_values)
def expr(self, item): def expr(self, item: Optional[Item]) -> bool:
for possible_value in self.possible_values: for possible_value in self.possible_values:
if self.attr.expr(item) == possible_value.expr(item): if self.attr.expr(item) == possible_value.expr(item):
return True return True
@ -1205,7 +1214,7 @@ COMPARATOR_CLASS = {
"<>": OpNotEqual, "<>": OpNotEqual,
} }
FUNC_CLASS = { FUNC_CLASS: Dict[str, Any] = {
"attribute_exists": FuncAttrExists, "attribute_exists": FuncAttrExists,
"attribute_not_exists": FuncAttrNotExists, "attribute_not_exists": FuncAttrNotExists,
"attribute_type": FuncAttrType, "attribute_type": FuncAttrType,

View File

@ -1,4 +1,5 @@
import json import json
from typing import Any, List, Optional
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from moto.dynamodb.limits import HASH_KEY_MAX_LENGTH, RANGE_KEY_MAX_LENGTH from moto.dynamodb.limits import HASH_KEY_MAX_LENGTH, RANGE_KEY_MAX_LENGTH
@ -10,7 +11,7 @@ class DynamodbException(JsonRESTError):
class MockValidationException(DynamodbException): class MockValidationException(DynamodbException):
error_type = "com.amazonaws.dynamodb.v20111205#ValidationException" error_type = "com.amazonaws.dynamodb.v20111205#ValidationException"
def __init__(self, message): def __init__(self, message: str):
super().__init__(MockValidationException.error_type, message=message) super().__init__(MockValidationException.error_type, message=message)
self.exception_msg = message self.exception_msg = message
@ -24,14 +25,14 @@ class InvalidUpdateExpressionInvalidDocumentPath(MockValidationException):
"The document path provided in the update expression is invalid for update" "The document path provided in the update expression is invalid for update"
) )
def __init__(self): def __init__(self) -> None:
super().__init__(self.invalid_update_expression_msg) super().__init__(self.invalid_update_expression_msg)
class InvalidUpdateExpression(MockValidationException): class InvalidUpdateExpression(MockValidationException):
invalid_update_expr_msg = "Invalid UpdateExpression: {update_expression_error}" invalid_update_expr_msg = "Invalid UpdateExpression: {update_expression_error}"
def __init__(self, update_expression_error): def __init__(self, update_expression_error: str):
self.update_expression_error = update_expression_error self.update_expression_error = update_expression_error
super().__init__( super().__init__(
self.invalid_update_expr_msg.format( self.invalid_update_expr_msg.format(
@ -45,7 +46,7 @@ class InvalidConditionExpression(MockValidationException):
"Invalid ConditionExpression: {condition_expression_error}" "Invalid ConditionExpression: {condition_expression_error}"
) )
def __init__(self, condition_expression_error): def __init__(self, condition_expression_error: str):
self.condition_expression_error = condition_expression_error self.condition_expression_error = condition_expression_error
super().__init__( super().__init__(
self.invalid_condition_expr_msg.format( self.invalid_condition_expr_msg.format(
@ -59,7 +60,7 @@ class ConditionAttributeIsReservedKeyword(InvalidConditionExpression):
"Attribute name is a reserved keyword; reserved keyword: {keyword}" "Attribute name is a reserved keyword; reserved keyword: {keyword}"
) )
def __init__(self, keyword): def __init__(self, keyword: str):
self.keyword = keyword self.keyword = keyword
super().__init__(self.attribute_is_keyword_msg.format(keyword=keyword)) super().__init__(self.attribute_is_keyword_msg.format(keyword=keyword))
@ -69,7 +70,7 @@ class AttributeDoesNotExist(MockValidationException):
"The provided expression refers to an attribute that does not exist in the item" "The provided expression refers to an attribute that does not exist in the item"
) )
def __init__(self): def __init__(self) -> None:
super().__init__(self.attr_does_not_exist_msg) super().__init__(self.attr_does_not_exist_msg)
@ -78,14 +79,14 @@ class ProvidedKeyDoesNotExist(MockValidationException):
"The provided key element does not match the schema" "The provided key element does not match the schema"
) )
def __init__(self): def __init__(self) -> None:
super().__init__(self.provided_key_does_not_exist_msg) super().__init__(self.provided_key_does_not_exist_msg)
class ExpressionAttributeNameNotDefined(InvalidUpdateExpression): class ExpressionAttributeNameNotDefined(InvalidUpdateExpression):
name_not_defined_msg = "An expression attribute name used in the document path is not defined; attribute name: {n}" name_not_defined_msg = "An expression attribute name used in the document path is not defined; attribute name: {n}"
def __init__(self, attribute_name): def __init__(self, attribute_name: str):
self.not_defined_attribute_name = attribute_name self.not_defined_attribute_name = attribute_name
super().__init__(self.name_not_defined_msg.format(n=attribute_name)) super().__init__(self.name_not_defined_msg.format(n=attribute_name))
@ -95,7 +96,7 @@ class AttributeIsReservedKeyword(InvalidUpdateExpression):
"Attribute name is a reserved keyword; reserved keyword: {keyword}" "Attribute name is a reserved keyword; reserved keyword: {keyword}"
) )
def __init__(self, keyword): def __init__(self, keyword: str):
self.keyword = keyword self.keyword = keyword
super().__init__(self.attribute_is_keyword_msg.format(keyword=keyword)) super().__init__(self.attribute_is_keyword_msg.format(keyword=keyword))
@ -103,7 +104,7 @@ class AttributeIsReservedKeyword(InvalidUpdateExpression):
class ExpressionAttributeValueNotDefined(InvalidUpdateExpression): class ExpressionAttributeValueNotDefined(InvalidUpdateExpression):
attr_value_not_defined_msg = "An expression attribute value used in expression is not defined; attribute value: {attribute_value}" attr_value_not_defined_msg = "An expression attribute value used in expression is not defined; attribute value: {attribute_value}"
def __init__(self, attribute_value): def __init__(self, attribute_value: str):
self.attribute_value = attribute_value self.attribute_value = attribute_value
super().__init__( super().__init__(
self.attr_value_not_defined_msg.format(attribute_value=attribute_value) self.attr_value_not_defined_msg.format(attribute_value=attribute_value)
@ -113,7 +114,7 @@ class ExpressionAttributeValueNotDefined(InvalidUpdateExpression):
class UpdateExprSyntaxError(InvalidUpdateExpression): class UpdateExprSyntaxError(InvalidUpdateExpression):
update_expr_syntax_error_msg = "Syntax error; {error_detail}" update_expr_syntax_error_msg = "Syntax error; {error_detail}"
def __init__(self, error_detail): def __init__(self, error_detail: str):
self.error_detail = error_detail self.error_detail = error_detail
super().__init__( super().__init__(
self.update_expr_syntax_error_msg.format(error_detail=error_detail) self.update_expr_syntax_error_msg.format(error_detail=error_detail)
@ -123,7 +124,7 @@ class UpdateExprSyntaxError(InvalidUpdateExpression):
class InvalidTokenException(UpdateExprSyntaxError): class InvalidTokenException(UpdateExprSyntaxError):
token_detail_msg = 'token: "{token}", near: "{near}"' token_detail_msg = 'token: "{token}", near: "{near}"'
def __init__(self, token, near): def __init__(self, token: str, near: str):
self.token = token self.token = token
self.near = near self.near = near
super().__init__(self.token_detail_msg.format(token=token, near=near)) super().__init__(self.token_detail_msg.format(token=token, near=near))
@ -134,7 +135,7 @@ class InvalidExpressionAttributeNameKey(MockValidationException):
'ExpressionAttributeNames contains invalid key: Syntax error; key: "{key}"' 'ExpressionAttributeNames contains invalid key: Syntax error; key: "{key}"'
) )
def __init__(self, key): def __init__(self, key: str):
self.key = key self.key = key
super().__init__(self.invalid_expr_attr_name_msg.format(key=key)) super().__init__(self.invalid_expr_attr_name_msg.format(key=key))
@ -142,7 +143,7 @@ class InvalidExpressionAttributeNameKey(MockValidationException):
class ItemSizeTooLarge(MockValidationException): class ItemSizeTooLarge(MockValidationException):
item_size_too_large_msg = "Item size has exceeded the maximum allowed size" item_size_too_large_msg = "Item size has exceeded the maximum allowed size"
def __init__(self): def __init__(self) -> None:
super().__init__(self.item_size_too_large_msg) super().__init__(self.item_size_too_large_msg)
@ -151,7 +152,7 @@ class ItemSizeToUpdateTooLarge(MockValidationException):
"Item size to update has exceeded the maximum allowed size" "Item size to update has exceeded the maximum allowed size"
) )
def __init__(self): def __init__(self) -> None:
super().__init__(self.item_size_to_update_too_large_msg) super().__init__(self.item_size_to_update_too_large_msg)
@ -159,21 +160,21 @@ class HashKeyTooLong(MockValidationException):
# deliberately no space between of and {lim} # deliberately no space between of and {lim}
key_too_large_msg = f"One or more parameter values were invalid: Size of hashkey has exceeded the maximum size limit of{HASH_KEY_MAX_LENGTH} bytes" key_too_large_msg = f"One or more parameter values were invalid: Size of hashkey has exceeded the maximum size limit of{HASH_KEY_MAX_LENGTH} bytes"
def __init__(self): def __init__(self) -> None:
super().__init__(self.key_too_large_msg) super().__init__(self.key_too_large_msg)
class RangeKeyTooLong(MockValidationException): class RangeKeyTooLong(MockValidationException):
key_too_large_msg = f"One or more parameter values were invalid: Aggregated size of all range keys has exceeded the size limit of {RANGE_KEY_MAX_LENGTH} bytes" key_too_large_msg = f"One or more parameter values were invalid: Aggregated size of all range keys has exceeded the size limit of {RANGE_KEY_MAX_LENGTH} bytes"
def __init__(self): def __init__(self) -> None:
super().__init__(self.key_too_large_msg) super().__init__(self.key_too_large_msg)
class IncorrectOperandType(InvalidUpdateExpression): class IncorrectOperandType(InvalidUpdateExpression):
inv_operand_msg = "Incorrect operand type for operator or function; operator or function: {f}, operand type: {t}" inv_operand_msg = "Incorrect operand type for operator or function; operator or function: {f}, operand type: {t}"
def __init__(self, operator_or_function, operand_type): def __init__(self, operator_or_function: str, operand_type: str):
self.operator_or_function = operator_or_function self.operator_or_function = operator_or_function
self.operand_type = operand_type self.operand_type = operand_type
super().__init__( super().__init__(
@ -184,14 +185,14 @@ class IncorrectOperandType(InvalidUpdateExpression):
class IncorrectDataType(MockValidationException): class IncorrectDataType(MockValidationException):
inc_data_type_msg = "An operand in the update expression has an incorrect data type" inc_data_type_msg = "An operand in the update expression has an incorrect data type"
def __init__(self): def __init__(self) -> None:
super().__init__(self.inc_data_type_msg) super().__init__(self.inc_data_type_msg)
class ConditionalCheckFailed(DynamodbException): class ConditionalCheckFailed(DynamodbException):
error_type = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException" error_type = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException"
def __init__(self, msg=None): def __init__(self, msg: Optional[str] = None):
super().__init__( super().__init__(
ConditionalCheckFailed.error_type, msg or "The conditional request failed" ConditionalCheckFailed.error_type, msg or "The conditional request failed"
) )
@ -201,7 +202,7 @@ class TransactionCanceledException(DynamodbException):
cancel_reason_msg = "Transaction cancelled, please refer cancellation reasons for specific reasons [{}]" cancel_reason_msg = "Transaction cancelled, please refer cancellation reasons for specific reasons [{}]"
error_type = "com.amazonaws.dynamodb.v20120810#TransactionCanceledException" error_type = "com.amazonaws.dynamodb.v20120810#TransactionCanceledException"
def __init__(self, errors): def __init__(self, errors: List[Any]):
msg = self.cancel_reason_msg.format( msg = self.cancel_reason_msg.format(
", ".join([str(code) for code, _, _ in errors]) ", ".join([str(code) for code, _, _ in errors])
) )
@ -224,7 +225,7 @@ class TransactionCanceledException(DynamodbException):
class MultipleTransactionsException(MockValidationException): class MultipleTransactionsException(MockValidationException):
msg = "Transaction request cannot include multiple operations on one item" msg = "Transaction request cannot include multiple operations on one item"
def __init__(self): def __init__(self) -> None:
super().__init__(self.msg) super().__init__(self.msg)
@ -234,7 +235,7 @@ class TooManyTransactionsException(MockValidationException):
"Member must have length less than or equal to 100." "Member must have length less than or equal to 100."
) )
def __init__(self): def __init__(self) -> None:
super().__init__(self.msg) super().__init__(self.msg)
@ -243,26 +244,28 @@ class EmptyKeyAttributeException(MockValidationException):
# AWS has a different message for empty index keys # AWS has a different message for empty index keys
empty_index_msg = "One or more parameter values are not valid. The update expression attempted to update a secondary index key to a value that is not supported. The AttributeValue for a key attribute cannot contain an empty string value." empty_index_msg = "One or more parameter values are not valid. The update expression attempted to update a secondary index key to a value that is not supported. The AttributeValue for a key attribute cannot contain an empty string value."
def __init__(self, key_in_index=False): def __init__(self, key_in_index: bool = False):
super().__init__(self.empty_index_msg if key_in_index else self.empty_str_msg) super().__init__(self.empty_index_msg if key_in_index else self.empty_str_msg)
class UpdateHashRangeKeyException(MockValidationException): class UpdateHashRangeKeyException(MockValidationException):
msg = "One or more parameter values were invalid: Cannot update attribute {}. This attribute is part of the key" msg = "One or more parameter values were invalid: Cannot update attribute {}. This attribute is part of the key"
def __init__(self, key_name): def __init__(self, key_name: str):
super().__init__(self.msg.format(key_name)) super().__init__(self.msg.format(key_name))
class InvalidAttributeTypeError(MockValidationException): class InvalidAttributeTypeError(MockValidationException):
msg = "One or more parameter values were invalid: Type mismatch for key {} expected: {} actual: {}" msg = "One or more parameter values were invalid: Type mismatch for key {} expected: {} actual: {}"
def __init__(self, name, expected_type, actual_type): def __init__(
self, name: Optional[str], expected_type: Optional[str], actual_type: str
):
super().__init__(self.msg.format(name, expected_type, actual_type)) super().__init__(self.msg.format(name, expected_type, actual_type))
class DuplicateUpdateExpression(InvalidUpdateExpression): class DuplicateUpdateExpression(InvalidUpdateExpression):
def __init__(self, names): def __init__(self, names: List[str]):
super().__init__( super().__init__(
f"Two document paths overlap with each other; must remove or rewrite one of these paths; path one: [{names[0]}], path two: [{names[1]}]" f"Two document paths overlap with each other; must remove or rewrite one of these paths; path one: [{names[0]}], path two: [{names[1]}]"
) )
@ -271,54 +274,54 @@ class DuplicateUpdateExpression(InvalidUpdateExpression):
class TooManyAddClauses(InvalidUpdateExpression): class TooManyAddClauses(InvalidUpdateExpression):
msg = 'The "ADD" section can only be used once in an update expression;' msg = 'The "ADD" section can only be used once in an update expression;'
def __init__(self): def __init__(self) -> None:
super().__init__(self.msg) super().__init__(self.msg)
class ResourceNotFoundException(JsonRESTError): class ResourceNotFoundException(JsonRESTError):
def __init__(self, msg=None): def __init__(self, msg: Optional[str] = None):
err = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" err = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
super().__init__(err, msg or "Requested resource not found") super().__init__(err, msg or "Requested resource not found")
class TableNotFoundException(JsonRESTError): class TableNotFoundException(JsonRESTError):
def __init__(self, name): def __init__(self, name: str):
err = "com.amazonaws.dynamodb.v20111205#TableNotFoundException" err = "com.amazonaws.dynamodb.v20111205#TableNotFoundException"
super().__init__(err, f"Table not found: {name}") super().__init__(err, f"Table not found: {name}")
class SourceTableNotFoundException(JsonRESTError): class SourceTableNotFoundException(JsonRESTError):
def __init__(self, source_table_name): def __init__(self, source_table_name: str):
er = "com.amazonaws.dynamodb.v20111205#SourceTableNotFoundException" er = "com.amazonaws.dynamodb.v20111205#SourceTableNotFoundException"
super().__init__(er, f"Source table not found: {source_table_name}") super().__init__(er, f"Source table not found: {source_table_name}")
class BackupNotFoundException(JsonRESTError): class BackupNotFoundException(JsonRESTError):
def __init__(self, backup_arn): def __init__(self, backup_arn: str):
er = "com.amazonaws.dynamodb.v20111205#BackupNotFoundException" er = "com.amazonaws.dynamodb.v20111205#BackupNotFoundException"
super().__init__(er, f"Backup not found: {backup_arn}") super().__init__(er, f"Backup not found: {backup_arn}")
class TableAlreadyExistsException(JsonRESTError): class TableAlreadyExistsException(JsonRESTError):
def __init__(self, target_table_name): def __init__(self, target_table_name: str):
er = "com.amazonaws.dynamodb.v20111205#TableAlreadyExistsException" er = "com.amazonaws.dynamodb.v20111205#TableAlreadyExistsException"
super().__init__(er, f"Table already exists: {target_table_name}") super().__init__(er, f"Table already exists: {target_table_name}")
class ResourceInUseException(JsonRESTError): class ResourceInUseException(JsonRESTError):
def __init__(self): def __init__(self) -> None:
er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException" er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException"
super().__init__(er, "Resource in use") super().__init__(er, "Resource in use")
class StreamAlreadyEnabledException(JsonRESTError): class StreamAlreadyEnabledException(JsonRESTError):
def __init__(self): def __init__(self) -> None:
er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException" er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException"
super().__init__(er, "Cannot enable stream") super().__init__(er, "Cannot enable stream")
class InvalidConversion(JsonRESTError): class InvalidConversion(JsonRESTError):
def __init__(self): def __init__(self) -> None:
er = "SerializationException" er = "SerializationException"
super().__init__(er, "NUMBER_VALUE cannot be converted to String") super().__init__(er, "NUMBER_VALUE cannot be converted to String")
@ -328,10 +331,10 @@ class TransactWriteSingleOpException(MockValidationException):
"TransactItems can only contain one of Check, Put, Update or Delete" "TransactItems can only contain one of Check, Put, Update or Delete"
) )
def __init__(self): def __init__(self) -> None:
super().__init__(self.there_can_be_only_one) super().__init__(self.there_can_be_only_one)
class SerializationException(DynamodbException): class SerializationException(DynamodbException):
def __init__(self, msg): def __init__(self, msg: str):
super().__init__(error_type="SerializationException", message=msg) super().__init__(error_type="SerializationException", message=msg)

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,16 @@
from moto.dynamodb.comparisons import get_comparison_func import decimal
from moto.dynamodb.exceptions import IncorrectDataType from typing import Any, Dict, List, Union, Optional
from moto.core import BaseModel
from moto.dynamodb.exceptions import (
IncorrectDataType,
EmptyKeyAttributeException,
ItemSizeTooLarge,
)
from moto.dynamodb.models.utilities import bytesize from moto.dynamodb.models.utilities import bytesize
class DDBType(object): class DDBType:
""" """
Official documentation at https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_AttributeValue.html Official documentation at https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_AttributeValue.html
""" """
@ -20,7 +27,7 @@ class DDBType(object):
NULL = "NULL" NULL = "NULL"
class DDBTypeConversion(object): class DDBTypeConversion:
_human_type_mapping = { _human_type_mapping = {
val: key.replace("_", " ") val: key.replace("_", " ")
for key, val in DDBType.__dict__.items() for key, val in DDBType.__dict__.items()
@ -28,13 +35,13 @@ class DDBTypeConversion(object):
} }
@classmethod @classmethod
def get_human_type(cls, abbreviated_type): def get_human_type(cls, abbreviated_type: str) -> str:
""" """
Args: Args:
abbreviated_type(str): An attribute of DDBType abbreviated_type(str): An attribute of DDBType
Returns: Returns:
str: The human readable form of the DDBType. str: The human-readable form of the DDBType.
""" """
return cls._human_type_mapping.get(abbreviated_type, abbreviated_type) return cls._human_type_mapping.get(abbreviated_type, abbreviated_type)
@ -44,19 +51,19 @@ class DynamoType(object):
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
""" """
def __init__(self, type_as_dict): def __init__(self, type_as_dict: Union["DynamoType", Dict[str, Any]]):
if type(type_as_dict) == DynamoType: if type(type_as_dict) == DynamoType:
self.type = type_as_dict.type self.type: str = type_as_dict.type
self.value = type_as_dict.value self.value: Any = type_as_dict.value
else: else:
self.type = list(type_as_dict)[0] self.type = list(type_as_dict)[0] # type: ignore[arg-type]
self.value = list(type_as_dict.values())[0] self.value = list(type_as_dict.values())[0] # type: ignore[union-attr]
if self.is_list(): if self.is_list():
self.value = [DynamoType(val) for val in self.value] self.value = [DynamoType(val) for val in self.value]
elif self.is_map(): elif self.is_map():
self.value = dict((k, DynamoType(v)) for k, v in self.value.items()) self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
def filter(self, projection_expressions): def filter(self, projection_expressions: str) -> None:
nested_projections = [ nested_projections = [
expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr
] ]
@ -78,31 +85,31 @@ class DynamoType(object):
for expr in expressions_to_delete: for expr in expressions_to_delete:
self.value.pop(expr) self.value.pop(expr)
def __hash__(self): def __hash__(self) -> int:
return hash((self.type, self.value)) return hash((self.type, self.value))
def __eq__(self, other): def __eq__(self, other: "DynamoType") -> bool: # type: ignore[override]
return self.type == other.type and self.value == other.value return self.type == other.type and self.value == other.value
def __ne__(self, other): def __ne__(self, other: "DynamoType") -> bool: # type: ignore[override]
return self.type != other.type or self.value != other.value return self.type != other.type or self.value != other.value
def __lt__(self, other): def __lt__(self, other: "DynamoType") -> bool:
return self.cast_value < other.cast_value return self.cast_value < other.cast_value
def __le__(self, other): def __le__(self, other: "DynamoType") -> bool:
return self.cast_value <= other.cast_value return self.cast_value <= other.cast_value
def __gt__(self, other): def __gt__(self, other: "DynamoType") -> bool:
return self.cast_value > other.cast_value return self.cast_value > other.cast_value
def __ge__(self, other): def __ge__(self, other: "DynamoType") -> bool:
return self.cast_value >= other.cast_value return self.cast_value >= other.cast_value
def __repr__(self): def __repr__(self) -> str:
return f"DynamoType: {self.to_json()}" return f"DynamoType: {self.to_json()}"
def __add__(self, other): def __add__(self, other: "DynamoType") -> "DynamoType":
if self.type != other.type: if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.") raise TypeError("Different types of operandi is not allowed.")
if self.is_number(): if self.is_number():
@ -112,7 +119,7 @@ class DynamoType(object):
else: else:
raise IncorrectDataType() raise IncorrectDataType()
def __sub__(self, other): def __sub__(self, other: "DynamoType") -> "DynamoType":
if self.type != other.type: if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.") raise TypeError("Different types of operandi is not allowed.")
if self.type == DDBType.NUMBER: if self.type == DDBType.NUMBER:
@ -122,7 +129,7 @@ class DynamoType(object):
else: else:
raise TypeError("Sum only supported for Numbers.") raise TypeError("Sum only supported for Numbers.")
def __getitem__(self, item): def __getitem__(self, item: "DynamoType") -> "DynamoType":
if isinstance(item, str): if isinstance(item, str):
# If our DynamoType is a map it should be subscriptable with a key # If our DynamoType is a map it should be subscriptable with a key
if self.type == DDBType.MAP: if self.type == DDBType.MAP:
@ -135,7 +142,7 @@ class DynamoType(object):
f"This DynamoType {self.type} is not subscriptable by a {type(item)}" f"This DynamoType {self.type} is not subscriptable by a {type(item)}"
) )
def __setitem__(self, key, value): def __setitem__(self, key: Any, value: Any) -> None:
if isinstance(key, int): if isinstance(key, int):
if self.is_list(): if self.is_list():
if key >= len(self.value): if key >= len(self.value):
@ -150,7 +157,7 @@ class DynamoType(object):
raise NotImplementedError(f"No set_item for {type(key)}") raise NotImplementedError(f"No set_item for {type(key)}")
@property @property
def cast_value(self): def cast_value(self) -> Any: # type: ignore[misc]
if self.is_number(): if self.is_number():
try: try:
return int(self.value) return int(self.value)
@ -166,7 +173,7 @@ class DynamoType(object):
else: else:
return self.value return self.value
def child_attr(self, key): def child_attr(self, key: Union[int, str, None]) -> Optional["DynamoType"]:
""" """
Get Map or List children by key. str for Map, int for List. Get Map or List children by key. str for Map, int for List.
@ -183,7 +190,7 @@ class DynamoType(object):
return None return None
def size(self): def size(self) -> int:
if self.is_number(): if self.is_number():
value_size = len(str(self.value)) value_size = len(str(self.value))
elif self.is_set(): elif self.is_set():
@ -201,34 +208,204 @@ class DynamoType(object):
value_size = bytesize(self.value) value_size = bytesize(self.value)
return value_size return value_size
def to_json(self): def to_json(self) -> Dict[str, Any]:
return {self.type: self.value} return {self.type: self.value}
def compare(self, range_comparison, range_objs): def compare(self, range_comparison: str, range_objs: List[Any]) -> bool:
""" """
Compares this type against comparison filters Compares this type against comparison filters
""" """
from moto.dynamodb.comparisons import get_comparison_func
range_values = [obj.cast_value for obj in range_objs] range_values = [obj.cast_value for obj in range_objs]
comparison_func = get_comparison_func(range_comparison) comparison_func = get_comparison_func(range_comparison)
return comparison_func(self.cast_value, *range_values) return comparison_func(self.cast_value, *range_values)
def is_number(self): def is_number(self) -> bool:
return self.type == DDBType.NUMBER return self.type == DDBType.NUMBER
def is_set(self): def is_set(self) -> bool:
return self.type in (DDBType.STRING_SET, DDBType.NUMBER_SET, DDBType.BINARY_SET) return self.type in (DDBType.STRING_SET, DDBType.NUMBER_SET, DDBType.BINARY_SET)
def is_list(self): def is_list(self) -> bool:
return self.type == DDBType.LIST return self.type == DDBType.LIST
def is_map(self): def is_map(self) -> bool:
return self.type == DDBType.MAP return self.type == DDBType.MAP
def same_type(self, other): def same_type(self, other: "DynamoType") -> bool:
return self.type == other.type return self.type == other.type
def pop(self, key, *args, **kwargs): def pop(self, key: str, *args: Any, **kwargs: Any) -> None:
if self.is_map() or self.is_list(): if self.is_map() or self.is_list():
self.value.pop(key, *args, **kwargs) self.value.pop(key, *args, **kwargs)
else: else:
raise TypeError(f"pop not supported for DynamoType {self.type}") raise TypeError(f"pop not supported for DynamoType {self.type}")
# https://github.com/getmoto/moto/issues/1874
# Ensure that the total size of an item does not exceed 400kb
class LimitedSizeDict(Dict[str, Any]):
def __init__(self, *args: Any, **kwargs: Any):
self.update(*args, **kwargs)
def __setitem__(self, key: str, value: Any) -> None:
current_item_size = sum(
[
item.size() if type(item) == DynamoType else bytesize(str(item))
for item in (list(self.keys()) + list(self.values()))
]
)
new_item_size = bytesize(key) + (
value.size() if type(value) == DynamoType else bytesize(str(value))
)
# Official limit is set to 400000 (400KB)
# Manual testing confirms that the actual limit is between 409 and 410KB
# We'll set the limit to something in between to be safe
if (current_item_size + new_item_size) > 405000:
raise ItemSizeTooLarge
super().__setitem__(key, value)
class Item(BaseModel):
def __init__(
self,
hash_key: DynamoType,
range_key: Optional[DynamoType],
attrs: Dict[str, Any],
):
self.hash_key = hash_key
self.range_key = range_key
self.attrs = LimitedSizeDict()
for key, value in attrs.items():
self.attrs[key] = DynamoType(value)
def __eq__(self, other: "Item") -> bool: # type: ignore[override]
return all(
[
self.hash_key == other.hash_key,
self.range_key == other.range_key, # type: ignore[operator]
self.attrs == other.attrs,
]
)
def __repr__(self) -> str:
return f"Item: {self.to_json()}"
def size(self) -> int:
return sum(bytesize(key) + value.size() for key, value in self.attrs.items())
def to_json(self) -> Dict[str, Any]:
attributes = {}
for attribute_key, attribute in self.attrs.items():
attributes[attribute_key] = {attribute.type: attribute.value}
return {"Attributes": attributes}
def describe_attrs(
self, attributes: Optional[Dict[str, Any]]
) -> Dict[str, Dict[str, Any]]:
if attributes:
included = {}
for key, value in self.attrs.items():
if key in attributes:
included[key] = value
else:
included = self.attrs
return {"Item": included}
def validate_no_empty_key_values(
self, attribute_updates: Dict[str, Any], key_attributes: List[str]
) -> None:
for attribute_name, update_action in attribute_updates.items():
action = update_action.get("Action") or "PUT" # PUT is default
if action == "DELETE":
continue
new_value = next(iter(update_action["Value"].values()))
if action == "PUT" and new_value == "" and attribute_name in key_attributes:
raise EmptyKeyAttributeException
def update_with_attribute_updates(self, attribute_updates: Dict[str, Any]) -> None:
for attribute_name, update_action in attribute_updates.items():
# Use default Action value, if no explicit Action is passed.
# Default value is 'Put', according to
# Boto3 DynamoDB.Client.update_item documentation.
action = update_action.get("Action", "PUT")
if action == "DELETE" and "Value" not in update_action:
if attribute_name in self.attrs:
del self.attrs[attribute_name]
continue
new_value = list(update_action["Value"].values())[0]
if action == "PUT":
# TODO deal with other types
if set(update_action["Value"].keys()) == set(["SS"]):
self.attrs[attribute_name] = DynamoType({"SS": new_value})
elif isinstance(new_value, list):
self.attrs[attribute_name] = DynamoType({"L": new_value})
elif isinstance(new_value, dict):
self.attrs[attribute_name] = DynamoType({"M": new_value})
elif set(update_action["Value"].keys()) == set(["N"]):
self.attrs[attribute_name] = DynamoType({"N": new_value})
elif set(update_action["Value"].keys()) == set(["NULL"]):
if attribute_name in self.attrs:
del self.attrs[attribute_name]
else:
self.attrs[attribute_name] = DynamoType({"S": new_value})
elif action == "ADD":
if set(update_action["Value"].keys()) == set(["N"]):
existing = self.attrs.get(attribute_name, DynamoType({"N": "0"}))
self.attrs[attribute_name] = DynamoType(
{
"N": str(
decimal.Decimal(existing.value)
+ decimal.Decimal(new_value)
)
}
)
elif set(update_action["Value"].keys()) == set(["SS"]):
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
new_set = set(existing.value).union(set(new_value))
self.attrs[attribute_name] = DynamoType({"SS": list(new_set)})
elif set(update_action["Value"].keys()) == {"L"}:
existing = self.attrs.get(attribute_name, DynamoType({"L": []}))
new_list = existing.value + new_value
self.attrs[attribute_name] = DynamoType({"L": new_list})
else:
# TODO: implement other data types
raise NotImplementedError(
"ADD not supported for %s"
% ", ".join(update_action["Value"].keys())
)
elif action == "DELETE":
if set(update_action["Value"].keys()) == set(["SS"]):
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
new_set = set(existing.value).difference(set(new_value))
self.attrs[attribute_name] = DynamoType({"SS": list(new_set)})
else:
raise NotImplementedError(
"ADD not supported for %s"
% ", ".join(update_action["Value"].keys())
)
else:
raise NotImplementedError(
f"{action} action not support for update_with_attribute_updates"
)
# Filter using projection_expression
# Ensure a deep copy is used to filter, otherwise actual data will be removed
def filter(self, projection_expression: str) -> None:
expressions = [x.strip() for x in projection_expression.split(",")]
top_level_expressions = [
expr[0 : expr.index(".")] for expr in expressions if "." in expr
]
for attr in list(self.attrs):
if attr not in expressions and attr not in top_level_expressions:
self.attrs.pop(attr)
if attr in top_level_expressions:
relevant_expressions = [
expr[len(attr + ".") :]
for expr in expressions
if expr.startswith(attr + ".")
]
self.attrs[attr].filter(relevant_expressions)

File diff suppressed because it is too large Load Diff

View File

@ -1,2 +1,16 @@
def bytesize(val): import json
from typing import Any
class DynamoJsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
if hasattr(o, "to_json"):
return o.to_json()
def dynamo_json_dump(dynamo_object: Any) -> str:
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
def bytesize(val: str) -> int:
return len(val.encode("utf-8")) return len(val.encode("utf-8"))

View File

@ -1,3 +1,4 @@
# type: ignore
import abc import abc
from abc import abstractmethod from abc import abstractmethod
from collections import deque from collections import deque
@ -21,7 +22,7 @@ class Node(metaclass=abc.ABCMeta):
def set_parent(self, parent_node): def set_parent(self, parent_node):
self.parent = parent_node self.parent = parent_node
def validate(self): def validate(self) -> None:
if self.type == "UpdateExpression": if self.type == "UpdateExpression":
nr_of_clauses = len(self.find_clauses([UpdateExpressionAddClause])) nr_of_clauses = len(self.find_clauses([UpdateExpressionAddClause]))
if nr_of_clauses > 1: if nr_of_clauses > 1:

View File

@ -1,13 +1,19 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union, Type
from moto.dynamodb.exceptions import ( from moto.dynamodb.exceptions import (
IncorrectOperandType, IncorrectOperandType,
IncorrectDataType, IncorrectDataType,
ProvidedKeyDoesNotExist, ProvidedKeyDoesNotExist,
) )
from moto.dynamodb.models import DynamoType from moto.dynamodb.models.dynamo_type import (
from moto.dynamodb.models.dynamo_type import DDBTypeConversion, DDBType DDBTypeConversion,
from moto.dynamodb.parsing.ast_nodes import ( DDBType,
DynamoType,
Item,
)
from moto.dynamodb.parsing.ast_nodes import ( # type: ignore
Node,
UpdateExpressionSetAction, UpdateExpressionSetAction,
UpdateExpressionDeleteAction, UpdateExpressionDeleteAction,
UpdateExpressionRemoveAction, UpdateExpressionRemoveAction,
@ -21,16 +27,18 @@ from moto.dynamodb.parsing.ast_nodes import (
from moto.dynamodb.parsing.validators import ExpressionPathResolver from moto.dynamodb.parsing.validators import ExpressionPathResolver
class NodeExecutor(object): class NodeExecutor:
def __init__(self, ast_node, expression_attribute_names): def __init__(self, ast_node: Node, expression_attribute_names: Dict[str, str]):
self.node = ast_node self.node = ast_node
self.expression_attribute_names = expression_attribute_names self.expression_attribute_names = expression_attribute_names
@abstractmethod @abstractmethod
def execute(self, item): def execute(self, item: Item) -> None:
pass pass
def get_item_part_for_path_nodes(self, item, path_nodes): def get_item_part_for_path_nodes(
self, item: Item, path_nodes: List[Node]
) -> Union[DynamoType, Dict[str, Any]]:
""" """
For a list of path nodes travers the item by following the path_nodes For a list of path nodes travers the item by following the path_nodes
Args: Args:
@ -43,11 +51,13 @@ class NodeExecutor(object):
if len(path_nodes) == 0: if len(path_nodes) == 0:
return item.attrs return item.attrs
else: else:
return ExpressionPathResolver( return ExpressionPathResolver( # type: ignore
self.expression_attribute_names self.expression_attribute_names
).resolve_expression_path_nodes_to_dynamo_type(item, path_nodes) ).resolve_expression_path_nodes_to_dynamo_type(item, path_nodes)
def get_item_before_end_of_path(self, item): def get_item_before_end_of_path(
self, item: Item
) -> Union[DynamoType, Dict[str, Any]]:
""" """
Get the part ot the item where the item will perform the action. For most actions this should be the parent. As Get the part ot the item where the item will perform the action. For most actions this should be the parent. As
that element will need to be modified by the action. that element will need to be modified by the action.
@ -61,7 +71,7 @@ class NodeExecutor(object):
item, self.get_path_expression_nodes()[:-1] item, self.get_path_expression_nodes()[:-1]
) )
def get_item_at_end_of_path(self, item): def get_item_at_end_of_path(self, item: Item) -> Union[DynamoType, Dict[str, Any]]:
""" """
For a DELETE the path points at the stringset so we need to evaluate the full path. For a DELETE the path points at the stringset so we need to evaluate the full path.
Args: Args:
@ -76,15 +86,15 @@ class NodeExecutor(object):
# that element will need to be modified by the action. # that element will need to be modified by the action.
get_item_part_in_which_to_perform_action = get_item_before_end_of_path get_item_part_in_which_to_perform_action = get_item_before_end_of_path
def get_path_expression_nodes(self): def get_path_expression_nodes(self) -> List[Node]:
update_expression_path = self.node.children[0] update_expression_path = self.node.children[0]
assert isinstance(update_expression_path, UpdateExpressionPath) assert isinstance(update_expression_path, UpdateExpressionPath)
return update_expression_path.children return update_expression_path.children
def get_element_to_action(self): def get_element_to_action(self) -> Node:
return self.get_path_expression_nodes()[-1] return self.get_path_expression_nodes()[-1]
def get_action_value(self): def get_action_value(self) -> DynamoType:
""" """
Returns: Returns:
@ -98,7 +108,7 @@ class NodeExecutor(object):
class SetExecutor(NodeExecutor): class SetExecutor(NodeExecutor):
def execute(self, item): def execute(self, item: Item) -> None:
self.set( self.set(
item_part_to_modify_with_set=self.get_item_part_in_which_to_perform_action( item_part_to_modify_with_set=self.get_item_part_in_which_to_perform_action(
item item
@ -109,13 +119,13 @@ class SetExecutor(NodeExecutor):
) )
@classmethod @classmethod
def set( def set( # type: ignore[misc]
cls, cls,
item_part_to_modify_with_set, item_part_to_modify_with_set: Union[DynamoType, Dict[str, Any]],
element_to_set, element_to_set: Any,
value_to_set, value_to_set: Any,
expression_attribute_names, expression_attribute_names: Dict[str, str],
): ) -> None:
if isinstance(element_to_set, ExpressionAttribute): if isinstance(element_to_set, ExpressionAttribute):
attribute_name = element_to_set.get_attribute_name() attribute_name = element_to_set.get_attribute_name()
item_part_to_modify_with_set[attribute_name] = value_to_set item_part_to_modify_with_set[attribute_name] = value_to_set
@ -136,7 +146,7 @@ class SetExecutor(NodeExecutor):
class DeleteExecutor(NodeExecutor): class DeleteExecutor(NodeExecutor):
operator = "operator: DELETE" operator = "operator: DELETE"
def execute(self, item): def execute(self, item: Item) -> None:
string_set_to_remove = self.get_action_value() string_set_to_remove = self.get_action_value()
assert isinstance(string_set_to_remove, DynamoType) assert isinstance(string_set_to_remove, DynamoType)
if not string_set_to_remove.is_set(): if not string_set_to_remove.is_set():
@ -176,11 +186,11 @@ class DeleteExecutor(NodeExecutor):
f"Moto does not support deleting {type(element)} yet" f"Moto does not support deleting {type(element)} yet"
) )
container = self.get_item_before_end_of_path(item) container = self.get_item_before_end_of_path(item)
del container[attribute_name] del container[attribute_name] # type: ignore[union-attr]
class RemoveExecutor(NodeExecutor): class RemoveExecutor(NodeExecutor):
def execute(self, item): def execute(self, item: Item) -> None:
element_to_remove = self.get_element_to_action() element_to_remove = self.get_element_to_action()
if isinstance(element_to_remove, ExpressionAttribute): if isinstance(element_to_remove, ExpressionAttribute):
attribute_name = element_to_remove.get_attribute_name() attribute_name = element_to_remove.get_attribute_name()
@ -208,7 +218,7 @@ class RemoveExecutor(NodeExecutor):
class AddExecutor(NodeExecutor): class AddExecutor(NodeExecutor):
def execute(self, item): def execute(self, item: Item) -> None:
value_to_add = self.get_action_value() value_to_add = self.get_action_value()
if isinstance(value_to_add, DynamoType): if isinstance(value_to_add, DynamoType):
if value_to_add.is_set(): if value_to_add.is_set():
@ -253,7 +263,7 @@ class AddExecutor(NodeExecutor):
raise IncorrectDataType() raise IncorrectDataType()
class UpdateExpressionExecutor(object): class UpdateExpressionExecutor:
execution_map = { execution_map = {
UpdateExpressionSetAction: SetExecutor, UpdateExpressionSetAction: SetExecutor,
UpdateExpressionAddAction: AddExecutor, UpdateExpressionAddAction: AddExecutor,
@ -261,12 +271,14 @@ class UpdateExpressionExecutor(object):
UpdateExpressionDeleteAction: DeleteExecutor, UpdateExpressionDeleteAction: DeleteExecutor,
} }
def __init__(self, update_ast, item, expression_attribute_names): def __init__(
self, update_ast: Node, item: Item, expression_attribute_names: Dict[str, str]
):
self.update_ast = update_ast self.update_ast = update_ast
self.item = item self.item = item
self.expression_attribute_names = expression_attribute_names self.expression_attribute_names = expression_attribute_names
def execute(self, node=None): def execute(self, node: Optional[Node] = None) -> None:
""" """
As explained in moto.dynamodb.parsing.expressions.NestableExpressionParserMixin._create_node the order of nodes As explained in moto.dynamodb.parsing.expressions.NestableExpressionParserMixin._create_node the order of nodes
in the AST can be translated of the order of statements in the expression. As such we can start at the root node in the AST can be translated of the order of statements in the expression. As such we can start at the root node
@ -286,12 +298,12 @@ class UpdateExpressionExecutor(object):
node_executor = self.get_specific_execution(node) node_executor = self.get_specific_execution(node)
if node_executor is None: if node_executor is None:
for node in node.children: for n in node.children:
self.execute(node) self.execute(n)
else: else:
node_executor(node, self.expression_attribute_names).execute(self.item) node_executor(node, self.expression_attribute_names).execute(self.item)
def get_specific_execution(self, node): def get_specific_execution(self, node: Node) -> Optional[Type[NodeExecutor]]:
for node_class in self.execution_map: for node_class in self.execution_map:
if isinstance(node, node_class): if isinstance(node, node_class):
return self.execution_map[node_class] return self.execution_map[node_class]

View File

@ -1,3 +1,4 @@
# type: ignore
import logging import logging
from abc import abstractmethod from abc import abstractmethod
import abc import abc
@ -35,7 +36,7 @@ from moto.dynamodb.parsing.tokens import Token, ExpressionTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NestableExpressionParserMixin(object): class NestableExpressionParserMixin:
""" """
For nodes that can be nested in themselves (recursive). Take for example UpdateExpression's grammar: For nodes that can be nested in themselves (recursive). Take for example UpdateExpression's grammar:

View File

@ -1,4 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, List, Dict, Tuple, Optional
from moto.dynamodb.exceptions import MockValidationException from moto.dynamodb.exceptions import MockValidationException
from moto.utilities.tokenizer import GenericTokenizer from moto.utilities.tokenizer import GenericTokenizer
@ -11,17 +12,17 @@ class EXPRESSION_STAGES(Enum):
EOF = "EOF" EOF = "EOF"
def get_key(schema, key_type): def get_key(schema: List[Dict[str, str]], key_type: str) -> Optional[str]:
keys = [key for key in schema if key["KeyType"] == key_type] keys = [key for key in schema if key["KeyType"] == key_type]
return keys[0]["AttributeName"] if keys else None return keys[0]["AttributeName"] if keys else None
def parse_expression( def parse_expression(
key_condition_expression, key_condition_expression: str,
expression_attribute_values, expression_attribute_values: Dict[str, str],
expression_attribute_names, expression_attribute_names: Dict[str, str],
schema, schema: List[Dict[str, str]],
): ) -> Tuple[Dict[str, Any], Optional[str], List[Dict[str, Any]]]:
""" """
Parse a KeyConditionExpression using the provided expression attribute names/values Parse a KeyConditionExpression using the provided expression attribute names/values
@ -31,11 +32,11 @@ def parse_expression(
schema: [{'AttributeName': 'hashkey', 'KeyType': 'HASH'}, {"AttributeName": "sortkey", "KeyType": "RANGE"}] schema: [{'AttributeName': 'hashkey', 'KeyType': 'HASH'}, {"AttributeName": "sortkey", "KeyType": "RANGE"}]
""" """
current_stage: EXPRESSION_STAGES = None current_stage: Optional[EXPRESSION_STAGES] = None
current_phrase = "" current_phrase = ""
key_name = comparison = None key_name = comparison = ""
key_values = [] key_values = []
results = [] results: List[Tuple[str, str, Any]] = []
tokenizer = GenericTokenizer(key_condition_expression) tokenizer = GenericTokenizer(key_condition_expression)
for crnt_char in tokenizer: for crnt_char in tokenizer:
if crnt_char == " ": if crnt_char == " ":
@ -188,7 +189,9 @@ def parse_expression(
# Validate that the schema-keys are encountered in our query # Validate that the schema-keys are encountered in our query
def validate_schema(results, schema): def validate_schema(
results: Any, schema: List[Dict[str, str]]
) -> Tuple[Dict[str, Any], Optional[str], List[Dict[str, Any]]]:
index_hash_key = get_key(schema, "HASH") index_hash_key = get_key(schema, "HASH")
comparison, hash_value = next( comparison, hash_value = next(
( (
@ -219,4 +222,4 @@ def validate_schema(results, schema):
f"Query condition missed key schema element: {index_range_key}" f"Query condition missed key schema element: {index_range_key}"
) )
return hash_value, range_comparison, range_values return hash_value, range_comparison, range_values # type: ignore[return-value]

View File

@ -1,27 +1,26 @@
from moto.utilities.utils import load_resource from typing import List, Optional
from moto.utilities.utils import load_resource_as_str
class ReservedKeywords(list): class ReservedKeywords:
""" """
DynamoDB has an extensive list of keywords. Keywords are considered when validating the expression Tree. DynamoDB has an extensive list of keywords. Keywords are considered when validating the expression Tree.
Not earlier since an update expression like "SET path = VALUE 1" fails with: Not earlier since an update expression like "SET path = VALUE 1" fails with:
'Invalid UpdateExpression: Syntax error; token: "1", near: "VALUE 1"' 'Invalid UpdateExpression: Syntax error; token: "1", near: "VALUE 1"'
""" """
KEYWORDS = None KEYWORDS: Optional[List[str]] = None
@classmethod @classmethod
def get_reserved_keywords(cls): def get_reserved_keywords(cls) -> List[str]:
if cls.KEYWORDS is None: if cls.KEYWORDS is None:
cls.KEYWORDS = cls._get_reserved_keywords() cls.KEYWORDS = cls._get_reserved_keywords()
return cls.KEYWORDS return cls.KEYWORDS
@classmethod @classmethod
def _get_reserved_keywords(cls): def _get_reserved_keywords(cls) -> List[str]:
""" """
Get a list of reserved keywords of DynamoDB Get a list of reserved keywords of DynamoDB
""" """
reserved_keywords = load_resource( reserved_keywords = load_resource_as_str(__name__, "reserved_keywords.txt")
__name__, "reserved_keywords.txt", as_json=False
)
return reserved_keywords.split() return reserved_keywords.split()

View File

@ -1,4 +1,5 @@
import re import re
from typing import List, Union
from moto.dynamodb.exceptions import ( from moto.dynamodb.exceptions import (
InvalidTokenException, InvalidTokenException,
@ -6,7 +7,7 @@ from moto.dynamodb.exceptions import (
) )
class Token(object): class Token:
_TOKEN_INSTANCE = None _TOKEN_INSTANCE = None
MINUS_SIGN = "-" MINUS_SIGN = "-"
PLUS_SIGN = "+" PLUS_SIGN = "+"
@ -53,7 +54,7 @@ class Token(object):
NUMBER: "Number", NUMBER: "Number",
} }
def __init__(self, token_type, value): def __init__(self, token_type: Union[int, str], value: str):
assert ( assert (
token_type in self.SPECIAL_CHARACTERS token_type in self.SPECIAL_CHARACTERS
or token_type in self.PLACEHOLDER_NAMES or token_type in self.PLACEHOLDER_NAMES
@ -61,13 +62,13 @@ class Token(object):
self.type = token_type self.type = token_type
self.value = value self.value = value
def __repr__(self): def __repr__(self) -> str:
if isinstance(self.type, int): if isinstance(self.type, int):
return f'Token("{self.PLACEHOLDER_NAMES[self.type]}", "{self.value}")' return f'Token("{self.PLACEHOLDER_NAMES[self.type]}", "{self.value}")'
else: else:
return f'Token("{self.type}", "{self.value}")' return f'Token("{self.type}", "{self.value}")'
def __eq__(self, other): def __eq__(self, other: "Token") -> bool: # type: ignore[override]
return self.type == other.type and self.value == other.value return self.type == other.type and self.value == other.value
@ -94,22 +95,22 @@ class ExpressionTokenizer(object):
""" """
@classmethod @classmethod
def is_simple_token_character(cls, character): def is_simple_token_character(cls, character: str) -> bool:
return character.isalnum() or character in ("_", ":", "#") return character.isalnum() or character in ("_", ":", "#")
@classmethod @classmethod
def is_possible_token_boundary(cls, character): def is_possible_token_boundary(cls, character: str) -> bool:
return ( return (
character in Token.SPECIAL_CHARACTERS character in Token.SPECIAL_CHARACTERS
or not cls.is_simple_token_character(character) or not cls.is_simple_token_character(character)
) )
@classmethod @classmethod
def is_expression_attribute(cls, input_string): def is_expression_attribute(cls, input_string: str) -> bool:
return re.compile("^[a-zA-Z0-9][a-zA-Z0-9_]*$").match(input_string) is not None return re.compile("^[a-zA-Z0-9][a-zA-Z0-9_]*$").match(input_string) is not None
@classmethod @classmethod
def is_expression_attribute_name(cls, input_string): def is_expression_attribute_name(cls, input_string: str) -> bool:
""" """
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html
An expression attribute name must begin with a pound sign (#), and be followed by one or more alphanumeric An expression attribute name must begin with a pound sign (#), and be followed by one or more alphanumeric
@ -120,10 +121,10 @@ class ExpressionTokenizer(object):
) )
@classmethod @classmethod
def is_expression_attribute_value(cls, input_string): def is_expression_attribute_value(cls, input_string: str) -> bool:
return re.compile("^:[a-zA-Z0-9_]*$").match(input_string) is not None return re.compile("^:[a-zA-Z0-9_]*$").match(input_string) is not None
def raise_unexpected_token(self): def raise_unexpected_token(self) -> None:
"""If during parsing an unexpected token is encountered""" """If during parsing an unexpected token is encountered"""
if len(self.token_list) == 0: if len(self.token_list) == 0:
near = "" near = ""
@ -140,29 +141,29 @@ class ExpressionTokenizer(object):
problematic_token = self.staged_characters[0] problematic_token = self.staged_characters[0]
raise InvalidTokenException(problematic_token, near + self.staged_characters) raise InvalidTokenException(problematic_token, near + self.staged_characters)
def __init__(self, input_expression_str): def __init__(self, input_expression_str: str):
self.input_expression_str = input_expression_str self.input_expression_str = input_expression_str
self.token_list = [] self.token_list: List[Token] = []
self.staged_characters = "" self.staged_characters = ""
@classmethod @classmethod
def make_list(cls, input_expression_str): def make_list(cls, input_expression_str: str) -> List[Token]:
assert isinstance(input_expression_str, str) assert isinstance(input_expression_str, str)
return ExpressionTokenizer(input_expression_str)._make_list() return ExpressionTokenizer(input_expression_str)._make_list()
def add_token(self, token_type, token_value): def add_token(self, token_type: Union[int, str], token_value: str) -> None:
self.token_list.append(Token(token_type, token_value)) self.token_list.append(Token(token_type, token_value))
def add_token_from_stage(self, token_type): def add_token_from_stage(self, token_type: int) -> None:
self.add_token(token_type, self.staged_characters) self.add_token(token_type, self.staged_characters)
self.staged_characters = "" self.staged_characters = ""
@classmethod @classmethod
def is_numeric(cls, input_str): def is_numeric(cls, input_str: str) -> bool:
return re.compile("[0-9]+").match(input_str) is not None return re.compile("[0-9]+").match(input_str) is not None
def process_staged_characters(self): def process_staged_characters(self) -> None:
if len(self.staged_characters) == 0: if len(self.staged_characters) == 0:
return return
if self.staged_characters.startswith("#"): if self.staged_characters.startswith("#"):
@ -179,7 +180,7 @@ class ExpressionTokenizer(object):
else: else:
self.raise_unexpected_token() self.raise_unexpected_token()
def _make_list(self): def _make_list(self) -> List[Token]:
""" """
Just go through characters if a character is not a token boundary stage it for adding it as a grouped token Just go through characters if a character is not a token boundary stage it for adding it as a grouped token
later if it is a tokenboundary process staged characters and then process the token boundary as well. later if it is a tokenboundary process staged characters and then process the token boundary as well.

View File

@ -3,6 +3,7 @@ See docstring class Validator below for more details on validation
""" """
from abc import abstractmethod from abc import abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Any, Callable, Dict, List, Type, Union
from moto.dynamodb.exceptions import ( from moto.dynamodb.exceptions import (
AttributeIsReservedKeyword, AttributeIsReservedKeyword,
@ -15,9 +16,12 @@ from moto.dynamodb.exceptions import (
EmptyKeyAttributeException, EmptyKeyAttributeException,
UpdateHashRangeKeyException, UpdateHashRangeKeyException,
) )
from moto.dynamodb.models import DynamoType from moto.dynamodb.models.dynamo_type import DynamoType, Item
from moto.dynamodb.parsing.ast_nodes import ( from moto.dynamodb.models.table import Table
from moto.dynamodb.parsing.ast_nodes import ( # type: ignore
Node,
ExpressionAttribute, ExpressionAttribute,
UpdateExpressionClause,
UpdateExpressionPath, UpdateExpressionPath,
UpdateExpressionSetAction, UpdateExpressionSetAction,
UpdateExpressionAddAction, UpdateExpressionAddAction,
@ -37,16 +41,23 @@ from moto.dynamodb.parsing.ast_nodes import (
from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords
class ExpressionAttributeValueProcessor(DepthFirstTraverser): class ExpressionAttributeValueProcessor(DepthFirstTraverser): # type: ignore[misc]
def __init__(self, expression_attribute_values): def __init__(self, expression_attribute_values: Dict[str, Dict[str, Any]]):
self.expression_attribute_values = expression_attribute_values self.expression_attribute_values = expression_attribute_values
def _processing_map(self): def _processing_map(
self,
) -> Dict[
Type[ExpressionAttributeValue],
Callable[[ExpressionAttributeValue], DDBTypedValue],
]:
return { return {
ExpressionAttributeValue: self.replace_expression_attribute_value_with_value ExpressionAttributeValue: self.replace_expression_attribute_value_with_value
} }
def replace_expression_attribute_value_with_value(self, node): def replace_expression_attribute_value_with_value(
self, node: ExpressionAttributeValue
) -> DDBTypedValue:
"""A node representing an Expression Attribute Value. Resolve and replace value""" """A node representing an Expression Attribute Value. Resolve and replace value"""
assert isinstance(node, ExpressionAttributeValue) assert isinstance(node, ExpressionAttributeValue)
attribute_value_name = node.get_value_name() attribute_value_name = node.get_value_name()
@ -59,20 +70,24 @@ class ExpressionAttributeValueProcessor(DepthFirstTraverser):
return DDBTypedValue(DynamoType(target)) return DDBTypedValue(DynamoType(target))
class ExpressionPathResolver(object): class ExpressionPathResolver:
def __init__(self, expression_attribute_names): def __init__(self, expression_attribute_names: Dict[str, str]):
self.expression_attribute_names = expression_attribute_names self.expression_attribute_names = expression_attribute_names
@classmethod @classmethod
def raise_exception_if_keyword(cls, attribute): def raise_exception_if_keyword(cls, attribute: Any) -> None: # type: ignore[misc]
if attribute.upper() in ReservedKeywords.get_reserved_keywords(): if attribute.upper() in ReservedKeywords.get_reserved_keywords():
raise AttributeIsReservedKeyword(attribute) raise AttributeIsReservedKeyword(attribute)
def resolve_expression_path(self, item, update_expression_path): def resolve_expression_path(
self, item: Item, update_expression_path: UpdateExpressionPath
) -> Union[NoneExistingPath, DDBTypedValue]:
assert isinstance(update_expression_path, UpdateExpressionPath) assert isinstance(update_expression_path, UpdateExpressionPath)
return self.resolve_expression_path_nodes(item, update_expression_path.children) return self.resolve_expression_path_nodes(item, update_expression_path.children)
def resolve_expression_path_nodes(self, item, update_expression_path_nodes): def resolve_expression_path_nodes(
self, item: Item, update_expression_path_nodes: List[Node]
) -> Union[NoneExistingPath, DDBTypedValue]:
target = item.attrs target = item.attrs
for child in update_expression_path_nodes: for child in update_expression_path_nodes:
@ -100,7 +115,7 @@ class ExpressionPathResolver(object):
continue continue
elif isinstance(child, ExpressionSelector): elif isinstance(child, ExpressionSelector):
index = child.get_index() index = child.get_index()
if target.is_list(): if target.is_list(): # type: ignore
try: try:
target = target[index] target = target[index]
except IndexError: except IndexError:
@ -116,8 +131,8 @@ class ExpressionPathResolver(object):
return DDBTypedValue(target) return DDBTypedValue(target)
def resolve_expression_path_nodes_to_dynamo_type( def resolve_expression_path_nodes_to_dynamo_type(
self, item, update_expression_path_nodes self, item: Item, update_expression_path_nodes: List[Node]
): ) -> Any:
node = self.resolve_expression_path_nodes(item, update_expression_path_nodes) node = self.resolve_expression_path_nodes(item, update_expression_path_nodes)
if isinstance(node, NoneExistingPath): if isinstance(node, NoneExistingPath):
raise ProvidedKeyDoesNotExist() raise ProvidedKeyDoesNotExist()
@ -125,19 +140,30 @@ class ExpressionPathResolver(object):
return node.get_value() return node.get_value()
class ExpressionAttributeResolvingProcessor(DepthFirstTraverser): class ExpressionAttributeResolvingProcessor(DepthFirstTraverser): # type: ignore[misc]
def _processing_map(self): def _processing_map(
self,
) -> Dict[Type[UpdateExpressionClause], Callable[[DDBTypedValue], DDBTypedValue]]:
return { return {
UpdateExpressionSetAction: self.disable_resolving, UpdateExpressionSetAction: self.disable_resolving,
UpdateExpressionPath: self.process_expression_path_node, UpdateExpressionPath: self.process_expression_path_node,
} }
def __init__(self, expression_attribute_names, item): def __init__(self, expression_attribute_names: Dict[str, str], item: Item):
self.expression_attribute_names = expression_attribute_names self.expression_attribute_names = expression_attribute_names
self.item = item self.item = item
self.resolving = False self.resolving = False
def pre_processing_of_child(self, parent_node, child_id): def pre_processing_of_child(
self,
parent_node: Union[
UpdateExpressionSetAction,
UpdateExpressionRemoveAction,
UpdateExpressionDeleteAction,
UpdateExpressionAddAction,
],
child_id: int,
) -> None:
""" """
We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first. We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first.
Because first argument is path to be set, 2nd argument would be the value. Because first argument is path to be set, 2nd argument would be the value.
@ -156,11 +182,11 @@ class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
else: else:
self.resolving = True self.resolving = True
def disable_resolving(self, node=None): def disable_resolving(self, node: DDBTypedValue) -> DDBTypedValue:
self.resolving = False self.resolving = False
return node return node
def process_expression_path_node(self, node): def process_expression_path_node(self, node: DDBTypedValue) -> DDBTypedValue:
"""Resolve ExpressionAttribute if not part of a path and resolving is enabled.""" """Resolve ExpressionAttribute if not part of a path and resolving is enabled."""
if self.resolving: if self.resolving:
return self.resolve_expression_path(node) return self.resolve_expression_path(node)
@ -175,13 +201,15 @@ class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
return node return node
def resolve_expression_path(self, node): def resolve_expression_path(
self, node: DDBTypedValue
) -> Union[NoneExistingPath, DDBTypedValue]:
return ExpressionPathResolver( return ExpressionPathResolver(
self.expression_attribute_names self.expression_attribute_names
).resolve_expression_path(self.item, node) ).resolve_expression_path(self.item, node)
class UpdateExpressionFunctionEvaluator(DepthFirstTraverser): class UpdateExpressionFunctionEvaluator(DepthFirstTraverser): # type: ignore[misc]
""" """
At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET
expression as per the official AWS docs: expression as per the official AWS docs:
@ -189,10 +217,15 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET
""" """
def _processing_map(self): def _processing_map(
self,
) -> Dict[
Type[UpdateExpressionFunction],
Callable[[UpdateExpressionFunction], DDBTypedValue],
]:
return {UpdateExpressionFunction: self.process_function} return {UpdateExpressionFunction: self.process_function}
def process_function(self, node): def process_function(self, node: UpdateExpressionFunction) -> DDBTypedValue:
assert isinstance(node, UpdateExpressionFunction) assert isinstance(node, UpdateExpressionFunction)
function_name = node.get_function_name() function_name = node.get_function_name()
first_arg = node.get_nth_argument(1) first_arg = node.get_nth_argument(1)
@ -217,7 +250,7 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
raise NotImplementedError(f"Unsupported function for moto {function_name}") raise NotImplementedError(f"Unsupported function for moto {function_name}")
@classmethod @classmethod
def get_list_from_ddb_typed_value(cls, node, function_name): def get_list_from_ddb_typed_value(cls, node: DDBTypedValue, function_name: str) -> DynamoType: # type: ignore[misc]
assert isinstance(node, DDBTypedValue) assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value() dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType) assert isinstance(dynamo_value, DynamoType)
@ -226,23 +259,25 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
return dynamo_value return dynamo_value
class NoneExistingPathChecker(DepthFirstTraverser): class NoneExistingPathChecker(DepthFirstTraverser): # type: ignore[misc]
""" """
Pass through the AST and make sure there are no none-existing paths. Pass through the AST and make sure there are no none-existing paths.
""" """
def _processing_map(self): def _processing_map(self) -> Dict[Type[NoneExistingPath], Callable[[Node], None]]:
return {NoneExistingPath: self.raise_none_existing_path} return {NoneExistingPath: self.raise_none_existing_path}
def raise_none_existing_path(self, node): def raise_none_existing_path(self, node: Node) -> None:
raise AttributeDoesNotExist raise AttributeDoesNotExist
class ExecuteOperations(DepthFirstTraverser): class ExecuteOperations(DepthFirstTraverser): # type: ignore[misc]
def _processing_map(self): def _processing_map(
self,
) -> Dict[Type[UpdateExpressionValue], Callable[[Node], DDBTypedValue]]:
return {UpdateExpressionValue: self.process_update_expression_value} return {UpdateExpressionValue: self.process_update_expression_value}
def process_update_expression_value(self, node): def process_update_expression_value(self, node: Node) -> DDBTypedValue:
""" """
If an UpdateExpressionValue only has a single child the node will be replaced with the childe. If an UpdateExpressionValue only has a single child the node will be replaced with the childe.
Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them
@ -273,14 +308,14 @@ class ExecuteOperations(DepthFirstTraverser):
) )
@classmethod @classmethod
def get_dynamo_value_from_ddb_typed_value(cls, node): def get_dynamo_value_from_ddb_typed_value(cls, node: DDBTypedValue) -> DynamoType: # type: ignore[misc]
assert isinstance(node, DDBTypedValue) assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value() dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType) assert isinstance(dynamo_value, DynamoType)
return dynamo_value return dynamo_value
@classmethod @classmethod
def get_sum(cls, left_operand, right_operand): def get_sum(cls, left_operand: DynamoType, right_operand: DynamoType) -> DDBTypedValue: # type: ignore[misc]
""" """
Args: Args:
left_operand(DynamoType): left_operand(DynamoType):
@ -295,7 +330,7 @@ class ExecuteOperations(DepthFirstTraverser):
raise IncorrectOperandType("+", left_operand.type) raise IncorrectOperandType("+", left_operand.type)
@classmethod @classmethod
def get_subtraction(cls, left_operand, right_operand): def get_subtraction(cls, left_operand: DynamoType, right_operand: DynamoType) -> DDBTypedValue: # type: ignore[misc]
""" """
Args: Args:
left_operand(DynamoType): left_operand(DynamoType):
@ -310,14 +345,21 @@ class ExecuteOperations(DepthFirstTraverser):
raise IncorrectOperandType("-", left_operand.type) raise IncorrectOperandType("-", left_operand.type)
class EmptyStringKeyValueValidator(DepthFirstTraverser): class EmptyStringKeyValueValidator(DepthFirstTraverser): # type: ignore[misc]
def __init__(self, key_attributes): def __init__(self, key_attributes: List[str]):
self.key_attributes = key_attributes self.key_attributes = key_attributes
def _processing_map(self): def _processing_map(
self,
) -> Dict[
Type[UpdateExpressionSetAction],
Callable[[UpdateExpressionSetAction], UpdateExpressionSetAction],
]:
return {UpdateExpressionSetAction: self.check_for_empty_string_key_value} return {UpdateExpressionSetAction: self.check_for_empty_string_key_value}
def check_for_empty_string_key_value(self, node): def check_for_empty_string_key_value(
self, node: UpdateExpressionSetAction
) -> UpdateExpressionSetAction:
"""A node representing a SET action. Check that keys are not being assigned empty strings""" """A node representing a SET action. Check that keys are not being assigned empty strings"""
assert isinstance(node, UpdateExpressionSetAction) assert isinstance(node, UpdateExpressionSetAction)
assert len(node.children) == 2 assert len(node.children) == 2
@ -332,15 +374,26 @@ class EmptyStringKeyValueValidator(DepthFirstTraverser):
return node return node
class UpdateHashRangeKeyValidator(DepthFirstTraverser): class UpdateHashRangeKeyValidator(DepthFirstTraverser): # type: ignore[misc]
def __init__(self, table_key_attributes, expression_attribute_names): def __init__(
self,
table_key_attributes: List[str],
expression_attribute_names: Dict[str, str],
):
self.table_key_attributes = table_key_attributes self.table_key_attributes = table_key_attributes
self.expression_attribute_names = expression_attribute_names self.expression_attribute_names = expression_attribute_names
def _processing_map(self): def _processing_map(
self,
) -> Dict[
Type[UpdateExpressionPath],
Callable[[UpdateExpressionPath], UpdateExpressionPath],
]:
return {UpdateExpressionPath: self.check_for_hash_or_range_key} return {UpdateExpressionPath: self.check_for_hash_or_range_key}
def check_for_hash_or_range_key(self, node): def check_for_hash_or_range_key(
self, node: UpdateExpressionPath
) -> UpdateExpressionPath:
"""Check that hash and range keys are not updated""" """Check that hash and range keys are not updated"""
key_to_update = node.children[0].children[0] key_to_update = node.children[0].children[0]
key_to_update = self.expression_attribute_names.get( key_to_update = self.expression_attribute_names.get(
@ -351,18 +404,18 @@ class UpdateHashRangeKeyValidator(DepthFirstTraverser):
return node return node
class Validator(object): class Validator:
""" """
A validator is used to validate expressions which are passed in as an AST. A validator is used to validate expressions which are passed in as an AST.
""" """
def __init__( def __init__(
self, self,
expression, expression: Node,
expression_attribute_names, expression_attribute_names: Dict[str, str],
expression_attribute_values, expression_attribute_values: Dict[str, Dict[str, Any]],
item, item: Item,
table, table: Table,
): ):
""" """
Besides validation the Validator should also replace referenced parts of an item which is cheapest upon Besides validation the Validator should also replace referenced parts of an item which is cheapest upon
@ -382,10 +435,10 @@ class Validator(object):
self.node_to_validate = deepcopy(expression) self.node_to_validate = deepcopy(expression)
@abstractmethod @abstractmethod
def get_ast_processors(self): def get_ast_processors(self) -> List[DepthFirstTraverser]: # type: ignore[misc]
"""Get the different processors that go through the AST tree and processes the nodes.""" """Get the different processors that go through the AST tree and processes the nodes."""
def validate(self): def validate(self) -> Node:
n = self.node_to_validate n = self.node_to_validate
for processor in self.processors: for processor in self.processors:
n = processor.traverse(n) n = processor.traverse(n)
@ -393,7 +446,7 @@ class Validator(object):
class UpdateExpressionValidator(Validator): class UpdateExpressionValidator(Validator):
def get_ast_processors(self): def get_ast_processors(self) -> List[DepthFirstTraverser]:
"""Get the different processors that go through the AST tree and processes the nodes.""" """Get the different processors that go through the AST tree and processes the nodes."""
processors = [ processors = [
UpdateHashRangeKeyValidator( UpdateHashRangeKeyValidator(

View File

@ -3,7 +3,9 @@ import json
import itertools import itertools
from functools import wraps from functools import wraps
from typing import Any, Dict, List, Union, Callable, Optional
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from moto.dynamodb.parsing.key_condition_expression import parse_expression from moto.dynamodb.parsing.key_condition_expression import parse_expression
@ -13,17 +15,27 @@ from .exceptions import (
ResourceNotFoundException, ResourceNotFoundException,
ConditionalCheckFailed, ConditionalCheckFailed,
) )
from moto.dynamodb.models import dynamodb_backends, dynamo_json_dump, Table from moto.dynamodb.models import dynamodb_backends, Table, DynamoDBBackend
from moto.dynamodb.models.utilities import dynamo_json_dump
from moto.utilities.aws_headers import amz_crc32, amzn_request_id from moto.utilities.aws_headers import amz_crc32, amzn_request_id
TRANSACTION_MAX_ITEMS = 25 TRANSACTION_MAX_ITEMS = 25
def include_consumed_capacity(val=1.0): def include_consumed_capacity(
def _inner(f): val: float = 1.0,
) -> Callable[
[Callable[["DynamoHandler"], str]],
Callable[["DynamoHandler"], Union[str, TYPE_RESPONSE]],
]:
def _inner(
f: Callable[..., Union[str, TYPE_RESPONSE]]
) -> Callable[["DynamoHandler"], Union[str, TYPE_RESPONSE]]:
@wraps(f) @wraps(f)
def _wrapper(*args, **kwargs): def _wrapper(
*args: "DynamoHandler", **kwargs: None
) -> Union[str, TYPE_RESPONSE]:
(handler,) = args (handler,) = args
expected_capacity = handler.body.get("ReturnConsumedCapacity", "NONE") expected_capacity = handler.body.get("ReturnConsumedCapacity", "NONE")
if expected_capacity not in ["NONE", "TOTAL", "INDEXES"]: if expected_capacity not in ["NONE", "TOTAL", "INDEXES"]:
@ -67,7 +79,7 @@ def include_consumed_capacity(val=1.0):
return _inner return _inner
def get_empty_keys_on_put(field_updates, table: Table): def get_empty_keys_on_put(field_updates: Dict[str, Any], table: Table) -> Optional[str]:
""" """
Return the first key-name that has an empty value. None if all keys are filled Return the first key-name that has an empty value. None if all keys are filled
""" """
@ -83,12 +95,12 @@ def get_empty_keys_on_put(field_updates, table: Table):
return next( return next(
(keyname for keyname in key_names if keyname in empty_str_fields), None (keyname for keyname in key_names if keyname in empty_str_fields), None
) )
return False return None
def put_has_empty_attrs(field_updates, table): def put_has_empty_attrs(field_updates: Dict[str, Any], table: Table) -> bool:
# Example invalid attribute: [{'M': {'SS': {'NS': []}}}] # Example invalid attribute: [{'M': {'SS': {'NS': []}}}]
def _validate_attr(attr: dict): def _validate_attr(attr: Dict[str, Any]) -> bool:
if "NS" in attr and attr["NS"] == []: if "NS" in attr and attr["NS"] == []:
return True return True
else: else:
@ -105,7 +117,7 @@ def put_has_empty_attrs(field_updates, table):
return False return False
def validate_put_has_gsi_keys_set_to_none(item, table: Table) -> None: def validate_put_has_gsi_keys_set_to_none(item: Dict[str, Any], table: Table) -> None:
for gsi in table.global_indexes: for gsi in table.global_indexes:
for attr in gsi.schema: for attr in gsi.schema:
attr_name = attr["AttributeName"] attr_name = attr["AttributeName"]
@ -115,7 +127,7 @@ def validate_put_has_gsi_keys_set_to_none(item, table: Table) -> None:
) )
def check_projection_expression(expression): def check_projection_expression(expression: str) -> None:
if expression.upper() in ReservedKeywords.get_reserved_keywords(): if expression.upper() in ReservedKeywords.get_reserved_keywords():
raise MockValidationException( raise MockValidationException(
f"ProjectionExpression: Attribute name is a reserved keyword; reserved keyword: {expression}" f"ProjectionExpression: Attribute name is a reserved keyword; reserved keyword: {expression}"
@ -131,10 +143,10 @@ def check_projection_expression(expression):
class DynamoHandler(BaseResponse): class DynamoHandler(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="dynamodb") super().__init__(service_name="dynamodb")
def get_endpoint_name(self, headers): def get_endpoint_name(self, headers: Any) -> Optional[str]:
"""Parses request headers and extracts part od the X-Amz-Target """Parses request headers and extracts part od the X-Amz-Target
that corresponds to a method of DynamoHandler that corresponds to a method of DynamoHandler
@ -144,9 +156,10 @@ class DynamoHandler(BaseResponse):
match = headers.get("x-amz-target") or headers.get("X-Amz-Target") match = headers.get("x-amz-target") or headers.get("X-Amz-Target")
if match: if match:
return match.split(".")[1] return match.split(".")[1]
return None
@property @property
def dynamodb_backend(self): def dynamodb_backend(self) -> DynamoDBBackend:
""" """
:return: DynamoDB Backend :return: DynamoDB Backend
:rtype: moto.dynamodb.models.DynamoDBBackend :rtype: moto.dynamodb.models.DynamoDBBackend
@ -155,7 +168,7 @@ class DynamoHandler(BaseResponse):
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def call_action(self): def call_action(self) -> TYPE_RESPONSE:
self.body = json.loads(self.body or "{}") self.body = json.loads(self.body or "{}")
endpoint = self.get_endpoint_name(self.headers) endpoint = self.get_endpoint_name(self.headers)
if endpoint: if endpoint:
@ -171,7 +184,7 @@ class DynamoHandler(BaseResponse):
else: else:
return 404, self.response_headers, "" return 404, self.response_headers, ""
def list_tables(self): def list_tables(self) -> str:
body = self.body body = self.body
limit = body.get("Limit", 100) limit = body.get("Limit", 100)
exclusive_start_table_name = body.get("ExclusiveStartTableName") exclusive_start_table_name = body.get("ExclusiveStartTableName")
@ -179,13 +192,13 @@ class DynamoHandler(BaseResponse):
limit, exclusive_start_table_name limit, exclusive_start_table_name
) )
response = {"TableNames": tables} response: Dict[str, Any] = {"TableNames": tables}
if last_eval: if last_eval:
response["LastEvaluatedTableName"] = last_eval response["LastEvaluatedTableName"] = last_eval
return dynamo_json_dump(response) return dynamo_json_dump(response)
def create_table(self): def create_table(self) -> str:
body = self.body body = self.body
# get the table name # get the table name
table_name = body["TableName"] table_name = body["TableName"]
@ -243,7 +256,7 @@ class DynamoHandler(BaseResponse):
actual_attrs = [item["AttributeName"] for item in attr] actual_attrs = [item["AttributeName"] for item in attr]
actual_attrs.sort() actual_attrs.sort()
if actual_attrs != expected_attrs: if actual_attrs != expected_attrs:
return self._throw_attr_error( self._throw_attr_error(
actual_attrs, expected_attrs, global_indexes or local_secondary_indexes actual_attrs, expected_attrs, global_indexes or local_secondary_indexes
) )
# get the stream specification # get the stream specification
@ -265,8 +278,10 @@ class DynamoHandler(BaseResponse):
) )
return dynamo_json_dump(table.describe()) return dynamo_json_dump(table.describe())
def _throw_attr_error(self, actual_attrs, expected_attrs, indexes): def _throw_attr_error(
def dump_list(list_): self, actual_attrs: List[str], expected_attrs: List[str], indexes: bool
) -> None:
def dump_list(list_: List[str]) -> str:
return str(list_).replace("'", "") return str(list_).replace("'", "")
err_head = "One or more parameter values were invalid: " err_head = "One or more parameter values were invalid: "
@ -315,28 +330,28 @@ class DynamoHandler(BaseResponse):
+ dump_list(actual_attrs) + dump_list(actual_attrs)
) )
def delete_table(self): def delete_table(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
table = self.dynamodb_backend.delete_table(name) table = self.dynamodb_backend.delete_table(name)
return dynamo_json_dump(table.describe()) return dynamo_json_dump(table.describe())
def describe_endpoints(self): def describe_endpoints(self) -> str:
response = {"Endpoints": self.dynamodb_backend.describe_endpoints()} response = {"Endpoints": self.dynamodb_backend.describe_endpoints()}
return dynamo_json_dump(response) return dynamo_json_dump(response)
def tag_resource(self): def tag_resource(self) -> str:
table_arn = self.body["ResourceArn"] table_arn = self.body["ResourceArn"]
tags = self.body["Tags"] tags = self.body["Tags"]
self.dynamodb_backend.tag_resource(table_arn, tags) self.dynamodb_backend.tag_resource(table_arn, tags)
return "" return ""
def untag_resource(self): def untag_resource(self) -> str:
table_arn = self.body["ResourceArn"] table_arn = self.body["ResourceArn"]
tags = self.body["TagKeys"] tags = self.body["TagKeys"]
self.dynamodb_backend.untag_resource(table_arn, tags) self.dynamodb_backend.untag_resource(table_arn, tags)
return "" return ""
def list_tags_of_resource(self): def list_tags_of_resource(self) -> str:
table_arn = self.body["ResourceArn"] table_arn = self.body["ResourceArn"]
all_tags = self.dynamodb_backend.list_tags_of_resource(table_arn) all_tags = self.dynamodb_backend.list_tags_of_resource(table_arn)
all_tag_keys = [tag["Key"] for tag in all_tags] all_tag_keys = [tag["Key"] for tag in all_tags]
@ -354,7 +369,7 @@ class DynamoHandler(BaseResponse):
return json.dumps({"Tags": tags_resp, "NextToken": next_marker}) return json.dumps({"Tags": tags_resp, "NextToken": next_marker})
return json.dumps({"Tags": tags_resp}) return json.dumps({"Tags": tags_resp})
def update_table(self): def update_table(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
attr_definitions = self.body.get("AttributeDefinitions", None) attr_definitions = self.body.get("AttributeDefinitions", None)
global_index = self.body.get("GlobalSecondaryIndexUpdates", None) global_index = self.body.get("GlobalSecondaryIndexUpdates", None)
@ -371,13 +386,13 @@ class DynamoHandler(BaseResponse):
) )
return dynamo_json_dump(table.describe()) return dynamo_json_dump(table.describe())
def describe_table(self): def describe_table(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
table = self.dynamodb_backend.describe_table(name) table = self.dynamodb_backend.describe_table(name)
return dynamo_json_dump(table) return dynamo_json_dump(table)
@include_consumed_capacity() @include_consumed_capacity()
def put_item(self): def put_item(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
item = self.body["Item"] item = self.body["Item"]
return_values = self.body.get("ReturnValues", "NONE") return_values = self.body.get("ReturnValues", "NONE")
@ -436,7 +451,7 @@ class DynamoHandler(BaseResponse):
item_dict.pop("Attributes", None) item_dict.pop("Attributes", None)
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
def batch_write_item(self): def batch_write_item(self) -> str:
table_batches = self.body["RequestItems"] table_batches = self.body["RequestItems"]
put_requests = [] put_requests = []
delete_requests = [] delete_requests = []
@ -478,7 +493,7 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(response) return dynamo_json_dump(response)
@include_consumed_capacity(0.5) @include_consumed_capacity(0.5)
def get_item(self): def get_item(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
self.dynamodb_backend.get_table(name) self.dynamodb_backend.get_table(name)
key = self.body["Key"] key = self.body["Key"]
@ -520,10 +535,14 @@ class DynamoHandler(BaseResponse):
# Item not found # Item not found
return dynamo_json_dump({}) return dynamo_json_dump({})
def batch_get_item(self): def batch_get_item(self) -> str:
table_batches = self.body["RequestItems"] table_batches = self.body["RequestItems"]
results = {"ConsumedCapacity": [], "Responses": {}, "UnprocessedKeys": {}} results: Dict[str, Any] = {
"ConsumedCapacity": [],
"Responses": {},
"UnprocessedKeys": {},
}
# Validation: Can only request up to 100 items at the same time # Validation: Can only request up to 100 items at the same time
# Scenario 1: We're requesting more than a 100 keys from a single table # Scenario 1: We're requesting more than a 100 keys from a single table
@ -582,7 +601,7 @@ class DynamoHandler(BaseResponse):
) )
return dynamo_json_dump(results) return dynamo_json_dump(results)
def _contains_duplicates(self, keys): def _contains_duplicates(self, keys: List[str]) -> bool:
unique_keys = [] unique_keys = []
for k in keys: for k in keys:
if k in unique_keys: if k in unique_keys:
@ -592,7 +611,7 @@ class DynamoHandler(BaseResponse):
return False return False
@include_consumed_capacity() @include_consumed_capacity()
def query(self): def query(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
key_condition_expression = self.body.get("KeyConditionExpression") key_condition_expression = self.body.get("KeyConditionExpression")
projection_expression = self.body.get("ProjectionExpression") projection_expression = self.body.get("ProjectionExpression")
@ -676,7 +695,7 @@ class DynamoHandler(BaseResponse):
**filter_kwargs, **filter_kwargs,
) )
result = { result: Dict[str, Any] = {
"Count": len(items), "Count": len(items),
"ScannedCount": scanned_count, "ScannedCount": scanned_count,
} }
@ -689,8 +708,10 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(result) return dynamo_json_dump(result)
def _adjust_projection_expression(self, projection_expression, expr_attr_names): def _adjust_projection_expression(
def _adjust(expression): self, projection_expression: str, expr_attr_names: Dict[str, str]
) -> str:
def _adjust(expression: str) -> str:
return ( return (
expr_attr_names[expression] expr_attr_names[expression]
if expression in expr_attr_names if expression in expr_attr_names
@ -712,7 +733,7 @@ class DynamoHandler(BaseResponse):
return projection_expression return projection_expression
@include_consumed_capacity() @include_consumed_capacity()
def scan(self): def scan(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
filters = {} filters = {}
@ -760,7 +781,7 @@ class DynamoHandler(BaseResponse):
result["LastEvaluatedKey"] = last_evaluated_key result["LastEvaluatedKey"] = last_evaluated_key
return dynamo_json_dump(result) return dynamo_json_dump(result)
def delete_item(self): def delete_item(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
key = self.body["Key"] key = self.body["Key"]
return_values = self.body.get("ReturnValues", "NONE") return_values = self.body.get("ReturnValues", "NONE")
@ -795,7 +816,7 @@ class DynamoHandler(BaseResponse):
item_dict["ConsumedCapacityUnits"] = 0.5 item_dict["ConsumedCapacityUnits"] = 0.5
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
def update_item(self): def update_item(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
key = self.body["Key"] key = self.body["Key"]
return_values = self.body.get("ReturnValues", "NONE") return_values = self.body.get("ReturnValues", "NONE")
@ -870,7 +891,7 @@ class DynamoHandler(BaseResponse):
) )
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
def _build_updated_new_attributes(self, original, changed): def _build_updated_new_attributes(self, original: Any, changed: Any) -> Any:
if type(changed) != type(original): if type(changed) != type(original):
return changed return changed
else: else:
@ -895,7 +916,7 @@ class DynamoHandler(BaseResponse):
else: else:
return changed return changed
def describe_limits(self): def describe_limits(self) -> str:
return json.dumps( return json.dumps(
{ {
"AccountMaxReadCapacityUnits": 20000, "AccountMaxReadCapacityUnits": 20000,
@ -905,7 +926,7 @@ class DynamoHandler(BaseResponse):
} }
) )
def update_time_to_live(self): def update_time_to_live(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
ttl_spec = self.body["TimeToLiveSpecification"] ttl_spec = self.body["TimeToLiveSpecification"]
@ -913,16 +934,16 @@ class DynamoHandler(BaseResponse):
return json.dumps({"TimeToLiveSpecification": ttl_spec}) return json.dumps({"TimeToLiveSpecification": ttl_spec})
def describe_time_to_live(self): def describe_time_to_live(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
ttl_spec = self.dynamodb_backend.describe_time_to_live(name) ttl_spec = self.dynamodb_backend.describe_time_to_live(name)
return json.dumps({"TimeToLiveDescription": ttl_spec}) return json.dumps({"TimeToLiveDescription": ttl_spec})
def transact_get_items(self): def transact_get_items(self) -> str:
transact_items = self.body["TransactItems"] transact_items = self.body["TransactItems"]
responses = list() responses: List[Dict[str, Any]] = list()
if len(transact_items) > TRANSACTION_MAX_ITEMS: if len(transact_items) > TRANSACTION_MAX_ITEMS:
msg = "1 validation error detected: Value '[" msg = "1 validation error detected: Value '["
@ -945,7 +966,7 @@ class DynamoHandler(BaseResponse):
raise MockValidationException(msg) raise MockValidationException(msg)
ret_consumed_capacity = self.body.get("ReturnConsumedCapacity", "NONE") ret_consumed_capacity = self.body.get("ReturnConsumedCapacity", "NONE")
consumed_capacity = dict() consumed_capacity: Dict[str, Any] = dict()
for transact_item in transact_items: for transact_item in transact_items:
@ -957,7 +978,7 @@ class DynamoHandler(BaseResponse):
responses.append({}) responses.append({})
continue continue
item_describe = item.describe_attrs(False) item_describe = item.describe_attrs(attributes=None)
responses.append(item_describe) responses.append(item_describe)
table_capacity = consumed_capacity.get(table_name, {}) table_capacity = consumed_capacity.get(table_name, {})
@ -981,20 +1002,20 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(result) return dynamo_json_dump(result)
def transact_write_items(self): def transact_write_items(self) -> str:
transact_items = self.body["TransactItems"] transact_items = self.body["TransactItems"]
self.dynamodb_backend.transact_write_items(transact_items) self.dynamodb_backend.transact_write_items(transact_items)
response = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}} response: Dict[str, Any] = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}}
return dynamo_json_dump(response) return dynamo_json_dump(response)
def describe_continuous_backups(self): def describe_continuous_backups(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
response = self.dynamodb_backend.describe_continuous_backups(name) response = self.dynamodb_backend.describe_continuous_backups(name)
return json.dumps({"ContinuousBackupsDescription": response}) return json.dumps({"ContinuousBackupsDescription": response})
def update_continuous_backups(self): def update_continuous_backups(self) -> str:
name = self.body["TableName"] name = self.body["TableName"]
point_in_time_spec = self.body["PointInTimeRecoverySpecification"] point_in_time_spec = self.body["PointInTimeRecoverySpecification"]
@ -1004,14 +1025,14 @@ class DynamoHandler(BaseResponse):
return json.dumps({"ContinuousBackupsDescription": response}) return json.dumps({"ContinuousBackupsDescription": response})
def list_backups(self): def list_backups(self) -> str:
body = self.body body = self.body
table_name = body.get("TableName") table_name = body.get("TableName")
backups = self.dynamodb_backend.list_backups(table_name) backups = self.dynamodb_backend.list_backups(table_name)
response = {"BackupSummaries": [backup.summary for backup in backups]} response = {"BackupSummaries": [backup.summary for backup in backups]}
return dynamo_json_dump(response) return dynamo_json_dump(response)
def create_backup(self): def create_backup(self) -> str:
body = self.body body = self.body
table_name = body.get("TableName") table_name = body.get("TableName")
backup_name = body.get("BackupName") backup_name = body.get("BackupName")
@ -1019,21 +1040,21 @@ class DynamoHandler(BaseResponse):
response = {"BackupDetails": backup.details} response = {"BackupDetails": backup.details}
return dynamo_json_dump(response) return dynamo_json_dump(response)
def delete_backup(self): def delete_backup(self) -> str:
body = self.body body = self.body
backup_arn = body.get("BackupArn") backup_arn = body.get("BackupArn")
backup = self.dynamodb_backend.delete_backup(backup_arn) backup = self.dynamodb_backend.delete_backup(backup_arn)
response = {"BackupDescription": backup.description} response = {"BackupDescription": backup.description}
return dynamo_json_dump(response) return dynamo_json_dump(response)
def describe_backup(self): def describe_backup(self) -> str:
body = self.body body = self.body
backup_arn = body.get("BackupArn") backup_arn = body.get("BackupArn")
backup = self.dynamodb_backend.describe_backup(backup_arn) backup = self.dynamodb_backend.describe_backup(backup_arn)
response = {"BackupDescription": backup.description} response = {"BackupDescription": backup.description}
return dynamo_json_dump(response) return dynamo_json_dump(response)
def restore_table_from_backup(self): def restore_table_from_backup(self) -> str:
body = self.body body = self.body
target_table_name = body.get("TargetTableName") target_table_name = body.get("TargetTableName")
backup_arn = body.get("BackupArn") backup_arn = body.get("BackupArn")
@ -1042,7 +1063,7 @@ class DynamoHandler(BaseResponse):
) )
return dynamo_json_dump(restored_table.describe()) return dynamo_json_dump(restored_table.describe())
def restore_table_to_point_in_time(self): def restore_table_to_point_in_time(self) -> str:
body = self.body body = self.body
target_table_name = body.get("TargetTableName") target_table_name = body.get("TargetTableName")
source_table_name = body.get("SourceTableName") source_table_name = body.get("SourceTableName")

View File

@ -4,8 +4,9 @@ import base64
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.dynamodb.models import dynamodb_backends, DynamoJsonEncoder, DynamoDBBackend from moto.dynamodb.models import dynamodb_backends, DynamoDBBackend
from moto.dynamodb.models import Table, StreamShard from moto.dynamodb.models.table import Table, StreamShard
from moto.dynamodb.models.utilities import DynamoJsonEncoder
class ShardIterator(BaseModel): class ShardIterator(BaseModel):
@ -86,7 +87,7 @@ class DynamoDBStreamsBackend(BaseBackend):
"StreamStatus": ( "StreamStatus": (
"ENABLED" if table.latest_stream_label else "DISABLED" "ENABLED" if table.latest_stream_label else "DISABLED"
), ),
"StreamViewType": table.stream_specification["StreamViewType"], "StreamViewType": table.stream_specification["StreamViewType"], # type: ignore[index]
"CreationRequestDateTime": table.stream_shard.created_on.isoformat(), # type: ignore[union-attr] "CreationRequestDateTime": table.stream_shard.created_on.isoformat(), # type: ignore[union-attr]
"TableName": table.name, "TableName": table.name,
"KeySchema": table.schema, "KeySchema": table.schema,

View File

@ -4,17 +4,17 @@ class GenericTokenizer:
The final character to be returned will be an empty string, to notify the caller that we've reached the end. The final character to be returned will be an empty string, to notify the caller that we've reached the end.
""" """
def __init__(self, expression): def __init__(self, expression: str):
self.expression = expression self.expression = expression
self.token_pos = 0 self.token_pos = 0
def __iter__(self): def __iter__(self) -> "GenericTokenizer":
return self return self
def is_eof(self): def is_eof(self) -> bool:
return self.peek() == "" return self.peek() == ""
def peek(self, length=1): def peek(self, length: int = 1) -> str:
""" """
Peek the next character without changing the position Peek the next character without changing the position
""" """
@ -23,7 +23,7 @@ class GenericTokenizer:
except IndexError: except IndexError:
return "" return ""
def __next__(self): def __next__(self) -> str:
""" """
Returns the next character, or an empty string if we've reached the end of the string. Returns the next character, or an empty string if we've reached the end of the string.
Calling this method again will result in a StopIterator Calling this method again will result in a StopIterator
@ -38,7 +38,7 @@ class GenericTokenizer:
return "" return ""
raise StopIteration raise StopIteration
def skip_characters(self, phrase, case_sensitive=False) -> None: def skip_characters(self, phrase: str, case_sensitive: bool = False) -> None:
""" """
Skip the characters in the supplied phrase. Skip the characters in the supplied phrase.
If any other character is encountered instead, this will fail. If any other character is encountered instead, this will fail.
@ -51,7 +51,7 @@ class GenericTokenizer:
assert self.expression[self.token_pos] in [ch.lower(), ch.upper()] assert self.expression[self.token_pos] in [ch.lower(), ch.upper()]
self.token_pos += 1 self.token_pos += 1
def skip_white_space(self): def skip_white_space(self) -> None:
""" """
Skip any whitespace characters that are coming up Skip any whitespace characters that are coming up
""" """

View File

@ -13,14 +13,19 @@ def str2bool(v):
return False return False
def load_resource(package: str, resource: str, as_json: bool = True) -> Dict[str, Any]: def load_resource(package: str, resource: str) -> Dict[str, Any]:
""" """
Open a file, and return the contents as JSON. Open a file, and return the contents as JSON.
Usage: Usage:
load_resource(__name__, "resources/file.json") load_resource(__name__, "resources/file.json")
""" """
resource = pkgutil.get_data(package, resource) resource = pkgutil.get_data(package, resource)
return json.loads(resource) if as_json else resource.decode("utf-8") return json.loads(resource)
def load_resource_as_str(package: str, resource: str) -> str:
resource = pkgutil.get_data(package, resource)
return resource.decode("utf-8")
def merge_multiple_dicts(*args: Any) -> Dict[str, any]: def merge_multiple_dicts(*args: Any) -> Dict[str, any]:

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/ds,moto/dynamodb_v20111205,moto/dynamodbstreams,moto/moto_api files= moto/a*,moto/b*,moto/c*,moto/d*,moto/moto_api
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract