diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index bdc5743c2..5915d6eea 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -57,7 +57,7 @@ class DynamoType(object): @property def cast_value(self): - if self.type == 'N': + if self.is_number(): try: return int(self.value) except ValueError: @@ -76,6 +76,15 @@ class DynamoType(object): comparison_func = get_comparison_func(range_comparison) return comparison_func(self.cast_value, *range_values) + def is_number(self): + return self.type == 'N' + + def is_set(self): + return self.type == 'SS' or self.type == 'NS' or self.type == 'BS' + + def same_type(self, other): + return self.type == other.type + class Item(BaseModel): @@ -140,6 +149,55 @@ class Item(BaseModel): self.attrs[key] = DynamoType(expression_attribute_values[value]) else: self.attrs[key] = DynamoType({"S": value}) + elif action == 'ADD': + key, value = value.split(" ", 1) + key = key.strip() + value_str = value.strip() + if value_str in expression_attribute_values: + dyn_value = DynamoType(expression_attribute_values[value]) + else: + raise TypeError + + # Handle adding numbers - value gets added to existing value, + # or added to 0 if it doesn't exist yet + if dyn_value.is_number(): + existing = self.attrs.get(key, DynamoType({"N": '0'})) + if not existing.same_type(dyn_value): + raise TypeError() + self.attrs[key] = DynamoType({"N": str( + decimal.Decimal(existing.value) + + decimal.Decimal(dyn_value.value) + )}) + + # Handle adding sets - value is added to the set, or set is + # created with only this value if it doesn't exist yet + # New value must be of same set type as previous value + elif dyn_value.is_set(): + existing = self.attrs.get(key, DynamoType({dyn_value.type: {}})) + if not existing.same_type(dyn_value): + raise TypeError() + new_set = set(existing.value).union(dyn_value.value) + self.attrs[key] = DynamoType({existing.type: list(new_set)}) + else: # Number and Sets are the only supported types for ADD + raise TypeError + + elif action == 'DELETE': + key, value = value.split(" ", 1) + key = key.strip() + value_str = value.strip() + if value_str in expression_attribute_values: + dyn_value = DynamoType(expression_attribute_values[value]) + else: + raise TypeError + + if not dyn_value.is_set(): + raise TypeError + existing = self.attrs.get(key, None) + if existing: + if not existing.same_type(dyn_value): + raise TypeError + new_set = set(existing.value).difference(dyn_value.value) + self.attrs[key] = DynamoType({existing.type: list(new_set)}) else: raise NotImplementedError('{} update action not yet supported'.format(action)) diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 586a1db7b..c3cb4ef72 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -497,6 +497,9 @@ class DynamoHandler(BaseResponse): except ValueError: er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' return self.error(er) + except TypeError: + er = 'com.amazonaws.dynamodb.v20111205#ValidationException' + return self.error(er) item_dict = item.to_json() item_dict['ConsumedCapacityUnits'] = 0.5 diff --git a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py index 93dc5b3ef..a9ab298b7 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py @@ -5,6 +5,7 @@ from decimal import Decimal import boto import boto3 from boto3.dynamodb.conditions import Key +from botocore.exceptions import ClientError import sure # noqa from freezegun import freeze_time from moto import mock_dynamodb2, mock_dynamodb2_deprecated @@ -1190,6 +1191,14 @@ def _create_table_with_range_key(): 'AttributeName': 'subject', 'AttributeType': 'S' }, + { + 'AttributeName': 'username', + 'AttributeType': 'S' + }, + { + 'AttributeName': 'created', + 'AttributeType': 'N' + } ], ProvisionedThroughput={ 'ReadCapacityUnits': 5, @@ -1392,6 +1401,155 @@ def test_update_item_with_expression(): 'subject': '123', }) +@mock_dynamodb2 +def test_update_item_add_with_expression(): + table = _create_table_with_range_key() + + item_key = {'forum_name': 'the-key', 'subject': '123'} + current_item = { + 'forum_name': 'the-key', + 'subject': '123', + 'str_set': {'item1', 'item2', 'item3'}, + 'num_set': {1, 2, 3}, + 'num_val': 6 + } + + # Put an entry in the DB to play with + table.put_item(Item=current_item) + + # Update item to add a string value to a string set + table.update_item( + Key=item_key, + UpdateExpression='ADD str_set :v', + ExpressionAttributeValues={ + ':v': {'item4'} + } + ) + current_item['str_set'] = current_item['str_set'].union({'item4'}) + dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + + # Update item to add a num value to a num set + table.update_item( + Key=item_key, + UpdateExpression='ADD num_set :v', + ExpressionAttributeValues={ + ':v': {6} + } + ) + current_item['num_set'] = current_item['num_set'].union({6}) + dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + + # Update item to add a value to a number value + table.update_item( + Key=item_key, + UpdateExpression='ADD num_val :v', + ExpressionAttributeValues={ + ':v': 20 + } + ) + current_item['num_val'] = current_item['num_val'] + 20 + dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + + # Attempt to add a number value to a string set, should raise Client Error + table.update_item.when.called_with( + Key=item_key, + UpdateExpression='ADD str_set :v', + ExpressionAttributeValues={ + ':v': 20 + } + ).should.have.raised(ClientError) + dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + + # Attempt to add a number set to the string set, should raise a ClientError + table.update_item.when.called_with( + Key=item_key, + UpdateExpression='ADD str_set :v', + ExpressionAttributeValues={ + ':v': { 20 } + } + ).should.have.raised(ClientError) + dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + + # Attempt to update with a bad expression + table.update_item.when.called_with( + Key=item_key, + UpdateExpression='ADD str_set bad_value' + ).should.have.raised(ClientError) + + # Attempt to add a string value instead of a string set + table.update_item.when.called_with( + Key=item_key, + UpdateExpression='ADD str_set :v', + ExpressionAttributeValues={ + ':v': 'new_string' + } + ).should.have.raised(ClientError) + + +@mock_dynamodb2 +def test_update_item_delete_with_expression(): + table = _create_table_with_range_key() + + item_key = {'forum_name': 'the-key', 'subject': '123'} + current_item = { + 'forum_name': 'the-key', + 'subject': '123', + 'str_set': {'item1', 'item2', 'item3'}, + 'num_set': {1, 2, 3}, + 'num_val': 6 + } + + # Put an entry in the DB to play with + table.put_item(Item=current_item) + + # Update item to delete a string value from a string set + table.update_item( + Key=item_key, + UpdateExpression='DELETE str_set :v', + ExpressionAttributeValues={ + ':v': {'item2'} + } + ) + current_item['str_set'] = current_item['str_set'].difference({'item2'}) + dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + + # Update item to delete a num value from a num set + table.update_item( + Key=item_key, + UpdateExpression='DELETE num_set :v', + ExpressionAttributeValues={ + ':v': {2} + } + ) + current_item['num_set'] = current_item['num_set'].difference({2}) + dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + + # Try to delete on a number, this should fail + table.update_item.when.called_with( + Key=item_key, + UpdateExpression='DELETE num_val :v', + ExpressionAttributeValues={ + ':v': 20 + } + ).should.have.raised(ClientError) + dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + + # Try to delete a string set from a number set + table.update_item.when.called_with( + Key=item_key, + UpdateExpression='DELETE num_set :v', + ExpressionAttributeValues={ + ':v': {'del_str'} + } + ).should.have.raised(ClientError) + dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + + # Attempt to update with a bad expression + table.update_item.when.called_with( + Key=item_key, + UpdateExpression='DELETE num_val badvalue' + ).should.have.raised(ClientError) + @mock_dynamodb2 def test_boto3_query_gsi_range_comparison():