Support Expression in Glue's get_partitions (#5013)
This commit is contained in:
		
							parent
							
								
									d46987ec29
								
							
						
					
					
						commit
						794e940421
					
				@ -78,3 +78,22 @@ class CrawlerNotRunningException(GlueClientError):
 | 
				
			|||||||
class ConcurrentRunsExceededException(GlueClientError):
 | 
					class ConcurrentRunsExceededException(GlueClientError):
 | 
				
			||||||
    def __init__(self, msg):
 | 
					    def __init__(self, msg):
 | 
				
			||||||
        super().__init__("ConcurrentRunsExceededException", msg)
 | 
					        super().__init__("ConcurrentRunsExceededException", msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _InvalidOperationException(GlueClientError):
 | 
				
			||||||
 | 
					    def __init__(self, error_type, op, msg):
 | 
				
			||||||
 | 
					        super().__init__(
 | 
				
			||||||
 | 
					            error_type,
 | 
				
			||||||
 | 
					            "An error occurred (%s) when calling the %s operation: %s"
 | 
				
			||||||
 | 
					            % (error_type, op, msg),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class InvalidInputException(_InvalidOperationException):
 | 
				
			||||||
 | 
					    def __init__(self, op, msg):
 | 
				
			||||||
 | 
					        super().__init__("InvalidInputException", op, msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class InvalidStateException(_InvalidOperationException):
 | 
				
			||||||
 | 
					    def __init__(self, op, msg):
 | 
				
			||||||
 | 
					        super().__init__("InvalidStateException", op, msg)
 | 
				
			||||||
 | 
				
			|||||||
@ -18,6 +18,7 @@ from .exceptions import (
 | 
				
			|||||||
    JobNotFoundException,
 | 
					    JobNotFoundException,
 | 
				
			||||||
    ConcurrentRunsExceededException,
 | 
					    ConcurrentRunsExceededException,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					from .utils import PartitionFilter
 | 
				
			||||||
from ..utilities.paginator import paginate
 | 
					from ..utilities.paginator import paginate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -278,8 +279,19 @@ class FakeTable(BaseModel):
 | 
				
			|||||||
            raise PartitionAlreadyExistsException()
 | 
					            raise PartitionAlreadyExistsException()
 | 
				
			||||||
        self.partitions[str(partition.values)] = partition
 | 
					        self.partitions[str(partition.values)] = partition
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_partitions(self):
 | 
					    def get_partitions(self, expression):
 | 
				
			||||||
        return [p for str_part_values, p in self.partitions.items()]
 | 
					        """See https://docs.aws.amazon.com/glue/latest/webapi/API_GetPartitions.html
 | 
				
			||||||
 | 
					        for supported expressions.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Expression caveats:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        - Column names must consist of UPPERCASE, lowercase, dots and underscores only.
 | 
				
			||||||
 | 
					        - Nanosecond expressions on timestamp columns are rounded to microseconds.
 | 
				
			||||||
 | 
					        - Literal dates and timestamps must be valid, i.e. no support for February 31st.
 | 
				
			||||||
 | 
					        - LIKE expressions are converted to Python regexes, escaping special characters.
 | 
				
			||||||
 | 
					          Only % and _ wildcards are supported, and SQL escaping using [] does not work.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return list(filter(PartitionFilter(expression, self), self.partitions.values()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_partition(self, values):
 | 
					    def get_partition(self, values):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
				
			|||||||
@ -129,13 +129,12 @@ class GlueResponse(BaseResponse):
 | 
				
			|||||||
    def get_partitions(self):
 | 
					    def get_partitions(self):
 | 
				
			||||||
        database_name = self.parameters.get("DatabaseName")
 | 
					        database_name = self.parameters.get("DatabaseName")
 | 
				
			||||||
        table_name = self.parameters.get("TableName")
 | 
					        table_name = self.parameters.get("TableName")
 | 
				
			||||||
        if "Expression" in self.parameters:
 | 
					        expression = self.parameters.get("Expression")
 | 
				
			||||||
            raise NotImplementedError(
 | 
					 | 
				
			||||||
                "Expression filtering in get_partitions is not implemented in moto"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        table = self.glue_backend.get_table(database_name, table_name)
 | 
					        table = self.glue_backend.get_table(database_name, table_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return json.dumps({"Partitions": [p.as_dict() for p in table.get_partitions()]})
 | 
					        return json.dumps(
 | 
				
			||||||
 | 
					            {"Partitions": [p.as_dict() for p in table.get_partitions(expression)]}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_partition(self):
 | 
					    def get_partition(self):
 | 
				
			||||||
        database_name = self.parameters.get("DatabaseName")
 | 
					        database_name = self.parameters.get("DatabaseName")
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,354 @@
 | 
				
			|||||||
 | 
					import abc
 | 
				
			||||||
 | 
					import operator
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					from datetime import date, datetime, timedelta
 | 
				
			||||||
 | 
					from itertools import repeat
 | 
				
			||||||
 | 
					from typing import Any, Dict, List, Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pyparsing import (
 | 
				
			||||||
 | 
					    CaselessKeyword,
 | 
				
			||||||
 | 
					    Forward,
 | 
				
			||||||
 | 
					    OpAssoc,
 | 
				
			||||||
 | 
					    ParserElement,
 | 
				
			||||||
 | 
					    ParseResults,
 | 
				
			||||||
 | 
					    QuotedString,
 | 
				
			||||||
 | 
					    Suppress,
 | 
				
			||||||
 | 
					    Word,
 | 
				
			||||||
 | 
					    alphanums,
 | 
				
			||||||
 | 
					    delimited_list,
 | 
				
			||||||
 | 
					    exceptions,
 | 
				
			||||||
 | 
					    infix_notation,
 | 
				
			||||||
 | 
					    one_of,
 | 
				
			||||||
 | 
					    pyparsing_common,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .exceptions import InvalidInputException, InvalidStateException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _cast(type_: str, value: Any) -> Union[date, datetime, float, int, str]:
 | 
				
			||||||
 | 
					    # values are always cast from string to target type
 | 
				
			||||||
 | 
					    value = str(value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if type_ in ("bigint", "int", "smallint", "tinyint"):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return int(value)  # no size is enforced
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            raise ValueError(f'"{value}" is not an integer.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if type_ == "decimal":
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return float(value)
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            raise ValueError(f"{value} is not a decimal.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if type_ in ("char", "string", "varchar"):
 | 
				
			||||||
 | 
					        return value  # no length is enforced
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if type_ == "date":
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return datetime.strptime(value, "%Y-%m-%d").date()
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            raise ValueError(f"{value} is not a date.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if type_ == "timestamp":
 | 
				
			||||||
 | 
					        match = re.search(
 | 
				
			||||||
 | 
					            r"^(?P<timestamp>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})"
 | 
				
			||||||
 | 
					            r"(?P<nanos>\.\d{1,9})?$",
 | 
				
			||||||
 | 
					            value,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if match is None:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                "Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff]"
 | 
				
			||||||
 | 
					                f" {value} is not a timestamp."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            timestamp = datetime.strptime(match.group("timestamp"), "%Y-%m-%d %H:%M:%S")
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                "Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff]"
 | 
				
			||||||
 | 
					                f" {value} is not a timestamp."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        nanos = match.group("nanos")
 | 
				
			||||||
 | 
					        if nanos is not None:
 | 
				
			||||||
 | 
					            # strip leading dot, reverse and left pad with zeros to nanoseconds
 | 
				
			||||||
 | 
					            nanos = "".join(reversed(nanos[1:])).zfill(9)
 | 
				
			||||||
 | 
					            for i, nanoseconds in enumerate(nanos):
 | 
				
			||||||
 | 
					                microseconds = (int(nanoseconds) * 10**i) / 1000
 | 
				
			||||||
 | 
					                timestamp += timedelta(microseconds=round(microseconds))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return timestamp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    raise InvalidInputException("GetPartitions", f"Unknown type : '{type_}'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _escape_regex(pattern: str) -> str:
 | 
				
			||||||
 | 
					    """Taken from Python 3.7 to avoid escaping '%'."""
 | 
				
			||||||
 | 
					    _special_chars_map = {i: "\\" + chr(i) for i in b"()[]{}?*+-|^$\\.&~# \t\n\r\v\f"}
 | 
				
			||||||
 | 
					    return pattern.translate(_special_chars_map)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _Expr(abc.ABC):
 | 
				
			||||||
 | 
					    @abc.abstractmethod
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any:
 | 
				
			||||||
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _Ident(_Expr):
 | 
				
			||||||
 | 
					    def __init__(self, tokens: ParseResults):
 | 
				
			||||||
 | 
					        self.ident: str = tokens[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return self._eval(part_keys, part_input)
 | 
				
			||||||
 | 
					        except ValueError as e:
 | 
				
			||||||
 | 
					            # existing partition values cannot be cast to current schema
 | 
				
			||||||
 | 
					            raise InvalidStateException("GetPartitions", str(e))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def leval(self, part_keys: List[Dict[str, str]], literal: Any) -> Any:
 | 
				
			||||||
 | 
					        # evaluate literal by simulating partition input
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return self._eval(part_keys, part_input={"Values": repeat(literal)})
 | 
				
			||||||
 | 
					        except ValueError as e:
 | 
				
			||||||
 | 
					            # expression literal cannot be cast to current schema
 | 
				
			||||||
 | 
					            raise InvalidInputException("GetPartitions", str(e))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def type_(self, part_keys: List[Dict[str, str]]) -> str:
 | 
				
			||||||
 | 
					        for key in part_keys:
 | 
				
			||||||
 | 
					            if self.ident == key["Name"]:
 | 
				
			||||||
 | 
					                return key["Type"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        raise InvalidInputException("GetPartitions", f"Unknown column '{self.ident}'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any:
 | 
				
			||||||
 | 
					        for key, value in zip(part_keys, part_input["Values"]):
 | 
				
			||||||
 | 
					            if self.ident == key["Name"]:
 | 
				
			||||||
 | 
					                return _cast(key["Type"], value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # also raised for unpartitioned tables
 | 
				
			||||||
 | 
					        raise InvalidInputException("GetPartitions", f"Unknown column '{self.ident}'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _IsNull(_Expr):
 | 
				
			||||||
 | 
					    def __init__(self, tokens: ParseResults):
 | 
				
			||||||
 | 
					        self.ident: _Ident = tokens[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        return self.ident.eval(part_keys, part_input) is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _IsNotNull(_IsNull):
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        return not super().eval(part_keys, part_input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _BinOp(_Expr):
 | 
				
			||||||
 | 
					    def __init__(self, tokens: ParseResults):
 | 
				
			||||||
 | 
					        self.ident: _Ident = tokens[0]
 | 
				
			||||||
 | 
					        self.bin_op: str = tokens[1]
 | 
				
			||||||
 | 
					        self.literal: Any = tokens[2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        ident = self.ident.eval(part_keys, part_input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # simulate partition input for the lateral
 | 
				
			||||||
 | 
					        rhs = self.ident.leval(part_keys, self.literal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            "<>": operator.ne,
 | 
				
			||||||
 | 
					            ">=": operator.ge,
 | 
				
			||||||
 | 
					            "<=": operator.le,
 | 
				
			||||||
 | 
					            ">": operator.gt,
 | 
				
			||||||
 | 
					            "<": operator.lt,
 | 
				
			||||||
 | 
					            "=": operator.eq,
 | 
				
			||||||
 | 
					        }[self.bin_op](ident, rhs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _Like(_Expr):
 | 
				
			||||||
 | 
					    def __init__(self, tokens: ParseResults):
 | 
				
			||||||
 | 
					        self.ident: _Ident = tokens[0]
 | 
				
			||||||
 | 
					        self.literal: str = tokens[2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        type_ = self.ident.type_(part_keys)
 | 
				
			||||||
 | 
					        if type_ in ("bigint", "int", "smallint", "tinyint"):
 | 
				
			||||||
 | 
					            raise InvalidInputException(
 | 
				
			||||||
 | 
					                "GetPartitions", "Integral data type doesn't support operation 'LIKE'"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if type_ in ("date", "decimal", "timestamp"):
 | 
				
			||||||
 | 
					            raise InvalidInputException(
 | 
				
			||||||
 | 
					                "GetPartitions",
 | 
				
			||||||
 | 
					                f"{type_[0].upper()}{type_[1:]} data type"
 | 
				
			||||||
 | 
					                " doesn't support operation 'LIKE'",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ident = self.ident.eval(part_keys, part_input)
 | 
				
			||||||
 | 
					        assert isinstance(ident, str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pattern = _cast("string", self.literal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # prepare SQL pattern for conversion to regex pattern
 | 
				
			||||||
 | 
					        pattern = _escape_regex(pattern)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # NOTE convert SQL wildcards to regex, no literal matches possible
 | 
				
			||||||
 | 
					        pattern = pattern.replace("_", ".").replace("%", ".*")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # LIKE clauses always start at the beginning
 | 
				
			||||||
 | 
					        pattern = "^" + pattern + "$"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return re.search(pattern, ident) is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _NotLike(_Like):
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        return not super().eval(part_keys, part_input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _In(_Expr):
 | 
				
			||||||
 | 
					    def __init__(self, tokens: ParseResults):
 | 
				
			||||||
 | 
					        self.ident: _Ident = tokens[0]
 | 
				
			||||||
 | 
					        self.values: List[Any] = tokens[2:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        ident = self.ident.eval(part_keys, part_input)
 | 
				
			||||||
 | 
					        values = (self.ident.leval(part_keys, value) for value in self.values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return ident in values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _NotIn(_In):
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        return not super().eval(part_keys, part_input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _Between(_Expr):
 | 
				
			||||||
 | 
					    def __init__(self, tokens: ParseResults):
 | 
				
			||||||
 | 
					        self.ident: _Ident = tokens[0]
 | 
				
			||||||
 | 
					        self.left: Any = tokens[2]
 | 
				
			||||||
 | 
					        self.right: Any = tokens[4]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        ident = self.ident.eval(part_keys, part_input)
 | 
				
			||||||
 | 
					        left = self.ident.leval(part_keys, self.left)
 | 
				
			||||||
 | 
					        right = self.ident.leval(part_keys, self.right)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return left <= ident <= right or left > ident > right
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _NotBetween(_Between):
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        return not super().eval(part_keys, part_input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _BoolAnd(_Expr):
 | 
				
			||||||
 | 
					    def __init__(self, tokens: ParseResults) -> None:
 | 
				
			||||||
 | 
					        self.operands: List[_Expr] = tokens[0][0::2]  # skip 'and' between tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        return all(operand.eval(part_keys, part_input) for operand in self.operands)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _BoolOr(_Expr):
 | 
				
			||||||
 | 
					    def __init__(self, tokens: ParseResults) -> None:
 | 
				
			||||||
 | 
					        self.operands: List[_Expr] = tokens[0][0::2]  # skip 'or' between tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
 | 
				
			||||||
 | 
					        return any(operand.eval(part_keys, part_input) for operand in self.operands)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _PartitionFilterExpressionCache:
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        # build grammar according to Glue.Client.get_partitions(Expression)
 | 
				
			||||||
 | 
					        lpar, rpar = map(Suppress, "()")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # NOTE these are AWS Athena column name best practices
 | 
				
			||||||
 | 
					        ident = Forward().set_name("ident")
 | 
				
			||||||
 | 
					        ident <<= Word(alphanums + "._").set_parse_action(_Ident) | lpar + ident + rpar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        number = Forward().set_name("number")
 | 
				
			||||||
 | 
					        number <<= pyparsing_common.number | lpar + number + rpar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        string = Forward().set_name("string")
 | 
				
			||||||
 | 
					        string <<= QuotedString(quote_char="'", esc_quote="''") | lpar + string + rpar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        literal = (number | string).set_name("literal")
 | 
				
			||||||
 | 
					        literal_list = delimited_list(literal, min=1).set_name("list")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        bin_op = one_of("<> >= <= > < =").set_name("binary op")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        and_ = Forward()
 | 
				
			||||||
 | 
					        and_ <<= CaselessKeyword("and") | lpar + and_ + rpar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        or_ = Forward()
 | 
				
			||||||
 | 
					        or_ <<= CaselessKeyword("or") | lpar + or_ + rpar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        in_, between, like, not_, is_, null = map(
 | 
				
			||||||
 | 
					            CaselessKeyword, "in between like not is null".split()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        not_ = Suppress(not_)  # only needed for matching
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cond = (
 | 
				
			||||||
 | 
					            (ident + is_ + null).set_parse_action(_IsNull)
 | 
				
			||||||
 | 
					            | (ident + is_ + not_ + null).set_parse_action(_IsNotNull)
 | 
				
			||||||
 | 
					            | (ident + bin_op + literal).set_parse_action(_BinOp)
 | 
				
			||||||
 | 
					            | (ident + like + string).set_parse_action(_Like)
 | 
				
			||||||
 | 
					            | (ident + not_ + like + string).set_parse_action(_NotLike)
 | 
				
			||||||
 | 
					            | (ident + in_ + lpar + literal_list + rpar).set_parse_action(_In)
 | 
				
			||||||
 | 
					            | (ident + not_ + in_ + lpar + literal_list + rpar).set_parse_action(_NotIn)
 | 
				
			||||||
 | 
					            | (ident + between + literal + and_ + literal).set_parse_action(_Between)
 | 
				
			||||||
 | 
					            | (ident + not_ + between + literal + and_ + literal).set_parse_action(
 | 
				
			||||||
 | 
					                _NotBetween
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        ).set_name("cond")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # conditions can be joined using 2-ary AND and/or OR
 | 
				
			||||||
 | 
					        expr = infix_notation(
 | 
				
			||||||
 | 
					            cond,
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                (and_, 2, OpAssoc.LEFT, _BoolAnd),
 | 
				
			||||||
 | 
					                (or_, 2, OpAssoc.LEFT, _BoolOr),
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self._expr = expr.set_name("expr")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._cache: Dict[str, _Expr] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, expression: Optional[str]) -> Optional[_Expr]:
 | 
				
			||||||
 | 
					        if expression is None:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if expression not in self._cache:
 | 
				
			||||||
 | 
					            ParserElement.enable_packrat()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                expr: ParseResults = self._expr.parse_string(expression, parse_all=True)
 | 
				
			||||||
 | 
					                self._cache[expression] = expr[0]
 | 
				
			||||||
 | 
					            except exceptions.ParseException:
 | 
				
			||||||
 | 
					                raise InvalidInputException(
 | 
				
			||||||
 | 
					                    "GetPartitions", f"Unsupported expression '{expression}'"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self._cache[expression]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_PARTITION_FILTER_EXPRESSION_CACHE = _PartitionFilterExpressionCache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PartitionFilter:
 | 
				
			||||||
 | 
					    def __init__(self, expression: Optional[str], fake_table):
 | 
				
			||||||
 | 
					        self.expression = expression
 | 
				
			||||||
 | 
					        self.fake_table = fake_table
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, fake_partition) -> bool:
 | 
				
			||||||
 | 
					        expression = _PARTITION_FILTER_EXPRESSION_CACHE.get(self.expression)
 | 
				
			||||||
 | 
					        if expression is None:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        warnings.warn("Expression filtering is experimental")
 | 
				
			||||||
 | 
					        return expression.eval(
 | 
				
			||||||
 | 
					            part_keys=self.fake_table.versions[-1].get("PartitionKeys", []),
 | 
				
			||||||
 | 
					            part_input=fake_partition.partition_input,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
							
								
								
									
										3
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
									
									
									
									
								
							@ -54,6 +54,7 @@ _dep_aws_xray_sdk = "aws-xray-sdk!=0.96,>=0.93"
 | 
				
			|||||||
_dep_idna = "idna<4,>=2.5"
 | 
					_dep_idna = "idna<4,>=2.5"
 | 
				
			||||||
_dep_cfn_lint = "cfn-lint>=0.4.0"
 | 
					_dep_cfn_lint = "cfn-lint>=0.4.0"
 | 
				
			||||||
_dep_sshpubkeys = "sshpubkeys>=3.1.0"
 | 
					_dep_sshpubkeys = "sshpubkeys>=3.1.0"
 | 
				
			||||||
 | 
					_dep_pyparsing = "pyparsing>=3.0.0"
 | 
				
			||||||
_setuptools = "setuptools"
 | 
					_setuptools = "setuptools"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
all_extra_deps = [
 | 
					all_extra_deps = [
 | 
				
			||||||
@ -67,6 +68,7 @@ all_extra_deps = [
 | 
				
			|||||||
    _dep_idna,
 | 
					    _dep_idna,
 | 
				
			||||||
    _dep_cfn_lint,
 | 
					    _dep_cfn_lint,
 | 
				
			||||||
    _dep_sshpubkeys,
 | 
					    _dep_sshpubkeys,
 | 
				
			||||||
 | 
					    _dep_pyparsing,
 | 
				
			||||||
    _setuptools,
 | 
					    _setuptools,
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
all_server_deps = all_extra_deps + ["flask", "flask-cors"]
 | 
					all_server_deps = all_extra_deps + ["flask", "flask-cors"]
 | 
				
			||||||
@ -88,6 +90,7 @@ extras_per_service.update(
 | 
				
			|||||||
        "cloudformation": [_dep_docker, _dep_PyYAML, _dep_cfn_lint],
 | 
					        "cloudformation": [_dep_docker, _dep_PyYAML, _dep_cfn_lint],
 | 
				
			||||||
        "cognitoidp": [_dep_python_jose, _dep_python_jose_ecdsa_pin],
 | 
					        "cognitoidp": [_dep_python_jose, _dep_python_jose_ecdsa_pin],
 | 
				
			||||||
        "ec2": [_dep_sshpubkeys],
 | 
					        "ec2": [_dep_sshpubkeys],
 | 
				
			||||||
 | 
					        "glue": [_dep_pyparsing],
 | 
				
			||||||
        "iotdata": [_dep_jsondiff],
 | 
					        "iotdata": [_dep_jsondiff],
 | 
				
			||||||
        "s3": [_dep_PyYAML],
 | 
					        "s3": [_dep_PyYAML],
 | 
				
			||||||
        "ses": [],
 | 
					        "ses": [],
 | 
				
			||||||
 | 
				
			|||||||
@ -67,6 +67,15 @@ def get_table_version(client, database_name, table_name, version_id):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_column(name, type_, comment=None, parameters=None):
 | 
				
			||||||
 | 
					    column = {"Name": name, "Type": type_}
 | 
				
			||||||
 | 
					    if comment is not None:
 | 
				
			||||||
 | 
					        column["Comment"] = comment
 | 
				
			||||||
 | 
					    if parameters is not None:
 | 
				
			||||||
 | 
					        column["Parameters"] = parameters
 | 
				
			||||||
 | 
					    return column
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_partition_input(database_name, table_name, values=None, columns=None):
 | 
					def create_partition_input(database_name, table_name, values=None, columns=None):
 | 
				
			||||||
    root_path = "s3://my-bucket/{database_name}/{table_name}".format(
 | 
					    root_path = "s3://my-bucket/{database_name}/{table_name}".format(
 | 
				
			||||||
        database_name=database_name, table_name=table_name
 | 
					        database_name=database_name, table_name=table_name
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										381
									
								
								tests/test_glue/test_partition_filter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										381
									
								
								tests/test_glue/test_partition_filter.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,381 @@
 | 
				
			|||||||
 | 
					from unittest import SkipTest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import boto3
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import sure  # noqa # pylint: disable=unused-import
 | 
				
			||||||
 | 
					from botocore.client import ClientError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from moto import mock_glue, settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from . import helpers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@mock_glue
 | 
				
			||||||
 | 
					def test_get_partitions_expression_unknown_column():
 | 
				
			||||||
 | 
					    client = boto3.client("glue", region_name="us-east-1")
 | 
				
			||||||
 | 
					    database_name = "myspecialdatabase"
 | 
				
			||||||
 | 
					    table_name = "myfirsttable"
 | 
				
			||||||
 | 
					    values = ["2018-10-01"]
 | 
				
			||||||
 | 
					    columns = [helpers.create_column("date_col", "date")]
 | 
				
			||||||
 | 
					    helpers.create_database(client, database_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    helpers.create_table(client, database_name, table_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    helpers.create_partition(
 | 
				
			||||||
 | 
					        client, database_name, table_name, values=values, columns=columns
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(
 | 
				
			||||||
 | 
					            DatabaseName=database_name,
 | 
				
			||||||
 | 
					            TableName=table_name,
 | 
				
			||||||
 | 
					            Expression="unknown_col IS NULL",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match("Unknown column 'unknown_col'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@mock_glue
 | 
				
			||||||
 | 
					def test_get_partitions_expression_int_column():
 | 
				
			||||||
 | 
					    client = boto3.client("glue", region_name="us-east-1")
 | 
				
			||||||
 | 
					    database_name = "myspecialdatabase"
 | 
				
			||||||
 | 
					    table_name = "myfirsttable"
 | 
				
			||||||
 | 
					    columns = [helpers.create_column("int_col", "int")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    helpers.create_database(client, database_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = (client, database_name, table_name)
 | 
				
			||||||
 | 
					    helpers.create_table(*args, partition_keys=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["1"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["2"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["3"], columns=columns)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kwargs = {"DatabaseName": database_name, "TableName": table_name}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response = client.get_partitions(**kwargs)
 | 
				
			||||||
 | 
					    partitions = response["Partitions"]
 | 
				
			||||||
 | 
					    partitions.should.have.length_of(3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int_col_is_two_expressions = (
 | 
				
			||||||
 | 
					        "int_col = 2",
 | 
				
			||||||
 | 
					        "int_col = '2'",
 | 
				
			||||||
 | 
					        "int_col IN (2)",
 | 
				
			||||||
 | 
					        "int_col in (6, '4', 2)",
 | 
				
			||||||
 | 
					        "int_col between 2 AND 2",
 | 
				
			||||||
 | 
					        "int_col > 1 AND int_col < 3",
 | 
				
			||||||
 | 
					        "int_col >= 2 and int_col <> 3",
 | 
				
			||||||
 | 
					        "(int_col) = ((2)) (OR) (((int_col))) = (2)",
 | 
				
			||||||
 | 
					        "int_col IS NOT NULL and int_col = 2",
 | 
				
			||||||
 | 
					        "int_col not IN (1, 3)",
 | 
				
			||||||
 | 
					        "int_col NOT BETWEEN 1 AND 1 and int_col NOT BETWEEN 3 AND 3",
 | 
				
			||||||
 | 
					        "int_col = 4 OR int_col = 5 AND int_col = '-1' OR int_col = 0 OR int_col = '2'",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for expression in int_col_is_two_expressions:
 | 
				
			||||||
 | 
					        response = client.get_partitions(**kwargs, Expression=expression)
 | 
				
			||||||
 | 
					        partitions = response["Partitions"]
 | 
				
			||||||
 | 
					        partitions.should.have.length_of(1)
 | 
				
			||||||
 | 
					        partition = partitions[0]
 | 
				
			||||||
 | 
					        partition["Values"].should.equal(["2"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bad_int_expressions = ("int_col = 'test'", "int_col in (2.5)")
 | 
				
			||||||
 | 
					    for expression in bad_int_expressions:
 | 
				
			||||||
 | 
					        with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					            client.get_partitions(**kwargs, Expression=expression)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					        exc.value.response["Error"]["Message"].should.match("is not an integer")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(**kwargs, Expression="int_col LIKE '2'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match(
 | 
				
			||||||
 | 
					        "Integral data type doesn't support operation 'LIKE'"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@mock_glue
 | 
				
			||||||
 | 
					def test_get_partitions_expression_decimal_column():
 | 
				
			||||||
 | 
					    client = boto3.client("glue", region_name="us-east-1")
 | 
				
			||||||
 | 
					    database_name = "myspecialdatabase"
 | 
				
			||||||
 | 
					    table_name = "myfirsttable"
 | 
				
			||||||
 | 
					    columns = [helpers.create_column("decimal_col", "decimal")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    helpers.create_database(client, database_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = (client, database_name, table_name)
 | 
				
			||||||
 | 
					    helpers.create_table(*args, partition_keys=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["1.2"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["2.6"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["3e14"], columns=columns)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kwargs = {"DatabaseName": database_name, "TableName": table_name}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response = client.get_partitions(**kwargs)
 | 
				
			||||||
 | 
					    partitions = response["Partitions"]
 | 
				
			||||||
 | 
					    partitions.should.have.length_of(3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    decimal_col_is_two_point_six_expressions = (
 | 
				
			||||||
 | 
					        "decimal_col = 2.6",
 | 
				
			||||||
 | 
					        "decimal_col = '2.6'",
 | 
				
			||||||
 | 
					        "decimal_col IN (2.6)",
 | 
				
			||||||
 | 
					        "decimal_col in (6, '4', 2.6)",
 | 
				
			||||||
 | 
					        "decimal_col between 2 AND 3e10",
 | 
				
			||||||
 | 
					        "decimal_col > 1.5 AND decimal_col < 3",
 | 
				
			||||||
 | 
					        "decimal_col >= 2 and decimal_col <> '3e14'",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for expression in decimal_col_is_two_point_six_expressions:
 | 
				
			||||||
 | 
					        response = client.get_partitions(**kwargs, Expression=expression)
 | 
				
			||||||
 | 
					        partitions = response["Partitions"]
 | 
				
			||||||
 | 
					        partitions.should.have.length_of(1)
 | 
				
			||||||
 | 
					        partition = partitions[0]
 | 
				
			||||||
 | 
					        partition["Values"].should.equal(["2.6"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bad_decimal_expressions = ("decimal_col = 'test'",)
 | 
				
			||||||
 | 
					    for expression in bad_decimal_expressions:
 | 
				
			||||||
 | 
					        with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					            client.get_partitions(**kwargs, Expression=expression)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					        exc.value.response["Error"]["Message"].should.match("is not a decimal")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(**kwargs, Expression="decimal_col LIKE '2'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match(
 | 
				
			||||||
 | 
					        "Decimal data type doesn't support operation 'LIKE'"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@mock_glue
 | 
				
			||||||
 | 
					def test_get_partitions_expression_string_column():
 | 
				
			||||||
 | 
					    client = boto3.client("glue", region_name="us-east-1")
 | 
				
			||||||
 | 
					    database_name = "myspecialdatabase"
 | 
				
			||||||
 | 
					    table_name = "myfirsttable"
 | 
				
			||||||
 | 
					    columns = [helpers.create_column("string_col", "string")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    helpers.create_database(client, database_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = (client, database_name, table_name)
 | 
				
			||||||
 | 
					    helpers.create_table(*args, partition_keys=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["one"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["two"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["2"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["three"], columns=columns)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kwargs = {"DatabaseName": database_name, "TableName": table_name}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response = client.get_partitions(**kwargs)
 | 
				
			||||||
 | 
					    partitions = response["Partitions"]
 | 
				
			||||||
 | 
					    partitions.should.have.length_of(4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    string_col_is_two_expressions = (
 | 
				
			||||||
 | 
					        "string_col = 'two'",
 | 
				
			||||||
 | 
					        "string_col = 2",
 | 
				
			||||||
 | 
					        "string_col IN (1, 2, 3)",
 | 
				
			||||||
 | 
					        "string_col IN ('1', '2', '3')",
 | 
				
			||||||
 | 
					        "string_col IN ('test', 'two', '3')",
 | 
				
			||||||
 | 
					        "string_col between 'twn' AND 'twp'",
 | 
				
			||||||
 | 
					        "string_col > '1' AND string_col < '3'",
 | 
				
			||||||
 | 
					        "string_col LIKE 'two'",
 | 
				
			||||||
 | 
					        "string_col LIKE 't_o'",
 | 
				
			||||||
 | 
					        "string_col LIKE 't__'",
 | 
				
			||||||
 | 
					        "string_col LIKE '%wo'",
 | 
				
			||||||
 | 
					        "string_col NOT LIKE '%e' AND string_col not like '_'",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for expression in string_col_is_two_expressions:
 | 
				
			||||||
 | 
					        response = client.get_partitions(**kwargs, Expression=expression)
 | 
				
			||||||
 | 
					        partitions = response["Partitions"]
 | 
				
			||||||
 | 
					        partitions.should.have.length_of(1)
 | 
				
			||||||
 | 
					        partition = partitions[0]
 | 
				
			||||||
 | 
					        partition["Values"].should.be.within((["two"], ["2"]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(**kwargs, Expression="unknown_col LIKE 'two'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match("Unknown column 'unknown_col'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@mock_glue
 | 
				
			||||||
 | 
					def test_get_partitions_expression_date_column():
 | 
				
			||||||
 | 
					    client = boto3.client("glue", region_name="us-east-1")
 | 
				
			||||||
 | 
					    database_name = "myspecialdatabase"
 | 
				
			||||||
 | 
					    table_name = "myfirsttable"
 | 
				
			||||||
 | 
					    columns = [helpers.create_column("date_col", "date")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    helpers.create_database(client, database_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = (client, database_name, table_name)
 | 
				
			||||||
 | 
					    helpers.create_table(*args, partition_keys=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["2022-01-01"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["2022-02-01"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["2022-03-01"], columns=columns)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kwargs = {"DatabaseName": database_name, "TableName": table_name}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response = client.get_partitions(**kwargs)
 | 
				
			||||||
 | 
					    partitions = response["Partitions"]
 | 
				
			||||||
 | 
					    partitions.should.have.length_of(3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    date_col_is_february_expressions = (
 | 
				
			||||||
 | 
					        "date_col = '2022-02-01'",
 | 
				
			||||||
 | 
					        "date_col IN ('2022-02-01')",
 | 
				
			||||||
 | 
					        "date_col in ('2024-02-29', '2022-02-01', '2022-02-02')",
 | 
				
			||||||
 | 
					        "date_col between '2022-01-15' AND '2022-02-15'",
 | 
				
			||||||
 | 
					        "date_col > '2022-01-15' AND date_col < '2022-02-15'",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for expression in date_col_is_february_expressions:
 | 
				
			||||||
 | 
					        response = client.get_partitions(**kwargs, Expression=expression)
 | 
				
			||||||
 | 
					        partitions = response["Partitions"]
 | 
				
			||||||
 | 
					        partitions.should.have.length_of(1)
 | 
				
			||||||
 | 
					        partition = partitions[0]
 | 
				
			||||||
 | 
					        partition["Values"].should.equal(["2022-02-01"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bad_date_expressions = ("date_col = 'test'", "date_col = '2022-02-32'")
 | 
				
			||||||
 | 
					    for expression in bad_date_expressions:
 | 
				
			||||||
 | 
					        with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					            client.get_partitions(**kwargs, Expression=expression)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					        exc.value.response["Error"]["Message"].should.match("is not a date")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(**kwargs, Expression="date_col LIKE '2022-02-01'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match(
 | 
				
			||||||
 | 
					        "Date data type doesn't support operation 'LIKE'"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@mock_glue
 | 
				
			||||||
 | 
					def test_get_partitions_expression_timestamp_column():
 | 
				
			||||||
 | 
					    client = boto3.client("glue", region_name="us-east-1")
 | 
				
			||||||
 | 
					    database_name = "myspecialdatabase"
 | 
				
			||||||
 | 
					    table_name = "myfirsttable"
 | 
				
			||||||
 | 
					    columns = [helpers.create_column("timestamp_col", "timestamp")]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    helpers.create_database(client, database_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = (client, database_name, table_name)
 | 
				
			||||||
 | 
					    helpers.create_table(*args, partition_keys=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["2022-01-01 12:34:56.789"], columns=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(
 | 
				
			||||||
 | 
					        *args, values=["2022-02-01 00:00:00.000000"], columns=columns
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    helpers.create_partition(
 | 
				
			||||||
 | 
					        *args, values=["2022-03-01 21:00:12.3456789"], columns=columns
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kwargs = {"DatabaseName": database_name, "TableName": table_name}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response = client.get_partitions(**kwargs)
 | 
				
			||||||
 | 
					    partitions = response["Partitions"]
 | 
				
			||||||
 | 
					    partitions.should.have.length_of(3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    timestamp_col_is_february_expressions = (
 | 
				
			||||||
 | 
					        "timestamp_col = '2022-02-01 00:00:00'",
 | 
				
			||||||
 | 
					        "timestamp_col = '2022-02-01 00:00:00.0'",
 | 
				
			||||||
 | 
					        "timestamp_col = '2022-02-01 00:00:00.000000000'",
 | 
				
			||||||
 | 
					        "timestamp_col IN ('2022-02-01 00:00:00.000')",
 | 
				
			||||||
 | 
					        "timestamp_col between '2022-01-15 00:00:00' AND '2022-02-15 00:00:00'",
 | 
				
			||||||
 | 
					        "timestamp_col > '2022-01-15 00:00:00' AND "
 | 
				
			||||||
 | 
					        "timestamp_col < '2022-02-15 00:00:00'",
 | 
				
			||||||
 | 
					        # these expressions only work because of rounding to microseconds
 | 
				
			||||||
 | 
					        "timestamp_col = '2022-01-31 23:59:59.999999999'",
 | 
				
			||||||
 | 
					        "timestamp_col = '2022-02-01 00:00:00.00000001'",
 | 
				
			||||||
 | 
					        "timestamp_col > '2022-01-31 23:59:59.999999499' AND"
 | 
				
			||||||
 | 
					        " timestamp_col < '2022-02-01 00:00:00.0000009'",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for expression in timestamp_col_is_february_expressions:
 | 
				
			||||||
 | 
					        response = client.get_partitions(**kwargs, Expression=expression)
 | 
				
			||||||
 | 
					        partitions = response["Partitions"]
 | 
				
			||||||
 | 
					        partitions.should.have.length_of(1)
 | 
				
			||||||
 | 
					        partition = partitions[0]
 | 
				
			||||||
 | 
					        partition["Values"].should.equal(["2022-02-01 00:00:00.000000"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bad_timestamp_expressions = (
 | 
				
			||||||
 | 
					        "timestamp_col = '2022-02-01'",
 | 
				
			||||||
 | 
					        "timestamp_col = '2022-02-15 00:00:00.'",
 | 
				
			||||||
 | 
					        "timestamp_col = '2022-02-32 00:00:00'",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    for expression in bad_timestamp_expressions:
 | 
				
			||||||
 | 
					        with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					            client.get_partitions(**kwargs, Expression=expression)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					        exc.value.response["Error"]["Message"].should.match("is not a timestamp")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(
 | 
				
			||||||
 | 
					            **kwargs, Expression="timestamp_col LIKE '2022-02-01 00:00:00'"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match(
 | 
				
			||||||
 | 
					        "Timestamp data type doesn't support operation 'LIKE'"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@mock_glue
 | 
				
			||||||
 | 
					def test_get_partition_expression_warnings_and_exceptions():
 | 
				
			||||||
 | 
					    if settings.TEST_SERVER_MODE:
 | 
				
			||||||
 | 
					        raise SkipTest("Cannot catch warnings in server mode")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    client = boto3.client("glue", region_name="us-east-1")
 | 
				
			||||||
 | 
					    database_name = "myspecialdatabase"
 | 
				
			||||||
 | 
					    table_name = "myfirsttable"
 | 
				
			||||||
 | 
					    columns = [
 | 
				
			||||||
 | 
					        helpers.create_column("string_col", "string"),
 | 
				
			||||||
 | 
					        helpers.create_column("int_col", "int"),
 | 
				
			||||||
 | 
					        helpers.create_column("float_col", "float"),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    helpers.create_database(client, database_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = (client, database_name, table_name)
 | 
				
			||||||
 | 
					    helpers.create_table(*args, partition_keys=columns)
 | 
				
			||||||
 | 
					    helpers.create_partition(*args, values=["test", "int", "3.14"], columns=columns)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kwargs = {"DatabaseName": database_name, "TableName": table_name}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.warns(match="Expression filtering is experimental"):
 | 
				
			||||||
 | 
					        response = client.get_partitions(**kwargs, Expression="string_col = 'test'")
 | 
				
			||||||
 | 
					        partitions = response["Partitions"]
 | 
				
			||||||
 | 
					        partitions.should.have.length_of(1)
 | 
				
			||||||
 | 
					        partition = partitions[0]
 | 
				
			||||||
 | 
					        partition["Values"].should.equal(["test", "int", "3.14"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(**kwargs, Expression="float_col = 3.14")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match("Unknown type : 'float'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(**kwargs, Expression="int_col = 2")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidStateException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match('"int" is not an integer')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(**kwargs, Expression="unknown_col = 'test'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match("Unknown column 'unknown_col'")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(ClientError) as exc:
 | 
				
			||||||
 | 
					        client.get_partitions(
 | 
				
			||||||
 | 
					            **kwargs, Expression="string_col IS test' AND not parsable"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Code"].should.equal("InvalidInputException")
 | 
				
			||||||
 | 
					    exc.value.response["Error"]["Message"].should.match("Unsupported expression")
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user