Refactor DynamoDB update expressions (#2497)
* Refactor DynamoDB.update to use recursive method for nested updates * Simplify DynamoDB.update_item logic
This commit is contained in:
parent
9e4860ccd8
commit
64cf1fc2c9
@ -34,14 +34,76 @@ def bytesize(val):
|
||||
return len(str(val).encode('utf-8'))
|
||||
|
||||
|
||||
def attribute_is_list(attr):
|
||||
"""
|
||||
Checks if attribute denotes a list, and returns the regular expression if so
|
||||
:param attr: attr or attr[index]
|
||||
:return: attr, re or None
|
||||
"""
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', attr)
|
||||
if list_index_update:
|
||||
attr = list_index_update.group(1)
|
||||
return attr, list_index_update.group(2) if list_index_update else None
|
||||
|
||||
|
||||
class DynamoType(object):
|
||||
"""
|
||||
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
|
||||
"""
|
||||
|
||||
def __init__(self, type_as_dict):
|
||||
self.type = list(type_as_dict)[0]
|
||||
self.value = list(type_as_dict.values())[0]
|
||||
if type(type_as_dict) == DynamoType:
|
||||
self.type = type_as_dict.type
|
||||
self.value = type_as_dict.value
|
||||
else:
|
||||
self.type = list(type_as_dict)[0]
|
||||
self.value = list(type_as_dict.values())[0]
|
||||
if self.is_list():
|
||||
self.value = [DynamoType(val) for val in self.value]
|
||||
elif self.is_map():
|
||||
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
|
||||
|
||||
def set(self, key, new_value, index=None):
|
||||
if index:
|
||||
index = int(index)
|
||||
if type(self.value) is not list:
|
||||
raise InvalidUpdateExpression
|
||||
if index >= len(self.value):
|
||||
self.value.append(new_value)
|
||||
# {'L': [DynamoType, ..]} ==> DynamoType.set()
|
||||
self.value[min(index, len(self.value) - 1)].set(key, new_value)
|
||||
else:
|
||||
attr = (key or '').split('.').pop(0)
|
||||
attr, list_index = attribute_is_list(attr)
|
||||
if not key:
|
||||
# {'S': value} ==> {'S': new_value}
|
||||
self.value = new_value.value
|
||||
else:
|
||||
if attr not in self.value: # nonexistingattribute
|
||||
type_of_new_attr = 'M' if '.' in key else new_value.type
|
||||
self.value[attr] = DynamoType({type_of_new_attr: {}})
|
||||
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
|
||||
self.value[attr].set('.'.join(key.split('.')[1:]), new_value, list_index)
|
||||
|
||||
def delete(self, key, index=None):
|
||||
if index:
|
||||
if not key:
|
||||
if int(index) < len(self.value):
|
||||
del self.value[int(index)]
|
||||
elif '.' in key:
|
||||
self.value[int(index)].delete('.'.join(key.split('.')[1:]))
|
||||
else:
|
||||
self.value[int(index)].delete(key)
|
||||
else:
|
||||
attr = key.split('.')[0]
|
||||
attr, list_index = attribute_is_list(attr)
|
||||
|
||||
if list_index:
|
||||
self.value[attr].delete('.'.join(key.split('.')[1:]), list_index)
|
||||
elif '.' in key:
|
||||
self.value[attr].delete('.'.join(key.split('.')[1:]))
|
||||
else:
|
||||
self.value.pop(key)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.type, self.value))
|
||||
@ -98,7 +160,7 @@ class DynamoType(object):
|
||||
|
||||
if isinstance(key, int) and self.is_list():
|
||||
idx = key
|
||||
if idx >= 0 and idx < len(self.value):
|
||||
if 0 <= idx < len(self.value):
|
||||
return DynamoType(self.value[idx])
|
||||
|
||||
return None
|
||||
@ -110,7 +172,7 @@ class DynamoType(object):
|
||||
sub_type = self.type[0]
|
||||
value_size = sum([DynamoType({sub_type: v}).size() for v in self.value])
|
||||
elif self.is_list():
|
||||
value_size = sum([DynamoType(v).size() for v in self.value])
|
||||
value_size = sum([v.size() for v in self.value])
|
||||
elif self.is_map():
|
||||
value_size = sum([bytesize(k) + DynamoType(v).size() for k, v in self.value.items()])
|
||||
elif type(self.value) == bool:
|
||||
@ -162,22 +224,6 @@ class LimitedSizeDict(dict):
|
||||
raise ItemSizeTooLarge
|
||||
super(LimitedSizeDict, self).__setitem__(key, value)
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
if args:
|
||||
if len(args) > 1:
|
||||
raise TypeError("update expected at most 1 arguments, "
|
||||
"got %d" % len(args))
|
||||
other = dict(args[0])
|
||||
for key in other:
|
||||
self[key] = other[key]
|
||||
for key in kwargs:
|
||||
self[key] = kwargs[key]
|
||||
|
||||
def setdefault(self, key, value=None):
|
||||
if key not in self:
|
||||
self[key] = value
|
||||
return self[key]
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
|
||||
@ -236,72 +282,26 @@ class Item(BaseModel):
|
||||
|
||||
if action == "REMOVE":
|
||||
key = value
|
||||
attr, list_index = attribute_is_list(key.split('.')[0])
|
||||
if '.' not in key:
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key)
|
||||
if list_index_update:
|
||||
# We need to remove an item from a list (REMOVE listattr[0])
|
||||
key_attr = self.attrs[list_index_update.group(1)]
|
||||
list_index = int(list_index_update.group(2))
|
||||
if key_attr.is_list():
|
||||
if len(key_attr.value) > list_index:
|
||||
del key_attr.value[list_index]
|
||||
if list_index:
|
||||
new_list = DynamoType(self.attrs[attr])
|
||||
new_list.delete(None, list_index)
|
||||
self.attrs[attr] = new_list
|
||||
else:
|
||||
self.attrs.pop(value, None)
|
||||
else:
|
||||
# Handle nested dict updates
|
||||
key_parts = key.split('.')
|
||||
attr = key_parts.pop(0)
|
||||
if attr not in self.attrs:
|
||||
raise ValueError
|
||||
|
||||
last_val = self.attrs[attr].value
|
||||
for key_part in key_parts[:-1]:
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part)
|
||||
if list_index_update:
|
||||
key_part = list_index_update.group(1) # listattr[1] ==> listattr
|
||||
# Hack but it'll do, traverses into a dict
|
||||
last_val_type = list(last_val.keys())
|
||||
if last_val_type and last_val_type[0] == 'M':
|
||||
last_val = last_val['M']
|
||||
|
||||
if key_part not in last_val:
|
||||
last_val[key_part] = {'M': {}}
|
||||
|
||||
last_val = last_val[key_part]
|
||||
if list_index_update:
|
||||
last_val = last_val['L'][int(list_index_update.group(2))]
|
||||
|
||||
last_val_type = list(last_val.keys())
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_parts[-1])
|
||||
if list_index_update:
|
||||
# We need to remove an item from a list (REMOVE attr.listattr[0])
|
||||
key_part = list_index_update.group(1) # listattr[1] ==> listattr
|
||||
list_to_update = last_val[key_part]['L']
|
||||
index_to_remove = int(list_index_update.group(2))
|
||||
if index_to_remove < len(list_to_update):
|
||||
del list_to_update[index_to_remove]
|
||||
else:
|
||||
if last_val_type and last_val_type[0] == 'M':
|
||||
last_val['M'].pop(key_parts[-1], None)
|
||||
else:
|
||||
last_val.pop(key_parts[-1], None)
|
||||
self.attrs[attr].delete('.'.join(key.split('.')[1:]))
|
||||
elif action == 'SET':
|
||||
key, value = value.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# If not exists, changes value to a default if needed, else its the same as it was
|
||||
if value.startswith('if_not_exists'):
|
||||
# Function signature
|
||||
match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value)
|
||||
if not match:
|
||||
raise TypeError
|
||||
|
||||
path, value = match.groups()
|
||||
|
||||
# If it already exists, get its value so we dont overwrite it
|
||||
if path in self.attrs:
|
||||
value = self.attrs[path]
|
||||
# check whether key is a list
|
||||
attr, list_index = attribute_is_list(key.split('.')[0])
|
||||
# If value not exists, changes value to a default if needed, else its the same as it was
|
||||
value = self._get_default(value)
|
||||
|
||||
if type(value) != DynamoType:
|
||||
if value in expression_attribute_values:
|
||||
@ -311,55 +311,12 @@ class Item(BaseModel):
|
||||
else:
|
||||
dyn_value = value
|
||||
|
||||
if '.' not in key:
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key)
|
||||
if list_index_update:
|
||||
key_attr = self.attrs[list_index_update.group(1)]
|
||||
list_index = int(list_index_update.group(2))
|
||||
if key_attr.is_list():
|
||||
if len(key_attr.value) > list_index:
|
||||
key_attr.value[list_index] = expression_attribute_values[value]
|
||||
else:
|
||||
key_attr.value.append(expression_attribute_values[value])
|
||||
else:
|
||||
raise InvalidUpdateExpression
|
||||
else:
|
||||
self.attrs[key] = dyn_value
|
||||
if '.' in key and attr not in self.attrs:
|
||||
raise ValueError # Setting nested attr not allowed if first attr does not exist yet
|
||||
elif attr not in self.attrs:
|
||||
self.attrs[attr] = dyn_value # set new top-level attribute
|
||||
else:
|
||||
# Handle nested dict updates
|
||||
key_parts = key.split('.')
|
||||
attr = key_parts.pop(0)
|
||||
if attr not in self.attrs:
|
||||
raise ValueError
|
||||
last_val = self.attrs[attr].value
|
||||
for key_part in key_parts:
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part)
|
||||
if list_index_update:
|
||||
key_part = list_index_update.group(1) # listattr[1] ==> listattr
|
||||
# Hack but it'll do, traverses into a dict
|
||||
last_val_type = list(last_val.keys())
|
||||
if last_val_type and last_val_type[0] == 'M':
|
||||
last_val = last_val['M']
|
||||
|
||||
if key_part not in last_val:
|
||||
last_val[key_part] = {'M': {}}
|
||||
last_val = last_val[key_part]
|
||||
|
||||
current_type = list(last_val.keys())[0]
|
||||
if list_index_update:
|
||||
# We need to add an item to a list
|
||||
list_index = int(list_index_update.group(2))
|
||||
if len(last_val['L']) > list_index:
|
||||
last_val['L'][list_index] = expression_attribute_values[value]
|
||||
else:
|
||||
last_val['L'].append(expression_attribute_values[value])
|
||||
else:
|
||||
# We have reference to a nested object but we cant just assign to it
|
||||
if current_type == dyn_value.type:
|
||||
last_val[current_type] = dyn_value.value
|
||||
else:
|
||||
last_val[dyn_value.type] = dyn_value.value
|
||||
del last_val[current_type]
|
||||
self.attrs[attr].set('.'.join(key.split('.')[1:]), dyn_value, list_index) # set value recursively
|
||||
|
||||
elif action == 'ADD':
|
||||
key, value = value.split(" ", 1)
|
||||
@ -413,6 +370,20 @@ class Item(BaseModel):
|
||||
else:
|
||||
raise NotImplementedError('{} update action not yet supported'.format(action))
|
||||
|
||||
def _get_default(self, value):
|
||||
if value.startswith('if_not_exists'):
|
||||
# Function signature
|
||||
match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value)
|
||||
if not match:
|
||||
raise TypeError
|
||||
|
||||
path, value = match.groups()
|
||||
|
||||
# If it already exists, get its value so we dont overwrite it
|
||||
if path in self.attrs:
|
||||
value = self.attrs[path]
|
||||
return value
|
||||
|
||||
def update_with_attribute_updates(self, attribute_updates):
|
||||
for attribute_name, update_action in attribute_updates.items():
|
||||
action = update_action['Action']
|
||||
@ -810,7 +781,6 @@ class Table(BaseModel):
|
||||
else:
|
||||
possible_results = [item for item in list(self.all_items()) if isinstance(
|
||||
item, Item) and item.hash_key == hash_key]
|
||||
|
||||
if range_comparison:
|
||||
if index_name and not index_range_key:
|
||||
raise ValueError(
|
||||
|
@ -2161,20 +2161,11 @@ def test_condition_expression__attr_doesnt_exist():
|
||||
client.create_table(
|
||||
TableName='test',
|
||||
KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}],
|
||||
AttributeDefinitions=[
|
||||
{'AttributeName': 'forum_name', 'AttributeType': 'S'},
|
||||
],
|
||||
ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1},
|
||||
)
|
||||
|
||||
client.put_item(
|
||||
TableName='test',
|
||||
Item={
|
||||
'forum_name': {'S': 'foo'},
|
||||
'ttl': {'N': 'bar'},
|
||||
}
|
||||
)
|
||||
AttributeDefinitions=[{'AttributeName': 'forum_name', 'AttributeType': 'S'}],
|
||||
ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1})
|
||||
|
||||
client.put_item(TableName='test',
|
||||
Item={'forum_name': {'S': 'foo'}, 'ttl': {'N': 'bar'}})
|
||||
|
||||
def update_if_attr_doesnt_exist():
|
||||
# Test nonexistent top-level attribute.
|
||||
@ -2261,6 +2252,7 @@ def test_condition_expression__and_order():
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mock_dynamodb2
|
||||
def test_query_gsi_with_range_key():
|
||||
dynamodb = boto3.client('dynamodb', region_name='us-east-1')
|
||||
@ -2510,13 +2502,15 @@ def test_index_with_unknown_attributes_should_fail():
|
||||
def test_update_list_index__set_existing_index():
|
||||
table_name = 'test_list_index_access'
|
||||
client = create_table_with_list(table_name)
|
||||
client.put_item(TableName=table_name,
|
||||
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
|
||||
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}},
|
||||
UpdateExpression='set itemlist[1]=:Item',
|
||||
ExpressionAttributeValues={':Item': {'S': 'bar2_update'}})
|
||||
#
|
||||
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
|
||||
assert result['id'] == {'S': 'foo'}
|
||||
assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]}
|
||||
result['id'].should.equal({'S': 'foo'})
|
||||
result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]})
|
||||
|
||||
|
||||
@mock_dynamodb2
|
||||
@ -2530,14 +2524,16 @@ def test_update_list_index__set_existing_nested_index():
|
||||
ExpressionAttributeValues={':Item': {'S': 'bar2_update'}})
|
||||
#
|
||||
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item']
|
||||
assert result['id'] == {'S': 'foo2'}
|
||||
assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]
|
||||
result['id'].should.equal({'S': 'foo2'})
|
||||
result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}])
|
||||
|
||||
|
||||
@mock_dynamodb2
|
||||
def test_update_list_index__set_index_out_of_range():
|
||||
table_name = 'test_list_index_access'
|
||||
client = create_table_with_list(table_name)
|
||||
client.put_item(TableName=table_name,
|
||||
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
|
||||
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}},
|
||||
UpdateExpression='set itemlist[10]=:Item',
|
||||
ExpressionAttributeValues={':Item': {'S': 'bar10'}})
|
||||
@ -2562,6 +2558,25 @@ def test_update_list_index__set_nested_index_out_of_range():
|
||||
assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}, {'S': 'bar10'}]
|
||||
|
||||
|
||||
@mock_dynamodb2
|
||||
def test_update_list_index__set_double_nested_index():
|
||||
table_name = 'test_list_index_access'
|
||||
client = create_table_with_list(table_name)
|
||||
client.put_item(TableName=table_name,
|
||||
Item={'id': {'S': 'foo2'},
|
||||
'itemmap': {'M': {'itemlist': {'L': [{'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}},
|
||||
{'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar21'}}}]}}}})
|
||||
client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}},
|
||||
UpdateExpression='set itemmap.itemlist[1].foos=:Item',
|
||||
ExpressionAttributeValues={':Item': {'S': 'bar22'}})
|
||||
#
|
||||
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item']
|
||||
assert result['id'] == {'S': 'foo2'}
|
||||
len(result['itemmap']['M']['itemlist']['L']).should.equal(2)
|
||||
result['itemmap']['M']['itemlist']['L'][0].should.equal({'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}}) # unchanged
|
||||
result['itemmap']['M']['itemlist']['L'][1].should.equal({'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar22'}}}) # updated
|
||||
|
||||
|
||||
@mock_dynamodb2
|
||||
def test_update_list_index__set_index_of_a_string():
|
||||
table_name = 'test_list_index_access'
|
||||
@ -2578,15 +2593,29 @@ def test_update_list_index__set_index_of_a_string():
|
||||
'The document path provided in the update expression is invalid for update')
|
||||
|
||||
|
||||
@mock_dynamodb2
|
||||
def test_remove_top_level_attribute():
|
||||
table_name = 'test_remove'
|
||||
client = create_table_with_list(table_name)
|
||||
client.put_item(TableName=table_name,
|
||||
Item={'id': {'S': 'foo'}, 'item': {'S': 'bar'}})
|
||||
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE item')
|
||||
#
|
||||
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
|
||||
result.should.equal({'id': {'S': 'foo'}})
|
||||
|
||||
|
||||
@mock_dynamodb2
|
||||
def test_remove_list_index__remove_existing_index():
|
||||
table_name = 'test_list_index_access'
|
||||
client = create_table_with_list(table_name)
|
||||
client.put_item(TableName=table_name,
|
||||
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
|
||||
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[1]')
|
||||
#
|
||||
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
|
||||
assert result['id'] == {'S': 'foo'}
|
||||
assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar3'}]}
|
||||
result['id'].should.equal({'S': 'foo'})
|
||||
result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar3'}]})
|
||||
|
||||
|
||||
@mock_dynamodb2
|
||||
@ -2598,8 +2627,8 @@ def test_remove_list_index__remove_existing_nested_index():
|
||||
client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, UpdateExpression='REMOVE itemmap.itemlist[1]')
|
||||
#
|
||||
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item']
|
||||
assert result['id'] == {'S': 'foo2'}
|
||||
assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}]
|
||||
result['id'].should.equal({'S': 'foo2'})
|
||||
result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}])
|
||||
|
||||
|
||||
@mock_dynamodb2
|
||||
@ -2626,6 +2655,8 @@ def test_remove_list_index__remove_existing_double_nested_index():
|
||||
def test_remove_list_index__remove_index_out_of_range():
|
||||
table_name = 'test_list_index_access'
|
||||
client = create_table_with_list(table_name)
|
||||
client.put_item(TableName=table_name,
|
||||
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
|
||||
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[10]')
|
||||
#
|
||||
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
|
||||
@ -2639,8 +2670,6 @@ def create_table_with_list(table_name):
|
||||
KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}],
|
||||
AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}],
|
||||
BillingMode='PAY_PER_REQUEST')
|
||||
client.put_item(TableName=table_name,
|
||||
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
|
||||
return client
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user