Techdebt: MyPy DynamoDB v20111205 (#5799)

This commit is contained in:
Bert Blommers 2022-12-22 09:58:08 -01:00 committed by GitHub
parent fb0a4d64c8
commit 5920d1a8ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 148 additions and 91 deletions

View File

@ -1,3 +1,6 @@
from typing import Callable, Any
# TODO add tests for all of these
COMPARISON_FUNCS = {
"EQ": lambda item_value, test_value: item_value == test_value,
@ -18,5 +21,5 @@ COMPARISON_FUNCS = {
}
def get_comparison_func(range_comparison):
return COMPARISON_FUNCS.get(range_comparison)
def get_comparison_func(range_comparison: str) -> Callable[..., Any]:
return COMPARISON_FUNCS.get(range_comparison) # type: ignore[return-value]

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from typing import Any, Dict, Optional, List, Union, Tuple, Iterable
import datetime
import json
@ -9,12 +10,12 @@ from .comparisons import get_comparison_func
class DynamoJsonEncoder(json.JSONEncoder):
def default(self, o):
def default(self, o: Any) -> Optional[str]: # type: ignore[return]
if hasattr(o, "to_json"):
return o.to_json()
def dynamo_json_dump(dynamo_object):
def dynamo_json_dump(dynamo_object: Any) -> str:
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
@ -23,29 +24,29 @@ class DynamoType(object):
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
"""
def __init__(self, type_as_dict):
def __init__(self, type_as_dict: Dict[str, Any]):
self.type = list(type_as_dict.keys())[0]
self.value = list(type_as_dict.values())[0]
def __hash__(self):
def __hash__(self) -> int:
return hash((self.type, self.value))
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return self.type == other.type and self.value == other.value
def __repr__(self):
def __repr__(self) -> str:
return f"DynamoType: {self.to_json()}"
def add(self, dyn_type):
def add(self, dyn_type: "DynamoType") -> None:
if self.type == "SS":
self.value.append(dyn_type.value)
if self.type == "N":
self.value = str(int(self.value) + int(dyn_type.value))
def to_json(self):
def to_json(self) -> Dict[str, Any]:
return {self.type: self.value}
def compare(self, range_comparison, range_objs):
def compare(self, range_comparison: str, range_objs: List["DynamoType"]) -> Any:
"""
Compares this type against comparison filters
"""
@ -55,7 +56,14 @@ class DynamoType(object):
class Item(BaseModel):
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
def __init__(
self,
hash_key: DynamoType,
hash_key_type: str,
range_key: Optional[DynamoType],
range_key_type: Optional[str],
attrs: Dict[str, Any],
):
self.hash_key = hash_key
self.hash_key_type = hash_key_type
self.range_key = range_key
@ -65,17 +73,17 @@ class Item(BaseModel):
for key, value in attrs.items():
self.attrs[key] = DynamoType(value)
def __repr__(self):
def __repr__(self) -> str:
return f"Item: {self.to_json()}"
def to_json(self):
def to_json(self) -> Dict[str, Any]:
attributes = {}
for attribute_key, attribute in self.attrs.items():
attributes[attribute_key] = attribute.value
return {"Attributes": attributes}
def describe_attrs(self, attributes):
def describe_attrs(self, attributes: List[str]) -> Dict[str, Any]:
if attributes:
included = {}
for key, value in self.attrs.items():
@ -89,14 +97,14 @@ class Item(BaseModel):
class Table(CloudFormationModel):
def __init__(
self,
account_id,
name,
hash_key_attr,
hash_key_type,
range_key_attr=None,
range_key_type=None,
read_capacity=None,
write_capacity=None,
account_id: str,
name: str,
hash_key_attr: str,
hash_key_type: str,
range_key_attr: Optional[str] = None,
range_key_type: Optional[str] = None,
read_capacity: Optional[str] = None,
write_capacity: Optional[str] = None,
):
self.account_id = account_id
self.name = name
@ -107,15 +115,17 @@ class Table(CloudFormationModel):
self.read_capacity = read_capacity
self.write_capacity = write_capacity
self.created_at = datetime.datetime.utcnow()
self.items = defaultdict(dict)
self.items: Dict[DynamoType, Union[Item, Dict[DynamoType, Item]]] = defaultdict(
dict
)
@property
def has_range_key(self):
def has_range_key(self) -> bool:
return self.range_key_attr is not None
@property
def describe(self):
results = {
def describe(self) -> Dict[str, Any]: # type: ignore[misc]
results: Dict[str, Any] = {
"Table": {
"CreationDateTime": unix_time(self.created_at),
"KeySchema": {
@ -142,18 +152,23 @@ class Table(CloudFormationModel):
return results
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "TableName"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html
return "AWS::DynamoDB::Table"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Table":
properties = cloudformation_json["Properties"]
key_attr = [
i["AttributeName"]
@ -175,21 +190,21 @@ class Table(CloudFormationModel):
# range_key_attr, range_key_type, read_capacity, write_capacity
return Table(**spec)
def __len__(self):
def __len__(self) -> int:
return sum(
[(len(value) if self.has_range_key else 1) for value in self.items.values()]
[(len(value) if self.has_range_key else 1) for value in self.items.values()] # type: ignore
)
def __nonzero__(self):
def __nonzero__(self) -> bool:
return True
def __bool__(self):
def __bool__(self) -> bool:
return self.__nonzero__()
def put_item(self, item_attrs):
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
def put_item(self, item_attrs: Dict[str, Any]) -> Item:
hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) # type: ignore[arg-type]
if self.has_range_key:
range_value = DynamoType(item_attrs.get(self.range_key_attr))
range_value: Optional[DynamoType] = DynamoType(item_attrs.get(self.range_key_attr)) # type: ignore[arg-type]
else:
range_value = None
@ -198,51 +213,55 @@ class Table(CloudFormationModel):
)
if range_value:
self.items[hash_value][range_value] = item
self.items[hash_value][range_value] = item # type: ignore[index]
else:
self.items[hash_value] = item
return item
def get_item(self, hash_key, range_key):
def get_item(
self, hash_key: DynamoType, range_key: Optional[DynamoType]
) -> Optional[Item]:
if self.has_range_key and not range_key:
raise ValueError(
"Table has a range key, but no range key was passed into get_item"
)
try:
if range_key:
return self.items[hash_key][range_key]
return self.items[hash_key][range_key] # type: ignore
else:
return self.items[hash_key]
return self.items[hash_key] # type: ignore
except KeyError:
return None
def query(self, hash_key, range_comparison, range_objs):
def query(
self, hash_key: DynamoType, range_comparison: str, range_objs: Any
) -> Tuple[Iterable[Item], bool]:
results = []
last_page = True # Once pagination is implemented, change this
if self.range_key_attr:
possible_results = self.items[hash_key].values()
possible_results = self.items[hash_key].values() # type: ignore[union-attr]
else:
possible_results = list(self.all_items())
if range_comparison:
for result in possible_results:
if result.range_key.compare(range_comparison, range_objs):
if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr]
results.append(result)
else:
# If we're not filtering on range key, return all values
results = possible_results
results = possible_results # type: ignore[assignment]
return results, last_page
def all_items(self):
def all_items(self) -> Iterable[Item]:
for hash_set in self.items.values():
if self.range_key_attr:
for item in hash_set.values():
for item in hash_set.values(): # type: ignore
yield item
else:
yield hash_set
yield hash_set # type: ignore[misc]
def scan(self, filters):
def scan(self, filters: Dict[str, Any]) -> Tuple[List[Item], int, bool]:
results = []
scanned_count = 0
last_page = True # Once pagination is implemented, change this
@ -275,16 +294,23 @@ class Table(CloudFormationModel):
return results, scanned_count, last_page
def delete_item(self, hash_key, range_key):
def delete_item(
self, hash_key: DynamoType, range_key: Optional[DynamoType]
) -> Optional[Item]:
try:
if range_key:
return self.items[hash_key].pop(range_key)
return self.items[hash_key].pop(range_key) # type: ignore
else:
return self.items.pop(hash_key)
return self.items.pop(hash_key) # type: ignore
except KeyError:
return None
def update_item(self, hash_key, range_key, attr_updates):
def update_item(
self,
hash_key: DynamoType,
range_key: Optional[DynamoType],
attr_updates: Dict[str, Any],
) -> Optional[Item]:
item = self.get_item(hash_key, range_key)
if not item:
return None
@ -299,10 +325,10 @@ class Table(CloudFormationModel):
return item
@classmethod
def has_cfn_attr(cls, attr):
def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["StreamArn"]
def get_cfn_attribute(self, attribute_name):
def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "StreamArn":
@ -313,32 +339,39 @@ class Table(CloudFormationModel):
class DynamoDBBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.tables = OrderedDict()
self.tables: Dict[str, Table] = OrderedDict()
def create_table(self, name, **params):
def create_table(self, name: str, **params: Any) -> Table:
table = Table(self.account_id, name, **params)
self.tables[name] = table
return table
def delete_table(self, name):
def delete_table(self, name: str) -> Optional[Table]:
return self.tables.pop(name, None)
def update_table_throughput(self, name, new_read_units, new_write_units):
def update_table_throughput(
self, name: str, new_read_units: str, new_write_units: str
) -> Table:
table = self.tables[name]
table.read_capacity = new_read_units
table.write_capacity = new_write_units
return table
def put_item(self, table_name, item_attrs):
def put_item(self, table_name: str, item_attrs: Dict[str, Any]) -> Optional[Item]:
table = self.tables.get(table_name)
if not table:
return None
return table.put_item(item_attrs)
def get_item(self, table_name, hash_key_dict, range_key_dict):
def get_item(
self,
table_name: str,
hash_key_dict: Dict[str, Any],
range_key_dict: Optional[Dict[str, Any]],
) -> Optional[Item]:
table = self.tables.get(table_name)
if not table:
return None
@ -348,7 +381,13 @@ class DynamoDBBackend(BaseBackend):
return table.get_item(hash_key, range_key)
def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts):
def query(
self,
table_name: str,
hash_key_dict: Dict[str, Any],
range_comparison: str,
range_value_dicts: List[Dict[str, Any]],
) -> Tuple[Optional[Iterable[Item]], Optional[bool]]:
table = self.tables.get(table_name)
if not table:
return None, None
@ -358,7 +397,9 @@ class DynamoDBBackend(BaseBackend):
return table.query(hash_key, range_comparison, range_values)
def scan(self, table_name, filters):
def scan(
self, table_name: str, filters: Dict[str, Any]
) -> Tuple[Optional[List[Item]], Optional[int], Optional[bool]]:
table = self.tables.get(table_name)
if not table:
return None, None, None
@ -370,7 +411,12 @@ class DynamoDBBackend(BaseBackend):
return table.scan(scan_filters)
def delete_item(self, table_name, hash_key_dict, range_key_dict):
def delete_item(
self,
table_name: str,
hash_key_dict: Dict[str, Any],
range_key_dict: Optional[Dict[str, Any]],
) -> Optional[Item]:
table = self.tables.get(table_name)
if not table:
return None
@ -380,7 +426,13 @@ class DynamoDBBackend(BaseBackend):
return table.delete_item(hash_key, range_key)
def update_item(self, table_name, hash_key_dict, range_key_dict, attr_updates):
def update_item(
self,
table_name: str,
hash_key_dict: Dict[str, Any],
range_key_dict: Optional[Dict[str, Any]],
attr_updates: Dict[str, Any],
) -> Optional[Item]:
table = self.tables.get(table_name)
if not table:
return None

View File

@ -1,15 +1,17 @@
import json
from typing import Any, Dict, Optional, Union
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from .models import dynamodb_backends, dynamo_json_dump
from .models import dynamodb_backends, DynamoDBBackend, dynamo_json_dump
class DynamoHandler(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="dynamodb")
def get_endpoint_name(self, headers):
def get_endpoint_name(self, headers: Dict[str, str]) -> Optional[str]: # type: ignore[return]
"""Parses request headers and extracts part od the X-Amz-Target
that corresponds to a method of DynamoHandler
@ -20,10 +22,10 @@ class DynamoHandler(BaseResponse):
if match:
return match.split(".")[1]
def error(self, type_, status=400):
def error(self, type_: str, status: int = 400) -> TYPE_RESPONSE:
return status, self.response_headers, dynamo_json_dump({"__type": type_})
def call_action(self):
def call_action(self) -> TYPE_RESPONSE:
self.body = json.loads(self.body or "{}")
endpoint = self.get_endpoint_name(self.headers)
if endpoint:
@ -40,10 +42,10 @@ class DynamoHandler(BaseResponse):
return 404, self.response_headers, ""
@property
def backend(self):
def backend(self) -> DynamoDBBackend:
return dynamodb_backends[self.current_account]["global"]
def list_tables(self):
def list_tables(self) -> str:
body = self.body
limit = body.get("Limit")
if body.get("ExclusiveStartTableName"):
@ -56,12 +58,12 @@ class DynamoHandler(BaseResponse):
tables = all_tables[start : start + limit]
else:
tables = all_tables[start:]
response = {"TableNames": tables}
response: Dict[str, Any] = {"TableNames": tables}
if limit and len(all_tables) > start + limit:
response["LastEvaluatedTableName"] = tables[-1]
return dynamo_json_dump(response)
def create_table(self):
def create_table(self) -> str:
body = self.body
name = body["TableName"]
@ -89,7 +91,7 @@ class DynamoHandler(BaseResponse):
)
return dynamo_json_dump(table.describe)
def delete_table(self):
def delete_table(self) -> Union[str, TYPE_RESPONSE]:
name = self.body["TableName"]
table = self.backend.delete_table(name)
if table:
@ -98,7 +100,7 @@ class DynamoHandler(BaseResponse):
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)
def update_table(self):
def update_table(self) -> str:
name = self.body["TableName"]
throughput = self.body["ProvisionedThroughput"]
new_read_units = throughput["ReadCapacityUnits"]
@ -108,7 +110,7 @@ class DynamoHandler(BaseResponse):
)
return dynamo_json_dump(table.describe)
def describe_table(self):
def describe_table(self) -> Union[TYPE_RESPONSE, str]:
name = self.body["TableName"]
try:
table = self.backend.tables[name]
@ -117,7 +119,7 @@ class DynamoHandler(BaseResponse):
return self.error(er)
return dynamo_json_dump(table.describe)
def put_item(self):
def put_item(self) -> Union[TYPE_RESPONSE, str]:
name = self.body["TableName"]
item = self.body["Item"]
result = self.backend.put_item(name, item)
@ -129,7 +131,7 @@ class DynamoHandler(BaseResponse):
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)
def batch_write_item(self):
def batch_write_item(self) -> str:
table_batches = self.body["RequestItems"]
for table_name, table_requests in table_batches.items():
@ -156,7 +158,7 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(response)
def get_item(self):
def get_item(self) -> Union[TYPE_RESPONSE, str]:
name = self.body["TableName"]
key = self.body["Key"]
hash_key = key["HashKeyElement"]
@ -176,10 +178,10 @@ class DynamoHandler(BaseResponse):
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er, status=404)
def batch_get_item(self):
def batch_get_item(self) -> str:
table_batches = self.body["RequestItems"]
results = {"Responses": {"UnprocessedKeys": {}}}
results: Dict[str, Any] = {"Responses": {"UnprocessedKeys": {}}}
for table_name, table_request in table_batches.items():
items = []
@ -198,7 +200,7 @@ class DynamoHandler(BaseResponse):
}
return dynamo_json_dump(results)
def query(self):
def query(self) -> Union[TYPE_RESPONSE, str]:
name = self.body["TableName"]
hash_key = self.body["HashKeyValue"]
range_condition = self.body.get("RangeKeyCondition")
@ -216,7 +218,7 @@ class DynamoHandler(BaseResponse):
return self.error(er)
result = {
"Count": len(items),
"Count": len(items), # type: ignore[arg-type]
"Items": [item.attrs for item in items],
"ConsumedCapacityUnits": 1,
}
@ -229,7 +231,7 @@ class DynamoHandler(BaseResponse):
# }
return dynamo_json_dump(result)
def scan(self):
def scan(self) -> Union[TYPE_RESPONSE, str]:
name = self.body["TableName"]
filters = {}
@ -262,7 +264,7 @@ class DynamoHandler(BaseResponse):
# }
return dynamo_json_dump(result)
def delete_item(self):
def delete_item(self) -> Union[TYPE_RESPONSE, str]:
name = self.body["TableName"]
key = self.body["Key"]
hash_key = key["HashKeyElement"]
@ -280,7 +282,7 @@ class DynamoHandler(BaseResponse):
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)
def update_item(self):
def update_item(self) -> Union[TYPE_RESPONSE, str]:
name = self.body["TableName"]
key = self.body["Key"]
hash_key = key["HashKeyElement"]

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/ds,moto/dynamodbstreams,moto/moto_api
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/ds,moto/dynamodb_v20111205,moto/dynamodbstreams,moto/moto_api
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract