query-engine/src/sql/parser.rs
2026-04-13 13:14:01 +02:00

807 lines
24 KiB
Rust

use std::error::Error;
use std::fmt;
use super::ast::{
AggregateArg, AggregateFunc, BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem,
SortDirection, TableRef,
};
/// Errors returned by the minimal SQL parser.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ParseError {
UnexpectedEnd,
ExpectedToken(&'static str),
ExpectedIdentifier,
UnexpectedToken(String),
MixedWildcardProjection,
UnterminatedString,
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnexpectedEnd => write!(f, "unexpected end of input"),
Self::ExpectedToken(token) => write!(f, "expected `{}`", token),
Self::ExpectedIdentifier => write!(f, "expected identifier"),
Self::UnexpectedToken(token) => write!(f, "unexpected token `{}`", token),
Self::MixedWildcardProjection => {
write!(
f,
"wildcard projections cannot be combined with other items"
)
}
Self::UnterminatedString => write!(f, "unterminated string literal"),
}
}
}
impl Error for ParseError {}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Token {
Select,
From,
Where,
As,
And,
Or,
Order,
By,
Asc,
Desc,
Null,
Limit,
Group,
Identifier(String),
String(String),
Integer(usize),
Star,
Comma,
LParen,
RParen,
Eq,
Ne,
}
/// Parse a `SELECT-FROM-WHERE-ORDER BY` query in the current SQL subset.
pub fn parse_select(input: &str) -> Result<Select, ParseError> {
let tokens = tokenize(input)?;
let mut parser = Parser::new(tokens);
parser.parse_select()
}
struct Parser {
tokens: Vec<Token>,
index: usize,
}
impl Parser {
fn new(tokens: Vec<Token>) -> Self {
Self { tokens, index: 0 }
}
fn parse_select(&mut self) -> Result<Select, ParseError> {
self.expect_keyword(Token::Select, "SELECT")?;
let projection = self.parse_projection()?;
self.expect_keyword(Token::From, "FROM")?;
let from = self.parse_from_list()?;
let selection = if self.peek() == Some(&Token::Where) {
self.index += 1;
Some(self.parse_expr()?)
} else {
None
};
let group_by = if self.peek() == Some(&Token::Group) {
self.index += 1;
self.expect_keyword(Token::By, "BY")?;
self.parse_group_by()?
} else {
Vec::new()
};
let order_by = if self.peek() == Some(&Token::Order) {
self.index += 1;
self.expect_keyword(Token::By, "BY")?;
self.parse_order_by()?
} else {
Vec::new()
};
let limit = if self.peek() == Some(&Token::Limit) {
self.index += 1;
Some(self.expect_integer()?)
} else {
None
};
if let Some(token) = self.peek() {
return Err(ParseError::UnexpectedToken(render_token(token)));
}
Ok(Select {
projection,
from,
selection,
group_by,
order_by,
limit,
})
}
fn parse_group_by(&mut self) -> Result<Vec<Expr>, ParseError> {
let mut items = Vec::new();
loop {
items.push(self.parse_operand()?);
if self.peek() == Some(&Token::Comma) {
self.index += 1;
continue;
}
break;
}
Ok(items)
}
fn parse_projection(&mut self) -> Result<Vec<SelectItem>, ParseError> {
let mut items = Vec::new();
loop {
let item = match self.peek().ok_or(ParseError::UnexpectedEnd)? {
Token::Star => {
self.index += 1;
SelectItem::Wildcard
}
_ => {
let expr = self.parse_operand()?;
let alias = if self.peek() == Some(&Token::As) {
self.index += 1;
Some(self.expect_identifier()?)
} else {
None
};
SelectItem::Expr { expr, alias }
}
};
items.push(item);
if self.peek() == Some(&Token::Comma) {
self.index += 1;
continue;
}
break;
}
if items.len() > 1
&& items
.iter()
.any(|item| matches!(item, SelectItem::Wildcard))
{
return Err(ParseError::MixedWildcardProjection);
}
Ok(items)
}
fn parse_from_list(&mut self) -> Result<Vec<TableRef>, ParseError> {
let mut tables = Vec::new();
loop {
let name = self.expect_identifier()?;
let alias = if self.peek() == Some(&Token::As) {
self.index += 1;
Some(self.expect_identifier()?)
} else {
None
};
tables.push(TableRef { name, alias });
if self.peek() == Some(&Token::Comma) {
self.index += 1;
continue;
}
break;
}
Ok(tables)
}
fn parse_expr(&mut self) -> Result<Expr, ParseError> {
let mut expr = self.parse_and()?;
while self.peek() == Some(&Token::Or) {
self.index += 1;
let right = self.parse_and()?;
expr = Expr::Binary {
left: Box::new(expr),
op: BinaryOp::Or,
right: Box::new(right),
};
}
Ok(expr)
}
fn parse_and(&mut self) -> Result<Expr, ParseError> {
let mut expr = self.parse_equality()?;
while self.peek() == Some(&Token::And) {
self.index += 1;
let right = self.parse_equality()?;
expr = Expr::Binary {
left: Box::new(expr),
op: BinaryOp::And,
right: Box::new(right),
};
}
Ok(expr)
}
fn parse_order_by(&mut self) -> Result<Vec<OrderByItem>, ParseError> {
let mut items = Vec::new();
loop {
let expr = self.parse_operand()?;
let direction = match self.peek() {
Some(Token::Asc) => {
self.index += 1;
SortDirection::Asc
}
Some(Token::Desc) => {
self.index += 1;
SortDirection::Desc
}
_ => SortDirection::Asc,
};
items.push(OrderByItem { expr, direction });
if self.peek() == Some(&Token::Comma) {
self.index += 1;
continue;
}
break;
}
Ok(items)
}
fn parse_equality(&mut self) -> Result<Expr, ParseError> {
let left = self.parse_operand()?;
match self.next().ok_or(ParseError::UnexpectedEnd)? {
Token::Eq => {
let right = self.parse_operand()?;
Ok(Expr::Binary {
left: Box::new(left),
op: BinaryOp::Eq,
right: Box::new(right),
})
}
Token::Ne => {
let right = self.parse_operand()?;
Ok(Expr::Binary {
left: Box::new(left),
op: BinaryOp::Ne,
right: Box::new(right),
})
}
other => Err(ParseError::UnexpectedToken(render_token(&other))),
}
}
fn parse_operand(&mut self) -> Result<Expr, ParseError> {
match self.next().ok_or(ParseError::UnexpectedEnd)? {
Token::Identifier(name) => {
if self.peek() == Some(&Token::LParen) {
self.parse_function_call(name)
} else {
Ok(Expr::Identifier(name))
}
}
Token::String(value) => Ok(Expr::Literal(Literal::String(value))),
Token::Integer(n) => Ok(Expr::Literal(Literal::Integer(n as i64))),
Token::Null => Ok(Expr::Literal(Literal::Null)),
other => Err(ParseError::UnexpectedToken(render_token(&other))),
}
}
fn parse_function_call(&mut self, name: String) -> Result<Expr, ParseError> {
self.expect_keyword(Token::LParen, "(")?;
let func = match name.to_ascii_uppercase().as_str() {
"COUNT" => AggregateFunc::Count,
"SUM" => AggregateFunc::Sum,
"MIN" => AggregateFunc::Min,
"MAX" => AggregateFunc::Max,
"AVG" => AggregateFunc::Avg,
_ => return Err(ParseError::UnexpectedToken(name)),
};
let arg = if self.peek() == Some(&Token::Star) {
self.index += 1;
if !matches!(func, AggregateFunc::Count) {
return Err(ParseError::UnexpectedToken("*".to_string()));
}
AggregateArg::Star
} else {
AggregateArg::Expr(Box::new(self.parse_operand()?))
};
self.expect_keyword(Token::RParen, ")")?;
Ok(Expr::Aggregate { func, arg })
}
fn expect_keyword(&mut self, token: Token, label: &'static str) -> Result<(), ParseError> {
let next = self.next().ok_or(ParseError::UnexpectedEnd)?;
if next == token {
Ok(())
} else {
Err(ParseError::ExpectedToken(label))
}
}
fn expect_identifier(&mut self) -> Result<String, ParseError> {
match self.next().ok_or(ParseError::UnexpectedEnd)? {
Token::Identifier(name) => Ok(name),
_ => Err(ParseError::ExpectedIdentifier),
}
}
fn expect_integer(&mut self) -> Result<usize, ParseError> {
match self.next().ok_or(ParseError::UnexpectedEnd)? {
Token::Integer(n) => Ok(n),
other => Err(ParseError::UnexpectedToken(render_token(&other))),
}
}
fn peek(&self) -> Option<&Token> {
self.tokens.get(self.index)
}
fn next(&mut self) -> Option<Token> {
let token = self.tokens.get(self.index).cloned();
if token.is_some() {
self.index += 1;
}
token
}
}
fn tokenize(input: &str) -> Result<Vec<Token>, ParseError> {
let mut chars = input.chars().peekable();
let mut tokens = Vec::new();
while let Some(ch) = chars.peek().copied() {
if ch.is_whitespace() {
chars.next();
continue;
}
match ch {
'*' => {
chars.next();
tokens.push(Token::Star);
}
',' => {
chars.next();
tokens.push(Token::Comma);
}
'(' => {
chars.next();
tokens.push(Token::LParen);
}
')' => {
chars.next();
tokens.push(Token::RParen);
}
'!' => {
chars.next();
if chars.peek() == Some(&'=') {
chars.next();
tokens.push(Token::Ne);
} else {
return Err(ParseError::UnexpectedToken("!".to_string()));
}
}
'<' => {
chars.next();
if chars.peek() == Some(&'>') {
chars.next();
tokens.push(Token::Ne);
} else {
return Err(ParseError::UnexpectedToken("<".to_string()));
}
}
'=' => {
chars.next();
tokens.push(Token::Eq);
}
'\'' => tokens.push(Token::String(parse_string(&mut chars)?)),
ch if ch.is_ascii_digit() => {
let number = parse_integer(&mut chars);
tokens.push(Token::Integer(number));
}
ch if is_identifier_start(ch) => {
let ident = parse_identifier(&mut chars);
let token = match ident.to_ascii_uppercase().as_str() {
"SELECT" => Token::Select,
"FROM" => Token::From,
"WHERE" => Token::Where,
"AS" => Token::As,
"AND" => Token::And,
"OR" => Token::Or,
"ORDER" => Token::Order,
"BY" => Token::By,
"ASC" => Token::Asc,
"DESC" => Token::Desc,
"NULL" => Token::Null,
"LIMIT" => Token::Limit,
"GROUP" => Token::Group,
_ => Token::Identifier(ident),
};
tokens.push(token);
}
other => return Err(ParseError::UnexpectedToken(other.to_string())),
}
}
Ok(tokens)
}
fn parse_string<I>(chars: &mut std::iter::Peekable<I>) -> Result<String, ParseError>
where
I: Iterator<Item = char>,
{
let mut value = String::new();
let quote = chars.next();
if quote != Some('\'') {
return Err(ParseError::ExpectedToken("'"));
}
while let Some(ch) = chars.next() {
if ch == '\'' {
if chars.peek() == Some(&'\'') {
chars.next();
value.push('\'');
continue;
}
return Ok(value);
}
value.push(ch);
}
Err(ParseError::UnterminatedString)
}
fn parse_identifier<I>(chars: &mut std::iter::Peekable<I>) -> String
where
I: Iterator<Item = char>,
{
let mut ident = String::new();
while let Some(ch) = chars.peek().copied() {
if is_identifier_part(ch) {
ident.push(ch);
chars.next();
} else {
break;
}
}
ident
}
fn parse_integer<I>(chars: &mut std::iter::Peekable<I>) -> usize
where
I: Iterator<Item = char>,
{
let mut value: usize = 0;
while let Some(ch) = chars.peek().copied() {
if ch.is_ascii_digit() {
value = value * 10 + (ch as usize - '0' as usize);
chars.next();
} else {
break;
}
}
value
}
fn is_identifier_start(ch: char) -> bool {
ch.is_ascii_alphabetic() || ch == '_'
}
fn is_identifier_part(ch: char) -> bool {
ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | ':' | '.')
}
fn render_token(token: &Token) -> String {
match token {
Token::Select => "SELECT".to_string(),
Token::From => "FROM".to_string(),
Token::Where => "WHERE".to_string(),
Token::As => "AS".to_string(),
Token::And => "AND".to_string(),
Token::Or => "OR".to_string(),
Token::Order => "ORDER".to_string(),
Token::By => "BY".to_string(),
Token::Asc => "ASC".to_string(),
Token::Desc => "DESC".to_string(),
Token::Null => "NULL".to_string(),
Token::Limit => "LIMIT".to_string(),
Token::Identifier(name) => name.clone(),
Token::Integer(n) => n.to_string(),
Token::String(value) => format!("'{}'", value),
Token::Star => "*".to_string(),
Token::Comma => ",".to_string(),
Token::LParen => "(".to_string(),
Token::RParen => ")".to_string(),
Token::Group => "GROUP".to_string(),
Token::Eq => "=".to_string(),
Token::Ne => "!=".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_select_with_filter() {
let select = parse_select("SELECT c0 FROM Parent WHERE c1 = 'bob'").unwrap();
assert_eq!(
select.from,
vec![TableRef {
name: "Parent".to_string(),
alias: None,
}]
);
assert_eq!(select.projection.len(), 1);
assert!(select.selection.is_some());
assert!(select.order_by.is_empty());
}
#[test]
fn parses_projection_aliases_and_literals() {
let select =
parse_select("SELECT c0 AS parent_name, 'seed' AS label, NULL FROM Parent").unwrap();
assert_eq!(select.projection.len(), 3);
assert_eq!(
select.projection[0],
SelectItem::Expr {
expr: Expr::Identifier("c0".to_string()),
alias: Some("parent_name".to_string()),
}
);
assert_eq!(
select.projection[1],
SelectItem::Expr {
expr: Expr::Literal(Literal::String("seed".to_string())),
alias: Some("label".to_string()),
}
);
assert_eq!(
select.projection[2],
SelectItem::Expr {
expr: Expr::Literal(Literal::Null),
alias: None,
}
);
}
#[test]
fn parses_multi_table_select_with_qualified_columns() {
let select = parse_select(
"SELECT Parent.parent, Ancestor.child FROM Parent, Ancestor \
WHERE Parent.child = Ancestor.parent",
)
.unwrap();
assert_eq!(
select.from,
vec![
TableRef {
name: "Parent".to_string(),
alias: None,
},
TableRef {
name: "Ancestor".to_string(),
alias: None,
}
]
);
assert_eq!(
select.projection[0],
SelectItem::Expr {
expr: Expr::Identifier("Parent.parent".to_string()),
alias: None,
}
);
}
#[test]
fn parses_table_aliases() {
let select = parse_select(
"SELECT p.parent, a.child FROM Parent AS p, Ancestor AS a \
WHERE p.child = a.parent",
)
.unwrap();
assert_eq!(
select.from,
vec![
TableRef {
name: "Parent".to_string(),
alias: Some("p".to_string()),
},
TableRef {
name: "Ancestor".to_string(),
alias: Some("a".to_string()),
}
]
);
}
#[test]
fn parses_conjunctive_where_clause() {
let select =
parse_select("SELECT c0 FROM Parent WHERE c1 = 'bob' AND c0 = 'alice'").unwrap();
assert_eq!(
select.selection,
Some(Expr::Binary {
left: Box::new(Expr::Binary {
left: Box::new(Expr::Identifier("c1".to_string())),
op: BinaryOp::Eq,
right: Box::new(Expr::Literal(Literal::String("bob".to_string()))),
}),
op: BinaryOp::And,
right: Box::new(Expr::Binary {
left: Box::new(Expr::Identifier("c0".to_string())),
op: BinaryOp::Eq,
right: Box::new(Expr::Literal(Literal::String("alice".to_string()))),
}),
})
);
}
#[test]
fn parses_order_by_clause() {
let select = parse_select("SELECT c0 FROM Parent ORDER BY c0 DESC, c1 ASC").unwrap();
assert_eq!(
select.order_by,
vec![
OrderByItem {
expr: Expr::Identifier("c0".to_string()),
direction: SortDirection::Desc,
},
OrderByItem {
expr: Expr::Identifier("c1".to_string()),
direction: SortDirection::Asc,
},
]
);
}
#[test]
fn parses_frontend_style_identifiers() {
let select = parse_select("SELECT * FROM Employee-Records:2025").unwrap();
assert_eq!(select.from.len(), 1);
assert_eq!(select.from[0].name, "Employee-Records:2025");
}
#[test]
fn rejects_mixed_wildcard_projection() {
let error = parse_select("SELECT *, c0 FROM Parent").unwrap_err();
assert_eq!(error, ParseError::MixedWildcardProjection);
}
#[test]
fn parses_not_equal_with_bang_eq() {
let select = parse_select("SELECT c0 FROM Parent WHERE c1 != 'bob'").unwrap();
assert_eq!(
select.selection,
Some(Expr::Binary {
left: Box::new(Expr::Identifier("c1".to_string())),
op: BinaryOp::Ne,
right: Box::new(Expr::Literal(Literal::String("bob".to_string()))),
})
);
}
#[test]
fn parses_not_equal_with_diamond() {
let select = parse_select("SELECT c0 FROM Parent WHERE c1 <> 'bob'").unwrap();
assert_eq!(
select.selection,
Some(Expr::Binary {
left: Box::new(Expr::Identifier("c1".to_string())),
op: BinaryOp::Ne,
right: Box::new(Expr::Literal(Literal::String("bob".to_string()))),
})
);
}
#[test]
fn parses_or_expression() {
let select =
parse_select("SELECT c0 FROM Parent WHERE c0 = 'alice' OR c0 = 'bob'").unwrap();
assert_eq!(
select.selection,
Some(Expr::Binary {
left: Box::new(Expr::Binary {
left: Box::new(Expr::Identifier("c0".to_string())),
op: BinaryOp::Eq,
right: Box::new(Expr::Literal(Literal::String("alice".to_string()))),
}),
op: BinaryOp::Or,
right: Box::new(Expr::Binary {
left: Box::new(Expr::Identifier("c0".to_string())),
op: BinaryOp::Eq,
right: Box::new(Expr::Literal(Literal::String("bob".to_string()))),
}),
})
);
}
#[test]
fn parses_integer_literal_in_expression() {
let select = parse_select("SELECT c0 FROM Parent WHERE c0 = 42").unwrap();
assert_eq!(
select.selection,
Some(Expr::Binary {
left: Box::new(Expr::Identifier("c0".to_string())),
op: BinaryOp::Eq,
right: Box::new(Expr::Literal(Literal::Integer(42))),
})
);
}
#[test]
fn parses_limit_clause() {
let select = parse_select("SELECT c0 FROM Parent LIMIT 5").unwrap();
assert_eq!(select.limit, Some(5));
}
#[test]
fn parses_order_by_with_limit() {
let select = parse_select("SELECT c0 FROM Parent ORDER BY c0 DESC LIMIT 1").unwrap();
assert_eq!(select.order_by.len(), 1);
assert_eq!(select.limit, Some(1));
}
#[test]
fn parses_or_with_and_precedence() {
// AND binds tighter than OR: a = '1' OR b = '2' AND c = '3'
// should parse as: a = '1' OR (b = '2' AND c = '3')
let select =
parse_select("SELECT c0 FROM Parent WHERE c0 = '1' OR c1 = '2' AND c0 = '3'").unwrap();
assert_eq!(
select.selection,
Some(Expr::Binary {
left: Box::new(Expr::Binary {
left: Box::new(Expr::Identifier("c0".to_string())),
op: BinaryOp::Eq,
right: Box::new(Expr::Literal(Literal::String("1".to_string()))),
}),
op: BinaryOp::Or,
right: Box::new(Expr::Binary {
left: Box::new(Expr::Binary {
left: Box::new(Expr::Identifier("c1".to_string())),
op: BinaryOp::Eq,
right: Box::new(Expr::Literal(Literal::String("2".to_string()))),
}),
op: BinaryOp::And,
right: Box::new(Expr::Binary {
left: Box::new(Expr::Identifier("c0".to_string())),
op: BinaryOp::Eq,
right: Box::new(Expr::Literal(Literal::String("3".to_string()))),
}),
}),
})
);
}
}