DynamoDB:query() now has improved support for KeyConditionExpression (#5449)

This commit is contained in:
Bert Blommers 2022-09-07 11:45:34 +00:00 committed by GitHub
parent 40ec430143
commit e76ffb3409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 546 additions and 101 deletions

View File

@ -0,0 +1,286 @@
from enum import Enum
from moto.dynamodb.exceptions import MockValidationException
class KeyConditionExpressionTokenizer:
"""
Tokenizer for a KeyConditionExpression. Should be used as an iterator.
The final character to be returned will be an empty string, to notify the caller that we've reached the end.
"""
def __init__(self, expression):
self.expression = expression
self.token_pos = 0
def __iter__(self):
self.token_pos = 0
return self
def is_eof(self):
return self.peek() == ""
def peek(self):
"""
Peek the next character without changing the position
"""
try:
return self.expression[self.token_pos]
except IndexError:
return ""
def __next__(self):
"""
Returns the next character, or an empty string if we've reached the end of the string.
Calling this method again will result in a StopIterator
"""
try:
result = self.expression[self.token_pos]
self.token_pos += 1
return result
except IndexError:
if self.token_pos == len(self.expression):
self.token_pos += 1
return ""
raise StopIteration
def skip_characters(self, phrase, case_sensitive=False) -> None:
"""
Skip the characters in the supplied phrase.
If any other character is encountered instead, this will fail.
If we've already reached the end of the iterator, this will fail.
"""
for ch in phrase:
if case_sensitive:
assert self.expression[self.token_pos] == ch
else:
assert self.expression[self.token_pos] in [ch.lower(), ch.upper()]
self.token_pos += 1
def skip_white_space(self):
"""
Skip the any whitespace characters that are coming up
"""
try:
while self.peek() == " ":
self.token_pos += 1
except IndexError:
pass
class EXPRESSION_STAGES(Enum):
INITIAL_STAGE = "INITIAL_STAGE" # Can be a hash key, range key, or function
KEY_NAME = "KEY_NAME"
KEY_VALUE = "KEY_VALUE"
COMPARISON = "COMPARISON"
EOF = "EOF"
def get_key(schema, key_type):
keys = [key for key in schema if key["KeyType"] == key_type]
return keys[0]["AttributeName"] if keys else None
def parse_expression(
key_condition_expression,
expression_attribute_values,
expression_attribute_names,
schema,
):
"""
Parse a KeyConditionExpression using the provided expression attribute names/values
key_condition_expression: hashkey = :id AND :sk = val
expression_attribute_names: {":sk": "sortkey"}
expression_attribute_values: {":id": {"S": "some hash key"}}
schema: [{'AttributeName': 'hashkey', 'KeyType': 'HASH'}, {"AttributeName": "sortkey", "KeyType": "RANGE"}]
"""
current_stage: EXPRESSION_STAGES = None
current_phrase = ""
key_name = comparison = None
key_values = []
results = []
tokenizer = KeyConditionExpressionTokenizer(key_condition_expression)
for crnt_char in tokenizer:
if crnt_char == " ":
if current_stage == EXPRESSION_STAGES.INITIAL_STAGE:
tokenizer.skip_white_space()
if tokenizer.peek() == "(":
# begins_with(sk, :sk) and primary = :pk
# ^
continue
else:
# start_date < :sk and primary = :pk
# ^
key_name = expression_attribute_names.get(
current_phrase, current_phrase
)
current_phrase = ""
current_stage = EXPRESSION_STAGES.COMPARISON
tokenizer.skip_white_space()
elif current_stage == EXPRESSION_STAGES.KEY_VALUE:
# job_id = :id
# job_id = :id and ...
# pk=p and x=y
# pk=p and fn(x, y1, y1 )
# ^ --> ^
key_values.append(
expression_attribute_values.get(
current_phrase, {"S": current_phrase}
)
)
current_phrase = ""
if comparison.upper() != "BETWEEN" or len(key_values) == 2:
results.append((key_name, comparison, key_values))
key_values = []
tokenizer.skip_white_space()
if tokenizer.peek() == ")":
tokenizer.skip_characters(")")
current_stage = EXPRESSION_STAGES.EOF
break
elif tokenizer.is_eof():
break
tokenizer.skip_characters("AND", case_sensitive=False)
tokenizer.skip_white_space()
if comparison.upper() == "BETWEEN":
# We can expect another key_value, i.e. BETWEEN x and y
# We should add some validation, to not allow BETWEEN x and y and z and ..
pass
else:
current_stage = EXPRESSION_STAGES.INITIAL_STAGE
elif current_stage == EXPRESSION_STAGES.COMPARISON:
# hashkey = :id and sortkey = :sk
# hashkey = :id and sortkey BETWEEN x and y
# ^ --> ^
comparison = current_phrase
current_phrase = ""
current_stage = EXPRESSION_STAGES.KEY_VALUE
continue
if crnt_char in ["=", "<", ">"] and current_stage in [
EXPRESSION_STAGES.KEY_NAME,
EXPRESSION_STAGES.INITIAL_STAGE,
EXPRESSION_STAGES.COMPARISON,
]:
if current_stage in [
EXPRESSION_STAGES.KEY_NAME,
EXPRESSION_STAGES.INITIAL_STAGE,
]:
key_name = expression_attribute_names.get(
current_phrase, current_phrase
)
current_phrase = ""
if crnt_char in ["<", ">"] and tokenizer.peek() == "=":
comparison = crnt_char + tokenizer.__next__()
else:
comparison = crnt_char
tokenizer.skip_white_space()
current_stage = EXPRESSION_STAGES.KEY_VALUE
continue
if crnt_char in [","]:
if current_stage == EXPRESSION_STAGES.KEY_NAME:
# hashkey = :id and begins_with(sortkey, :sk)
# ^ --> ^
key_name = expression_attribute_names.get(
current_phrase, current_phrase
)
current_phrase = ""
current_stage = EXPRESSION_STAGES.KEY_VALUE
tokenizer.skip_white_space()
continue
else:
raise MockValidationException(
f'Invalid KeyConditionExpression: Syntax error; token: "{current_phrase}"'
)
if crnt_char in [")"]:
if current_stage == EXPRESSION_STAGES.KEY_VALUE:
# hashkey = :id and begins_with(sortkey, :sk)
# ^
value = expression_attribute_values.get(current_phrase, current_phrase)
current_phrase = ""
key_values.append(value)
results.append((key_name, comparison, key_values))
key_values = []
tokenizer.skip_white_space()
if tokenizer.is_eof() or tokenizer.peek() == ")":
break
else:
tokenizer.skip_characters("AND", case_sensitive=False)
tokenizer.skip_white_space()
current_stage = EXPRESSION_STAGES.INITIAL_STAGE
continue
if crnt_char in [""]:
# hashkey = :id
# hashkey = :id and sortkey = :sk
# ^
if current_stage == EXPRESSION_STAGES.KEY_VALUE:
key_values.append(
expression_attribute_values.get(
current_phrase, {"S": current_phrase}
)
)
results.append((key_name, comparison, key_values))
break
if crnt_char == "(":
# hashkey = :id and begins_with( sortkey, :sk)
# ^ --> ^
if current_stage in [EXPRESSION_STAGES.INITIAL_STAGE]:
if current_phrase != "begins_with":
raise MockValidationException(
f"Invalid KeyConditionExpression: Invalid function name; function: {current_phrase}"
)
comparison = current_phrase
current_phrase = ""
tokenizer.skip_white_space()
current_stage = EXPRESSION_STAGES.KEY_NAME
continue
if current_stage is None:
# (hash_key = :id .. )
# ^
continue
current_phrase += crnt_char
if current_stage is None:
current_stage = EXPRESSION_STAGES.INITIAL_STAGE
hash_value, range_comparison, range_values = validate_schema(results, schema)
return (
hash_value,
range_comparison.upper() if range_comparison else None,
range_values,
)
# Validate that the schema-keys are encountered in our query
def validate_schema(results, schema):
index_hash_key = get_key(schema, "HASH")
comparison, hash_value = next(
(
(comparison, value[0])
for key, comparison, value in results
if key == index_hash_key
),
(None, None),
)
if hash_value is None:
raise MockValidationException(
f"Query condition missed key schema element: {index_hash_key}"
)
if comparison != "=":
raise MockValidationException("Query key condition not supported")
index_range_key = get_key(schema, "RANGE")
range_key, range_comparison, range_values = next(
(
(key, comparison, values)
for key, comparison, values in results
if key == index_range_key
),
(None, None, []),
)
if index_range_key and len(results) > 1 and range_key != index_range_key:
raise MockValidationException(
f"Query condition missed key schema element: {index_range_key}"
)
return hash_value, range_comparison, range_values

View File

@ -1,12 +1,12 @@
import copy
import json
import re
import itertools
from functools import wraps
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores, amz_crc32, amzn_request_id
from moto.dynamodb.parsing.key_condition_expression import parse_expression
from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords
from .exceptions import (
MockValidationException,
@ -587,110 +587,15 @@ class DynamoHandler(BaseResponse):
filter_kwargs = {}
if key_condition_expression:
value_alias_map = self.body.get("ExpressionAttributeValues", {})
index_name = self.body.get("IndexName")
schema = self.dynamodb_backend.get_schema(
table_name=name, index_name=index_name
)
reverse_attribute_lookup = dict(
(v, k) for k, v in self.body.get("ExpressionAttributeNames", {}).items()
)
if " and " in key_condition_expression.lower():
expressions = re.split(
" AND ", key_condition_expression, maxsplit=1, flags=re.IGNORECASE
)
index_hash_key = [key for key in schema if key["KeyType"] == "HASH"][0]
hash_key_var = reverse_attribute_lookup.get(
index_hash_key["AttributeName"], index_hash_key["AttributeName"]
)
hash_key_regex = r"(^|[\s(]){0}\b".format(hash_key_var)
i, hash_key_expression = next(
(
(i, e)
for i, e in enumerate(expressions)
if re.search(hash_key_regex, e)
),
(None, None),
)
if hash_key_expression is None:
raise MockValidationException(
"Query condition missed key schema element: {}".format(
hash_key_var
)
)
hash_key_expression = hash_key_expression.strip("()")
expressions.pop(i)
# TODO implement more than one range expression and OR operators
range_key_expression = expressions[0].strip("()")
# Split expression, and account for all kinds of whitespacing around commas and brackets
range_key_expression_components = re.split(
r"\s*\(\s*|\s*,\s*|\s", range_key_expression
)
# Skip whitespace
range_key_expression_components = [
c for c in range_key_expression_components if c
]
range_comparison = range_key_expression_components[1]
if " and " in range_key_expression.lower():
range_comparison = "BETWEEN"
# [range_key, between, x, and, y]
range_values = [
value_alias_map[range_key_expression_components[2]],
value_alias_map[range_key_expression_components[4]],
]
supplied_range_key = range_key_expression_components[0]
elif "begins_with" in range_key_expression:
range_comparison = "BEGINS_WITH"
# [begins_with, range_key, x]
range_values = [
value_alias_map[range_key_expression_components[-1]]
]
supplied_range_key = range_key_expression_components[1]
elif "begins_with" in range_key_expression.lower():
function_used = range_key_expression[
range_key_expression.lower().index("begins_with") : len(
"begins_with"
)
]
raise MockValidationException(
"Invalid KeyConditionExpression: Invalid function name; function: {}".format(
function_used
)
)
else:
# [range_key, =, x]
range_values = [value_alias_map[range_key_expression_components[2]]]
supplied_range_key = range_key_expression_components[0]
supplied_range_key = expression_attribute_names.get(
supplied_range_key, supplied_range_key
)
range_keys = [
k["AttributeName"] for k in schema if k["KeyType"] == "RANGE"
]
if supplied_range_key not in range_keys:
raise MockValidationException(
"Query condition missed key schema element: {}".format(
range_keys[0]
)
)
else:
hash_key_expression = key_condition_expression.strip("()")
range_comparison = None
range_values = []
if not re.search("[^<>]=", hash_key_expression):
raise MockValidationException("Query key condition not supported")
hash_key_value_alias = hash_key_expression.split("=")[1].strip()
# Temporary fix until we get proper KeyConditionExpression function
hash_key = value_alias_map.get(
hash_key_value_alias, {"S": hash_key_value_alias}
hash_key, range_comparison, range_values = parse_expression(
key_condition_expression=key_condition_expression,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
schema=schema,
)
else:
# 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}}

View File

@ -100,6 +100,33 @@ def test_query_gsi_with_wrong_key_attribute_names_throws_exception():
)
@mock_dynamodb
def test_query_table_with_wrong_key_attribute_names_throws_exception():
item = {
"partitionKey": "pk-1",
"someAttribute": "lore ipsum",
}
dynamodb = boto3.resource("dynamodb", region_name="us-east-1")
dynamodb.create_table(
TableName="test-table", BillingMode="PAY_PER_REQUEST", **table_schema
)
table = dynamodb.Table("test-table")
table.put_item(Item=item)
# check using wrong name for sort key throws exception
with pytest.raises(ClientError) as exc:
table.query(
KeyConditionExpression="wrongName = :pk",
ExpressionAttributeValues={":pk": "pk"},
)["Items"]
err = exc.value.response["Error"]
err["Code"].should.equal("ValidationException")
err["Message"].should.equal(
"Query condition missed key schema element: partitionKey"
)
@mock_dynamodb
def test_empty_expressionattributenames():
ddb = boto3.resource("dynamodb", region_name="us-east-1")
@ -667,3 +694,31 @@ def test_batch_put_item_with_empty_value():
# Empty regular parameter workst just fine though
with table.batch_writer() as batch:
batch.put_item(Item={"pk": "sth", "sk": "else", "par": ""})
@mock_dynamodb
def test_query_begins_with_without_brackets():
client = boto3.client("dynamodb", region_name="us-east-1")
client.create_table(
TableName="test-table",
AttributeDefinitions=[
{"AttributeName": "pk", "AttributeType": "S"},
{"AttributeName": "sk", "AttributeType": "S"},
],
KeySchema=[
{"AttributeName": "pk", "KeyType": "HASH"},
{"AttributeName": "sk", "KeyType": "RANGE"},
],
ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123},
)
with pytest.raises(ClientError) as exc:
client.query(
TableName="test-table",
KeyConditionExpression="pk=:pk AND begins_with sk, :sk ",
ExpressionAttributeValues={":pk": {"S": "test1"}, ":sk": {"S": "test2"}},
)
err = exc.value.response["Error"]
err["Message"].should.equal(
'Invalid KeyConditionExpression: Syntax error; token: "sk"'
)
err["Code"].should.equal("ValidationException")

View File

@ -0,0 +1,198 @@
import pytest
from moto.dynamodb.exceptions import DynamodbException
from moto.dynamodb.parsing.key_condition_expression import parse_expression
class TestHashKey:
schema = [{"AttributeName": "job_id", "KeyType": "HASH"}]
@pytest.mark.parametrize("expression", ["job_id = :id", "job_id = :id "])
def test_hash_key_only(self, expression):
eav = {":id": {"S": "asdasdasd"}}
desired_hash_key, comparison, range_values = parse_expression(
expression_attribute_values=eav,
key_condition_expression=expression,
schema=self.schema,
expression_attribute_names=dict(),
)
desired_hash_key.should.equal(eav[":id"])
comparison.should.equal(None)
range_values.should.equal([])
def test_unknown_hash_key(self):
kce = "wrongName = :id"
eav = {":id": "pk"}
with pytest.raises(DynamodbException) as exc:
parse_expression(
expression_attribute_values=eav,
key_condition_expression=kce,
schema=self.schema,
expression_attribute_names=dict(),
)
exc.value.message.should.equal(
"Query condition missed key schema element: job_id"
)
def test_unknown_hash_value(self):
# TODO: is this correct? I'd assume that this should throw an error instead
# Revisit after test in exceptions.py passes
kce = "job_id = :unknown"
eav = {":id": {"S": "asdasdasd"}}
desired_hash_key, comparison, range_values = parse_expression(
expression_attribute_values=eav,
key_condition_expression=kce,
schema=self.schema,
expression_attribute_names=dict(),
)
desired_hash_key.should.equal({"S": ":unknown"})
comparison.should.equal(None)
range_values.should.equal([])
class TestHashAndRangeKey:
schema = [
{"AttributeName": "job_id", "KeyType": "HASH"},
{"AttributeName": "start_date", "KeyType": "RANGE"},
]
def test_unknown_hash_key(self):
kce = "wrongName = :id AND start_date = :sk"
eav = {":id": "pk", ":sk": "sk"}
with pytest.raises(DynamodbException) as exc:
parse_expression(
expression_attribute_values=eav,
key_condition_expression=kce,
schema=self.schema,
expression_attribute_names=dict(),
)
exc.value.message.should.equal(
"Query condition missed key schema element: job_id"
)
@pytest.mark.parametrize(
"expr",
[
"job_id = :id AND wrongName = :sk",
"job_id = :id AND begins_with ( wrongName , :sk )",
"job_id = :id AND wrongName BETWEEN :sk and :sk2",
],
)
def test_unknown_range_key(self, expr):
eav = {":id": "pk", ":sk": "sk", ":sk2": "sk"}
with pytest.raises(DynamodbException) as exc:
parse_expression(
expression_attribute_values=eav,
key_condition_expression=expr,
schema=self.schema,
expression_attribute_names=dict(),
)
exc.value.message.should.equal(
"Query condition missed key schema element: start_date"
)
@pytest.mark.parametrize(
"expr",
[
"job_id = :id AND begins_with(start_date,:sk)",
"job_id = :id AND begins_with(start_date, :sk)",
"job_id = :id AND begins_with( start_date,:sk)",
"job_id = :id AND begins_with( start_date, :sk)",
"job_id = :id AND begins_with ( start_date, :sk ) ",
],
)
def test_begin_with(self, expr):
eav = {":id": "pk", ":sk": "19"}
desired_hash_key, comparison, range_values = parse_expression(
expression_attribute_values=eav,
key_condition_expression=expr,
schema=self.schema,
expression_attribute_names=dict(),
)
desired_hash_key.should.equal("pk")
comparison.should.equal("BEGINS_WITH")
range_values.should.equal(["19"])
@pytest.mark.parametrize("fn", ["Begins_with", "Begins_With", "BEGINS_WITH"])
def test_begin_with__wrong_case(self, fn):
eav = {":id": "pk", ":sk": "19"}
with pytest.raises(DynamodbException) as exc:
parse_expression(
expression_attribute_values=eav,
key_condition_expression=f"job_id = :id AND {fn}(start_date,:sk)",
schema=self.schema,
expression_attribute_names=dict(),
)
exc.value.message.should.equal(
f"Invalid KeyConditionExpression: Invalid function name; function: {fn}"
)
@pytest.mark.parametrize(
"expr",
[
"job_id = :id and start_date BETWEEN :sk1 AND :sk2",
"job_id = :id and start_date BETWEEN :sk1 and :sk2",
"job_id = :id and start_date between :sk1 and :sk2 ",
],
)
def test_in_between(self, expr):
eav = {":id": "pk", ":sk1": "19", ":sk2": "21"}
desired_hash_key, comparison, range_values = parse_expression(
expression_attribute_values=eav,
key_condition_expression=expr,
schema=self.schema,
expression_attribute_names=dict(),
)
desired_hash_key.should.equal("pk")
comparison.should.equal("BETWEEN")
range_values.should.equal(["19", "21"])
@pytest.mark.parametrize("operator", [" < ", " <=", "= ", ">", ">="])
def test_numeric_comparisons(self, operator):
eav = {":id": "pk", ":sk": "19"}
expr = f"job_id = :id and start_date{operator}:sk"
desired_hash_key, comparison, range_values = parse_expression(
expression_attribute_values=eav,
key_condition_expression=expr,
schema=self.schema,
expression_attribute_names=dict(),
)
desired_hash_key.should.equal("pk")
comparison.should.equal(operator.strip())
range_values.should.equal(["19"])
@pytest.mark.parametrize(
"expr",
[
"start_date >= :sk and job_id = :id",
"start_date>:sk and job_id=:id",
"start_date=:sk and job_id = :id",
"begins_with(start_date,:sk) and job_id = :id",
],
)
def test_reverse_keys(self, expr):
eav = {":id": "pk", ":sk1": "19", ":sk2": "21"}
desired_hash_key, comparison, range_values = parse_expression(
expression_attribute_values=eav,
key_condition_expression=expr,
schema=self.schema,
expression_attribute_names=dict(),
)
desired_hash_key.should.equal("pk")
class TestNamesAndValues:
schema = [{"AttributeName": "job_id", "KeyType": "HASH"}]
def test_names_and_values(self):
kce = ":j = :id"
ean = {":j": "job_id"}
eav = {":id": {"S": "asdasdasd"}}
desired_hash_key, comparison, range_values = parse_expression(
expression_attribute_values=eav,
key_condition_expression=kce,
schema=self.schema,
expression_attribute_names=ean,
)
desired_hash_key.should.equal(eav[":id"])
comparison.should.equal(None)
range_values.should.equal([])

View File

@ -126,6 +126,7 @@ def test_only_return_consumed_capacity_when_required(
# QUERY_INDEX
args["IndexName"] = "job_name-index"
args["KeyConditionExpression"] = "job_name = :id"
response = client.query(**args)
validate_response(response, should_have_capacity, should_have_table, is_index=True)