400 lines
12 KiB
Python
400 lines
12 KiB
Python
from collections import defaultdict
|
|
import datetime
|
|
import json
|
|
|
|
from collections import OrderedDict
|
|
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
|
|
from moto.core.utils import unix_time
|
|
from .comparisons import get_comparison_func
|
|
|
|
|
|
class DynamoJsonEncoder(json.JSONEncoder):
|
|
def default(self, o):
|
|
if hasattr(o, "to_json"):
|
|
return o.to_json()
|
|
|
|
|
|
def dynamo_json_dump(dynamo_object):
|
|
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
|
|
|
|
|
|
class DynamoType(object):
|
|
"""
|
|
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
|
|
"""
|
|
|
|
def __init__(self, type_as_dict):
|
|
self.type = list(type_as_dict.keys())[0]
|
|
self.value = list(type_as_dict.values())[0]
|
|
|
|
def __hash__(self):
|
|
return hash((self.type, self.value))
|
|
|
|
def __eq__(self, other):
|
|
return self.type == other.type and self.value == other.value
|
|
|
|
def __repr__(self):
|
|
return f"DynamoType: {self.to_json()}"
|
|
|
|
def add(self, dyn_type):
|
|
if self.type == "SS":
|
|
self.value.append(dyn_type.value)
|
|
if self.type == "N":
|
|
self.value = str(int(self.value) + int(dyn_type.value))
|
|
|
|
def to_json(self):
|
|
return {self.type: self.value}
|
|
|
|
def compare(self, range_comparison, range_objs):
|
|
"""
|
|
Compares this type against comparison filters
|
|
"""
|
|
range_values = [obj.value for obj in range_objs]
|
|
comparison_func = get_comparison_func(range_comparison)
|
|
return comparison_func(self.value, *range_values)
|
|
|
|
|
|
class Item(BaseModel):
|
|
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
|
|
self.hash_key = hash_key
|
|
self.hash_key_type = hash_key_type
|
|
self.range_key = range_key
|
|
self.range_key_type = range_key_type
|
|
|
|
self.attrs = {}
|
|
for key, value in attrs.items():
|
|
self.attrs[key] = DynamoType(value)
|
|
|
|
def __repr__(self):
|
|
return f"Item: {self.to_json()}"
|
|
|
|
def to_json(self):
|
|
attributes = {}
|
|
for attribute_key, attribute in self.attrs.items():
|
|
attributes[attribute_key] = attribute.value
|
|
|
|
return {"Attributes": attributes}
|
|
|
|
def describe_attrs(self, attributes):
|
|
if attributes:
|
|
included = {}
|
|
for key, value in self.attrs.items():
|
|
if key in attributes:
|
|
included[key] = value
|
|
else:
|
|
included = self.attrs
|
|
return {"Item": included}
|
|
|
|
|
|
class Table(CloudFormationModel):
|
|
def __init__(
|
|
self,
|
|
account_id,
|
|
name,
|
|
hash_key_attr,
|
|
hash_key_type,
|
|
range_key_attr=None,
|
|
range_key_type=None,
|
|
read_capacity=None,
|
|
write_capacity=None,
|
|
):
|
|
self.account_id = account_id
|
|
self.name = name
|
|
self.hash_key_attr = hash_key_attr
|
|
self.hash_key_type = hash_key_type
|
|
self.range_key_attr = range_key_attr
|
|
self.range_key_type = range_key_type
|
|
self.read_capacity = read_capacity
|
|
self.write_capacity = write_capacity
|
|
self.created_at = datetime.datetime.utcnow()
|
|
self.items = defaultdict(dict)
|
|
|
|
@property
|
|
def has_range_key(self):
|
|
return self.range_key_attr is not None
|
|
|
|
@property
|
|
def describe(self):
|
|
results = {
|
|
"Table": {
|
|
"CreationDateTime": unix_time(self.created_at),
|
|
"KeySchema": {
|
|
"HashKeyElement": {
|
|
"AttributeName": self.hash_key_attr,
|
|
"AttributeType": self.hash_key_type,
|
|
}
|
|
},
|
|
"ProvisionedThroughput": {
|
|
"ReadCapacityUnits": self.read_capacity,
|
|
"WriteCapacityUnits": self.write_capacity,
|
|
},
|
|
"TableName": self.name,
|
|
"TableStatus": "ACTIVE",
|
|
"ItemCount": len(self),
|
|
"TableSizeBytes": 0,
|
|
}
|
|
}
|
|
if self.has_range_key:
|
|
results["Table"]["KeySchema"]["RangeKeyElement"] = {
|
|
"AttributeName": self.range_key_attr,
|
|
"AttributeType": self.range_key_type,
|
|
}
|
|
return results
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type():
|
|
return "TableName"
|
|
|
|
@staticmethod
|
|
def cloudformation_type():
|
|
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html
|
|
return "AWS::DynamoDB::Table"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json(
|
|
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
|
|
):
|
|
properties = cloudformation_json["Properties"]
|
|
key_attr = [
|
|
i["AttributeName"]
|
|
for i in properties["KeySchema"]
|
|
if i["KeyType"] == "HASH"
|
|
][0]
|
|
key_type = [
|
|
i["AttributeType"]
|
|
for i in properties["AttributeDefinitions"]
|
|
if i["AttributeName"] == key_attr
|
|
][0]
|
|
spec = {
|
|
"account_id": account_id,
|
|
"name": properties["TableName"],
|
|
"hash_key_attr": key_attr,
|
|
"hash_key_type": key_type,
|
|
}
|
|
# TODO: optional properties still missing:
|
|
# range_key_attr, range_key_type, read_capacity, write_capacity
|
|
return Table(**spec)
|
|
|
|
def __len__(self):
|
|
return sum(
|
|
[(len(value) if self.has_range_key else 1) for value in self.items.values()]
|
|
)
|
|
|
|
def __nonzero__(self):
|
|
return True
|
|
|
|
def __bool__(self):
|
|
return self.__nonzero__()
|
|
|
|
def put_item(self, item_attrs):
|
|
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
|
|
if self.has_range_key:
|
|
range_value = DynamoType(item_attrs.get(self.range_key_attr))
|
|
else:
|
|
range_value = None
|
|
|
|
item = Item(
|
|
hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs
|
|
)
|
|
|
|
if range_value:
|
|
self.items[hash_value][range_value] = item
|
|
else:
|
|
self.items[hash_value] = item
|
|
return item
|
|
|
|
def get_item(self, hash_key, range_key):
|
|
if self.has_range_key and not range_key:
|
|
raise ValueError(
|
|
"Table has a range key, but no range key was passed into get_item"
|
|
)
|
|
try:
|
|
if range_key:
|
|
return self.items[hash_key][range_key]
|
|
else:
|
|
return self.items[hash_key]
|
|
except KeyError:
|
|
return None
|
|
|
|
def query(self, hash_key, range_comparison, range_objs):
|
|
results = []
|
|
last_page = True # Once pagination is implemented, change this
|
|
|
|
if self.range_key_attr:
|
|
possible_results = self.items[hash_key].values()
|
|
else:
|
|
possible_results = list(self.all_items())
|
|
|
|
if range_comparison:
|
|
for result in possible_results:
|
|
if result.range_key.compare(range_comparison, range_objs):
|
|
results.append(result)
|
|
else:
|
|
# If we're not filtering on range key, return all values
|
|
results = possible_results
|
|
return results, last_page
|
|
|
|
def all_items(self):
|
|
for hash_set in self.items.values():
|
|
if self.range_key_attr:
|
|
for item in hash_set.values():
|
|
yield item
|
|
else:
|
|
yield hash_set
|
|
|
|
def scan(self, filters):
|
|
results = []
|
|
scanned_count = 0
|
|
last_page = True # Once pagination is implemented, change this
|
|
|
|
for result in self.all_items():
|
|
scanned_count += 1
|
|
passes_all_conditions = True
|
|
for (
|
|
attribute_name,
|
|
(comparison_operator, comparison_objs),
|
|
) in filters.items():
|
|
attribute = result.attrs.get(attribute_name)
|
|
|
|
if attribute:
|
|
# Attribute found
|
|
if not attribute.compare(comparison_operator, comparison_objs):
|
|
passes_all_conditions = False
|
|
break
|
|
elif comparison_operator == "NULL":
|
|
# Comparison is NULL and we don't have the attribute
|
|
continue
|
|
else:
|
|
# No attribute found and comparison is no NULL. This item
|
|
# fails
|
|
passes_all_conditions = False
|
|
break
|
|
|
|
if passes_all_conditions:
|
|
results.append(result)
|
|
|
|
return results, scanned_count, last_page
|
|
|
|
def delete_item(self, hash_key, range_key):
|
|
try:
|
|
if range_key:
|
|
return self.items[hash_key].pop(range_key)
|
|
else:
|
|
return self.items.pop(hash_key)
|
|
except KeyError:
|
|
return None
|
|
|
|
def update_item(self, hash_key, range_key, attr_updates):
|
|
item = self.get_item(hash_key, range_key)
|
|
if not item:
|
|
return None
|
|
|
|
for attr, update in attr_updates.items():
|
|
if update["Action"] == "PUT":
|
|
item.attrs[attr] = DynamoType(update["Value"])
|
|
if update["Action"] == "DELETE":
|
|
item.attrs.pop(attr)
|
|
if update["Action"] == "ADD":
|
|
item.attrs[attr].add(DynamoType(update["Value"]))
|
|
return item
|
|
|
|
@classmethod
|
|
def has_cfn_attr(cls, attr):
|
|
return attr in ["StreamArn"]
|
|
|
|
def get_cfn_attribute(self, attribute_name):
|
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
|
|
|
if attribute_name == "StreamArn":
|
|
region = "us-east-1"
|
|
time = "2000-01-01T00:00:00.000"
|
|
return f"arn:aws:dynamodb:{region}:{self.account_id}:table/{self.name}/stream/{time}"
|
|
raise UnformattedGetAttTemplateException()
|
|
|
|
|
|
class DynamoDBBackend(BaseBackend):
|
|
def __init__(self, region_name, account_id):
|
|
super().__init__(region_name, account_id)
|
|
self.tables = OrderedDict()
|
|
|
|
def create_table(self, name, **params):
|
|
table = Table(self.account_id, name, **params)
|
|
self.tables[name] = table
|
|
return table
|
|
|
|
def delete_table(self, name):
|
|
return self.tables.pop(name, None)
|
|
|
|
def update_table_throughput(self, name, new_read_units, new_write_units):
|
|
table = self.tables[name]
|
|
table.read_capacity = new_read_units
|
|
table.write_capacity = new_write_units
|
|
return table
|
|
|
|
def put_item(self, table_name, item_attrs):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
return table.put_item(item_attrs)
|
|
|
|
def get_item(self, table_name, hash_key_dict, range_key_dict):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
|
|
|
return table.get_item(hash_key, range_key)
|
|
|
|
def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None, None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_values = [DynamoType(range_value) for range_value in range_value_dicts]
|
|
|
|
return table.query(hash_key, range_comparison, range_values)
|
|
|
|
def scan(self, table_name, filters):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None, None, None
|
|
|
|
scan_filters = {}
|
|
for key, (comparison_operator, comparison_values) in filters.items():
|
|
dynamo_types = [DynamoType(value) for value in comparison_values]
|
|
scan_filters[key] = (comparison_operator, dynamo_types)
|
|
|
|
return table.scan(scan_filters)
|
|
|
|
def delete_item(self, table_name, hash_key_dict, range_key_dict):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
|
|
|
return table.delete_item(hash_key, range_key)
|
|
|
|
def update_item(self, table_name, hash_key_dict, range_key_dict, attr_updates):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
|
|
|
return table.update_item(hash_key, range_key, attr_updates)
|
|
|
|
|
|
dynamodb_backends = BackendDict(
|
|
DynamoDBBackend,
|
|
"dynamodb_v20111205",
|
|
use_boto3_regions=False,
|
|
additional_regions=["global"],
|
|
)
|