402 lines
13 KiB
Python
402 lines
13 KiB
Python
import json
|
|
from collections import OrderedDict, defaultdict
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
|
|
from moto.core.base_backend import BackendDict, BaseBackend
|
|
from moto.core.common_models import BaseModel
|
|
from moto.core.utils import unix_time, utcnow
|
|
|
|
from .comparisons import get_comparison_func
|
|
|
|
|
|
class DynamoJsonEncoder(json.JSONEncoder):
|
|
def default(self, o: Any) -> Optional[str]: # type: ignore[return]
|
|
if hasattr(o, "to_json"):
|
|
return o.to_json()
|
|
|
|
|
|
def dynamo_json_dump(dynamo_object: Any) -> str:
|
|
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
|
|
|
|
|
|
class DynamoType(object):
|
|
"""
|
|
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
|
|
"""
|
|
|
|
def __init__(self, type_as_dict: Dict[str, Any]):
|
|
self.type = list(type_as_dict.keys())[0]
|
|
self.value = list(type_as_dict.values())[0]
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self.type, self.value))
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
return self.type == other.type and self.value == other.value
|
|
|
|
def __repr__(self) -> str:
|
|
return f"DynamoType: {self.to_json()}"
|
|
|
|
def add(self, dyn_type: "DynamoType") -> None:
|
|
if self.type == "SS":
|
|
self.value.append(dyn_type.value)
|
|
if self.type == "N":
|
|
self.value = str(int(self.value) + int(dyn_type.value))
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
return {self.type: self.value}
|
|
|
|
def compare(self, range_comparison: str, range_objs: List["DynamoType"]) -> Any:
|
|
"""
|
|
Compares this type against comparison filters
|
|
"""
|
|
range_values = [obj.value for obj in range_objs]
|
|
comparison_func = get_comparison_func(range_comparison)
|
|
return comparison_func(self.value, *range_values)
|
|
|
|
|
|
class Item(BaseModel):
|
|
def __init__(
|
|
self,
|
|
hash_key: DynamoType,
|
|
hash_key_type: str,
|
|
range_key: Optional[DynamoType],
|
|
range_key_type: Optional[str],
|
|
attrs: Dict[str, Any],
|
|
):
|
|
self.hash_key = hash_key
|
|
self.hash_key_type = hash_key_type
|
|
self.range_key = range_key
|
|
self.range_key_type = range_key_type
|
|
|
|
self.attrs = {}
|
|
for key, value in attrs.items():
|
|
self.attrs[key] = DynamoType(value)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Item: {self.to_json()}"
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
attributes = {}
|
|
for attribute_key, attribute in self.attrs.items():
|
|
attributes[attribute_key] = attribute.value
|
|
|
|
return {"Attributes": attributes}
|
|
|
|
def describe_attrs(self, attributes: List[str]) -> Dict[str, Any]:
|
|
if attributes:
|
|
included = {}
|
|
for key, value in self.attrs.items():
|
|
if key in attributes:
|
|
included[key] = value
|
|
else:
|
|
included = self.attrs
|
|
return {"Item": included}
|
|
|
|
|
|
class Table(BaseModel):
|
|
def __init__(
|
|
self,
|
|
account_id: str,
|
|
name: str,
|
|
hash_key_attr: str,
|
|
hash_key_type: str,
|
|
range_key_attr: Optional[str] = None,
|
|
range_key_type: Optional[str] = None,
|
|
read_capacity: Optional[str] = None,
|
|
write_capacity: Optional[str] = None,
|
|
):
|
|
self.account_id = account_id
|
|
self.name = name
|
|
self.hash_key_attr = hash_key_attr
|
|
self.hash_key_type = hash_key_type
|
|
self.range_key_attr = range_key_attr
|
|
self.range_key_type = range_key_type
|
|
self.read_capacity = read_capacity
|
|
self.write_capacity = write_capacity
|
|
self.created_at = utcnow()
|
|
self.items: Dict[DynamoType, Union[Item, Dict[DynamoType, Item]]] = defaultdict(
|
|
dict
|
|
)
|
|
|
|
@property
|
|
def has_range_key(self) -> bool:
|
|
return self.range_key_attr is not None
|
|
|
|
@property
|
|
def describe(self) -> Dict[str, Any]: # type: ignore[misc]
|
|
results: Dict[str, Any] = {
|
|
"Table": {
|
|
"CreationDateTime": unix_time(self.created_at),
|
|
"KeySchema": {
|
|
"HashKeyElement": {
|
|
"AttributeName": self.hash_key_attr,
|
|
"AttributeType": self.hash_key_type,
|
|
}
|
|
},
|
|
"ProvisionedThroughput": {
|
|
"ReadCapacityUnits": self.read_capacity,
|
|
"WriteCapacityUnits": self.write_capacity,
|
|
},
|
|
"TableName": self.name,
|
|
"TableStatus": "ACTIVE",
|
|
"ItemCount": len(self),
|
|
"TableSizeBytes": 0,
|
|
}
|
|
}
|
|
if self.has_range_key:
|
|
results["Table"]["KeySchema"]["RangeKeyElement"] = {
|
|
"AttributeName": self.range_key_attr,
|
|
"AttributeType": self.range_key_type,
|
|
}
|
|
return results
|
|
|
|
def __len__(self) -> int:
|
|
return sum(
|
|
[(len(value) if self.has_range_key else 1) for value in self.items.values()] # type: ignore
|
|
)
|
|
|
|
def __nonzero__(self) -> bool:
|
|
return True
|
|
|
|
def __bool__(self) -> bool:
|
|
return self.__nonzero__()
|
|
|
|
def put_item(self, item_attrs: Dict[str, Any]) -> Item:
|
|
hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) # type: ignore[arg-type]
|
|
if self.has_range_key:
|
|
range_value: Optional[DynamoType] = DynamoType(
|
|
item_attrs.get(self.range_key_attr)
|
|
) # type: ignore[arg-type]
|
|
else:
|
|
range_value = None
|
|
|
|
item = Item(
|
|
hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs
|
|
)
|
|
|
|
if range_value:
|
|
self.items[hash_value][range_value] = item # type: ignore[index]
|
|
else:
|
|
self.items[hash_value] = item
|
|
return item
|
|
|
|
def get_item(
|
|
self, hash_key: DynamoType, range_key: Optional[DynamoType]
|
|
) -> Optional[Item]:
|
|
if self.has_range_key and not range_key:
|
|
raise ValueError(
|
|
"Table has a range key, but no range key was passed into get_item"
|
|
)
|
|
try:
|
|
if range_key:
|
|
return self.items[hash_key][range_key] # type: ignore
|
|
else:
|
|
return self.items[hash_key] # type: ignore
|
|
except KeyError:
|
|
return None
|
|
|
|
def query(
|
|
self, hash_key: DynamoType, range_comparison: str, range_objs: Any
|
|
) -> Tuple[Iterable[Item], bool]:
|
|
results = []
|
|
last_page = True # Once pagination is implemented, change this
|
|
|
|
if self.range_key_attr:
|
|
possible_results = self.items[hash_key].values() # type: ignore[union-attr]
|
|
else:
|
|
possible_results = list(self.all_items())
|
|
|
|
if range_comparison:
|
|
for result in possible_results:
|
|
if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr]
|
|
results.append(result)
|
|
else:
|
|
# If we're not filtering on range key, return all values
|
|
results = possible_results # type: ignore[assignment]
|
|
return results, last_page
|
|
|
|
def all_items(self) -> Iterable[Item]:
|
|
for hash_set in self.items.values():
|
|
if self.range_key_attr:
|
|
for item in hash_set.values(): # type: ignore
|
|
yield item
|
|
else:
|
|
yield hash_set # type: ignore[misc]
|
|
|
|
def scan(self, filters: Dict[str, Any]) -> Tuple[List[Item], int, bool]:
|
|
results = []
|
|
scanned_count = 0
|
|
last_page = True # Once pagination is implemented, change this
|
|
|
|
for result in self.all_items():
|
|
scanned_count += 1
|
|
passes_all_conditions = True
|
|
for (
|
|
attribute_name,
|
|
(comparison_operator, comparison_objs),
|
|
) in filters.items():
|
|
attribute = result.attrs.get(attribute_name)
|
|
|
|
if attribute:
|
|
# Attribute found
|
|
if not attribute.compare(comparison_operator, comparison_objs):
|
|
passes_all_conditions = False
|
|
break
|
|
elif comparison_operator == "NULL":
|
|
# Comparison is NULL and we don't have the attribute
|
|
continue
|
|
else:
|
|
# No attribute found and comparison is no NULL. This item
|
|
# fails
|
|
passes_all_conditions = False
|
|
break
|
|
|
|
if passes_all_conditions:
|
|
results.append(result)
|
|
|
|
return results, scanned_count, last_page
|
|
|
|
def delete_item(
|
|
self, hash_key: DynamoType, range_key: Optional[DynamoType]
|
|
) -> Optional[Item]:
|
|
try:
|
|
if range_key:
|
|
return self.items[hash_key].pop(range_key) # type: ignore
|
|
else:
|
|
return self.items.pop(hash_key) # type: ignore
|
|
except KeyError:
|
|
return None
|
|
|
|
def update_item(
|
|
self,
|
|
hash_key: DynamoType,
|
|
range_key: Optional[DynamoType],
|
|
attr_updates: Dict[str, Any],
|
|
) -> Optional[Item]:
|
|
item = self.get_item(hash_key, range_key)
|
|
if not item:
|
|
return None
|
|
|
|
for attr, update in attr_updates.items():
|
|
if update["Action"] == "PUT":
|
|
item.attrs[attr] = DynamoType(update["Value"])
|
|
if update["Action"] == "DELETE":
|
|
item.attrs.pop(attr)
|
|
if update["Action"] == "ADD":
|
|
item.attrs[attr].add(DynamoType(update["Value"]))
|
|
return item
|
|
|
|
|
|
class DynamoDBBackend(BaseBackend):
|
|
def __init__(self, region_name: str, account_id: str):
|
|
super().__init__(region_name, account_id)
|
|
self.tables: Dict[str, Table] = OrderedDict()
|
|
|
|
def create_table(self, name: str, **params: Any) -> Table:
|
|
table = Table(self.account_id, name, **params)
|
|
self.tables[name] = table
|
|
return table
|
|
|
|
def delete_table(self, name: str) -> Optional[Table]:
|
|
return self.tables.pop(name, None)
|
|
|
|
def update_table_throughput(
|
|
self, name: str, new_read_units: str, new_write_units: str
|
|
) -> Table:
|
|
table = self.tables[name]
|
|
table.read_capacity = new_read_units
|
|
table.write_capacity = new_write_units
|
|
return table
|
|
|
|
def put_item(self, table_name: str, item_attrs: Dict[str, Any]) -> Optional[Item]:
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
return table.put_item(item_attrs)
|
|
|
|
def get_item(
|
|
self,
|
|
table_name: str,
|
|
hash_key_dict: Dict[str, Any],
|
|
range_key_dict: Optional[Dict[str, Any]],
|
|
) -> Optional[Item]:
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
|
|
|
return table.get_item(hash_key, range_key)
|
|
|
|
def query(
|
|
self,
|
|
table_name: str,
|
|
hash_key_dict: Dict[str, Any],
|
|
range_comparison: str,
|
|
range_value_dicts: List[Dict[str, Any]],
|
|
) -> Tuple[Optional[Iterable[Item]], Optional[bool]]:
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None, None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_values = [DynamoType(range_value) for range_value in range_value_dicts]
|
|
|
|
return table.query(hash_key, range_comparison, range_values)
|
|
|
|
def scan(
|
|
self, table_name: str, filters: Dict[str, Any]
|
|
) -> Tuple[Optional[List[Item]], Optional[int], Optional[bool]]:
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None, None, None
|
|
|
|
scan_filters = {}
|
|
for key, (comparison_operator, comparison_values) in filters.items():
|
|
dynamo_types = [DynamoType(value) for value in comparison_values]
|
|
scan_filters[key] = (comparison_operator, dynamo_types)
|
|
|
|
return table.scan(scan_filters)
|
|
|
|
def delete_item(
|
|
self,
|
|
table_name: str,
|
|
hash_key_dict: Dict[str, Any],
|
|
range_key_dict: Optional[Dict[str, Any]],
|
|
) -> Optional[Item]:
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
|
|
|
return table.delete_item(hash_key, range_key)
|
|
|
|
def update_item(
|
|
self,
|
|
table_name: str,
|
|
hash_key_dict: Dict[str, Any],
|
|
range_key_dict: Optional[Dict[str, Any]],
|
|
attr_updates: Dict[str, Any],
|
|
) -> Optional[Item]:
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
|
|
|
return table.update_item(hash_key, range_key, attr_updates)
|
|
|
|
|
|
dynamodb_backends = BackendDict(
|
|
DynamoDBBackend,
|
|
"dynamodb_v20111205",
|
|
use_boto3_regions=False,
|
|
additional_regions=["global"],
|
|
)
|