diff --git a/moto/dynamodb_v20111205/comparisons.py b/moto/dynamodb_v20111205/comparisons.py index f31b9d5c3..883f10203 100644 --- a/moto/dynamodb_v20111205/comparisons.py +++ b/moto/dynamodb_v20111205/comparisons.py @@ -1,3 +1,6 @@ +from typing import Callable, Any + + # TODO add tests for all of these COMPARISON_FUNCS = { "EQ": lambda item_value, test_value: item_value == test_value, @@ -18,5 +21,5 @@ COMPARISON_FUNCS = { } -def get_comparison_func(range_comparison): - return COMPARISON_FUNCS.get(range_comparison) +def get_comparison_func(range_comparison: str) -> Callable[..., Any]: + return COMPARISON_FUNCS.get(range_comparison) # type: ignore[return-value] diff --git a/moto/dynamodb_v20111205/models.py b/moto/dynamodb_v20111205/models.py index 696323d9c..b7fc906a3 100644 --- a/moto/dynamodb_v20111205/models.py +++ b/moto/dynamodb_v20111205/models.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import Any, Dict, Optional, List, Union, Tuple, Iterable import datetime import json @@ -9,12 +10,12 @@ from .comparisons import get_comparison_func class DynamoJsonEncoder(json.JSONEncoder): - def default(self, o): + def default(self, o: Any) -> Optional[str]: # type: ignore[return] if hasattr(o, "to_json"): return o.to_json() -def dynamo_json_dump(dynamo_object): +def dynamo_json_dump(dynamo_object: Any) -> str: return json.dumps(dynamo_object, cls=DynamoJsonEncoder) @@ -23,29 +24,29 @@ class DynamoType(object): http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes """ - def __init__(self, type_as_dict): + def __init__(self, type_as_dict: Dict[str, Any]): self.type = list(type_as_dict.keys())[0] self.value = list(type_as_dict.values())[0] - def __hash__(self): + def __hash__(self) -> int: return hash((self.type, self.value)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self.type == other.type and self.value == other.value - def __repr__(self): + def __repr__(self) -> str: return f"DynamoType: {self.to_json()}" - def add(self, dyn_type): + def add(self, dyn_type: "DynamoType") -> None: if self.type == "SS": self.value.append(dyn_type.value) if self.type == "N": self.value = str(int(self.value) + int(dyn_type.value)) - def to_json(self): + def to_json(self) -> Dict[str, Any]: return {self.type: self.value} - def compare(self, range_comparison, range_objs): + def compare(self, range_comparison: str, range_objs: List["DynamoType"]) -> Any: """ Compares this type against comparison filters """ @@ -55,7 +56,14 @@ class DynamoType(object): class Item(BaseModel): - def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs): + def __init__( + self, + hash_key: DynamoType, + hash_key_type: str, + range_key: Optional[DynamoType], + range_key_type: Optional[str], + attrs: Dict[str, Any], + ): self.hash_key = hash_key self.hash_key_type = hash_key_type self.range_key = range_key @@ -65,17 +73,17 @@ class Item(BaseModel): for key, value in attrs.items(): self.attrs[key] = DynamoType(value) - def __repr__(self): + def __repr__(self) -> str: return f"Item: {self.to_json()}" - def to_json(self): + def to_json(self) -> Dict[str, Any]: attributes = {} for attribute_key, attribute in self.attrs.items(): attributes[attribute_key] = attribute.value return {"Attributes": attributes} - def describe_attrs(self, attributes): + def describe_attrs(self, attributes: List[str]) -> Dict[str, Any]: if attributes: included = {} for key, value in self.attrs.items(): @@ -89,14 +97,14 @@ class Item(BaseModel): class Table(CloudFormationModel): def __init__( self, - account_id, - name, - hash_key_attr, - hash_key_type, - range_key_attr=None, - range_key_type=None, - read_capacity=None, - write_capacity=None, + account_id: str, + name: str, + hash_key_attr: str, + hash_key_type: str, + range_key_attr: Optional[str] = None, + range_key_type: Optional[str] = None, + read_capacity: Optional[str] = None, + write_capacity: Optional[str] = None, ): self.account_id = account_id self.name = name @@ -107,15 +115,17 @@ class Table(CloudFormationModel): self.read_capacity = read_capacity self.write_capacity = write_capacity self.created_at = datetime.datetime.utcnow() - self.items = defaultdict(dict) + self.items: Dict[DynamoType, Union[Item, Dict[DynamoType, Item]]] = defaultdict( + dict + ) @property - def has_range_key(self): + def has_range_key(self) -> bool: return self.range_key_attr is not None @property - def describe(self): - results = { + def describe(self) -> Dict[str, Any]: # type: ignore[misc] + results: Dict[str, Any] = { "Table": { "CreationDateTime": unix_time(self.created_at), "KeySchema": { @@ -142,18 +152,23 @@ class Table(CloudFormationModel): return results @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "TableName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html return "AWS::DynamoDB::Table" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Table": properties = cloudformation_json["Properties"] key_attr = [ i["AttributeName"] @@ -175,21 +190,21 @@ class Table(CloudFormationModel): # range_key_attr, range_key_type, read_capacity, write_capacity return Table(**spec) - def __len__(self): + def __len__(self) -> int: return sum( - [(len(value) if self.has_range_key else 1) for value in self.items.values()] + [(len(value) if self.has_range_key else 1) for value in self.items.values()] # type: ignore ) - def __nonzero__(self): + def __nonzero__(self) -> bool: return True - def __bool__(self): + def __bool__(self) -> bool: return self.__nonzero__() - def put_item(self, item_attrs): - hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) + def put_item(self, item_attrs: Dict[str, Any]) -> Item: + hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) # type: ignore[arg-type] if self.has_range_key: - range_value = DynamoType(item_attrs.get(self.range_key_attr)) + range_value: Optional[DynamoType] = DynamoType(item_attrs.get(self.range_key_attr)) # type: ignore[arg-type] else: range_value = None @@ -198,51 +213,55 @@ class Table(CloudFormationModel): ) if range_value: - self.items[hash_value][range_value] = item + self.items[hash_value][range_value] = item # type: ignore[index] else: self.items[hash_value] = item return item - def get_item(self, hash_key, range_key): + def get_item( + self, hash_key: DynamoType, range_key: Optional[DynamoType] + ) -> Optional[Item]: if self.has_range_key and not range_key: raise ValueError( "Table has a range key, but no range key was passed into get_item" ) try: if range_key: - return self.items[hash_key][range_key] + return self.items[hash_key][range_key] # type: ignore else: - return self.items[hash_key] + return self.items[hash_key] # type: ignore except KeyError: return None - def query(self, hash_key, range_comparison, range_objs): + def query( + self, hash_key: DynamoType, range_comparison: str, range_objs: Any + ) -> Tuple[Iterable[Item], bool]: results = [] last_page = True # Once pagination is implemented, change this if self.range_key_attr: - possible_results = self.items[hash_key].values() + possible_results = self.items[hash_key].values() # type: ignore[union-attr] else: possible_results = list(self.all_items()) if range_comparison: for result in possible_results: - if result.range_key.compare(range_comparison, range_objs): + if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr] results.append(result) else: # If we're not filtering on range key, return all values - results = possible_results + results = possible_results # type: ignore[assignment] return results, last_page - def all_items(self): + def all_items(self) -> Iterable[Item]: for hash_set in self.items.values(): if self.range_key_attr: - for item in hash_set.values(): + for item in hash_set.values(): # type: ignore yield item else: - yield hash_set + yield hash_set # type: ignore[misc] - def scan(self, filters): + def scan(self, filters: Dict[str, Any]) -> Tuple[List[Item], int, bool]: results = [] scanned_count = 0 last_page = True # Once pagination is implemented, change this @@ -275,16 +294,23 @@ class Table(CloudFormationModel): return results, scanned_count, last_page - def delete_item(self, hash_key, range_key): + def delete_item( + self, hash_key: DynamoType, range_key: Optional[DynamoType] + ) -> Optional[Item]: try: if range_key: - return self.items[hash_key].pop(range_key) + return self.items[hash_key].pop(range_key) # type: ignore else: - return self.items.pop(hash_key) + return self.items.pop(hash_key) # type: ignore except KeyError: return None - def update_item(self, hash_key, range_key, attr_updates): + def update_item( + self, + hash_key: DynamoType, + range_key: Optional[DynamoType], + attr_updates: Dict[str, Any], + ) -> Optional[Item]: item = self.get_item(hash_key, range_key) if not item: return None @@ -299,10 +325,10 @@ class Table(CloudFormationModel): return item @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["StreamArn"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "StreamArn": @@ -313,32 +339,39 @@ class Table(CloudFormationModel): class DynamoDBBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.tables = OrderedDict() + self.tables: Dict[str, Table] = OrderedDict() - def create_table(self, name, **params): + def create_table(self, name: str, **params: Any) -> Table: table = Table(self.account_id, name, **params) self.tables[name] = table return table - def delete_table(self, name): + def delete_table(self, name: str) -> Optional[Table]: return self.tables.pop(name, None) - def update_table_throughput(self, name, new_read_units, new_write_units): + def update_table_throughput( + self, name: str, new_read_units: str, new_write_units: str + ) -> Table: table = self.tables[name] table.read_capacity = new_read_units table.write_capacity = new_write_units return table - def put_item(self, table_name, item_attrs): + def put_item(self, table_name: str, item_attrs: Dict[str, Any]) -> Optional[Item]: table = self.tables.get(table_name) if not table: return None return table.put_item(item_attrs) - def get_item(self, table_name, hash_key_dict, range_key_dict): + def get_item( + self, + table_name: str, + hash_key_dict: Dict[str, Any], + range_key_dict: Optional[Dict[str, Any]], + ) -> Optional[Item]: table = self.tables.get(table_name) if not table: return None @@ -348,7 +381,13 @@ class DynamoDBBackend(BaseBackend): return table.get_item(hash_key, range_key) - def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts): + def query( + self, + table_name: str, + hash_key_dict: Dict[str, Any], + range_comparison: str, + range_value_dicts: List[Dict[str, Any]], + ) -> Tuple[Optional[Iterable[Item]], Optional[bool]]: table = self.tables.get(table_name) if not table: return None, None @@ -358,7 +397,9 @@ class DynamoDBBackend(BaseBackend): return table.query(hash_key, range_comparison, range_values) - def scan(self, table_name, filters): + def scan( + self, table_name: str, filters: Dict[str, Any] + ) -> Tuple[Optional[List[Item]], Optional[int], Optional[bool]]: table = self.tables.get(table_name) if not table: return None, None, None @@ -370,7 +411,12 @@ class DynamoDBBackend(BaseBackend): return table.scan(scan_filters) - def delete_item(self, table_name, hash_key_dict, range_key_dict): + def delete_item( + self, + table_name: str, + hash_key_dict: Dict[str, Any], + range_key_dict: Optional[Dict[str, Any]], + ) -> Optional[Item]: table = self.tables.get(table_name) if not table: return None @@ -380,7 +426,13 @@ class DynamoDBBackend(BaseBackend): return table.delete_item(hash_key, range_key) - def update_item(self, table_name, hash_key_dict, range_key_dict, attr_updates): + def update_item( + self, + table_name: str, + hash_key_dict: Dict[str, Any], + range_key_dict: Optional[Dict[str, Any]], + attr_updates: Dict[str, Any], + ) -> Optional[Item]: table = self.tables.get(table_name) if not table: return None diff --git a/moto/dynamodb_v20111205/responses.py b/moto/dynamodb_v20111205/responses.py index 9c7b82732..84b98da7d 100644 --- a/moto/dynamodb_v20111205/responses.py +++ b/moto/dynamodb_v20111205/responses.py @@ -1,15 +1,17 @@ import json +from typing import Any, Dict, Optional, Union +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores -from .models import dynamodb_backends, dynamo_json_dump +from .models import dynamodb_backends, DynamoDBBackend, dynamo_json_dump class DynamoHandler(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="dynamodb") - def get_endpoint_name(self, headers): + def get_endpoint_name(self, headers: Dict[str, str]) -> Optional[str]: # type: ignore[return] """Parses request headers and extracts part od the X-Amz-Target that corresponds to a method of DynamoHandler @@ -20,10 +22,10 @@ class DynamoHandler(BaseResponse): if match: return match.split(".")[1] - def error(self, type_, status=400): + def error(self, type_: str, status: int = 400) -> TYPE_RESPONSE: return status, self.response_headers, dynamo_json_dump({"__type": type_}) - def call_action(self): + def call_action(self) -> TYPE_RESPONSE: self.body = json.loads(self.body or "{}") endpoint = self.get_endpoint_name(self.headers) if endpoint: @@ -40,10 +42,10 @@ class DynamoHandler(BaseResponse): return 404, self.response_headers, "" @property - def backend(self): + def backend(self) -> DynamoDBBackend: return dynamodb_backends[self.current_account]["global"] - def list_tables(self): + def list_tables(self) -> str: body = self.body limit = body.get("Limit") if body.get("ExclusiveStartTableName"): @@ -56,12 +58,12 @@ class DynamoHandler(BaseResponse): tables = all_tables[start : start + limit] else: tables = all_tables[start:] - response = {"TableNames": tables} + response: Dict[str, Any] = {"TableNames": tables} if limit and len(all_tables) > start + limit: response["LastEvaluatedTableName"] = tables[-1] return dynamo_json_dump(response) - def create_table(self): + def create_table(self) -> str: body = self.body name = body["TableName"] @@ -89,7 +91,7 @@ class DynamoHandler(BaseResponse): ) return dynamo_json_dump(table.describe) - def delete_table(self): + def delete_table(self) -> Union[str, TYPE_RESPONSE]: name = self.body["TableName"] table = self.backend.delete_table(name) if table: @@ -98,7 +100,7 @@ class DynamoHandler(BaseResponse): er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) - def update_table(self): + def update_table(self) -> str: name = self.body["TableName"] throughput = self.body["ProvisionedThroughput"] new_read_units = throughput["ReadCapacityUnits"] @@ -108,7 +110,7 @@ class DynamoHandler(BaseResponse): ) return dynamo_json_dump(table.describe) - def describe_table(self): + def describe_table(self) -> Union[TYPE_RESPONSE, str]: name = self.body["TableName"] try: table = self.backend.tables[name] @@ -117,7 +119,7 @@ class DynamoHandler(BaseResponse): return self.error(er) return dynamo_json_dump(table.describe) - def put_item(self): + def put_item(self) -> Union[TYPE_RESPONSE, str]: name = self.body["TableName"] item = self.body["Item"] result = self.backend.put_item(name, item) @@ -129,7 +131,7 @@ class DynamoHandler(BaseResponse): er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) - def batch_write_item(self): + def batch_write_item(self) -> str: table_batches = self.body["RequestItems"] for table_name, table_requests in table_batches.items(): @@ -156,7 +158,7 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(response) - def get_item(self): + def get_item(self) -> Union[TYPE_RESPONSE, str]: name = self.body["TableName"] key = self.body["Key"] hash_key = key["HashKeyElement"] @@ -176,10 +178,10 @@ class DynamoHandler(BaseResponse): er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er, status=404) - def batch_get_item(self): + def batch_get_item(self) -> str: table_batches = self.body["RequestItems"] - results = {"Responses": {"UnprocessedKeys": {}}} + results: Dict[str, Any] = {"Responses": {"UnprocessedKeys": {}}} for table_name, table_request in table_batches.items(): items = [] @@ -198,7 +200,7 @@ class DynamoHandler(BaseResponse): } return dynamo_json_dump(results) - def query(self): + def query(self) -> Union[TYPE_RESPONSE, str]: name = self.body["TableName"] hash_key = self.body["HashKeyValue"] range_condition = self.body.get("RangeKeyCondition") @@ -216,7 +218,7 @@ class DynamoHandler(BaseResponse): return self.error(er) result = { - "Count": len(items), + "Count": len(items), # type: ignore[arg-type] "Items": [item.attrs for item in items], "ConsumedCapacityUnits": 1, } @@ -229,7 +231,7 @@ class DynamoHandler(BaseResponse): # } return dynamo_json_dump(result) - def scan(self): + def scan(self) -> Union[TYPE_RESPONSE, str]: name = self.body["TableName"] filters = {} @@ -262,7 +264,7 @@ class DynamoHandler(BaseResponse): # } return dynamo_json_dump(result) - def delete_item(self): + def delete_item(self) -> Union[TYPE_RESPONSE, str]: name = self.body["TableName"] key = self.body["Key"] hash_key = key["HashKeyElement"] @@ -280,7 +282,7 @@ class DynamoHandler(BaseResponse): er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) - def update_item(self): + def update_item(self) -> Union[TYPE_RESPONSE, str]: name = self.body["TableName"] key = self.body["Key"] hash_key = key["HashKeyElement"] diff --git a/setup.cfg b/setup.cfg index df0761ab6..6416857cb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/ds,moto/dynamodbstreams,moto/moto_api +files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/ds,moto/dynamodb_v20111205,moto/dynamodbstreams,moto/moto_api show_column_numbers=True show_error_codes = True disable_error_code=abstract