moto/moto/glue/utils.py

365 lines
12 KiB
Python

import abc
import operator
import re
from datetime import date, datetime
from itertools import repeat
from typing import Any, Dict, List, Optional, Union
from pyparsing import (
CaselessKeyword,
Forward,
OpAssoc,
ParserElement,
ParseResults,
QuotedString,
Suppress,
Word,
alphanums,
exceptions,
infix_notation,
one_of,
pyparsing_common,
)
try:
# TODO import directly when depending on pyparsing>=3.1.0
from pyparsing import DelimitedList
except ImportError:
# delimited_list is deprecated in favor of DelimitedList in pyparsing 3.1.0
from pyparsing import delimited_list as DelimitedList # type: ignore[assignment]
from .exceptions import (
InvalidInputException,
InvalidStateException,
)
def _cast(type_: str, value: Any) -> Union[date, datetime, float, int, str]:
# values are always cast from string to target type
value = str(value)
if type_ in ("bigint", "int", "smallint", "tinyint"):
try:
return int(value) # no size is enforced
except ValueError:
raise ValueError(f'"{value}" is not an integer.')
if type_ == "decimal":
try:
return float(value)
except ValueError:
raise ValueError(f"{value} is not a decimal.")
if type_ in ("char", "string", "varchar"):
return value # no length is enforced
if type_ == "date":
try:
return datetime.strptime(value, "%Y-%m-%d").date()
except ValueError:
raise ValueError(f"{value} is not a date.")
if type_ == "timestamp":
match = re.search(
r"^(?P<timestamp>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})"
r"(?P<nanos>\.\d{1,9})?$",
value,
)
if match is None:
raise ValueError(
"Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff]"
f" {value} is not a timestamp."
)
try:
timestamp = datetime.strptime(match.group("timestamp"), "%Y-%m-%d %H:%M:%S")
except ValueError:
raise ValueError(
"Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff]"
f" {value} is not a timestamp."
)
# use nanosecond representation for timestamps
posix_nanoseconds = int(timestamp.timestamp() * 1_000_000_000)
nanos = match.group("nanos")
if nanos is not None:
# strip leading dot, reverse and left pad with zeros to nanoseconds
nanos = "".join(reversed(nanos[1:])).zfill(9)
for i, nanoseconds in enumerate(nanos):
posix_nanoseconds += int(nanoseconds) * 10**i
return posix_nanoseconds
raise InvalidInputException("GetPartitions", f"Unknown type : '{type_}'")
def _escape_regex(pattern: str) -> str:
"""Taken from Python 3.7 to avoid escaping '%'."""
_special_chars_map = {i: "\\" + chr(i) for i in b"()[]{}?*+-|^$\\.&~# \t\n\r\v\f"}
return pattern.translate(_special_chars_map)
class _Expr(abc.ABC):
@abc.abstractmethod
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any: # type: ignore[misc]
raise NotImplementedError()
class _Ident(_Expr):
def __init__(self, tokens: ParseResults):
self.ident: str = tokens[0]
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any:
try:
return self._eval(part_keys, part_input)
except ValueError as e:
# existing partition values cannot be cast to current schema
raise InvalidStateException("GetPartitions", str(e))
def leval(self, part_keys: List[Dict[str, str]], literal: Any) -> Any:
# evaluate literal by simulating partition input
try:
return self._eval(part_keys, part_input={"Values": repeat(literal)})
except ValueError as e:
# expression literal cannot be cast to current schema
raise InvalidInputException("GetPartitions", str(e))
def type_(self, part_keys: List[Dict[str, str]]) -> str:
for key in part_keys:
if self.ident == key["Name"]:
return key["Type"]
raise InvalidInputException("GetPartitions", f"Unknown column '{self.ident}'")
def _eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> Any:
for key, value in zip(part_keys, part_input["Values"]):
if self.ident == key["Name"]:
return _cast(key["Type"], value)
# also raised for unpartitioned tables
raise InvalidInputException("GetPartitions", f"Unknown column '{self.ident}'")
class _IsNull(_Expr):
def __init__(self, tokens: ParseResults):
self.ident: _Ident = tokens[0]
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
return self.ident.eval(part_keys, part_input) is None
class _IsNotNull(_IsNull):
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
return not super().eval(part_keys, part_input)
class _BinOp(_Expr):
def __init__(self, tokens: ParseResults):
self.ident: _Ident = tokens[0]
self.bin_op: str = tokens[1]
self.literal: Any = tokens[2]
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
ident = self.ident.eval(part_keys, part_input)
# simulate partition input for the lateral
rhs = self.ident.leval(part_keys, self.literal)
return {
"<>": operator.ne,
">=": operator.ge,
"<=": operator.le,
">": operator.gt,
"<": operator.lt,
"=": operator.eq,
}[self.bin_op](ident, rhs)
class _Like(_Expr):
def __init__(self, tokens: ParseResults):
self.ident: _Ident = tokens[0]
self.literal: str = tokens[2]
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
type_ = self.ident.type_(part_keys)
if type_ in ("bigint", "int", "smallint", "tinyint"):
raise InvalidInputException(
"GetPartitions", "Integral data type doesn't support operation 'LIKE'"
)
if type_ in ("date", "decimal", "timestamp"):
raise InvalidInputException(
"GetPartitions",
f"{type_[0].upper()}{type_[1:]} data type"
" doesn't support operation 'LIKE'",
)
ident = self.ident.eval(part_keys, part_input)
assert isinstance(ident, str)
pattern = _cast("string", self.literal)
# prepare SQL pattern for conversion to regex pattern
pattern = _escape_regex(pattern) # type: ignore
# NOTE convert SQL wildcards to regex, no literal matches possible
pattern = pattern.replace("_", ".").replace("%", ".*")
# LIKE clauses always start at the beginning
pattern = "^" + pattern + "$"
return re.search(pattern, ident) is not None
class _NotLike(_Like):
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
return not super().eval(part_keys, part_input)
class _In(_Expr):
def __init__(self, tokens: ParseResults):
self.ident: _Ident = tokens[0]
self.values: List[Any] = tokens[2:]
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
ident = self.ident.eval(part_keys, part_input)
values = (self.ident.leval(part_keys, value) for value in self.values)
return ident in values
class _NotIn(_In):
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
return not super().eval(part_keys, part_input)
class _Between(_Expr):
def __init__(self, tokens: ParseResults):
self.ident: _Ident = tokens[0]
self.left: Any = tokens[2]
self.right: Any = tokens[4]
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
ident = self.ident.eval(part_keys, part_input)
left = self.ident.leval(part_keys, self.left)
right = self.ident.leval(part_keys, self.right)
return left <= ident <= right or left > ident > right
class _NotBetween(_Between):
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
return not super().eval(part_keys, part_input)
class _BoolAnd(_Expr):
def __init__(self, tokens: ParseResults) -> None:
self.operands: List[_Expr] = tokens[0][0::2] # skip 'and' between tokens
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
return all(operand.eval(part_keys, part_input) for operand in self.operands)
class _BoolOr(_Expr):
def __init__(self, tokens: ParseResults) -> None:
self.operands: List[_Expr] = tokens[0][0::2] # skip 'or' between tokens
def eval(self, part_keys: List[Dict[str, str]], part_input: Dict[str, Any]) -> bool:
return any(operand.eval(part_keys, part_input) for operand in self.operands)
class _PartitionFilterExpressionCache:
def __init__(self) -> None:
# build grammar according to Glue.Client.get_partitions(Expression)
lpar, rpar = map(Suppress, "()")
# NOTE these are AWS Athena column name best practices
ident = Forward().set_name("ident")
ident <<= Word(alphanums + "._").set_parse_action(_Ident) | lpar + ident + rpar # type: ignore
number = Forward().set_name("number")
number <<= pyparsing_common.number | lpar + number + rpar # type: ignore
string = Forward().set_name("string")
string <<= QuotedString(quote_char="'", esc_quote="''") | lpar + string + rpar # type: ignore
literal = (number | string).set_name("literal")
literal_list = DelimitedList(literal, min=1).set_name("list")
bin_op = one_of("<> >= <= > < =").set_name("binary op")
and_ = Forward()
and_ <<= CaselessKeyword("and") | lpar + and_ + rpar
or_ = Forward()
or_ <<= CaselessKeyword("or") | lpar + or_ + rpar
in_, between, like, not_, is_, null = map(
CaselessKeyword, "in between like not is null".split()
)
not_ = Suppress(not_) # type: ignore # only needed for matching
cond = (
(ident + is_ + null).set_parse_action(_IsNull)
| (ident + is_ + not_ + null).set_parse_action(_IsNotNull)
| (ident + bin_op + literal).set_parse_action(_BinOp)
| (ident + like + string).set_parse_action(_Like)
| (ident + not_ + like + string).set_parse_action(_NotLike)
| (ident + in_ + lpar + literal_list + rpar).set_parse_action(_In)
| (ident + not_ + in_ + lpar + literal_list + rpar).set_parse_action(_NotIn)
| (ident + between + literal + and_ + literal).set_parse_action(_Between)
| (ident + not_ + between + literal + and_ + literal).set_parse_action(
_NotBetween
)
).set_name("cond")
# conditions can be joined using 2-ary AND and/or OR
expr = infix_notation(
cond,
[
(and_, 2, OpAssoc.LEFT, _BoolAnd),
(or_, 2, OpAssoc.LEFT, _BoolOr),
],
)
self._expr = expr.set_name("expr")
self._cache: Dict[str, _Expr] = {}
def get(self, expression: Optional[str]) -> Optional[_Expr]:
if expression is None:
return None
if expression not in self._cache:
ParserElement.enable_packrat()
try:
expr: ParseResults = self._expr.parse_string(expression, parse_all=True)
self._cache[expression] = expr[0]
except exceptions.ParseException:
raise InvalidInputException(
"GetPartitions", f"Unsupported expression '{expression}'"
)
return self._cache[expression]
_PARTITION_FILTER_EXPRESSION_CACHE = _PartitionFilterExpressionCache()
class PartitionFilter:
def __init__(self, expression: Optional[str], fake_table: Any):
self.expression = expression
self.fake_table = fake_table
def __call__(self, fake_partition: Any) -> bool:
expression = _PARTITION_FILTER_EXPRESSION_CACHE.get(self.expression)
if expression is None:
return True
versions = list(self.fake_table.versions.values())
return expression.eval(
part_keys=versions[-1].get("PartitionKeys", []),
part_input=fake_partition.partition_input,
)