diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index e8b5254ef..121f564a4 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -63,6 +63,16 @@ class DynamoType(object): elif self.is_map(): self.value = dict((k, DynamoType(v)) for k, v in self.value.items()) + def get(self, key): + if not key: + return self + else: + key_head = key.split(".")[0] + key_tail = ".".join(key.split(".")[1:]) + if key_head not in self.value: + self.value[key_head] = DynamoType({"NONE": None}) + return self.value[key_head].get(key_tail) + def set(self, key, new_value, index=None): if index: index = int(index) @@ -388,11 +398,19 @@ class Item(BaseModel): # created with only this value if it doesn't exist yet # New value must be of same set type as previous value elif dyn_value.is_set(): - existing = self.attrs.get(key, DynamoType({dyn_value.type: {}})) - if not existing.same_type(dyn_value): + key_head = key.split(".")[0] + key_tail = ".".join(key.split(".")[1:]) + if key_head not in self.attrs: + self.attrs[key_head] = DynamoType({dyn_value.type: {}}) + existing = self.attrs.get(key_head) + existing = existing.get(key_tail) + if existing.value and not existing.same_type(dyn_value): raise TypeError() - new_set = set(existing.value).union(dyn_value.value) - self.attrs[key] = DynamoType({existing.type: list(new_set)}) + new_set = set(existing.value or []).union(dyn_value.value) + existing.set( + key=None, + new_value=DynamoType({dyn_value.type: list(new_set)}), + ) else: # Number and Sets are the only supported types for ADD raise TypeError @@ -407,12 +425,18 @@ class Item(BaseModel): if not dyn_value.is_set(): raise TypeError - existing = self.attrs.get(key, None) + key_head = key.split(".")[0] + key_tail = ".".join(key.split(".")[1:]) + existing = self.attrs.get(key_head) + existing = existing.get(key_tail) if existing: if not existing.same_type(dyn_value): raise TypeError new_set = set(existing.value).difference(dyn_value.value) - self.attrs[key] = DynamoType({existing.type: list(new_set)}) + existing.set( + key=None, + new_value=DynamoType({existing.type: list(new_set)}), + ) else: raise NotImplementedError( "{} update action not yet supported".format(action) diff --git a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py index b12b41ac0..7c7770874 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py @@ -1289,6 +1289,16 @@ def test_update_item_add_with_expression(): current_item["str_set"] = current_item["str_set"].union({"item4"}) dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) + # Update item to add a string value to a non-existing set + # Should just create the set in the background + table.update_item( + Key=item_key, + UpdateExpression="ADD non_existing_str_set :v", + ExpressionAttributeValues={":v": {"item4"}}, + ) + current_item["non_existing_str_set"] = {"item4"} + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) + # Update item to add a num value to a num set table.update_item( Key=item_key, @@ -1336,6 +1346,69 @@ def test_update_item_add_with_expression(): ).should.have.raised(ClientError) +@mock_dynamodb2 +def test_update_item_add_with_nested_sets(): + table = _create_table_with_range_key() + + item_key = {"forum_name": "the-key", "subject": "123"} + current_item = { + "forum_name": "the-key", + "subject": "123", + "nested": {"str_set": {"item1", "item2", "item3"}}, + } + + # Put an entry in the DB to play with + table.put_item(Item=current_item) + + # Update item to add a string value to a nested string set + table.update_item( + Key=item_key, + UpdateExpression="ADD nested.str_set :v", + ExpressionAttributeValues={":v": {"item4"}}, + ) + current_item["nested"]["str_set"] = current_item["nested"]["str_set"].union( + {"item4"} + ) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) + + # Update item to add a string value to a non-existing set + # Should just create the set in the background + table.update_item( + Key=item_key, + UpdateExpression="ADD #ns.#ne :v", + ExpressionAttributeNames={"#ns": "nested", "#ne": "non_existing_str_set"}, + ExpressionAttributeValues={":v": {"new_item"}}, + ) + current_item["nested"]["non_existing_str_set"] = {"new_item"} + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) + + +@mock_dynamodb2 +def test_update_item_delete_with_nested_sets(): + table = _create_table_with_range_key() + + item_key = {"forum_name": "the-key", "subject": "123"} + current_item = { + "forum_name": "the-key", + "subject": "123", + "nested": {"str_set": {"item1", "item2", "item3"}}, + } + + # Put an entry in the DB to play with + table.put_item(Item=current_item) + + # Update item to add a string value to a nested string set + table.update_item( + Key=item_key, + UpdateExpression="DELETE nested.str_set :v", + ExpressionAttributeValues={":v": {"item3"}}, + ) + current_item["nested"]["str_set"] = current_item["nested"]["str_set"].difference( + {"item3"} + ) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) + + @mock_dynamodb2 def test_update_item_delete_with_expression(): table = _create_table_with_range_key()