use std::collections::HashSet; use std::error::Error; use std::fmt; use crate::catalog::{CatalogError, PredicateCatalog}; use crate::planner::logical::{ AggregateExpr as PlanAggregateExpr, LogicalExpr, LogicalPlan, NamedExpr, SortDirection as LogicalSortDirection, SortKey, }; use crate::relational::{DataType, Field, Schema, Value}; use crate::sql::ast::{ AggregateArg, AggregateFunc, BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem, SortDirection, TableRef, }; /// Errors returned when translating SQL AST into a logical plan. #[derive(Debug)] pub enum PlannerError { /// Catalog lookup failed. Catalog(CatalogError), /// A referenced column does not exist in the input schema. UnknownColumn(String), /// A table or alias name appears more than once in one query. DuplicateSourceName(String), /// The current `ORDER BY` subset only supports output column names. UnsupportedOrderBy, /// The parser or AST contains a wildcard mixed with other projection items. MixedWildcardProjection, /// A `GROUP BY` expression is not a simple column reference. UnsupportedGroupBy, /// A projected column is neither aggregated nor present in `GROUP BY`. ProjectionNotGrouped(String), /// An aggregate expression appears in an unsupported position. UnsupportedAggregate, /// `COUNT(*)` was used with a non-count aggregate function. StarArgNotAllowed, } impl fmt::Display for PlannerError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Catalog(err) => write!(f, "catalog error: {}", err), Self::UnknownColumn(column) => write!(f, "unknown column `{}`", column), Self::DuplicateSourceName(name) => { write!(f, "source name `{}` appears more than once", name) } Self::UnsupportedOrderBy => { write!(f, "only output column names are supported in ORDER BY") } Self::MixedWildcardProjection => { write!( f, "wildcard projections cannot be combined with other items" ) } Self::UnsupportedGroupBy => { write!(f, "only bare column references are supported in GROUP BY") } Self::ProjectionNotGrouped(name) => { write!(f, "column `{}` is not aggregated and not in GROUP BY", name) } Self::UnsupportedAggregate => { write!(f, "aggregate expressions are only allowed in SELECT items") } Self::StarArgNotAllowed => { write!(f, "`*` is only allowed as the argument to COUNT") } } } } impl Error for PlannerError { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { Self::Catalog(err) => Some(err), Self::UnknownColumn(_) | Self::DuplicateSourceName(_) | Self::UnsupportedOrderBy | Self::MixedWildcardProjection | Self::UnsupportedGroupBy | Self::ProjectionNotGrouped(_) | Self::UnsupportedAggregate | Self::StarArgNotAllowed => None, } } } impl From for PlannerError { fn from(value: CatalogError) -> Self { Self::Catalog(value) } } /// Plan a parsed `SELECT` statement into the current logical plan subset. pub fn plan_select( select: &Select, catalog: &PredicateCatalog, ) -> Result { let (mut plan, input_schema) = plan_from_tables(&select.from, catalog)?; if let Some(selection) = &select.selection { let predicate = plan_expr(selection, &input_schema, &select.from)?; plan = LogicalPlan::Filter { input: Box::new(plan), predicate, }; } let is_aggregate_query = !select.group_by.is_empty() || select.projection.iter().any(|item| match item { SelectItem::Expr { expr, .. } => contains_aggregate(expr), SelectItem::Wildcard => false, }); if is_aggregate_query { plan = plan_aggregate(plan, &input_schema, select)?; } else if !is_wildcard_projection(&select.projection) { let mut expressions = Vec::new(); let mut fields = Vec::new(); for (index, item) in select.projection.iter().enumerate() { match item { SelectItem::Expr { expr, alias } => { let planned_expr = plan_expr(expr, &input_schema, &select.from)?; let output_name = alias .clone() .unwrap_or_else(|| default_projection_name(expr, index + 1)); let (data_type, nullable) = projection_metadata(expr, &input_schema, &select.from)?; expressions.push(NamedExpr { name: output_name.clone(), expr: planned_expr, }); fields.push(Field::new(output_name, data_type, nullable)); } SelectItem::Wildcard => return Err(PlannerError::MixedWildcardProjection), } } plan = LogicalPlan::Project { input: Box::new(plan), expressions, schema: Schema::new(fields), }; } let output_schema = plan.output_schema().clone(); plan = maybe_apply_sort(plan, output_schema, &select.order_by, &select.from)?; if let Some(count) = select.limit { plan = LogicalPlan::Limit { input: Box::new(plan), count, }; } Ok(plan) } fn contains_aggregate(expr: &Expr) -> bool { match expr { Expr::Aggregate { .. } => true, Expr::Binary { left, right, .. } => contains_aggregate(left) || contains_aggregate(right), Expr::Identifier(_) | Expr::Literal(_) => false, } } fn plan_aggregate( input: LogicalPlan, input_schema: &Schema, select: &Select, ) -> Result { // Resolve GROUP BY expressions to column names. let mut group_by_cols = Vec::new(); for expr in &select.group_by { match expr { Expr::Identifier(name) => { let resolved = resolve_column_name(name, input_schema, &select.from)?; group_by_cols.push(resolved); } _ => return Err(PlannerError::UnsupportedGroupBy), } } // Walk the projection, collecting aggregate expressions and verifying // non-aggregate column references are in GROUP BY. let mut aggregates: Vec = Vec::new(); let mut projection_items: Vec<(String, ProjectionSource)> = Vec::new(); for (index, item) in select.projection.iter().enumerate() { match item { SelectItem::Wildcard => return Err(PlannerError::MixedWildcardProjection), SelectItem::Expr { expr, alias } => { let output_name = alias .clone() .unwrap_or_else(|| default_projection_name(expr, index + 1)); let source = plan_aggregate_projection( expr, input_schema, select, &group_by_cols, &mut aggregates, )?; projection_items.push((output_name, source)); } } } // Build the Aggregate node's output schema: group_by columns followed by // aggregate outputs. let mut agg_fields = Vec::new(); for col in &group_by_cols { let field_index = input_schema .index_of(col) .ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?; let field = &input_schema.fields()[field_index]; agg_fields.push(Field::new( col.clone(), field.data_type().clone(), field.nullable(), )); } for agg in &aggregates { let (dtype, nullable) = aggregate_output_type(agg, input_schema)?; agg_fields.push(Field::new(agg.name.clone(), dtype, nullable)); } let agg_schema = Schema::new(agg_fields); let aggregate_plan = LogicalPlan::Aggregate { input: Box::new(input), group_by: group_by_cols.clone(), aggregates, schema: agg_schema.clone(), }; // Build the final Project over the aggregate output. let mut expressions = Vec::new(); let mut fields = Vec::new(); for (name, source) in projection_items { let (expr, dtype, nullable) = match source { ProjectionSource::GroupColumn(col) => { let index = agg_schema .index_of(&col) .ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?; let field = &agg_schema.fields()[index]; ( LogicalExpr::Column(col), field.data_type().clone(), field.nullable(), ) } ProjectionSource::AggregateColumn(col) => { let index = agg_schema .index_of(&col) .ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?; let field = &agg_schema.fields()[index]; ( LogicalExpr::Column(col), field.data_type().clone(), field.nullable(), ) } ProjectionSource::Literal(value) => { let (dtype, nullable) = literal_metadata(&value); (LogicalExpr::Literal(value), dtype, nullable) } }; expressions.push(NamedExpr { name: name.clone(), expr, }); fields.push(Field::new(name, dtype, nullable)); } Ok(LogicalPlan::Project { input: Box::new(aggregate_plan), expressions, schema: Schema::new(fields), }) } #[derive(Debug, Clone)] enum ProjectionSource { GroupColumn(String), AggregateColumn(String), Literal(Value), } fn plan_aggregate_projection( expr: &Expr, input_schema: &Schema, select: &Select, group_by_cols: &[String], aggregates: &mut Vec, ) -> Result { match expr { Expr::Aggregate { func, arg } => { let arg_col = match arg { AggregateArg::Star => { if !matches!(func, AggregateFunc::Count) { return Err(PlannerError::StarArgNotAllowed); } None } AggregateArg::Expr(inner) => match inner.as_ref() { Expr::Identifier(name) => { Some(resolve_column_name(name, input_schema, &select.from)?) } _ => return Err(PlannerError::UnsupportedAggregate), }, }; let synthetic_name = format!("__agg_{}", aggregates.len()); aggregates.push(PlanAggregateExpr { name: synthetic_name.clone(), func: *func, arg: arg_col, }); Ok(ProjectionSource::AggregateColumn(synthetic_name)) } Expr::Identifier(name) => { let resolved = resolve_column_name(name, input_schema, &select.from)?; if !group_by_cols.contains(&resolved) { return Err(PlannerError::ProjectionNotGrouped(name.clone())); } Ok(ProjectionSource::GroupColumn(resolved)) } Expr::Literal(literal) => Ok(ProjectionSource::Literal(plan_literal(literal))), Expr::Binary { .. } => Err(PlannerError::UnsupportedAggregate), } } fn aggregate_output_type( agg: &PlanAggregateExpr, input_schema: &Schema, ) -> Result<(DataType, bool), PlannerError> { match agg.func { AggregateFunc::Count => Ok((DataType::Integer, false)), AggregateFunc::Sum | AggregateFunc::Avg => Ok((DataType::Integer, true)), AggregateFunc::Min | AggregateFunc::Max => { if let Some(col) = &agg.arg { let index = input_schema .index_of(col) .ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?; let field = &input_schema.fields()[index]; Ok((field.data_type().clone(), true)) } else { Ok((DataType::Text, true)) } } } } fn literal_metadata(value: &Value) -> (DataType, bool) { match value { Value::Text(_) => (DataType::Text, false), Value::Integer(_) => (DataType::Integer, false), Value::Boolean(_) => (DataType::Boolean, false), Value::Null => (DataType::Text, true), } } fn is_wildcard_projection(items: &[SelectItem]) -> bool { matches!(items, [SelectItem::Wildcard]) } fn plan_from_tables( tables: &[TableRef], catalog: &PredicateCatalog, ) -> Result<(LogicalPlan, Schema), PlannerError> { let mut seen = HashSet::new(); let mut table_iter = tables.iter(); let first = table_iter.next().ok_or_else(|| { PlannerError::Catalog(CatalogError::UnknownTable("".to_string())) })?; let first_name = source_name(first); if !seen.insert(first_name.clone()) { return Err(PlannerError::DuplicateSourceName(first_name)); } let first_schema = input_schema_for_table(first, catalog, should_qualify_columns(first, tables))?; let mut plan = LogicalPlan::Scan { table: first.name.clone(), schema: first_schema.clone(), }; let mut combined_schema = first_schema; for table in table_iter { let qualified_name = source_name(table); if !seen.insert(qualified_name.clone()) { return Err(PlannerError::DuplicateSourceName(qualified_name)); } let right_schema = input_schema_for_table(table, catalog, should_qualify_columns(table, tables))?; let join_schema = combine_schemas(&combined_schema, &right_schema); let right_plan = LogicalPlan::Scan { table: table.name.clone(), schema: right_schema.clone(), }; plan = LogicalPlan::CrossJoin { left: Box::new(plan), right: Box::new(right_plan), schema: join_schema.clone(), }; combined_schema = join_schema; } Ok((plan, combined_schema)) } fn plan_expr( expr: &Expr, schema: &Schema, tables: &[TableRef], ) -> Result { match expr { Expr::Identifier(name) => { let resolved = resolve_column_name(name, schema, tables)?; Ok(LogicalExpr::Column(resolved)) } Expr::Literal(literal) => Ok(LogicalExpr::Literal(plan_literal(literal))), Expr::Binary { left, op, right } => match op { BinaryOp::Eq => Ok(LogicalExpr::Eq( Box::new(plan_expr(left, schema, tables)?), Box::new(plan_expr(right, schema, tables)?), )), BinaryOp::Ne => Ok(LogicalExpr::Ne( Box::new(plan_expr(left, schema, tables)?), Box::new(plan_expr(right, schema, tables)?), )), BinaryOp::And => Ok(LogicalExpr::And( Box::new(plan_expr(left, schema, tables)?), Box::new(plan_expr(right, schema, tables)?), )), BinaryOp::Or => Ok(LogicalExpr::Or( Box::new(plan_expr(left, schema, tables)?), Box::new(plan_expr(right, schema, tables)?), )), }, Expr::Aggregate { .. } => Err(PlannerError::UnsupportedAggregate), } } fn maybe_apply_sort( plan: LogicalPlan, schema: Schema, order_by: &[OrderByItem], tables: &[TableRef], ) -> Result { if order_by.is_empty() { return Ok(plan); } let mut keys = Vec::new(); for item in order_by { let column = match &item.expr { Expr::Identifier(name) => name.clone(), _ => return Err(PlannerError::UnsupportedOrderBy), }; let column = resolve_column_name(&column, &schema, tables)?; keys.push(SortKey { column, direction: match item.direction { SortDirection::Asc => LogicalSortDirection::Asc, SortDirection::Desc => LogicalSortDirection::Desc, }, }); } Ok(LogicalPlan::Sort { input: Box::new(plan), keys, schema, }) } fn plan_literal(literal: &Literal) -> Value { match literal { Literal::String(value) => Value::text(value.clone()), Literal::Integer(n) => Value::Integer(*n), Literal::Null => Value::Null, } } fn projection_metadata( expr: &Expr, schema: &Schema, tables: &[TableRef], ) -> Result<(DataType, bool), PlannerError> { match expr { Expr::Identifier(name) => { let resolved = resolve_column_name(name, schema, tables)?; let index = schema .index_of(&resolved) .ok_or_else(|| PlannerError::UnknownColumn(name.clone()))?; let field = &schema.fields()[index]; Ok((field.data_type().clone(), field.nullable())) } Expr::Literal(Literal::String(_)) => Ok((DataType::Text, false)), Expr::Literal(Literal::Integer(_)) => Ok((DataType::Integer, false)), Expr::Literal(Literal::Null) => Ok((DataType::Text, true)), Expr::Binary { .. } => Ok((DataType::Boolean, true)), Expr::Aggregate { .. } => Err(PlannerError::UnsupportedAggregate), } } fn resolve_column_name( name: &str, schema: &Schema, tables: &[TableRef], ) -> Result { if schema.index_of(name).is_some() { return Ok(name.to_string()); } if let Some((table_name, column_name)) = name.rsplit_once('.') && tables.len() == 1 && tables[0].alias.is_none() && tables[0].name == table_name && schema.index_of(column_name).is_some() { return Ok(column_name.to_string()); } Err(PlannerError::UnknownColumn(name.to_string())) } fn default_projection_name(expr: &Expr, ordinal: usize) -> String { match expr { Expr::Aggregate { func, arg } => { let func_name = match func { AggregateFunc::Count => "COUNT", AggregateFunc::Sum => "SUM", AggregateFunc::Min => "MIN", AggregateFunc::Max => "MAX", AggregateFunc::Avg => "AVG", }; let arg_str = match arg { AggregateArg::Star => "*".to_string(), AggregateArg::Expr(inner) => match inner.as_ref() { Expr::Identifier(name) => name.clone(), _ => format!("expr{}", ordinal), }, }; format!("{}({})", func_name, arg_str) } Expr::Identifier(name) => name.clone(), Expr::Literal(_) | Expr::Binary { .. } => format!("expr{}", ordinal), } } fn input_schema_for_table( table: &TableRef, catalog: &PredicateCatalog, qualify_columns: bool, ) -> Result { let schema = catalog.schema_for(&table.name)?.clone(); if !qualify_columns { return Ok(schema); } let qualifier = source_name(table); let fields = schema .fields() .iter() .map(|field| { Field::new( format!("{}.{}", qualifier, field.name()), field.data_type().clone(), field.nullable(), ) }) .collect(); Ok(Schema::new(fields)) } fn should_qualify_columns(table: &TableRef, tables: &[TableRef]) -> bool { table.alias.is_some() || tables.len() > 1 } fn source_name(table: &TableRef) -> String { table.alias.clone().unwrap_or_else(|| table.name.clone()) } fn combine_schemas(left: &Schema, right: &Schema) -> Schema { let mut fields = left.fields().to_vec(); fields.extend_from_slice(right.fields()); Schema::new(fields) } #[cfg(test)] mod tests { use super::*; use crate::catalog::PredicateCatalog; use crate::chase::{Atom, Instance, Term}; use crate::sql::parser::{ParseError, parse_select}; #[test] fn plans_projection_and_filter() { let instance: Instance = vec![Atom::new( "Parent", vec![Term::constant("alice"), Term::constant("bob")], )] .into_iter() .collect(); let catalog = PredicateCatalog::from_instance(&instance).unwrap(); let select = parse_select("SELECT c0 FROM Parent WHERE c1 = 'bob'").unwrap(); let plan = plan_select(&select, &catalog).unwrap(); assert_eq!(plan.output_schema().len(), 1); } #[test] fn plans_aliases_and_literal_projection() { let instance: Instance = vec![Atom::new( "Parent", vec![Term::constant("alice"), Term::constant("bob")], )] .into_iter() .collect(); let catalog = PredicateCatalog::from_instance(&instance).unwrap(); let select = parse_select("SELECT c0 AS parent_name, 'seed' AS label, NULL FROM Parent").unwrap(); let plan = plan_select(&select, &catalog).unwrap(); let schema = plan.output_schema(); assert_eq!(schema.len(), 3); assert_eq!(schema.fields()[0].name(), "parent_name"); assert_eq!(schema.fields()[1].name(), "label"); assert_eq!(schema.fields()[2].name(), "expr3"); assert_eq!(schema.fields()[1].data_type(), &DataType::Text); } #[test] fn plans_multi_table_select_with_qualified_columns() { let instance: Instance = vec![ Atom::new( "Parent", vec![Term::constant("alice"), Term::constant("bob")], ), Atom::new( "Ancestor", vec![Term::constant("bob"), Term::constant("carol")], ), ] .into_iter() .collect(); let mut catalog = PredicateCatalog::from_instance(&instance).unwrap(); catalog .rename_columns("Parent", ["parent", "child"]) .unwrap(); catalog .rename_columns("Ancestor", ["parent", "child"]) .unwrap(); let select = parse_select( "SELECT Parent.parent, Ancestor.child FROM Parent, Ancestor \ WHERE Parent.child = Ancestor.parent", ) .unwrap(); let plan = plan_select(&select, &catalog).unwrap(); let schema = plan.output_schema(); assert_eq!(schema.len(), 2); assert_eq!(schema.fields()[0].name(), "Parent.parent"); assert_eq!(schema.fields()[1].name(), "Ancestor.child"); } #[test] fn plans_self_join_with_table_aliases() { let instance: Instance = vec![ Atom::new( "Parent", vec![Term::constant("alice"), Term::constant("bob")], ), Atom::new( "Parent", vec![Term::constant("bob"), Term::constant("carol")], ), ] .into_iter() .collect(); let mut catalog = PredicateCatalog::from_instance(&instance).unwrap(); catalog .rename_columns("Parent", ["parent", "child"]) .unwrap(); let select = parse_select( "SELECT p.parent, q.child FROM Parent AS p, Parent AS q \ WHERE p.child = q.parent", ) .unwrap(); let plan = plan_select(&select, &catalog).unwrap(); let schema = plan.output_schema(); assert_eq!(schema.len(), 2); assert_eq!(schema.fields()[0].name(), "p.parent"); assert_eq!(schema.fields()[1].name(), "q.child"); } #[test] fn plans_single_table_with_alias() { let instance: Instance = vec![Atom::new( "Parent", vec![Term::constant("alice"), Term::constant("bob")], )] .into_iter() .collect(); let mut catalog = PredicateCatalog::from_instance(&instance).unwrap(); catalog .rename_columns("Parent", ["parent", "child"]) .unwrap(); let select = parse_select("SELECT p.parent FROM Parent AS p WHERE p.child = 'bob'").unwrap(); let plan = plan_select(&select, &catalog).unwrap(); let schema = plan.output_schema(); assert_eq!(schema.len(), 1); assert_eq!(schema.fields()[0].name(), "p.parent"); } #[test] fn plans_single_table_with_qualified_table_name() { let instance: Instance = vec![Atom::new( "Parent", vec![Term::constant("alice"), Term::constant("bob")], )] .into_iter() .collect(); let mut catalog = PredicateCatalog::from_instance(&instance).unwrap(); catalog .rename_columns("Parent", ["parent", "child"]) .unwrap(); let select = parse_select("SELECT Parent.parent FROM Parent WHERE Parent.child = 'bob'").unwrap(); let plan = plan_select(&select, &catalog).unwrap(); let schema = plan.output_schema(); assert_eq!(schema.len(), 1); assert_eq!(schema.fields()[0].name(), "Parent.parent"); } #[test] fn plans_conjunctive_filter() { let instance: Instance = vec![Atom::new( "Parent", vec![Term::constant("alice"), Term::constant("bob")], )] .into_iter() .collect(); let catalog = PredicateCatalog::from_instance(&instance).unwrap(); let select = parse_select("SELECT c0 FROM Parent WHERE c1 = 'bob' AND c0 = 'alice'").unwrap(); let plan = plan_select(&select, &catalog).unwrap(); match plan { LogicalPlan::Project { input, .. } => match *input { LogicalPlan::Filter { predicate, .. } => { assert!(matches!(predicate, LogicalExpr::And(_, _))); } other => panic!("unexpected input plan: {:?}", other), }, other => panic!("unexpected plan: {:?}", other), } } #[test] fn plans_order_by_after_projection() { let instance: Instance = vec![ Atom::new( "Parent", vec![Term::constant("alice"), Term::constant("bob")], ), Atom::new( "Parent", vec![Term::constant("bob"), Term::constant("carol")], ), ] .into_iter() .collect(); let catalog = PredicateCatalog::from_instance(&instance).unwrap(); let select = parse_select("SELECT c0 FROM Parent ORDER BY c0 DESC").unwrap(); let plan = plan_select(&select, &catalog).unwrap(); match plan { LogicalPlan::Sort { keys, input, .. } => { assert_eq!(keys.len(), 1); assert_eq!(keys[0].column, "c0"); assert!(matches!(keys[0].direction, LogicalSortDirection::Desc)); assert!(matches!(*input, LogicalPlan::Project { .. })); } other => panic!("unexpected plan: {:?}", other), } } #[test] fn rejects_mixed_wildcard_projection() { let instance: Instance = vec![Atom::new( "Parent", vec![Term::constant("alice"), Term::constant("bob")], )] .into_iter() .collect(); let catalog = PredicateCatalog::from_instance(&instance).unwrap(); let select = parse_select("SELECT *, c0 FROM Parent").unwrap_err(); assert_eq!(select, ParseError::MixedWildcardProjection); let malformed = Select { projection: vec![ SelectItem::Wildcard, SelectItem::Expr { expr: Expr::Identifier("c0".to_string()), alias: None, }, ], from: vec![TableRef { name: "Parent".to_string(), alias: None, }], selection: None, group_by: Vec::new(), order_by: Vec::new(), limit: None, }; let error = plan_select(&malformed, &catalog).unwrap_err(); assert_eq!( error.to_string(), "wildcard projections cannot be combined with other items" ); } }