Techdebt: MyPy DynamoDB (#5863)

This commit is contained in:
Bert Blommers 2023-01-22 21:06:41 -01:00 committed by GitHub
parent 90150d30c6
commit 00ad788975
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1856 additions and 1640 deletions

View File

@ -1,12 +1,17 @@
import re
from collections import deque
from collections import namedtuple
from collections import deque, namedtuple
from typing import Any, Dict, List, Tuple, Deque, Optional, Iterable, Union
from moto.dynamodb.models.dynamo_type import Item
from moto.dynamodb.exceptions import ConditionAttributeIsReservedKeyword
from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords
def get_filter_expression(expr, names, values):
def get_filter_expression(
expr: Optional[str],
names: Optional[Dict[str, str]],
values: Optional[Dict[str, str]],
) -> Union["Op", "Func"]:
"""
Parse a filter expression into an Op.
@ -18,7 +23,7 @@ def get_filter_expression(expr, names, values):
return parser.parse()
def get_expected(expected):
def get_expected(expected: Dict[str, Any]) -> Union["Op", "Func"]:
"""
Parse a filter expression into an Op.
@ -26,7 +31,7 @@ def get_expected(expected):
expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)'
expr = 'Id > 5 AND Subs < 7'
"""
ops = {
ops: Dict[str, Any] = {
"EQ": OpEqual,
"NE": OpNotEqual,
"LE": OpLessThanOrEqual,
@ -43,7 +48,7 @@ def get_expected(expected):
}
# NOTE: Always uses ConditionalOperator=AND
conditions = []
conditions: List[Union["Op", "Func"]] = []
for key, cond in expected.items():
path = AttributePath([key])
if "Exists" in cond:
@ -66,26 +71,28 @@ def get_expected(expected):
for condition in conditions[1:]:
output = ConditionalOp(output, condition)
else:
return OpDefault(None, None)
return OpDefault(None, None) # type: ignore[arg-type]
return output
class Op(object):
class Op:
"""
Base class for a FilterExpression operator
"""
OP = ""
def __init__(self, lhs, rhs):
def __init__(
self, lhs: Union["Func", "Op", "Operand"], rhs: Union["Func", "Op", "Operand"]
):
self.lhs = lhs
self.rhs = rhs
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool: # type: ignore
raise NotImplementedError(f"Expr not defined for {type(self)}")
def __repr__(self):
def __repr__(self) -> str:
return f"({self.lhs} {self.OP} {self.rhs})"
@ -125,7 +132,7 @@ COMPARISON_FUNCS = {
}
def get_comparison_func(range_comparison):
def get_comparison_func(range_comparison: str) -> Any:
return COMPARISON_FUNCS.get(range_comparison)
@ -136,15 +143,15 @@ class RecursionStopIteration(StopIteration):
class ConditionExpressionParser:
def __init__(
self,
condition_expression,
expression_attribute_names,
expression_attribute_values,
condition_expression: Optional[str],
expression_attribute_names: Optional[Dict[str, str]],
expression_attribute_values: Optional[Dict[str, str]],
):
self.condition_expression = condition_expression
self.expression_attribute_names = expression_attribute_names
self.expression_attribute_values = expression_attribute_values
def parse(self):
def parse(self) -> Union[Op, "Func"]:
"""Returns a syntax tree for the expression.
The tree, and all of the nodes in the tree are a tuple of
@ -181,7 +188,7 @@ class ConditionExpressionParser:
"""
if not self.condition_expression:
return OpDefault(None, None)
return OpDefault(None, None) # type: ignore[arg-type]
nodes = self._lex_condition_expression()
nodes = self._parse_paths(nodes)
# NOTE: The docs say that functions should be parsed after
@ -242,12 +249,12 @@ class ConditionExpressionParser:
Node = namedtuple("Node", ["nonterminal", "kind", "text", "value", "children"])
@classmethod
def raise_exception_if_keyword(cls, attribute):
def raise_exception_if_keyword(cls, attribute: str) -> None:
if attribute.upper() in ReservedKeywords.get_reserved_keywords():
raise ConditionAttributeIsReservedKeyword(attribute)
def _lex_condition_expression(self):
nodes = deque()
def _lex_condition_expression(self) -> Deque[Node]:
nodes: Deque[ConditionExpressionParser.Node] = deque()
remaining_expression = self.condition_expression
while remaining_expression:
node, remaining_expression = self._lex_one_node(remaining_expression)
@ -256,7 +263,7 @@ class ConditionExpressionParser:
nodes.append(node)
return nodes
def _lex_one_node(self, remaining_expression):
def _lex_one_node(self, remaining_expression: str) -> Tuple[Node, str]:
# TODO: Handle indexing like [1]
attribute_regex = r"(:|#)?[A-z0-9\-_]+"
patterns = [
@ -305,8 +312,8 @@ class ConditionExpressionParser:
return node, remaining_expression
def _parse_paths(self, nodes):
output = deque()
def _parse_paths(self, nodes: Deque[Node]) -> Deque[Node]:
output: Deque[ConditionExpressionParser.Node] = deque()
while nodes:
node = nodes.popleft()
@ -339,7 +346,7 @@ class ConditionExpressionParser:
output.append(node)
return output
def _parse_path_element(self, name):
def _parse_path_element(self, name: str) -> Node:
reserved = {
"and": self.Nonterminal.AND,
"or": self.Nonterminal.OR,
@ -416,11 +423,11 @@ class ConditionExpressionParser:
children=[],
)
def _lookup_expression_attribute_value(self, name):
return self.expression_attribute_values[name]
def _lookup_expression_attribute_value(self, name: str) -> str:
return self.expression_attribute_values[name] # type: ignore[index]
def _lookup_expression_attribute_name(self, name):
return self.expression_attribute_names[name]
def _lookup_expression_attribute_name(self, name: str) -> str:
return self.expression_attribute_names[name] # type: ignore[index]
# NOTE: The following constructions are ordered from high precedence to low precedence
# according to
@ -464,7 +471,7 @@ class ConditionExpressionParser:
# contains (path, operand)
# size (path)
def _matches(self, nodes, production):
def _matches(self, nodes: Deque[Node], production: List[str]) -> bool:
"""Check if the nodes start with the given production.
Parameters
@ -484,9 +491,9 @@ class ConditionExpressionParser:
return False
return True
def _apply_comparator(self, nodes):
def _apply_comparator(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := operand comparator operand."""
output = deque()
output: Deque[ConditionExpressionParser.Node] = deque()
while nodes:
if self._matches(nodes, ["*", "COMPARATOR"]):
@ -511,9 +518,9 @@ class ConditionExpressionParser:
output.append(nodes.popleft())
return output
def _apply_in(self, nodes):
def _apply_in(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := operand IN ( operand , ... )."""
output = deque()
output: Deque[ConditionExpressionParser.Node] = deque()
while nodes:
if self._matches(nodes, ["*", "IN"]):
self._assert(
@ -553,9 +560,9 @@ class ConditionExpressionParser:
output.append(nodes.popleft())
return output
def _apply_between(self, nodes):
def _apply_between(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := operand BETWEEN operand AND operand."""
output = deque()
output: Deque[ConditionExpressionParser.Node] = deque()
while nodes:
if self._matches(nodes, ["*", "BETWEEN"]):
self._assert(
@ -584,9 +591,9 @@ class ConditionExpressionParser:
output.append(nodes.popleft())
return output
def _apply_functions(self, nodes):
def _apply_functions(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := function_name (operand , ...)."""
output = deque()
output: Deque[ConditionExpressionParser.Node] = deque()
either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE}
expected_argument_kind_map = {
"attribute_exists": [{self.Kind.PATH}],
@ -656,9 +663,11 @@ class ConditionExpressionParser:
output.append(nodes.popleft())
return output
def _apply_parens_and_booleans(self, nodes, left_paren=None):
def _apply_parens_and_booleans(
self, nodes: Deque[Node], left_paren: Any = None
) -> Deque[Node]:
"""Apply condition := ( condition ) and booleans."""
output = deque()
output: Deque[ConditionExpressionParser.Node] = deque()
while nodes:
if self._matches(nodes, ["LEFT_PAREN"]):
parsed = self._apply_parens_and_booleans(
@ -696,7 +705,7 @@ class ConditionExpressionParser:
self._assert(left_paren is None, "Unmatched ( at", list(output))
return self._apply_booleans(output)
def _apply_booleans(self, nodes):
def _apply_booleans(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply and, or, and not constructions."""
nodes = self._apply_not(nodes)
nodes = self._apply_and(nodes)
@ -710,9 +719,9 @@ class ConditionExpressionParser:
)
return nodes
def _apply_not(self, nodes):
def _apply_not(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := NOT condition."""
output = deque()
output: Deque[ConditionExpressionParser.Node] = deque()
while nodes:
if self._matches(nodes, ["NOT"]):
self._assert(
@ -736,9 +745,9 @@ class ConditionExpressionParser:
return output
def _apply_and(self, nodes):
def _apply_and(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := condition AND condition."""
output = deque()
output: Deque[ConditionExpressionParser.Node] = deque()
while nodes:
if self._matches(nodes, ["*", "AND"]):
self._assert(
@ -764,9 +773,9 @@ class ConditionExpressionParser:
return output
def _apply_or(self, nodes):
def _apply_or(self, nodes: Deque[Node]) -> Deque[Node]:
"""Apply condition := condition OR condition."""
output = deque()
output: Deque[ConditionExpressionParser.Node] = deque()
while nodes:
if self._matches(nodes, ["*", "OR"]):
self._assert(
@ -792,7 +801,7 @@ class ConditionExpressionParser:
return output
def _make_operand(self, node):
def _make_operand(self, node: Node) -> "Operand":
if node.kind == self.Kind.PATH:
return AttributePath([child.value for child in node.children])
elif node.kind == self.Kind.EXPRESSION_ATTRIBUTE_VALUE:
@ -807,7 +816,7 @@ class ConditionExpressionParser:
else: # pragma: no cover
raise ValueError(f"Unknown operand: {node}")
def _make_op_condition(self, node):
def _make_op_condition(self, node: Node) -> Union["Func", Op]:
if node.kind == self.Kind.OR:
lhs, rhs = node.children
return OpOr(self._make_op_condition(lhs), self._make_op_condition(rhs))
@ -847,21 +856,21 @@ class ConditionExpressionParser:
else: # pragma: no cover
raise ValueError(f"Unknown expression node kind {node.kind}")
def _assert(self, condition, message, nodes):
def _assert(self, condition: bool, message: str, nodes: Iterable[Node]) -> None:
if not condition:
raise ValueError(message + " " + " ".join([t.text for t in nodes]))
class Operand(object):
def expr(self, item):
class Operand:
def expr(self, item: Optional[Item]) -> Any: # type: ignore
raise NotImplementedError
def get_type(self, item):
def get_type(self, item: Optional[Item]) -> Optional[str]: # type: ignore
raise NotImplementedError
class AttributePath(Operand):
def __init__(self, path):
def __init__(self, path: List[Any]):
"""Initialize the AttributePath.
Parameters
@ -872,7 +881,7 @@ class AttributePath(Operand):
assert len(path) >= 1
self.path = path
def _get_attr(self, item):
def _get_attr(self, item: Optional[Item]) -> Any:
if item is None:
return None
@ -888,26 +897,26 @@ class AttributePath(Operand):
return attr
def expr(self, item):
def expr(self, item: Optional[Item]) -> Any:
attr = self._get_attr(item)
if attr is None:
return None
else:
return attr.cast_value
def get_type(self, item):
def get_type(self, item: Optional[Item]) -> Optional[str]:
attr = self._get_attr(item)
if attr is None:
return None
else:
return attr.type
def __repr__(self):
def __repr__(self) -> str:
return ".".join(self.path)
class AttributeValue(Operand):
def __init__(self, value):
def __init__(self, value: Dict[str, Any]):
"""Initialize the AttributePath.
Parameters
@ -919,7 +928,7 @@ class AttributeValue(Operand):
self.type = list(value.keys())[0]
self.value = value[self.type]
def expr(self, item):
def expr(self, item: Optional[Item]) -> Any:
# TODO: Reuse DynamoType code
if self.type == "N":
try:
@ -939,17 +948,17 @@ class AttributeValue(Operand):
return self.value
return self.value
def get_type(self, item):
def get_type(self, item: Optional[Item]) -> str:
return self.type
def __repr__(self):
def __repr__(self) -> str:
return repr(self.value)
class OpDefault(Op):
OP = "NONE"
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
"""If no condition is specified, always True."""
return True
@ -957,21 +966,21 @@ class OpDefault(Op):
class OpNot(Op):
OP = "NOT"
def __init__(self, lhs):
super().__init__(lhs, None)
def __init__(self, lhs: Union["Func", Op]):
super().__init__(lhs, None) # type: ignore[arg-type]
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item)
return not lhs
def __str__(self):
def __str__(self) -> str:
return f"({self.OP} {self.lhs})"
class OpAnd(Op):
OP = "AND"
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item)
return lhs and self.rhs.expr(item)
@ -979,7 +988,7 @@ class OpAnd(Op):
class OpLessThan(Op):
OP = "<"
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially
@ -992,7 +1001,7 @@ class OpLessThan(Op):
class OpGreaterThan(Op):
OP = ">"
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially
@ -1005,7 +1014,7 @@ class OpGreaterThan(Op):
class OpEqual(Op):
OP = "="
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item)
return lhs == rhs
@ -1014,7 +1023,7 @@ class OpEqual(Op):
class OpNotEqual(Op):
OP = "<>"
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item)
return lhs != rhs
@ -1023,7 +1032,7 @@ class OpNotEqual(Op):
class OpLessThanOrEqual(Op):
OP = "<="
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially
@ -1036,7 +1045,7 @@ class OpLessThanOrEqual(Op):
class OpGreaterThanOrEqual(Op):
OP = ">="
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially
@ -1049,64 +1058,64 @@ class OpGreaterThanOrEqual(Op):
class OpOr(Op):
OP = "OR"
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
lhs = self.lhs.expr(item)
return lhs or self.rhs.expr(item)
class Func(object):
class Func:
"""
Base class for a FilterExpression function
"""
FUNC = "Unknown"
def __init__(self, *arguments):
def __init__(self, *arguments: Any):
self.arguments = arguments
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
raise NotImplementedError
def __repr__(self):
def __repr__(self) -> str:
return f"{self.FUNC}({' '.join([repr(arg) for arg in self.arguments])})"
class FuncAttrExists(Func):
FUNC = "attribute_exists"
def __init__(self, attribute):
def __init__(self, attribute: Operand):
self.attr = attribute
super().__init__(attribute)
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
return self.attr.get_type(item) is not None
def FuncAttrNotExists(attribute):
def FuncAttrNotExists(attribute: Operand) -> Any:
return OpNot(FuncAttrExists(attribute))
class FuncAttrType(Func):
FUNC = "attribute_type"
def __init__(self, attribute, _type):
def __init__(self, attribute: Operand, _type: Func):
self.attr = attribute
self.type = _type
super().__init__(attribute, _type)
def expr(self, item):
return self.attr.get_type(item) == self.type.expr(item)
def expr(self, item: Optional[Item]) -> bool:
return self.attr.get_type(item) == self.type.expr(item) # type: ignore[comparison-overlap]
class FuncBeginsWith(Func):
FUNC = "begins_with"
def __init__(self, attribute, substr):
def __init__(self, attribute: Operand, substr: Operand):
self.attr = attribute
self.substr = substr
super().__init__(attribute, substr)
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
if self.attr.get_type(item) != "S":
return False
if self.substr.get_type(item) != "S":
@ -1117,12 +1126,12 @@ class FuncBeginsWith(Func):
class FuncContains(Func):
FUNC = "contains"
def __init__(self, attribute, operand):
def __init__(self, attribute: Operand, operand: Operand):
self.attr = attribute
self.operand = operand
super().__init__(attribute, operand)
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
if self.attr.get_type(item) in ("S", "SS", "NS", "BS", "L"):
try:
return self.operand.expr(item) in self.attr.expr(item)
@ -1131,18 +1140,18 @@ class FuncContains(Func):
return False
def FuncNotContains(attribute, operand):
def FuncNotContains(attribute: Operand, operand: Operand) -> OpNot:
return OpNot(FuncContains(attribute, operand))
class FuncSize(Func):
FUNC = "size"
def __init__(self, attribute):
def __init__(self, attribute: Operand):
self.attr = attribute
super().__init__(attribute)
def expr(self, item):
def expr(self, item: Optional[Item]) -> int: # type: ignore[override]
if self.attr.get_type(item) is None:
raise ValueError(f"Invalid attribute name {self.attr}")
@ -1154,13 +1163,13 @@ class FuncSize(Func):
class FuncBetween(Func):
FUNC = "BETWEEN"
def __init__(self, attribute, start, end):
def __init__(self, attribute: Operand, start: Operand, end: Operand):
self.attr = attribute
self.start = start
self.end = end
super().__init__(attribute, start, end)
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
# In python3 None is not a valid comparator when using < or > so must be handled specially
start = self.start.expr(item)
attr = self.attr.expr(item)
@ -1183,12 +1192,12 @@ class FuncBetween(Func):
class FuncIn(Func):
FUNC = "IN"
def __init__(self, attribute, *possible_values):
def __init__(self, attribute: Operand, *possible_values: Any):
self.attr = attribute
self.possible_values = possible_values
super().__init__(attribute, *possible_values)
def expr(self, item):
def expr(self, item: Optional[Item]) -> bool:
for possible_value in self.possible_values:
if self.attr.expr(item) == possible_value.expr(item):
return True
@ -1205,7 +1214,7 @@ COMPARATOR_CLASS = {
"<>": OpNotEqual,
}
FUNC_CLASS = {
FUNC_CLASS: Dict[str, Any] = {
"attribute_exists": FuncAttrExists,
"attribute_not_exists": FuncAttrNotExists,
"attribute_type": FuncAttrType,

View File

@ -1,4 +1,5 @@
import json
from typing import Any, List, Optional
from moto.core.exceptions import JsonRESTError
from moto.dynamodb.limits import HASH_KEY_MAX_LENGTH, RANGE_KEY_MAX_LENGTH
@ -10,7 +11,7 @@ class DynamodbException(JsonRESTError):
class MockValidationException(DynamodbException):
error_type = "com.amazonaws.dynamodb.v20111205#ValidationException"
def __init__(self, message):
def __init__(self, message: str):
super().__init__(MockValidationException.error_type, message=message)
self.exception_msg = message
@ -24,14 +25,14 @@ class InvalidUpdateExpressionInvalidDocumentPath(MockValidationException):
"The document path provided in the update expression is invalid for update"
)
def __init__(self):
def __init__(self) -> None:
super().__init__(self.invalid_update_expression_msg)
class InvalidUpdateExpression(MockValidationException):
invalid_update_expr_msg = "Invalid UpdateExpression: {update_expression_error}"
def __init__(self, update_expression_error):
def __init__(self, update_expression_error: str):
self.update_expression_error = update_expression_error
super().__init__(
self.invalid_update_expr_msg.format(
@ -45,7 +46,7 @@ class InvalidConditionExpression(MockValidationException):
"Invalid ConditionExpression: {condition_expression_error}"
)
def __init__(self, condition_expression_error):
def __init__(self, condition_expression_error: str):
self.condition_expression_error = condition_expression_error
super().__init__(
self.invalid_condition_expr_msg.format(
@ -59,7 +60,7 @@ class ConditionAttributeIsReservedKeyword(InvalidConditionExpression):
"Attribute name is a reserved keyword; reserved keyword: {keyword}"
)
def __init__(self, keyword):
def __init__(self, keyword: str):
self.keyword = keyword
super().__init__(self.attribute_is_keyword_msg.format(keyword=keyword))
@ -69,7 +70,7 @@ class AttributeDoesNotExist(MockValidationException):
"The provided expression refers to an attribute that does not exist in the item"
)
def __init__(self):
def __init__(self) -> None:
super().__init__(self.attr_does_not_exist_msg)
@ -78,14 +79,14 @@ class ProvidedKeyDoesNotExist(MockValidationException):
"The provided key element does not match the schema"
)
def __init__(self):
def __init__(self) -> None:
super().__init__(self.provided_key_does_not_exist_msg)
class ExpressionAttributeNameNotDefined(InvalidUpdateExpression):
name_not_defined_msg = "An expression attribute name used in the document path is not defined; attribute name: {n}"
def __init__(self, attribute_name):
def __init__(self, attribute_name: str):
self.not_defined_attribute_name = attribute_name
super().__init__(self.name_not_defined_msg.format(n=attribute_name))
@ -95,7 +96,7 @@ class AttributeIsReservedKeyword(InvalidUpdateExpression):
"Attribute name is a reserved keyword; reserved keyword: {keyword}"
)
def __init__(self, keyword):
def __init__(self, keyword: str):
self.keyword = keyword
super().__init__(self.attribute_is_keyword_msg.format(keyword=keyword))
@ -103,7 +104,7 @@ class AttributeIsReservedKeyword(InvalidUpdateExpression):
class ExpressionAttributeValueNotDefined(InvalidUpdateExpression):
attr_value_not_defined_msg = "An expression attribute value used in expression is not defined; attribute value: {attribute_value}"
def __init__(self, attribute_value):
def __init__(self, attribute_value: str):
self.attribute_value = attribute_value
super().__init__(
self.attr_value_not_defined_msg.format(attribute_value=attribute_value)
@ -113,7 +114,7 @@ class ExpressionAttributeValueNotDefined(InvalidUpdateExpression):
class UpdateExprSyntaxError(InvalidUpdateExpression):
update_expr_syntax_error_msg = "Syntax error; {error_detail}"
def __init__(self, error_detail):
def __init__(self, error_detail: str):
self.error_detail = error_detail
super().__init__(
self.update_expr_syntax_error_msg.format(error_detail=error_detail)
@ -123,7 +124,7 @@ class UpdateExprSyntaxError(InvalidUpdateExpression):
class InvalidTokenException(UpdateExprSyntaxError):
token_detail_msg = 'token: "{token}", near: "{near}"'
def __init__(self, token, near):
def __init__(self, token: str, near: str):
self.token = token
self.near = near
super().__init__(self.token_detail_msg.format(token=token, near=near))
@ -134,7 +135,7 @@ class InvalidExpressionAttributeNameKey(MockValidationException):
'ExpressionAttributeNames contains invalid key: Syntax error; key: "{key}"'
)
def __init__(self, key):
def __init__(self, key: str):
self.key = key
super().__init__(self.invalid_expr_attr_name_msg.format(key=key))
@ -142,7 +143,7 @@ class InvalidExpressionAttributeNameKey(MockValidationException):
class ItemSizeTooLarge(MockValidationException):
item_size_too_large_msg = "Item size has exceeded the maximum allowed size"
def __init__(self):
def __init__(self) -> None:
super().__init__(self.item_size_too_large_msg)
@ -151,7 +152,7 @@ class ItemSizeToUpdateTooLarge(MockValidationException):
"Item size to update has exceeded the maximum allowed size"
)
def __init__(self):
def __init__(self) -> None:
super().__init__(self.item_size_to_update_too_large_msg)
@ -159,21 +160,21 @@ class HashKeyTooLong(MockValidationException):
# deliberately no space between of and {lim}
key_too_large_msg = f"One or more parameter values were invalid: Size of hashkey has exceeded the maximum size limit of{HASH_KEY_MAX_LENGTH} bytes"
def __init__(self):
def __init__(self) -> None:
super().__init__(self.key_too_large_msg)
class RangeKeyTooLong(MockValidationException):
key_too_large_msg = f"One or more parameter values were invalid: Aggregated size of all range keys has exceeded the size limit of {RANGE_KEY_MAX_LENGTH} bytes"
def __init__(self):
def __init__(self) -> None:
super().__init__(self.key_too_large_msg)
class IncorrectOperandType(InvalidUpdateExpression):
inv_operand_msg = "Incorrect operand type for operator or function; operator or function: {f}, operand type: {t}"
def __init__(self, operator_or_function, operand_type):
def __init__(self, operator_or_function: str, operand_type: str):
self.operator_or_function = operator_or_function
self.operand_type = operand_type
super().__init__(
@ -184,14 +185,14 @@ class IncorrectOperandType(InvalidUpdateExpression):
class IncorrectDataType(MockValidationException):
inc_data_type_msg = "An operand in the update expression has an incorrect data type"
def __init__(self):
def __init__(self) -> None:
super().__init__(self.inc_data_type_msg)
class ConditionalCheckFailed(DynamodbException):
error_type = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException"
def __init__(self, msg=None):
def __init__(self, msg: Optional[str] = None):
super().__init__(
ConditionalCheckFailed.error_type, msg or "The conditional request failed"
)
@ -201,7 +202,7 @@ class TransactionCanceledException(DynamodbException):
cancel_reason_msg = "Transaction cancelled, please refer cancellation reasons for specific reasons [{}]"
error_type = "com.amazonaws.dynamodb.v20120810#TransactionCanceledException"
def __init__(self, errors):
def __init__(self, errors: List[Any]):
msg = self.cancel_reason_msg.format(
", ".join([str(code) for code, _, _ in errors])
)
@ -224,7 +225,7 @@ class TransactionCanceledException(DynamodbException):
class MultipleTransactionsException(MockValidationException):
msg = "Transaction request cannot include multiple operations on one item"
def __init__(self):
def __init__(self) -> None:
super().__init__(self.msg)
@ -234,7 +235,7 @@ class TooManyTransactionsException(MockValidationException):
"Member must have length less than or equal to 100."
)
def __init__(self):
def __init__(self) -> None:
super().__init__(self.msg)
@ -243,26 +244,28 @@ class EmptyKeyAttributeException(MockValidationException):
# AWS has a different message for empty index keys
empty_index_msg = "One or more parameter values are not valid. The update expression attempted to update a secondary index key to a value that is not supported. The AttributeValue for a key attribute cannot contain an empty string value."
def __init__(self, key_in_index=False):
def __init__(self, key_in_index: bool = False):
super().__init__(self.empty_index_msg if key_in_index else self.empty_str_msg)
class UpdateHashRangeKeyException(MockValidationException):
msg = "One or more parameter values were invalid: Cannot update attribute {}. This attribute is part of the key"
def __init__(self, key_name):
def __init__(self, key_name: str):
super().__init__(self.msg.format(key_name))
class InvalidAttributeTypeError(MockValidationException):
msg = "One or more parameter values were invalid: Type mismatch for key {} expected: {} actual: {}"
def __init__(self, name, expected_type, actual_type):
def __init__(
self, name: Optional[str], expected_type: Optional[str], actual_type: str
):
super().__init__(self.msg.format(name, expected_type, actual_type))
class DuplicateUpdateExpression(InvalidUpdateExpression):
def __init__(self, names):
def __init__(self, names: List[str]):
super().__init__(
f"Two document paths overlap with each other; must remove or rewrite one of these paths; path one: [{names[0]}], path two: [{names[1]}]"
)
@ -271,54 +274,54 @@ class DuplicateUpdateExpression(InvalidUpdateExpression):
class TooManyAddClauses(InvalidUpdateExpression):
msg = 'The "ADD" section can only be used once in an update expression;'
def __init__(self):
def __init__(self) -> None:
super().__init__(self.msg)
class ResourceNotFoundException(JsonRESTError):
def __init__(self, msg=None):
def __init__(self, msg: Optional[str] = None):
err = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
super().__init__(err, msg or "Requested resource not found")
class TableNotFoundException(JsonRESTError):
def __init__(self, name):
def __init__(self, name: str):
err = "com.amazonaws.dynamodb.v20111205#TableNotFoundException"
super().__init__(err, f"Table not found: {name}")
class SourceTableNotFoundException(JsonRESTError):
def __init__(self, source_table_name):
def __init__(self, source_table_name: str):
er = "com.amazonaws.dynamodb.v20111205#SourceTableNotFoundException"
super().__init__(er, f"Source table not found: {source_table_name}")
class BackupNotFoundException(JsonRESTError):
def __init__(self, backup_arn):
def __init__(self, backup_arn: str):
er = "com.amazonaws.dynamodb.v20111205#BackupNotFoundException"
super().__init__(er, f"Backup not found: {backup_arn}")
class TableAlreadyExistsException(JsonRESTError):
def __init__(self, target_table_name):
def __init__(self, target_table_name: str):
er = "com.amazonaws.dynamodb.v20111205#TableAlreadyExistsException"
super().__init__(er, f"Table already exists: {target_table_name}")
class ResourceInUseException(JsonRESTError):
def __init__(self):
def __init__(self) -> None:
er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException"
super().__init__(er, "Resource in use")
class StreamAlreadyEnabledException(JsonRESTError):
def __init__(self):
def __init__(self) -> None:
er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException"
super().__init__(er, "Cannot enable stream")
class InvalidConversion(JsonRESTError):
def __init__(self):
def __init__(self) -> None:
er = "SerializationException"
super().__init__(er, "NUMBER_VALUE cannot be converted to String")
@ -328,10 +331,10 @@ class TransactWriteSingleOpException(MockValidationException):
"TransactItems can only contain one of Check, Put, Update or Delete"
)
def __init__(self):
def __init__(self) -> None:
super().__init__(self.there_can_be_only_one)
class SerializationException(DynamodbException):
def __init__(self, msg):
def __init__(self, msg: str):
super().__init__(error_type="SerializationException", message=msg)

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,16 @@
from moto.dynamodb.comparisons import get_comparison_func
from moto.dynamodb.exceptions import IncorrectDataType
import decimal
from typing import Any, Dict, List, Union, Optional
from moto.core import BaseModel
from moto.dynamodb.exceptions import (
IncorrectDataType,
EmptyKeyAttributeException,
ItemSizeTooLarge,
)
from moto.dynamodb.models.utilities import bytesize
class DDBType(object):
class DDBType:
"""
Official documentation at https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_AttributeValue.html
"""
@ -20,7 +27,7 @@ class DDBType(object):
NULL = "NULL"
class DDBTypeConversion(object):
class DDBTypeConversion:
_human_type_mapping = {
val: key.replace("_", " ")
for key, val in DDBType.__dict__.items()
@ -28,13 +35,13 @@ class DDBTypeConversion(object):
}
@classmethod
def get_human_type(cls, abbreviated_type):
def get_human_type(cls, abbreviated_type: str) -> str:
"""
Args:
abbreviated_type(str): An attribute of DDBType
Returns:
str: The human readable form of the DDBType.
str: The human-readable form of the DDBType.
"""
return cls._human_type_mapping.get(abbreviated_type, abbreviated_type)
@ -44,19 +51,19 @@ class DynamoType(object):
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
"""
def __init__(self, type_as_dict):
def __init__(self, type_as_dict: Union["DynamoType", Dict[str, Any]]):
if type(type_as_dict) == DynamoType:
self.type = type_as_dict.type
self.value = type_as_dict.value
self.type: str = type_as_dict.type
self.value: Any = type_as_dict.value
else:
self.type = list(type_as_dict)[0]
self.value = list(type_as_dict.values())[0]
self.type = list(type_as_dict)[0] # type: ignore[arg-type]
self.value = list(type_as_dict.values())[0] # type: ignore[union-attr]
if self.is_list():
self.value = [DynamoType(val) for val in self.value]
elif self.is_map():
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
def filter(self, projection_expressions):
def filter(self, projection_expressions: str) -> None:
nested_projections = [
expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr
]
@ -78,31 +85,31 @@ class DynamoType(object):
for expr in expressions_to_delete:
self.value.pop(expr)
def __hash__(self):
def __hash__(self) -> int:
return hash((self.type, self.value))
def __eq__(self, other):
def __eq__(self, other: "DynamoType") -> bool: # type: ignore[override]
return self.type == other.type and self.value == other.value
def __ne__(self, other):
def __ne__(self, other: "DynamoType") -> bool: # type: ignore[override]
return self.type != other.type or self.value != other.value
def __lt__(self, other):
def __lt__(self, other: "DynamoType") -> bool:
return self.cast_value < other.cast_value
def __le__(self, other):
def __le__(self, other: "DynamoType") -> bool:
return self.cast_value <= other.cast_value
def __gt__(self, other):
def __gt__(self, other: "DynamoType") -> bool:
return self.cast_value > other.cast_value
def __ge__(self, other):
def __ge__(self, other: "DynamoType") -> bool:
return self.cast_value >= other.cast_value
def __repr__(self):
def __repr__(self) -> str:
return f"DynamoType: {self.to_json()}"
def __add__(self, other):
def __add__(self, other: "DynamoType") -> "DynamoType":
if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.")
if self.is_number():
@ -112,7 +119,7 @@ class DynamoType(object):
else:
raise IncorrectDataType()
def __sub__(self, other):
def __sub__(self, other: "DynamoType") -> "DynamoType":
if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.")
if self.type == DDBType.NUMBER:
@ -122,7 +129,7 @@ class DynamoType(object):
else:
raise TypeError("Sum only supported for Numbers.")
def __getitem__(self, item):
def __getitem__(self, item: "DynamoType") -> "DynamoType":
if isinstance(item, str):
# If our DynamoType is a map it should be subscriptable with a key
if self.type == DDBType.MAP:
@ -135,7 +142,7 @@ class DynamoType(object):
f"This DynamoType {self.type} is not subscriptable by a {type(item)}"
)
def __setitem__(self, key, value):
def __setitem__(self, key: Any, value: Any) -> None:
if isinstance(key, int):
if self.is_list():
if key >= len(self.value):
@ -150,7 +157,7 @@ class DynamoType(object):
raise NotImplementedError(f"No set_item for {type(key)}")
@property
def cast_value(self):
def cast_value(self) -> Any: # type: ignore[misc]
if self.is_number():
try:
return int(self.value)
@ -166,7 +173,7 @@ class DynamoType(object):
else:
return self.value
def child_attr(self, key):
def child_attr(self, key: Union[int, str, None]) -> Optional["DynamoType"]:
"""
Get Map or List children by key. str for Map, int for List.
@ -183,7 +190,7 @@ class DynamoType(object):
return None
def size(self):
def size(self) -> int:
if self.is_number():
value_size = len(str(self.value))
elif self.is_set():
@ -201,34 +208,204 @@ class DynamoType(object):
value_size = bytesize(self.value)
return value_size
def to_json(self):
def to_json(self) -> Dict[str, Any]:
return {self.type: self.value}
def compare(self, range_comparison, range_objs):
def compare(self, range_comparison: str, range_objs: List[Any]) -> bool:
"""
Compares this type against comparison filters
"""
from moto.dynamodb.comparisons import get_comparison_func
range_values = [obj.cast_value for obj in range_objs]
comparison_func = get_comparison_func(range_comparison)
return comparison_func(self.cast_value, *range_values)
def is_number(self):
def is_number(self) -> bool:
return self.type == DDBType.NUMBER
def is_set(self):
def is_set(self) -> bool:
return self.type in (DDBType.STRING_SET, DDBType.NUMBER_SET, DDBType.BINARY_SET)
def is_list(self):
def is_list(self) -> bool:
return self.type == DDBType.LIST
def is_map(self):
def is_map(self) -> bool:
return self.type == DDBType.MAP
def same_type(self, other):
def same_type(self, other: "DynamoType") -> bool:
return self.type == other.type
def pop(self, key, *args, **kwargs):
def pop(self, key: str, *args: Any, **kwargs: Any) -> None:
if self.is_map() or self.is_list():
self.value.pop(key, *args, **kwargs)
else:
raise TypeError(f"pop not supported for DynamoType {self.type}")
# https://github.com/getmoto/moto/issues/1874
# Ensure that the total size of an item does not exceed 400kb
class LimitedSizeDict(Dict[str, Any]):
def __init__(self, *args: Any, **kwargs: Any):
self.update(*args, **kwargs)
def __setitem__(self, key: str, value: Any) -> None:
current_item_size = sum(
[
item.size() if type(item) == DynamoType else bytesize(str(item))
for item in (list(self.keys()) + list(self.values()))
]
)
new_item_size = bytesize(key) + (
value.size() if type(value) == DynamoType else bytesize(str(value))
)
# Official limit is set to 400000 (400KB)
# Manual testing confirms that the actual limit is between 409 and 410KB
# We'll set the limit to something in between to be safe
if (current_item_size + new_item_size) > 405000:
raise ItemSizeTooLarge
super().__setitem__(key, value)
class Item(BaseModel):
def __init__(
self,
hash_key: DynamoType,
range_key: Optional[DynamoType],
attrs: Dict[str, Any],
):
self.hash_key = hash_key
self.range_key = range_key
self.attrs = LimitedSizeDict()
for key, value in attrs.items():
self.attrs[key] = DynamoType(value)
def __eq__(self, other: "Item") -> bool: # type: ignore[override]
return all(
[
self.hash_key == other.hash_key,
self.range_key == other.range_key, # type: ignore[operator]
self.attrs == other.attrs,
]
)
def __repr__(self) -> str:
return f"Item: {self.to_json()}"
def size(self) -> int:
return sum(bytesize(key) + value.size() for key, value in self.attrs.items())
def to_json(self) -> Dict[str, Any]:
attributes = {}
for attribute_key, attribute in self.attrs.items():
attributes[attribute_key] = {attribute.type: attribute.value}
return {"Attributes": attributes}
def describe_attrs(
self, attributes: Optional[Dict[str, Any]]
) -> Dict[str, Dict[str, Any]]:
if attributes:
included = {}
for key, value in self.attrs.items():
if key in attributes:
included[key] = value
else:
included = self.attrs
return {"Item": included}
def validate_no_empty_key_values(
self, attribute_updates: Dict[str, Any], key_attributes: List[str]
) -> None:
for attribute_name, update_action in attribute_updates.items():
action = update_action.get("Action") or "PUT" # PUT is default
if action == "DELETE":
continue
new_value = next(iter(update_action["Value"].values()))
if action == "PUT" and new_value == "" and attribute_name in key_attributes:
raise EmptyKeyAttributeException
def update_with_attribute_updates(self, attribute_updates: Dict[str, Any]) -> None:
for attribute_name, update_action in attribute_updates.items():
# Use default Action value, if no explicit Action is passed.
# Default value is 'Put', according to
# Boto3 DynamoDB.Client.update_item documentation.
action = update_action.get("Action", "PUT")
if action == "DELETE" and "Value" not in update_action:
if attribute_name in self.attrs:
del self.attrs[attribute_name]
continue
new_value = list(update_action["Value"].values())[0]
if action == "PUT":
# TODO deal with other types
if set(update_action["Value"].keys()) == set(["SS"]):
self.attrs[attribute_name] = DynamoType({"SS": new_value})
elif isinstance(new_value, list):
self.attrs[attribute_name] = DynamoType({"L": new_value})
elif isinstance(new_value, dict):
self.attrs[attribute_name] = DynamoType({"M": new_value})
elif set(update_action["Value"].keys()) == set(["N"]):
self.attrs[attribute_name] = DynamoType({"N": new_value})
elif set(update_action["Value"].keys()) == set(["NULL"]):
if attribute_name in self.attrs:
del self.attrs[attribute_name]
else:
self.attrs[attribute_name] = DynamoType({"S": new_value})
elif action == "ADD":
if set(update_action["Value"].keys()) == set(["N"]):
existing = self.attrs.get(attribute_name, DynamoType({"N": "0"}))
self.attrs[attribute_name] = DynamoType(
{
"N": str(
decimal.Decimal(existing.value)
+ decimal.Decimal(new_value)
)
}
)
elif set(update_action["Value"].keys()) == set(["SS"]):
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
new_set = set(existing.value).union(set(new_value))
self.attrs[attribute_name] = DynamoType({"SS": list(new_set)})
elif set(update_action["Value"].keys()) == {"L"}:
existing = self.attrs.get(attribute_name, DynamoType({"L": []}))
new_list = existing.value + new_value
self.attrs[attribute_name] = DynamoType({"L": new_list})
else:
# TODO: implement other data types
raise NotImplementedError(
"ADD not supported for %s"
% ", ".join(update_action["Value"].keys())
)
elif action == "DELETE":
if set(update_action["Value"].keys()) == set(["SS"]):
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
new_set = set(existing.value).difference(set(new_value))
self.attrs[attribute_name] = DynamoType({"SS": list(new_set)})
else:
raise NotImplementedError(
"ADD not supported for %s"
% ", ".join(update_action["Value"].keys())
)
else:
raise NotImplementedError(
f"{action} action not support for update_with_attribute_updates"
)
# Filter using projection_expression
# Ensure a deep copy is used to filter, otherwise actual data will be removed
def filter(self, projection_expression: str) -> None:
expressions = [x.strip() for x in projection_expression.split(",")]
top_level_expressions = [
expr[0 : expr.index(".")] for expr in expressions if "." in expr
]
for attr in list(self.attrs):
if attr not in expressions and attr not in top_level_expressions:
self.attrs.pop(attr)
if attr in top_level_expressions:
relevant_expressions = [
expr[len(attr + ".") :]
for expr in expressions
if expr.startswith(attr + ".")
]
self.attrs[attr].filter(relevant_expressions)

File diff suppressed because it is too large Load Diff

View File

@ -1,2 +1,16 @@
def bytesize(val):
import json
from typing import Any
class DynamoJsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
if hasattr(o, "to_json"):
return o.to_json()
def dynamo_json_dump(dynamo_object: Any) -> str:
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
def bytesize(val: str) -> int:
return len(val.encode("utf-8"))

View File

@ -1,3 +1,4 @@
# type: ignore
import abc
from abc import abstractmethod
from collections import deque
@ -21,7 +22,7 @@ class Node(metaclass=abc.ABCMeta):
def set_parent(self, parent_node):
self.parent = parent_node
def validate(self):
def validate(self) -> None:
if self.type == "UpdateExpression":
nr_of_clauses = len(self.find_clauses([UpdateExpressionAddClause]))
if nr_of_clauses > 1:

View File

@ -1,13 +1,19 @@
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union, Type
from moto.dynamodb.exceptions import (
IncorrectOperandType,
IncorrectDataType,
ProvidedKeyDoesNotExist,
)
from moto.dynamodb.models import DynamoType
from moto.dynamodb.models.dynamo_type import DDBTypeConversion, DDBType
from moto.dynamodb.parsing.ast_nodes import (
from moto.dynamodb.models.dynamo_type import (
DDBTypeConversion,
DDBType,
DynamoType,
Item,
)
from moto.dynamodb.parsing.ast_nodes import ( # type: ignore
Node,
UpdateExpressionSetAction,
UpdateExpressionDeleteAction,
UpdateExpressionRemoveAction,
@ -21,16 +27,18 @@ from moto.dynamodb.parsing.ast_nodes import (
from moto.dynamodb.parsing.validators import ExpressionPathResolver
class NodeExecutor(object):
def __init__(self, ast_node, expression_attribute_names):
class NodeExecutor:
def __init__(self, ast_node: Node, expression_attribute_names: Dict[str, str]):
self.node = ast_node
self.expression_attribute_names = expression_attribute_names
@abstractmethod
def execute(self, item):
def execute(self, item: Item) -> None:
pass
def get_item_part_for_path_nodes(self, item, path_nodes):
def get_item_part_for_path_nodes(
self, item: Item, path_nodes: List[Node]
) -> Union[DynamoType, Dict[str, Any]]:
"""
For a list of path nodes travers the item by following the path_nodes
Args:
@ -43,11 +51,13 @@ class NodeExecutor(object):
if len(path_nodes) == 0:
return item.attrs
else:
return ExpressionPathResolver(
return ExpressionPathResolver( # type: ignore
self.expression_attribute_names
).resolve_expression_path_nodes_to_dynamo_type(item, path_nodes)
def get_item_before_end_of_path(self, item):
def get_item_before_end_of_path(
self, item: Item
) -> Union[DynamoType, Dict[str, Any]]:
"""
Get the part ot the item where the item will perform the action. For most actions this should be the parent. As
that element will need to be modified by the action.
@ -61,7 +71,7 @@ class NodeExecutor(object):
item, self.get_path_expression_nodes()[:-1]
)
def get_item_at_end_of_path(self, item):
def get_item_at_end_of_path(self, item: Item) -> Union[DynamoType, Dict[str, Any]]:
"""
For a DELETE the path points at the stringset so we need to evaluate the full path.
Args:
@ -76,15 +86,15 @@ class NodeExecutor(object):
# that element will need to be modified by the action.
get_item_part_in_which_to_perform_action = get_item_before_end_of_path
def get_path_expression_nodes(self):
def get_path_expression_nodes(self) -> List[Node]:
update_expression_path = self.node.children[0]
assert isinstance(update_expression_path, UpdateExpressionPath)
return update_expression_path.children
def get_element_to_action(self):
def get_element_to_action(self) -> Node:
return self.get_path_expression_nodes()[-1]
def get_action_value(self):
def get_action_value(self) -> DynamoType:
"""
Returns:
@ -98,7 +108,7 @@ class NodeExecutor(object):
class SetExecutor(NodeExecutor):
def execute(self, item):
def execute(self, item: Item) -> None:
self.set(
item_part_to_modify_with_set=self.get_item_part_in_which_to_perform_action(
item
@ -109,13 +119,13 @@ class SetExecutor(NodeExecutor):
)
@classmethod
def set(
def set( # type: ignore[misc]
cls,
item_part_to_modify_with_set,
element_to_set,
value_to_set,
expression_attribute_names,
):
item_part_to_modify_with_set: Union[DynamoType, Dict[str, Any]],
element_to_set: Any,
value_to_set: Any,
expression_attribute_names: Dict[str, str],
) -> None:
if isinstance(element_to_set, ExpressionAttribute):
attribute_name = element_to_set.get_attribute_name()
item_part_to_modify_with_set[attribute_name] = value_to_set
@ -136,7 +146,7 @@ class SetExecutor(NodeExecutor):
class DeleteExecutor(NodeExecutor):
operator = "operator: DELETE"
def execute(self, item):
def execute(self, item: Item) -> None:
string_set_to_remove = self.get_action_value()
assert isinstance(string_set_to_remove, DynamoType)
if not string_set_to_remove.is_set():
@ -176,11 +186,11 @@ class DeleteExecutor(NodeExecutor):
f"Moto does not support deleting {type(element)} yet"
)
container = self.get_item_before_end_of_path(item)
del container[attribute_name]
del container[attribute_name] # type: ignore[union-attr]
class RemoveExecutor(NodeExecutor):
def execute(self, item):
def execute(self, item: Item) -> None:
element_to_remove = self.get_element_to_action()
if isinstance(element_to_remove, ExpressionAttribute):
attribute_name = element_to_remove.get_attribute_name()
@ -208,7 +218,7 @@ class RemoveExecutor(NodeExecutor):
class AddExecutor(NodeExecutor):
def execute(self, item):
def execute(self, item: Item) -> None:
value_to_add = self.get_action_value()
if isinstance(value_to_add, DynamoType):
if value_to_add.is_set():
@ -253,7 +263,7 @@ class AddExecutor(NodeExecutor):
raise IncorrectDataType()
class UpdateExpressionExecutor(object):
class UpdateExpressionExecutor:
execution_map = {
UpdateExpressionSetAction: SetExecutor,
UpdateExpressionAddAction: AddExecutor,
@ -261,12 +271,14 @@ class UpdateExpressionExecutor(object):
UpdateExpressionDeleteAction: DeleteExecutor,
}
def __init__(self, update_ast, item, expression_attribute_names):
def __init__(
self, update_ast: Node, item: Item, expression_attribute_names: Dict[str, str]
):
self.update_ast = update_ast
self.item = item
self.expression_attribute_names = expression_attribute_names
def execute(self, node=None):
def execute(self, node: Optional[Node] = None) -> None:
"""
As explained in moto.dynamodb.parsing.expressions.NestableExpressionParserMixin._create_node the order of nodes
in the AST can be translated of the order of statements in the expression. As such we can start at the root node
@ -286,12 +298,12 @@ class UpdateExpressionExecutor(object):
node_executor = self.get_specific_execution(node)
if node_executor is None:
for node in node.children:
self.execute(node)
for n in node.children:
self.execute(n)
else:
node_executor(node, self.expression_attribute_names).execute(self.item)
def get_specific_execution(self, node):
def get_specific_execution(self, node: Node) -> Optional[Type[NodeExecutor]]:
for node_class in self.execution_map:
if isinstance(node, node_class):
return self.execution_map[node_class]

View File

@ -1,3 +1,4 @@
# type: ignore
import logging
from abc import abstractmethod
import abc
@ -35,7 +36,7 @@ from moto.dynamodb.parsing.tokens import Token, ExpressionTokenizer
logger = logging.getLogger(__name__)
class NestableExpressionParserMixin(object):
class NestableExpressionParserMixin:
"""
For nodes that can be nested in themselves (recursive). Take for example UpdateExpression's grammar:

View File

@ -1,4 +1,5 @@
from enum import Enum
from typing import Any, List, Dict, Tuple, Optional
from moto.dynamodb.exceptions import MockValidationException
from moto.utilities.tokenizer import GenericTokenizer
@ -11,17 +12,17 @@ class EXPRESSION_STAGES(Enum):
EOF = "EOF"
def get_key(schema, key_type):
def get_key(schema: List[Dict[str, str]], key_type: str) -> Optional[str]:
keys = [key for key in schema if key["KeyType"] == key_type]
return keys[0]["AttributeName"] if keys else None
def parse_expression(
key_condition_expression,
expression_attribute_values,
expression_attribute_names,
schema,
):
key_condition_expression: str,
expression_attribute_values: Dict[str, str],
expression_attribute_names: Dict[str, str],
schema: List[Dict[str, str]],
) -> Tuple[Dict[str, Any], Optional[str], List[Dict[str, Any]]]:
"""
Parse a KeyConditionExpression using the provided expression attribute names/values
@ -31,11 +32,11 @@ def parse_expression(
schema: [{'AttributeName': 'hashkey', 'KeyType': 'HASH'}, {"AttributeName": "sortkey", "KeyType": "RANGE"}]
"""
current_stage: EXPRESSION_STAGES = None
current_stage: Optional[EXPRESSION_STAGES] = None
current_phrase = ""
key_name = comparison = None
key_name = comparison = ""
key_values = []
results = []
results: List[Tuple[str, str, Any]] = []
tokenizer = GenericTokenizer(key_condition_expression)
for crnt_char in tokenizer:
if crnt_char == " ":
@ -188,7 +189,9 @@ def parse_expression(
# Validate that the schema-keys are encountered in our query
def validate_schema(results, schema):
def validate_schema(
results: Any, schema: List[Dict[str, str]]
) -> Tuple[Dict[str, Any], Optional[str], List[Dict[str, Any]]]:
index_hash_key = get_key(schema, "HASH")
comparison, hash_value = next(
(
@ -219,4 +222,4 @@ def validate_schema(results, schema):
f"Query condition missed key schema element: {index_range_key}"
)
return hash_value, range_comparison, range_values
return hash_value, range_comparison, range_values # type: ignore[return-value]

View File

@ -1,27 +1,26 @@
from moto.utilities.utils import load_resource
from typing import List, Optional
from moto.utilities.utils import load_resource_as_str
class ReservedKeywords(list):
class ReservedKeywords:
"""
DynamoDB has an extensive list of keywords. Keywords are considered when validating the expression Tree.
Not earlier since an update expression like "SET path = VALUE 1" fails with:
'Invalid UpdateExpression: Syntax error; token: "1", near: "VALUE 1"'
"""
KEYWORDS = None
KEYWORDS: Optional[List[str]] = None
@classmethod
def get_reserved_keywords(cls):
def get_reserved_keywords(cls) -> List[str]:
if cls.KEYWORDS is None:
cls.KEYWORDS = cls._get_reserved_keywords()
return cls.KEYWORDS
@classmethod
def _get_reserved_keywords(cls):
def _get_reserved_keywords(cls) -> List[str]:
"""
Get a list of reserved keywords of DynamoDB
"""
reserved_keywords = load_resource(
__name__, "reserved_keywords.txt", as_json=False
)
reserved_keywords = load_resource_as_str(__name__, "reserved_keywords.txt")
return reserved_keywords.split()

View File

@ -1,4 +1,5 @@
import re
from typing import List, Union
from moto.dynamodb.exceptions import (
InvalidTokenException,
@ -6,7 +7,7 @@ from moto.dynamodb.exceptions import (
)
class Token(object):
class Token:
_TOKEN_INSTANCE = None
MINUS_SIGN = "-"
PLUS_SIGN = "+"
@ -53,7 +54,7 @@ class Token(object):
NUMBER: "Number",
}
def __init__(self, token_type, value):
def __init__(self, token_type: Union[int, str], value: str):
assert (
token_type in self.SPECIAL_CHARACTERS
or token_type in self.PLACEHOLDER_NAMES
@ -61,13 +62,13 @@ class Token(object):
self.type = token_type
self.value = value
def __repr__(self):
def __repr__(self) -> str:
if isinstance(self.type, int):
return f'Token("{self.PLACEHOLDER_NAMES[self.type]}", "{self.value}")'
else:
return f'Token("{self.type}", "{self.value}")'
def __eq__(self, other):
def __eq__(self, other: "Token") -> bool: # type: ignore[override]
return self.type == other.type and self.value == other.value
@ -94,22 +95,22 @@ class ExpressionTokenizer(object):
"""
@classmethod
def is_simple_token_character(cls, character):
def is_simple_token_character(cls, character: str) -> bool:
return character.isalnum() or character in ("_", ":", "#")
@classmethod
def is_possible_token_boundary(cls, character):
def is_possible_token_boundary(cls, character: str) -> bool:
return (
character in Token.SPECIAL_CHARACTERS
or not cls.is_simple_token_character(character)
)
@classmethod
def is_expression_attribute(cls, input_string):
def is_expression_attribute(cls, input_string: str) -> bool:
return re.compile("^[a-zA-Z0-9][a-zA-Z0-9_]*$").match(input_string) is not None
@classmethod
def is_expression_attribute_name(cls, input_string):
def is_expression_attribute_name(cls, input_string: str) -> bool:
"""
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html
An expression attribute name must begin with a pound sign (#), and be followed by one or more alphanumeric
@ -120,10 +121,10 @@ class ExpressionTokenizer(object):
)
@classmethod
def is_expression_attribute_value(cls, input_string):
def is_expression_attribute_value(cls, input_string: str) -> bool:
return re.compile("^:[a-zA-Z0-9_]*$").match(input_string) is not None
def raise_unexpected_token(self):
def raise_unexpected_token(self) -> None:
"""If during parsing an unexpected token is encountered"""
if len(self.token_list) == 0:
near = ""
@ -140,29 +141,29 @@ class ExpressionTokenizer(object):
problematic_token = self.staged_characters[0]
raise InvalidTokenException(problematic_token, near + self.staged_characters)
def __init__(self, input_expression_str):
def __init__(self, input_expression_str: str):
self.input_expression_str = input_expression_str
self.token_list = []
self.token_list: List[Token] = []
self.staged_characters = ""
@classmethod
def make_list(cls, input_expression_str):
def make_list(cls, input_expression_str: str) -> List[Token]:
assert isinstance(input_expression_str, str)
return ExpressionTokenizer(input_expression_str)._make_list()
def add_token(self, token_type, token_value):
def add_token(self, token_type: Union[int, str], token_value: str) -> None:
self.token_list.append(Token(token_type, token_value))
def add_token_from_stage(self, token_type):
def add_token_from_stage(self, token_type: int) -> None:
self.add_token(token_type, self.staged_characters)
self.staged_characters = ""
@classmethod
def is_numeric(cls, input_str):
def is_numeric(cls, input_str: str) -> bool:
return re.compile("[0-9]+").match(input_str) is not None
def process_staged_characters(self):
def process_staged_characters(self) -> None:
if len(self.staged_characters) == 0:
return
if self.staged_characters.startswith("#"):
@ -179,7 +180,7 @@ class ExpressionTokenizer(object):
else:
self.raise_unexpected_token()
def _make_list(self):
def _make_list(self) -> List[Token]:
"""
Just go through characters if a character is not a token boundary stage it for adding it as a grouped token
later if it is a tokenboundary process staged characters and then process the token boundary as well.

View File

@ -3,6 +3,7 @@ See docstring class Validator below for more details on validation
"""
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Callable, Dict, List, Type, Union
from moto.dynamodb.exceptions import (
AttributeIsReservedKeyword,
@ -15,9 +16,12 @@ from moto.dynamodb.exceptions import (
EmptyKeyAttributeException,
UpdateHashRangeKeyException,
)
from moto.dynamodb.models import DynamoType
from moto.dynamodb.parsing.ast_nodes import (
from moto.dynamodb.models.dynamo_type import DynamoType, Item
from moto.dynamodb.models.table import Table
from moto.dynamodb.parsing.ast_nodes import ( # type: ignore
Node,
ExpressionAttribute,
UpdateExpressionClause,
UpdateExpressionPath,
UpdateExpressionSetAction,
UpdateExpressionAddAction,
@ -37,16 +41,23 @@ from moto.dynamodb.parsing.ast_nodes import (
from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords
class ExpressionAttributeValueProcessor(DepthFirstTraverser):
def __init__(self, expression_attribute_values):
class ExpressionAttributeValueProcessor(DepthFirstTraverser): # type: ignore[misc]
def __init__(self, expression_attribute_values: Dict[str, Dict[str, Any]]):
self.expression_attribute_values = expression_attribute_values
def _processing_map(self):
def _processing_map(
self,
) -> Dict[
Type[ExpressionAttributeValue],
Callable[[ExpressionAttributeValue], DDBTypedValue],
]:
return {
ExpressionAttributeValue: self.replace_expression_attribute_value_with_value
}
def replace_expression_attribute_value_with_value(self, node):
def replace_expression_attribute_value_with_value(
self, node: ExpressionAttributeValue
) -> DDBTypedValue:
"""A node representing an Expression Attribute Value. Resolve and replace value"""
assert isinstance(node, ExpressionAttributeValue)
attribute_value_name = node.get_value_name()
@ -59,20 +70,24 @@ class ExpressionAttributeValueProcessor(DepthFirstTraverser):
return DDBTypedValue(DynamoType(target))
class ExpressionPathResolver(object):
def __init__(self, expression_attribute_names):
class ExpressionPathResolver:
def __init__(self, expression_attribute_names: Dict[str, str]):
self.expression_attribute_names = expression_attribute_names
@classmethod
def raise_exception_if_keyword(cls, attribute):
def raise_exception_if_keyword(cls, attribute: Any) -> None: # type: ignore[misc]
if attribute.upper() in ReservedKeywords.get_reserved_keywords():
raise AttributeIsReservedKeyword(attribute)
def resolve_expression_path(self, item, update_expression_path):
def resolve_expression_path(
self, item: Item, update_expression_path: UpdateExpressionPath
) -> Union[NoneExistingPath, DDBTypedValue]:
assert isinstance(update_expression_path, UpdateExpressionPath)
return self.resolve_expression_path_nodes(item, update_expression_path.children)
def resolve_expression_path_nodes(self, item, update_expression_path_nodes):
def resolve_expression_path_nodes(
self, item: Item, update_expression_path_nodes: List[Node]
) -> Union[NoneExistingPath, DDBTypedValue]:
target = item.attrs
for child in update_expression_path_nodes:
@ -100,7 +115,7 @@ class ExpressionPathResolver(object):
continue
elif isinstance(child, ExpressionSelector):
index = child.get_index()
if target.is_list():
if target.is_list(): # type: ignore
try:
target = target[index]
except IndexError:
@ -116,8 +131,8 @@ class ExpressionPathResolver(object):
return DDBTypedValue(target)
def resolve_expression_path_nodes_to_dynamo_type(
self, item, update_expression_path_nodes
):
self, item: Item, update_expression_path_nodes: List[Node]
) -> Any:
node = self.resolve_expression_path_nodes(item, update_expression_path_nodes)
if isinstance(node, NoneExistingPath):
raise ProvidedKeyDoesNotExist()
@ -125,19 +140,30 @@ class ExpressionPathResolver(object):
return node.get_value()
class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
def _processing_map(self):
class ExpressionAttributeResolvingProcessor(DepthFirstTraverser): # type: ignore[misc]
def _processing_map(
self,
) -> Dict[Type[UpdateExpressionClause], Callable[[DDBTypedValue], DDBTypedValue]]:
return {
UpdateExpressionSetAction: self.disable_resolving,
UpdateExpressionPath: self.process_expression_path_node,
}
def __init__(self, expression_attribute_names, item):
def __init__(self, expression_attribute_names: Dict[str, str], item: Item):
self.expression_attribute_names = expression_attribute_names
self.item = item
self.resolving = False
def pre_processing_of_child(self, parent_node, child_id):
def pre_processing_of_child(
self,
parent_node: Union[
UpdateExpressionSetAction,
UpdateExpressionRemoveAction,
UpdateExpressionDeleteAction,
UpdateExpressionAddAction,
],
child_id: int,
) -> None:
"""
We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first.
Because first argument is path to be set, 2nd argument would be the value.
@ -156,11 +182,11 @@ class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
else:
self.resolving = True
def disable_resolving(self, node=None):
def disable_resolving(self, node: DDBTypedValue) -> DDBTypedValue:
self.resolving = False
return node
def process_expression_path_node(self, node):
def process_expression_path_node(self, node: DDBTypedValue) -> DDBTypedValue:
"""Resolve ExpressionAttribute if not part of a path and resolving is enabled."""
if self.resolving:
return self.resolve_expression_path(node)
@ -175,13 +201,15 @@ class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
return node
def resolve_expression_path(self, node):
def resolve_expression_path(
self, node: DDBTypedValue
) -> Union[NoneExistingPath, DDBTypedValue]:
return ExpressionPathResolver(
self.expression_attribute_names
).resolve_expression_path(self.item, node)
class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
class UpdateExpressionFunctionEvaluator(DepthFirstTraverser): # type: ignore[misc]
"""
At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET
expression as per the official AWS docs:
@ -189,10 +217,15 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET
"""
def _processing_map(self):
def _processing_map(
self,
) -> Dict[
Type[UpdateExpressionFunction],
Callable[[UpdateExpressionFunction], DDBTypedValue],
]:
return {UpdateExpressionFunction: self.process_function}
def process_function(self, node):
def process_function(self, node: UpdateExpressionFunction) -> DDBTypedValue:
assert isinstance(node, UpdateExpressionFunction)
function_name = node.get_function_name()
first_arg = node.get_nth_argument(1)
@ -217,7 +250,7 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
raise NotImplementedError(f"Unsupported function for moto {function_name}")
@classmethod
def get_list_from_ddb_typed_value(cls, node, function_name):
def get_list_from_ddb_typed_value(cls, node: DDBTypedValue, function_name: str) -> DynamoType: # type: ignore[misc]
assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType)
@ -226,23 +259,25 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
return dynamo_value
class NoneExistingPathChecker(DepthFirstTraverser):
class NoneExistingPathChecker(DepthFirstTraverser): # type: ignore[misc]
"""
Pass through the AST and make sure there are no none-existing paths.
"""
def _processing_map(self):
def _processing_map(self) -> Dict[Type[NoneExistingPath], Callable[[Node], None]]:
return {NoneExistingPath: self.raise_none_existing_path}
def raise_none_existing_path(self, node):
def raise_none_existing_path(self, node: Node) -> None:
raise AttributeDoesNotExist
class ExecuteOperations(DepthFirstTraverser):
def _processing_map(self):
class ExecuteOperations(DepthFirstTraverser): # type: ignore[misc]
def _processing_map(
self,
) -> Dict[Type[UpdateExpressionValue], Callable[[Node], DDBTypedValue]]:
return {UpdateExpressionValue: self.process_update_expression_value}
def process_update_expression_value(self, node):
def process_update_expression_value(self, node: Node) -> DDBTypedValue:
"""
If an UpdateExpressionValue only has a single child the node will be replaced with the childe.
Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them
@ -273,14 +308,14 @@ class ExecuteOperations(DepthFirstTraverser):
)
@classmethod
def get_dynamo_value_from_ddb_typed_value(cls, node):
def get_dynamo_value_from_ddb_typed_value(cls, node: DDBTypedValue) -> DynamoType: # type: ignore[misc]
assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType)
return dynamo_value
@classmethod
def get_sum(cls, left_operand, right_operand):
def get_sum(cls, left_operand: DynamoType, right_operand: DynamoType) -> DDBTypedValue: # type: ignore[misc]
"""
Args:
left_operand(DynamoType):
@ -295,7 +330,7 @@ class ExecuteOperations(DepthFirstTraverser):
raise IncorrectOperandType("+", left_operand.type)
@classmethod
def get_subtraction(cls, left_operand, right_operand):
def get_subtraction(cls, left_operand: DynamoType, right_operand: DynamoType) -> DDBTypedValue: # type: ignore[misc]
"""
Args:
left_operand(DynamoType):
@ -310,14 +345,21 @@ class ExecuteOperations(DepthFirstTraverser):
raise IncorrectOperandType("-", left_operand.type)
class EmptyStringKeyValueValidator(DepthFirstTraverser):
def __init__(self, key_attributes):
class EmptyStringKeyValueValidator(DepthFirstTraverser): # type: ignore[misc]
def __init__(self, key_attributes: List[str]):
self.key_attributes = key_attributes
def _processing_map(self):
def _processing_map(
self,
) -> Dict[
Type[UpdateExpressionSetAction],
Callable[[UpdateExpressionSetAction], UpdateExpressionSetAction],
]:
return {UpdateExpressionSetAction: self.check_for_empty_string_key_value}
def check_for_empty_string_key_value(self, node):
def check_for_empty_string_key_value(
self, node: UpdateExpressionSetAction
) -> UpdateExpressionSetAction:
"""A node representing a SET action. Check that keys are not being assigned empty strings"""
assert isinstance(node, UpdateExpressionSetAction)
assert len(node.children) == 2
@ -332,15 +374,26 @@ class EmptyStringKeyValueValidator(DepthFirstTraverser):
return node
class UpdateHashRangeKeyValidator(DepthFirstTraverser):
def __init__(self, table_key_attributes, expression_attribute_names):
class UpdateHashRangeKeyValidator(DepthFirstTraverser): # type: ignore[misc]
def __init__(
self,
table_key_attributes: List[str],
expression_attribute_names: Dict[str, str],
):
self.table_key_attributes = table_key_attributes
self.expression_attribute_names = expression_attribute_names
def _processing_map(self):
def _processing_map(
self,
) -> Dict[
Type[UpdateExpressionPath],
Callable[[UpdateExpressionPath], UpdateExpressionPath],
]:
return {UpdateExpressionPath: self.check_for_hash_or_range_key}
def check_for_hash_or_range_key(self, node):
def check_for_hash_or_range_key(
self, node: UpdateExpressionPath
) -> UpdateExpressionPath:
"""Check that hash and range keys are not updated"""
key_to_update = node.children[0].children[0]
key_to_update = self.expression_attribute_names.get(
@ -351,18 +404,18 @@ class UpdateHashRangeKeyValidator(DepthFirstTraverser):
return node
class Validator(object):
class Validator:
"""
A validator is used to validate expressions which are passed in as an AST.
"""
def __init__(
self,
expression,
expression_attribute_names,
expression_attribute_values,
item,
table,
expression: Node,
expression_attribute_names: Dict[str, str],
expression_attribute_values: Dict[str, Dict[str, Any]],
item: Item,
table: Table,
):
"""
Besides validation the Validator should also replace referenced parts of an item which is cheapest upon
@ -382,10 +435,10 @@ class Validator(object):
self.node_to_validate = deepcopy(expression)
@abstractmethod
def get_ast_processors(self):
def get_ast_processors(self) -> List[DepthFirstTraverser]: # type: ignore[misc]
"""Get the different processors that go through the AST tree and processes the nodes."""
def validate(self):
def validate(self) -> Node:
n = self.node_to_validate
for processor in self.processors:
n = processor.traverse(n)
@ -393,7 +446,7 @@ class Validator(object):
class UpdateExpressionValidator(Validator):
def get_ast_processors(self):
def get_ast_processors(self) -> List[DepthFirstTraverser]:
"""Get the different processors that go through the AST tree and processes the nodes."""
processors = [
UpdateHashRangeKeyValidator(

View File

@ -3,7 +3,9 @@ import json
import itertools
from functools import wraps
from typing import Any, Dict, List, Union, Callable, Optional
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from moto.dynamodb.parsing.key_condition_expression import parse_expression
@ -13,17 +15,27 @@ from .exceptions import (
ResourceNotFoundException,
ConditionalCheckFailed,
)
from moto.dynamodb.models import dynamodb_backends, dynamo_json_dump, Table
from moto.dynamodb.models import dynamodb_backends, Table, DynamoDBBackend
from moto.dynamodb.models.utilities import dynamo_json_dump
from moto.utilities.aws_headers import amz_crc32, amzn_request_id
TRANSACTION_MAX_ITEMS = 25
def include_consumed_capacity(val=1.0):
def _inner(f):
def include_consumed_capacity(
val: float = 1.0,
) -> Callable[
[Callable[["DynamoHandler"], str]],
Callable[["DynamoHandler"], Union[str, TYPE_RESPONSE]],
]:
def _inner(
f: Callable[..., Union[str, TYPE_RESPONSE]]
) -> Callable[["DynamoHandler"], Union[str, TYPE_RESPONSE]]:
@wraps(f)
def _wrapper(*args, **kwargs):
def _wrapper(
*args: "DynamoHandler", **kwargs: None
) -> Union[str, TYPE_RESPONSE]:
(handler,) = args
expected_capacity = handler.body.get("ReturnConsumedCapacity", "NONE")
if expected_capacity not in ["NONE", "TOTAL", "INDEXES"]:
@ -67,7 +79,7 @@ def include_consumed_capacity(val=1.0):
return _inner
def get_empty_keys_on_put(field_updates, table: Table):
def get_empty_keys_on_put(field_updates: Dict[str, Any], table: Table) -> Optional[str]:
"""
Return the first key-name that has an empty value. None if all keys are filled
"""
@ -83,12 +95,12 @@ def get_empty_keys_on_put(field_updates, table: Table):
return next(
(keyname for keyname in key_names if keyname in empty_str_fields), None
)
return False
return None
def put_has_empty_attrs(field_updates, table):
def put_has_empty_attrs(field_updates: Dict[str, Any], table: Table) -> bool:
# Example invalid attribute: [{'M': {'SS': {'NS': []}}}]
def _validate_attr(attr: dict):
def _validate_attr(attr: Dict[str, Any]) -> bool:
if "NS" in attr and attr["NS"] == []:
return True
else:
@ -105,7 +117,7 @@ def put_has_empty_attrs(field_updates, table):
return False
def validate_put_has_gsi_keys_set_to_none(item, table: Table) -> None:
def validate_put_has_gsi_keys_set_to_none(item: Dict[str, Any], table: Table) -> None:
for gsi in table.global_indexes:
for attr in gsi.schema:
attr_name = attr["AttributeName"]
@ -115,7 +127,7 @@ def validate_put_has_gsi_keys_set_to_none(item, table: Table) -> None:
)
def check_projection_expression(expression):
def check_projection_expression(expression: str) -> None:
if expression.upper() in ReservedKeywords.get_reserved_keywords():
raise MockValidationException(
f"ProjectionExpression: Attribute name is a reserved keyword; reserved keyword: {expression}"
@ -131,10 +143,10 @@ def check_projection_expression(expression):
class DynamoHandler(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="dynamodb")
def get_endpoint_name(self, headers):
def get_endpoint_name(self, headers: Any) -> Optional[str]:
"""Parses request headers and extracts part od the X-Amz-Target
that corresponds to a method of DynamoHandler
@ -144,9 +156,10 @@ class DynamoHandler(BaseResponse):
match = headers.get("x-amz-target") or headers.get("X-Amz-Target")
if match:
return match.split(".")[1]
return None
@property
def dynamodb_backend(self):
def dynamodb_backend(self) -> DynamoDBBackend:
"""
:return: DynamoDB Backend
:rtype: moto.dynamodb.models.DynamoDBBackend
@ -155,7 +168,7 @@ class DynamoHandler(BaseResponse):
@amz_crc32
@amzn_request_id
def call_action(self):
def call_action(self) -> TYPE_RESPONSE:
self.body = json.loads(self.body or "{}")
endpoint = self.get_endpoint_name(self.headers)
if endpoint:
@ -171,7 +184,7 @@ class DynamoHandler(BaseResponse):
else:
return 404, self.response_headers, ""
def list_tables(self):
def list_tables(self) -> str:
body = self.body
limit = body.get("Limit", 100)
exclusive_start_table_name = body.get("ExclusiveStartTableName")
@ -179,13 +192,13 @@ class DynamoHandler(BaseResponse):
limit, exclusive_start_table_name
)
response = {"TableNames": tables}
response: Dict[str, Any] = {"TableNames": tables}
if last_eval:
response["LastEvaluatedTableName"] = last_eval
return dynamo_json_dump(response)
def create_table(self):
def create_table(self) -> str:
body = self.body
# get the table name
table_name = body["TableName"]
@ -243,7 +256,7 @@ class DynamoHandler(BaseResponse):
actual_attrs = [item["AttributeName"] for item in attr]
actual_attrs.sort()
if actual_attrs != expected_attrs:
return self._throw_attr_error(
self._throw_attr_error(
actual_attrs, expected_attrs, global_indexes or local_secondary_indexes
)
# get the stream specification
@ -265,8 +278,10 @@ class DynamoHandler(BaseResponse):
)
return dynamo_json_dump(table.describe())
def _throw_attr_error(self, actual_attrs, expected_attrs, indexes):
def dump_list(list_):
def _throw_attr_error(
self, actual_attrs: List[str], expected_attrs: List[str], indexes: bool
) -> None:
def dump_list(list_: List[str]) -> str:
return str(list_).replace("'", "")
err_head = "One or more parameter values were invalid: "
@ -315,28 +330,28 @@ class DynamoHandler(BaseResponse):
+ dump_list(actual_attrs)
)
def delete_table(self):
def delete_table(self) -> str:
name = self.body["TableName"]
table = self.dynamodb_backend.delete_table(name)
return dynamo_json_dump(table.describe())
def describe_endpoints(self):
def describe_endpoints(self) -> str:
response = {"Endpoints": self.dynamodb_backend.describe_endpoints()}
return dynamo_json_dump(response)
def tag_resource(self):
def tag_resource(self) -> str:
table_arn = self.body["ResourceArn"]
tags = self.body["Tags"]
self.dynamodb_backend.tag_resource(table_arn, tags)
return ""
def untag_resource(self):
def untag_resource(self) -> str:
table_arn = self.body["ResourceArn"]
tags = self.body["TagKeys"]
self.dynamodb_backend.untag_resource(table_arn, tags)
return ""
def list_tags_of_resource(self):
def list_tags_of_resource(self) -> str:
table_arn = self.body["ResourceArn"]
all_tags = self.dynamodb_backend.list_tags_of_resource(table_arn)
all_tag_keys = [tag["Key"] for tag in all_tags]
@ -354,7 +369,7 @@ class DynamoHandler(BaseResponse):
return json.dumps({"Tags": tags_resp, "NextToken": next_marker})
return json.dumps({"Tags": tags_resp})
def update_table(self):
def update_table(self) -> str:
name = self.body["TableName"]
attr_definitions = self.body.get("AttributeDefinitions", None)
global_index = self.body.get("GlobalSecondaryIndexUpdates", None)
@ -371,13 +386,13 @@ class DynamoHandler(BaseResponse):
)
return dynamo_json_dump(table.describe())
def describe_table(self):
def describe_table(self) -> str:
name = self.body["TableName"]
table = self.dynamodb_backend.describe_table(name)
return dynamo_json_dump(table)
@include_consumed_capacity()
def put_item(self):
def put_item(self) -> str:
name = self.body["TableName"]
item = self.body["Item"]
return_values = self.body.get("ReturnValues", "NONE")
@ -436,7 +451,7 @@ class DynamoHandler(BaseResponse):
item_dict.pop("Attributes", None)
return dynamo_json_dump(item_dict)
def batch_write_item(self):
def batch_write_item(self) -> str:
table_batches = self.body["RequestItems"]
put_requests = []
delete_requests = []
@ -478,7 +493,7 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(response)
@include_consumed_capacity(0.5)
def get_item(self):
def get_item(self) -> str:
name = self.body["TableName"]
self.dynamodb_backend.get_table(name)
key = self.body["Key"]
@ -520,10 +535,14 @@ class DynamoHandler(BaseResponse):
# Item not found
return dynamo_json_dump({})
def batch_get_item(self):
def batch_get_item(self) -> str:
table_batches = self.body["RequestItems"]
results = {"ConsumedCapacity": [], "Responses": {}, "UnprocessedKeys": {}}
results: Dict[str, Any] = {
"ConsumedCapacity": [],
"Responses": {},
"UnprocessedKeys": {},
}
# Validation: Can only request up to 100 items at the same time
# Scenario 1: We're requesting more than a 100 keys from a single table
@ -582,7 +601,7 @@ class DynamoHandler(BaseResponse):
)
return dynamo_json_dump(results)
def _contains_duplicates(self, keys):
def _contains_duplicates(self, keys: List[str]) -> bool:
unique_keys = []
for k in keys:
if k in unique_keys:
@ -592,7 +611,7 @@ class DynamoHandler(BaseResponse):
return False
@include_consumed_capacity()
def query(self):
def query(self) -> str:
name = self.body["TableName"]
key_condition_expression = self.body.get("KeyConditionExpression")
projection_expression = self.body.get("ProjectionExpression")
@ -676,7 +695,7 @@ class DynamoHandler(BaseResponse):
**filter_kwargs,
)
result = {
result: Dict[str, Any] = {
"Count": len(items),
"ScannedCount": scanned_count,
}
@ -689,8 +708,10 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(result)
def _adjust_projection_expression(self, projection_expression, expr_attr_names):
def _adjust(expression):
def _adjust_projection_expression(
self, projection_expression: str, expr_attr_names: Dict[str, str]
) -> str:
def _adjust(expression: str) -> str:
return (
expr_attr_names[expression]
if expression in expr_attr_names
@ -712,7 +733,7 @@ class DynamoHandler(BaseResponse):
return projection_expression
@include_consumed_capacity()
def scan(self):
def scan(self) -> str:
name = self.body["TableName"]
filters = {}
@ -760,7 +781,7 @@ class DynamoHandler(BaseResponse):
result["LastEvaluatedKey"] = last_evaluated_key
return dynamo_json_dump(result)
def delete_item(self):
def delete_item(self) -> str:
name = self.body["TableName"]
key = self.body["Key"]
return_values = self.body.get("ReturnValues", "NONE")
@ -795,7 +816,7 @@ class DynamoHandler(BaseResponse):
item_dict["ConsumedCapacityUnits"] = 0.5
return dynamo_json_dump(item_dict)
def update_item(self):
def update_item(self) -> str:
name = self.body["TableName"]
key = self.body["Key"]
return_values = self.body.get("ReturnValues", "NONE")
@ -870,7 +891,7 @@ class DynamoHandler(BaseResponse):
)
return dynamo_json_dump(item_dict)
def _build_updated_new_attributes(self, original, changed):
def _build_updated_new_attributes(self, original: Any, changed: Any) -> Any:
if type(changed) != type(original):
return changed
else:
@ -895,7 +916,7 @@ class DynamoHandler(BaseResponse):
else:
return changed
def describe_limits(self):
def describe_limits(self) -> str:
return json.dumps(
{
"AccountMaxReadCapacityUnits": 20000,
@ -905,7 +926,7 @@ class DynamoHandler(BaseResponse):
}
)
def update_time_to_live(self):
def update_time_to_live(self) -> str:
name = self.body["TableName"]
ttl_spec = self.body["TimeToLiveSpecification"]
@ -913,16 +934,16 @@ class DynamoHandler(BaseResponse):
return json.dumps({"TimeToLiveSpecification": ttl_spec})
def describe_time_to_live(self):
def describe_time_to_live(self) -> str:
name = self.body["TableName"]
ttl_spec = self.dynamodb_backend.describe_time_to_live(name)
return json.dumps({"TimeToLiveDescription": ttl_spec})
def transact_get_items(self):
def transact_get_items(self) -> str:
transact_items = self.body["TransactItems"]
responses = list()
responses: List[Dict[str, Any]] = list()
if len(transact_items) > TRANSACTION_MAX_ITEMS:
msg = "1 validation error detected: Value '["
@ -945,7 +966,7 @@ class DynamoHandler(BaseResponse):
raise MockValidationException(msg)
ret_consumed_capacity = self.body.get("ReturnConsumedCapacity", "NONE")
consumed_capacity = dict()
consumed_capacity: Dict[str, Any] = dict()
for transact_item in transact_items:
@ -957,7 +978,7 @@ class DynamoHandler(BaseResponse):
responses.append({})
continue
item_describe = item.describe_attrs(False)
item_describe = item.describe_attrs(attributes=None)
responses.append(item_describe)
table_capacity = consumed_capacity.get(table_name, {})
@ -981,20 +1002,20 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(result)
def transact_write_items(self):
def transact_write_items(self) -> str:
transact_items = self.body["TransactItems"]
self.dynamodb_backend.transact_write_items(transact_items)
response = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}}
response: Dict[str, Any] = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}}
return dynamo_json_dump(response)
def describe_continuous_backups(self):
def describe_continuous_backups(self) -> str:
name = self.body["TableName"]
response = self.dynamodb_backend.describe_continuous_backups(name)
return json.dumps({"ContinuousBackupsDescription": response})
def update_continuous_backups(self):
def update_continuous_backups(self) -> str:
name = self.body["TableName"]
point_in_time_spec = self.body["PointInTimeRecoverySpecification"]
@ -1004,14 +1025,14 @@ class DynamoHandler(BaseResponse):
return json.dumps({"ContinuousBackupsDescription": response})
def list_backups(self):
def list_backups(self) -> str:
body = self.body
table_name = body.get("TableName")
backups = self.dynamodb_backend.list_backups(table_name)
response = {"BackupSummaries": [backup.summary for backup in backups]}
return dynamo_json_dump(response)
def create_backup(self):
def create_backup(self) -> str:
body = self.body
table_name = body.get("TableName")
backup_name = body.get("BackupName")
@ -1019,21 +1040,21 @@ class DynamoHandler(BaseResponse):
response = {"BackupDetails": backup.details}
return dynamo_json_dump(response)
def delete_backup(self):
def delete_backup(self) -> str:
body = self.body
backup_arn = body.get("BackupArn")
backup = self.dynamodb_backend.delete_backup(backup_arn)
response = {"BackupDescription": backup.description}
return dynamo_json_dump(response)
def describe_backup(self):
def describe_backup(self) -> str:
body = self.body
backup_arn = body.get("BackupArn")
backup = self.dynamodb_backend.describe_backup(backup_arn)
response = {"BackupDescription": backup.description}
return dynamo_json_dump(response)
def restore_table_from_backup(self):
def restore_table_from_backup(self) -> str:
body = self.body
target_table_name = body.get("TargetTableName")
backup_arn = body.get("BackupArn")
@ -1042,7 +1063,7 @@ class DynamoHandler(BaseResponse):
)
return dynamo_json_dump(restored_table.describe())
def restore_table_to_point_in_time(self):
def restore_table_to_point_in_time(self) -> str:
body = self.body
target_table_name = body.get("TargetTableName")
source_table_name = body.get("SourceTableName")

