From 62647ab1a3adc6d6a18bed106999fac5183d432d Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Fri, 2 Feb 2024 19:29:48 +0000 Subject: [PATCH] DynamoDB: Improve support for Update-statements in batch_execute-statement() (#7297) --- moto/dynamodb/models/__init__.py | 34 +++++---- setup.cfg | 18 ++--- .../test_dynamodb/test_dynamodb_statements.py | 69 +++++++++++++++++++ 3 files changed, 97 insertions(+), 24 deletions(-) diff --git a/moto/dynamodb/models/__init__.py b/moto/dynamodb/models/__init__.py index 1cf919be8..c3be0ff5d 100644 --- a/moto/dynamodb/models/__init__.py +++ b/moto/dynamodb/models/__init__.py @@ -855,27 +855,31 @@ class DynamoDBBackend(BaseBackend): } else: response["TableName"] = table_name - table = self.tables[table_name] - for required_attr in table.table_key_attrs: - if required_attr not in filter_keys: - response["Error"] = { - "Code": "ValidationError", - "Message": "Select statements within BatchExecuteStatement must specify the primary key in the where clause.", - } + if metadata.is_select_query(): + table = self.tables[table_name] + for required_attr in table.table_key_attrs: + if required_attr not in filter_keys: + response["Error"] = { + "Code": "ValidationError", + "Message": "Select statements within BatchExecuteStatement must specify the primary key in the where clause.", + } responses.append(response) # Execution for idx, stmt in enumerate(statements): if "Error" in responses[idx]: continue - items = self.execute_statement( - statement=stmt["Statement"], parameters=stmt.get("Parameters", []) - ) - # Statements should always contain a HashKey and SortKey - # An item with those keys may not exist - if items: - # But if it does, it will always only contain one item at most - responses[idx]["Item"] = items[0] + try: + items = self.execute_statement( + statement=stmt["Statement"], parameters=stmt.get("Parameters", []) + ) + # Statements should always contain a HashKey and SortKey + # An item with those keys may not exist + if items: + # But if it does, it will always only contain one item at most + responses[idx]["Item"] = items[0] + except Exception as e: + responses[idx] = {"Error": {"Code": e.name, "Message": e.message}} # type: ignore return responses diff --git a/setup.cfg b/setup.cfg index 61920adf9..14bd1e8e2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,7 @@ all = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.0 + py-partiql-parser==0.5.1 aws-xray-sdk!=0.96,>=0.93 setuptools multipart @@ -69,7 +69,7 @@ proxy = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.0 + py-partiql-parser==0.5.1 aws-xray-sdk!=0.96,>=0.93 setuptools multipart @@ -84,7 +84,7 @@ server = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.0 + py-partiql-parser==0.5.1 aws-xray-sdk!=0.96,>=0.93 setuptools flask!=2.2.0,!=2.2.1 @@ -122,7 +122,7 @@ cloudformation = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.0 + py-partiql-parser==0.5.1 aws-xray-sdk!=0.96,>=0.93 setuptools cloudfront = @@ -145,10 +145,10 @@ dms = ds = dynamodb = docker>=3.0.0 - py-partiql-parser==0.5.0 + py-partiql-parser==0.5.1 dynamodbstreams = docker>=3.0.0 - py-partiql-parser==0.5.0 + py-partiql-parser==0.5.1 ebs = ec2 = sshpubkeys>=3.1.0 ec2instanceconnect = @@ -213,15 +213,15 @@ resourcegroupstaggingapi = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.0 + py-partiql-parser==0.5.1 route53 = route53resolver = s3 = PyYAML>=5.1 - py-partiql-parser==0.5.0 + py-partiql-parser==0.5.1 s3crc32c = PyYAML>=5.1 - py-partiql-parser==0.5.0 + py-partiql-parser==0.5.1 crc32c s3control = sagemaker = diff --git a/tests/test_dynamodb/test_dynamodb_statements.py b/tests/test_dynamodb/test_dynamodb_statements.py index 6530cf129..33f47ccab 100644 --- a/tests/test_dynamodb/test_dynamodb_statements.py +++ b/tests/test_dynamodb/test_dynamodb_statements.py @@ -334,6 +334,75 @@ def test_update_data(table_name=None): assert item2 in items +@mock_aws +def test_batch_update__not_enough_parameters(): + ddb_cli = boto3.client("dynamodb", "us-east-1") + ddb_res = boto3.resource("dynamodb", "us-east-1") + ddb_res.create_table( + TableName="users", + KeySchema=[{"AttributeName": "username", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "username", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + + statements = [ + { + "Statement": 'UPDATE users SET "first_name" = ?, "last_name" = ? WHERE "username"= ?', + "Parameters": [{"S": "test5"}, {"S": "test6"}], + } + ] + resp = ddb_cli.batch_execute_statement(Statements=statements)["Responses"] + assert resp == [ + { + "Error": { + "Code": "ValidationError", + "Message": "Number of parameters in request and statement don't match.", + } + } + ] + + +@mock_aws +def test_batch_update(): + ddb_cli = boto3.client("dynamodb", "us-east-1") + ddb_res = boto3.resource("dynamodb", "us-east-1") + table = ddb_res.create_table( + TableName="users", + KeySchema=[{"AttributeName": "username", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "username", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table.put_item( + Item={"username": "XXXX", "first_name": "test1", "last_name": "test2"} + ) + table.put_item( + Item={"username": "YYYY", "first_name": "test3", "last_name": "test4"} + ) + + statements = [ + { + "Statement": 'UPDATE users SET "first_name" = ?, "last_name" = ? WHERE "username"= ?', + "Parameters": [{"S": "test5"}, {"S": "test6"}, {"S": "XXXX"}], + }, + {"Statement": "DELETE FROM users WHERE username='YYYY'"}, + {"Statement": "INSERT INTO users value {'username': 'new'}"}, + ] + response = ddb_cli.batch_execute_statement(Statements=statements)["Responses"] + assert response == [ + {"TableName": "users"}, + {"TableName": "users"}, + {"TableName": "users"}, + ] + + users = ddb_res.Table("users").scan()["Items"] + assert len(users) == 2 + + # Changed + assert {"username": "XXXX", "first_name": "test5", "last_name": "test6"} in users + # New + assert {"username": "new"} in users + + @pytest.mark.aws_verified @dynamodb_aws_verified() def test_delete_data(table_name=None):