diff --git a/moto/glue/exceptions.py b/moto/glue/exceptions.py index ec4338577..003734b11 100644 --- a/moto/glue/exceptions.py +++ b/moto/glue/exceptions.py @@ -78,3 +78,22 @@ class CrawlerNotRunningException(GlueClientError): class ConcurrentRunsExceededException(GlueClientError): def __init__(self, msg): super().__init__("ConcurrentRunsExceededException", msg) + + +class _InvalidOperationException(GlueClientError): + def __init__(self, error_type, op, msg): + super().__init__( + error_type, + "An error occurred (%s) when calling the %s operation: %s" + % (error_type, op, msg), + ) + + +class InvalidInputException(_InvalidOperationException): + def __init__(self, op, msg): + super().__init__("InvalidInputException", op, msg) + + +class InvalidStateException(_InvalidOperationException): + def __init__(self, op, msg): + super().__init__("InvalidStateException", op, msg) diff --git a/moto/glue/models.py b/moto/glue/models.py index f627c8fb7..0c00d817d 100644 --- a/moto/glue/models.py +++ b/moto/glue/models.py @@ -18,6 +18,7 @@ from .exceptions import ( JobNotFoundException, ConcurrentRunsExceededException, ) +from .utils import PartitionFilter from ..utilities.paginator import paginate @@ -278,8 +279,19 @@ class FakeTable(BaseModel): raise PartitionAlreadyExistsException() self.partitions[str(partition.values)] = partition - def get_partitions(self): - return [p for str_part_values, p in self.partitions.items()] + def get_partitions(self, expression): + """See https://docs.aws.amazon.com/glue/latest/webapi/API_GetPartitions.html + for supported expressions. + + Expression caveats: + + - Column names must consist of UPPERCASE, lowercase, dots and underscores only. + - Nanosecond expressions on timestamp columns are rounded to microseconds. + - Literal dates and timestamps must be valid, i.e. no support for February 31st. + - LIKE expressions are converted to Python regexes, escaping special characters. + Only % and _ wildcards are supported, and SQL escaping using [] does not work. + """ + return list(filter(PartitionFilter(expression, self), self.partitions.values())) def get_partition(self, values): try: diff --git a/moto/glue/responses.py b/moto/glue/responses.py index efc9335c7..78a97c651 100644 --- a/moto/glue/responses.py +++ b/moto/glue/responses.py @@ -129,13 +129,12 @@ class GlueResponse(BaseResponse): def get_partitions(self): database_name = self.parameters.get("DatabaseName") table_name = self.parameters.get("TableName") - if "Expression" in self.parameters: - raise NotImplementedError( - "Expression filtering in get_partitions is not implemented in moto" - ) + expression = self.parameters.get("Expression") table = self.glue_backend.get_table(database_name, table_name) - return json.dumps({"Partitions": [p.as_dict() for p in table.get_partitions()]}) + return json.dumps( + {"Partitions": [p.as_dict() for p in table.get_partitions(expression)]} + ) def get_partition(self): database_name = self.parameters.get("DatabaseName") diff --git a/moto/glue/utils.py b/moto/glue/utils.py index e69de29bb..99f039405 100644 --- a/moto/glue/utils.py +++ b/moto/glue/utils.py @@ -0,0 +1,354 @@ +import abc +import operator +import re +import warnings +from datetime import date, datetime, timedelta +from itertools import repeat +from typing import Any, Dict, List, Optional, Union + +from pyparsing import ( + CaselessKeyword, + Forward, + OpAssoc, + ParserElement, + ParseResults, + QuotedString, + Suppress, + Word, + alphanums, + delimited_list, + exceptions, + infix_notation, + one_of, + pyparsing_common, +) + +from .exceptions import InvalidInputException, InvalidStateException + + +def _cast(type_: str, value: Any) -> Union[date, datetime, float, int, str]: + # values are always cast from string to target type + value = str(value) + + if type_ in ("bigint", "int", "smallint", "tinyint"): + try: + return int(value) # no size is enforced + except ValueError: + raise ValueError(f'"{value}" is not an integer.') + + if type_ == "decimal": + try: + return float(value) + except ValueError: + raise ValueError(f"{value} is not a decimal.") + + if type_ in ("char", "string", "varchar"): + return value # no length is enforced + + if type_ == "date": + try: + return datetime.strptime(value, "%Y-%m-%d").date() + except ValueError: + raise ValueError(f"{value} is not a date.") + + if type_ == "timestamp": + match = re.search( + r"^(?P\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})" + r"(?P\.\d{1,9})?$", + value, + ) + if match is None: + raise ValueError( + "Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff]" + f" {value} is not a timestamp." + ) + + try: + timestamp = datetime.strptime(match.group("timestamp"), "%Y-%m-%d %H:%M:%S") + except ValueError: + raise ValueError( + "Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff]" + f" {value} is not a timestamp." + ) + + nanos = match.group("nanos") + if nanos is not None: + # strip leading dot, reverse and left pad with zeros to nanoseconds + nanos = "".join(reversed(nanos[1:])).zfill(9) + for i, nanoseconds in enumerate(nanos): + microseconds = (int(nanoseconds) * 10**i) / 1000 + timestamp += timedelta(microseconds=round(microseconds)) + + return timestamp + + raise InvalidInputException("GetPartitions", f"Unknown type : '{type_}'") + + +def _escape_regex(pattern: str) -> str: + """Taken from Python 3.7 to avoid escaping '%'.""" + _special_chars_map = {i: "\\" + chr(i) for i in b"()[]{}?*+-|^$\\.&~# \t\n\r\v\f"} + return pattern.translate(_special_chars_map) + + +class _Expr(abc.ABC): + @abc.abstractmethod + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any: + raise NotImplementedError() + + +class _Ident(_Expr): + def __init__(self, tokens: ParseResults): + self.ident: str = tokens[0] + + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any: + try: + return self._eval(part_keys, part_input) + except ValueError as e: + # existing partition values cannot be cast to current schema + raise InvalidStateException("GetPartitions", str(e)) + + def leval(self, part_keys: List[Dict[str, str]], literal: Any) -> Any: + # evaluate literal by simulating partition input + try: + return self._eval(part_keys, part_input={"Values": repeat(literal)}) + except ValueError as e: + # expression literal cannot be cast to current schema + raise InvalidInputException("GetPartitions", str(e)) + + def type_(self, part_keys: List[Dict[str, str]]) -> str: + for key in part_keys: + if self.ident == key["Name"]: + return key["Type"] + + raise InvalidInputException("GetPartitions", f"Unknown column '{self.ident}'") + + def _eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any: + for key, value in zip(part_keys, part_input["Values"]): + if self.ident == key["Name"]: + return _cast(key["Type"], value) + + # also raised for unpartitioned tables + raise InvalidInputException("GetPartitions", f"Unknown column '{self.ident}'") + + +class _IsNull(_Expr): + def __init__(self, tokens: ParseResults): + self.ident: _Ident = tokens[0] + + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + return self.ident.eval(part_keys, part_input) is None + + +class _IsNotNull(_IsNull): + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + return not super().eval(part_keys, part_input) + + +class _BinOp(_Expr): + def __init__(self, tokens: ParseResults): + self.ident: _Ident = tokens[0] + self.bin_op: str = tokens[1] + self.literal: Any = tokens[2] + + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + ident = self.ident.eval(part_keys, part_input) + + # simulate partition input for the lateral + rhs = self.ident.leval(part_keys, self.literal) + + return { + "<>": operator.ne, + ">=": operator.ge, + "<=": operator.le, + ">": operator.gt, + "<": operator.lt, + "=": operator.eq, + }[self.bin_op](ident, rhs) + + +class _Like(_Expr): + def __init__(self, tokens: ParseResults): + self.ident: _Ident = tokens[0] + self.literal: str = tokens[2] + + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + type_ = self.ident.type_(part_keys) + if type_ in ("bigint", "int", "smallint", "tinyint"): + raise InvalidInputException( + "GetPartitions", "Integral data type doesn't support operation 'LIKE'" + ) + + if type_ in ("date", "decimal", "timestamp"): + raise InvalidInputException( + "GetPartitions", + f"{type_[0].upper()}{type_[1:]} data type" + " doesn't support operation 'LIKE'", + ) + + ident = self.ident.eval(part_keys, part_input) + assert isinstance(ident, str) + + pattern = _cast("string", self.literal) + + # prepare SQL pattern for conversion to regex pattern + pattern = _escape_regex(pattern) + + # NOTE convert SQL wildcards to regex, no literal matches possible + pattern = pattern.replace("_", ".").replace("%", ".*") + + # LIKE clauses always start at the beginning + pattern = "^" + pattern + "$" + + return re.search(pattern, ident) is not None + + +class _NotLike(_Like): + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + return not super().eval(part_keys, part_input) + + +class _In(_Expr): + def __init__(self, tokens: ParseResults): + self.ident: _Ident = tokens[0] + self.values: List[Any] = tokens[2:] + + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + ident = self.ident.eval(part_keys, part_input) + values = (self.ident.leval(part_keys, value) for value in self.values) + + return ident in values + + +class _NotIn(_In): + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + return not super().eval(part_keys, part_input) + + +class _Between(_Expr): + def __init__(self, tokens: ParseResults): + self.ident: _Ident = tokens[0] + self.left: Any = tokens[2] + self.right: Any = tokens[4] + + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + ident = self.ident.eval(part_keys, part_input) + left = self.ident.leval(part_keys, self.left) + right = self.ident.leval(part_keys, self.right) + + return left <= ident <= right or left > ident > right + + +class _NotBetween(_Between): + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + return not super().eval(part_keys, part_input) + + +class _BoolAnd(_Expr): + def __init__(self, tokens: ParseResults) -> None: + self.operands: List[_Expr] = tokens[0][0::2] # skip 'and' between tokens + + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + return all(operand.eval(part_keys, part_input) for operand in self.operands) + + +class _BoolOr(_Expr): + def __init__(self, tokens: ParseResults) -> None: + self.operands: List[_Expr] = tokens[0][0::2] # skip 'or' between tokens + + def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool: + return any(operand.eval(part_keys, part_input) for operand in self.operands) + + +class _PartitionFilterExpressionCache: + def __init__(self): + # build grammar according to Glue.Client.get_partitions(Expression) + lpar, rpar = map(Suppress, "()") + + # NOTE these are AWS Athena column name best practices + ident = Forward().set_name("ident") + ident <<= Word(alphanums + "._").set_parse_action(_Ident) | lpar + ident + rpar + + number = Forward().set_name("number") + number <<= pyparsing_common.number | lpar + number + rpar + + string = Forward().set_name("string") + string <<= QuotedString(quote_char="'", esc_quote="''") | lpar + string + rpar + + literal = (number | string).set_name("literal") + literal_list = delimited_list(literal, min=1).set_name("list") + + bin_op = one_of("<> >= <= > < =").set_name("binary op") + + and_ = Forward() + and_ <<= CaselessKeyword("and") | lpar + and_ + rpar + + or_ = Forward() + or_ <<= CaselessKeyword("or") | lpar + or_ + rpar + + in_, between, like, not_, is_, null = map( + CaselessKeyword, "in between like not is null".split() + ) + not_ = Suppress(not_) # only needed for matching + + cond = ( + (ident + is_ + null).set_parse_action(_IsNull) + | (ident + is_ + not_ + null).set_parse_action(_IsNotNull) + | (ident + bin_op + literal).set_parse_action(_BinOp) + | (ident + like + string).set_parse_action(_Like) + | (ident + not_ + like + string).set_parse_action(_NotLike) + | (ident + in_ + lpar + literal_list + rpar).set_parse_action(_In) + | (ident + not_ + in_ + lpar + literal_list + rpar).set_parse_action(_NotIn) + | (ident + between + literal + and_ + literal).set_parse_action(_Between) + | (ident + not_ + between + literal + and_ + literal).set_parse_action( + _NotBetween + ) + ).set_name("cond") + + # conditions can be joined using 2-ary AND and/or OR + expr = infix_notation( + cond, + [ + (and_, 2, OpAssoc.LEFT, _BoolAnd), + (or_, 2, OpAssoc.LEFT, _BoolOr), + ], + ) + self._expr = expr.set_name("expr") + + self._cache: Dict[str, _Expr] = {} + + def get(self, expression: Optional[str]) -> Optional[_Expr]: + if expression is None: + return None + + if expression not in self._cache: + ParserElement.enable_packrat() + + try: + expr: ParseResults = self._expr.parse_string(expression, parse_all=True) + self._cache[expression] = expr[0] + except exceptions.ParseException: + raise InvalidInputException( + "GetPartitions", f"Unsupported expression '{expression}'" + ) + + return self._cache[expression] + + +_PARTITION_FILTER_EXPRESSION_CACHE = _PartitionFilterExpressionCache() + + +class PartitionFilter: + def __init__(self, expression: Optional[str], fake_table): + self.expression = expression + self.fake_table = fake_table + + def __call__(self, fake_partition) -> bool: + expression = _PARTITION_FILTER_EXPRESSION_CACHE.get(self.expression) + if expression is None: + return True + + warnings.warn("Expression filtering is experimental") + return expression.eval( + part_keys=self.fake_table.versions[-1].get("PartitionKeys", []), + part_input=fake_partition.partition_input, + ) diff --git a/setup.py b/setup.py index 31cb595d8..860747c0b 100755 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ _dep_aws_xray_sdk = "aws-xray-sdk!=0.96,>=0.93" _dep_idna = "idna<4,>=2.5" _dep_cfn_lint = "cfn-lint>=0.4.0" _dep_sshpubkeys = "sshpubkeys>=3.1.0" +_dep_pyparsing = "pyparsing>=3.0.0" _setuptools = "setuptools" all_extra_deps = [ @@ -67,6 +68,7 @@ all_extra_deps = [ _dep_idna, _dep_cfn_lint, _dep_sshpubkeys, + _dep_pyparsing, _setuptools, ] all_server_deps = all_extra_deps + ["flask", "flask-cors"] @@ -88,6 +90,7 @@ extras_per_service.update( "cloudformation": [_dep_docker, _dep_PyYAML, _dep_cfn_lint], "cognitoidp": [_dep_python_jose, _dep_python_jose_ecdsa_pin], "ec2": [_dep_sshpubkeys], + "glue": [_dep_pyparsing], "iotdata": [_dep_jsondiff], "s3": [_dep_PyYAML], "ses": [], diff --git a/tests/test_glue/helpers.py b/tests/test_glue/helpers.py index 465f7c942..7293b99bf 100644 --- a/tests/test_glue/helpers.py +++ b/tests/test_glue/helpers.py @@ -67,6 +67,15 @@ def get_table_version(client, database_name, table_name, version_id): ) +def create_column(name, type_, comment=None, parameters=None): + column = {"Name": name, "Type": type_} + if comment is not None: + column["Comment"] = comment + if parameters is not None: + column["Parameters"] = parameters + return column + + def create_partition_input(database_name, table_name, values=None, columns=None): root_path = "s3://my-bucket/{database_name}/{table_name}".format( database_name=database_name, table_name=table_name diff --git a/tests/test_glue/test_partition_filter.py b/tests/test_glue/test_partition_filter.py new file mode 100644 index 000000000..b1b70c6f0 --- /dev/null +++ b/tests/test_glue/test_partition_filter.py @@ -0,0 +1,381 @@ +from unittest import SkipTest + +import boto3 +import pytest +import sure # noqa # pylint: disable=unused-import +from botocore.client import ClientError + +from moto import mock_glue, settings + +from . import helpers + + +@mock_glue +def test_get_partitions_expression_unknown_column(): + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] + columns = [helpers.create_column("date_col", "date")] + helpers.create_database(client, database_name) + + helpers.create_table(client, database_name, table_name) + + helpers.create_partition( + client, database_name, table_name, values=values, columns=columns + ) + + with pytest.raises(ClientError) as exc: + client.get_partitions( + DatabaseName=database_name, + TableName=table_name, + Expression="unknown_col IS NULL", + ) + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match("Unknown column 'unknown_col'") + + +@mock_glue +def test_get_partitions_expression_int_column(): + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + columns = [helpers.create_column("int_col", "int")] + + helpers.create_database(client, database_name) + + args = (client, database_name, table_name) + helpers.create_table(*args, partition_keys=columns) + helpers.create_partition(*args, values=["1"], columns=columns) + helpers.create_partition(*args, values=["2"], columns=columns) + helpers.create_partition(*args, values=["3"], columns=columns) + + kwargs = {"DatabaseName": database_name, "TableName": table_name} + + response = client.get_partitions(**kwargs) + partitions = response["Partitions"] + partitions.should.have.length_of(3) + + int_col_is_two_expressions = ( + "int_col = 2", + "int_col = '2'", + "int_col IN (2)", + "int_col in (6, '4', 2)", + "int_col between 2 AND 2", + "int_col > 1 AND int_col < 3", + "int_col >= 2 and int_col <> 3", + "(int_col) = ((2)) (OR) (((int_col))) = (2)", + "int_col IS NOT NULL and int_col = 2", + "int_col not IN (1, 3)", + "int_col NOT BETWEEN 1 AND 1 and int_col NOT BETWEEN 3 AND 3", + "int_col = 4 OR int_col = 5 AND int_col = '-1' OR int_col = 0 OR int_col = '2'", + ) + + for expression in int_col_is_two_expressions: + response = client.get_partitions(**kwargs, Expression=expression) + partitions = response["Partitions"] + partitions.should.have.length_of(1) + partition = partitions[0] + partition["Values"].should.equal(["2"]) + + bad_int_expressions = ("int_col = 'test'", "int_col in (2.5)") + for expression in bad_int_expressions: + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression=expression) + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match("is not an integer") + + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression="int_col LIKE '2'") + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match( + "Integral data type doesn't support operation 'LIKE'" + ) + + +@mock_glue +def test_get_partitions_expression_decimal_column(): + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + columns = [helpers.create_column("decimal_col", "decimal")] + + helpers.create_database(client, database_name) + + args = (client, database_name, table_name) + helpers.create_table(*args, partition_keys=columns) + helpers.create_partition(*args, values=["1.2"], columns=columns) + helpers.create_partition(*args, values=["2.6"], columns=columns) + helpers.create_partition(*args, values=["3e14"], columns=columns) + + kwargs = {"DatabaseName": database_name, "TableName": table_name} + + response = client.get_partitions(**kwargs) + partitions = response["Partitions"] + partitions.should.have.length_of(3) + + decimal_col_is_two_point_six_expressions = ( + "decimal_col = 2.6", + "decimal_col = '2.6'", + "decimal_col IN (2.6)", + "decimal_col in (6, '4', 2.6)", + "decimal_col between 2 AND 3e10", + "decimal_col > 1.5 AND decimal_col < 3", + "decimal_col >= 2 and decimal_col <> '3e14'", + ) + + for expression in decimal_col_is_two_point_six_expressions: + response = client.get_partitions(**kwargs, Expression=expression) + partitions = response["Partitions"] + partitions.should.have.length_of(1) + partition = partitions[0] + partition["Values"].should.equal(["2.6"]) + + bad_decimal_expressions = ("decimal_col = 'test'",) + for expression in bad_decimal_expressions: + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression=expression) + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match("is not a decimal") + + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression="decimal_col LIKE '2'") + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match( + "Decimal data type doesn't support operation 'LIKE'" + ) + + +@mock_glue +def test_get_partitions_expression_string_column(): + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + columns = [helpers.create_column("string_col", "string")] + + helpers.create_database(client, database_name) + + args = (client, database_name, table_name) + helpers.create_table(*args, partition_keys=columns) + helpers.create_partition(*args, values=["one"], columns=columns) + helpers.create_partition(*args, values=["two"], columns=columns) + helpers.create_partition(*args, values=["2"], columns=columns) + helpers.create_partition(*args, values=["three"], columns=columns) + + kwargs = {"DatabaseName": database_name, "TableName": table_name} + + response = client.get_partitions(**kwargs) + partitions = response["Partitions"] + partitions.should.have.length_of(4) + + string_col_is_two_expressions = ( + "string_col = 'two'", + "string_col = 2", + "string_col IN (1, 2, 3)", + "string_col IN ('1', '2', '3')", + "string_col IN ('test', 'two', '3')", + "string_col between 'twn' AND 'twp'", + "string_col > '1' AND string_col < '3'", + "string_col LIKE 'two'", + "string_col LIKE 't_o'", + "string_col LIKE 't__'", + "string_col LIKE '%wo'", + "string_col NOT LIKE '%e' AND string_col not like '_'", + ) + + for expression in string_col_is_two_expressions: + response = client.get_partitions(**kwargs, Expression=expression) + partitions = response["Partitions"] + partitions.should.have.length_of(1) + partition = partitions[0] + partition["Values"].should.be.within((["two"], ["2"])) + + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression="unknown_col LIKE 'two'") + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match("Unknown column 'unknown_col'") + + +@mock_glue +def test_get_partitions_expression_date_column(): + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + columns = [helpers.create_column("date_col", "date")] + + helpers.create_database(client, database_name) + + args = (client, database_name, table_name) + helpers.create_table(*args, partition_keys=columns) + helpers.create_partition(*args, values=["2022-01-01"], columns=columns) + helpers.create_partition(*args, values=["2022-02-01"], columns=columns) + helpers.create_partition(*args, values=["2022-03-01"], columns=columns) + + kwargs = {"DatabaseName": database_name, "TableName": table_name} + + response = client.get_partitions(**kwargs) + partitions = response["Partitions"] + partitions.should.have.length_of(3) + + date_col_is_february_expressions = ( + "date_col = '2022-02-01'", + "date_col IN ('2022-02-01')", + "date_col in ('2024-02-29', '2022-02-01', '2022-02-02')", + "date_col between '2022-01-15' AND '2022-02-15'", + "date_col > '2022-01-15' AND date_col < '2022-02-15'", + ) + + for expression in date_col_is_february_expressions: + response = client.get_partitions(**kwargs, Expression=expression) + partitions = response["Partitions"] + partitions.should.have.length_of(1) + partition = partitions[0] + partition["Values"].should.equal(["2022-02-01"]) + + bad_date_expressions = ("date_col = 'test'", "date_col = '2022-02-32'") + for expression in bad_date_expressions: + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression=expression) + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match("is not a date") + + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression="date_col LIKE '2022-02-01'") + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match( + "Date data type doesn't support operation 'LIKE'" + ) + + +@mock_glue +def test_get_partitions_expression_timestamp_column(): + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + columns = [helpers.create_column("timestamp_col", "timestamp")] + + helpers.create_database(client, database_name) + + args = (client, database_name, table_name) + helpers.create_table(*args, partition_keys=columns) + helpers.create_partition(*args, values=["2022-01-01 12:34:56.789"], columns=columns) + helpers.create_partition( + *args, values=["2022-02-01 00:00:00.000000"], columns=columns + ) + helpers.create_partition( + *args, values=["2022-03-01 21:00:12.3456789"], columns=columns + ) + + kwargs = {"DatabaseName": database_name, "TableName": table_name} + + response = client.get_partitions(**kwargs) + partitions = response["Partitions"] + partitions.should.have.length_of(3) + + timestamp_col_is_february_expressions = ( + "timestamp_col = '2022-02-01 00:00:00'", + "timestamp_col = '2022-02-01 00:00:00.0'", + "timestamp_col = '2022-02-01 00:00:00.000000000'", + "timestamp_col IN ('2022-02-01 00:00:00.000')", + "timestamp_col between '2022-01-15 00:00:00' AND '2022-02-15 00:00:00'", + "timestamp_col > '2022-01-15 00:00:00' AND " + "timestamp_col < '2022-02-15 00:00:00'", + # these expressions only work because of rounding to microseconds + "timestamp_col = '2022-01-31 23:59:59.999999999'", + "timestamp_col = '2022-02-01 00:00:00.00000001'", + "timestamp_col > '2022-01-31 23:59:59.999999499' AND" + " timestamp_col < '2022-02-01 00:00:00.0000009'", + ) + + for expression in timestamp_col_is_february_expressions: + response = client.get_partitions(**kwargs, Expression=expression) + partitions = response["Partitions"] + partitions.should.have.length_of(1) + partition = partitions[0] + partition["Values"].should.equal(["2022-02-01 00:00:00.000000"]) + + bad_timestamp_expressions = ( + "timestamp_col = '2022-02-01'", + "timestamp_col = '2022-02-15 00:00:00.'", + "timestamp_col = '2022-02-32 00:00:00'", + ) + for expression in bad_timestamp_expressions: + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression=expression) + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match("is not a timestamp") + + with pytest.raises(ClientError) as exc: + client.get_partitions( + **kwargs, Expression="timestamp_col LIKE '2022-02-01 00:00:00'" + ) + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match( + "Timestamp data type doesn't support operation 'LIKE'" + ) + + +@mock_glue +def test_get_partition_expression_warnings_and_exceptions(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Cannot catch warnings in server mode") + + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + columns = [ + helpers.create_column("string_col", "string"), + helpers.create_column("int_col", "int"), + helpers.create_column("float_col", "float"), + ] + + helpers.create_database(client, database_name) + + args = (client, database_name, table_name) + helpers.create_table(*args, partition_keys=columns) + helpers.create_partition(*args, values=["test", "int", "3.14"], columns=columns) + + kwargs = {"DatabaseName": database_name, "TableName": table_name} + + with pytest.warns(match="Expression filtering is experimental"): + response = client.get_partitions(**kwargs, Expression="string_col = 'test'") + partitions = response["Partitions"] + partitions.should.have.length_of(1) + partition = partitions[0] + partition["Values"].should.equal(["test", "int", "3.14"]) + + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression="float_col = 3.14") + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match("Unknown type : 'float'") + + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression="int_col = 2") + + exc.value.response["Error"]["Code"].should.equal("InvalidStateException") + exc.value.response["Error"]["Message"].should.match('"int" is not an integer') + + with pytest.raises(ClientError) as exc: + client.get_partitions(**kwargs, Expression="unknown_col = 'test'") + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match("Unknown column 'unknown_col'") + + with pytest.raises(ClientError) as exc: + client.get_partitions( + **kwargs, Expression="string_col IS test' AND not parsable" + ) + + exc.value.response["Error"]["Code"].should.equal("InvalidInputException") + exc.value.response["Error"]["Message"].should.match("Unsupported expression")