342 lines
11 KiB
Python
342 lines
11 KiB
Python
from __future__ import unicode_literals
|
|
from collections import defaultdict
|
|
import datetime
|
|
import json
|
|
|
|
from moto.compat import OrderedDict
|
|
from moto.core import BaseBackend, BaseModel
|
|
from moto.core.utils import unix_time
|
|
from .comparisons import get_comparison_func
|
|
|
|
|
|
class DynamoJsonEncoder(json.JSONEncoder):
|
|
|
|
def default(self, obj):
|
|
if hasattr(obj, 'to_json'):
|
|
return obj.to_json()
|
|
|
|
|
|
def dynamo_json_dump(dynamo_object):
|
|
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
|
|
|
|
|
|
class DynamoType(object):
|
|
"""
|
|
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
|
|
"""
|
|
|
|
def __init__(self, type_as_dict):
|
|
self.type = list(type_as_dict.keys())[0]
|
|
self.value = list(type_as_dict.values())[0]
|
|
|
|
def __hash__(self):
|
|
return hash((self.type, self.value))
|
|
|
|
def __eq__(self, other):
|
|
return (
|
|
self.type == other.type and
|
|
self.value == other.value
|
|
)
|
|
|
|
def __repr__(self):
|
|
return "DynamoType: {0}".format(self.to_json())
|
|
|
|
def to_json(self):
|
|
return {self.type: self.value}
|
|
|
|
def compare(self, range_comparison, range_objs):
|
|
"""
|
|
Compares this type against comparison filters
|
|
"""
|
|
range_values = [obj.value for obj in range_objs]
|
|
comparison_func = get_comparison_func(range_comparison)
|
|
return comparison_func(self.value, *range_values)
|
|
|
|
|
|
class Item(BaseModel):
|
|
|
|
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
|
|
self.hash_key = hash_key
|
|
self.hash_key_type = hash_key_type
|
|
self.range_key = range_key
|
|
self.range_key_type = range_key_type
|
|
|
|
self.attrs = {}
|
|
for key, value in attrs.items():
|
|
self.attrs[key] = DynamoType(value)
|
|
|
|
def __repr__(self):
|
|
return "Item: {0}".format(self.to_json())
|
|
|
|
def to_json(self):
|
|
attributes = {}
|
|
for attribute_key, attribute in self.attrs.items():
|
|
attributes[attribute_key] = attribute.value
|
|
|
|
return {
|
|
"Attributes": attributes
|
|
}
|
|
|
|
def describe_attrs(self, attributes):
|
|
if attributes:
|
|
included = {}
|
|
for key, value in self.attrs.items():
|
|
if key in attributes:
|
|
included[key] = value
|
|
else:
|
|
included = self.attrs
|
|
return {
|
|
"Item": included
|
|
}
|
|
|
|
|
|
class Table(BaseModel):
|
|
|
|
def __init__(self, name, hash_key_attr, hash_key_type,
|
|
range_key_attr=None, range_key_type=None, read_capacity=None,
|
|
write_capacity=None):
|
|
self.name = name
|
|
self.hash_key_attr = hash_key_attr
|
|
self.hash_key_type = hash_key_type
|
|
self.range_key_attr = range_key_attr
|
|
self.range_key_type = range_key_type
|
|
self.read_capacity = read_capacity
|
|
self.write_capacity = write_capacity
|
|
self.created_at = datetime.datetime.utcnow()
|
|
self.items = defaultdict(dict)
|
|
|
|
@property
|
|
def has_range_key(self):
|
|
return self.range_key_attr is not None
|
|
|
|
@property
|
|
def describe(self):
|
|
results = {
|
|
"Table": {
|
|
"CreationDateTime": unix_time(self.created_at),
|
|
"KeySchema": {
|
|
"HashKeyElement": {
|
|
"AttributeName": self.hash_key_attr,
|
|
"AttributeType": self.hash_key_type
|
|
},
|
|
},
|
|
"ProvisionedThroughput": {
|
|
"ReadCapacityUnits": self.read_capacity,
|
|
"WriteCapacityUnits": self.write_capacity
|
|
},
|
|
"TableName": self.name,
|
|
"TableStatus": "ACTIVE",
|
|
"ItemCount": len(self),
|
|
"TableSizeBytes": 0,
|
|
}
|
|
}
|
|
if self.has_range_key:
|
|
results["Table"]["KeySchema"]["RangeKeyElement"] = {
|
|
"AttributeName": self.range_key_attr,
|
|
"AttributeType": self.range_key_type
|
|
}
|
|
return results
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
|
|
properties = cloudformation_json['Properties']
|
|
key_attr = [i['AttributeName'] for i in properties['KeySchema'] if i['KeyType'] == 'HASH'][0]
|
|
key_type = [i['AttributeType'] for i in properties['AttributeDefinitions'] if i['AttributeName'] == key_attr][0]
|
|
spec = {
|
|
'name': properties['TableName'],
|
|
'hash_key_attr': key_attr,
|
|
'hash_key_type': key_type
|
|
}
|
|
# TODO: optional properties still missing:
|
|
# range_key_attr, range_key_type, read_capacity, write_capacity
|
|
return Table(**spec)
|
|
|
|
def __len__(self):
|
|
count = 0
|
|
for key, value in self.items.items():
|
|
if self.has_range_key:
|
|
count += len(value)
|
|
else:
|
|
count += 1
|
|
return count
|
|
|
|
def __nonzero__(self):
|
|
return True
|
|
|
|
def __bool__(self):
|
|
return self.__nonzero__()
|
|
|
|
def put_item(self, item_attrs):
|
|
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
|
|
if self.has_range_key:
|
|
range_value = DynamoType(item_attrs.get(self.range_key_attr))
|
|
else:
|
|
range_value = None
|
|
|
|
item = Item(hash_value, self.hash_key_type, range_value,
|
|
self.range_key_type, item_attrs)
|
|
|
|
if range_value:
|
|
self.items[hash_value][range_value] = item
|
|
else:
|
|
self.items[hash_value] = item
|
|
return item
|
|
|
|
def get_item(self, hash_key, range_key):
|
|
if self.has_range_key and not range_key:
|
|
raise ValueError(
|
|
"Table has a range key, but no range key was passed into get_item")
|
|
try:
|
|
if range_key:
|
|
return self.items[hash_key][range_key]
|
|
else:
|
|
return self.items[hash_key]
|
|
except KeyError:
|
|
return None
|
|
|
|
def query(self, hash_key, range_comparison, range_objs):
|
|
results = []
|
|
last_page = True # Once pagination is implemented, change this
|
|
|
|
if self.range_key_attr:
|
|
possible_results = self.items[hash_key].values()
|
|
else:
|
|
possible_results = list(self.all_items())
|
|
|
|
if range_comparison:
|
|
for result in possible_results:
|
|
if result.range_key.compare(range_comparison, range_objs):
|
|
results.append(result)
|
|
else:
|
|
# If we're not filtering on range key, return all values
|
|
results = possible_results
|
|
return results, last_page
|
|
|
|
def all_items(self):
|
|
for hash_set in self.items.values():
|
|
if self.range_key_attr:
|
|
for item in hash_set.values():
|
|
yield item
|
|
else:
|
|
yield hash_set
|
|
|
|
def scan(self, filters):
|
|
results = []
|
|
scanned_count = 0
|
|
last_page = True # Once pagination is implemented, change this
|
|
|
|
for result in self.all_items():
|
|
scanned_count += 1
|
|
passes_all_conditions = True
|
|
for attribute_name, (comparison_operator, comparison_objs) in filters.items():
|
|
attribute = result.attrs.get(attribute_name)
|
|
|
|
if attribute:
|
|
# Attribute found
|
|
if not attribute.compare(comparison_operator, comparison_objs):
|
|
passes_all_conditions = False
|
|
break
|
|
elif comparison_operator == 'NULL':
|
|
# Comparison is NULL and we don't have the attribute
|
|
continue
|
|
else:
|
|
# No attribute found and comparison is no NULL. This item
|
|
# fails
|
|
passes_all_conditions = False
|
|
break
|
|
|
|
if passes_all_conditions:
|
|
results.append(result)
|
|
|
|
return results, scanned_count, last_page
|
|
|
|
def delete_item(self, hash_key, range_key):
|
|
try:
|
|
if range_key:
|
|
return self.items[hash_key].pop(range_key)
|
|
else:
|
|
return self.items.pop(hash_key)
|
|
except KeyError:
|
|
return None
|
|
|
|
def get_cfn_attribute(self, attribute_name):
|
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
|
if attribute_name == 'StreamArn':
|
|
region = 'us-east-1'
|
|
time = '2000-01-01T00:00:00.000'
|
|
return 'arn:aws:dynamodb:{0}:123456789012:table/{1}/stream/{2}'.format(region, self.name, time)
|
|
raise UnformattedGetAttTemplateException()
|
|
|
|
|
|
class DynamoDBBackend(BaseBackend):
|
|
|
|
def __init__(self):
|
|
self.tables = OrderedDict()
|
|
|
|
def create_table(self, name, **params):
|
|
table = Table(name, **params)
|
|
self.tables[name] = table
|
|
return table
|
|
|
|
def delete_table(self, name):
|
|
return self.tables.pop(name, None)
|
|
|
|
def update_table_throughput(self, name, new_read_units, new_write_units):
|
|
table = self.tables[name]
|
|
table.read_capacity = new_read_units
|
|
table.write_capacity = new_write_units
|
|
return table
|
|
|
|
def put_item(self, table_name, item_attrs):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
return table.put_item(item_attrs)
|
|
|
|
def get_item(self, table_name, hash_key_dict, range_key_dict):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
|
|
|
return table.get_item(hash_key, range_key)
|
|
|
|
def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None, None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_values = [DynamoType(range_value)
|
|
for range_value in range_value_dicts]
|
|
|
|
return table.query(hash_key, range_comparison, range_values)
|
|
|
|
def scan(self, table_name, filters):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None, None, None
|
|
|
|
scan_filters = {}
|
|
for key, (comparison_operator, comparison_values) in filters.items():
|
|
dynamo_types = [DynamoType(value) for value in comparison_values]
|
|
scan_filters[key] = (comparison_operator, dynamo_types)
|
|
|
|
return table.scan(scan_filters)
|
|
|
|
def delete_item(self, table_name, hash_key_dict, range_key_dict):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
|
|
|
return table.delete_item(hash_key, range_key)
|
|
|
|
|
|
dynamodb_backend = DynamoDBBackend()
|