diff --git a/moto/dynamodb/comparisons.py b/moto/dynamodb/comparisons.py index 6ba5d9052..94fbe2370 100644 --- a/moto/dynamodb/comparisons.py +++ b/moto/dynamodb/comparisons.py @@ -1,12 +1,17 @@ import re -from collections import deque -from collections import namedtuple +from collections import deque, namedtuple +from typing import Any, Dict, List, Tuple, Deque, Optional, Iterable, Union +from moto.dynamodb.models.dynamo_type import Item from moto.dynamodb.exceptions import ConditionAttributeIsReservedKeyword from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords -def get_filter_expression(expr, names, values): +def get_filter_expression( + expr: Optional[str], + names: Optional[Dict[str, str]], + values: Optional[Dict[str, str]], +) -> Union["Op", "Func"]: """ Parse a filter expression into an Op. @@ -18,7 +23,7 @@ def get_filter_expression(expr, names, values): return parser.parse() -def get_expected(expected): +def get_expected(expected: Dict[str, Any]) -> Union["Op", "Func"]: """ Parse a filter expression into an Op. @@ -26,7 +31,7 @@ def get_expected(expected): expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' expr = 'Id > 5 AND Subs < 7' """ - ops = { + ops: Dict[str, Any] = { "EQ": OpEqual, "NE": OpNotEqual, "LE": OpLessThanOrEqual, @@ -43,7 +48,7 @@ def get_expected(expected): } # NOTE: Always uses ConditionalOperator=AND - conditions = [] + conditions: List[Union["Op", "Func"]] = [] for key, cond in expected.items(): path = AttributePath([key]) if "Exists" in cond: @@ -66,26 +71,28 @@ def get_expected(expected): for condition in conditions[1:]: output = ConditionalOp(output, condition) else: - return OpDefault(None, None) + return OpDefault(None, None) # type: ignore[arg-type] return output -class Op(object): +class Op: """ Base class for a FilterExpression operator """ OP = "" - def __init__(self, lhs, rhs): + def __init__( + self, lhs: Union["Func", "Op", "Operand"], rhs: Union["Func", "Op", "Operand"] + ): self.lhs = lhs self.rhs = rhs - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: # type: ignore raise NotImplementedError(f"Expr not defined for {type(self)}") - def __repr__(self): + def __repr__(self) -> str: return f"({self.lhs} {self.OP} {self.rhs})" @@ -125,7 +132,7 @@ COMPARISON_FUNCS = { } -def get_comparison_func(range_comparison): +def get_comparison_func(range_comparison: str) -> Any: return COMPARISON_FUNCS.get(range_comparison) @@ -136,15 +143,15 @@ class RecursionStopIteration(StopIteration): class ConditionExpressionParser: def __init__( self, - condition_expression, - expression_attribute_names, - expression_attribute_values, + condition_expression: Optional[str], + expression_attribute_names: Optional[Dict[str, str]], + expression_attribute_values: Optional[Dict[str, str]], ): self.condition_expression = condition_expression self.expression_attribute_names = expression_attribute_names self.expression_attribute_values = expression_attribute_values - def parse(self): + def parse(self) -> Union[Op, "Func"]: """Returns a syntax tree for the expression. The tree, and all of the nodes in the tree are a tuple of @@ -181,7 +188,7 @@ class ConditionExpressionParser: """ if not self.condition_expression: - return OpDefault(None, None) + return OpDefault(None, None) # type: ignore[arg-type] nodes = self._lex_condition_expression() nodes = self._parse_paths(nodes) # NOTE: The docs say that functions should be parsed after @@ -242,12 +249,12 @@ class ConditionExpressionParser: Node = namedtuple("Node", ["nonterminal", "kind", "text", "value", "children"]) @classmethod - def raise_exception_if_keyword(cls, attribute): + def raise_exception_if_keyword(cls, attribute: str) -> None: if attribute.upper() in ReservedKeywords.get_reserved_keywords(): raise ConditionAttributeIsReservedKeyword(attribute) - def _lex_condition_expression(self): - nodes = deque() + def _lex_condition_expression(self) -> Deque[Node]: + nodes: Deque[ConditionExpressionParser.Node] = deque() remaining_expression = self.condition_expression while remaining_expression: node, remaining_expression = self._lex_one_node(remaining_expression) @@ -256,7 +263,7 @@ class ConditionExpressionParser: nodes.append(node) return nodes - def _lex_one_node(self, remaining_expression): + def _lex_one_node(self, remaining_expression: str) -> Tuple[Node, str]: # TODO: Handle indexing like [1] attribute_regex = r"(:|#)?[A-z0-9\-_]+" patterns = [ @@ -305,8 +312,8 @@ class ConditionExpressionParser: return node, remaining_expression - def _parse_paths(self, nodes): - output = deque() + def _parse_paths(self, nodes: Deque[Node]) -> Deque[Node]: + output: Deque[ConditionExpressionParser.Node] = deque() while nodes: node = nodes.popleft() @@ -339,7 +346,7 @@ class ConditionExpressionParser: output.append(node) return output - def _parse_path_element(self, name): + def _parse_path_element(self, name: str) -> Node: reserved = { "and": self.Nonterminal.AND, "or": self.Nonterminal.OR, @@ -416,11 +423,11 @@ class ConditionExpressionParser: children=[], ) - def _lookup_expression_attribute_value(self, name): - return self.expression_attribute_values[name] + def _lookup_expression_attribute_value(self, name: str) -> str: + return self.expression_attribute_values[name] # type: ignore[index] - def _lookup_expression_attribute_name(self, name): - return self.expression_attribute_names[name] + def _lookup_expression_attribute_name(self, name: str) -> str: + return self.expression_attribute_names[name] # type: ignore[index] # NOTE: The following constructions are ordered from high precedence to low precedence # according to @@ -464,7 +471,7 @@ class ConditionExpressionParser: # contains (path, operand) # size (path) - def _matches(self, nodes, production): + def _matches(self, nodes: Deque[Node], production: List[str]) -> bool: """Check if the nodes start with the given production. Parameters @@ -484,9 +491,9 @@ class ConditionExpressionParser: return False return True - def _apply_comparator(self, nodes): + def _apply_comparator(self, nodes: Deque[Node]) -> Deque[Node]: """Apply condition := operand comparator operand.""" - output = deque() + output: Deque[ConditionExpressionParser.Node] = deque() while nodes: if self._matches(nodes, ["*", "COMPARATOR"]): @@ -511,9 +518,9 @@ class ConditionExpressionParser: output.append(nodes.popleft()) return output - def _apply_in(self, nodes): + def _apply_in(self, nodes: Deque[Node]) -> Deque[Node]: """Apply condition := operand IN ( operand , ... ).""" - output = deque() + output: Deque[ConditionExpressionParser.Node] = deque() while nodes: if self._matches(nodes, ["*", "IN"]): self._assert( @@ -553,9 +560,9 @@ class ConditionExpressionParser: output.append(nodes.popleft()) return output - def _apply_between(self, nodes): + def _apply_between(self, nodes: Deque[Node]) -> Deque[Node]: """Apply condition := operand BETWEEN operand AND operand.""" - output = deque() + output: Deque[ConditionExpressionParser.Node] = deque() while nodes: if self._matches(nodes, ["*", "BETWEEN"]): self._assert( @@ -584,9 +591,9 @@ class ConditionExpressionParser: output.append(nodes.popleft()) return output - def _apply_functions(self, nodes): + def _apply_functions(self, nodes: Deque[Node]) -> Deque[Node]: """Apply condition := function_name (operand , ...).""" - output = deque() + output: Deque[ConditionExpressionParser.Node] = deque() either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE} expected_argument_kind_map = { "attribute_exists": [{self.Kind.PATH}], @@ -656,9 +663,11 @@ class ConditionExpressionParser: output.append(nodes.popleft()) return output - def _apply_parens_and_booleans(self, nodes, left_paren=None): + def _apply_parens_and_booleans( + self, nodes: Deque[Node], left_paren: Any = None + ) -> Deque[Node]: """Apply condition := ( condition ) and booleans.""" - output = deque() + output: Deque[ConditionExpressionParser.Node] = deque() while nodes: if self._matches(nodes, ["LEFT_PAREN"]): parsed = self._apply_parens_and_booleans( @@ -696,7 +705,7 @@ class ConditionExpressionParser: self._assert(left_paren is None, "Unmatched ( at", list(output)) return self._apply_booleans(output) - def _apply_booleans(self, nodes): + def _apply_booleans(self, nodes: Deque[Node]) -> Deque[Node]: """Apply and, or, and not constructions.""" nodes = self._apply_not(nodes) nodes = self._apply_and(nodes) @@ -710,9 +719,9 @@ class ConditionExpressionParser: ) return nodes - def _apply_not(self, nodes): + def _apply_not(self, nodes: Deque[Node]) -> Deque[Node]: """Apply condition := NOT condition.""" - output = deque() + output: Deque[ConditionExpressionParser.Node] = deque() while nodes: if self._matches(nodes, ["NOT"]): self._assert( @@ -736,9 +745,9 @@ class ConditionExpressionParser: return output - def _apply_and(self, nodes): + def _apply_and(self, nodes: Deque[Node]) -> Deque[Node]: """Apply condition := condition AND condition.""" - output = deque() + output: Deque[ConditionExpressionParser.Node] = deque() while nodes: if self._matches(nodes, ["*", "AND"]): self._assert( @@ -764,9 +773,9 @@ class ConditionExpressionParser: return output - def _apply_or(self, nodes): + def _apply_or(self, nodes: Deque[Node]) -> Deque[Node]: """Apply condition := condition OR condition.""" - output = deque() + output: Deque[ConditionExpressionParser.Node] = deque() while nodes: if self._matches(nodes, ["*", "OR"]): self._assert( @@ -792,7 +801,7 @@ class ConditionExpressionParser: return output - def _make_operand(self, node): + def _make_operand(self, node: Node) -> "Operand": if node.kind == self.Kind.PATH: return AttributePath([child.value for child in node.children]) elif node.kind == self.Kind.EXPRESSION_ATTRIBUTE_VALUE: @@ -807,7 +816,7 @@ class ConditionExpressionParser: else: # pragma: no cover raise ValueError(f"Unknown operand: {node}") - def _make_op_condition(self, node): + def _make_op_condition(self, node: Node) -> Union["Func", Op]: if node.kind == self.Kind.OR: lhs, rhs = node.children return OpOr(self._make_op_condition(lhs), self._make_op_condition(rhs)) @@ -847,21 +856,21 @@ class ConditionExpressionParser: else: # pragma: no cover raise ValueError(f"Unknown expression node kind {node.kind}") - def _assert(self, condition, message, nodes): + def _assert(self, condition: bool, message: str, nodes: Iterable[Node]) -> None: if not condition: raise ValueError(message + " " + " ".join([t.text for t in nodes])) -class Operand(object): - def expr(self, item): +class Operand: + def expr(self, item: Optional[Item]) -> Any: # type: ignore raise NotImplementedError - def get_type(self, item): + def get_type(self, item: Optional[Item]) -> Optional[str]: # type: ignore raise NotImplementedError class AttributePath(Operand): - def __init__(self, path): + def __init__(self, path: List[Any]): """Initialize the AttributePath. Parameters @@ -872,7 +881,7 @@ class AttributePath(Operand): assert len(path) >= 1 self.path = path - def _get_attr(self, item): + def _get_attr(self, item: Optional[Item]) -> Any: if item is None: return None @@ -888,26 +897,26 @@ class AttributePath(Operand): return attr - def expr(self, item): + def expr(self, item: Optional[Item]) -> Any: attr = self._get_attr(item) if attr is None: return None else: return attr.cast_value - def get_type(self, item): + def get_type(self, item: Optional[Item]) -> Optional[str]: attr = self._get_attr(item) if attr is None: return None else: return attr.type - def __repr__(self): + def __repr__(self) -> str: return ".".join(self.path) class AttributeValue(Operand): - def __init__(self, value): + def __init__(self, value: Dict[str, Any]): """Initialize the AttributePath. Parameters @@ -919,7 +928,7 @@ class AttributeValue(Operand): self.type = list(value.keys())[0] self.value = value[self.type] - def expr(self, item): + def expr(self, item: Optional[Item]) -> Any: # TODO: Reuse DynamoType code if self.type == "N": try: @@ -939,17 +948,17 @@ class AttributeValue(Operand): return self.value return self.value - def get_type(self, item): + def get_type(self, item: Optional[Item]) -> str: return self.type - def __repr__(self): + def __repr__(self) -> str: return repr(self.value) class OpDefault(Op): OP = "NONE" - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: """If no condition is specified, always True.""" return True @@ -957,21 +966,21 @@ class OpDefault(Op): class OpNot(Op): OP = "NOT" - def __init__(self, lhs): - super().__init__(lhs, None) + def __init__(self, lhs: Union["Func", Op]): + super().__init__(lhs, None) # type: ignore[arg-type] - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: lhs = self.lhs.expr(item) return not lhs - def __str__(self): + def __str__(self) -> str: return f"({self.OP} {self.lhs})" class OpAnd(Op): OP = "AND" - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: lhs = self.lhs.expr(item) return lhs and self.rhs.expr(item) @@ -979,7 +988,7 @@ class OpAnd(Op): class OpLessThan(Op): OP = "<" - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) # In python3 None is not a valid comparator when using < or > so must be handled specially @@ -992,7 +1001,7 @@ class OpLessThan(Op): class OpGreaterThan(Op): OP = ">" - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) # In python3 None is not a valid comparator when using < or > so must be handled specially @@ -1005,7 +1014,7 @@ class OpGreaterThan(Op): class OpEqual(Op): OP = "=" - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) return lhs == rhs @@ -1014,7 +1023,7 @@ class OpEqual(Op): class OpNotEqual(Op): OP = "<>" - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) return lhs != rhs @@ -1023,7 +1032,7 @@ class OpNotEqual(Op): class OpLessThanOrEqual(Op): OP = "<=" - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) # In python3 None is not a valid comparator when using < or > so must be handled specially @@ -1036,7 +1045,7 @@ class OpLessThanOrEqual(Op): class OpGreaterThanOrEqual(Op): OP = ">=" - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) # In python3 None is not a valid comparator when using < or > so must be handled specially @@ -1049,64 +1058,64 @@ class OpGreaterThanOrEqual(Op): class OpOr(Op): OP = "OR" - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: lhs = self.lhs.expr(item) return lhs or self.rhs.expr(item) -class Func(object): +class Func: """ Base class for a FilterExpression function """ FUNC = "Unknown" - def __init__(self, *arguments): + def __init__(self, *arguments: Any): self.arguments = arguments - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: raise NotImplementedError - def __repr__(self): + def __repr__(self) -> str: return f"{self.FUNC}({' '.join([repr(arg) for arg in self.arguments])})" class FuncAttrExists(Func): FUNC = "attribute_exists" - def __init__(self, attribute): + def __init__(self, attribute: Operand): self.attr = attribute super().__init__(attribute) - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: return self.attr.get_type(item) is not None -def FuncAttrNotExists(attribute): +def FuncAttrNotExists(attribute: Operand) -> Any: return OpNot(FuncAttrExists(attribute)) class FuncAttrType(Func): FUNC = "attribute_type" - def __init__(self, attribute, _type): + def __init__(self, attribute: Operand, _type: Func): self.attr = attribute self.type = _type super().__init__(attribute, _type) - def expr(self, item): - return self.attr.get_type(item) == self.type.expr(item) + def expr(self, item: Optional[Item]) -> bool: + return self.attr.get_type(item) == self.type.expr(item) # type: ignore[comparison-overlap] class FuncBeginsWith(Func): FUNC = "begins_with" - def __init__(self, attribute, substr): + def __init__(self, attribute: Operand, substr: Operand): self.attr = attribute self.substr = substr super().__init__(attribute, substr) - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: if self.attr.get_type(item) != "S": return False if self.substr.get_type(item) != "S": @@ -1117,12 +1126,12 @@ class FuncBeginsWith(Func): class FuncContains(Func): FUNC = "contains" - def __init__(self, attribute, operand): + def __init__(self, attribute: Operand, operand: Operand): self.attr = attribute self.operand = operand super().__init__(attribute, operand) - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: if self.attr.get_type(item) in ("S", "SS", "NS", "BS", "L"): try: return self.operand.expr(item) in self.attr.expr(item) @@ -1131,18 +1140,18 @@ class FuncContains(Func): return False -def FuncNotContains(attribute, operand): +def FuncNotContains(attribute: Operand, operand: Operand) -> OpNot: return OpNot(FuncContains(attribute, operand)) class FuncSize(Func): FUNC = "size" - def __init__(self, attribute): + def __init__(self, attribute: Operand): self.attr = attribute super().__init__(attribute) - def expr(self, item): + def expr(self, item: Optional[Item]) -> int: # type: ignore[override] if self.attr.get_type(item) is None: raise ValueError(f"Invalid attribute name {self.attr}") @@ -1154,13 +1163,13 @@ class FuncSize(Func): class FuncBetween(Func): FUNC = "BETWEEN" - def __init__(self, attribute, start, end): + def __init__(self, attribute: Operand, start: Operand, end: Operand): self.attr = attribute self.start = start self.end = end super().__init__(attribute, start, end) - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: # In python3 None is not a valid comparator when using < or > so must be handled specially start = self.start.expr(item) attr = self.attr.expr(item) @@ -1183,12 +1192,12 @@ class FuncBetween(Func): class FuncIn(Func): FUNC = "IN" - def __init__(self, attribute, *possible_values): + def __init__(self, attribute: Operand, *possible_values: Any): self.attr = attribute self.possible_values = possible_values super().__init__(attribute, *possible_values) - def expr(self, item): + def expr(self, item: Optional[Item]) -> bool: for possible_value in self.possible_values: if self.attr.expr(item) == possible_value.expr(item): return True @@ -1205,7 +1214,7 @@ COMPARATOR_CLASS = { "<>": OpNotEqual, } -FUNC_CLASS = { +FUNC_CLASS: Dict[str, Any] = { "attribute_exists": FuncAttrExists, "attribute_not_exists": FuncAttrNotExists, "attribute_type": FuncAttrType, diff --git a/moto/dynamodb/exceptions.py b/moto/dynamodb/exceptions.py index 30aba61ce..87642779a 100644 --- a/moto/dynamodb/exceptions.py +++ b/moto/dynamodb/exceptions.py @@ -1,4 +1,5 @@ import json +from typing import Any, List, Optional from moto.core.exceptions import JsonRESTError from moto.dynamodb.limits import HASH_KEY_MAX_LENGTH, RANGE_KEY_MAX_LENGTH @@ -10,7 +11,7 @@ class DynamodbException(JsonRESTError): class MockValidationException(DynamodbException): error_type = "com.amazonaws.dynamodb.v20111205#ValidationException" - def __init__(self, message): + def __init__(self, message: str): super().__init__(MockValidationException.error_type, message=message) self.exception_msg = message @@ -24,14 +25,14 @@ class InvalidUpdateExpressionInvalidDocumentPath(MockValidationException): "The document path provided in the update expression is invalid for update" ) - def __init__(self): + def __init__(self) -> None: super().__init__(self.invalid_update_expression_msg) class InvalidUpdateExpression(MockValidationException): invalid_update_expr_msg = "Invalid UpdateExpression: {update_expression_error}" - def __init__(self, update_expression_error): + def __init__(self, update_expression_error: str): self.update_expression_error = update_expression_error super().__init__( self.invalid_update_expr_msg.format( @@ -45,7 +46,7 @@ class InvalidConditionExpression(MockValidationException): "Invalid ConditionExpression: {condition_expression_error}" ) - def __init__(self, condition_expression_error): + def __init__(self, condition_expression_error: str): self.condition_expression_error = condition_expression_error super().__init__( self.invalid_condition_expr_msg.format( @@ -59,7 +60,7 @@ class ConditionAttributeIsReservedKeyword(InvalidConditionExpression): "Attribute name is a reserved keyword; reserved keyword: {keyword}" ) - def __init__(self, keyword): + def __init__(self, keyword: str): self.keyword = keyword super().__init__(self.attribute_is_keyword_msg.format(keyword=keyword)) @@ -69,7 +70,7 @@ class AttributeDoesNotExist(MockValidationException): "The provided expression refers to an attribute that does not exist in the item" ) - def __init__(self): + def __init__(self) -> None: super().__init__(self.attr_does_not_exist_msg) @@ -78,14 +79,14 @@ class ProvidedKeyDoesNotExist(MockValidationException): "The provided key element does not match the schema" ) - def __init__(self): + def __init__(self) -> None: super().__init__(self.provided_key_does_not_exist_msg) class ExpressionAttributeNameNotDefined(InvalidUpdateExpression): name_not_defined_msg = "An expression attribute name used in the document path is not defined; attribute name: {n}" - def __init__(self, attribute_name): + def __init__(self, attribute_name: str): self.not_defined_attribute_name = attribute_name super().__init__(self.name_not_defined_msg.format(n=attribute_name)) @@ -95,7 +96,7 @@ class AttributeIsReservedKeyword(InvalidUpdateExpression): "Attribute name is a reserved keyword; reserved keyword: {keyword}" ) - def __init__(self, keyword): + def __init__(self, keyword: str): self.keyword = keyword super().__init__(self.attribute_is_keyword_msg.format(keyword=keyword)) @@ -103,7 +104,7 @@ class AttributeIsReservedKeyword(InvalidUpdateExpression): class ExpressionAttributeValueNotDefined(InvalidUpdateExpression): attr_value_not_defined_msg = "An expression attribute value used in expression is not defined; attribute value: {attribute_value}" - def __init__(self, attribute_value): + def __init__(self, attribute_value: str): self.attribute_value = attribute_value super().__init__( self.attr_value_not_defined_msg.format(attribute_value=attribute_value) @@ -113,7 +114,7 @@ class ExpressionAttributeValueNotDefined(InvalidUpdateExpression): class UpdateExprSyntaxError(InvalidUpdateExpression): update_expr_syntax_error_msg = "Syntax error; {error_detail}" - def __init__(self, error_detail): + def __init__(self, error_detail: str): self.error_detail = error_detail super().__init__( self.update_expr_syntax_error_msg.format(error_detail=error_detail) @@ -123,7 +124,7 @@ class UpdateExprSyntaxError(InvalidUpdateExpression): class InvalidTokenException(UpdateExprSyntaxError): token_detail_msg = 'token: "{token}", near: "{near}"' - def __init__(self, token, near): + def __init__(self, token: str, near: str): self.token = token self.near = near super().__init__(self.token_detail_msg.format(token=token, near=near)) @@ -134,7 +135,7 @@ class InvalidExpressionAttributeNameKey(MockValidationException): 'ExpressionAttributeNames contains invalid key: Syntax error; key: "{key}"' ) - def __init__(self, key): + def __init__(self, key: str): self.key = key super().__init__(self.invalid_expr_attr_name_msg.format(key=key)) @@ -142,7 +143,7 @@ class InvalidExpressionAttributeNameKey(MockValidationException): class ItemSizeTooLarge(MockValidationException): item_size_too_large_msg = "Item size has exceeded the maximum allowed size" - def __init__(self): + def __init__(self) -> None: super().__init__(self.item_size_too_large_msg) @@ -151,7 +152,7 @@ class ItemSizeToUpdateTooLarge(MockValidationException): "Item size to update has exceeded the maximum allowed size" ) - def __init__(self): + def __init__(self) -> None: super().__init__(self.item_size_to_update_too_large_msg) @@ -159,21 +160,21 @@ class HashKeyTooLong(MockValidationException): # deliberately no space between of and {lim} key_too_large_msg = f"One or more parameter values were invalid: Size of hashkey has exceeded the maximum size limit of{HASH_KEY_MAX_LENGTH} bytes" - def __init__(self): + def __init__(self) -> None: super().__init__(self.key_too_large_msg) class RangeKeyTooLong(MockValidationException): key_too_large_msg = f"One or more parameter values were invalid: Aggregated size of all range keys has exceeded the size limit of {RANGE_KEY_MAX_LENGTH} bytes" - def __init__(self): + def __init__(self) -> None: super().__init__(self.key_too_large_msg) class IncorrectOperandType(InvalidUpdateExpression): inv_operand_msg = "Incorrect operand type for operator or function; operator or function: {f}, operand type: {t}" - def __init__(self, operator_or_function, operand_type): + def __init__(self, operator_or_function: str, operand_type: str): self.operator_or_function = operator_or_function self.operand_type = operand_type super().__init__( @@ -184,14 +185,14 @@ class IncorrectOperandType(InvalidUpdateExpression): class IncorrectDataType(MockValidationException): inc_data_type_msg = "An operand in the update expression has an incorrect data type" - def __init__(self): + def __init__(self) -> None: super().__init__(self.inc_data_type_msg) class ConditionalCheckFailed(DynamodbException): error_type = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException" - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): super().__init__( ConditionalCheckFailed.error_type, msg or "The conditional request failed" ) @@ -201,7 +202,7 @@ class TransactionCanceledException(DynamodbException): cancel_reason_msg = "Transaction cancelled, please refer cancellation reasons for specific reasons [{}]" error_type = "com.amazonaws.dynamodb.v20120810#TransactionCanceledException" - def __init__(self, errors): + def __init__(self, errors: List[Any]): msg = self.cancel_reason_msg.format( ", ".join([str(code) for code, _, _ in errors]) ) @@ -224,7 +225,7 @@ class TransactionCanceledException(DynamodbException): class MultipleTransactionsException(MockValidationException): msg = "Transaction request cannot include multiple operations on one item" - def __init__(self): + def __init__(self) -> None: super().__init__(self.msg) @@ -234,7 +235,7 @@ class TooManyTransactionsException(MockValidationException): "Member must have length less than or equal to 100." ) - def __init__(self): + def __init__(self) -> None: super().__init__(self.msg) @@ -243,26 +244,28 @@ class EmptyKeyAttributeException(MockValidationException): # AWS has a different message for empty index keys empty_index_msg = "One or more parameter values are not valid. The update expression attempted to update a secondary index key to a value that is not supported. The AttributeValue for a key attribute cannot contain an empty string value." - def __init__(self, key_in_index=False): + def __init__(self, key_in_index: bool = False): super().__init__(self.empty_index_msg if key_in_index else self.empty_str_msg) class UpdateHashRangeKeyException(MockValidationException): msg = "One or more parameter values were invalid: Cannot update attribute {}. This attribute is part of the key" - def __init__(self, key_name): + def __init__(self, key_name: str): super().__init__(self.msg.format(key_name)) class InvalidAttributeTypeError(MockValidationException): msg = "One or more parameter values were invalid: Type mismatch for key {} expected: {} actual: {}" - def __init__(self, name, expected_type, actual_type): + def __init__( + self, name: Optional[str], expected_type: Optional[str], actual_type: str + ): super().__init__(self.msg.format(name, expected_type, actual_type)) class DuplicateUpdateExpression(InvalidUpdateExpression): - def __init__(self, names): + def __init__(self, names: List[str]): super().__init__( f"Two document paths overlap with each other; must remove or rewrite one of these paths; path one: [{names[0]}], path two: [{names[1]}]" ) @@ -271,54 +274,54 @@ class DuplicateUpdateExpression(InvalidUpdateExpression): class TooManyAddClauses(InvalidUpdateExpression): msg = 'The "ADD" section can only be used once in an update expression;' - def __init__(self): + def __init__(self) -> None: super().__init__(self.msg) class ResourceNotFoundException(JsonRESTError): - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): err = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" super().__init__(err, msg or "Requested resource not found") class TableNotFoundException(JsonRESTError): - def __init__(self, name): + def __init__(self, name: str): err = "com.amazonaws.dynamodb.v20111205#TableNotFoundException" super().__init__(err, f"Table not found: {name}") class SourceTableNotFoundException(JsonRESTError): - def __init__(self, source_table_name): + def __init__(self, source_table_name: str): er = "com.amazonaws.dynamodb.v20111205#SourceTableNotFoundException" super().__init__(er, f"Source table not found: {source_table_name}") class BackupNotFoundException(JsonRESTError): - def __init__(self, backup_arn): + def __init__(self, backup_arn: str): er = "com.amazonaws.dynamodb.v20111205#BackupNotFoundException" super().__init__(er, f"Backup not found: {backup_arn}") class TableAlreadyExistsException(JsonRESTError): - def __init__(self, target_table_name): + def __init__(self, target_table_name: str): er = "com.amazonaws.dynamodb.v20111205#TableAlreadyExistsException" super().__init__(er, f"Table already exists: {target_table_name}") class ResourceInUseException(JsonRESTError): - def __init__(self): + def __init__(self) -> None: er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException" super().__init__(er, "Resource in use") class StreamAlreadyEnabledException(JsonRESTError): - def __init__(self): + def __init__(self) -> None: er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException" super().__init__(er, "Cannot enable stream") class InvalidConversion(JsonRESTError): - def __init__(self): + def __init__(self) -> None: er = "SerializationException" super().__init__(er, "NUMBER_VALUE cannot be converted to String") @@ -328,10 +331,10 @@ class TransactWriteSingleOpException(MockValidationException): "TransactItems can only contain one of Check, Put, Update or Delete" ) - def __init__(self): + def __init__(self) -> None: super().__init__(self.there_can_be_only_one) class SerializationException(DynamodbException): - def __init__(self, msg): + def __init__(self, msg: str): super().__init__(error_type="SerializationException", message=msg) diff --git a/moto/dynamodb/models/__init__.py b/moto/dynamodb/models/__init__.py index 9360a4155..7b03cb1ba 100644 --- a/moto/dynamodb/models/__init__.py +++ b/moto/dynamodb/models/__init__.py @@ -1,26 +1,17 @@ -from collections import defaultdict import copy -import datetime -import decimal -import json import re from collections import OrderedDict -from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel -from moto.core.utils import unix_time, unix_time_millis +from typing import Any, Dict, Optional, List, Tuple, Union, Set +from moto.core import BaseBackend, BackendDict +from moto.core.utils import unix_time from moto.core.exceptions import JsonRESTError -from moto.dynamodb.comparisons import get_filter_expression -from moto.dynamodb.comparisons import get_expected +from moto.dynamodb.comparisons import get_filter_expression, get_expected from moto.dynamodb.exceptions import ( - InvalidIndexNameError, ItemSizeTooLarge, ItemSizeToUpdateTooLarge, - HashKeyTooLong, - RangeKeyTooLong, ConditionalCheckFailed, TransactionCanceledException, - EmptyKeyAttributeException, - InvalidAttributeTypeError, MultipleTransactionsException, TooManyTransactionsException, TableNotFoundException, @@ -31,1169 +22,31 @@ from moto.dynamodb.exceptions import ( ResourceInUseException, StreamAlreadyEnabledException, MockValidationException, - InvalidConversion, TransactWriteSingleOpException, - SerializationException, ) -from moto.dynamodb.models.utilities import bytesize -from moto.dynamodb.models.dynamo_type import DynamoType +from moto.dynamodb.models.dynamo_type import DynamoType, Item +from moto.dynamodb.models.table import ( + Table, + RestoredTable, + GlobalSecondaryIndex, + RestoredPITTable, + Backup, +) from moto.dynamodb.parsing.executors import UpdateExpressionExecutor -from moto.dynamodb.parsing.expressions import UpdateExpressionParser +from moto.dynamodb.parsing.expressions import UpdateExpressionParser # type: ignore from moto.dynamodb.parsing.validators import UpdateExpressionValidator -from moto.dynamodb.limits import HASH_KEY_MAX_LENGTH, RANGE_KEY_MAX_LENGTH -from moto.moto_api._internal import mock_random - - -class DynamoJsonEncoder(json.JSONEncoder): - def default(self, o): - if hasattr(o, "to_json"): - return o.to_json() - - -def dynamo_json_dump(dynamo_object): - return json.dumps(dynamo_object, cls=DynamoJsonEncoder) - - -# https://github.com/getmoto/moto/issues/1874 -# Ensure that the total size of an item does not exceed 400kb -class LimitedSizeDict(dict): - def __init__(self, *args, **kwargs): - self.update(*args, **kwargs) - - def __setitem__(self, key, value): - current_item_size = sum( - [ - item.size() if type(item) == DynamoType else bytesize(str(item)) - for item in (list(self.keys()) + list(self.values())) - ] - ) - new_item_size = bytesize(key) + ( - value.size() if type(value) == DynamoType else bytesize(str(value)) - ) - # Official limit is set to 400000 (400KB) - # Manual testing confirms that the actual limit is between 409 and 410KB - # We'll set the limit to something in between to be safe - if (current_item_size + new_item_size) > 405000: - raise ItemSizeTooLarge - super().__setitem__(key, value) - - -class Item(BaseModel): - def __init__(self, hash_key, range_key, attrs): - self.hash_key = hash_key - self.range_key = range_key - - self.attrs = LimitedSizeDict() - for key, value in attrs.items(): - self.attrs[key] = DynamoType(value) - - def __eq__(self, other): - return all( - [ - self.hash_key == other.hash_key, - self.range_key == other.range_key, - self.attrs == other.attrs, - ] - ) - - def __repr__(self): - return f"Item: {self.to_json()}" - - def size(self): - return sum(bytesize(key) + value.size() for key, value in self.attrs.items()) - - def to_json(self): - attributes = {} - for attribute_key, attribute in self.attrs.items(): - attributes[attribute_key] = {attribute.type: attribute.value} - - return {"Attributes": attributes} - - def describe_attrs(self, attributes): - if attributes: - included = {} - for key, value in self.attrs.items(): - if key in attributes: - included[key] = value - else: - included = self.attrs - return {"Item": included} - - def validate_no_empty_key_values(self, attribute_updates, key_attributes): - for attribute_name, update_action in attribute_updates.items(): - action = update_action.get("Action") or "PUT" # PUT is default - if action == "DELETE": - continue - new_value = next(iter(update_action["Value"].values())) - if action == "PUT" and new_value == "" and attribute_name in key_attributes: - raise EmptyKeyAttributeException - - def update_with_attribute_updates(self, attribute_updates): - for attribute_name, update_action in attribute_updates.items(): - # Use default Action value, if no explicit Action is passed. - # Default value is 'Put', according to - # Boto3 DynamoDB.Client.update_item documentation. - action = update_action.get("Action", "PUT") - if action == "DELETE" and "Value" not in update_action: - if attribute_name in self.attrs: - del self.attrs[attribute_name] - continue - new_value = list(update_action["Value"].values())[0] - if action == "PUT": - # TODO deal with other types - if set(update_action["Value"].keys()) == set(["SS"]): - self.attrs[attribute_name] = DynamoType({"SS": new_value}) - elif isinstance(new_value, list): - self.attrs[attribute_name] = DynamoType({"L": new_value}) - elif isinstance(new_value, dict): - self.attrs[attribute_name] = DynamoType({"M": new_value}) - elif set(update_action["Value"].keys()) == set(["N"]): - self.attrs[attribute_name] = DynamoType({"N": new_value}) - elif set(update_action["Value"].keys()) == set(["NULL"]): - if attribute_name in self.attrs: - del self.attrs[attribute_name] - else: - self.attrs[attribute_name] = DynamoType({"S": new_value}) - elif action == "ADD": - if set(update_action["Value"].keys()) == set(["N"]): - existing = self.attrs.get(attribute_name, DynamoType({"N": "0"})) - self.attrs[attribute_name] = DynamoType( - { - "N": str( - decimal.Decimal(existing.value) - + decimal.Decimal(new_value) - ) - } - ) - elif set(update_action["Value"].keys()) == set(["SS"]): - existing = self.attrs.get(attribute_name, DynamoType({"SS": {}})) - new_set = set(existing.value).union(set(new_value)) - self.attrs[attribute_name] = DynamoType({"SS": list(new_set)}) - elif set(update_action["Value"].keys()) == {"L"}: - existing = self.attrs.get(attribute_name, DynamoType({"L": []})) - new_list = existing.value + new_value - self.attrs[attribute_name] = DynamoType({"L": new_list}) - else: - # TODO: implement other data types - raise NotImplementedError( - "ADD not supported for %s" - % ", ".join(update_action["Value"].keys()) - ) - elif action == "DELETE": - if set(update_action["Value"].keys()) == set(["SS"]): - existing = self.attrs.get(attribute_name, DynamoType({"SS": {}})) - new_set = set(existing.value).difference(set(new_value)) - self.attrs[attribute_name] = DynamoType({"SS": list(new_set)}) - else: - raise NotImplementedError( - "ADD not supported for %s" - % ", ".join(update_action["Value"].keys()) - ) - else: - raise NotImplementedError( - f"{action} action not support for update_with_attribute_updates" - ) - - # Filter using projection_expression - # Ensure a deep copy is used to filter, otherwise actual data will be removed - def filter(self, projection_expression): - expressions = [x.strip() for x in projection_expression.split(",")] - top_level_expressions = [ - expr[0 : expr.index(".")] for expr in expressions if "." in expr - ] - for attr in list(self.attrs): - if attr not in expressions and attr not in top_level_expressions: - self.attrs.pop(attr) - if attr in top_level_expressions: - relevant_expressions = [ - expr[len(attr + ".") :] - for expr in expressions - if expr.startswith(attr + ".") - ] - self.attrs[attr].filter(relevant_expressions) - - -class StreamRecord(BaseModel): - def __init__(self, table, stream_type, event_name, old, new, seq): - old_a = old.to_json()["Attributes"] if old is not None else {} - new_a = new.to_json()["Attributes"] if new is not None else {} - - rec = old if old is not None else new - keys = {table.hash_key_attr: rec.hash_key.to_json()} - if table.range_key_attr is not None: - keys[table.range_key_attr] = rec.range_key.to_json() - - self.record = { - "eventID": mock_random.uuid4().hex, - "eventName": event_name, - "eventSource": "aws:dynamodb", - "eventVersion": "1.0", - "awsRegion": "us-east-1", - "dynamodb": { - "StreamViewType": stream_type, - "ApproximateCreationDateTime": datetime.datetime.utcnow().isoformat(), - "SequenceNumber": str(seq), - "SizeBytes": 1, - "Keys": keys, - }, - } - - if stream_type in ("NEW_IMAGE", "NEW_AND_OLD_IMAGES"): - self.record["dynamodb"]["NewImage"] = new_a - if stream_type in ("OLD_IMAGE", "NEW_AND_OLD_IMAGES"): - self.record["dynamodb"]["OldImage"] = old_a - - # This is a substantial overestimate but it's the easiest to do now - self.record["dynamodb"]["SizeBytes"] = len( - dynamo_json_dump(self.record["dynamodb"]) - ) - - def to_json(self): - return self.record - - -class StreamShard(BaseModel): - def __init__(self, account_id, table): - self.account_id = account_id - self.table = table - self.id = "shardId-00000001541626099285-f35f62ef" - self.starting_sequence_number = 1100000000017454423009 - self.items = [] - self.created_on = datetime.datetime.utcnow() - - def to_json(self): - return { - "ShardId": self.id, - "SequenceNumberRange": { - "StartingSequenceNumber": str(self.starting_sequence_number) - }, - } - - def add(self, old, new): - t = self.table.stream_specification["StreamViewType"] - if old is None: - event_name = "INSERT" - elif new is None: - event_name = "REMOVE" - else: - event_name = "MODIFY" - seq = len(self.items) + self.starting_sequence_number - self.items.append(StreamRecord(self.table, t, event_name, old, new, seq)) - result = None - from moto.awslambda import lambda_backends - - for arn, esm in self.table.lambda_event_source_mappings.items(): - region = arn[ - len("arn:aws:lambda:") : arn.index(":", len("arn:aws:lambda:")) - ] - - result = lambda_backends[self.account_id][region].send_dynamodb_items( - arn, self.items, esm.event_source_arn - ) - - if result: - self.items = [] - - def get(self, start, quantity): - start -= self.starting_sequence_number - assert start >= 0 - end = start + quantity - return [i.to_json() for i in self.items[start:end]] - - -class SecondaryIndex(BaseModel): - def project(self, item): - """ - Enforces the ProjectionType of this Index (LSI/GSI) - Removes any non-wanted attributes from the item - :param item: - :return: - """ - if self.projection: - projection_type = self.projection.get("ProjectionType", None) - key_attributes = self.table_key_attrs + [ - key["AttributeName"] for key in self.schema - ] - - if projection_type == "KEYS_ONLY": - item.filter(",".join(key_attributes)) - elif projection_type == "INCLUDE": - allowed_attributes = key_attributes + self.projection.get( - "NonKeyAttributes", [] - ) - item.filter(",".join(allowed_attributes)) - # ALL is handled implicitly by not filtering - return item - - -class LocalSecondaryIndex(SecondaryIndex): - def __init__(self, index_name, schema, projection, table_key_attrs): - self.name = index_name - self.schema = schema - self.projection = projection - self.table_key_attrs = table_key_attrs - - def describe(self): - return { - "IndexName": self.name, - "KeySchema": self.schema, - "Projection": self.projection, - } - - @staticmethod - def create(dct, table_key_attrs): - return LocalSecondaryIndex( - index_name=dct["IndexName"], - schema=dct["KeySchema"], - projection=dct["Projection"], - table_key_attrs=table_key_attrs, - ) - - -class GlobalSecondaryIndex(SecondaryIndex): - def __init__( - self, - index_name, - schema, - projection, - table_key_attrs, - status="ACTIVE", - throughput=None, - ): - self.name = index_name - self.schema = schema - self.projection = projection - self.table_key_attrs = table_key_attrs - self.status = status - self.throughput = throughput or { - "ReadCapacityUnits": 0, - "WriteCapacityUnits": 0, - } - - def describe(self): - return { - "IndexName": self.name, - "KeySchema": self.schema, - "Projection": self.projection, - "IndexStatus": self.status, - "ProvisionedThroughput": self.throughput, - } - - @staticmethod - def create(dct, table_key_attrs): - return GlobalSecondaryIndex( - index_name=dct["IndexName"], - schema=dct["KeySchema"], - projection=dct["Projection"], - table_key_attrs=table_key_attrs, - throughput=dct.get("ProvisionedThroughput", None), - ) - - def update(self, u): - self.name = u.get("IndexName", self.name) - self.schema = u.get("KeySchema", self.schema) - self.projection = u.get("Projection", self.projection) - self.throughput = u.get("ProvisionedThroughput", self.throughput) - - -class Table(CloudFormationModel): - def __init__( - self, - table_name, - account_id, - region, - schema=None, - attr=None, - throughput=None, - billing_mode=None, - indexes=None, - global_indexes=None, - streams=None, - sse_specification=None, - tags=None, - ): - self.name = table_name - self.account_id = account_id - self.region_name = region - self.attr = attr - self.schema = schema - self.range_key_attr = None - self.hash_key_attr = None - self.range_key_type = None - self.hash_key_type = None - for elem in schema: - attr_type = [ - a["AttributeType"] - for a in attr - if a["AttributeName"] == elem["AttributeName"] - ][0] - if elem["KeyType"] == "HASH": - self.hash_key_attr = elem["AttributeName"] - self.hash_key_type = attr_type - else: - self.range_key_attr = elem["AttributeName"] - self.range_key_type = attr_type - self.table_key_attrs = [ - key for key in (self.hash_key_attr, self.range_key_attr) if key - ] - self.billing_mode = billing_mode - if throughput is None: - self.throughput = {"WriteCapacityUnits": 0, "ReadCapacityUnits": 0} - else: - self.throughput = throughput - self.throughput["NumberOfDecreasesToday"] = 0 - self.indexes = [ - LocalSecondaryIndex.create(i, self.table_key_attrs) - for i in (indexes if indexes else []) - ] - self.global_indexes = [ - GlobalSecondaryIndex.create(i, self.table_key_attrs) - for i in (global_indexes if global_indexes else []) - ] - self.created_at = datetime.datetime.utcnow() - self.items = defaultdict(dict) - self.table_arn = self._generate_arn(table_name) - self.tags = tags or [] - self.ttl = { - "TimeToLiveStatus": "DISABLED" # One of 'ENABLING'|'DISABLING'|'ENABLED'|'DISABLED', - # 'AttributeName': 'string' # Can contain this - } - self.stream_specification = {"StreamEnabled": False} - self.latest_stream_label = None - self.stream_shard = None - self.set_stream_specification(streams) - self.lambda_event_source_mappings = {} - self.continuous_backups = { - "ContinuousBackupsStatus": "ENABLED", # One of 'ENABLED'|'DISABLED', it's enabled by default - "PointInTimeRecoveryDescription": { - "PointInTimeRecoveryStatus": "DISABLED" # One of 'ENABLED'|'DISABLED' - }, - } - self.sse_specification = sse_specification - if sse_specification and "KMSMasterKeyId" not in self.sse_specification: - self.sse_specification["KMSMasterKeyId"] = self._get_default_encryption_key( - account_id, region - ) - - def _get_default_encryption_key(self, account_id, region): - from moto.kms import kms_backends - - # https://aws.amazon.com/kms/features/#AWS_Service_Integration - # An AWS managed CMK is created automatically when you first create - # an encrypted resource using an AWS service integrated with KMS. - kms = kms_backends[account_id][region] - ddb_alias = "alias/aws/dynamodb" - if not kms.alias_exists(ddb_alias): - key = kms.create_key( - policy="", - key_usage="ENCRYPT_DECRYPT", - key_spec="SYMMETRIC_DEFAULT", - description="Default master key that protects my DynamoDB table storage", - tags=None, - ) - kms.add_alias(key.id, ddb_alias) - ebs_key = kms.describe_key(ddb_alias) - return ebs_key.arn - - @classmethod - def has_cfn_attr(cls, attr): - return attr in ["Arn", "StreamArn"] - - def get_cfn_attribute(self, attribute_name): - from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - - if attribute_name == "Arn": - return self.table_arn - elif attribute_name == "StreamArn" and self.stream_specification: - return self.describe()["TableDescription"]["LatestStreamArn"] - - raise UnformattedGetAttTemplateException() - - @property - def physical_resource_id(self): - return self.name - - @property - def attribute_keys(self): - # A set of all the hash or range attributes for all indexes - def keys_from_index(idx): - schema = idx.schema - return [attr["AttributeName"] for attr in schema] - - fieldnames = copy.copy(self.table_key_attrs) - for idx in self.indexes + self.global_indexes: - fieldnames += keys_from_index(idx) - return fieldnames - - @staticmethod - def cloudformation_name_type(): - return "TableName" - - @staticmethod - def cloudformation_type(): - # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html - return "AWS::DynamoDB::Table" - - @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): - properties = cloudformation_json["Properties"] - params = {} - - if "KeySchema" in properties: - params["schema"] = properties["KeySchema"] - if "AttributeDefinitions" in properties: - params["attr"] = properties["AttributeDefinitions"] - if "GlobalSecondaryIndexes" in properties: - params["global_indexes"] = properties["GlobalSecondaryIndexes"] - if "ProvisionedThroughput" in properties: - params["throughput"] = properties["ProvisionedThroughput"] - if "LocalSecondaryIndexes" in properties: - params["indexes"] = properties["LocalSecondaryIndexes"] - if "StreamSpecification" in properties: - params["streams"] = properties["StreamSpecification"] - - table = dynamodb_backends[account_id][region_name].create_table( - name=resource_name, **params - ) - return table - - @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): - table = dynamodb_backends[account_id][region_name].delete_table( - name=resource_name - ) - return table - - def _generate_arn(self, name): - return f"arn:aws:dynamodb:{self.region_name}:{self.account_id}:table/{name}" - - def set_stream_specification(self, streams): - self.stream_specification = streams - if streams and (streams.get("StreamEnabled") or streams.get("StreamViewType")): - self.stream_specification["StreamEnabled"] = True - self.latest_stream_label = datetime.datetime.utcnow().isoformat() - self.stream_shard = StreamShard(self.account_id, self) - else: - self.stream_specification = {"StreamEnabled": False} - - def describe(self, base_key="TableDescription"): - results = { - base_key: { - "AttributeDefinitions": self.attr, - "ProvisionedThroughput": self.throughput, - "BillingModeSummary": {"BillingMode": self.billing_mode}, - "TableSizeBytes": 0, - "TableName": self.name, - "TableStatus": "ACTIVE", - "TableArn": self.table_arn, - "KeySchema": self.schema, - "ItemCount": len(self), - "CreationDateTime": unix_time(self.created_at), - "GlobalSecondaryIndexes": [ - index.describe() for index in self.global_indexes - ], - "LocalSecondaryIndexes": [index.describe() for index in self.indexes], - } - } - if self.latest_stream_label: - results[base_key]["LatestStreamLabel"] = self.latest_stream_label - results[base_key][ - "LatestStreamArn" - ] = f"{self.table_arn}/stream/{self.latest_stream_label}" - if self.stream_specification and self.stream_specification["StreamEnabled"]: - results[base_key]["StreamSpecification"] = self.stream_specification - if self.sse_specification and self.sse_specification.get("Enabled") is True: - results[base_key]["SSEDescription"] = { - "Status": "ENABLED", - "SSEType": "KMS", - "KMSMasterKeyArn": self.sse_specification.get("KMSMasterKeyId"), - } - return results - - def __len__(self): - return sum( - [(len(value) if self.has_range_key else 1) for value in self.items.values()] - ) - - @property - def hash_key_names(self): - keys = [self.hash_key_attr] - for index in self.global_indexes: - hash_key = None - for key in index.schema: - if key["KeyType"] == "HASH": - hash_key = key["AttributeName"] - keys.append(hash_key) - return keys - - @property - def range_key_names(self): - keys = [self.range_key_attr] - for index in self.global_indexes: - range_key = None - for key in index.schema: - if key["KeyType"] == "RANGE": - range_key = keys.append(key["AttributeName"]) - keys.append(range_key) - return keys - - def _validate_key_sizes(self, item_attrs): - for hash_name in self.hash_key_names: - hash_value = item_attrs.get(hash_name) - if hash_value: - if DynamoType(hash_value).size() > HASH_KEY_MAX_LENGTH: - raise HashKeyTooLong - for range_name in self.range_key_names: - range_value = item_attrs.get(range_name) - if range_value: - if DynamoType(range_value).size() > RANGE_KEY_MAX_LENGTH: - raise RangeKeyTooLong - - def _validate_item_types(self, item_attrs): - for key, value in item_attrs.items(): - if type(value) == dict: - self._validate_item_types(value) - elif type(value) == int and key == "N": - raise InvalidConversion - if key == "S": - # This scenario is usually caught by boto3, but the user can disable parameter validation - # Which is why we need to catch it 'server-side' as well - if type(value) == int: - raise SerializationException( - "NUMBER_VALUE cannot be converted to String" - ) - if type(value) == dict: - raise SerializationException( - "Start of structure or map found where not expected" - ) - - def put_item( - self, - item_attrs, - expected=None, - condition_expression=None, - expression_attribute_names=None, - expression_attribute_values=None, - overwrite=False, - ): - if self.hash_key_attr not in item_attrs.keys(): - raise MockValidationException( - "One or more parameter values were invalid: Missing the key " - + self.hash_key_attr - + " in the item" - ) - hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) - if self.has_range_key: - if self.range_key_attr not in item_attrs.keys(): - raise MockValidationException( - "One or more parameter values were invalid: Missing the key " - + self.range_key_attr - + " in the item" - ) - range_value = DynamoType(item_attrs.get(self.range_key_attr)) - else: - range_value = None - - if hash_value.type != self.hash_key_type: - raise InvalidAttributeTypeError( - self.hash_key_attr, - expected_type=self.hash_key_type, - actual_type=hash_value.type, - ) - if range_value and range_value.type != self.range_key_type: - raise InvalidAttributeTypeError( - self.range_key_attr, - expected_type=self.range_key_type, - actual_type=range_value.type, - ) - - self._validate_item_types(item_attrs) - self._validate_key_sizes(item_attrs) - - if expected is None: - expected = {} - lookup_range_value = range_value - else: - expected_range_value = expected.get(self.range_key_attr, {}).get("Value") - if expected_range_value is None: - lookup_range_value = range_value - else: - lookup_range_value = DynamoType(expected_range_value) - current = self.get_item(hash_value, lookup_range_value) - item = Item(hash_value, range_value, item_attrs) - - if not overwrite: - if not get_expected(expected).expr(current): - raise ConditionalCheckFailed - condition_op = get_filter_expression( - condition_expression, - expression_attribute_names, - expression_attribute_values, - ) - if not condition_op.expr(current): - raise ConditionalCheckFailed - - if range_value: - self.items[hash_value][range_value] = item - else: - self.items[hash_value] = item - - if self.stream_shard is not None: - self.stream_shard.add(current, item) - - return item - - def __nonzero__(self): - return True - - def __bool__(self): - return self.__nonzero__() - - @property - def has_range_key(self): - return self.range_key_attr is not None - - def get_item(self, hash_key, range_key=None, projection_expression=None): - if self.has_range_key and not range_key: - raise MockValidationException( - "Table has a range key, but no range key was passed into get_item" - ) - try: - result = None - - if range_key: - result = self.items[hash_key][range_key] - elif hash_key in self.items: - result = self.items[hash_key] - - if projection_expression and result: - result = copy.deepcopy(result) - result.filter(projection_expression) - - if not result: - raise KeyError - - return result - except KeyError: - return None - - def delete_item(self, hash_key, range_key): - try: - if range_key: - item = self.items[hash_key].pop(range_key) - else: - item = self.items.pop(hash_key) - - if self.stream_shard is not None: - self.stream_shard.add(item, None) - - return item - except KeyError: - return None - - def query( - self, - hash_key, - range_comparison, - range_objs, - limit, - exclusive_start_key, - scan_index_forward, - projection_expression, - index_name=None, - filter_expression=None, - **filter_kwargs, - ): - results = [] - - if index_name: - all_indexes = self.all_indexes() - indexes_by_name = dict((i.name, i) for i in all_indexes) - if index_name not in indexes_by_name: - all_indexes = ", ".join(indexes_by_name.keys()) - raise MockValidationException( - f"Invalid index: {index_name} for table: {self.name}. Available indexes are: {all_indexes}" - ) - - index = indexes_by_name[index_name] - try: - index_hash_key = [ - key for key in index.schema if key["KeyType"] == "HASH" - ][0] - except IndexError: - raise MockValidationException( - f"Missing Hash Key. KeySchema: {index.name}" - ) - - try: - index_range_key = [ - key for key in index.schema if key["KeyType"] == "RANGE" - ][0] - except IndexError: - index_range_key = None - - possible_results = [] - for item in self.all_items(): - if not isinstance(item, Item): - continue - item_hash_key = item.attrs.get(index_hash_key["AttributeName"]) - if index_range_key is None: - if item_hash_key and item_hash_key == hash_key: - possible_results.append(item) - else: - item_range_key = item.attrs.get(index_range_key["AttributeName"]) - if item_hash_key and item_hash_key == hash_key and item_range_key: - possible_results.append(item) - else: - possible_results = [ - item - for item in list(self.all_items()) - if isinstance(item, Item) and item.hash_key == hash_key - ] - - if range_comparison: - if index_name and not index_range_key: - raise ValueError( - "Range Key comparison but no range key found for index: %s" - % index_name - ) - - elif index_name: - for result in possible_results: - if result.attrs.get(index_range_key["AttributeName"]).compare( - range_comparison, range_objs - ): - results.append(result) - else: - for result in possible_results: - if result.range_key.compare(range_comparison, range_objs): - results.append(result) - - if filter_kwargs: - for result in possible_results: - for field, value in filter_kwargs.items(): - dynamo_types = [ - DynamoType(ele) for ele in value["AttributeValueList"] - ] - if result.attrs.get(field).compare( - value["ComparisonOperator"], dynamo_types - ): - results.append(result) - - if not range_comparison and not filter_kwargs: - # If we're not filtering on range key or on an index return all - # values - results = possible_results - - if index_name: - - if index_range_key: - - # Convert to float if necessary to ensure proper ordering - def conv(x): - return float(x.value) if x.type == "N" else x.value - - results.sort( - key=lambda item: conv(item.attrs[index_range_key["AttributeName"]]) - if item.attrs.get(index_range_key["AttributeName"]) - else None - ) - else: - results.sort(key=lambda item: item.range_key) - - if scan_index_forward is False: - results.reverse() - - scanned_count = len(list(self.all_items())) - - results = copy.deepcopy(results) - if index_name: - index = self.get_index(index_name) - for result in results: - index.project(result) - - results, last_evaluated_key = self._trim_results( - results, limit, exclusive_start_key, scanned_index=index_name - ) - - if filter_expression is not None: - results = [item for item in results if filter_expression.expr(item)] - - if projection_expression: - for result in results: - result.filter(projection_expression) - - return results, scanned_count, last_evaluated_key - - def all_items(self): - for hash_set in self.items.values(): - if self.range_key_attr: - for item in hash_set.values(): - yield item - else: - yield hash_set - - def all_indexes(self): - return (self.global_indexes or []) + (self.indexes or []) - - def get_index(self, index_name, error_if_not=False): - all_indexes = self.all_indexes() - indexes_by_name = dict((i.name, i) for i in all_indexes) - if error_if_not and index_name not in indexes_by_name: - raise InvalidIndexNameError( - f"The table does not have the specified index: {index_name}" - ) - return indexes_by_name[index_name] - - def has_idx_items(self, index_name): - - idx = self.get_index(index_name) - idx_col_set = set([i["AttributeName"] for i in idx.schema]) - - for hash_set in self.items.values(): - if self.range_key_attr: - for item in hash_set.values(): - if idx_col_set.issubset(set(item.attrs)): - yield item - else: - if idx_col_set.issubset(set(hash_set.attrs)): - yield hash_set - - def scan( - self, - filters, - limit, - exclusive_start_key, - filter_expression=None, - index_name=None, - projection_expression=None, - ): - results = [] - scanned_count = 0 - - if index_name: - self.get_index(index_name, error_if_not=True) - items = self.has_idx_items(index_name) - else: - items = self.all_items() - - for item in items: - scanned_count += 1 - passes_all_conditions = True - for ( - attribute_name, - (comparison_operator, comparison_objs), - ) in filters.items(): - attribute = item.attrs.get(attribute_name) - - if attribute: - # Attribute found - if not attribute.compare(comparison_operator, comparison_objs): - passes_all_conditions = False - break - elif comparison_operator == "NULL": - # Comparison is NULL and we don't have the attribute - continue - else: - # No attribute found and comparison is no NULL. This item - # fails - passes_all_conditions = False - break - - if passes_all_conditions: - results.append(item) - - results, last_evaluated_key = self._trim_results( - results, limit, exclusive_start_key, scanned_index=index_name - ) - - if filter_expression is not None: - results = [item for item in results if filter_expression.expr(item)] - - if projection_expression: - results = copy.deepcopy(results) - for result in results: - result.filter(projection_expression) - - return results, scanned_count, last_evaluated_key - - def _trim_results(self, results, limit, exclusive_start_key, scanned_index=None): - if exclusive_start_key is not None: - hash_key = DynamoType(exclusive_start_key.get(self.hash_key_attr)) - range_key = exclusive_start_key.get(self.range_key_attr) - if range_key is not None: - range_key = DynamoType(range_key) - for i in range(len(results)): - if ( - results[i].hash_key == hash_key - and results[i].range_key == range_key - ): - results = results[i + 1 :] - break - - last_evaluated_key = None - size_limit = 1000000 # DynamoDB has a 1MB size limit - item_size = sum(res.size() for res in results) - if item_size > size_limit: - item_size = idx = 0 - while item_size + results[idx].size() < size_limit: - item_size += results[idx].size() - idx += 1 - limit = min(limit, idx) if limit else idx - if limit and len(results) > limit: - results = results[:limit] - last_evaluated_key = {self.hash_key_attr: results[-1].hash_key} - if results[-1].range_key is not None: - last_evaluated_key[self.range_key_attr] = results[-1].range_key - - if scanned_index: - idx = self.get_index(scanned_index) - idx_col_list = [i["AttributeName"] for i in idx.schema] - for col in idx_col_list: - last_evaluated_key[col] = results[-1].attrs[col] - - return results, last_evaluated_key - - def delete(self, account_id, region_name): - dynamodb_backends[account_id][region_name].delete_table(self.name) - - -class RestoredTable(Table): - def __init__(self, name, account_id, region, backup): - params = self._parse_params_from_backup(backup) - super().__init__(name, account_id=account_id, region=region, **params) - self.indexes = copy.deepcopy(backup.table.indexes) - self.global_indexes = copy.deepcopy(backup.table.global_indexes) - self.items = copy.deepcopy(backup.table.items) - # Restore Attrs - self.source_backup_arn = backup.arn - self.source_table_arn = backup.table.table_arn - self.restore_date_time = self.created_at - - @staticmethod - def _parse_params_from_backup(backup): - params = { - "schema": copy.deepcopy(backup.table.schema), - "attr": copy.deepcopy(backup.table.attr), - "throughput": copy.deepcopy(backup.table.throughput), - } - return params - - def describe(self, base_key="TableDescription"): - result = super().describe(base_key=base_key) - result[base_key]["RestoreSummary"] = { - "SourceBackupArn": self.source_backup_arn, - "SourceTableArn": self.source_table_arn, - "RestoreDateTime": unix_time(self.restore_date_time), - "RestoreInProgress": False, - } - return result - - -class RestoredPITTable(Table): - def __init__(self, name, account_id, region, source): - params = self._parse_params_from_table(source) - super().__init__(name, account_id=account_id, region=region, **params) - self.indexes = copy.deepcopy(source.indexes) - self.global_indexes = copy.deepcopy(source.global_indexes) - self.items = copy.deepcopy(source.items) - # Restore Attrs - self.source_table_arn = source.table_arn - self.restore_date_time = self.created_at - - @staticmethod - def _parse_params_from_table(table): - params = { - "schema": copy.deepcopy(table.schema), - "attr": copy.deepcopy(table.attr), - "throughput": copy.deepcopy(table.throughput), - } - return params - - def describe(self, base_key="TableDescription"): - result = super().describe(base_key=base_key) - result[base_key]["RestoreSummary"] = { - "SourceTableArn": self.source_table_arn, - "RestoreDateTime": unix_time(self.restore_date_time), - "RestoreInProgress": False, - } - return result - - -class Backup(object): - def __init__(self, backend, name, table, status=None, type_=None): - self.backend = backend - self.name = name - self.table = copy.deepcopy(table) - self.status = status or "AVAILABLE" - self.type = type_ or "USER" - self.creation_date_time = datetime.datetime.utcnow() - self.identifier = self._make_identifier() - - def _make_identifier(self): - timestamp = int(unix_time_millis(self.creation_date_time)) - timestamp_padded = str("0" + str(timestamp))[-16:16] - guid = str(mock_random.uuid4()) - guid_shortened = guid[:8] - return f"{timestamp_padded}-{guid_shortened}" - - @property - def arn(self): - return f"arn:aws:dynamodb:{self.backend.region_name}:{self.backend.account_id}:table/{self.table.name}/backup/{self.identifier}" - - @property - def details(self): - details = { - "BackupArn": self.arn, - "BackupName": self.name, - "BackupSizeBytes": 123, - "BackupStatus": self.status, - "BackupType": self.type, - "BackupCreationDateTime": unix_time(self.creation_date_time), - } - return details - - @property - def summary(self): - summary = { - "TableName": self.table.name, - # 'TableId': 'string', - "TableArn": self.table.table_arn, - "BackupArn": self.arn, - "BackupName": self.name, - "BackupCreationDateTime": unix_time(self.creation_date_time), - # 'BackupExpiryDateTime': datetime(2015, 1, 1), - "BackupStatus": self.status, - "BackupType": self.type, - "BackupSizeBytes": 123, - } - return summary - - @property - def description(self): - source_table_details = self.table.describe()["TableDescription"] - source_table_details["TableCreationDateTime"] = source_table_details[ - "CreationDateTime" - ] - description = { - "BackupDetails": self.details, - "SourceTableDetails": source_table_details, - } - return description class DynamoDBBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.tables = OrderedDict() - self.backups = OrderedDict() + self.tables: Dict[str, Table] = OrderedDict() + self.backups: Dict[str, Backup] = OrderedDict() @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" # No 'vpce' in the base endpoint DNS name return BaseBackend.default_vpc_endpoint_service_factory( @@ -1205,7 +58,7 @@ class DynamoDBBackend(BaseBackend): base_endpoint_dns_names=[f"dynamodb.{service_region}.amazonaws.com"], ) - def create_table(self, name, **params): + def create_table(self, name: str, **params: Any) -> Table: if name in self.tables: raise ResourceInUseException table = Table( @@ -1214,12 +67,12 @@ class DynamoDBBackend(BaseBackend): self.tables[name] = table return table - def delete_table(self, name): + def delete_table(self, name: str) -> Table: if name not in self.tables: raise ResourceNotFoundException - return self.tables.pop(name, None) + return self.tables.pop(name) - def describe_endpoints(self): + def describe_endpoints(self) -> List[Dict[str, Union[int, str]]]: return [ { "Address": f"dynamodb.{self.region_name}.amazonaws.com", @@ -1227,26 +80,28 @@ class DynamoDBBackend(BaseBackend): } ] - def tag_resource(self, table_arn, tags): + def tag_resource(self, table_arn: str, tags: List[Dict[str, str]]) -> None: for table in self.tables: if self.tables[table].table_arn == table_arn: self.tables[table].tags.extend(tags) - def untag_resource(self, table_arn, tag_keys): + def untag_resource(self, table_arn: str, tag_keys: List[str]) -> None: for table in self.tables: if self.tables[table].table_arn == table_arn: self.tables[table].tags = [ tag for tag in self.tables[table].tags if tag["Key"] not in tag_keys ] - def list_tags_of_resource(self, table_arn): + def list_tags_of_resource(self, table_arn: str) -> List[Dict[str, str]]: for table in self.tables: if self.tables[table].table_arn == table_arn: return self.tables[table].tags raise ResourceNotFoundException - def list_tables(self, limit, exclusive_start_table_name): - all_tables = list(self.tables.keys()) + def list_tables( + self, limit: int, exclusive_start_table_name: str + ) -> Tuple[List[str], Optional[str]]: + all_tables: List[str] = list(self.tables.keys()) if exclusive_start_table_name: try: @@ -1267,19 +122,19 @@ class DynamoDBBackend(BaseBackend): return tables, tables[-1] return tables, None - def describe_table(self, name): + def describe_table(self, name: str) -> Dict[str, Any]: table = self.get_table(name) return table.describe(base_key="Table") def update_table( self, - name, - attr_definitions, - global_index, - throughput, - billing_mode, - stream_spec, - ): + name: str, + attr_definitions: List[Dict[str, str]], + global_index: List[Dict[str, Any]], + throughput: Dict[str, Any], + billing_mode: str, + stream_spec: Dict[str, Any], + ) -> Table: table = self.get_table(name) if attr_definitions: table.attr = attr_definitions @@ -1293,17 +148,19 @@ class DynamoDBBackend(BaseBackend): table = self.update_table_streams(name, stream_spec) return table - def update_table_throughput(self, name, throughput): + def update_table_throughput(self, name: str, throughput: Dict[str, int]) -> Table: table = self.tables[name] table.throughput = throughput return table - def update_table_billing_mode(self, name, billing_mode): + def update_table_billing_mode(self, name: str, billing_mode: str) -> Table: table = self.tables[name] table.billing_mode = billing_mode return table - def update_table_streams(self, name, stream_specification): + def update_table_streams( + self, name: str, stream_specification: Dict[str, Any] + ) -> Table: table = self.tables[name] if ( stream_specification.get("StreamEnabled") @@ -1313,7 +170,9 @@ class DynamoDBBackend(BaseBackend): table.set_stream_specification(stream_specification) return table - def update_table_global_indexes(self, name, global_index_updates): + def update_table_global_indexes( + self, name: str, global_index_updates: List[Dict[str, Any]] + ) -> Table: table = self.tables[name] gsis_by_name = dict((i.name, i) for i in table.global_indexes) for gsi_update in global_index_updates: @@ -1356,14 +215,14 @@ class DynamoDBBackend(BaseBackend): def put_item( self, - table_name, - item_attrs, - expected=None, - condition_expression=None, - expression_attribute_names=None, - expression_attribute_values=None, - overwrite=False, - ): + table_name: str, + item_attrs: Dict[str, Any], + expected: Optional[Dict[str, Any]] = None, + condition_expression: Optional[str] = None, + expression_attribute_names: Optional[Dict[str, Any]] = None, + expression_attribute_values: Optional[Dict[str, Any]] = None, + overwrite: bool = False, + ) -> Item: table = self.get_table(table_name) return table.put_item( item_attrs, @@ -1374,7 +233,9 @@ class DynamoDBBackend(BaseBackend): overwrite, ) - def get_table_keys_name(self, table_name, keys): + def get_table_keys_name( + self, table_name: str, keys: Dict[str, Any] + ) -> Tuple[Optional[str], Optional[str]]: """ Given a set of keys, extracts the key and range key """ @@ -1395,7 +256,9 @@ class DynamoDBBackend(BaseBackend): potential_range = key return potential_hash, potential_range - def get_keys_value(self, table, keys): + def get_keys_value( + self, table: Table, keys: Dict[str, Any] + ) -> Tuple[DynamoType, Optional[DynamoType]]: if table.hash_key_attr not in keys or ( table.has_range_key and table.range_key_attr not in keys ): @@ -1403,57 +266,64 @@ class DynamoDBBackend(BaseBackend): raise MockValidationException("Validation Exception") hash_key = DynamoType(keys[table.hash_key_attr]) range_key = ( - DynamoType(keys[table.range_key_attr]) if table.has_range_key else None + DynamoType(keys[table.range_key_attr]) if table.range_key_attr else None ) return hash_key, range_key - def get_schema(self, table_name, index_name): + def get_schema( + self, table_name: str, index_name: Optional[str] + ) -> List[Dict[str, Any]]: table = self.get_table(table_name) if index_name: all_indexes = (table.global_indexes or []) + (table.indexes or []) indexes_by_name = dict((i.name, i) for i in all_indexes) if index_name not in indexes_by_name: - all_indexes = ", ".join(indexes_by_name.keys()) + all_index_names = ", ".join(indexes_by_name.keys()) raise ResourceNotFoundException( - f"Invalid index: {index_name} for table: {table_name}. Available indexes are: {all_indexes}" + f"Invalid index: {index_name} for table: {table_name}. Available indexes are: {all_index_names}" ) return indexes_by_name[index_name].schema else: return table.schema - def get_table(self, table_name) -> Table: + def get_table(self, table_name: str) -> Table: if table_name not in self.tables: raise ResourceNotFoundException() - return self.tables.get(table_name) + return self.tables[table_name] - def get_item(self, table_name, keys, projection_expression=None): + def get_item( + self, + table_name: str, + keys: Dict[str, Any], + projection_expression: Optional[str] = None, + ) -> Optional[Item]: table = self.get_table(table_name) hash_key, range_key = self.get_keys_value(table, keys) return table.get_item(hash_key, range_key, projection_expression) def query( self, - table_name, - hash_key_dict, - range_comparison, - range_value_dicts, - limit, - exclusive_start_key, - scan_index_forward, - projection_expression, - index_name=None, - expr_names=None, - expr_values=None, - filter_expression=None, - **filter_kwargs, - ): + table_name: str, + hash_key_dict: Dict[str, Any], + range_comparison: Optional[str], + range_value_dicts: List[Dict[str, Any]], + limit: int, + exclusive_start_key: Dict[str, Any], + scan_index_forward: bool, + projection_expression: str, + index_name: Optional[str] = None, + expr_names: Optional[Dict[str, str]] = None, + expr_values: Optional[Dict[str, str]] = None, + filter_expression: Optional[str] = None, + **filter_kwargs: Any, + ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: table = self.get_table(table_name) hash_key = DynamoType(hash_key_dict) range_values = [DynamoType(range_value) for range_value in range_value_dicts] - filter_expression = get_filter_expression( + filter_expression_op = get_filter_expression( filter_expression, expr_names, expr_values ) @@ -1466,30 +336,30 @@ class DynamoDBBackend(BaseBackend): scan_index_forward, projection_expression, index_name, - filter_expression, + filter_expression_op, **filter_kwargs, ) def scan( self, - table_name, - filters, - limit, - exclusive_start_key, - filter_expression, - expr_names, - expr_values, - index_name, - projection_expression, - ): + table_name: str, + filters: Dict[str, Any], + limit: int, + exclusive_start_key: Dict[str, Any], + filter_expression: str, + expr_names: Dict[str, Any], + expr_values: Dict[str, Any], + index_name: str, + projection_expression: str, + ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: table = self.get_table(table_name) - scan_filters = {} + scan_filters: Dict[str, Any] = {} for key, (comparison_operator, comparison_values) in filters.items(): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) - filter_expression = get_filter_expression( + filter_expression_op = get_filter_expression( filter_expression, expr_names, expr_values ) @@ -1497,22 +367,22 @@ class DynamoDBBackend(BaseBackend): scan_filters, limit, exclusive_start_key, - filter_expression, + filter_expression_op, index_name, projection_expression, ) def update_item( self, - table_name, - key, - update_expression, - expression_attribute_names, - expression_attribute_values, - attribute_updates=None, - expected=None, - condition_expression=None, - ): + table_name: str, + key: Dict[str, Dict[str, Any]], + update_expression: str, + expression_attribute_names: Dict[str, Any], + expression_attribute_values: Dict[str, Any], + attribute_updates: Optional[Dict[str, Any]] = None, + expected: Optional[Dict[str, Any]] = None, + condition_expression: Optional[str] = None, + ) -> Item: table = self.get_table(table_name) # Support spaces between operators in an update expression @@ -1527,7 +397,7 @@ class DynamoDBBackend(BaseBackend): # Covers cases where table has hash and range keys, ``key`` param # will be a dict hash_value = DynamoType(key[table.hash_key_attr]) - range_value = DynamoType(key[table.range_key_attr]) + range_value = DynamoType(key[table.range_key_attr]) # type: ignore[index] elif table.hash_key_attr in key: # Covers tables that have a range key where ``key`` param is a dict hash_value = DynamoType(key[table.hash_key_attr]) @@ -1565,48 +435,50 @@ class DynamoDBBackend(BaseBackend): item=item, table=table, ).validate() - data = {table.hash_key_attr: {hash_value.type: hash_value.value}} + data: Dict[str, Any] = { + table.hash_key_attr: {hash_value.type: hash_value.value} + } if range_value: data.update( - {table.range_key_attr: {range_value.type: range_value.value}} + {table.range_key_attr: {range_value.type: range_value.value}} # type: ignore[dict-item] ) table.put_item(data) item = table.get_item(hash_value, range_value) if attribute_updates: - item.validate_no_empty_key_values(attribute_updates, table.attribute_keys) + item.validate_no_empty_key_values(attribute_updates, table.attribute_keys) # type: ignore[union-attr] if update_expression: validator = UpdateExpressionValidator( update_expression_ast, expression_attribute_names=expression_attribute_names, expression_attribute_values=expression_attribute_values, - item=item, + item=item, # type: ignore[arg-type] table=table, ) validated_ast = validator.validate() validated_ast.normalize() try: UpdateExpressionExecutor( - validated_ast, item, expression_attribute_names + validated_ast, item, expression_attribute_names # type: ignore[arg-type] ).execute() except ItemSizeTooLarge: raise ItemSizeToUpdateTooLarge() else: - item.update_with_attribute_updates(attribute_updates) + item.update_with_attribute_updates(attribute_updates) # type: ignore if table.stream_shard is not None: table.stream_shard.add(orig_item, item) - return item + return item # type: ignore[return-value] def delete_item( self, - table_name, - key, - expression_attribute_names=None, - expression_attribute_values=None, - condition_expression=None, - ): + table_name: str, + key: Dict[str, Any], + expression_attribute_names: Optional[Dict[str, Any]] = None, + expression_attribute_values: Optional[Dict[str, Any]] = None, + condition_expression: Optional[str] = None, + ) -> Optional[Item]: table = self.get_table(table_name) hash_value, range_value = self.get_keys_value(table, key) @@ -1617,12 +489,12 @@ class DynamoDBBackend(BaseBackend): expression_attribute_names, expression_attribute_values, ) - if not condition_op.expr(item): + if not condition_op.expr(item): # type: ignore[arg-type] raise ConditionalCheckFailed return table.delete_item(hash_value, range_value) - def update_time_to_live(self, table_name, ttl_spec): + def update_time_to_live(self, table_name: str, ttl_spec: Dict[str, Any]) -> None: table = self.tables.get(table_name) if table is None: raise JsonRESTError("ResourceNotFound", "Table not found") @@ -1639,27 +511,29 @@ class DynamoDBBackend(BaseBackend): table.ttl["TimeToLiveStatus"] = "DISABLED" table.ttl["AttributeName"] = ttl_spec["AttributeName"] - def describe_time_to_live(self, table_name): + def describe_time_to_live(self, table_name: str) -> Dict[str, Any]: table = self.tables.get(table_name) if table is None: raise JsonRESTError("ResourceNotFound", "Table not found") return table.ttl - def transact_write_items(self, transact_items): + def transact_write_items(self, transact_items: List[Dict[str, Any]]) -> None: if len(transact_items) > 100: raise TooManyTransactionsException() # Create a backup in case any of the transactions fail original_table_state = copy.deepcopy(self.tables) - target_items = set() + target_items: Set[Tuple[str, str]] = set() - def check_unicity(table_name, key): + def check_unicity(table_name: str, key: Dict[str, Any]) -> None: item = (str(table_name), str(key)) if item in target_items: raise MultipleTransactionsException() target_items.add(item) - errors = [] # [(Code, Message, Item), ..] + errors: List[ + Union[Tuple[str, str, Dict[str, Any]], Tuple[None, None, None]] + ] = [] # [(Code, Message, Item), ..] for item in transact_items: # check transact writes are not performing multiple operations # in the same item @@ -1686,7 +560,7 @@ class DynamoDBBackend(BaseBackend): expression_attribute_names, expression_attribute_values, ) - if not condition_op.expr(current): + if not condition_op.expr(current): # type: ignore[arg-type] raise ConditionalCheckFailed() elif "Put" in item: item = item["Put"] @@ -1766,13 +640,13 @@ class DynamoDBBackend(BaseBackend): self.tables = original_table_state raise MultipleTransactionsException() except Exception as e: # noqa: E722 Do not use bare except - errors.append((type(e).__name__, e.message, item)) + errors.append((type(e).__name__, e.message, item)) # type: ignore[attr-defined] if any([code is not None for code, _, _ in errors]): # Rollback to the original state, and reraise the errors self.tables = original_table_state raise TransactionCanceledException(errors) - def describe_continuous_backups(self, table_name): + def describe_continuous_backups(self, table_name: str) -> Dict[str, Any]: try: table = self.get_table(table_name) except ResourceNotFoundException: @@ -1780,7 +654,9 @@ class DynamoDBBackend(BaseBackend): return table.continuous_backups - def update_continuous_backups(self, table_name, point_in_time_spec): + def update_continuous_backups( + self, table_name: str, point_in_time_spec: Dict[str, Any] + ) -> Dict[str, Any]: try: table = self.get_table(table_name) except ResourceNotFoundException: @@ -1805,27 +681,27 @@ class DynamoDBBackend(BaseBackend): return table.continuous_backups - def get_backup(self, backup_arn): + def get_backup(self, backup_arn: str) -> Backup: if backup_arn not in self.backups: raise BackupNotFoundException(backup_arn) - return self.backups.get(backup_arn) + return self.backups[backup_arn] - def list_backups(self, table_name): + def list_backups(self, table_name: str) -> List[Backup]: backups = list(self.backups.values()) if table_name is not None: backups = [backup for backup in backups if backup.table.name == table_name] return backups - def create_backup(self, table_name, backup_name): + def create_backup(self, table_name: str, backup_name: str) -> Backup: try: table = self.get_table(table_name) except ResourceNotFoundException: raise TableNotFoundException(table_name) - backup = Backup(self, backup_name, table) + backup = Backup(self.account_id, self.region_name, backup_name, table) self.backups[backup.arn] = backup return backup - def delete_backup(self, backup_arn): + def delete_backup(self, backup_arn: str) -> Backup: backup = self.get_backup(backup_arn) if backup is None: raise KeyError() @@ -1833,13 +709,15 @@ class DynamoDBBackend(BaseBackend): backup_deleted.status = "DELETED" return backup_deleted - def describe_backup(self, backup_arn): + def describe_backup(self, backup_arn: str) -> Backup: backup = self.get_backup(backup_arn) if backup is None: raise KeyError() return backup - def restore_table_from_backup(self, target_table_name, backup_arn): + def restore_table_from_backup( + self, target_table_name: str, backup_arn: str + ) -> RestoredTable: backup = self.get_backup(backup_arn) if target_table_name in self.tables: raise TableAlreadyExistsException(target_table_name) @@ -1852,7 +730,9 @@ class DynamoDBBackend(BaseBackend): self.tables[target_table_name] = new_table return new_table - def restore_table_to_point_in_time(self, target_table_name, source_table_name): + def restore_table_to_point_in_time( + self, target_table_name: str, source_table_name: str + ) -> RestoredPITTable: """ Currently this only accepts the source and target table elements, and will copy all items from the source without respect to other arguments. @@ -1879,13 +759,13 @@ class DynamoDBBackend(BaseBackend): # TODO: Move logic here ###################### - def batch_get_item(self): + def batch_get_item(self) -> None: pass - def batch_write_item(self): + def batch_write_item(self) -> None: pass - def transact_get_items(self): + def transact_get_items(self) -> None: pass diff --git a/moto/dynamodb/models/dynamo_type.py b/moto/dynamodb/models/dynamo_type.py index 672407f26..ee9f7b425 100644 --- a/moto/dynamodb/models/dynamo_type.py +++ b/moto/dynamodb/models/dynamo_type.py @@ -1,9 +1,16 @@ -from moto.dynamodb.comparisons import get_comparison_func -from moto.dynamodb.exceptions import IncorrectDataType +import decimal +from typing import Any, Dict, List, Union, Optional + +from moto.core import BaseModel +from moto.dynamodb.exceptions import ( + IncorrectDataType, + EmptyKeyAttributeException, + ItemSizeTooLarge, +) from moto.dynamodb.models.utilities import bytesize -class DDBType(object): +class DDBType: """ Official documentation at https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_AttributeValue.html """ @@ -20,7 +27,7 @@ class DDBType(object): NULL = "NULL" -class DDBTypeConversion(object): +class DDBTypeConversion: _human_type_mapping = { val: key.replace("_", " ") for key, val in DDBType.__dict__.items() @@ -28,13 +35,13 @@ class DDBTypeConversion(object): } @classmethod - def get_human_type(cls, abbreviated_type): + def get_human_type(cls, abbreviated_type: str) -> str: """ Args: abbreviated_type(str): An attribute of DDBType Returns: - str: The human readable form of the DDBType. + str: The human-readable form of the DDBType. """ return cls._human_type_mapping.get(abbreviated_type, abbreviated_type) @@ -44,19 +51,19 @@ class DynamoType(object): http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes """ - def __init__(self, type_as_dict): + def __init__(self, type_as_dict: Union["DynamoType", Dict[str, Any]]): if type(type_as_dict) == DynamoType: - self.type = type_as_dict.type - self.value = type_as_dict.value + self.type: str = type_as_dict.type + self.value: Any = type_as_dict.value else: - self.type = list(type_as_dict)[0] - self.value = list(type_as_dict.values())[0] + self.type = list(type_as_dict)[0] # type: ignore[arg-type] + self.value = list(type_as_dict.values())[0] # type: ignore[union-attr] if self.is_list(): self.value = [DynamoType(val) for val in self.value] elif self.is_map(): self.value = dict((k, DynamoType(v)) for k, v in self.value.items()) - def filter(self, projection_expressions): + def filter(self, projection_expressions: str) -> None: nested_projections = [ expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr ] @@ -78,31 +85,31 @@ class DynamoType(object): for expr in expressions_to_delete: self.value.pop(expr) - def __hash__(self): + def __hash__(self) -> int: return hash((self.type, self.value)) - def __eq__(self, other): + def __eq__(self, other: "DynamoType") -> bool: # type: ignore[override] return self.type == other.type and self.value == other.value - def __ne__(self, other): + def __ne__(self, other: "DynamoType") -> bool: # type: ignore[override] return self.type != other.type or self.value != other.value - def __lt__(self, other): + def __lt__(self, other: "DynamoType") -> bool: return self.cast_value < other.cast_value - def __le__(self, other): + def __le__(self, other: "DynamoType") -> bool: return self.cast_value <= other.cast_value - def __gt__(self, other): + def __gt__(self, other: "DynamoType") -> bool: return self.cast_value > other.cast_value - def __ge__(self, other): + def __ge__(self, other: "DynamoType") -> bool: return self.cast_value >= other.cast_value - def __repr__(self): + def __repr__(self) -> str: return f"DynamoType: {self.to_json()}" - def __add__(self, other): + def __add__(self, other: "DynamoType") -> "DynamoType": if self.type != other.type: raise TypeError("Different types of operandi is not allowed.") if self.is_number(): @@ -112,7 +119,7 @@ class DynamoType(object): else: raise IncorrectDataType() - def __sub__(self, other): + def __sub__(self, other: "DynamoType") -> "DynamoType": if self.type != other.type: raise TypeError("Different types of operandi is not allowed.") if self.type == DDBType.NUMBER: @@ -122,7 +129,7 @@ class DynamoType(object): else: raise TypeError("Sum only supported for Numbers.") - def __getitem__(self, item): + def __getitem__(self, item: "DynamoType") -> "DynamoType": if isinstance(item, str): # If our DynamoType is a map it should be subscriptable with a key if self.type == DDBType.MAP: @@ -135,7 +142,7 @@ class DynamoType(object): f"This DynamoType {self.type} is not subscriptable by a {type(item)}" ) - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: Any) -> None: if isinstance(key, int): if self.is_list(): if key >= len(self.value): @@ -150,7 +157,7 @@ class DynamoType(object): raise NotImplementedError(f"No set_item for {type(key)}") @property - def cast_value(self): + def cast_value(self) -> Any: # type: ignore[misc] if self.is_number(): try: return int(self.value) @@ -166,7 +173,7 @@ class DynamoType(object): else: return self.value - def child_attr(self, key): + def child_attr(self, key: Union[int, str, None]) -> Optional["DynamoType"]: """ Get Map or List children by key. str for Map, int for List. @@ -183,7 +190,7 @@ class DynamoType(object): return None - def size(self): + def size(self) -> int: if self.is_number(): value_size = len(str(self.value)) elif self.is_set(): @@ -201,34 +208,204 @@ class DynamoType(object): value_size = bytesize(self.value) return value_size - def to_json(self): + def to_json(self) -> Dict[str, Any]: return {self.type: self.value} - def compare(self, range_comparison, range_objs): + def compare(self, range_comparison: str, range_objs: List[Any]) -> bool: """ Compares this type against comparison filters """ + from moto.dynamodb.comparisons import get_comparison_func + range_values = [obj.cast_value for obj in range_objs] comparison_func = get_comparison_func(range_comparison) return comparison_func(self.cast_value, *range_values) - def is_number(self): + def is_number(self) -> bool: return self.type == DDBType.NUMBER - def is_set(self): + def is_set(self) -> bool: return self.type in (DDBType.STRING_SET, DDBType.NUMBER_SET, DDBType.BINARY_SET) - def is_list(self): + def is_list(self) -> bool: return self.type == DDBType.LIST - def is_map(self): + def is_map(self) -> bool: return self.type == DDBType.MAP - def same_type(self, other): + def same_type(self, other: "DynamoType") -> bool: return self.type == other.type - def pop(self, key, *args, **kwargs): + def pop(self, key: str, *args: Any, **kwargs: Any) -> None: if self.is_map() or self.is_list(): self.value.pop(key, *args, **kwargs) else: raise TypeError(f"pop not supported for DynamoType {self.type}") + + +# https://github.com/getmoto/moto/issues/1874 +# Ensure that the total size of an item does not exceed 400kb +class LimitedSizeDict(Dict[str, Any]): + def __init__(self, *args: Any, **kwargs: Any): + self.update(*args, **kwargs) + + def __setitem__(self, key: str, value: Any) -> None: + current_item_size = sum( + [ + item.size() if type(item) == DynamoType else bytesize(str(item)) + for item in (list(self.keys()) + list(self.values())) + ] + ) + new_item_size = bytesize(key) + ( + value.size() if type(value) == DynamoType else bytesize(str(value)) + ) + # Official limit is set to 400000 (400KB) + # Manual testing confirms that the actual limit is between 409 and 410KB + # We'll set the limit to something in between to be safe + if (current_item_size + new_item_size) > 405000: + raise ItemSizeTooLarge + super().__setitem__(key, value) + + +class Item(BaseModel): + def __init__( + self, + hash_key: DynamoType, + range_key: Optional[DynamoType], + attrs: Dict[str, Any], + ): + self.hash_key = hash_key + self.range_key = range_key + + self.attrs = LimitedSizeDict() + for key, value in attrs.items(): + self.attrs[key] = DynamoType(value) + + def __eq__(self, other: "Item") -> bool: # type: ignore[override] + return all( + [ + self.hash_key == other.hash_key, + self.range_key == other.range_key, # type: ignore[operator] + self.attrs == other.attrs, + ] + ) + + def __repr__(self) -> str: + return f"Item: {self.to_json()}" + + def size(self) -> int: + return sum(bytesize(key) + value.size() for key, value in self.attrs.items()) + + def to_json(self) -> Dict[str, Any]: + attributes = {} + for attribute_key, attribute in self.attrs.items(): + attributes[attribute_key] = {attribute.type: attribute.value} + + return {"Attributes": attributes} + + def describe_attrs( + self, attributes: Optional[Dict[str, Any]] + ) -> Dict[str, Dict[str, Any]]: + if attributes: + included = {} + for key, value in self.attrs.items(): + if key in attributes: + included[key] = value + else: + included = self.attrs + return {"Item": included} + + def validate_no_empty_key_values( + self, attribute_updates: Dict[str, Any], key_attributes: List[str] + ) -> None: + for attribute_name, update_action in attribute_updates.items(): + action = update_action.get("Action") or "PUT" # PUT is default + if action == "DELETE": + continue + new_value = next(iter(update_action["Value"].values())) + if action == "PUT" and new_value == "" and attribute_name in key_attributes: + raise EmptyKeyAttributeException + + def update_with_attribute_updates(self, attribute_updates: Dict[str, Any]) -> None: + for attribute_name, update_action in attribute_updates.items(): + # Use default Action value, if no explicit Action is passed. + # Default value is 'Put', according to + # Boto3 DynamoDB.Client.update_item documentation. + action = update_action.get("Action", "PUT") + if action == "DELETE" and "Value" not in update_action: + if attribute_name in self.attrs: + del self.attrs[attribute_name] + continue + new_value = list(update_action["Value"].values())[0] + if action == "PUT": + # TODO deal with other types + if set(update_action["Value"].keys()) == set(["SS"]): + self.attrs[attribute_name] = DynamoType({"SS": new_value}) + elif isinstance(new_value, list): + self.attrs[attribute_name] = DynamoType({"L": new_value}) + elif isinstance(new_value, dict): + self.attrs[attribute_name] = DynamoType({"M": new_value}) + elif set(update_action["Value"].keys()) == set(["N"]): + self.attrs[attribute_name] = DynamoType({"N": new_value}) + elif set(update_action["Value"].keys()) == set(["NULL"]): + if attribute_name in self.attrs: + del self.attrs[attribute_name] + else: + self.attrs[attribute_name] = DynamoType({"S": new_value}) + elif action == "ADD": + if set(update_action["Value"].keys()) == set(["N"]): + existing = self.attrs.get(attribute_name, DynamoType({"N": "0"})) + self.attrs[attribute_name] = DynamoType( + { + "N": str( + decimal.Decimal(existing.value) + + decimal.Decimal(new_value) + ) + } + ) + elif set(update_action["Value"].keys()) == set(["SS"]): + existing = self.attrs.get(attribute_name, DynamoType({"SS": {}})) + new_set = set(existing.value).union(set(new_value)) + self.attrs[attribute_name] = DynamoType({"SS": list(new_set)}) + elif set(update_action["Value"].keys()) == {"L"}: + existing = self.attrs.get(attribute_name, DynamoType({"L": []})) + new_list = existing.value + new_value + self.attrs[attribute_name] = DynamoType({"L": new_list}) + else: + # TODO: implement other data types + raise NotImplementedError( + "ADD not supported for %s" + % ", ".join(update_action["Value"].keys()) + ) + elif action == "DELETE": + if set(update_action["Value"].keys()) == set(["SS"]): + existing = self.attrs.get(attribute_name, DynamoType({"SS": {}})) + new_set = set(existing.value).difference(set(new_value)) + self.attrs[attribute_name] = DynamoType({"SS": list(new_set)}) + else: + raise NotImplementedError( + "ADD not supported for %s" + % ", ".join(update_action["Value"].keys()) + ) + else: + raise NotImplementedError( + f"{action} action not support for update_with_attribute_updates" + ) + + # Filter using projection_expression + # Ensure a deep copy is used to filter, otherwise actual data will be removed + def filter(self, projection_expression: str) -> None: + expressions = [x.strip() for x in projection_expression.split(",")] + top_level_expressions = [ + expr[0 : expr.index(".")] for expr in expressions if "." in expr + ] + for attr in list(self.attrs): + if attr not in expressions and attr not in top_level_expressions: + self.attrs.pop(attr) + if attr in top_level_expressions: + relevant_expressions = [ + expr[len(attr + ".") :] + for expr in expressions + if expr.startswith(attr + ".") + ] + self.attrs[attr].filter(relevant_expressions) diff --git a/moto/dynamodb/models/table.py b/moto/dynamodb/models/table.py new file mode 100644 index 000000000..483897246 --- /dev/null +++ b/moto/dynamodb/models/table.py @@ -0,0 +1,1036 @@ +from collections import defaultdict +import copy +import datetime + +from typing import Any, Dict, Optional, List, Tuple, Iterator, Sequence +from moto.core import BaseModel, CloudFormationModel +from moto.core.utils import unix_time, unix_time_millis +from moto.dynamodb.comparisons import get_filter_expression, get_expected +from moto.dynamodb.exceptions import ( + InvalidIndexNameError, + HashKeyTooLong, + RangeKeyTooLong, + ConditionalCheckFailed, + InvalidAttributeTypeError, + MockValidationException, + InvalidConversion, + SerializationException, +) +from moto.dynamodb.models.utilities import dynamo_json_dump +from moto.dynamodb.models.dynamo_type import DynamoType, Item +from moto.dynamodb.limits import HASH_KEY_MAX_LENGTH, RANGE_KEY_MAX_LENGTH +from moto.moto_api._internal import mock_random + + +class SecondaryIndex(BaseModel): + def __init__( + self, + index_name: str, + schema: List[Dict[str, str]], + projection: Dict[str, Any], + table_key_attrs: List[str], + ): + self.name = index_name + self.schema = schema + self.table_key_attrs = table_key_attrs + self.projection = projection + + def project(self, item: Item) -> Item: + """ + Enforces the ProjectionType of this Index (LSI/GSI) + Removes any non-wanted attributes from the item + :param item: + :return: + """ + if self.projection: + projection_type = self.projection.get("ProjectionType", None) + key_attributes = self.table_key_attrs + [ + key["AttributeName"] for key in self.schema + ] + + if projection_type == "KEYS_ONLY": + item.filter(",".join(key_attributes)) + elif projection_type == "INCLUDE": + allowed_attributes = key_attributes + self.projection.get( + "NonKeyAttributes", [] + ) + item.filter(",".join(allowed_attributes)) + # ALL is handled implicitly by not filtering + return item + + +class LocalSecondaryIndex(SecondaryIndex): + def describe(self) -> Dict[str, Any]: + return { + "IndexName": self.name, + "KeySchema": self.schema, + "Projection": self.projection, + } + + @staticmethod + def create(dct: Dict[str, Any], table_key_attrs: List[str]) -> "LocalSecondaryIndex": # type: ignore[misc] + return LocalSecondaryIndex( + index_name=dct["IndexName"], + schema=dct["KeySchema"], + projection=dct["Projection"], + table_key_attrs=table_key_attrs, + ) + + +class GlobalSecondaryIndex(SecondaryIndex): + def __init__( + self, + index_name: str, + schema: List[Dict[str, str]], + projection: Dict[str, Any], + table_key_attrs: List[str], + status: str = "ACTIVE", + throughput: Optional[Dict[str, Any]] = None, + ): + super().__init__(index_name, schema, projection, table_key_attrs) + self.status = status + self.throughput = throughput or { + "ReadCapacityUnits": 0, + "WriteCapacityUnits": 0, + } + + def describe(self) -> Dict[str, Any]: + return { + "IndexName": self.name, + "KeySchema": self.schema, + "Projection": self.projection, + "IndexStatus": self.status, + "ProvisionedThroughput": self.throughput, + } + + @staticmethod + def create(dct: Dict[str, Any], table_key_attrs: List[str]) -> "GlobalSecondaryIndex": # type: ignore[misc] + return GlobalSecondaryIndex( + index_name=dct["IndexName"], + schema=dct["KeySchema"], + projection=dct["Projection"], + table_key_attrs=table_key_attrs, + throughput=dct.get("ProvisionedThroughput", None), + ) + + def update(self, u: Dict[str, Any]) -> None: + self.name = u.get("IndexName", self.name) + self.schema = u.get("KeySchema", self.schema) + self.projection = u.get("Projection", self.projection) + self.throughput = u.get("ProvisionedThroughput", self.throughput) + + +class StreamRecord(BaseModel): + def __init__( + self, + table: "Table", + stream_type: str, + event_name: str, + old: Optional[Item], + new: Optional[Item], + seq: int, + ): + old_a = old.to_json()["Attributes"] if old is not None else {} + new_a = new.to_json()["Attributes"] if new is not None else {} + + rec = old if old is not None else new + keys = {table.hash_key_attr: rec.hash_key.to_json()} # type: ignore[union-attr] + if table.range_key_attr is not None and rec is not None: + keys[table.range_key_attr] = rec.range_key.to_json() # type: ignore + + self.record: Dict[str, Any] = { + "eventID": mock_random.uuid4().hex, + "eventName": event_name, + "eventSource": "aws:dynamodb", + "eventVersion": "1.0", + "awsRegion": "us-east-1", + "dynamodb": { + "StreamViewType": stream_type, + "ApproximateCreationDateTime": datetime.datetime.utcnow().isoformat(), + "SequenceNumber": str(seq), + "SizeBytes": 1, + "Keys": keys, + }, + } + + if stream_type in ("NEW_IMAGE", "NEW_AND_OLD_IMAGES"): + self.record["dynamodb"]["NewImage"] = new_a + if stream_type in ("OLD_IMAGE", "NEW_AND_OLD_IMAGES"): + self.record["dynamodb"]["OldImage"] = old_a + + # This is a substantial overestimate but it's the easiest to do now + self.record["dynamodb"]["SizeBytes"] = len( + dynamo_json_dump(self.record["dynamodb"]) + ) + + def to_json(self) -> Dict[str, Any]: + return self.record + + +class StreamShard(BaseModel): + def __init__(self, account_id: str, table: "Table"): + self.account_id = account_id + self.table = table + self.id = "shardId-00000001541626099285-f35f62ef" + self.starting_sequence_number = 1100000000017454423009 + self.items: List[StreamRecord] = [] + self.created_on = datetime.datetime.utcnow() + + def to_json(self) -> Dict[str, Any]: + return { + "ShardId": self.id, + "SequenceNumberRange": { + "StartingSequenceNumber": str(self.starting_sequence_number) + }, + } + + def add(self, old: Optional[Item], new: Optional[Item]) -> None: + t = self.table.stream_specification["StreamViewType"] # type: ignore + if old is None: + event_name = "INSERT" + elif new is None: + event_name = "REMOVE" + else: + event_name = "MODIFY" + seq = len(self.items) + self.starting_sequence_number + self.items.append(StreamRecord(self.table, t, event_name, old, new, seq)) + result = None + from moto.awslambda import lambda_backends + + for arn, esm in self.table.lambda_event_source_mappings.items(): + region = arn[ + len("arn:aws:lambda:") : arn.index(":", len("arn:aws:lambda:")) + ] + + result = lambda_backends[self.account_id][region].send_dynamodb_items( + arn, self.items, esm.event_source_arn + ) + + if result: + self.items = [] + + def get(self, start: int, quantity: int) -> List[Dict[str, Any]]: + start -= self.starting_sequence_number + assert start >= 0 + end = start + quantity + return [i.to_json() for i in self.items[start:end]] + + +class Table(CloudFormationModel): + def __init__( + self, + table_name: str, + account_id: str, + region: str, + schema: List[Dict[str, Any]], + attr: List[Dict[str, str]], + throughput: Optional[Dict[str, int]] = None, + billing_mode: Optional[str] = None, + indexes: Optional[List[Dict[str, Any]]] = None, + global_indexes: Optional[List[Dict[str, Any]]] = None, + streams: Optional[Dict[str, Any]] = None, + sse_specification: Optional[Dict[str, Any]] = None, + tags: Optional[List[Dict[str, str]]] = None, + ): + self.name = table_name + self.account_id = account_id + self.region_name = region + self.attr = attr + self.schema = schema + self.range_key_attr: Optional[str] = None + self.hash_key_attr: str = "" + self.range_key_type: Optional[str] = None + self.hash_key_type: str = "" + for elem in schema: + attr_type = [ + a["AttributeType"] + for a in attr + if a["AttributeName"] == elem["AttributeName"] + ][0] + if elem["KeyType"] == "HASH": + self.hash_key_attr = elem["AttributeName"] + self.hash_key_type = attr_type + else: + self.range_key_attr = elem["AttributeName"] + self.range_key_type = attr_type + self.table_key_attrs = [ + key for key in (self.hash_key_attr, self.range_key_attr) if key is not None + ] + self.billing_mode = billing_mode + if throughput is None: + self.throughput = {"WriteCapacityUnits": 0, "ReadCapacityUnits": 0} + else: + self.throughput = throughput + self.throughput["NumberOfDecreasesToday"] = 0 + self.indexes = [ + LocalSecondaryIndex.create(i, self.table_key_attrs) + for i in (indexes if indexes else []) + ] + self.global_indexes = [ + GlobalSecondaryIndex.create(i, self.table_key_attrs) + for i in (global_indexes if global_indexes else []) + ] + self.created_at = datetime.datetime.utcnow() + self.items = defaultdict(dict) # type: ignore # [hash: DynamoType] or [hash: [range: DynamoType]] + self.table_arn = self._generate_arn(table_name) + self.tags = tags or [] + self.ttl = { + "TimeToLiveStatus": "DISABLED" # One of 'ENABLING'|'DISABLING'|'ENABLED'|'DISABLED', + # 'AttributeName': 'string' # Can contain this + } + self.stream_specification: Optional[Dict[str, Any]] = {"StreamEnabled": False} + self.latest_stream_label: Optional[str] = None + self.stream_shard: Optional[StreamShard] = None + self.set_stream_specification(streams) + self.lambda_event_source_mappings: Dict[str, Any] = {} + self.continuous_backups: Dict[str, Any] = { + "ContinuousBackupsStatus": "ENABLED", # One of 'ENABLED'|'DISABLED', it's enabled by default + "PointInTimeRecoveryDescription": { + "PointInTimeRecoveryStatus": "DISABLED" # One of 'ENABLED'|'DISABLED' + }, + } + self.sse_specification = sse_specification + if self.sse_specification and "KMSMasterKeyId" not in self.sse_specification: + self.sse_specification["KMSMasterKeyId"] = self._get_default_encryption_key( + account_id, region + ) + + def _get_default_encryption_key(self, account_id: str, region: str) -> str: + from moto.kms import kms_backends + + # https://aws.amazon.com/kms/features/#AWS_Service_Integration + # An AWS managed CMK is created automatically when you first create + # an encrypted resource using an AWS service integrated with KMS. + kms = kms_backends[account_id][region] + ddb_alias = "alias/aws/dynamodb" + if not kms.alias_exists(ddb_alias): + key = kms.create_key( + policy="", + key_usage="ENCRYPT_DECRYPT", + key_spec="SYMMETRIC_DEFAULT", + description="Default master key that protects my DynamoDB table storage", + tags=None, + ) + kms.add_alias(key.id, ddb_alias) + ebs_key = kms.describe_key(ddb_alias) + return ebs_key.arn + + @classmethod + def has_cfn_attr(cls, attr: str) -> bool: + return attr in ["Arn", "StreamArn"] + + def get_cfn_attribute(self, attribute_name: str) -> Any: # type: ignore[misc] + from moto.cloudformation.exceptions import UnformattedGetAttTemplateException + + if attribute_name == "Arn": + return self.table_arn + elif attribute_name == "StreamArn" and self.stream_specification: + return self.describe()["TableDescription"]["LatestStreamArn"] + + raise UnformattedGetAttTemplateException() + + @property + def physical_resource_id(self) -> str: + return self.name + + @property + def attribute_keys(self) -> List[str]: + # A set of all the hash or range attributes for all indexes + def keys_from_index(idx: SecondaryIndex) -> List[str]: + schema = idx.schema + return [attr["AttributeName"] for attr in schema] + + fieldnames = copy.copy(self.table_key_attrs) + for idx in self.indexes + self.global_indexes: + fieldnames += keys_from_index(idx) + return fieldnames + + @staticmethod + def cloudformation_name_type() -> str: + return "TableName" + + @staticmethod + def cloudformation_type() -> str: + # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html + return "AWS::DynamoDB::Table" + + @classmethod + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Table": + from moto.dynamodb.models import dynamodb_backends + + properties = cloudformation_json["Properties"] + params = {} + + if "KeySchema" in properties: + params["schema"] = properties["KeySchema"] + if "AttributeDefinitions" in properties: + params["attr"] = properties["AttributeDefinitions"] + if "GlobalSecondaryIndexes" in properties: + params["global_indexes"] = properties["GlobalSecondaryIndexes"] + if "ProvisionedThroughput" in properties: + params["throughput"] = properties["ProvisionedThroughput"] + if "LocalSecondaryIndexes" in properties: + params["indexes"] = properties["LocalSecondaryIndexes"] + if "StreamSpecification" in properties: + params["streams"] = properties["StreamSpecification"] + + table = dynamodb_backends[account_id][region_name].create_table( + name=resource_name, **params + ) + return table + + @classmethod + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + ) -> None: + from moto.dynamodb.models import dynamodb_backends + + dynamodb_backends[account_id][region_name].delete_table(name=resource_name) + + def _generate_arn(self, name: str) -> str: + return f"arn:aws:dynamodb:{self.region_name}:{self.account_id}:table/{name}" + + def set_stream_specification(self, streams: Optional[Dict[str, Any]]) -> None: + self.stream_specification = streams + if ( + self.stream_specification + and streams + and (streams.get("StreamEnabled") or streams.get("StreamViewType")) + ): + self.stream_specification["StreamEnabled"] = True + self.latest_stream_label = datetime.datetime.utcnow().isoformat() + self.stream_shard = StreamShard(self.account_id, self) + else: + self.stream_specification = {"StreamEnabled": False} + + def describe(self, base_key: str = "TableDescription") -> Dict[str, Any]: + results: Dict[str, Any] = { + base_key: { + "AttributeDefinitions": self.attr, + "ProvisionedThroughput": self.throughput, + "BillingModeSummary": {"BillingMode": self.billing_mode}, + "TableSizeBytes": 0, + "TableName": self.name, + "TableStatus": "ACTIVE", + "TableArn": self.table_arn, + "KeySchema": self.schema, + "ItemCount": len(self), + "CreationDateTime": unix_time(self.created_at), + "GlobalSecondaryIndexes": [ + index.describe() for index in self.global_indexes + ], + "LocalSecondaryIndexes": [index.describe() for index in self.indexes], + } + } + if self.latest_stream_label: + results[base_key]["LatestStreamLabel"] = self.latest_stream_label + results[base_key][ + "LatestStreamArn" + ] = f"{self.table_arn}/stream/{self.latest_stream_label}" + if self.stream_specification and self.stream_specification["StreamEnabled"]: + results[base_key]["StreamSpecification"] = self.stream_specification + if self.sse_specification and self.sse_specification.get("Enabled") is True: + results[base_key]["SSEDescription"] = { + "Status": "ENABLED", + "SSEType": "KMS", + "KMSMasterKeyArn": self.sse_specification.get("KMSMasterKeyId"), + } + return results + + def __len__(self) -> int: + return sum( + [(len(value) if self.has_range_key else 1) for value in self.items.values()] + ) + + @property + def hash_key_names(self) -> List[str]: + keys = [self.hash_key_attr] + for index in self.global_indexes: + for key in index.schema: + if key["KeyType"] == "HASH": + keys.append(key["AttributeName"]) + return keys + + @property + def range_key_names(self) -> List[str]: + keys = [self.range_key_attr] + for index in self.global_indexes: + for key in index.schema: + if key["KeyType"] == "RANGE": + keys.append(key["AttributeName"]) + return keys # type: ignore[return-value] + + def _validate_key_sizes(self, item_attrs: Dict[str, Any]) -> None: + for hash_name in self.hash_key_names: + hash_value = item_attrs.get(hash_name) + if hash_value: + if DynamoType(hash_value).size() > HASH_KEY_MAX_LENGTH: + raise HashKeyTooLong + for range_name in self.range_key_names: + range_value = item_attrs.get(range_name) + if range_value: + if DynamoType(range_value).size() > RANGE_KEY_MAX_LENGTH: + raise RangeKeyTooLong + + def _validate_item_types(self, item_attrs: Dict[str, Any]) -> None: + for key, value in item_attrs.items(): + if type(value) == dict: + self._validate_item_types(value) + elif type(value) == int and key == "N": + raise InvalidConversion + if key == "S": + # This scenario is usually caught by boto3, but the user can disable parameter validation + # Which is why we need to catch it 'server-side' as well + if type(value) == int: + raise SerializationException( + "NUMBER_VALUE cannot be converted to String" + ) + if type(value) == dict: + raise SerializationException( + "Start of structure or map found where not expected" + ) + + def put_item( + self, + item_attrs: Dict[str, Any], + expected: Optional[Dict[str, Any]] = None, + condition_expression: Optional[str] = None, + expression_attribute_names: Optional[Dict[str, str]] = None, + expression_attribute_values: Optional[Dict[str, Any]] = None, + overwrite: bool = False, + ) -> Item: + if self.hash_key_attr not in item_attrs.keys(): + raise MockValidationException( + "One or more parameter values were invalid: Missing the key " + + self.hash_key_attr + + " in the item" + ) + hash_value = DynamoType(item_attrs[self.hash_key_attr]) + if self.range_key_attr is not None: + if self.range_key_attr not in item_attrs.keys(): + raise MockValidationException( + f"One or more parameter values were invalid: Missing the key {self.range_key_attr} in the item" + ) + range_value = DynamoType(item_attrs[self.range_key_attr]) + else: + range_value = None + + if hash_value.type != self.hash_key_type: + raise InvalidAttributeTypeError( + self.hash_key_attr, + expected_type=self.hash_key_type, + actual_type=hash_value.type, + ) + if range_value and range_value.type != self.range_key_type: + raise InvalidAttributeTypeError( + self.range_key_attr, + expected_type=self.range_key_type, + actual_type=range_value.type, + ) + + self._validate_item_types(item_attrs) + self._validate_key_sizes(item_attrs) + + if expected is None: + expected = {} + lookup_range_value = range_value + else: + expected_range_value = expected.get(self.range_key_attr, {}).get("Value") # type: ignore + if expected_range_value is None: + lookup_range_value = range_value + else: + lookup_range_value = DynamoType(expected_range_value) + current = self.get_item(hash_value, lookup_range_value) + item = Item(hash_value, range_value, item_attrs) + + if not overwrite: + if not get_expected(expected).expr(current): + raise ConditionalCheckFailed + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + if not condition_op.expr(current): + raise ConditionalCheckFailed + + if range_value: + self.items[hash_value][range_value] = item + else: + self.items[hash_value] = item # type: ignore[assignment] + + if self.stream_shard is not None: + self.stream_shard.add(current, item) + + return item + + def __nonzero__(self) -> bool: + return True + + def __bool__(self) -> bool: + return self.__nonzero__() + + @property + def has_range_key(self) -> bool: + return self.range_key_attr is not None + + def get_item( + self, + hash_key: DynamoType, + range_key: Optional[DynamoType] = None, + projection_expression: Optional[str] = None, + ) -> Optional[Item]: + if self.has_range_key and not range_key: + raise MockValidationException( + "Table has a range key, but no range key was passed into get_item" + ) + try: + result = None + + if range_key: + result = self.items[hash_key][range_key] + elif hash_key in self.items: + result = self.items[hash_key] + + if projection_expression and result: + result = copy.deepcopy(result) + result.filter(projection_expression) + + if not result: + raise KeyError + + return result + except KeyError: + return None + + def delete_item( + self, hash_key: DynamoType, range_key: Optional[DynamoType] + ) -> Optional[Item]: + try: + if range_key: + item = self.items[hash_key].pop(range_key) + else: + item = self.items.pop(hash_key) + + if self.stream_shard is not None: + self.stream_shard.add(item, None) + + return item + except KeyError: + return None + + def query( + self, + hash_key: DynamoType, + range_comparison: Optional[str], + range_objs: List[DynamoType], + limit: int, + exclusive_start_key: Dict[str, Any], + scan_index_forward: bool, + projection_expression: str, + index_name: Optional[str] = None, + filter_expression: Any = None, + **filter_kwargs: Any, + ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: + results = [] + + if index_name: + all_indexes = self.all_indexes() + indexes_by_name = dict((i.name, i) for i in all_indexes) + if index_name not in indexes_by_name: + all_names = ", ".join(indexes_by_name.keys()) + raise MockValidationException( + f"Invalid index: {index_name} for table: {self.name}. Available indexes are: {all_names}" + ) + + index = indexes_by_name[index_name] + try: + index_hash_key = [ + key for key in index.schema if key["KeyType"] == "HASH" + ][0] + except IndexError: + raise MockValidationException( + f"Missing Hash Key. KeySchema: {index.name}" + ) + + try: + index_range_key = [ + key for key in index.schema if key["KeyType"] == "RANGE" + ][0] + except IndexError: + index_range_key = None + + possible_results = [] + for item in self.all_items(): + if not isinstance(item, Item): + continue + item_hash_key = item.attrs.get(index_hash_key["AttributeName"]) + if index_range_key is None: + if item_hash_key and item_hash_key == hash_key: + possible_results.append(item) + else: + item_range_key = item.attrs.get(index_range_key["AttributeName"]) + if item_hash_key and item_hash_key == hash_key and item_range_key: + possible_results.append(item) + else: + possible_results = [ + item + for item in list(self.all_items()) + if isinstance(item, Item) and item.hash_key == hash_key + ] + + if range_comparison: + if index_name and not index_range_key: + raise ValueError( + "Range Key comparison but no range key found for index: %s" + % index_name + ) + + elif index_name: + for result in possible_results: + if result.attrs.get(index_range_key["AttributeName"]).compare( # type: ignore + range_comparison, range_objs + ): + results.append(result) + else: + for result in possible_results: + if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr] + results.append(result) + + if filter_kwargs: + for result in possible_results: + for field, value in filter_kwargs.items(): + dynamo_types = [ + DynamoType(ele) for ele in value["AttributeValueList"] + ] + if result.attrs.get(field).compare( # type: ignore[union-attr] + value["ComparisonOperator"], dynamo_types + ): + results.append(result) + + if not range_comparison and not filter_kwargs: + # If we're not filtering on range key or on an index return all + # values + results = possible_results + + if index_name: + + if index_range_key: + + # Convert to float if necessary to ensure proper ordering + def conv(x: DynamoType) -> Any: + return float(x.value) if x.type == "N" else x.value + + results.sort( + key=lambda item: conv(item.attrs[index_range_key["AttributeName"]]) # type: ignore + if item.attrs.get(index_range_key["AttributeName"]) # type: ignore + else None + ) + else: + results.sort(key=lambda item: item.range_key) # type: ignore + + if scan_index_forward is False: + results.reverse() + + scanned_count = len(list(self.all_items())) + + results = copy.deepcopy(results) + if index_name: + index = self.get_index(index_name) + for result in results: + index.project(result) + + results, last_evaluated_key = self._trim_results( + results, limit, exclusive_start_key, scanned_index=index_name + ) + + if filter_expression is not None: + results = [item for item in results if filter_expression.expr(item)] + + if projection_expression: + for result in results: + result.filter(projection_expression) + + return results, scanned_count, last_evaluated_key + + def all_items(self) -> Iterator[Item]: + for hash_set in self.items.values(): + if self.range_key_attr: + for item in hash_set.values(): + yield item + else: + yield hash_set # type: ignore + + def all_indexes(self) -> Sequence[SecondaryIndex]: + return (self.global_indexes or []) + (self.indexes or []) # type: ignore + + def get_index(self, index_name: str, error_if_not: bool = False) -> SecondaryIndex: + all_indexes = self.all_indexes() + indexes_by_name = dict((i.name, i) for i in all_indexes) + if error_if_not and index_name not in indexes_by_name: + raise InvalidIndexNameError( + f"The table does not have the specified index: {index_name}" + ) + return indexes_by_name[index_name] + + def has_idx_items(self, index_name: str) -> Iterator[Item]: + + idx = self.get_index(index_name) + idx_col_set = set([i["AttributeName"] for i in idx.schema]) + + for hash_set in self.items.values(): + if self.range_key_attr: + for item in hash_set.values(): + if idx_col_set.issubset(set(item.attrs)): + yield item + else: + if idx_col_set.issubset(set(hash_set.attrs)): # type: ignore + yield hash_set # type: ignore + + def scan( + self, + filters: Dict[str, Any], + limit: int, + exclusive_start_key: Dict[str, Any], + filter_expression: Any = None, + index_name: Optional[str] = None, + projection_expression: Optional[str] = None, + ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: + results = [] + scanned_count = 0 + + if index_name: + self.get_index(index_name, error_if_not=True) + items = self.has_idx_items(index_name) + else: + items = self.all_items() + + for item in items: + scanned_count += 1 + passes_all_conditions = True + for ( + attribute_name, + (comparison_operator, comparison_objs), + ) in filters.items(): + attribute = item.attrs.get(attribute_name) + + if attribute: + # Attribute found + if not attribute.compare(comparison_operator, comparison_objs): + passes_all_conditions = False + break + elif comparison_operator == "NULL": + # Comparison is NULL and we don't have the attribute + continue + else: + # No attribute found and comparison is no NULL. This item + # fails + passes_all_conditions = False + break + + if passes_all_conditions: + results.append(item) + + results, last_evaluated_key = self._trim_results( + results, limit, exclusive_start_key, scanned_index=index_name + ) + + if filter_expression is not None: + results = [item for item in results if filter_expression.expr(item)] + + if projection_expression: + results = copy.deepcopy(results) + for result in results: + result.filter(projection_expression) + + return results, scanned_count, last_evaluated_key + + def _trim_results( + self, + results: List[Item], + limit: int, + exclusive_start_key: Optional[Dict[str, Any]], + scanned_index: Optional[str] = None, + ) -> Tuple[List[Item], Optional[Dict[str, Any]]]: + if exclusive_start_key is not None: + hash_key = DynamoType(exclusive_start_key.get(self.hash_key_attr)) # type: ignore[arg-type] + range_key = ( + exclusive_start_key.get(self.range_key_attr) + if self.range_key_attr + else None + ) + if range_key is not None: + range_key = DynamoType(range_key) + for i in range(len(results)): + if ( + results[i].hash_key == hash_key + and results[i].range_key == range_key + ): + results = results[i + 1 :] + break + + last_evaluated_key = None + size_limit = 1000000 # DynamoDB has a 1MB size limit + item_size = sum(res.size() for res in results) + if item_size > size_limit: + item_size = idx = 0 + while item_size + results[idx].size() < size_limit: + item_size += results[idx].size() + idx += 1 + limit = min(limit, idx) if limit else idx + if limit and len(results) > limit: + results = results[:limit] + last_evaluated_key = {self.hash_key_attr: results[-1].hash_key} + if self.range_key_attr is not None and results[-1].range_key is not None: + last_evaluated_key[self.range_key_attr] = results[-1].range_key + + if scanned_index: + index = self.get_index(scanned_index) + idx_col_list = [i["AttributeName"] for i in index.schema] + for col in idx_col_list: + last_evaluated_key[col] = results[-1].attrs[col] + + return results, last_evaluated_key + + def delete(self, account_id: str, region_name: str) -> None: + from moto.dynamodb.models import dynamodb_backends + + dynamodb_backends[account_id][region_name].delete_table(self.name) + + +class Backup: + def __init__( + self, + account_id: str, + region_name: str, + name: str, + table: Table, + status: Optional[str] = None, + type_: Optional[str] = None, + ): + self.region_name = region_name + self.account_id = account_id + self.name = name + self.table = copy.deepcopy(table) + self.status = status or "AVAILABLE" + self.type = type_ or "USER" + self.creation_date_time = datetime.datetime.utcnow() + self.identifier = self._make_identifier() + + def _make_identifier(self) -> str: + timestamp = int(unix_time_millis(self.creation_date_time)) + timestamp_padded = str("0" + str(timestamp))[-16:16] + guid = str(mock_random.uuid4()) + guid_shortened = guid[:8] + return f"{timestamp_padded}-{guid_shortened}" + + @property + def arn(self) -> str: + return f"arn:aws:dynamodb:{self.region_name}:{self.account_id}:table/{self.table.name}/backup/{self.identifier}" + + @property + def details(self) -> Dict[str, Any]: # type: ignore[misc] + return { + "BackupArn": self.arn, + "BackupName": self.name, + "BackupSizeBytes": 123, + "BackupStatus": self.status, + "BackupType": self.type, + "BackupCreationDateTime": unix_time(self.creation_date_time), + } + + @property + def summary(self) -> Dict[str, Any]: # type: ignore[misc] + return { + "TableName": self.table.name, + # 'TableId': 'string', + "TableArn": self.table.table_arn, + "BackupArn": self.arn, + "BackupName": self.name, + "BackupCreationDateTime": unix_time(self.creation_date_time), + # 'BackupExpiryDateTime': datetime(2015, 1, 1), + "BackupStatus": self.status, + "BackupType": self.type, + "BackupSizeBytes": 123, + } + + @property + def description(self) -> Dict[str, Any]: # type: ignore[misc] + source_table_details = self.table.describe()["TableDescription"] + source_table_details["TableCreationDateTime"] = source_table_details[ + "CreationDateTime" + ] + description = { + "BackupDetails": self.details, + "SourceTableDetails": source_table_details, + } + return description + + +class RestoredTable(Table): + def __init__(self, name: str, account_id: str, region: str, backup: "Backup"): + params = self._parse_params_from_backup(backup) + super().__init__(name, account_id=account_id, region=region, **params) + self.indexes = copy.deepcopy(backup.table.indexes) + self.global_indexes = copy.deepcopy(backup.table.global_indexes) + self.items = copy.deepcopy(backup.table.items) + # Restore Attrs + self.source_backup_arn = backup.arn + self.source_table_arn = backup.table.table_arn + self.restore_date_time = self.created_at + + def _parse_params_from_backup(self, backup: "Backup") -> Dict[str, Any]: + return { + "schema": copy.deepcopy(backup.table.schema), + "attr": copy.deepcopy(backup.table.attr), + "throughput": copy.deepcopy(backup.table.throughput), + } + + def describe(self, base_key: str = "TableDescription") -> Dict[str, Any]: + result = super().describe(base_key=base_key) + result[base_key]["RestoreSummary"] = { + "SourceBackupArn": self.source_backup_arn, + "SourceTableArn": self.source_table_arn, + "RestoreDateTime": unix_time(self.restore_date_time), + "RestoreInProgress": False, + } + return result + + +class RestoredPITTable(Table): + def __init__(self, name: str, account_id: str, region: str, source: Table): + params = self._parse_params_from_table(source) + super().__init__(name, account_id=account_id, region=region, **params) + self.indexes = copy.deepcopy(source.indexes) + self.global_indexes = copy.deepcopy(source.global_indexes) + self.items = copy.deepcopy(source.items) + # Restore Attrs + self.source_table_arn = source.table_arn + self.restore_date_time = self.created_at + + def _parse_params_from_table(self, table: Table) -> Dict[str, Any]: + return { + "schema": copy.deepcopy(table.schema), + "attr": copy.deepcopy(table.attr), + "throughput": copy.deepcopy(table.throughput), + } + + def describe(self, base_key: str = "TableDescription") -> Dict[str, Any]: + result = super().describe(base_key=base_key) + result[base_key]["RestoreSummary"] = { + "SourceTableArn": self.source_table_arn, + "RestoreDateTime": unix_time(self.restore_date_time), + "RestoreInProgress": False, + } + return result diff --git a/moto/dynamodb/models/utilities.py b/moto/dynamodb/models/utilities.py index 28c6676f5..0aa6b59bd 100644 --- a/moto/dynamodb/models/utilities.py +++ b/moto/dynamodb/models/utilities.py @@ -1,2 +1,16 @@ -def bytesize(val): +import json +from typing import Any + + +class DynamoJsonEncoder(json.JSONEncoder): + def default(self, o: Any) -> Any: + if hasattr(o, "to_json"): + return o.to_json() + + +def dynamo_json_dump(dynamo_object: Any) -> str: + return json.dumps(dynamo_object, cls=DynamoJsonEncoder) + + +def bytesize(val: str) -> int: return len(val.encode("utf-8")) diff --git a/moto/dynamodb/parsing/ast_nodes.py b/moto/dynamodb/parsing/ast_nodes.py index a021d3ca6..93ad8551c 100644 --- a/moto/dynamodb/parsing/ast_nodes.py +++ b/moto/dynamodb/parsing/ast_nodes.py @@ -1,3 +1,4 @@ +# type: ignore import abc from abc import abstractmethod from collections import deque @@ -21,7 +22,7 @@ class Node(metaclass=abc.ABCMeta): def set_parent(self, parent_node): self.parent = parent_node - def validate(self): + def validate(self) -> None: if self.type == "UpdateExpression": nr_of_clauses = len(self.find_clauses([UpdateExpressionAddClause])) if nr_of_clauses > 1: diff --git a/moto/dynamodb/parsing/executors.py b/moto/dynamodb/parsing/executors.py index 9b8df9602..6065c7361 100644 --- a/moto/dynamodb/parsing/executors.py +++ b/moto/dynamodb/parsing/executors.py @@ -1,13 +1,19 @@ from abc import abstractmethod +from typing import Any, Dict, List, Optional, Union, Type from moto.dynamodb.exceptions import ( IncorrectOperandType, IncorrectDataType, ProvidedKeyDoesNotExist, ) -from moto.dynamodb.models import DynamoType -from moto.dynamodb.models.dynamo_type import DDBTypeConversion, DDBType -from moto.dynamodb.parsing.ast_nodes import ( +from moto.dynamodb.models.dynamo_type import ( + DDBTypeConversion, + DDBType, + DynamoType, + Item, +) +from moto.dynamodb.parsing.ast_nodes import ( # type: ignore + Node, UpdateExpressionSetAction, UpdateExpressionDeleteAction, UpdateExpressionRemoveAction, @@ -21,16 +27,18 @@ from moto.dynamodb.parsing.ast_nodes import ( from moto.dynamodb.parsing.validators import ExpressionPathResolver -class NodeExecutor(object): - def __init__(self, ast_node, expression_attribute_names): +class NodeExecutor: + def __init__(self, ast_node: Node, expression_attribute_names: Dict[str, str]): self.node = ast_node self.expression_attribute_names = expression_attribute_names @abstractmethod - def execute(self, item): + def execute(self, item: Item) -> None: pass - def get_item_part_for_path_nodes(self, item, path_nodes): + def get_item_part_for_path_nodes( + self, item: Item, path_nodes: List[Node] + ) -> Union[DynamoType, Dict[str, Any]]: """ For a list of path nodes travers the item by following the path_nodes Args: @@ -43,11 +51,13 @@ class NodeExecutor(object): if len(path_nodes) == 0: return item.attrs else: - return ExpressionPathResolver( + return ExpressionPathResolver( # type: ignore self.expression_attribute_names ).resolve_expression_path_nodes_to_dynamo_type(item, path_nodes) - def get_item_before_end_of_path(self, item): + def get_item_before_end_of_path( + self, item: Item + ) -> Union[DynamoType, Dict[str, Any]]: """ Get the part ot the item where the item will perform the action. For most actions this should be the parent. As that element will need to be modified by the action. @@ -61,7 +71,7 @@ class NodeExecutor(object): item, self.get_path_expression_nodes()[:-1] ) - def get_item_at_end_of_path(self, item): + def get_item_at_end_of_path(self, item: Item) -> Union[DynamoType, Dict[str, Any]]: """ For a DELETE the path points at the stringset so we need to evaluate the full path. Args: @@ -76,15 +86,15 @@ class NodeExecutor(object): # that element will need to be modified by the action. get_item_part_in_which_to_perform_action = get_item_before_end_of_path - def get_path_expression_nodes(self): + def get_path_expression_nodes(self) -> List[Node]: update_expression_path = self.node.children[0] assert isinstance(update_expression_path, UpdateExpressionPath) return update_expression_path.children - def get_element_to_action(self): + def get_element_to_action(self) -> Node: return self.get_path_expression_nodes()[-1] - def get_action_value(self): + def get_action_value(self) -> DynamoType: """ Returns: @@ -98,7 +108,7 @@ class NodeExecutor(object): class SetExecutor(NodeExecutor): - def execute(self, item): + def execute(self, item: Item) -> None: self.set( item_part_to_modify_with_set=self.get_item_part_in_which_to_perform_action( item @@ -109,13 +119,13 @@ class SetExecutor(NodeExecutor): ) @classmethod - def set( + def set( # type: ignore[misc] cls, - item_part_to_modify_with_set, - element_to_set, - value_to_set, - expression_attribute_names, - ): + item_part_to_modify_with_set: Union[DynamoType, Dict[str, Any]], + element_to_set: Any, + value_to_set: Any, + expression_attribute_names: Dict[str, str], + ) -> None: if isinstance(element_to_set, ExpressionAttribute): attribute_name = element_to_set.get_attribute_name() item_part_to_modify_with_set[attribute_name] = value_to_set @@ -136,7 +146,7 @@ class SetExecutor(NodeExecutor): class DeleteExecutor(NodeExecutor): operator = "operator: DELETE" - def execute(self, item): + def execute(self, item: Item) -> None: string_set_to_remove = self.get_action_value() assert isinstance(string_set_to_remove, DynamoType) if not string_set_to_remove.is_set(): @@ -176,11 +186,11 @@ class DeleteExecutor(NodeExecutor): f"Moto does not support deleting {type(element)} yet" ) container = self.get_item_before_end_of_path(item) - del container[attribute_name] + del container[attribute_name] # type: ignore[union-attr] class RemoveExecutor(NodeExecutor): - def execute(self, item): + def execute(self, item: Item) -> None: element_to_remove = self.get_element_to_action() if isinstance(element_to_remove, ExpressionAttribute): attribute_name = element_to_remove.get_attribute_name() @@ -208,7 +218,7 @@ class RemoveExecutor(NodeExecutor): class AddExecutor(NodeExecutor): - def execute(self, item): + def execute(self, item: Item) -> None: value_to_add = self.get_action_value() if isinstance(value_to_add, DynamoType): if value_to_add.is_set(): @@ -253,7 +263,7 @@ class AddExecutor(NodeExecutor): raise IncorrectDataType() -class UpdateExpressionExecutor(object): +class UpdateExpressionExecutor: execution_map = { UpdateExpressionSetAction: SetExecutor, UpdateExpressionAddAction: AddExecutor, @@ -261,12 +271,14 @@ class UpdateExpressionExecutor(object): UpdateExpressionDeleteAction: DeleteExecutor, } - def __init__(self, update_ast, item, expression_attribute_names): + def __init__( + self, update_ast: Node, item: Item, expression_attribute_names: Dict[str, str] + ): self.update_ast = update_ast self.item = item self.expression_attribute_names = expression_attribute_names - def execute(self, node=None): + def execute(self, node: Optional[Node] = None) -> None: """ As explained in moto.dynamodb.parsing.expressions.NestableExpressionParserMixin._create_node the order of nodes in the AST can be translated of the order of statements in the expression. As such we can start at the root node @@ -286,12 +298,12 @@ class UpdateExpressionExecutor(object): node_executor = self.get_specific_execution(node) if node_executor is None: - for node in node.children: - self.execute(node) + for n in node.children: + self.execute(n) else: node_executor(node, self.expression_attribute_names).execute(self.item) - def get_specific_execution(self, node): + def get_specific_execution(self, node: Node) -> Optional[Type[NodeExecutor]]: for node_class in self.execution_map: if isinstance(node, node_class): return self.execution_map[node_class] diff --git a/moto/dynamodb/parsing/expressions.py b/moto/dynamodb/parsing/expressions.py index a14021b1d..1561f439c 100644 --- a/moto/dynamodb/parsing/expressions.py +++ b/moto/dynamodb/parsing/expressions.py @@ -1,3 +1,4 @@ +# type: ignore import logging from abc import abstractmethod import abc @@ -35,7 +36,7 @@ from moto.dynamodb.parsing.tokens import Token, ExpressionTokenizer logger = logging.getLogger(__name__) -class NestableExpressionParserMixin(object): +class NestableExpressionParserMixin: """ For nodes that can be nested in themselves (recursive). Take for example UpdateExpression's grammar: diff --git a/moto/dynamodb/parsing/key_condition_expression.py b/moto/dynamodb/parsing/key_condition_expression.py index 3bdbaa7e1..3e8a9c10c 100644 --- a/moto/dynamodb/parsing/key_condition_expression.py +++ b/moto/dynamodb/parsing/key_condition_expression.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Any, List, Dict, Tuple, Optional from moto.dynamodb.exceptions import MockValidationException from moto.utilities.tokenizer import GenericTokenizer @@ -11,17 +12,17 @@ class EXPRESSION_STAGES(Enum): EOF = "EOF" -def get_key(schema, key_type): +def get_key(schema: List[Dict[str, str]], key_type: str) -> Optional[str]: keys = [key for key in schema if key["KeyType"] == key_type] return keys[0]["AttributeName"] if keys else None def parse_expression( - key_condition_expression, - expression_attribute_values, - expression_attribute_names, - schema, -): + key_condition_expression: str, + expression_attribute_values: Dict[str, str], + expression_attribute_names: Dict[str, str], + schema: List[Dict[str, str]], +) -> Tuple[Dict[str, Any], Optional[str], List[Dict[str, Any]]]: """ Parse a KeyConditionExpression using the provided expression attribute names/values @@ -31,11 +32,11 @@ def parse_expression( schema: [{'AttributeName': 'hashkey', 'KeyType': 'HASH'}, {"AttributeName": "sortkey", "KeyType": "RANGE"}] """ - current_stage: EXPRESSION_STAGES = None + current_stage: Optional[EXPRESSION_STAGES] = None current_phrase = "" - key_name = comparison = None + key_name = comparison = "" key_values = [] - results = [] + results: List[Tuple[str, str, Any]] = [] tokenizer = GenericTokenizer(key_condition_expression) for crnt_char in tokenizer: if crnt_char == " ": @@ -188,7 +189,9 @@ def parse_expression( # Validate that the schema-keys are encountered in our query -def validate_schema(results, schema): +def validate_schema( + results: Any, schema: List[Dict[str, str]] +) -> Tuple[Dict[str, Any], Optional[str], List[Dict[str, Any]]]: index_hash_key = get_key(schema, "HASH") comparison, hash_value = next( ( @@ -219,4 +222,4 @@ def validate_schema(results, schema): f"Query condition missed key schema element: {index_range_key}" ) - return hash_value, range_comparison, range_values + return hash_value, range_comparison, range_values # type: ignore[return-value] diff --git a/moto/dynamodb/parsing/reserved_keywords.py b/moto/dynamodb/parsing/reserved_keywords.py index 7fa8ddb15..9a59adf4f 100644 --- a/moto/dynamodb/parsing/reserved_keywords.py +++ b/moto/dynamodb/parsing/reserved_keywords.py @@ -1,27 +1,26 @@ -from moto.utilities.utils import load_resource +from typing import List, Optional +from moto.utilities.utils import load_resource_as_str -class ReservedKeywords(list): +class ReservedKeywords: """ DynamoDB has an extensive list of keywords. Keywords are considered when validating the expression Tree. Not earlier since an update expression like "SET path = VALUE 1" fails with: 'Invalid UpdateExpression: Syntax error; token: "1", near: "VALUE 1"' """ - KEYWORDS = None + KEYWORDS: Optional[List[str]] = None @classmethod - def get_reserved_keywords(cls): + def get_reserved_keywords(cls) -> List[str]: if cls.KEYWORDS is None: cls.KEYWORDS = cls._get_reserved_keywords() return cls.KEYWORDS @classmethod - def _get_reserved_keywords(cls): + def _get_reserved_keywords(cls) -> List[str]: """ Get a list of reserved keywords of DynamoDB """ - reserved_keywords = load_resource( - __name__, "reserved_keywords.txt", as_json=False - ) + reserved_keywords = load_resource_as_str(__name__, "reserved_keywords.txt") return reserved_keywords.split() diff --git a/moto/dynamodb/parsing/tokens.py b/moto/dynamodb/parsing/tokens.py index aada81264..3ccc9f8bd 100644 --- a/moto/dynamodb/parsing/tokens.py +++ b/moto/dynamodb/parsing/tokens.py @@ -1,4 +1,5 @@ import re +from typing import List, Union from moto.dynamodb.exceptions import ( InvalidTokenException, @@ -6,7 +7,7 @@ from moto.dynamodb.exceptions import ( ) -class Token(object): +class Token: _TOKEN_INSTANCE = None MINUS_SIGN = "-" PLUS_SIGN = "+" @@ -53,7 +54,7 @@ class Token(object): NUMBER: "Number", } - def __init__(self, token_type, value): + def __init__(self, token_type: Union[int, str], value: str): assert ( token_type in self.SPECIAL_CHARACTERS or token_type in self.PLACEHOLDER_NAMES @@ -61,13 +62,13 @@ class Token(object): self.type = token_type self.value = value - def __repr__(self): + def __repr__(self) -> str: if isinstance(self.type, int): return f'Token("{self.PLACEHOLDER_NAMES[self.type]}", "{self.value}")' else: return f'Token("{self.type}", "{self.value}")' - def __eq__(self, other): + def __eq__(self, other: "Token") -> bool: # type: ignore[override] return self.type == other.type and self.value == other.value @@ -94,22 +95,22 @@ class ExpressionTokenizer(object): """ @classmethod - def is_simple_token_character(cls, character): + def is_simple_token_character(cls, character: str) -> bool: return character.isalnum() or character in ("_", ":", "#") @classmethod - def is_possible_token_boundary(cls, character): + def is_possible_token_boundary(cls, character: str) -> bool: return ( character in Token.SPECIAL_CHARACTERS or not cls.is_simple_token_character(character) ) @classmethod - def is_expression_attribute(cls, input_string): + def is_expression_attribute(cls, input_string: str) -> bool: return re.compile("^[a-zA-Z0-9][a-zA-Z0-9_]*$").match(input_string) is not None @classmethod - def is_expression_attribute_name(cls, input_string): + def is_expression_attribute_name(cls, input_string: str) -> bool: """ https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html An expression attribute name must begin with a pound sign (#), and be followed by one or more alphanumeric @@ -120,10 +121,10 @@ class ExpressionTokenizer(object): ) @classmethod - def is_expression_attribute_value(cls, input_string): + def is_expression_attribute_value(cls, input_string: str) -> bool: return re.compile("^:[a-zA-Z0-9_]*$").match(input_string) is not None - def raise_unexpected_token(self): + def raise_unexpected_token(self) -> None: """If during parsing an unexpected token is encountered""" if len(self.token_list) == 0: near = "" @@ -140,29 +141,29 @@ class ExpressionTokenizer(object): problematic_token = self.staged_characters[0] raise InvalidTokenException(problematic_token, near + self.staged_characters) - def __init__(self, input_expression_str): + def __init__(self, input_expression_str: str): self.input_expression_str = input_expression_str - self.token_list = [] + self.token_list: List[Token] = [] self.staged_characters = "" @classmethod - def make_list(cls, input_expression_str): + def make_list(cls, input_expression_str: str) -> List[Token]: assert isinstance(input_expression_str, str) return ExpressionTokenizer(input_expression_str)._make_list() - def add_token(self, token_type, token_value): + def add_token(self, token_type: Union[int, str], token_value: str) -> None: self.token_list.append(Token(token_type, token_value)) - def add_token_from_stage(self, token_type): + def add_token_from_stage(self, token_type: int) -> None: self.add_token(token_type, self.staged_characters) self.staged_characters = "" @classmethod - def is_numeric(cls, input_str): + def is_numeric(cls, input_str: str) -> bool: return re.compile("[0-9]+").match(input_str) is not None - def process_staged_characters(self): + def process_staged_characters(self) -> None: if len(self.staged_characters) == 0: return if self.staged_characters.startswith("#"): @@ -179,7 +180,7 @@ class ExpressionTokenizer(object): else: self.raise_unexpected_token() - def _make_list(self): + def _make_list(self) -> List[Token]: """ Just go through characters if a character is not a token boundary stage it for adding it as a grouped token later if it is a tokenboundary process staged characters and then process the token boundary as well. diff --git a/moto/dynamodb/parsing/validators.py b/moto/dynamodb/parsing/validators.py index d328867e8..2116db772 100644 --- a/moto/dynamodb/parsing/validators.py +++ b/moto/dynamodb/parsing/validators.py @@ -3,6 +3,7 @@ See docstring class Validator below for more details on validation """ from abc import abstractmethod from copy import deepcopy +from typing import Any, Callable, Dict, List, Type, Union from moto.dynamodb.exceptions import ( AttributeIsReservedKeyword, @@ -15,9 +16,12 @@ from moto.dynamodb.exceptions import ( EmptyKeyAttributeException, UpdateHashRangeKeyException, ) -from moto.dynamodb.models import DynamoType -from moto.dynamodb.parsing.ast_nodes import ( +from moto.dynamodb.models.dynamo_type import DynamoType, Item +from moto.dynamodb.models.table import Table +from moto.dynamodb.parsing.ast_nodes import ( # type: ignore + Node, ExpressionAttribute, + UpdateExpressionClause, UpdateExpressionPath, UpdateExpressionSetAction, UpdateExpressionAddAction, @@ -37,16 +41,23 @@ from moto.dynamodb.parsing.ast_nodes import ( from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords -class ExpressionAttributeValueProcessor(DepthFirstTraverser): - def __init__(self, expression_attribute_values): +class ExpressionAttributeValueProcessor(DepthFirstTraverser): # type: ignore[misc] + def __init__(self, expression_attribute_values: Dict[str, Dict[str, Any]]): self.expression_attribute_values = expression_attribute_values - def _processing_map(self): + def _processing_map( + self, + ) -> Dict[ + Type[ExpressionAttributeValue], + Callable[[ExpressionAttributeValue], DDBTypedValue], + ]: return { ExpressionAttributeValue: self.replace_expression_attribute_value_with_value } - def replace_expression_attribute_value_with_value(self, node): + def replace_expression_attribute_value_with_value( + self, node: ExpressionAttributeValue + ) -> DDBTypedValue: """A node representing an Expression Attribute Value. Resolve and replace value""" assert isinstance(node, ExpressionAttributeValue) attribute_value_name = node.get_value_name() @@ -59,20 +70,24 @@ class ExpressionAttributeValueProcessor(DepthFirstTraverser): return DDBTypedValue(DynamoType(target)) -class ExpressionPathResolver(object): - def __init__(self, expression_attribute_names): +class ExpressionPathResolver: + def __init__(self, expression_attribute_names: Dict[str, str]): self.expression_attribute_names = expression_attribute_names @classmethod - def raise_exception_if_keyword(cls, attribute): + def raise_exception_if_keyword(cls, attribute: Any) -> None: # type: ignore[misc] if attribute.upper() in ReservedKeywords.get_reserved_keywords(): raise AttributeIsReservedKeyword(attribute) - def resolve_expression_path(self, item, update_expression_path): + def resolve_expression_path( + self, item: Item, update_expression_path: UpdateExpressionPath + ) -> Union[NoneExistingPath, DDBTypedValue]: assert isinstance(update_expression_path, UpdateExpressionPath) return self.resolve_expression_path_nodes(item, update_expression_path.children) - def resolve_expression_path_nodes(self, item, update_expression_path_nodes): + def resolve_expression_path_nodes( + self, item: Item, update_expression_path_nodes: List[Node] + ) -> Union[NoneExistingPath, DDBTypedValue]: target = item.attrs for child in update_expression_path_nodes: @@ -100,7 +115,7 @@ class ExpressionPathResolver(object): continue elif isinstance(child, ExpressionSelector): index = child.get_index() - if target.is_list(): + if target.is_list(): # type: ignore try: target = target[index] except IndexError: @@ -116,8 +131,8 @@ class ExpressionPathResolver(object): return DDBTypedValue(target) def resolve_expression_path_nodes_to_dynamo_type( - self, item, update_expression_path_nodes - ): + self, item: Item, update_expression_path_nodes: List[Node] + ) -> Any: node = self.resolve_expression_path_nodes(item, update_expression_path_nodes) if isinstance(node, NoneExistingPath): raise ProvidedKeyDoesNotExist() @@ -125,19 +140,30 @@ class ExpressionPathResolver(object): return node.get_value() -class ExpressionAttributeResolvingProcessor(DepthFirstTraverser): - def _processing_map(self): +class ExpressionAttributeResolvingProcessor(DepthFirstTraverser): # type: ignore[misc] + def _processing_map( + self, + ) -> Dict[Type[UpdateExpressionClause], Callable[[DDBTypedValue], DDBTypedValue]]: return { UpdateExpressionSetAction: self.disable_resolving, UpdateExpressionPath: self.process_expression_path_node, } - def __init__(self, expression_attribute_names, item): + def __init__(self, expression_attribute_names: Dict[str, str], item: Item): self.expression_attribute_names = expression_attribute_names self.item = item self.resolving = False - def pre_processing_of_child(self, parent_node, child_id): + def pre_processing_of_child( + self, + parent_node: Union[ + UpdateExpressionSetAction, + UpdateExpressionRemoveAction, + UpdateExpressionDeleteAction, + UpdateExpressionAddAction, + ], + child_id: int, + ) -> None: """ We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first. Because first argument is path to be set, 2nd argument would be the value. @@ -156,11 +182,11 @@ class ExpressionAttributeResolvingProcessor(DepthFirstTraverser): else: self.resolving = True - def disable_resolving(self, node=None): + def disable_resolving(self, node: DDBTypedValue) -> DDBTypedValue: self.resolving = False return node - def process_expression_path_node(self, node): + def process_expression_path_node(self, node: DDBTypedValue) -> DDBTypedValue: """Resolve ExpressionAttribute if not part of a path and resolving is enabled.""" if self.resolving: return self.resolve_expression_path(node) @@ -175,13 +201,15 @@ class ExpressionAttributeResolvingProcessor(DepthFirstTraverser): return node - def resolve_expression_path(self, node): + def resolve_expression_path( + self, node: DDBTypedValue + ) -> Union[NoneExistingPath, DDBTypedValue]: return ExpressionPathResolver( self.expression_attribute_names ).resolve_expression_path(self.item, node) -class UpdateExpressionFunctionEvaluator(DepthFirstTraverser): +class UpdateExpressionFunctionEvaluator(DepthFirstTraverser): # type: ignore[misc] """ At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET expression as per the official AWS docs: @@ -189,10 +217,15 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser): Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET """ - def _processing_map(self): + def _processing_map( + self, + ) -> Dict[ + Type[UpdateExpressionFunction], + Callable[[UpdateExpressionFunction], DDBTypedValue], + ]: return {UpdateExpressionFunction: self.process_function} - def process_function(self, node): + def process_function(self, node: UpdateExpressionFunction) -> DDBTypedValue: assert isinstance(node, UpdateExpressionFunction) function_name = node.get_function_name() first_arg = node.get_nth_argument(1) @@ -217,7 +250,7 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser): raise NotImplementedError(f"Unsupported function for moto {function_name}") @classmethod - def get_list_from_ddb_typed_value(cls, node, function_name): + def get_list_from_ddb_typed_value(cls, node: DDBTypedValue, function_name: str) -> DynamoType: # type: ignore[misc] assert isinstance(node, DDBTypedValue) dynamo_value = node.get_value() assert isinstance(dynamo_value, DynamoType) @@ -226,23 +259,25 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser): return dynamo_value -class NoneExistingPathChecker(DepthFirstTraverser): +class NoneExistingPathChecker(DepthFirstTraverser): # type: ignore[misc] """ Pass through the AST and make sure there are no none-existing paths. """ - def _processing_map(self): + def _processing_map(self) -> Dict[Type[NoneExistingPath], Callable[[Node], None]]: return {NoneExistingPath: self.raise_none_existing_path} - def raise_none_existing_path(self, node): + def raise_none_existing_path(self, node: Node) -> None: raise AttributeDoesNotExist -class ExecuteOperations(DepthFirstTraverser): - def _processing_map(self): +class ExecuteOperations(DepthFirstTraverser): # type: ignore[misc] + def _processing_map( + self, + ) -> Dict[Type[UpdateExpressionValue], Callable[[Node], DDBTypedValue]]: return {UpdateExpressionValue: self.process_update_expression_value} - def process_update_expression_value(self, node): + def process_update_expression_value(self, node: Node) -> DDBTypedValue: """ If an UpdateExpressionValue only has a single child the node will be replaced with the childe. Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them @@ -273,14 +308,14 @@ class ExecuteOperations(DepthFirstTraverser): ) @classmethod - def get_dynamo_value_from_ddb_typed_value(cls, node): + def get_dynamo_value_from_ddb_typed_value(cls, node: DDBTypedValue) -> DynamoType: # type: ignore[misc] assert isinstance(node, DDBTypedValue) dynamo_value = node.get_value() assert isinstance(dynamo_value, DynamoType) return dynamo_value @classmethod - def get_sum(cls, left_operand, right_operand): + def get_sum(cls, left_operand: DynamoType, right_operand: DynamoType) -> DDBTypedValue: # type: ignore[misc] """ Args: left_operand(DynamoType): @@ -295,7 +330,7 @@ class ExecuteOperations(DepthFirstTraverser): raise IncorrectOperandType("+", left_operand.type) @classmethod - def get_subtraction(cls, left_operand, right_operand): + def get_subtraction(cls, left_operand: DynamoType, right_operand: DynamoType) -> DDBTypedValue: # type: ignore[misc] """ Args: left_operand(DynamoType): @@ -310,14 +345,21 @@ class ExecuteOperations(DepthFirstTraverser): raise IncorrectOperandType("-", left_operand.type) -class EmptyStringKeyValueValidator(DepthFirstTraverser): - def __init__(self, key_attributes): +class EmptyStringKeyValueValidator(DepthFirstTraverser): # type: ignore[misc] + def __init__(self, key_attributes: List[str]): self.key_attributes = key_attributes - def _processing_map(self): + def _processing_map( + self, + ) -> Dict[ + Type[UpdateExpressionSetAction], + Callable[[UpdateExpressionSetAction], UpdateExpressionSetAction], + ]: return {UpdateExpressionSetAction: self.check_for_empty_string_key_value} - def check_for_empty_string_key_value(self, node): + def check_for_empty_string_key_value( + self, node: UpdateExpressionSetAction + ) -> UpdateExpressionSetAction: """A node representing a SET action. Check that keys are not being assigned empty strings""" assert isinstance(node, UpdateExpressionSetAction) assert len(node.children) == 2 @@ -332,15 +374,26 @@ class EmptyStringKeyValueValidator(DepthFirstTraverser): return node -class UpdateHashRangeKeyValidator(DepthFirstTraverser): - def __init__(self, table_key_attributes, expression_attribute_names): +class UpdateHashRangeKeyValidator(DepthFirstTraverser): # type: ignore[misc] + def __init__( + self, + table_key_attributes: List[str], + expression_attribute_names: Dict[str, str], + ): self.table_key_attributes = table_key_attributes self.expression_attribute_names = expression_attribute_names - def _processing_map(self): + def _processing_map( + self, + ) -> Dict[ + Type[UpdateExpressionPath], + Callable[[UpdateExpressionPath], UpdateExpressionPath], + ]: return {UpdateExpressionPath: self.check_for_hash_or_range_key} - def check_for_hash_or_range_key(self, node): + def check_for_hash_or_range_key( + self, node: UpdateExpressionPath + ) -> UpdateExpressionPath: """Check that hash and range keys are not updated""" key_to_update = node.children[0].children[0] key_to_update = self.expression_attribute_names.get( @@ -351,18 +404,18 @@ class UpdateHashRangeKeyValidator(DepthFirstTraverser): return node -class Validator(object): +class Validator: """ A validator is used to validate expressions which are passed in as an AST. """ def __init__( self, - expression, - expression_attribute_names, - expression_attribute_values, - item, - table, + expression: Node, + expression_attribute_names: Dict[str, str], + expression_attribute_values: Dict[str, Dict[str, Any]], + item: Item, + table: Table, ): """ Besides validation the Validator should also replace referenced parts of an item which is cheapest upon @@ -382,10 +435,10 @@ class Validator(object): self.node_to_validate = deepcopy(expression) @abstractmethod - def get_ast_processors(self): + def get_ast_processors(self) -> List[DepthFirstTraverser]: # type: ignore[misc] """Get the different processors that go through the AST tree and processes the nodes.""" - def validate(self): + def validate(self) -> Node: n = self.node_to_validate for processor in self.processors: n = processor.traverse(n) @@ -393,7 +446,7 @@ class Validator(object): class UpdateExpressionValidator(Validator): - def get_ast_processors(self): + def get_ast_processors(self) -> List[DepthFirstTraverser]: """Get the different processors that go through the AST tree and processes the nodes.""" processors = [ UpdateHashRangeKeyValidator( diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index 0144c835b..54f3abe72 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -3,7 +3,9 @@ import json import itertools from functools import wraps +from typing import Any, Dict, List, Union, Callable, Optional +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores from moto.dynamodb.parsing.key_condition_expression import parse_expression @@ -13,17 +15,27 @@ from .exceptions import ( ResourceNotFoundException, ConditionalCheckFailed, ) -from moto.dynamodb.models import dynamodb_backends, dynamo_json_dump, Table +from moto.dynamodb.models import dynamodb_backends, Table, DynamoDBBackend +from moto.dynamodb.models.utilities import dynamo_json_dump from moto.utilities.aws_headers import amz_crc32, amzn_request_id TRANSACTION_MAX_ITEMS = 25 -def include_consumed_capacity(val=1.0): - def _inner(f): +def include_consumed_capacity( + val: float = 1.0, +) -> Callable[ + [Callable[["DynamoHandler"], str]], + Callable[["DynamoHandler"], Union[str, TYPE_RESPONSE]], +]: + def _inner( + f: Callable[..., Union[str, TYPE_RESPONSE]] + ) -> Callable[["DynamoHandler"], Union[str, TYPE_RESPONSE]]: @wraps(f) - def _wrapper(*args, **kwargs): + def _wrapper( + *args: "DynamoHandler", **kwargs: None + ) -> Union[str, TYPE_RESPONSE]: (handler,) = args expected_capacity = handler.body.get("ReturnConsumedCapacity", "NONE") if expected_capacity not in ["NONE", "TOTAL", "INDEXES"]: @@ -67,7 +79,7 @@ def include_consumed_capacity(val=1.0): return _inner -def get_empty_keys_on_put(field_updates, table: Table): +def get_empty_keys_on_put(field_updates: Dict[str, Any], table: Table) -> Optional[str]: """ Return the first key-name that has an empty value. None if all keys are filled """ @@ -83,12 +95,12 @@ def get_empty_keys_on_put(field_updates, table: Table): return next( (keyname for keyname in key_names if keyname in empty_str_fields), None ) - return False + return None -def put_has_empty_attrs(field_updates, table): +def put_has_empty_attrs(field_updates: Dict[str, Any], table: Table) -> bool: # Example invalid attribute: [{'M': {'SS': {'NS': []}}}] - def _validate_attr(attr: dict): + def _validate_attr(attr: Dict[str, Any]) -> bool: if "NS" in attr and attr["NS"] == []: return True else: @@ -105,7 +117,7 @@ def put_has_empty_attrs(field_updates, table): return False -def validate_put_has_gsi_keys_set_to_none(item, table: Table) -> None: +def validate_put_has_gsi_keys_set_to_none(item: Dict[str, Any], table: Table) -> None: for gsi in table.global_indexes: for attr in gsi.schema: attr_name = attr["AttributeName"] @@ -115,7 +127,7 @@ def validate_put_has_gsi_keys_set_to_none(item, table: Table) -> None: ) -def check_projection_expression(expression): +def check_projection_expression(expression: str) -> None: if expression.upper() in ReservedKeywords.get_reserved_keywords(): raise MockValidationException( f"ProjectionExpression: Attribute name is a reserved keyword; reserved keyword: {expression}" @@ -131,10 +143,10 @@ def check_projection_expression(expression): class DynamoHandler(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="dynamodb") - def get_endpoint_name(self, headers): + def get_endpoint_name(self, headers: Any) -> Optional[str]: """Parses request headers and extracts part od the X-Amz-Target that corresponds to a method of DynamoHandler @@ -144,9 +156,10 @@ class DynamoHandler(BaseResponse): match = headers.get("x-amz-target") or headers.get("X-Amz-Target") if match: return match.split(".")[1] + return None @property - def dynamodb_backend(self): + def dynamodb_backend(self) -> DynamoDBBackend: """ :return: DynamoDB Backend :rtype: moto.dynamodb.models.DynamoDBBackend @@ -155,7 +168,7 @@ class DynamoHandler(BaseResponse): @amz_crc32 @amzn_request_id - def call_action(self): + def call_action(self) -> TYPE_RESPONSE: self.body = json.loads(self.body or "{}") endpoint = self.get_endpoint_name(self.headers) if endpoint: @@ -171,7 +184,7 @@ class DynamoHandler(BaseResponse): else: return 404, self.response_headers, "" - def list_tables(self): + def list_tables(self) -> str: body = self.body limit = body.get("Limit", 100) exclusive_start_table_name = body.get("ExclusiveStartTableName") @@ -179,13 +192,13 @@ class DynamoHandler(BaseResponse): limit, exclusive_start_table_name ) - response = {"TableNames": tables} + response: Dict[str, Any] = {"TableNames": tables} if last_eval: response["LastEvaluatedTableName"] = last_eval return dynamo_json_dump(response) - def create_table(self): + def create_table(self) -> str: body = self.body # get the table name table_name = body["TableName"] @@ -243,7 +256,7 @@ class DynamoHandler(BaseResponse): actual_attrs = [item["AttributeName"] for item in attr] actual_attrs.sort() if actual_attrs != expected_attrs: - return self._throw_attr_error( + self._throw_attr_error( actual_attrs, expected_attrs, global_indexes or local_secondary_indexes ) # get the stream specification @@ -265,8 +278,10 @@ class DynamoHandler(BaseResponse): ) return dynamo_json_dump(table.describe()) - def _throw_attr_error(self, actual_attrs, expected_attrs, indexes): - def dump_list(list_): + def _throw_attr_error( + self, actual_attrs: List[str], expected_attrs: List[str], indexes: bool + ) -> None: + def dump_list(list_: List[str]) -> str: return str(list_).replace("'", "") err_head = "One or more parameter values were invalid: " @@ -315,28 +330,28 @@ class DynamoHandler(BaseResponse): + dump_list(actual_attrs) ) - def delete_table(self): + def delete_table(self) -> str: name = self.body["TableName"] table = self.dynamodb_backend.delete_table(name) return dynamo_json_dump(table.describe()) - def describe_endpoints(self): + def describe_endpoints(self) -> str: response = {"Endpoints": self.dynamodb_backend.describe_endpoints()} return dynamo_json_dump(response) - def tag_resource(self): + def tag_resource(self) -> str: table_arn = self.body["ResourceArn"] tags = self.body["Tags"] self.dynamodb_backend.tag_resource(table_arn, tags) return "" - def untag_resource(self): + def untag_resource(self) -> str: table_arn = self.body["ResourceArn"] tags = self.body["TagKeys"] self.dynamodb_backend.untag_resource(table_arn, tags) return "" - def list_tags_of_resource(self): + def list_tags_of_resource(self) -> str: table_arn = self.body["ResourceArn"] all_tags = self.dynamodb_backend.list_tags_of_resource(table_arn) all_tag_keys = [tag["Key"] for tag in all_tags] @@ -354,7 +369,7 @@ class DynamoHandler(BaseResponse): return json.dumps({"Tags": tags_resp, "NextToken": next_marker}) return json.dumps({"Tags": tags_resp}) - def update_table(self): + def update_table(self) -> str: name = self.body["TableName"] attr_definitions = self.body.get("AttributeDefinitions", None) global_index = self.body.get("GlobalSecondaryIndexUpdates", None) @@ -371,13 +386,13 @@ class DynamoHandler(BaseResponse): ) return dynamo_json_dump(table.describe()) - def describe_table(self): + def describe_table(self) -> str: name = self.body["TableName"] table = self.dynamodb_backend.describe_table(name) return dynamo_json_dump(table) @include_consumed_capacity() - def put_item(self): + def put_item(self) -> str: name = self.body["TableName"] item = self.body["Item"] return_values = self.body.get("ReturnValues", "NONE") @@ -436,7 +451,7 @@ class DynamoHandler(BaseResponse): item_dict.pop("Attributes", None) return dynamo_json_dump(item_dict) - def batch_write_item(self): + def batch_write_item(self) -> str: table_batches = self.body["RequestItems"] put_requests = [] delete_requests = [] @@ -478,7 +493,7 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(response) @include_consumed_capacity(0.5) - def get_item(self): + def get_item(self) -> str: name = self.body["TableName"] self.dynamodb_backend.get_table(name) key = self.body["Key"] @@ -520,10 +535,14 @@ class DynamoHandler(BaseResponse): # Item not found return dynamo_json_dump({}) - def batch_get_item(self): + def batch_get_item(self) -> str: table_batches = self.body["RequestItems"] - results = {"ConsumedCapacity": [], "Responses": {}, "UnprocessedKeys": {}} + results: Dict[str, Any] = { + "ConsumedCapacity": [], + "Responses": {}, + "UnprocessedKeys": {}, + } # Validation: Can only request up to 100 items at the same time # Scenario 1: We're requesting more than a 100 keys from a single table @@ -582,7 +601,7 @@ class DynamoHandler(BaseResponse): ) return dynamo_json_dump(results) - def _contains_duplicates(self, keys): + def _contains_duplicates(self, keys: List[str]) -> bool: unique_keys = [] for k in keys: if k in unique_keys: @@ -592,7 +611,7 @@ class DynamoHandler(BaseResponse): return False @include_consumed_capacity() - def query(self): + def query(self) -> str: name = self.body["TableName"] key_condition_expression = self.body.get("KeyConditionExpression") projection_expression = self.body.get("ProjectionExpression") @@ -676,7 +695,7 @@ class DynamoHandler(BaseResponse): **filter_kwargs, ) - result = { + result: Dict[str, Any] = { "Count": len(items), "ScannedCount": scanned_count, } @@ -689,8 +708,10 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(result) - def _adjust_projection_expression(self, projection_expression, expr_attr_names): - def _adjust(expression): + def _adjust_projection_expression( + self, projection_expression: str, expr_attr_names: Dict[str, str] + ) -> str: + def _adjust(expression: str) -> str: return ( expr_attr_names[expression] if expression in expr_attr_names @@ -712,7 +733,7 @@ class DynamoHandler(BaseResponse): return projection_expression @include_consumed_capacity() - def scan(self): + def scan(self) -> str: name = self.body["TableName"] filters = {} @@ -760,7 +781,7 @@ class DynamoHandler(BaseResponse): result["LastEvaluatedKey"] = last_evaluated_key return dynamo_json_dump(result) - def delete_item(self): + def delete_item(self) -> str: name = self.body["TableName"] key = self.body["Key"] return_values = self.body.get("ReturnValues", "NONE") @@ -795,7 +816,7 @@ class DynamoHandler(BaseResponse): item_dict["ConsumedCapacityUnits"] = 0.5 return dynamo_json_dump(item_dict) - def update_item(self): + def update_item(self) -> str: name = self.body["TableName"] key = self.body["Key"] return_values = self.body.get("ReturnValues", "NONE") @@ -870,7 +891,7 @@ class DynamoHandler(BaseResponse): ) return dynamo_json_dump(item_dict) - def _build_updated_new_attributes(self, original, changed): + def _build_updated_new_attributes(self, original: Any, changed: Any) -> Any: if type(changed) != type(original): return changed else: @@ -895,7 +916,7 @@ class DynamoHandler(BaseResponse): else: return changed - def describe_limits(self): + def describe_limits(self) -> str: return json.dumps( { "AccountMaxReadCapacityUnits": 20000, @@ -905,7 +926,7 @@ class DynamoHandler(BaseResponse): } ) - def update_time_to_live(self): + def update_time_to_live(self) -> str: name = self.body["TableName"] ttl_spec = self.body["TimeToLiveSpecification"] @@ -913,16 +934,16 @@ class DynamoHandler(BaseResponse): return json.dumps({"TimeToLiveSpecification": ttl_spec}) - def describe_time_to_live(self): + def describe_time_to_live(self) -> str: name = self.body["TableName"] ttl_spec = self.dynamodb_backend.describe_time_to_live(name) return json.dumps({"TimeToLiveDescription": ttl_spec}) - def transact_get_items(self): + def transact_get_items(self) -> str: transact_items = self.body["TransactItems"] - responses = list() + responses: List[Dict[str, Any]] = list() if len(transact_items) > TRANSACTION_MAX_ITEMS: msg = "1 validation error detected: Value '[" @@ -945,7 +966,7 @@ class DynamoHandler(BaseResponse): raise MockValidationException(msg) ret_consumed_capacity = self.body.get("ReturnConsumedCapacity", "NONE") - consumed_capacity = dict() + consumed_capacity: Dict[str, Any] = dict() for transact_item in transact_items: @@ -957,7 +978,7 @@ class DynamoHandler(BaseResponse): responses.append({}) continue - item_describe = item.describe_attrs(False) + item_describe = item.describe_attrs(attributes=None) responses.append(item_describe) table_capacity = consumed_capacity.get(table_name, {}) @@ -981,20 +1002,20 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(result) - def transact_write_items(self): + def transact_write_items(self) -> str: transact_items = self.body["TransactItems"] self.dynamodb_backend.transact_write_items(transact_items) - response = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}} + response: Dict[str, Any] = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}} return dynamo_json_dump(response) - def describe_continuous_backups(self): + def describe_continuous_backups(self) -> str: name = self.body["TableName"] response = self.dynamodb_backend.describe_continuous_backups(name) return json.dumps({"ContinuousBackupsDescription": response}) - def update_continuous_backups(self): + def update_continuous_backups(self) -> str: name = self.body["TableName"] point_in_time_spec = self.body["PointInTimeRecoverySpecification"] @@ -1004,14 +1025,14 @@ class DynamoHandler(BaseResponse): return json.dumps({"ContinuousBackupsDescription": response}) - def list_backups(self): + def list_backups(self) -> str: body = self.body table_name = body.get("TableName") backups = self.dynamodb_backend.list_backups(table_name) response = {"BackupSummaries": [backup.summary for backup in backups]} return dynamo_json_dump(response) - def create_backup(self): + def create_backup(self) -> str: body = self.body table_name = body.get("TableName") backup_name = body.get("BackupName") @@ -1019,21 +1040,21 @@ class DynamoHandler(BaseResponse): response = {"BackupDetails": backup.details} return dynamo_json_dump(response) - def delete_backup(self): + def delete_backup(self) -> str: body = self.body backup_arn = body.get("BackupArn") backup = self.dynamodb_backend.delete_backup(backup_arn) response = {"BackupDescription": backup.description} return dynamo_json_dump(response) - def describe_backup(self): + def describe_backup(self) -> str: body = self.body backup_arn = body.get("BackupArn") backup = self.dynamodb_backend.describe_backup(backup_arn) response = {"BackupDescription": backup.description} return dynamo_json_dump(response) - def restore_table_from_backup(self): + def restore_table_from_backup(self) -> str: body = self.body target_table_name = body.get("TargetTableName") backup_arn = body.get("BackupArn") @@ -1042,7 +1063,7 @@ class DynamoHandler(BaseResponse): ) return dynamo_json_dump(restored_table.describe()) - def restore_table_to_point_in_time(self): + def restore_table_to_point_in_time(self) -> str: body = self.body target_table_name = body.get("TargetTableName") source_table_name = body.get("SourceTableName") diff --git a/moto/dynamodbstreams/models.py b/moto/dynamodbstreams/models.py index a4764fd36..e0f176db4 100644 --- a/moto/dynamodbstreams/models.py +++ b/moto/dynamodbstreams/models.py @@ -4,8 +4,9 @@ import base64 from typing import Any, Dict, Optional from moto.core import BaseBackend, BackendDict, BaseModel -from moto.dynamodb.models import dynamodb_backends, DynamoJsonEncoder, DynamoDBBackend -from moto.dynamodb.models import Table, StreamShard +from moto.dynamodb.models import dynamodb_backends, DynamoDBBackend +from moto.dynamodb.models.table import Table, StreamShard +from moto.dynamodb.models.utilities import DynamoJsonEncoder class ShardIterator(BaseModel): @@ -86,7 +87,7 @@ class DynamoDBStreamsBackend(BaseBackend): "StreamStatus": ( "ENABLED" if table.latest_stream_label else "DISABLED" ), - "StreamViewType": table.stream_specification["StreamViewType"], + "StreamViewType": table.stream_specification["StreamViewType"], # type: ignore[index] "CreationRequestDateTime": table.stream_shard.created_on.isoformat(), # type: ignore[union-attr] "TableName": table.name, "KeySchema": table.schema, diff --git a/moto/utilities/tokenizer.py b/moto/utilities/tokenizer.py index a6e76dc7f..7872f8a7c 100644 --- a/moto/utilities/tokenizer.py +++ b/moto/utilities/tokenizer.py @@ -4,17 +4,17 @@ class GenericTokenizer: The final character to be returned will be an empty string, to notify the caller that we've reached the end. """ - def __init__(self, expression): + def __init__(self, expression: str): self.expression = expression self.token_pos = 0 - def __iter__(self): + def __iter__(self) -> "GenericTokenizer": return self - def is_eof(self): + def is_eof(self) -> bool: return self.peek() == "" - def peek(self, length=1): + def peek(self, length: int = 1) -> str: """ Peek the next character without changing the position """ @@ -23,7 +23,7 @@ class GenericTokenizer: except IndexError: return "" - def __next__(self): + def __next__(self) -> str: """ Returns the next character, or an empty string if we've reached the end of the string. Calling this method again will result in a StopIterator @@ -38,7 +38,7 @@ class GenericTokenizer: return "" raise StopIteration - def skip_characters(self, phrase, case_sensitive=False) -> None: + def skip_characters(self, phrase: str, case_sensitive: bool = False) -> None: """ Skip the characters in the supplied phrase. If any other character is encountered instead, this will fail. @@ -51,7 +51,7 @@ class GenericTokenizer: assert self.expression[self.token_pos] in [ch.lower(), ch.upper()] self.token_pos += 1 - def skip_white_space(self): + def skip_white_space(self) -> None: """ Skip any whitespace characters that are coming up """ diff --git a/moto/utilities/utils.py b/moto/utilities/utils.py index ce75b988e..7ff515706 100644 --- a/moto/utilities/utils.py +++ b/moto/utilities/utils.py @@ -13,14 +13,19 @@ def str2bool(v): return False -def load_resource(package: str, resource: str, as_json: bool = True) -> Dict[str, Any]: +def load_resource(package: str, resource: str) -> Dict[str, Any]: """ Open a file, and return the contents as JSON. Usage: load_resource(__name__, "resources/file.json") """ resource = pkgutil.get_data(package, resource) - return json.loads(resource) if as_json else resource.decode("utf-8") + return json.loads(resource) + + +def load_resource_as_str(package: str, resource: str) -> str: + resource = pkgutil.get_data(package, resource) + return resource.decode("utf-8") def merge_multiple_dicts(*args: Any) -> Dict[str, any]: diff --git a/setup.cfg b/setup.cfg index 5c2b55b82..b07a359dc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -229,7 +229,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/ds,moto/dynamodb_v20111205,moto/dynamodbstreams,moto/moto_api +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/moto_api show_column_numbers=True show_error_codes = True disable_error_code=abstract