diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 56a8fb4c0..99c965612 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -191,7 +191,7 @@ class Table(object): keys.append(key['AttributeName']) return keys - def put_item(self, item_attrs): + def put_item(self, item_attrs, expected = None, overwrite = False): hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) if self.has_range_key: range_value = DynamoType(item_attrs.get(self.range_key_attr)) @@ -200,6 +200,35 @@ class Table(object): item = Item(hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs) + if not overwrite: + if expected is None: + expected = {} + lookup_range_value = range_value + else: + expected_range_value = expected.get(self.range_key_attr, {}).get("Value") + if(expected_range_value is None): + lookup_range_value = range_value + else: + lookup_range_value = DynamoType(expected_range_value) + + current = self.get_item(hash_value, lookup_range_value) + + if current is None: + current_attr = {} + elif hasattr(current,'attrs'): + current_attr = current.attrs + else: + current_attr = current + + for key, val in expected.items(): + if 'Exists' in val and val['Exists'] == False: + if key in current_attr: + raise ValueError("The conditional request failed") + elif key not in current_attr: + raise ValueError("The conditional request failed") + elif DynamoType(val['Value']).value != current_attr[key].value: + raise ValueError("The conditional request failed") + if range_value: self.items[hash_value][range_value] = item else: @@ -317,11 +346,11 @@ class DynamoDBBackend(BaseBackend): table.throughput = throughput return table - def put_item(self, table_name, item_attrs): + def put_item(self, table_name, item_attrs, expected = None, overwrite = False): table = self.tables.get(table_name) if not table: return None - return table.put_item(item_attrs) + return table.put_item(item_attrs, expected, overwrite) def get_table_keys_name(self, table_name, keys): """ diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 8cee08ebe..5713910d9 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -134,7 +134,17 @@ class DynamoHandler(BaseResponse): def put_item(self): name = self.body['TableName'] item = self.body['Item'] - result = dynamodb_backend2.put_item(name, item) + overwrite = 'Expected' not in self.body + if not overwrite: + expected = self.body['Expected'] + else: + expected = None + + try: + result = dynamodb_backend2.put_item(name, item, expected, overwrite) + except Exception: + er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' + return self.error(er) if result: item_dict = result.to_json() diff --git a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py index a08b4b521..e7ae60e7e 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py @@ -10,6 +10,7 @@ try: from boto.dynamodb2.fields import GlobalAllIndex, HashKey, RangeKey from boto.dynamodb2.table import Item, Table from boto.dynamodb2.exceptions import ValidationException + from boto.dynamodb2.exceptions import ConditionalCheckFailedException except ImportError: pass @@ -553,3 +554,52 @@ def test_lookup(): message = table.lookup(hash_key, range_key) message.get('test_hash').should.equal(Decimal(hash_key)) message.get('test_range').should.equal(Decimal(range_key)) + + +@mock_dynamodb2 +def test_failed_overwrite(): + from decimal import Decimal + table = Table.create('messages', schema=[ + HashKey('id'), + RangeKey('range'), + ], throughput={ + 'read': 7, + 'write': 3, + }) + + data1 = {'id': '123', 'range': 'abc', 'data':'678'} + table.put_item(data=data1) + + data2 = {'id': '123', 'range': 'abc', 'data':'345'} + table.put_item(data=data2, overwrite = True) + + data3 = {'id': '123', 'range': 'abc', 'data':'812'} + table.put_item.when.called_with(data=data3).should.throw(ConditionalCheckFailedException) + + returned_item = table.lookup('123', 'abc') + dict(returned_item).should.equal(data2) + + data4 = {'id': '123', 'range': 'ghi', 'data':812} + table.put_item(data=data4) + + returned_item = table.lookup('123', 'ghi') + dict(returned_item).should.equal(data4) + + +@mock_dynamodb2 +def test_conflicting_writes(): + table = Table.create('messages', schema=[ + HashKey('id'), + RangeKey('range'), + ]) + + item_data = {'id': '123', 'range':'abc', 'data':'678'} + item1 = Item(table, item_data) + item2 = Item(table, item_data) + item1.save() + + item1['data'] = '579' + item2['data'] = '912' + + item1.save() + item2.save.when.called_with().should.throw(ConditionalCheckFailedException) diff --git a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py index 87e56f8d1..0f845759e 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py @@ -10,6 +10,7 @@ try: from boto.dynamodb2.fields import HashKey from boto.dynamodb2.table import Table from boto.dynamodb2.table import Item + from boto.dynamodb2.exceptions import ConditionalCheckFailedException except ImportError: pass @@ -437,3 +438,50 @@ def test_update_item_set(): 'foo': 'bar', 'blah': 'baz', }) + + +@mock_dynamodb2 +def test_failed_overwrite(): + from decimal import Decimal + table = Table.create('messages', schema=[ + HashKey('id'), + ], throughput={ + 'read': 7, + 'write': 3, + }) + + data1 = {'id': '123', 'data':'678'} + table.put_item(data=data1) + + data2 = {'id': '123', 'data':'345'} + table.put_item(data=data2, overwrite = True) + + data3 = {'id': '123', 'data':'812'} + table.put_item.when.called_with(data=data3).should.throw(ConditionalCheckFailedException) + + returned_item = table.lookup('123') + dict(returned_item).should.equal(data2) + + data4 = {'id': '124', 'data':812} + table.put_item(data=data4) + + returned_item = table.lookup('124') + dict(returned_item).should.equal(data4) + + +@mock_dynamodb2 +def test_conflicting_writes(): + table = Table.create('messages', schema=[ + HashKey('id'), + ]) + + item_data = {'id': '123', 'data':'678'} + item1 = Item(table, item_data) + item2 = Item(table, item_data) + item1.save() + + item1['data'] = '579' + item2['data'] = '912' + + item1.save() + item2.save.when.called_with().should.throw(ConditionalCheckFailedException) \ No newline at end of file