402 lines
13 KiB
Python

import json
from collections import OrderedDict, defaultdict
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.core.utils import unix_time, utcnow
from .comparisons import get_comparison_func
class DynamoJsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> Optional[str]: # type: ignore[return]
if hasattr(o, "to_json"):
return o.to_json()
def dynamo_json_dump(dynamo_object: Any) -> str:
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
class DynamoType(object):
"""
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
"""
def __init__(self, type_as_dict: Dict[str, Any]):
self.type = list(type_as_dict.keys())[0]
self.value = list(type_as_dict.values())[0]
def __hash__(self) -> int:
return hash((self.type, self.value))
def __eq__(self, other: Any) -> bool:
return self.type == other.type and self.value == other.value
def __repr__(self) -> str:
return f"DynamoType: {self.to_json()}"
def add(self, dyn_type: "DynamoType") -> None:
if self.type == "SS":
self.value.append(dyn_type.value)
if self.type == "N":
self.value = str(int(self.value) + int(dyn_type.value))
def to_json(self) -> Dict[str, Any]:
return {self.type: self.value}
def compare(self, range_comparison: str, range_objs: List["DynamoType"]) -> Any:
"""
Compares this type against comparison filters
"""
range_values = [obj.value for obj in range_objs]
comparison_func = get_comparison_func(range_comparison)
return comparison_func(self.value, *range_values)
class Item(BaseModel):
def __init__(
self,
hash_key: DynamoType,
hash_key_type: str,
range_key: Optional[DynamoType],
range_key_type: Optional[str],
attrs: Dict[str, Any],
):
self.hash_key = hash_key
self.hash_key_type = hash_key_type
self.range_key = range_key
self.range_key_type = range_key_type
self.attrs = {}
for key, value in attrs.items():
self.attrs[key] = DynamoType(value)
def __repr__(self) -> str:
return f"Item: {self.to_json()}"
def to_json(self) -> Dict[str, Any]:
attributes = {}
for attribute_key, attribute in self.attrs.items():
attributes[attribute_key] = attribute.value
return {"Attributes": attributes}
def describe_attrs(self, attributes: List[str]) -> Dict[str, Any]:
if attributes:
included = {}
for key, value in self.attrs.items():
if key in attributes:
included[key] = value
else:
included = self.attrs
return {"Item": included}
class Table(BaseModel):
def __init__(
self,
account_id: str,
name: str,
hash_key_attr: str,
hash_key_type: str,
range_key_attr: Optional[str] = None,
range_key_type: Optional[str] = None,
read_capacity: Optional[str] = None,
write_capacity: Optional[str] = None,
):
self.account_id = account_id
self.name = name
self.hash_key_attr = hash_key_attr
self.hash_key_type = hash_key_type
self.range_key_attr = range_key_attr
self.range_key_type = range_key_type
self.read_capacity = read_capacity
self.write_capacity = write_capacity
self.created_at = utcnow()
self.items: Dict[DynamoType, Union[Item, Dict[DynamoType, Item]]] = defaultdict(
dict
)
@property
def has_range_key(self) -> bool:
return self.range_key_attr is not None
@property
def describe(self) -> Dict[str, Any]: # type: ignore[misc]
results: Dict[str, Any] = {
"Table": {
"CreationDateTime": unix_time(self.created_at),
"KeySchema": {
"HashKeyElement": {
"AttributeName": self.hash_key_attr,
"AttributeType": self.hash_key_type,
}
},
"ProvisionedThroughput": {
"ReadCapacityUnits": self.read_capacity,
"WriteCapacityUnits": self.write_capacity,
},
"TableName": self.name,
"TableStatus": "ACTIVE",
"ItemCount": len(self),
"TableSizeBytes": 0,
}
}
if self.has_range_key:
results["Table"]["KeySchema"]["RangeKeyElement"] = {
"AttributeName": self.range_key_attr,
"AttributeType": self.range_key_type,
}
return results
def __len__(self) -> int:
return sum(
[(len(value) if self.has_range_key else 1) for value in self.items.values()] # type: ignore
)
def __nonzero__(self) -> bool:
return True
def __bool__(self) -> bool:
return self.__nonzero__()
def put_item(self, item_attrs: Dict[str, Any]) -> Item:
hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) # type: ignore[arg-type]
if self.has_range_key:
range_value: Optional[DynamoType] = DynamoType(
item_attrs.get(self.range_key_attr)
) # type: ignore[arg-type]
else:
range_value = None
item = Item(
hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs
)
if range_value:
self.items[hash_value][range_value] = item # type: ignore[index]
else:
self.items[hash_value] = item
return item
def get_item(
self, hash_key: DynamoType, range_key: Optional[DynamoType]
) -> Optional[Item]:
if self.has_range_key and not range_key:
raise ValueError(
"Table has a range key, but no range key was passed into get_item"
)
try:
if range_key:
return self.items[hash_key][range_key] # type: ignore
else:
return self.items[hash_key] # type: ignore
except KeyError:
return None
def query(
self, hash_key: DynamoType, range_comparison: str, range_objs: Any
) -> Tuple[Iterable[Item], bool]:
results = []
last_page = True # Once pagination is implemented, change this
if self.range_key_attr:
possible_results = self.items[hash_key].values() # type: ignore[union-attr]
else:
possible_results = list(self.all_items())
if range_comparison:
for result in possible_results:
if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr]
results.append(result)
else:
# If we're not filtering on range key, return all values
results = possible_results # type: ignore[assignment]
return results, last_page
def all_items(self) -> Iterable[Item]:
for hash_set in self.items.values():
if self.range_key_attr:
for item in hash_set.values(): # type: ignore
yield item
else:
yield hash_set # type: ignore[misc]
def scan(self, filters: Dict[str, Any]) -> Tuple[List[Item], int, bool]:
results = []
scanned_count = 0
last_page = True # Once pagination is implemented, change this
for result in self.all_items():
scanned_count += 1
passes_all_conditions = True
for (
attribute_name,
(comparison_operator, comparison_objs),
) in filters.items():
attribute = result.attrs.get(attribute_name)
if attribute:
# Attribute found
if not attribute.compare(comparison_operator, comparison_objs):
passes_all_conditions = False
break
elif comparison_operator == "NULL":
# Comparison is NULL and we don't have the attribute
continue
else:
# No attribute found and comparison is no NULL. This item
# fails
passes_all_conditions = False
break
if passes_all_conditions:
results.append(result)
return results, scanned_count, last_page
def delete_item(
self, hash_key: DynamoType, range_key: Optional[DynamoType]
) -> Optional[Item]:
try:
if range_key:
return self.items[hash_key].pop(range_key) # type: ignore
else:
return self.items.pop(hash_key) # type: ignore
except KeyError:
return None
def update_item(
self,
hash_key: DynamoType,
range_key: Optional[DynamoType],
attr_updates: Dict[str, Any],
) -> Optional[Item]:
item = self.get_item(hash_key, range_key)
if not item:
return None
for attr, update in attr_updates.items():
if update["Action"] == "PUT":
item.attrs[attr] = DynamoType(update["Value"])
if update["Action"] == "DELETE":
item.attrs.pop(attr)
if update["Action"] == "ADD":
item.attrs[attr].add(DynamoType(update["Value"]))
return item
class DynamoDBBackend(BaseBackend):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.tables: Dict[str, Table] = OrderedDict()
def create_table(self, name: str, **params: Any) -> Table:
table = Table(self.account_id, name, **params)
self.tables[name] = table
return table
def delete_table(self, name: str) -> Optional[Table]:
return self.tables.pop(name, None)
def update_table_throughput(
self, name: str, new_read_units: str, new_write_units: str
) -> Table:
table = self.tables[name]
table.read_capacity = new_read_units
table.write_capacity = new_write_units
return table
def put_item(self, table_name: str, item_attrs: Dict[str, Any]) -> Optional[Item]:
table = self.tables.get(table_name)
if not table:
return None
return table.put_item(item_attrs)
def get_item(
self,
table_name: str,
hash_key_dict: Dict[str, Any],
range_key_dict: Optional[Dict[str, Any]],
) -> Optional[Item]:
table = self.tables.get(table_name)
if not table:
return None
hash_key = DynamoType(hash_key_dict)
range_key = DynamoType(range_key_dict) if range_key_dict else None
return table.get_item(hash_key, range_key)
def query(
self,
table_name: str,
hash_key_dict: Dict[str, Any],
range_comparison: str,
range_value_dicts: List[Dict[str, Any]],
) -> Tuple[Optional[Iterable[Item]], Optional[bool]]:
table = self.tables.get(table_name)
if not table:
return None, None
hash_key = DynamoType(hash_key_dict)
range_values = [DynamoType(range_value) for range_value in range_value_dicts]
return table.query(hash_key, range_comparison, range_values)
def scan(
self, table_name: str, filters: Dict[str, Any]
) -> Tuple[Optional[List[Item]], Optional[int], Optional[bool]]:
table = self.tables.get(table_name)
if not table:
return None, None, None
scan_filters = {}
for key, (comparison_operator, comparison_values) in filters.items():
dynamo_types = [DynamoType(value) for value in comparison_values]
scan_filters[key] = (comparison_operator, dynamo_types)
return table.scan(scan_filters)
def delete_item(
self,
table_name: str,
hash_key_dict: Dict[str, Any],
range_key_dict: Optional[Dict[str, Any]],
) -> Optional[Item]:
table = self.tables.get(table_name)
if not table:
return None
hash_key = DynamoType(hash_key_dict)
range_key = DynamoType(range_key_dict) if range_key_dict else None
return table.delete_item(hash_key, range_key)
def update_item(
self,
table_name: str,
hash_key_dict: Dict[str, Any],
range_key_dict: Optional[Dict[str, Any]],
attr_updates: Dict[str, Any],
) -> Optional[Item]:
table = self.tables.get(table_name)
if not table:
return None
hash_key = DynamoType(hash_key_dict)
range_key = DynamoType(range_key_dict) if range_key_dict else None
return table.update_item(hash_key, range_key, attr_updates)
dynamodb_backends = BackendDict(
DynamoDBBackend,
"dynamodb_v20111205",
use_boto3_regions=False,
additional_regions=["global"],
)