From e76ffb3409c07e4416723c87972e6ac2730488b1 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 7 Sep 2022 11:45:34 +0000 Subject: [PATCH] DynamoDB:query() now has improved support for KeyConditionExpression (#5449) --- .../parsing/key_condition_expression.py | 286 ++++++++++++++++++ moto/dynamodb/responses.py | 107 +------ .../exceptions/test_dynamodb_exceptions.py | 55 ++++ .../test_key_condition_expression_parser.py | 198 ++++++++++++ .../test_dynamodb_consumedcapacity.py | 1 + 5 files changed, 546 insertions(+), 101 deletions(-) create mode 100644 moto/dynamodb/parsing/key_condition_expression.py create mode 100644 tests/test_dynamodb/models/test_key_condition_expression_parser.py diff --git a/moto/dynamodb/parsing/key_condition_expression.py b/moto/dynamodb/parsing/key_condition_expression.py new file mode 100644 index 000000000..cb9953450 --- /dev/null +++ b/moto/dynamodb/parsing/key_condition_expression.py @@ -0,0 +1,286 @@ +from enum import Enum +from moto.dynamodb.exceptions import MockValidationException + + +class KeyConditionExpressionTokenizer: + """ + Tokenizer for a KeyConditionExpression. Should be used as an iterator. + The final character to be returned will be an empty string, to notify the caller that we've reached the end. + """ + + def __init__(self, expression): + self.expression = expression + self.token_pos = 0 + + def __iter__(self): + self.token_pos = 0 + return self + + def is_eof(self): + return self.peek() == "" + + def peek(self): + """ + Peek the next character without changing the position + """ + try: + return self.expression[self.token_pos] + except IndexError: + return "" + + def __next__(self): + """ + Returns the next character, or an empty string if we've reached the end of the string. + Calling this method again will result in a StopIterator + """ + try: + result = self.expression[self.token_pos] + self.token_pos += 1 + return result + except IndexError: + if self.token_pos == len(self.expression): + self.token_pos += 1 + return "" + raise StopIteration + + def skip_characters(self, phrase, case_sensitive=False) -> None: + """ + Skip the characters in the supplied phrase. + If any other character is encountered instead, this will fail. + If we've already reached the end of the iterator, this will fail. + """ + for ch in phrase: + if case_sensitive: + assert self.expression[self.token_pos] == ch + else: + assert self.expression[self.token_pos] in [ch.lower(), ch.upper()] + self.token_pos += 1 + + def skip_white_space(self): + """ + Skip the any whitespace characters that are coming up + """ + try: + while self.peek() == " ": + self.token_pos += 1 + except IndexError: + pass + + +class EXPRESSION_STAGES(Enum): + INITIAL_STAGE = "INITIAL_STAGE" # Can be a hash key, range key, or function + KEY_NAME = "KEY_NAME" + KEY_VALUE = "KEY_VALUE" + COMPARISON = "COMPARISON" + EOF = "EOF" + + +def get_key(schema, key_type): + keys = [key for key in schema if key["KeyType"] == key_type] + return keys[0]["AttributeName"] if keys else None + + +def parse_expression( + key_condition_expression, + expression_attribute_values, + expression_attribute_names, + schema, +): + """ + Parse a KeyConditionExpression using the provided expression attribute names/values + + key_condition_expression: hashkey = :id AND :sk = val + expression_attribute_names: {":sk": "sortkey"} + expression_attribute_values: {":id": {"S": "some hash key"}} + schema: [{'AttributeName': 'hashkey', 'KeyType': 'HASH'}, {"AttributeName": "sortkey", "KeyType": "RANGE"}] + """ + + current_stage: EXPRESSION_STAGES = None + current_phrase = "" + key_name = comparison = None + key_values = [] + results = [] + tokenizer = KeyConditionExpressionTokenizer(key_condition_expression) + for crnt_char in tokenizer: + if crnt_char == " ": + if current_stage == EXPRESSION_STAGES.INITIAL_STAGE: + tokenizer.skip_white_space() + if tokenizer.peek() == "(": + # begins_with(sk, :sk) and primary = :pk + # ^ + continue + else: + # start_date < :sk and primary = :pk + # ^ + key_name = expression_attribute_names.get( + current_phrase, current_phrase + ) + current_phrase = "" + current_stage = EXPRESSION_STAGES.COMPARISON + tokenizer.skip_white_space() + elif current_stage == EXPRESSION_STAGES.KEY_VALUE: + # job_id = :id + # job_id = :id and ... + # pk=p and x=y + # pk=p and fn(x, y1, y1 ) + # ^ --> ^ + key_values.append( + expression_attribute_values.get( + current_phrase, {"S": current_phrase} + ) + ) + current_phrase = "" + if comparison.upper() != "BETWEEN" or len(key_values) == 2: + results.append((key_name, comparison, key_values)) + key_values = [] + tokenizer.skip_white_space() + if tokenizer.peek() == ")": + tokenizer.skip_characters(")") + current_stage = EXPRESSION_STAGES.EOF + break + elif tokenizer.is_eof(): + break + tokenizer.skip_characters("AND", case_sensitive=False) + tokenizer.skip_white_space() + if comparison.upper() == "BETWEEN": + # We can expect another key_value, i.e. BETWEEN x and y + # We should add some validation, to not allow BETWEEN x and y and z and .. + pass + else: + current_stage = EXPRESSION_STAGES.INITIAL_STAGE + elif current_stage == EXPRESSION_STAGES.COMPARISON: + # hashkey = :id and sortkey = :sk + # hashkey = :id and sortkey BETWEEN x and y + # ^ --> ^ + comparison = current_phrase + current_phrase = "" + current_stage = EXPRESSION_STAGES.KEY_VALUE + continue + if crnt_char in ["=", "<", ">"] and current_stage in [ + EXPRESSION_STAGES.KEY_NAME, + EXPRESSION_STAGES.INITIAL_STAGE, + EXPRESSION_STAGES.COMPARISON, + ]: + if current_stage in [ + EXPRESSION_STAGES.KEY_NAME, + EXPRESSION_STAGES.INITIAL_STAGE, + ]: + key_name = expression_attribute_names.get( + current_phrase, current_phrase + ) + current_phrase = "" + if crnt_char in ["<", ">"] and tokenizer.peek() == "=": + comparison = crnt_char + tokenizer.__next__() + else: + comparison = crnt_char + tokenizer.skip_white_space() + current_stage = EXPRESSION_STAGES.KEY_VALUE + continue + if crnt_char in [","]: + if current_stage == EXPRESSION_STAGES.KEY_NAME: + # hashkey = :id and begins_with(sortkey, :sk) + # ^ --> ^ + key_name = expression_attribute_names.get( + current_phrase, current_phrase + ) + current_phrase = "" + current_stage = EXPRESSION_STAGES.KEY_VALUE + tokenizer.skip_white_space() + continue + else: + raise MockValidationException( + f'Invalid KeyConditionExpression: Syntax error; token: "{current_phrase}"' + ) + if crnt_char in [")"]: + if current_stage == EXPRESSION_STAGES.KEY_VALUE: + # hashkey = :id and begins_with(sortkey, :sk) + # ^ + value = expression_attribute_values.get(current_phrase, current_phrase) + current_phrase = "" + key_values.append(value) + results.append((key_name, comparison, key_values)) + key_values = [] + tokenizer.skip_white_space() + if tokenizer.is_eof() or tokenizer.peek() == ")": + break + else: + tokenizer.skip_characters("AND", case_sensitive=False) + tokenizer.skip_white_space() + current_stage = EXPRESSION_STAGES.INITIAL_STAGE + continue + if crnt_char in [""]: + # hashkey = :id + # hashkey = :id and sortkey = :sk + # ^ + if current_stage == EXPRESSION_STAGES.KEY_VALUE: + key_values.append( + expression_attribute_values.get( + current_phrase, {"S": current_phrase} + ) + ) + results.append((key_name, comparison, key_values)) + break + if crnt_char == "(": + # hashkey = :id and begins_with( sortkey, :sk) + # ^ --> ^ + if current_stage in [EXPRESSION_STAGES.INITIAL_STAGE]: + if current_phrase != "begins_with": + raise MockValidationException( + f"Invalid KeyConditionExpression: Invalid function name; function: {current_phrase}" + ) + comparison = current_phrase + current_phrase = "" + tokenizer.skip_white_space() + current_stage = EXPRESSION_STAGES.KEY_NAME + continue + if current_stage is None: + # (hash_key = :id .. ) + # ^ + continue + + current_phrase += crnt_char + if current_stage is None: + current_stage = EXPRESSION_STAGES.INITIAL_STAGE + + hash_value, range_comparison, range_values = validate_schema(results, schema) + + return ( + hash_value, + range_comparison.upper() if range_comparison else None, + range_values, + ) + + +# Validate that the schema-keys are encountered in our query +def validate_schema(results, schema): + index_hash_key = get_key(schema, "HASH") + comparison, hash_value = next( + ( + (comparison, value[0]) + for key, comparison, value in results + if key == index_hash_key + ), + (None, None), + ) + if hash_value is None: + raise MockValidationException( + f"Query condition missed key schema element: {index_hash_key}" + ) + if comparison != "=": + raise MockValidationException("Query key condition not supported") + + index_range_key = get_key(schema, "RANGE") + range_key, range_comparison, range_values = next( + ( + (key, comparison, values) + for key, comparison, values in results + if key == index_range_key + ), + (None, None, []), + ) + if index_range_key and len(results) > 1 and range_key != index_range_key: + raise MockValidationException( + f"Query condition missed key schema element: {index_range_key}" + ) + + return hash_value, range_comparison, range_values diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index 39fcc72f5..3d6007b9a 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -1,12 +1,12 @@ import copy import json -import re import itertools from functools import wraps from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores, amz_crc32, amzn_request_id +from moto.dynamodb.parsing.key_condition_expression import parse_expression from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords from .exceptions import ( MockValidationException, @@ -587,110 +587,15 @@ class DynamoHandler(BaseResponse): filter_kwargs = {} if key_condition_expression: - value_alias_map = self.body.get("ExpressionAttributeValues", {}) - index_name = self.body.get("IndexName") schema = self.dynamodb_backend.get_schema( table_name=name, index_name=index_name ) - - reverse_attribute_lookup = dict( - (v, k) for k, v in self.body.get("ExpressionAttributeNames", {}).items() - ) - - if " and " in key_condition_expression.lower(): - expressions = re.split( - " AND ", key_condition_expression, maxsplit=1, flags=re.IGNORECASE - ) - - index_hash_key = [key for key in schema if key["KeyType"] == "HASH"][0] - hash_key_var = reverse_attribute_lookup.get( - index_hash_key["AttributeName"], index_hash_key["AttributeName"] - ) - hash_key_regex = r"(^|[\s(]){0}\b".format(hash_key_var) - i, hash_key_expression = next( - ( - (i, e) - for i, e in enumerate(expressions) - if re.search(hash_key_regex, e) - ), - (None, None), - ) - if hash_key_expression is None: - raise MockValidationException( - "Query condition missed key schema element: {}".format( - hash_key_var - ) - ) - hash_key_expression = hash_key_expression.strip("()") - expressions.pop(i) - - # TODO implement more than one range expression and OR operators - range_key_expression = expressions[0].strip("()") - # Split expression, and account for all kinds of whitespacing around commas and brackets - range_key_expression_components = re.split( - r"\s*\(\s*|\s*,\s*|\s", range_key_expression - ) - # Skip whitespace - range_key_expression_components = [ - c for c in range_key_expression_components if c - ] - range_comparison = range_key_expression_components[1] - - if " and " in range_key_expression.lower(): - range_comparison = "BETWEEN" - # [range_key, between, x, and, y] - range_values = [ - value_alias_map[range_key_expression_components[2]], - value_alias_map[range_key_expression_components[4]], - ] - supplied_range_key = range_key_expression_components[0] - elif "begins_with" in range_key_expression: - range_comparison = "BEGINS_WITH" - # [begins_with, range_key, x] - range_values = [ - value_alias_map[range_key_expression_components[-1]] - ] - supplied_range_key = range_key_expression_components[1] - elif "begins_with" in range_key_expression.lower(): - function_used = range_key_expression[ - range_key_expression.lower().index("begins_with") : len( - "begins_with" - ) - ] - raise MockValidationException( - "Invalid KeyConditionExpression: Invalid function name; function: {}".format( - function_used - ) - ) - else: - # [range_key, =, x] - range_values = [value_alias_map[range_key_expression_components[2]]] - supplied_range_key = range_key_expression_components[0] - - supplied_range_key = expression_attribute_names.get( - supplied_range_key, supplied_range_key - ) - range_keys = [ - k["AttributeName"] for k in schema if k["KeyType"] == "RANGE" - ] - if supplied_range_key not in range_keys: - raise MockValidationException( - "Query condition missed key schema element: {}".format( - range_keys[0] - ) - ) - else: - hash_key_expression = key_condition_expression.strip("()") - range_comparison = None - range_values = [] - - if not re.search("[^<>]=", hash_key_expression): - raise MockValidationException("Query key condition not supported") - hash_key_value_alias = hash_key_expression.split("=")[1].strip() - # Temporary fix until we get proper KeyConditionExpression function - hash_key = value_alias_map.get( - hash_key_value_alias, {"S": hash_key_value_alias} + hash_key, range_comparison, range_values = parse_expression( + key_condition_expression=key_condition_expression, + expression_attribute_names=expression_attribute_names, + expression_attribute_values=expression_attribute_values, + schema=schema, ) else: # 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}} diff --git a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py index 8ce7ce1e1..e81f0e2e8 100644 --- a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py +++ b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py @@ -100,6 +100,33 @@ def test_query_gsi_with_wrong_key_attribute_names_throws_exception(): ) +@mock_dynamodb +def test_query_table_with_wrong_key_attribute_names_throws_exception(): + item = { + "partitionKey": "pk-1", + "someAttribute": "lore ipsum", + } + + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + dynamodb.create_table( + TableName="test-table", BillingMode="PAY_PER_REQUEST", **table_schema + ) + table = dynamodb.Table("test-table") + table.put_item(Item=item) + + # check using wrong name for sort key throws exception + with pytest.raises(ClientError) as exc: + table.query( + KeyConditionExpression="wrongName = :pk", + ExpressionAttributeValues={":pk": "pk"}, + )["Items"] + err = exc.value.response["Error"] + err["Code"].should.equal("ValidationException") + err["Message"].should.equal( + "Query condition missed key schema element: partitionKey" + ) + + @mock_dynamodb def test_empty_expressionattributenames(): ddb = boto3.resource("dynamodb", region_name="us-east-1") @@ -667,3 +694,31 @@ def test_batch_put_item_with_empty_value(): # Empty regular parameter workst just fine though with table.batch_writer() as batch: batch.put_item(Item={"pk": "sth", "sk": "else", "par": ""}) + + +@mock_dynamodb +def test_query_begins_with_without_brackets(): + client = boto3.client("dynamodb", region_name="us-east-1") + client.create_table( + TableName="test-table", + AttributeDefinitions=[ + {"AttributeName": "pk", "AttributeType": "S"}, + {"AttributeName": "sk", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "pk", "KeyType": "HASH"}, + {"AttributeName": "sk", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, + ) + with pytest.raises(ClientError) as exc: + client.query( + TableName="test-table", + KeyConditionExpression="pk=:pk AND begins_with sk, :sk ", + ExpressionAttributeValues={":pk": {"S": "test1"}, ":sk": {"S": "test2"}}, + ) + err = exc.value.response["Error"] + err["Message"].should.equal( + 'Invalid KeyConditionExpression: Syntax error; token: "sk"' + ) + err["Code"].should.equal("ValidationException") diff --git a/tests/test_dynamodb/models/test_key_condition_expression_parser.py b/tests/test_dynamodb/models/test_key_condition_expression_parser.py new file mode 100644 index 000000000..0c924ebf9 --- /dev/null +++ b/tests/test_dynamodb/models/test_key_condition_expression_parser.py @@ -0,0 +1,198 @@ +import pytest +from moto.dynamodb.exceptions import DynamodbException +from moto.dynamodb.parsing.key_condition_expression import parse_expression + + +class TestHashKey: + schema = [{"AttributeName": "job_id", "KeyType": "HASH"}] + + @pytest.mark.parametrize("expression", ["job_id = :id", "job_id = :id "]) + def test_hash_key_only(self, expression): + eav = {":id": {"S": "asdasdasd"}} + desired_hash_key, comparison, range_values = parse_expression( + expression_attribute_values=eav, + key_condition_expression=expression, + schema=self.schema, + expression_attribute_names=dict(), + ) + desired_hash_key.should.equal(eav[":id"]) + comparison.should.equal(None) + range_values.should.equal([]) + + def test_unknown_hash_key(self): + kce = "wrongName = :id" + eav = {":id": "pk"} + with pytest.raises(DynamodbException) as exc: + parse_expression( + expression_attribute_values=eav, + key_condition_expression=kce, + schema=self.schema, + expression_attribute_names=dict(), + ) + exc.value.message.should.equal( + "Query condition missed key schema element: job_id" + ) + + def test_unknown_hash_value(self): + # TODO: is this correct? I'd assume that this should throw an error instead + # Revisit after test in exceptions.py passes + kce = "job_id = :unknown" + eav = {":id": {"S": "asdasdasd"}} + desired_hash_key, comparison, range_values = parse_expression( + expression_attribute_values=eav, + key_condition_expression=kce, + schema=self.schema, + expression_attribute_names=dict(), + ) + desired_hash_key.should.equal({"S": ":unknown"}) + comparison.should.equal(None) + range_values.should.equal([]) + + +class TestHashAndRangeKey: + schema = [ + {"AttributeName": "job_id", "KeyType": "HASH"}, + {"AttributeName": "start_date", "KeyType": "RANGE"}, + ] + + def test_unknown_hash_key(self): + kce = "wrongName = :id AND start_date = :sk" + eav = {":id": "pk", ":sk": "sk"} + with pytest.raises(DynamodbException) as exc: + parse_expression( + expression_attribute_values=eav, + key_condition_expression=kce, + schema=self.schema, + expression_attribute_names=dict(), + ) + exc.value.message.should.equal( + "Query condition missed key schema element: job_id" + ) + + @pytest.mark.parametrize( + "expr", + [ + "job_id = :id AND wrongName = :sk", + "job_id = :id AND begins_with ( wrongName , :sk )", + "job_id = :id AND wrongName BETWEEN :sk and :sk2", + ], + ) + def test_unknown_range_key(self, expr): + eav = {":id": "pk", ":sk": "sk", ":sk2": "sk"} + with pytest.raises(DynamodbException) as exc: + parse_expression( + expression_attribute_values=eav, + key_condition_expression=expr, + schema=self.schema, + expression_attribute_names=dict(), + ) + exc.value.message.should.equal( + "Query condition missed key schema element: start_date" + ) + + @pytest.mark.parametrize( + "expr", + [ + "job_id = :id AND begins_with(start_date,:sk)", + "job_id = :id AND begins_with(start_date, :sk)", + "job_id = :id AND begins_with( start_date,:sk)", + "job_id = :id AND begins_with( start_date, :sk)", + "job_id = :id AND begins_with ( start_date, :sk ) ", + ], + ) + def test_begin_with(self, expr): + eav = {":id": "pk", ":sk": "19"} + desired_hash_key, comparison, range_values = parse_expression( + expression_attribute_values=eav, + key_condition_expression=expr, + schema=self.schema, + expression_attribute_names=dict(), + ) + desired_hash_key.should.equal("pk") + comparison.should.equal("BEGINS_WITH") + range_values.should.equal(["19"]) + + @pytest.mark.parametrize("fn", ["Begins_with", "Begins_With", "BEGINS_WITH"]) + def test_begin_with__wrong_case(self, fn): + eav = {":id": "pk", ":sk": "19"} + with pytest.raises(DynamodbException) as exc: + parse_expression( + expression_attribute_values=eav, + key_condition_expression=f"job_id = :id AND {fn}(start_date,:sk)", + schema=self.schema, + expression_attribute_names=dict(), + ) + exc.value.message.should.equal( + f"Invalid KeyConditionExpression: Invalid function name; function: {fn}" + ) + + @pytest.mark.parametrize( + "expr", + [ + "job_id = :id and start_date BETWEEN :sk1 AND :sk2", + "job_id = :id and start_date BETWEEN :sk1 and :sk2", + "job_id = :id and start_date between :sk1 and :sk2 ", + ], + ) + def test_in_between(self, expr): + eav = {":id": "pk", ":sk1": "19", ":sk2": "21"} + desired_hash_key, comparison, range_values = parse_expression( + expression_attribute_values=eav, + key_condition_expression=expr, + schema=self.schema, + expression_attribute_names=dict(), + ) + desired_hash_key.should.equal("pk") + comparison.should.equal("BETWEEN") + range_values.should.equal(["19", "21"]) + + @pytest.mark.parametrize("operator", [" < ", " <=", "= ", ">", ">="]) + def test_numeric_comparisons(self, operator): + eav = {":id": "pk", ":sk": "19"} + expr = f"job_id = :id and start_date{operator}:sk" + desired_hash_key, comparison, range_values = parse_expression( + expression_attribute_values=eav, + key_condition_expression=expr, + schema=self.schema, + expression_attribute_names=dict(), + ) + desired_hash_key.should.equal("pk") + comparison.should.equal(operator.strip()) + range_values.should.equal(["19"]) + + @pytest.mark.parametrize( + "expr", + [ + "start_date >= :sk and job_id = :id", + "start_date>:sk and job_id=:id", + "start_date=:sk and job_id = :id", + "begins_with(start_date,:sk) and job_id = :id", + ], + ) + def test_reverse_keys(self, expr): + eav = {":id": "pk", ":sk1": "19", ":sk2": "21"} + desired_hash_key, comparison, range_values = parse_expression( + expression_attribute_values=eav, + key_condition_expression=expr, + schema=self.schema, + expression_attribute_names=dict(), + ) + desired_hash_key.should.equal("pk") + + +class TestNamesAndValues: + schema = [{"AttributeName": "job_id", "KeyType": "HASH"}] + + def test_names_and_values(self): + kce = ":j = :id" + ean = {":j": "job_id"} + eav = {":id": {"S": "asdasdasd"}} + desired_hash_key, comparison, range_values = parse_expression( + expression_attribute_values=eav, + key_condition_expression=kce, + schema=self.schema, + expression_attribute_names=ean, + ) + desired_hash_key.should.equal(eav[":id"]) + comparison.should.equal(None) + range_values.should.equal([]) diff --git a/tests/test_dynamodb/test_dynamodb_consumedcapacity.py b/tests/test_dynamodb/test_dynamodb_consumedcapacity.py index 0d367c75a..41d3a13c0 100644 --- a/tests/test_dynamodb/test_dynamodb_consumedcapacity.py +++ b/tests/test_dynamodb/test_dynamodb_consumedcapacity.py @@ -126,6 +126,7 @@ def test_only_return_consumed_capacity_when_required( # QUERY_INDEX args["IndexName"] = "job_name-index" + args["KeyConditionExpression"] = "job_name = :id" response = client.query(**args) validate_response(response, should_have_capacity, should_have_table, is_index=True)