The AWS API represents a set object as a list of values. Internally moto also represents a set as a list. This means that when we do value comparisons, the order of the values can cause a set equality test to fail.
1011 lines
39 KiB
Python
1011 lines
39 KiB
Python
from __future__ import unicode_literals
|
|
from collections import defaultdict
|
|
import copy
|
|
import datetime
|
|
import decimal
|
|
import json
|
|
import re
|
|
import uuid
|
|
|
|
import boto3
|
|
from moto.compat import OrderedDict
|
|
from moto.core import BaseBackend, BaseModel
|
|
from moto.core.utils import unix_time
|
|
from moto.core.exceptions import JsonRESTError
|
|
from .comparisons import get_comparison_func, get_filter_expression, Op
|
|
|
|
|
|
class DynamoJsonEncoder(json.JSONEncoder):
|
|
|
|
def default(self, obj):
|
|
if hasattr(obj, 'to_json'):
|
|
return obj.to_json()
|
|
|
|
|
|
def dynamo_json_dump(dynamo_object):
|
|
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
|
|
|
|
|
|
class DynamoType(object):
|
|
"""
|
|
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
|
|
"""
|
|
|
|
def __init__(self, type_as_dict):
|
|
self.type = list(type_as_dict)[0]
|
|
self.value = list(type_as_dict.values())[0]
|
|
|
|
def __hash__(self):
|
|
return hash((self.type, self.value))
|
|
|
|
def __eq__(self, other):
|
|
return (
|
|
self.type == other.type and
|
|
self.value == other.value
|
|
)
|
|
|
|
def __lt__(self, other):
|
|
return self.value < other.value
|
|
|
|
def __le__(self, other):
|
|
return self.value <= other.value
|
|
|
|
def __gt__(self, other):
|
|
return self.value > other.value
|
|
|
|
def __ge__(self, other):
|
|
return self.value >= other.value
|
|
|
|
def __repr__(self):
|
|
return "DynamoType: {0}".format(self.to_json())
|
|
|
|
@property
|
|
def cast_value(self):
|
|
if self.is_number():
|
|
try:
|
|
return int(self.value)
|
|
except ValueError:
|
|
return float(self.value)
|
|
elif self.is_set():
|
|
return set(self.value)
|
|
else:
|
|
return self.value
|
|
|
|
def to_json(self):
|
|
return {self.type: self.value}
|
|
|
|
def compare(self, range_comparison, range_objs):
|
|
"""
|
|
Compares this type against comparison filters
|
|
"""
|
|
range_values = [obj.cast_value for obj in range_objs]
|
|
comparison_func = get_comparison_func(range_comparison)
|
|
return comparison_func(self.cast_value, *range_values)
|
|
|
|
def is_number(self):
|
|
return self.type == 'N'
|
|
|
|
def is_set(self):
|
|
return self.type == 'SS' or self.type == 'NS' or self.type == 'BS'
|
|
|
|
def same_type(self, other):
|
|
return self.type == other.type
|
|
|
|
|
|
class Item(BaseModel):
|
|
|
|
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
|
|
self.hash_key = hash_key
|
|
self.hash_key_type = hash_key_type
|
|
self.range_key = range_key
|
|
self.range_key_type = range_key_type
|
|
|
|
self.attrs = {}
|
|
for key, value in attrs.items():
|
|
self.attrs[key] = DynamoType(value)
|
|
|
|
def __repr__(self):
|
|
return "Item: {0}".format(self.to_json())
|
|
|
|
def to_json(self):
|
|
attributes = {}
|
|
for attribute_key, attribute in self.attrs.items():
|
|
attributes[attribute_key] = {
|
|
attribute.type: attribute.value
|
|
}
|
|
|
|
return {
|
|
"Attributes": attributes
|
|
}
|
|
|
|
def describe_attrs(self, attributes):
|
|
if attributes:
|
|
included = {}
|
|
for key, value in self.attrs.items():
|
|
if key in attributes:
|
|
included[key] = value
|
|
else:
|
|
included = self.attrs
|
|
return {
|
|
"Item": included
|
|
}
|
|
|
|
def update(self, update_expression, expression_attribute_names, expression_attribute_values):
|
|
# Update subexpressions are identifiable by the operator keyword, so split on that and
|
|
# get rid of the empty leading string.
|
|
parts = [p for p in re.split(r'\b(SET|REMOVE|ADD|DELETE)\b', update_expression, flags=re.I) if p]
|
|
# make sure that we correctly found only operator/value pairs
|
|
assert len(parts) % 2 == 0, "Mismatched operators and values in update expression: '{}'".format(update_expression)
|
|
for action, valstr in zip(parts[:-1:2], parts[1::2]):
|
|
action = action.upper()
|
|
|
|
# "Should" retain arguments in side (...)
|
|
values = re.split(r',(?![^(]*\))', valstr)
|
|
for value in values:
|
|
# A Real value
|
|
value = value.lstrip(":").rstrip(",").strip()
|
|
for k, v in expression_attribute_names.items():
|
|
value = re.sub(r'{0}\b'.format(k), v, value)
|
|
|
|
if action == "REMOVE":
|
|
self.attrs.pop(value, None)
|
|
elif action == 'SET':
|
|
key, value = value.split("=", 1)
|
|
key = key.strip()
|
|
value = value.strip()
|
|
|
|
# If not exists, changes value to a default if needed, else its the same as it was
|
|
if value.startswith('if_not_exists'):
|
|
# Function signature
|
|
match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value)
|
|
if not match:
|
|
raise TypeError
|
|
|
|
path, value = match.groups()
|
|
|
|
# If it already exists, get its value so we dont overwrite it
|
|
if path in self.attrs:
|
|
value = self.attrs[path]
|
|
|
|
if type(value) != DynamoType:
|
|
if value in expression_attribute_values:
|
|
value = DynamoType(expression_attribute_values[value])
|
|
else:
|
|
value = DynamoType({"S": value})
|
|
|
|
if '.' not in key:
|
|
self.attrs[key] = value
|
|
else:
|
|
# Handle nested dict updates
|
|
key_parts = key.split('.')
|
|
attr = key_parts.pop(0)
|
|
if attr not in self.attrs:
|
|
raise ValueError
|
|
|
|
last_val = self.attrs[attr].value
|
|
for key_part in key_parts:
|
|
# Hack but it'll do, traverses into a dict
|
|
last_val_type = list(last_val.keys())
|
|
if last_val_type and last_val_type[0] == 'M':
|
|
last_val = last_val['M']
|
|
|
|
if key_part not in last_val:
|
|
last_val[key_part] = {'M': {}}
|
|
|
|
last_val = last_val[key_part]
|
|
|
|
# We have reference to a nested object but we cant just assign to it
|
|
current_type = list(last_val.keys())[0]
|
|
if current_type == value.type:
|
|
last_val[current_type] = value.value
|
|
else:
|
|
last_val[value.type] = value.value
|
|
del last_val[current_type]
|
|
|
|
elif action == 'ADD':
|
|
key, value = value.split(" ", 1)
|
|
key = key.strip()
|
|
value_str = value.strip()
|
|
if value_str in expression_attribute_values:
|
|
dyn_value = DynamoType(expression_attribute_values[value])
|
|
else:
|
|
raise TypeError
|
|
|
|
# Handle adding numbers - value gets added to existing value,
|
|
# or added to 0 if it doesn't exist yet
|
|
if dyn_value.is_number():
|
|
existing = self.attrs.get(key, DynamoType({"N": '0'}))
|
|
if not existing.same_type(dyn_value):
|
|
raise TypeError()
|
|
self.attrs[key] = DynamoType({"N": str(
|
|
decimal.Decimal(existing.value) +
|
|
decimal.Decimal(dyn_value.value)
|
|
)})
|
|
|
|
# Handle adding sets - value is added to the set, or set is
|
|
# created with only this value if it doesn't exist yet
|
|
# New value must be of same set type as previous value
|
|
elif dyn_value.is_set():
|
|
existing = self.attrs.get(key, DynamoType({dyn_value.type: {}}))
|
|
if not existing.same_type(dyn_value):
|
|
raise TypeError()
|
|
new_set = set(existing.value).union(dyn_value.value)
|
|
self.attrs[key] = DynamoType({existing.type: list(new_set)})
|
|
else: # Number and Sets are the only supported types for ADD
|
|
raise TypeError
|
|
|
|
elif action == 'DELETE':
|
|
key, value = value.split(" ", 1)
|
|
key = key.strip()
|
|
value_str = value.strip()
|
|
if value_str in expression_attribute_values:
|
|
dyn_value = DynamoType(expression_attribute_values[value])
|
|
else:
|
|
raise TypeError
|
|
|
|
if not dyn_value.is_set():
|
|
raise TypeError
|
|
existing = self.attrs.get(key, None)
|
|
if existing:
|
|
if not existing.same_type(dyn_value):
|
|
raise TypeError
|
|
new_set = set(existing.value).difference(dyn_value.value)
|
|
self.attrs[key] = DynamoType({existing.type: list(new_set)})
|
|
else:
|
|
raise NotImplementedError('{} update action not yet supported'.format(action))
|
|
|
|
def update_with_attribute_updates(self, attribute_updates):
|
|
for attribute_name, update_action in attribute_updates.items():
|
|
action = update_action['Action']
|
|
if action == 'DELETE' and 'Value' not in update_action:
|
|
if attribute_name in self.attrs:
|
|
del self.attrs[attribute_name]
|
|
continue
|
|
new_value = list(update_action['Value'].values())[0]
|
|
if action == 'PUT':
|
|
# TODO deal with other types
|
|
if isinstance(new_value, list) or isinstance(new_value, set):
|
|
self.attrs[attribute_name] = DynamoType({"SS": new_value})
|
|
elif isinstance(new_value, dict):
|
|
self.attrs[attribute_name] = DynamoType({"M": new_value})
|
|
elif set(update_action['Value'].keys()) == set(['N']):
|
|
self.attrs[attribute_name] = DynamoType({"N": new_value})
|
|
elif set(update_action['Value'].keys()) == set(['NULL']):
|
|
if attribute_name in self.attrs:
|
|
del self.attrs[attribute_name]
|
|
else:
|
|
self.attrs[attribute_name] = DynamoType({"S": new_value})
|
|
elif action == 'ADD':
|
|
if set(update_action['Value'].keys()) == set(['N']):
|
|
existing = self.attrs.get(
|
|
attribute_name, DynamoType({"N": '0'}))
|
|
self.attrs[attribute_name] = DynamoType({"N": str(
|
|
decimal.Decimal(existing.value) +
|
|
decimal.Decimal(new_value)
|
|
)})
|
|
elif set(update_action['Value'].keys()) == set(['SS']):
|
|
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
|
|
new_set = set(existing.value).union(set(new_value))
|
|
self.attrs[attribute_name] = DynamoType({
|
|
"SS": list(new_set)
|
|
})
|
|
else:
|
|
# TODO: implement other data types
|
|
raise NotImplementedError(
|
|
'ADD not supported for %s' % ', '.join(update_action['Value'].keys()))
|
|
|
|
|
|
class StreamRecord(BaseModel):
|
|
def __init__(self, table, stream_type, event_name, old, new, seq):
|
|
old_a = old.to_json()['Attributes'] if old is not None else {}
|
|
new_a = new.to_json()['Attributes'] if new is not None else {}
|
|
|
|
rec = old if old is not None else new
|
|
keys = {table.hash_key_attr: rec.hash_key.to_json()}
|
|
if table.range_key_attr is not None:
|
|
keys[table.range_key_attr] = rec.range_key.to_json()
|
|
|
|
self.record = {
|
|
'eventID': uuid.uuid4().hex,
|
|
'eventName': event_name,
|
|
'eventSource': 'aws:dynamodb',
|
|
'eventVersion': '1.0',
|
|
'awsRegion': 'us-east-1',
|
|
'dynamodb': {
|
|
'StreamViewType': stream_type,
|
|
'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(),
|
|
'SequenceNumber': seq,
|
|
'SizeBytes': 1,
|
|
'Keys': keys
|
|
}
|
|
}
|
|
|
|
if stream_type in ('NEW_IMAGE', 'NEW_AND_OLD_IMAGES'):
|
|
self.record['dynamodb']['NewImage'] = new_a
|
|
if stream_type in ('OLD_IMAGE', 'NEW_AND_OLD_IMAGES'):
|
|
self.record['dynamodb']['OldImage'] = old_a
|
|
|
|
# This is a substantial overestimate but it's the easiest to do now
|
|
self.record['dynamodb']['SizeBytes'] = len(
|
|
json.dumps(self.record['dynamodb']))
|
|
|
|
def to_json(self):
|
|
return self.record
|
|
|
|
|
|
class StreamShard(BaseModel):
|
|
def __init__(self, table):
|
|
self.table = table
|
|
self.id = 'shardId-00000001541626099285-f35f62ef'
|
|
self.starting_sequence_number = 1100000000017454423009
|
|
self.items = []
|
|
self.created_on = datetime.datetime.utcnow()
|
|
|
|
def to_json(self):
|
|
return {
|
|
'ShardId': self.id,
|
|
'SequenceNumberRange': {
|
|
'StartingSequenceNumber': str(self.starting_sequence_number)
|
|
}
|
|
}
|
|
|
|
def add(self, old, new):
|
|
t = self.table.stream_specification['StreamViewType']
|
|
if old is None:
|
|
event_name = 'INSERT'
|
|
elif new is None:
|
|
event_name = 'DELETE'
|
|
else:
|
|
event_name = 'MODIFY'
|
|
seq = len(self.items) + self.starting_sequence_number
|
|
self.items.append(
|
|
StreamRecord(self.table, t, event_name, old, new, seq))
|
|
|
|
def get(self, start, quantity):
|
|
start -= self.starting_sequence_number
|
|
assert start >= 0
|
|
end = start + quantity
|
|
return [i.to_json() for i in self.items[start:end]]
|
|
|
|
|
|
class Table(BaseModel):
|
|
|
|
def __init__(self, table_name, schema=None, attr=None, throughput=None, indexes=None, global_indexes=None, streams=None):
|
|
self.name = table_name
|
|
self.attr = attr
|
|
self.schema = schema
|
|
self.range_key_attr = None
|
|
self.hash_key_attr = None
|
|
self.range_key_type = None
|
|
self.hash_key_type = None
|
|
for elem in schema:
|
|
if elem["KeyType"] == "HASH":
|
|
self.hash_key_attr = elem["AttributeName"]
|
|
self.hash_key_type = elem["KeyType"]
|
|
else:
|
|
self.range_key_attr = elem["AttributeName"]
|
|
self.range_key_type = elem["KeyType"]
|
|
if throughput is None:
|
|
self.throughput = {
|
|
'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10}
|
|
else:
|
|
self.throughput = throughput
|
|
self.throughput["NumberOfDecreasesToday"] = 0
|
|
self.indexes = indexes
|
|
self.global_indexes = global_indexes if global_indexes else []
|
|
self.created_at = datetime.datetime.utcnow()
|
|
self.items = defaultdict(dict)
|
|
self.table_arn = self._generate_arn(table_name)
|
|
self.tags = []
|
|
self.ttl = {
|
|
'TimeToLiveStatus': 'DISABLED' # One of 'ENABLING'|'DISABLING'|'ENABLED'|'DISABLED',
|
|
# 'AttributeName': 'string' # Can contain this
|
|
}
|
|
self.set_stream_specification(streams)
|
|
|
|
def _generate_arn(self, name):
|
|
return 'arn:aws:dynamodb:us-east-1:123456789011:table/' + name
|
|
|
|
def set_stream_specification(self, streams):
|
|
self.stream_specification = streams
|
|
if streams and (streams.get('StreamEnabled') or streams.get('StreamViewType')):
|
|
self.stream_specification['StreamEnabled'] = True
|
|
self.latest_stream_label = datetime.datetime.utcnow().isoformat()
|
|
self.stream_shard = StreamShard(self)
|
|
else:
|
|
self.stream_specification = {'StreamEnabled': False}
|
|
self.latest_stream_label = None
|
|
self.stream_shard = None
|
|
|
|
def describe(self, base_key='TableDescription'):
|
|
results = {
|
|
base_key: {
|
|
'AttributeDefinitions': self.attr,
|
|
'ProvisionedThroughput': self.throughput,
|
|
'TableSizeBytes': 0,
|
|
'TableName': self.name,
|
|
'TableStatus': 'ACTIVE',
|
|
'TableArn': self.table_arn,
|
|
'KeySchema': self.schema,
|
|
'ItemCount': len(self),
|
|
'CreationDateTime': unix_time(self.created_at),
|
|
'GlobalSecondaryIndexes': [index for index in self.global_indexes],
|
|
'LocalSecondaryIndexes': [index for index in self.indexes],
|
|
}
|
|
}
|
|
if self.stream_specification and self.stream_specification['StreamEnabled']:
|
|
results[base_key]['StreamSpecification'] = self.stream_specification
|
|
if self.latest_stream_label:
|
|
results[base_key]['LatestStreamLabel'] = self.latest_stream_label
|
|
results[base_key]['LatestStreamArn'] = self.table_arn + '/stream/' + self.latest_stream_label
|
|
return results
|
|
|
|
def __len__(self):
|
|
count = 0
|
|
for key, value in self.items.items():
|
|
if self.has_range_key:
|
|
count += len(value)
|
|
else:
|
|
count += 1
|
|
return count
|
|
|
|
@property
|
|
def hash_key_names(self):
|
|
keys = [self.hash_key_attr]
|
|
for index in self.global_indexes:
|
|
hash_key = None
|
|
for key in index['KeySchema']:
|
|
if key['KeyType'] == 'HASH':
|
|
hash_key = key['AttributeName']
|
|
keys.append(hash_key)
|
|
return keys
|
|
|
|
@property
|
|
def range_key_names(self):
|
|
keys = [self.range_key_attr]
|
|
for index in self.global_indexes:
|
|
range_key = None
|
|
for key in index['KeySchema']:
|
|
if key['KeyType'] == 'RANGE':
|
|
range_key = keys.append(key['AttributeName'])
|
|
keys.append(range_key)
|
|
return keys
|
|
|
|
def put_item(self, item_attrs, expected=None, overwrite=False):
|
|
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
|
|
if self.has_range_key:
|
|
range_value = DynamoType(item_attrs.get(self.range_key_attr))
|
|
else:
|
|
range_value = None
|
|
|
|
if expected is None:
|
|
expected = {}
|
|
lookup_range_value = range_value
|
|
else:
|
|
expected_range_value = expected.get(
|
|
self.range_key_attr, {}).get("Value")
|
|
if(expected_range_value is None):
|
|
lookup_range_value = range_value
|
|
else:
|
|
lookup_range_value = DynamoType(expected_range_value)
|
|
current = self.get_item(hash_value, lookup_range_value)
|
|
|
|
item = Item(hash_value, self.hash_key_type, range_value,
|
|
self.range_key_type, item_attrs)
|
|
|
|
if not overwrite:
|
|
if current is None:
|
|
current_attr = {}
|
|
elif hasattr(current, 'attrs'):
|
|
current_attr = current.attrs
|
|
else:
|
|
current_attr = current
|
|
|
|
for key, val in expected.items():
|
|
if 'Exists' in val and val['Exists'] is False \
|
|
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
|
|
if key in current_attr:
|
|
raise ValueError("The conditional request failed")
|
|
elif key not in current_attr:
|
|
raise ValueError("The conditional request failed")
|
|
elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value:
|
|
raise ValueError("The conditional request failed")
|
|
elif 'ComparisonOperator' in val:
|
|
dynamo_types = [
|
|
DynamoType(ele) for ele in
|
|
val.get("AttributeValueList", [])
|
|
]
|
|
if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types):
|
|
raise ValueError('The conditional request failed')
|
|
if range_value:
|
|
self.items[hash_value][range_value] = item
|
|
else:
|
|
self.items[hash_value] = item
|
|
|
|
if self.stream_shard is not None:
|
|
self.stream_shard.add(current, item)
|
|
|
|
return item
|
|
|
|
def __nonzero__(self):
|
|
return True
|
|
|
|
def __bool__(self):
|
|
return self.__nonzero__()
|
|
|
|
@property
|
|
def has_range_key(self):
|
|
return self.range_key_attr is not None
|
|
|
|
def get_item(self, hash_key, range_key=None):
|
|
if self.has_range_key and not range_key:
|
|
raise ValueError(
|
|
"Table has a range key, but no range key was passed into get_item")
|
|
try:
|
|
if range_key:
|
|
return self.items[hash_key][range_key]
|
|
|
|
if hash_key in self.items:
|
|
return self.items[hash_key]
|
|
|
|
raise KeyError
|
|
except KeyError:
|
|
return None
|
|
|
|
def delete_item(self, hash_key, range_key):
|
|
try:
|
|
if range_key:
|
|
item = self.items[hash_key].pop(range_key)
|
|
else:
|
|
item = self.items.pop(hash_key)
|
|
|
|
if self.stream_shard is not None:
|
|
self.stream_shard.add(item, None)
|
|
|
|
return item
|
|
except KeyError:
|
|
return None
|
|
|
|
def query(self, hash_key, range_comparison, range_objs, limit,
|
|
exclusive_start_key, scan_index_forward, projection_expression,
|
|
index_name=None, filter_expression=None, **filter_kwargs):
|
|
results = []
|
|
if index_name:
|
|
all_indexes = (self.global_indexes or []) + (self.indexes or [])
|
|
indexes_by_name = dict((i['IndexName'], i) for i in all_indexes)
|
|
if index_name not in indexes_by_name:
|
|
raise ValueError('Invalid index: %s for table: %s. Available indexes are: %s' % (
|
|
index_name, self.name, ', '.join(indexes_by_name.keys())
|
|
))
|
|
|
|
index = indexes_by_name[index_name]
|
|
try:
|
|
index_hash_key = [key for key in index[
|
|
'KeySchema'] if key['KeyType'] == 'HASH'][0]
|
|
except IndexError:
|
|
raise ValueError('Missing Hash Key. KeySchema: %s' %
|
|
index['KeySchema'])
|
|
|
|
possible_results = []
|
|
for item in self.all_items():
|
|
if not isinstance(item, Item):
|
|
continue
|
|
item_hash_key = item.attrs.get(index_hash_key['AttributeName'])
|
|
if item_hash_key and item_hash_key == hash_key:
|
|
possible_results.append(item)
|
|
else:
|
|
possible_results = [item for item in list(self.all_items()) if isinstance(
|
|
item, Item) and item.hash_key == hash_key]
|
|
|
|
if index_name:
|
|
try:
|
|
index_range_key = [key for key in index[
|
|
'KeySchema'] if key['KeyType'] == 'RANGE'][0]
|
|
except IndexError:
|
|
index_range_key = None
|
|
|
|
if range_comparison:
|
|
if index_name and not index_range_key:
|
|
raise ValueError(
|
|
'Range Key comparison but no range key found for index: %s' % index_name)
|
|
|
|
elif index_name:
|
|
for result in possible_results:
|
|
if result.attrs.get(index_range_key['AttributeName']).compare(range_comparison, range_objs):
|
|
results.append(result)
|
|
else:
|
|
for result in possible_results:
|
|
if result.range_key.compare(range_comparison, range_objs):
|
|
results.append(result)
|
|
|
|
if filter_kwargs:
|
|
for result in possible_results:
|
|
for field, value in filter_kwargs.items():
|
|
dynamo_types = [DynamoType(ele) for ele in value[
|
|
"AttributeValueList"]]
|
|
if result.attrs.get(field).compare(value['ComparisonOperator'], dynamo_types):
|
|
results.append(result)
|
|
|
|
if not range_comparison and not filter_kwargs:
|
|
# If we're not filtering on range key or on an index return all
|
|
# values
|
|
results = possible_results
|
|
|
|
if index_name:
|
|
|
|
if index_range_key:
|
|
results.sort(key=lambda item: item.attrs[index_range_key['AttributeName']].value
|
|
if item.attrs.get(index_range_key['AttributeName']) else None)
|
|
else:
|
|
results.sort(key=lambda item: item.range_key)
|
|
|
|
if scan_index_forward is False:
|
|
results.reverse()
|
|
|
|
scanned_count = len(list(self.all_items()))
|
|
|
|
if filter_expression is not None:
|
|
results = [item for item in results if filter_expression.expr(item)]
|
|
|
|
if projection_expression:
|
|
expressions = [x.strip() for x in projection_expression.split(',')]
|
|
results = copy.deepcopy(results)
|
|
for result in results:
|
|
for attr in list(result.attrs):
|
|
if attr not in expressions:
|
|
result.attrs.pop(attr)
|
|
|
|
results, last_evaluated_key = self._trim_results(results, limit,
|
|
exclusive_start_key)
|
|
return results, scanned_count, last_evaluated_key
|
|
|
|
def all_items(self):
|
|
for hash_set in self.items.values():
|
|
if self.range_key_attr:
|
|
for item in hash_set.values():
|
|
yield item
|
|
else:
|
|
yield hash_set
|
|
|
|
def scan(self, filters, limit, exclusive_start_key, filter_expression=None):
|
|
results = []
|
|
scanned_count = 0
|
|
|
|
for item in self.all_items():
|
|
scanned_count += 1
|
|
passes_all_conditions = True
|
|
for attribute_name, (comparison_operator, comparison_objs) in filters.items():
|
|
attribute = item.attrs.get(attribute_name)
|
|
|
|
if attribute:
|
|
# Attribute found
|
|
if not attribute.compare(comparison_operator, comparison_objs):
|
|
passes_all_conditions = False
|
|
break
|
|
elif comparison_operator == 'NULL':
|
|
# Comparison is NULL and we don't have the attribute
|
|
continue
|
|
else:
|
|
# No attribute found and comparison is no NULL. This item
|
|
# fails
|
|
passes_all_conditions = False
|
|
break
|
|
|
|
if filter_expression is not None:
|
|
passes_all_conditions &= filter_expression.expr(item)
|
|
|
|
if passes_all_conditions:
|
|
results.append(item)
|
|
|
|
results, last_evaluated_key = self._trim_results(results, limit,
|
|
exclusive_start_key)
|
|
return results, scanned_count, last_evaluated_key
|
|
|
|
def _trim_results(self, results, limit, exclusive_start_key):
|
|
if exclusive_start_key is not None:
|
|
hash_key = DynamoType(exclusive_start_key.get(self.hash_key_attr))
|
|
range_key = exclusive_start_key.get(self.range_key_attr)
|
|
if range_key is not None:
|
|
range_key = DynamoType(range_key)
|
|
for i in range(len(results)):
|
|
if results[i].hash_key == hash_key and results[i].range_key == range_key:
|
|
results = results[i + 1:]
|
|
break
|
|
|
|
last_evaluated_key = None
|
|
if limit and len(results) > limit:
|
|
results = results[:limit]
|
|
last_evaluated_key = {
|
|
self.hash_key_attr: results[-1].hash_key
|
|
}
|
|
if results[-1].range_key is not None:
|
|
last_evaluated_key[self.range_key_attr] = results[-1].range_key
|
|
|
|
return results, last_evaluated_key
|
|
|
|
def lookup(self, *args, **kwargs):
|
|
if not self.schema:
|
|
self.describe()
|
|
for x, arg in enumerate(args):
|
|
kwargs[self.schema[x].name] = arg
|
|
ret = self.get_item(**kwargs)
|
|
if not ret.keys():
|
|
return None
|
|
return ret
|
|
|
|
|
|
class DynamoDBBackend(BaseBackend):
|
|
|
|
def __init__(self, region_name=None):
|
|
self.region_name = region_name
|
|
self.tables = OrderedDict()
|
|
|
|
def reset(self):
|
|
region_name = self.region_name
|
|
|
|
self.__dict__ = {}
|
|
self.__init__(region_name)
|
|
|
|
def create_table(self, name, **params):
|
|
if name in self.tables:
|
|
return None
|
|
table = Table(name, **params)
|
|
self.tables[name] = table
|
|
return table
|
|
|
|
def delete_table(self, name):
|
|
return self.tables.pop(name, None)
|
|
|
|
def tag_resource(self, table_arn, tags):
|
|
for table in self.tables:
|
|
if self.tables[table].table_arn == table_arn:
|
|
self.tables[table].tags.extend(tags)
|
|
|
|
def untag_resource(self, table_arn, tag_keys):
|
|
for table in self.tables:
|
|
if self.tables[table].table_arn == table_arn:
|
|
self.tables[table].tags = [tag for tag in self.tables[table].tags if tag['Key'] not in tag_keys]
|
|
|
|
def list_tags_of_resource(self, table_arn):
|
|
required_table = None
|
|
for table in self.tables:
|
|
if self.tables[table].table_arn == table_arn:
|
|
required_table = self.tables[table]
|
|
return required_table.tags
|
|
|
|
def update_table_throughput(self, name, throughput):
|
|
table = self.tables[name]
|
|
table.throughput = throughput
|
|
return table
|
|
|
|
def update_table_streams(self, name, stream_specification):
|
|
table = self.tables[name]
|
|
if (stream_specification.get('StreamEnabled') or stream_specification.get('StreamViewType')) and table.latest_stream_label:
|
|
raise ValueError('Table already has stream enabled')
|
|
table.set_stream_specification(stream_specification)
|
|
return table
|
|
|
|
def update_table_global_indexes(self, name, global_index_updates):
|
|
table = self.tables[name]
|
|
gsis_by_name = dict((i['IndexName'], i) for i in table.global_indexes)
|
|
for gsi_update in global_index_updates:
|
|
gsi_to_create = gsi_update.get('Create')
|
|
gsi_to_update = gsi_update.get('Update')
|
|
gsi_to_delete = gsi_update.get('Delete')
|
|
|
|
if gsi_to_delete:
|
|
index_name = gsi_to_delete['IndexName']
|
|
if index_name not in gsis_by_name:
|
|
raise ValueError('Global Secondary Index does not exist, but tried to delete: %s' %
|
|
gsi_to_delete['IndexName'])
|
|
|
|
del gsis_by_name[index_name]
|
|
|
|
if gsi_to_update:
|
|
index_name = gsi_to_update['IndexName']
|
|
if index_name not in gsis_by_name:
|
|
raise ValueError('Global Secondary Index does not exist, but tried to update: %s' %
|
|
gsi_to_update['IndexName'])
|
|
gsis_by_name[index_name].update(gsi_to_update)
|
|
|
|
if gsi_to_create:
|
|
if gsi_to_create['IndexName'] in gsis_by_name:
|
|
raise ValueError(
|
|
'Global Secondary Index already exists: %s' % gsi_to_create['IndexName'])
|
|
|
|
gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create
|
|
|
|
# in python 3.6, dict.values() returns a dict_values object, but we expect it to be a list in other
|
|
# parts of the codebase
|
|
table.global_indexes = list(gsis_by_name.values())
|
|
return table
|
|
|
|
def put_item(self, table_name, item_attrs, expected=None, overwrite=False):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None
|
|
return table.put_item(item_attrs, expected, overwrite)
|
|
|
|
def get_table_keys_name(self, table_name, keys):
|
|
"""
|
|
Given a set of keys, extracts the key and range key
|
|
"""
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None, None
|
|
else:
|
|
if len(keys) == 1:
|
|
for key in keys:
|
|
if key in table.hash_key_names:
|
|
return key, None
|
|
# for potential_hash, potential_range in zip(table.hash_key_names, table.range_key_names):
|
|
# if set([potential_hash, potential_range]) == set(keys):
|
|
# return potential_hash, potential_range
|
|
potential_hash, potential_range = None, None
|
|
for key in set(keys):
|
|
if key in table.hash_key_names:
|
|
potential_hash = key
|
|
elif key in table.range_key_names:
|
|
potential_range = key
|
|
return potential_hash, potential_range
|
|
|
|
def get_keys_value(self, table, keys):
|
|
if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys):
|
|
raise ValueError(
|
|
"Table has a range key, but no range key was passed into get_item")
|
|
hash_key = DynamoType(keys[table.hash_key_attr])
|
|
range_key = DynamoType(
|
|
keys[table.range_key_attr]) if table.has_range_key else None
|
|
return hash_key, range_key
|
|
|
|
def get_table(self, table_name):
|
|
return self.tables.get(table_name)
|
|
|
|
def get_item(self, table_name, keys):
|
|
table = self.get_table(table_name)
|
|
if not table:
|
|
raise ValueError("No table found")
|
|
hash_key, range_key = self.get_keys_value(table, keys)
|
|
return table.get_item(hash_key, range_key)
|
|
|
|
def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts,
|
|
limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None,
|
|
expr_names=None, expr_values=None, filter_expression=None,
|
|
**filter_kwargs):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None, None
|
|
|
|
hash_key = DynamoType(hash_key_dict)
|
|
range_values = [DynamoType(range_value)
|
|
for range_value in range_value_dicts]
|
|
|
|
if filter_expression is not None:
|
|
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
|
|
else:
|
|
filter_expression = Op(None, None) # Will always eval to true
|
|
|
|
return table.query(hash_key, range_comparison, range_values, limit,
|
|
exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs)
|
|
|
|
def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values):
|
|
table = self.tables.get(table_name)
|
|
if not table:
|
|
return None, None, None
|
|
|
|
scan_filters = {}
|
|
for key, (comparison_operator, comparison_values) in filters.items():
|
|
dynamo_types = [DynamoType(value) for value in comparison_values]
|
|
scan_filters[key] = (comparison_operator, dynamo_types)
|
|
|
|
if filter_expression is not None:
|
|
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
|
|
else:
|
|
filter_expression = Op(None, None) # Will always eval to true
|
|
|
|
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression)
|
|
|
|
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
|
|
expression_attribute_values, expected=None):
|
|
table = self.get_table(table_name)
|
|
|
|
if all([table.hash_key_attr in key, table.range_key_attr in key]):
|
|
# Covers cases where table has hash and range keys, ``key`` param
|
|
# will be a dict
|
|
hash_value = DynamoType(key[table.hash_key_attr])
|
|
range_value = DynamoType(key[table.range_key_attr])
|
|
elif table.hash_key_attr in key:
|
|
# Covers tables that have a range key where ``key`` param is a dict
|
|
hash_value = DynamoType(key[table.hash_key_attr])
|
|
range_value = None
|
|
else:
|
|
# Covers other cases
|
|
hash_value = DynamoType(key)
|
|
range_value = None
|
|
|
|
item = table.get_item(hash_value, range_value)
|
|
|
|
if item is None:
|
|
item_attr = {}
|
|
elif hasattr(item, 'attrs'):
|
|
item_attr = item.attrs
|
|
else:
|
|
item_attr = item
|
|
|
|
if not expected:
|
|
expected = {}
|
|
|
|
for key, val in expected.items():
|
|
if 'Exists' in val and val['Exists'] is False \
|
|
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
|
|
if key in item_attr:
|
|
raise ValueError("The conditional request failed")
|
|
elif key not in item_attr:
|
|
raise ValueError("The conditional request failed")
|
|
elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value:
|
|
raise ValueError("The conditional request failed")
|
|
elif 'ComparisonOperator' in val:
|
|
dynamo_types = [
|
|
DynamoType(ele) for ele in
|
|
val.get("AttributeValueList", [])
|
|
]
|
|
if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types):
|
|
raise ValueError('The conditional request failed')
|
|
|
|
# Update does not fail on new items, so create one
|
|
if item is None:
|
|
data = {
|
|
table.hash_key_attr: {
|
|
hash_value.type: hash_value.value,
|
|
},
|
|
}
|
|
if range_value:
|
|
data.update({
|
|
table.range_key_attr: {
|
|
range_value.type: range_value.value,
|
|
}
|
|
})
|
|
|
|
table.put_item(data)
|
|
item = table.get_item(hash_value, range_value)
|
|
|
|
if update_expression:
|
|
item.update(update_expression, expression_attribute_names,
|
|
expression_attribute_values)
|
|
else:
|
|
item.update_with_attribute_updates(attribute_updates)
|
|
return item
|
|
|
|
def delete_item(self, table_name, keys):
|
|
table = self.get_table(table_name)
|
|
if not table:
|
|
return None
|
|
hash_key, range_key = self.get_keys_value(table, keys)
|
|
return table.delete_item(hash_key, range_key)
|
|
|
|
def update_ttl(self, table_name, ttl_spec):
|
|
table = self.tables.get(table_name)
|
|
if table is None:
|
|
raise JsonRESTError('ResourceNotFound', 'Table not found')
|
|
|
|
if 'Enabled' not in ttl_spec or 'AttributeName' not in ttl_spec:
|
|
raise JsonRESTError('InvalidParameterValue',
|
|
'TimeToLiveSpecification does not contain Enabled and AttributeName')
|
|
|
|
if ttl_spec['Enabled']:
|
|
table.ttl['TimeToLiveStatus'] = 'ENABLED'
|
|
else:
|
|
table.ttl['TimeToLiveStatus'] = 'DISABLED'
|
|
table.ttl['AttributeName'] = ttl_spec['AttributeName']
|
|
|
|
def describe_ttl(self, table_name):
|
|
table = self.tables.get(table_name)
|
|
if table is None:
|
|
raise JsonRESTError('ResourceNotFound', 'Table not found')
|
|
|
|
return table.ttl
|
|
|
|
|
|
available_regions = boto3.session.Session().get_available_regions("dynamodb")
|
|
dynamodb_backends = {region: DynamoDBBackend(region_name=region) for region in available_regions}
|