diff --git a/moto/dynamodb/parsing/validators.py b/moto/dynamodb/parsing/validators.py index 53d90b1df..fd54c7c80 100644 --- a/moto/dynamodb/parsing/validators.py +++ b/moto/dynamodb/parsing/validators.py @@ -1,6 +1,7 @@ """ See docstring class Validator below for more details on validation """ + from abc import abstractmethod from copy import deepcopy from typing import Any, Callable, Dict, List, Type, Union @@ -12,6 +13,7 @@ from moto.dynamodb.exceptions import ( ExpressionAttributeNameNotDefined, ExpressionAttributeValueNotDefined, IncorrectOperandType, + InvalidAttributeTypeError, InvalidUpdateExpressionInvalidDocumentPath, MockValidationException, ProvidedKeyDoesNotExist, @@ -379,6 +381,34 @@ class EmptyStringKeyValueValidator(DepthFirstTraverser): # type: ignore[misc] return node +class TypeMismatchValidator(DepthFirstTraverser): # type: ignore[misc] + def __init__(self, key_attributes_type: List[Dict[str, str]]): + self.key_attributes_type = key_attributes_type + + def _processing_map( + self, + ) -> Dict[ + Type[UpdateExpressionSetAction], + Callable[[UpdateExpressionSetAction], UpdateExpressionSetAction], + ]: + return {UpdateExpressionSetAction: self.check_for_type_mismatch} + + def check_for_type_mismatch( + self, node: UpdateExpressionSetAction + ) -> UpdateExpressionSetAction: + """A node representing a SET action. Check that type matches with the definition""" + assert isinstance(node, UpdateExpressionSetAction) + assert len(node.children) == 2 + key = node.children[0].children[0].children[0] + val_node = node.children[1].children[0] + for dct in self.key_attributes_type: + if dct["AttributeName"] == key and dct["AttributeType"] != val_node.type: + raise InvalidAttributeTypeError( + key, dct["AttributeType"], val_node.type + ) + return node + + class UpdateHashRangeKeyValidator(DepthFirstTraverser): # type: ignore[misc] def __init__( self, @@ -464,6 +494,7 @@ class UpdateExpressionValidator(Validator): UpdateExpressionFunctionEvaluator(), NoneExistingPathChecker(), ExecuteOperations(), + TypeMismatchValidator(self.table.attr), EmptyStringKeyValueValidator(self.table.attribute_keys), ] return processors diff --git a/tests/test_dynamodb/test_dynamodb.py b/tests/test_dynamodb/test_dynamodb.py index 808d48d2f..8feeb8fa2 100644 --- a/tests/test_dynamodb/test_dynamodb.py +++ b/tests/test_dynamodb/test_dynamodb.py @@ -5880,3 +5880,80 @@ def test_invalid_projection_expressions(): ClientError, match="ProjectionExpression: Attribute name contains white space" ): client.scan(TableName=table_name, ProjectionExpression="not_a_keyword, na me") + + +@mock_aws +def test_update_item_with_global_secondary_index(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table + dynamodb.create_table( + TableName="test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "gsi_hash_key_s", "AttributeType": "S"}, + {"AttributeName": "gsi_hash_key_b", "AttributeType": "B"}, + {"AttributeName": "gsi_hash_key_n", "AttributeType": "N"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + GlobalSecondaryIndexes=[ + { + "IndexName": "test_gsi_s", + "KeySchema": [ + {"AttributeName": "gsi_hash_key_s", "KeyType": "HASH"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, + }, + }, + { + "IndexName": "test_gsi_b", + "KeySchema": [ + {"AttributeName": "gsi_hash_key_b", "KeyType": "HASH"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, + }, + }, + { + "IndexName": "test_gsi_n", + "KeySchema": [ + {"AttributeName": "gsi_hash_key_n", "KeyType": "HASH"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, + }, + }, + ], + ) + table = dynamodb.Table("test") + + table.put_item( + Item={"id": "test1"}, + ) + + for key_name, values in { + "gsi_hash_key_s": [None, 0, b"binary"], + "gsi_hash_key_b": [None, "", 0], + "gsi_hash_key_n": [None, "", b"binary"], + }.items(): + for v in values: + with pytest.raises(ClientError) as ex: + table.update_item( + Key={"id": "test1"}, + UpdateExpression=f"SET {key_name} = :gsi_hash_key", + ExpressionAttributeValues={":gsi_hash_key": v, ":abc": ""}, + ) + err = ex.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + "One or more parameter values were invalid: Type mismatch" + in err["Message"] + )