diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 6651895dc..d4704bad7 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -34,14 +34,76 @@ def bytesize(val): return len(str(val).encode('utf-8')) +def attribute_is_list(attr): + """ + Checks if attribute denotes a list, and returns the regular expression if so + :param attr: attr or attr[index] + :return: attr, re or None + """ + list_index_update = re.match('(.+)\\[([0-9]+)\\]', attr) + if list_index_update: + attr = list_index_update.group(1) + return attr, list_index_update.group(2) if list_index_update else None + + class DynamoType(object): """ http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes """ def __init__(self, type_as_dict): - self.type = list(type_as_dict)[0] - self.value = list(type_as_dict.values())[0] + if type(type_as_dict) == DynamoType: + self.type = type_as_dict.type + self.value = type_as_dict.value + else: + self.type = list(type_as_dict)[0] + self.value = list(type_as_dict.values())[0] + if self.is_list(): + self.value = [DynamoType(val) for val in self.value] + elif self.is_map(): + self.value = dict((k, DynamoType(v)) for k, v in self.value.items()) + + def set(self, key, new_value, index=None): + if index: + index = int(index) + if type(self.value) is not list: + raise InvalidUpdateExpression + if index >= len(self.value): + self.value.append(new_value) + # {'L': [DynamoType, ..]} ==> DynamoType.set() + self.value[min(index, len(self.value) - 1)].set(key, new_value) + else: + attr = (key or '').split('.').pop(0) + attr, list_index = attribute_is_list(attr) + if not key: + # {'S': value} ==> {'S': new_value} + self.value = new_value.value + else: + if attr not in self.value: # nonexistingattribute + type_of_new_attr = 'M' if '.' in key else new_value.type + self.value[attr] = DynamoType({type_of_new_attr: {}}) + # {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value) + self.value[attr].set('.'.join(key.split('.')[1:]), new_value, list_index) + + def delete(self, key, index=None): + if index: + if not key: + if int(index) < len(self.value): + del self.value[int(index)] + elif '.' in key: + self.value[int(index)].delete('.'.join(key.split('.')[1:])) + else: + self.value[int(index)].delete(key) + else: + attr = key.split('.')[0] + attr, list_index = attribute_is_list(attr) + + if list_index: + self.value[attr].delete('.'.join(key.split('.')[1:]), list_index) + elif '.' in key: + self.value[attr].delete('.'.join(key.split('.')[1:])) + else: + self.value.pop(key) def __hash__(self): return hash((self.type, self.value)) @@ -98,7 +160,7 @@ class DynamoType(object): if isinstance(key, int) and self.is_list(): idx = key - if idx >= 0 and idx < len(self.value): + if 0 <= idx < len(self.value): return DynamoType(self.value[idx]) return None @@ -110,7 +172,7 @@ class DynamoType(object): sub_type = self.type[0] value_size = sum([DynamoType({sub_type: v}).size() for v in self.value]) elif self.is_list(): - value_size = sum([DynamoType(v).size() for v in self.value]) + value_size = sum([v.size() for v in self.value]) elif self.is_map(): value_size = sum([bytesize(k) + DynamoType(v).size() for k, v in self.value.items()]) elif type(self.value) == bool: @@ -162,22 +224,6 @@ class LimitedSizeDict(dict): raise ItemSizeTooLarge super(LimitedSizeDict, self).__setitem__(key, value) - def update(self, *args, **kwargs): - if args: - if len(args) > 1: - raise TypeError("update expected at most 1 arguments, " - "got %d" % len(args)) - other = dict(args[0]) - for key in other: - self[key] = other[key] - for key in kwargs: - self[key] = kwargs[key] - - def setdefault(self, key, value=None): - if key not in self: - self[key] = value - return self[key] - class Item(BaseModel): @@ -236,72 +282,26 @@ class Item(BaseModel): if action == "REMOVE": key = value + attr, list_index = attribute_is_list(key.split('.')[0]) if '.' not in key: - list_index_update = re.match('(.+)\\[([0-9]+)\\]', key) - if list_index_update: - # We need to remove an item from a list (REMOVE listattr[0]) - key_attr = self.attrs[list_index_update.group(1)] - list_index = int(list_index_update.group(2)) - if key_attr.is_list(): - if len(key_attr.value) > list_index: - del key_attr.value[list_index] + if list_index: + new_list = DynamoType(self.attrs[attr]) + new_list.delete(None, list_index) + self.attrs[attr] = new_list else: self.attrs.pop(value, None) else: # Handle nested dict updates - key_parts = key.split('.') - attr = key_parts.pop(0) - if attr not in self.attrs: - raise ValueError - - last_val = self.attrs[attr].value - for key_part in key_parts[:-1]: - list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part) - if list_index_update: - key_part = list_index_update.group(1) # listattr[1] ==> listattr - # Hack but it'll do, traverses into a dict - last_val_type = list(last_val.keys()) - if last_val_type and last_val_type[0] == 'M': - last_val = last_val['M'] - - if key_part not in last_val: - last_val[key_part] = {'M': {}} - - last_val = last_val[key_part] - if list_index_update: - last_val = last_val['L'][int(list_index_update.group(2))] - - last_val_type = list(last_val.keys()) - list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_parts[-1]) - if list_index_update: - # We need to remove an item from a list (REMOVE attr.listattr[0]) - key_part = list_index_update.group(1) # listattr[1] ==> listattr - list_to_update = last_val[key_part]['L'] - index_to_remove = int(list_index_update.group(2)) - if index_to_remove < len(list_to_update): - del list_to_update[index_to_remove] - else: - if last_val_type and last_val_type[0] == 'M': - last_val['M'].pop(key_parts[-1], None) - else: - last_val.pop(key_parts[-1], None) + self.attrs[attr].delete('.'.join(key.split('.')[1:])) elif action == 'SET': key, value = value.split("=", 1) key = key.strip() value = value.strip() - # If not exists, changes value to a default if needed, else its the same as it was - if value.startswith('if_not_exists'): - # Function signature - match = re.match(r'.*if_not_exists\s*\((?P.+),\s*(?P.+)\).*', value) - if not match: - raise TypeError - - path, value = match.groups() - - # If it already exists, get its value so we dont overwrite it - if path in self.attrs: - value = self.attrs[path] + # check whether key is a list + attr, list_index = attribute_is_list(key.split('.')[0]) + # If value not exists, changes value to a default if needed, else its the same as it was + value = self._get_default(value) if type(value) != DynamoType: if value in expression_attribute_values: @@ -311,55 +311,12 @@ class Item(BaseModel): else: dyn_value = value - if '.' not in key: - list_index_update = re.match('(.+)\\[([0-9]+)\\]', key) - if list_index_update: - key_attr = self.attrs[list_index_update.group(1)] - list_index = int(list_index_update.group(2)) - if key_attr.is_list(): - if len(key_attr.value) > list_index: - key_attr.value[list_index] = expression_attribute_values[value] - else: - key_attr.value.append(expression_attribute_values[value]) - else: - raise InvalidUpdateExpression - else: - self.attrs[key] = dyn_value + if '.' in key and attr not in self.attrs: + raise ValueError # Setting nested attr not allowed if first attr does not exist yet + elif attr not in self.attrs: + self.attrs[attr] = dyn_value # set new top-level attribute else: - # Handle nested dict updates - key_parts = key.split('.') - attr = key_parts.pop(0) - if attr not in self.attrs: - raise ValueError - last_val = self.attrs[attr].value - for key_part in key_parts: - list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part) - if list_index_update: - key_part = list_index_update.group(1) # listattr[1] ==> listattr - # Hack but it'll do, traverses into a dict - last_val_type = list(last_val.keys()) - if last_val_type and last_val_type[0] == 'M': - last_val = last_val['M'] - - if key_part not in last_val: - last_val[key_part] = {'M': {}} - last_val = last_val[key_part] - - current_type = list(last_val.keys())[0] - if list_index_update: - # We need to add an item to a list - list_index = int(list_index_update.group(2)) - if len(last_val['L']) > list_index: - last_val['L'][list_index] = expression_attribute_values[value] - else: - last_val['L'].append(expression_attribute_values[value]) - else: - # We have reference to a nested object but we cant just assign to it - if current_type == dyn_value.type: - last_val[current_type] = dyn_value.value - else: - last_val[dyn_value.type] = dyn_value.value - del last_val[current_type] + self.attrs[attr].set('.'.join(key.split('.')[1:]), dyn_value, list_index) # set value recursively elif action == 'ADD': key, value = value.split(" ", 1) @@ -413,6 +370,20 @@ class Item(BaseModel): else: raise NotImplementedError('{} update action not yet supported'.format(action)) + def _get_default(self, value): + if value.startswith('if_not_exists'): + # Function signature + match = re.match(r'.*if_not_exists\s*\((?P.+),\s*(?P.+)\).*', value) + if not match: + raise TypeError + + path, value = match.groups() + + # If it already exists, get its value so we dont overwrite it + if path in self.attrs: + value = self.attrs[path] + return value + def update_with_attribute_updates(self, attribute_updates): for attribute_name, update_action in attribute_updates.items(): action = update_action['Action'] @@ -810,7 +781,6 @@ class Table(BaseModel): else: possible_results = [item for item in list(self.all_items()) if isinstance( item, Item) and item.hash_key == hash_key] - if range_comparison: if index_name and not index_range_key: raise ValueError( diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index 3d9914f14..08ba6428c 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -2161,20 +2161,11 @@ def test_condition_expression__attr_doesnt_exist(): client.create_table( TableName='test', KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], - AttributeDefinitions=[ - {'AttributeName': 'forum_name', 'AttributeType': 'S'}, - ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, - ) - - client.put_item( - TableName='test', - Item={ - 'forum_name': {'S': 'foo'}, - 'ttl': {'N': 'bar'}, - } - ) + AttributeDefinitions=[{'AttributeName': 'forum_name', 'AttributeType': 'S'}], + ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}) + client.put_item(TableName='test', + Item={'forum_name': {'S': 'foo'}, 'ttl': {'N': 'bar'}}) def update_if_attr_doesnt_exist(): # Test nonexistent top-level attribute. @@ -2261,6 +2252,7 @@ def test_condition_expression__and_order(): } ) + @mock_dynamodb2 def test_query_gsi_with_range_key(): dynamodb = boto3.client('dynamodb', region_name='us-east-1') @@ -2510,13 +2502,15 @@ def test_index_with_unknown_attributes_should_fail(): def test_update_list_index__set_existing_index(): table_name = 'test_list_index_access' client = create_table_with_list(table_name) + client.put_item(TableName=table_name, + Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}) client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='set itemlist[1]=:Item', ExpressionAttributeValues={':Item': {'S': 'bar2_update'}}) # result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] - assert result['id'] == {'S': 'foo'} - assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]} + result['id'].should.equal({'S': 'foo'}) + result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]}) @mock_dynamodb2 @@ -2530,14 +2524,16 @@ def test_update_list_index__set_existing_nested_index(): ExpressionAttributeValues={':Item': {'S': 'bar2_update'}}) # result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] - assert result['id'] == {'S': 'foo2'} - assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}] + result['id'].should.equal({'S': 'foo2'}) + result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]) @mock_dynamodb2 def test_update_list_index__set_index_out_of_range(): table_name = 'test_list_index_access' client = create_table_with_list(table_name) + client.put_item(TableName=table_name, + Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}) client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='set itemlist[10]=:Item', ExpressionAttributeValues={':Item': {'S': 'bar10'}}) @@ -2562,6 +2558,25 @@ def test_update_list_index__set_nested_index_out_of_range(): assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}, {'S': 'bar10'}] +@mock_dynamodb2 +def test_update_list_index__set_double_nested_index(): + table_name = 'test_list_index_access' + client = create_table_with_list(table_name) + client.put_item(TableName=table_name, + Item={'id': {'S': 'foo2'}, + 'itemmap': {'M': {'itemlist': {'L': [{'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}}, + {'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar21'}}}]}}}}) + client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, + UpdateExpression='set itemmap.itemlist[1].foos=:Item', + ExpressionAttributeValues={':Item': {'S': 'bar22'}}) + # + result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] + assert result['id'] == {'S': 'foo2'} + len(result['itemmap']['M']['itemlist']['L']).should.equal(2) + result['itemmap']['M']['itemlist']['L'][0].should.equal({'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}}) # unchanged + result['itemmap']['M']['itemlist']['L'][1].should.equal({'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar22'}}}) # updated + + @mock_dynamodb2 def test_update_list_index__set_index_of_a_string(): table_name = 'test_list_index_access' @@ -2578,15 +2593,29 @@ def test_update_list_index__set_index_of_a_string(): 'The document path provided in the update expression is invalid for update') +@mock_dynamodb2 +def test_remove_top_level_attribute(): + table_name = 'test_remove' + client = create_table_with_list(table_name) + client.put_item(TableName=table_name, + Item={'id': {'S': 'foo'}, 'item': {'S': 'bar'}}) + client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE item') + # + result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] + result.should.equal({'id': {'S': 'foo'}}) + + @mock_dynamodb2 def test_remove_list_index__remove_existing_index(): table_name = 'test_list_index_access' client = create_table_with_list(table_name) + client.put_item(TableName=table_name, + Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}) client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[1]') # result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] - assert result['id'] == {'S': 'foo'} - assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar3'}]} + result['id'].should.equal({'S': 'foo'}) + result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar3'}]}) @mock_dynamodb2 @@ -2598,8 +2627,8 @@ def test_remove_list_index__remove_existing_nested_index(): client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, UpdateExpression='REMOVE itemmap.itemlist[1]') # result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] - assert result['id'] == {'S': 'foo2'} - assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}] + result['id'].should.equal({'S': 'foo2'}) + result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}]) @mock_dynamodb2 @@ -2626,6 +2655,8 @@ def test_remove_list_index__remove_existing_double_nested_index(): def test_remove_list_index__remove_index_out_of_range(): table_name = 'test_list_index_access' client = create_table_with_list(table_name) + client.put_item(TableName=table_name, + Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}) client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[10]') # result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] @@ -2639,8 +2670,6 @@ def create_table_with_list(table_name): KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], BillingMode='PAY_PER_REQUEST') - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}) return client