diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 8e5ecd4b6..5df0c0c28 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -105,6 +105,26 @@ class Item(object): "Item": included } + def update(self, update_expression): + ACTION_VALUES = ['SET', 'REMOVE'] + + action = None + for value in update_expression.split(): + if value in ACTION_VALUES: + # An action + action = value + continue + else: + # A Real value + value = value.lstrip(":").rstrip(",") + + if action == "REMOVE": + self.attrs.pop(value, None) + elif action == 'SET': + key, value = value.split("=:") + # TODO deal with other types + self.attrs[key] = DynamoType({"S": value}) + class Table(object): @@ -182,7 +202,7 @@ class Table(object): def has_range_key(self): return self.range_key_attr is not None - def get_item(self, hash_key, range_key): + def get_item(self, hash_key, range_key=None): if self.has_range_key and not range_key: raise ValueError("Table has a range key, but no range key was passed into get_item") try: @@ -293,8 +313,11 @@ class DynamoDBBackend(BaseBackend): range_key = DynamoType(keys[table.range_key_attr]) if table.has_range_key else None return hash_key, range_key + def get_table(self, table_name): + return self.tables.get(table_name) + def get_item(self, table_name, keys): - table = self.tables.get(table_name) + table = self.get_table(table_name) if not table: return None hash_key, range_key = self.get_keys_value(table, keys) @@ -322,6 +345,14 @@ class DynamoDBBackend(BaseBackend): return table.scan(scan_filters) + def update_item(self, table_name, key, update_expression): + table = self.get_table(table_name) + + hash_value = DynamoType(key) + item = table.get_item(hash_value) + item.update(update_expression) + return item + def delete_item(self, table_name, keys): table = self.tables.get(table_name) if not table: diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 9d6d31f73..4cc064bf6 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -315,3 +315,13 @@ class DynamoHandler(BaseResponse): else: er = 'com.amazonaws.dynamodb.v20120810#ConditionalCheckFailedException' return self.error(er) + + def update_item(self): + name = self.body['TableName'] + key = self.body['Key'] + update_expression = self.body['UpdateExpression'] + item = dynamodb_backend2.update_item(name, key, update_expression) + + item_dict = item.to_json() + item_dict['ConsumedCapacityUnits'] = 0.5 + return dynamo_json_dump(item_dict) diff --git a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py index 15348a24a..9c68b0cd4 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py @@ -384,3 +384,55 @@ def test_get_special_item(): table.put_item(data=data) returned_item = table.get_item(**{'date-joined': 127549192}) dict(returned_item).should.equal(data) + + +@mock_dynamodb2 +def test_update_item_remove(): + conn = boto.dynamodb2.connect_to_region("us-west-2") + table = Table.create('messages', schema=[ + HashKey('username') + ]) + + data = { + 'username': "steve", + 'SentBy': 'User A', + 'SentTo': 'User B', + } + table.put_item(data=data) + key_map = { + "S": "steve" + } + + # Then remove the SentBy field + conn.update_item("messages", key_map, update_expression="REMOVE :SentBy, :SentTo") + + returned_item = table.get_item(username="steve") + dict(returned_item).should.equal({ + 'username': "steve", + }) + + +@mock_dynamodb2 +def test_update_item_set(): + conn = boto.dynamodb2.connect_to_region("us-west-2") + table = Table.create('messages', schema=[ + HashKey('username') + ]) + + data = { + 'username': "steve", + 'SentBy': 'User A', + } + table.put_item(data=data) + key_map = { + "S": "steve" + } + + conn.update_item("messages", key_map, update_expression="SET foo=:bar, blah=:baz REMOVE :SentBy") + + returned_item = table.get_item(username="steve") + dict(returned_item).should.equal({ + 'username': "steve", + 'foo': 'bar', + 'blah': 'baz', + })