Refactor DynamoDB update expressions (#2497)

* Refactor DynamoDB.update to use recursive method for nested updates

* Simplify DynamoDB.update_item logic
This commit is contained in:
Bert Blommers 2019-10-22 20:40:41 +01:00 committed by Jack Danger
parent 9e4860ccd8
commit 64cf1fc2c9
2 changed files with 147 additions and 148 deletions

View File

@ -34,14 +34,76 @@ def bytesize(val):
return len(str(val).encode('utf-8')) return len(str(val).encode('utf-8'))
def attribute_is_list(attr):
"""
Checks if attribute denotes a list, and returns the regular expression if so
:param attr: attr or attr[index]
:return: attr, re or None
"""
list_index_update = re.match('(.+)\\[([0-9]+)\\]', attr)
if list_index_update:
attr = list_index_update.group(1)
return attr, list_index_update.group(2) if list_index_update else None
class DynamoType(object): class DynamoType(object):
""" """
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
""" """
def __init__(self, type_as_dict): def __init__(self, type_as_dict):
self.type = list(type_as_dict)[0] if type(type_as_dict) == DynamoType:
self.value = list(type_as_dict.values())[0] self.type = type_as_dict.type
self.value = type_as_dict.value
else:
self.type = list(type_as_dict)[0]
self.value = list(type_as_dict.values())[0]
if self.is_list():
self.value = [DynamoType(val) for val in self.value]
elif self.is_map():
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
def set(self, key, new_value, index=None):
if index:
index = int(index)
if type(self.value) is not list:
raise InvalidUpdateExpression
if index >= len(self.value):
self.value.append(new_value)
# {'L': [DynamoType, ..]} ==> DynamoType.set()
self.value[min(index, len(self.value) - 1)].set(key, new_value)
else:
attr = (key or '').split('.').pop(0)
attr, list_index = attribute_is_list(attr)
if not key:
# {'S': value} ==> {'S': new_value}
self.value = new_value.value
else:
if attr not in self.value: # nonexistingattribute
type_of_new_attr = 'M' if '.' in key else new_value.type
self.value[attr] = DynamoType({type_of_new_attr: {}})
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
self.value[attr].set('.'.join(key.split('.')[1:]), new_value, list_index)
def delete(self, key, index=None):
if index:
if not key:
if int(index) < len(self.value):
del self.value[int(index)]
elif '.' in key:
self.value[int(index)].delete('.'.join(key.split('.')[1:]))
else:
self.value[int(index)].delete(key)
else:
attr = key.split('.')[0]
attr, list_index = attribute_is_list(attr)
if list_index:
self.value[attr].delete('.'.join(key.split('.')[1:]), list_index)
elif '.' in key:
self.value[attr].delete('.'.join(key.split('.')[1:]))
else:
self.value.pop(key)
def __hash__(self): def __hash__(self):
return hash((self.type, self.value)) return hash((self.type, self.value))
@ -98,7 +160,7 @@ class DynamoType(object):
if isinstance(key, int) and self.is_list(): if isinstance(key, int) and self.is_list():
idx = key idx = key
if idx >= 0 and idx < len(self.value): if 0 <= idx < len(self.value):
return DynamoType(self.value[idx]) return DynamoType(self.value[idx])
return None return None
@ -110,7 +172,7 @@ class DynamoType(object):
sub_type = self.type[0] sub_type = self.type[0]
value_size = sum([DynamoType({sub_type: v}).size() for v in self.value]) value_size = sum([DynamoType({sub_type: v}).size() for v in self.value])
elif self.is_list(): elif self.is_list():
value_size = sum([DynamoType(v).size() for v in self.value]) value_size = sum([v.size() for v in self.value])
elif self.is_map(): elif self.is_map():
value_size = sum([bytesize(k) + DynamoType(v).size() for k, v in self.value.items()]) value_size = sum([bytesize(k) + DynamoType(v).size() for k, v in self.value.items()])
elif type(self.value) == bool: elif type(self.value) == bool:
@ -162,22 +224,6 @@ class LimitedSizeDict(dict):
raise ItemSizeTooLarge raise ItemSizeTooLarge
super(LimitedSizeDict, self).__setitem__(key, value) super(LimitedSizeDict, self).__setitem__(key, value)
def update(self, *args, **kwargs):
if args:
if len(args) > 1:
raise TypeError("update expected at most 1 arguments, "
"got %d" % len(args))
other = dict(args[0])
for key in other:
self[key] = other[key]
for key in kwargs:
self[key] = kwargs[key]
def setdefault(self, key, value=None):
if key not in self:
self[key] = value
return self[key]
class Item(BaseModel): class Item(BaseModel):
@ -236,72 +282,26 @@ class Item(BaseModel):
if action == "REMOVE": if action == "REMOVE":
key = value key = value
attr, list_index = attribute_is_list(key.split('.')[0])
if '.' not in key: if '.' not in key:
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key) if list_index:
if list_index_update: new_list = DynamoType(self.attrs[attr])
# We need to remove an item from a list (REMOVE listattr[0]) new_list.delete(None, list_index)
key_attr = self.attrs[list_index_update.group(1)] self.attrs[attr] = new_list
list_index = int(list_index_update.group(2))
if key_attr.is_list():
if len(key_attr.value) > list_index:
del key_attr.value[list_index]
else: else:
self.attrs.pop(value, None) self.attrs.pop(value, None)
else: else:
# Handle nested dict updates # Handle nested dict updates
key_parts = key.split('.') self.attrs[attr].delete('.'.join(key.split('.')[1:]))
attr = key_parts.pop(0)
if attr not in self.attrs:
raise ValueError
last_val = self.attrs[attr].value
for key_part in key_parts[:-1]:
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part)
if list_index_update:
key_part = list_index_update.group(1) # listattr[1] ==> listattr
# Hack but it'll do, traverses into a dict
last_val_type = list(last_val.keys())
if last_val_type and last_val_type[0] == 'M':
last_val = last_val['M']
if key_part not in last_val:
last_val[key_part] = {'M': {}}
last_val = last_val[key_part]
if list_index_update:
last_val = last_val['L'][int(list_index_update.group(2))]
last_val_type = list(last_val.keys())
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_parts[-1])
if list_index_update:
# We need to remove an item from a list (REMOVE attr.listattr[0])
key_part = list_index_update.group(1) # listattr[1] ==> listattr
list_to_update = last_val[key_part]['L']
index_to_remove = int(list_index_update.group(2))
if index_to_remove < len(list_to_update):
del list_to_update[index_to_remove]
else:
if last_val_type and last_val_type[0] == 'M':
last_val['M'].pop(key_parts[-1], None)
else:
last_val.pop(key_parts[-1], None)
elif action == 'SET': elif action == 'SET':
key, value = value.split("=", 1) key, value = value.split("=", 1)
key = key.strip() key = key.strip()
value = value.strip() value = value.strip()
# If not exists, changes value to a default if needed, else its the same as it was # check whether key is a list
if value.startswith('if_not_exists'): attr, list_index = attribute_is_list(key.split('.')[0])
# Function signature # If value not exists, changes value to a default if needed, else its the same as it was
match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value) value = self._get_default(value)
if not match:
raise TypeError
path, value = match.groups()
# If it already exists, get its value so we dont overwrite it
if path in self.attrs:
value = self.attrs[path]
if type(value) != DynamoType: if type(value) != DynamoType:
if value in expression_attribute_values: if value in expression_attribute_values:
@ -311,55 +311,12 @@ class Item(BaseModel):
else: else:
dyn_value = value dyn_value = value
if '.' not in key: if '.' in key and attr not in self.attrs:
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key) raise ValueError # Setting nested attr not allowed if first attr does not exist yet
if list_index_update: elif attr not in self.attrs:
key_attr = self.attrs[list_index_update.group(1)] self.attrs[attr] = dyn_value # set new top-level attribute
list_index = int(list_index_update.group(2))
if key_attr.is_list():
if len(key_attr.value) > list_index:
key_attr.value[list_index] = expression_attribute_values[value]
else:
key_attr.value.append(expression_attribute_values[value])
else:
raise InvalidUpdateExpression
else:
self.attrs[key] = dyn_value
else: else:
# Handle nested dict updates self.attrs[attr].set('.'.join(key.split('.')[1:]), dyn_value, list_index) # set value recursively
key_parts = key.split('.')
attr = key_parts.pop(0)
if attr not in self.attrs:
raise ValueError
last_val = self.attrs[attr].value
for key_part in key_parts:
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part)
if list_index_update:
key_part = list_index_update.group(1) # listattr[1] ==> listattr
# Hack but it'll do, traverses into a dict
last_val_type = list(last_val.keys())
if last_val_type and last_val_type[0] == 'M':
last_val = last_val['M']
if key_part not in last_val:
last_val[key_part] = {'M': {}}
last_val = last_val[key_part]
current_type = list(last_val.keys())[0]
if list_index_update:
# We need to add an item to a list
list_index = int(list_index_update.group(2))
if len(last_val['L']) > list_index:
last_val['L'][list_index] = expression_attribute_values[value]
else:
last_val['L'].append(expression_attribute_values[value])
else:
# We have reference to a nested object but we cant just assign to it
if current_type == dyn_value.type:
last_val[current_type] = dyn_value.value
else:
last_val[dyn_value.type] = dyn_value.value
del last_val[current_type]
elif action == 'ADD': elif action == 'ADD':
key, value = value.split(" ", 1) key, value = value.split(" ", 1)
@ -413,6 +370,20 @@ class Item(BaseModel):
else: else:
raise NotImplementedError('{} update action not yet supported'.format(action)) raise NotImplementedError('{} update action not yet supported'.format(action))
def _get_default(self, value):
if value.startswith('if_not_exists'):
# Function signature
match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value)
if not match:
raise TypeError
path, value = match.groups()
# If it already exists, get its value so we dont overwrite it
if path in self.attrs:
value = self.attrs[path]
return value
def update_with_attribute_updates(self, attribute_updates): def update_with_attribute_updates(self, attribute_updates):
for attribute_name, update_action in attribute_updates.items(): for attribute_name, update_action in attribute_updates.items():
action = update_action['Action'] action = update_action['Action']
@ -810,7 +781,6 @@ class Table(BaseModel):
else: else:
possible_results = [item for item in list(self.all_items()) if isinstance( possible_results = [item for item in list(self.all_items()) if isinstance(
item, Item) and item.hash_key == hash_key] item, Item) and item.hash_key == hash_key]
if range_comparison: if range_comparison:
if index_name and not index_range_key: if index_name and not index_range_key:
raise ValueError( raise ValueError(

View File

@ -2161,20 +2161,11 @@ def test_condition_expression__attr_doesnt_exist():
client.create_table( client.create_table(
TableName='test', TableName='test',
KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}],
AttributeDefinitions=[ AttributeDefinitions=[{'AttributeName': 'forum_name', 'AttributeType': 'S'}],
{'AttributeName': 'forum_name', 'AttributeType': 'S'}, ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1})
],
ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1},
)
client.put_item(
TableName='test',
Item={
'forum_name': {'S': 'foo'},
'ttl': {'N': 'bar'},
}
)
client.put_item(TableName='test',
Item={'forum_name': {'S': 'foo'}, 'ttl': {'N': 'bar'}})
def update_if_attr_doesnt_exist(): def update_if_attr_doesnt_exist():
# Test nonexistent top-level attribute. # Test nonexistent top-level attribute.
@ -2261,6 +2252,7 @@ def test_condition_expression__and_order():
} }
) )
@mock_dynamodb2 @mock_dynamodb2
def test_query_gsi_with_range_key(): def test_query_gsi_with_range_key():
dynamodb = boto3.client('dynamodb', region_name='us-east-1') dynamodb = boto3.client('dynamodb', region_name='us-east-1')
@ -2510,13 +2502,15 @@ def test_index_with_unknown_attributes_should_fail():
def test_update_list_index__set_existing_index(): def test_update_list_index__set_existing_index():
table_name = 'test_list_index_access' table_name = 'test_list_index_access'
client = create_table_with_list(table_name) client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}},
UpdateExpression='set itemlist[1]=:Item', UpdateExpression='set itemlist[1]=:Item',
ExpressionAttributeValues={':Item': {'S': 'bar2_update'}}) ExpressionAttributeValues={':Item': {'S': 'bar2_update'}})
# #
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
assert result['id'] == {'S': 'foo'} result['id'].should.equal({'S': 'foo'})
assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]} result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]})
@mock_dynamodb2 @mock_dynamodb2
@ -2530,14 +2524,16 @@ def test_update_list_index__set_existing_nested_index():
ExpressionAttributeValues={':Item': {'S': 'bar2_update'}}) ExpressionAttributeValues={':Item': {'S': 'bar2_update'}})
# #
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item']
assert result['id'] == {'S': 'foo2'} result['id'].should.equal({'S': 'foo2'})
assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}] result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}])
@mock_dynamodb2 @mock_dynamodb2
def test_update_list_index__set_index_out_of_range(): def test_update_list_index__set_index_out_of_range():
table_name = 'test_list_index_access' table_name = 'test_list_index_access'
client = create_table_with_list(table_name) client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}},
UpdateExpression='set itemlist[10]=:Item', UpdateExpression='set itemlist[10]=:Item',
ExpressionAttributeValues={':Item': {'S': 'bar10'}}) ExpressionAttributeValues={':Item': {'S': 'bar10'}})
@ -2562,6 +2558,25 @@ def test_update_list_index__set_nested_index_out_of_range():
assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}, {'S': 'bar10'}] assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}, {'S': 'bar10'}]
@mock_dynamodb2
def test_update_list_index__set_double_nested_index():
table_name = 'test_list_index_access'
client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo2'},
'itemmap': {'M': {'itemlist': {'L': [{'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}},
{'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar21'}}}]}}}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}},
UpdateExpression='set itemmap.itemlist[1].foos=:Item',
ExpressionAttributeValues={':Item': {'S': 'bar22'}})
#
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item']
assert result['id'] == {'S': 'foo2'}
len(result['itemmap']['M']['itemlist']['L']).should.equal(2)
result['itemmap']['M']['itemlist']['L'][0].should.equal({'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}}) # unchanged
result['itemmap']['M']['itemlist']['L'][1].should.equal({'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar22'}}}) # updated
@mock_dynamodb2 @mock_dynamodb2
def test_update_list_index__set_index_of_a_string(): def test_update_list_index__set_index_of_a_string():
table_name = 'test_list_index_access' table_name = 'test_list_index_access'
@ -2578,15 +2593,29 @@ def test_update_list_index__set_index_of_a_string():
'The document path provided in the update expression is invalid for update') 'The document path provided in the update expression is invalid for update')
@mock_dynamodb2
def test_remove_top_level_attribute():
table_name = 'test_remove'
client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'item': {'S': 'bar'}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE item')
#
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
result.should.equal({'id': {'S': 'foo'}})
@mock_dynamodb2 @mock_dynamodb2
def test_remove_list_index__remove_existing_index(): def test_remove_list_index__remove_existing_index():
table_name = 'test_list_index_access' table_name = 'test_list_index_access'
client = create_table_with_list(table_name) client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[1]') client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[1]')
# #
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
assert result['id'] == {'S': 'foo'} result['id'].should.equal({'S': 'foo'})
assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar3'}]} result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar3'}]})
@mock_dynamodb2 @mock_dynamodb2
@ -2598,8 +2627,8 @@ def test_remove_list_index__remove_existing_nested_index():
client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, UpdateExpression='REMOVE itemmap.itemlist[1]') client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, UpdateExpression='REMOVE itemmap.itemlist[1]')
# #
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item']
assert result['id'] == {'S': 'foo2'} result['id'].should.equal({'S': 'foo2'})
assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}] result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}])
@mock_dynamodb2 @mock_dynamodb2
@ -2626,6 +2655,8 @@ def test_remove_list_index__remove_existing_double_nested_index():
def test_remove_list_index__remove_index_out_of_range(): def test_remove_list_index__remove_index_out_of_range():
table_name = 'test_list_index_access' table_name = 'test_list_index_access'
client = create_table_with_list(table_name) client = create_table_with_list(table_name)
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[10]') client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[10]')
# #
result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item']
@ -2639,8 +2670,6 @@ def create_table_with_list(table_name):
KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}],
AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}],
BillingMode='PAY_PER_REQUEST') BillingMode='PAY_PER_REQUEST')
client.put_item(TableName=table_name,
Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}})
return client return client