diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 43059265c..cdefa0f58 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals from collections import defaultdict import datetime +import decimal import json from moto.compat import OrderedDict @@ -142,6 +143,16 @@ class Item(object): del self.attrs[attribute_name] else: self.attrs[attribute_name] = DynamoType({"S": new_value}) + elif action == 'ADD': + if set(update_action['Value'].keys()) == set(['N']): + existing = self.attrs.get(attribute_name, DynamoType({"N": '0'})) + self.attrs[attribute_name] = DynamoType({"N": str( + decimal.Decimal(existing.value) + + decimal.Decimal(new_value) + )}) + else: + # TODO: implement other data types + raise NotImplementedError('ADD not supported for %s' % ', '.join(update_action['Value'].keys())) class Table(object): @@ -202,18 +213,22 @@ class Table(object): def hash_key_names(self): keys = [self.hash_key_attr] for index in self.global_indexes: + hash_key = None for key in index['KeySchema']: if key['KeyType'] == 'HASH': - keys.append(key['AttributeName']) + hash_key = key['AttributeName'] + keys.append(hash_key) return keys @property def range_key_names(self): keys = [self.range_key_attr] for index in self.global_indexes: + range_key = None for key in index['KeySchema']: if key['KeyType'] == 'RANGE': - keys.append(key['AttributeName']) + range_key = keys.append(key['AttributeName']) + keys.append(range_key) return keys def put_item(self, item_attrs, expected=None, overwrite=False): @@ -276,8 +291,11 @@ class Table(object): try: if range_key: return self.items[hash_key][range_key] - else: + + if hash_key in self.items: return self.items[hash_key] + + raise KeyError except KeyError: return None @@ -462,13 +480,15 @@ class DynamoDBBackend(BaseBackend): if not table: return None, None else: - hash_key = range_key = None - for key in keys: - if key in table.hash_key_names: - hash_key = key - elif key in table.range_key_names: - range_key = key - return hash_key, range_key + if len(keys) == 1: + for key in keys: + if key in table.hash_key_names: + return key, None + + for potential_hash, potential_range in zip(table.hash_key_names, table.range_key_names): + if set([potential_hash, potential_range]) == set(keys): + return potential_hash, potential_range + return None, None def get_keys_value(self, table, keys): if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys): @@ -526,6 +546,23 @@ class DynamoDBBackend(BaseBackend): range_value = None item = table.get_item(hash_value, range_value) + # Update does not fail on new items, so create one + if item is None: + data = { + table.hash_key_attr: { + hash_value.type: hash_value.value, + }, + } + if range_value: + data.update({ + table.range_key_attr: { + range_value.type: range_value.value, + } + }) + + table.put_item(data) + item = table.get_item(hash_value, range_value) + if update_expression: item.update(update_expression) else: diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 08a03c08c..0770acedc 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -402,8 +402,12 @@ class DynamoHandler(BaseResponse): key = self.body['Key'] update_expression = self.body.get('UpdateExpression') attribute_updates = self.body.get('AttributeUpdates') + existing_item = dynamodb_backend2.get_item(name, key) item = dynamodb_backend2.update_item(name, key, update_expression, attribute_updates) item_dict = item.to_json() item_dict['ConsumedCapacityUnits'] = 0.5 + if not existing_item: + item_dict['Attributes'] = {} + return dynamo_json_dump(item_dict) diff --git a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py index bcda68f6d..d223be76c 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py @@ -941,6 +941,98 @@ def test_update_item_range_key_set(): }) + +@mock_dynamodb2 +def test_update_item_does_not_exist_is_created(): + table = _create_table_with_range_key() + + item_key = {'forum_name': 'the-key', 'subject': '123'} + result = table.update_item( + Key=item_key, + AttributeUpdates={ + 'username': { + 'Action': u'PUT', + 'Value': 'johndoe2' + }, + 'created': { + 'Action': u'PUT', + 'Value': Decimal('4'), + }, + 'mapfield': { + 'Action': u'PUT', + 'Value': {'key': 'value'}, + } + }, + ReturnValues='ALL_OLD', + ) + + assert not result.get('Attributes') + + returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)['Item'].items()) + dict(returned_item).should.equal({ + 'username': "johndoe2", + 'forum_name': 'the-key', + 'subject': '123', + 'created': '4', + 'mapfield': {'key': 'value'}, + }) + + +@mock_dynamodb2 +def test_update_item_add_value(): + table = _create_table_with_range_key() + + table.put_item(Item={ + 'forum_name': 'the-key', + 'subject': '123', + 'numeric_field': Decimal('-1'), + }) + + item_key = {'forum_name': 'the-key', 'subject': '123'} + table.update_item( + Key=item_key, + AttributeUpdates={ + 'numeric_field': { + 'Action': u'ADD', + 'Value': Decimal('2'), + }, + }, + ) + + returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)['Item'].items()) + dict(returned_item).should.equal({ + 'numeric_field': '1', + 'forum_name': 'the-key', + 'subject': '123', + }) + + +@mock_dynamodb2 +def test_update_item_add_value_does_not_exist_is_created(): + table = _create_table_with_range_key() + + item_key = {'forum_name': 'the-key', 'subject': '123'} + table.update_item( + Key=item_key, + AttributeUpdates={ + 'numeric_field': { + 'Action': u'ADD', + 'Value': Decimal('2'), + }, + }, + ) + + returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)['Item'].items()) + dict(returned_item).should.equal({ + 'numeric_field': '2', + 'forum_name': 'the-key', + 'subject': '123', + }) + + @mock_dynamodb2 def test_boto3_query_gsi_range_comparison(): table = _create_table_with_range_key() diff --git a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py index 36aac9a87..676694932 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py @@ -431,7 +431,7 @@ def test_update_item_remove(): } table.put_item(data=data) key_map = { - "S": "steve" + 'username': {"S": "steve"} } # Then remove the SentBy field @@ -456,7 +456,7 @@ def test_update_item_set(): } table.put_item(data=data) key_map = { - "S": "steve" + 'username': {"S": "steve"} } conn.update_item("messages", key_map, update_expression="SET foo=:bar, blah=:baz REMOVE :SentBy")