View File

@ -4,8 +4,9 @@ import base64
from typing import Any, Dict, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.dynamodb.models import dynamodb_backends, DynamoJsonEncoder, DynamoDBBackend
from moto.dynamodb.models import Table, StreamShard
from moto.dynamodb.models import dynamodb_backends, DynamoDBBackend
from moto.dynamodb.models.table import Table, StreamShard
from moto.dynamodb.models.utilities import DynamoJsonEncoder
class ShardIterator(BaseModel):
@ -86,7 +87,7 @@ class DynamoDBStreamsBackend(BaseBackend):
"StreamStatus": (
"ENABLED" if table.latest_stream_label else "DISABLED"
),
"StreamViewType": table.stream_specification["StreamViewType"],
"StreamViewType": table.stream_specification["StreamViewType"], # type: ignore[index]
"CreationRequestDateTime": table.stream_shard.created_on.isoformat(), # type: ignore[union-attr]
"TableName": table.name,
"KeySchema": table.schema,

View File

@ -4,17 +4,17 @@ class GenericTokenizer:
The final character to be returned will be an empty string, to notify the caller that we've reached the end.
"""
def __init__(self, expression):
def __init__(self, expression: str):
self.expression = expression
self.token_pos = 0
def __iter__(self):
def __iter__(self) -> "GenericTokenizer":
return self
def is_eof(self):
def is_eof(self) -> bool:
return self.peek() == ""
def peek(self, length=1):
def peek(self, length: int = 1) -> str:
"""
Peek the next character without changing the position
"""
@ -23,7 +23,7 @@ class GenericTokenizer:
except IndexError:
return ""
def __next__(self):
def __next__(self) -> str:
"""
Returns the next character, or an empty string if we've reached the end of the string.
Calling this method again will result in a StopIterator
@ -38,7 +38,7 @@ class GenericTokenizer:
return ""
raise StopIteration
def skip_characters(self, phrase, case_sensitive=False) -> None:
def skip_characters(self, phrase: str, case_sensitive: bool = False) -> None:
"""
Skip the characters in the supplied phrase.
If any other character is encountered instead, this will fail.
@ -51,7 +51,7 @@ class GenericTokenizer:
assert self.expression[self.token_pos] in [ch.lower(), ch.upper()]
self.token_pos += 1
def skip_white_space(self):
def skip_white_space(self) -> None:
"""
Skip any whitespace characters that are coming up
"""

View File

@ -13,14 +13,19 @@ def str2bool(v):
return False
def load_resource(package: str, resource: str, as_json: bool = True) -> Dict[str, Any]:
def load_resource(package: str, resource: str) -> Dict[str, Any]:
"""
Open a file, and return the contents as JSON.
Usage:
load_resource(__name__, "resources/file.json")
"""
resource = pkgutil.get_data(package, resource)
return json.loads(resource) if as_json else resource.decode("utf-8")
return json.loads(resource)
def load_resource_as_str(package: str, resource: str) -> str:
resource = pkgutil.get_data(package, resource)
return resource.decode("utf-8")
def merge_multiple_dicts(*args: Any) -> Dict[str, any]:

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/ds,moto/dynamodb_v20111205,moto/dynamodbstreams,moto/moto_api
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/moto_api
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract