Refactor DynamoDB update expressions (#2497)

* Refactor DynamoDB.update to use recursive method for nested updates

* Simplify DynamoDB.update_item logic
This commit is contained in:
Bert Blommers 2019-10-22 20:40:41 +01:00 committed by Jack Danger
parent 9e4860ccd8
commit 64cf1fc2c9
2 changed files with 147 additions and 148 deletions

View File

@ -34,14 +34,76 @@ def bytesize(val):
return len(str(val).encode('utf-8'))
def attribute_is_list(attr):
"""
Checks if attribute denotes a list, and returns the regular expression if so
:param attr: attr or attr[index]
:return: attr, re or None
"""
list_index_update = re.match('(.+)\\[([0-9]+)\\]', attr)
if list_index_update:
attr = list_index_update.group(1)
return attr, list_index_update.group(2) if list_index_update else None
class DynamoType(object):
"""
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
"""
def __init__(self, type_as_dict):
self.type = list(type_as_dict)[0]
self.value = list(type_as_dict.values())[0]
if type(type_as_dict) == DynamoType:
self.type = type_as_dict.type
self.value = type_as_dict.value
else:
self.type = list(type_as_dict)[0]
self.value = list(type_as_dict.values())[0]
if self.is_list():
self.value = [DynamoType(val) for val in self.value]
elif self.is_map():
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
def set(self, key, new_value, index=None):
if index:
index = int(index)
if type(self.value) is not list:
raise InvalidUpdateExpression
if index >= len(self.value):
self.value.append(new_value)
# {'L': [DynamoType, ..]} ==> DynamoType.set()
self.value[min(index, len(self.value) - 1)].set(key, new_value)
else:
attr = (key or '').split('.').pop(0)
attr, list_index = attribute_is_list(attr)
if not key:
# {'S': value} ==> {'S': new_value}
self.value = new_value.value
else:
if attr not in self.value: # nonexistingattribute
type_of_new_attr = 'M' if '.' in key else new_value.type
self.value[attr] = DynamoType({type_of_new_attr: {}})
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
self.value[attr].set('.'.join(key.split('.')[1:]), new_value, list_index)
def delete(self, key, index=None):
if index:
if not key:
if int(index) < len(self.value):
del self.value[int(index)]
elif '.' in key:
self.value[int(index)].delete('.'.join(key.split('.')[1:]))
else:
self.value[int(index)].delete(key)
else:
attr = key.split('.')[0]
attr, list_index = attribute_is_list(attr)
if list_index:
self.value[attr].delete('.'.join(key.split('.')[1:]), list_index)
elif '.' in key:
self.value[attr].delete('.'.join(key.split('.')[1:]))
else:
self.value.pop(key)
def __hash__(self):
return hash((self.type, self.value))
@ -98,7 +160,7 @@ class DynamoType(object):
if isinstance(key, int) and self.is_list():
idx = key
if idx >= 0 and idx < len(self.value):
if 0 <= idx < len(self.value):
return DynamoType(self.value[idx])
return None
@ -110,7 +172,7 @@ class DynamoType(object):
sub_type = self.type[0]
value_size = sum([DynamoType({sub_type: v}).size() for v in self.value])
elif self.is_list():
value_size = sum([DynamoType(v).size() for v in self.value])
value_size = sum([v.size() for v in self.value])
elif self.is_map():
value_size = sum([bytesize(k) + DynamoType(v).size() for k, v in self.value.items()])
elif type(self.value) == bool:
@ -162,22 +224,6 @@ class LimitedSizeDict(dict):
raise ItemSizeTooLarge
super(LimitedSizeDict, self).__setitem__(key, value)
def update(self, *args, **kwargs):
if args:
if len(args) > 1:
raise TypeError("update expected at most 1 arguments, "
"got %d" % len(args))
other = dict(args[0])
for key in other:
self[key] = other[key]
for key in kwargs:
self[key] = kwargs[key]
def setdefault(self, key, value=None):
if key not in self:
self[key] = value
return self[key]
class Item(BaseModel):
@ -236,72 +282,26 @@ class Item(BaseModel):
if action == "REMOVE":
key = value
attr, list_index = attribute_is_list(key.split('.')[0])
if '.' not in key:
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key)
if list_index_update:
# We need to remove an item from a list (REMOVE listattr[0])
key_attr = self.attrs[list_index_update.group(1)]
list_index = int(list_index_update.group(2))
if key_attr.is_list():
if len(key_attr.value) > list_index:
del key_attr.value[list_index]
if list_index:
new_list = DynamoType(self.attrs[attr])
new_list.delete(None, list_index)
self.attrs[attr] = new_list
else:
self.attrs.pop(value, None)
else:
# Handle nested dict updates
key_parts = key.split('.')
attr = key_parts.pop(0)
if attr not in self.attrs:
raise ValueError
last_val = self.attrs[attr].value
for key_part in key_parts[:-1]:
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part)
if list_index_update:
key_part = list_index_update.group(1) # listattr[1] ==> listattr
# Hack but it'll do, traverses into a dict
last_val_type = list(last_val.keys())
if last_val_type and last_val_type[0] == 'M':
last_val = last_val['M']
if key_part not in last_val:
last_val[key_part] = {'M': {}}
last_val = last_val[key_part]
if list_index_update:
last_val = last_val['L'][int(list_index_update.group(2))]
last_val_type = list(last_val.keys())
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_parts[-1])
if list_index_update:
# We need to remove an item from a list (REMOVE attr.listattr[0])
key_part = list_index_update.group(1) # listattr[1] ==> listattr
list_to_update = last_val[key_part]['L']
index_to_remove = int(list_index_update.group(2))
if index_to_remove < len(list_to_update):
del list_to_update[index_to_remove]
else:
if last_val_type and last_val_type[0] == 'M':
last_val['M'].pop(key_parts[-1], None)
else:
last_val.pop(key_parts[-1], None)
self.attrs[attr].delete('.'.join(key.split('.')[1:]))
elif action == 'SET':
key, value = value.split("=", 1)
key = key.strip()
value = value.strip()
# If not exists, changes value to a default if needed, else its the same as it was
if value.startswith('if_not_exists'):
# Function signature
match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value)
if not match:
raise TypeError
path, value = match.groups()
# If it already exists, get its value so we dont overwrite it
if path in self.attrs:
value = self.attrs[path]
# check whether key is a list
attr, list_index = attribute_is_list(key.split('.')[0])
# If value not exists, changes value to a default if needed, else its the same as it was
value = self._get_default(value)
if type(value) != DynamoType:
if value in expression_attribute_values:
@ -311,55 +311,12 @@ class Item(BaseModel):
else:
dyn_value = value
if '.' not in key:
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key)
if list_index_update:
key_attr = self.attrs[list_index_update.group(1)]
list_index = int(list_index_update.group(2))
if key_attr.is_list():
if len(key_attr.value) > list_index:
key_attr.value[list_index] = expression_attribute_values[value]
else:
key_attr.value.append(expression_attribute_values[value])
else:
raise InvalidUpdateExpression
else:
self.attrs[key] = dyn_value
if '.' in key and attr not in self.attrs:
raise ValueError # Setting nested attr not allowed if first attr does not exist yet
elif attr not in self.attrs:
self.attrs[attr] = dyn_value # set new top-level attribute
else:
# Handle nested dict updates
key_parts = key.split('.')
attr = key_parts.pop(0)
if attr not in self.attrs:
raise ValueError
last_val = self.attrs[attr].value
for key_part in key_parts:
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part)
if list_index_update:
key_part = list_index_update.group(1) # listattr[1] ==> listattr
# Hack but it'll do, traverses into a dict
last_val_type = list(last_val.keys())
if last_val_type and last_val_type[0] == 'M':
last_val = last_val['M']
if key_part not in last_val:
last_val[key_part] = {'M': {}}
last_val = last_val[key_part]
current_type = list(last_val.keys())[0]
if list_index_update:
# We need to add an item to a list
list_index = int(list_index_update.group(2))
if len(last_val['L']) > list_index:
last_val['L'][list_index] = expression_attribute_values[value]
else:
last_val['L'].append(expression_attribute_values[value])
else:
# We have reference to a nested object but we cant just assign to it
if current_type == dyn_value.type:
last_val[current_type] = dyn_value.value
else:
last_val[dyn_value.type] = dyn_value.value
del last_val[current_type]
self.attrs[attr].set('.'.join(key.split('.')[1:]), dyn_value, list_index) # set value recursively
elif action == 'ADD':
key, value = value.split(" ", 1)
@ -413,6 +370,20 @@ class Item(BaseModel):
else:
raise NotImplementedError('{} update action not yet supported'.format(action))
def _get_default(self, value):
if value.startswith('if_not_exists'):
# Function signature
match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value)
if not match:
raise TypeError
path, value = match.groups()
# If it already exists, get its value so we dont overwrite it
if path in self.attrs:
value = self.attrs[path]
return value
def update_with_attribute_updates(self, attribute_updates):
for attribute_name, update_action in attribute_updates.items():
action = update_action['Action']
@ -810,7 +781,6 @@ class Table(BaseModel):
else:
possible_results = [item for item in list(self.all_items()) if isinstance(
item, Item) and item.hash_key == hash_key]
if range_comparison:
if index_name and not index_range_key:
raise ValueError(

View File

@ -2161,20 +2161,11 @@ def test_condition_expression__attr_doesnt_exist():
client.create_table(
TableName='test',
KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}],
AttributeDefinitions=[
{'AttributeName': 'forum_name', 'AttributeType': 'S'},
],
ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1},
)
client.put_item(
TableName='test',
Item={
'forum_name': {'S': 'foo'},
'ttl': {'N': 'bar'},
}
)
AttributeDefinitions=[{'AttributeName': 'forum_name', 'AttributeType': 'S'}],
ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1})
client.put_item(TableName='test',
Item={'forum_name': {'S': 'foo'}, 'ttl': {'N': 'bar'}})
def update_if_attr_doesnt_exist():
# Test nonexistent top-level attribute.
@ -2261,6 +2252,7 @@ def test_condition_expression__and_order():
}
)
@mock_dynamodb2
def test_query_gsi_with_range_key():
dynamodb = boto3.client('dynamodb', region_name='us-east-1')
@ -2510,13 +2502,15 @@ def test_index_with_unknown_attributes_should_fail():
def test_update_list_index__set_existing_index():
table_name = 'test_list_index_access'
client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}},
UpdateExpression='set itemlist[1]=:Item',
ExpressionAttributeValues={':Item': {'S': 'bar2_update'}})
#
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
assert result['id'] == {'S': 'foo'}
assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]}
result['id'].should.equal({'S': 'foo'})
result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]})
@mock_dynamodb2
@ -2530,14 +2524,16 @@ def test_update_list_index__set_existing_nested_index():
ExpressionAttributeValues={':Item': {'S': 'bar2_update'}})
#
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item']
assert result['id'] == {'S': 'foo2'}
assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]
result['id'].should.equal({'S': 'foo2'})
result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}])
@mock_dynamodb2
def test_update_list_index__set_index_out_of_range():
table_name = 'test_list_index_access'
client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}},
UpdateExpression='set itemlist[10]=:Item',
ExpressionAttributeValues={':Item': {'S': 'bar10'}})
@ -2562,6 +2558,25 @@ def test_update_list_index__set_nested_index_out_of_range():
assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}, {'S': 'bar10'}]
@mock_dynamodb2
def test_update_list_index__set_double_nested_index():
table_name = 'test_list_index_access'
client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo2'},
'itemmap': {'M': {'itemlist': {'L': [{'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}},
{'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar21'}}}]}}}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}},
UpdateExpression='set itemmap.itemlist[1].foos=:Item',
ExpressionAttributeValues={':Item': {'S': 'bar22'}})
#
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item']
assert result['id'] == {'S': 'foo2'}
len(result['itemmap']['M']['itemlist']['L']).should.equal(2)
result['itemmap']['M']['itemlist']['L'][0].should.equal({'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}}) # unchanged
result['itemmap']['M']['itemlist']['L'][1].should.equal({'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar22'}}}) # updated
@mock_dynamodb2
def test_update_list_index__set_index_of_a_string():
table_name = 'test_list_index_access'
@ -2578,15 +2593,29 @@ def test_update_list_index__set_index_of_a_string():
'The document path provided in the update expression is invalid for update')
@mock_dynamodb2
def test_remove_top_level_attribute():
table_name = 'test_remove'
client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'item': {'S': 'bar'}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE item')
#
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
result.should.equal({'id': {'S': 'foo'}})
@mock_dynamodb2
def test_remove_list_index__remove_existing_index():
table_name = 'test_list_index_access'
client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[1]')
#
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
assert result['id'] == {'S': 'foo'}
assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar3'}]}
result['id'].should.equal({'S': 'foo'})
result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar3'}]})
@mock_dynamodb2
@ -2598,8 +2627,8 @@ def test_remove_list_index__remove_existing_nested_index():
client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, UpdateExpression='REMOVE itemmap.itemlist[1]')
#
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item']
assert result['id'] == {'S': 'foo2'}
assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}]
result['id'].should.equal({'S': 'foo2'})
result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}])
@mock_dynamodb2
@ -2626,6 +2655,8 @@ def test_remove_list_index__remove_existing_double_nested_index():
def test_remove_list_index__remove_index_out_of_range():
table_name = 'test_list_index_access'
client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[10]')
#
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
@ -2639,8 +2670,6 @@ def create_table_with_list(table_name):
KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}],
AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}],
BillingMode='PAY_PER_REQUEST')
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
return client