diff --git a/AGENTS.md b/AGENTS.md index 74ae1a1..c88021d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -71,7 +71,7 @@ Quick examples: - The chase engine should remain largely stateless; pass execution state explicitly. - New chase variants should be composable with existing infrastructure. - Existential variables generate labeled nulls (`Term::Null`). -- The current SQL support is intentionally narrow: `SELECT-FROM-WHERE-ORDER BY-LIMIT` over predicate-backed tables; equality and inequality predicates combined with `AND` and `OR`; comma-join style multi-table queries; table aliases; ordering by output-column names; integer and string literals. +- The current SQL support is intentionally narrow: `SELECT-FROM-WHERE-GROUP BY-ORDER BY-LIMIT` over predicate-backed tables; equality and inequality predicates combined with `AND` and `OR`; comma-join style multi-table queries; table aliases; ordering by output-column names; integer and string literals; `COUNT`, `SUM`, `MIN`, `MAX`, and `AVG` aggregates with optional `GROUP BY`. - Stable SQL column names come from explicit catalog registration or the frontend `schema ...` command, including for empty tables; otherwise the default names are positional such as `c0` and `c1`. - Single-table SQL queries may use the table name as a qualifier when no alias is present. - Do not describe unsupported SQL features such as aggregates, grouping, or arbitrary expressions as implemented. diff --git a/Cargo.toml b/Cargo.toml index d3b3496..76dd8fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,15 @@ binaries = [] [dev-dependencies] proptest = "1.6" +criterion = { version = "0.5", default-features = false } + +[[bench]] +name = "chase" +harness = false + +[[bench]] +name = "sql" +harness = false [profile.release] strip = "debuginfo" diff --git a/README.md b/README.md index 3ee9d04..0c29170 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,8 @@ execution boundaries. - Script, REPL, and local web UI for experimentation - Relational schema, catalog, logical-plan, and execution scaffolding - Physical operator scaffolding with a small rule-based rewrite layer -- A minimal SQL slice for `SELECT-FROM-WHERE-ORDER BY-LIMIT` queries over predicate-backed tables +- A minimal SQL slice for `SELECT-FROM-WHERE-GROUP BY-ORDER BY-LIMIT` queries over predicate-backed tables, including `COUNT`, `SUM`, `MIN`, `MAX`, and `AVG` aggregates +- Filter push-down across joins in the physical rewrite pass ### Architecture @@ -143,6 +144,8 @@ WHERE Parent.child = Ancestor.parent SELECT p.parent, q.child FROM Parent AS p, Parent AS q WHERE p.child = q.parent +SELECT COUNT(*) FROM Parent +SELECT dept, COUNT(*), SUM(salary) FROM Emp GROUP BY dept ``` In the REPL or script runner, use the `sql` command and end the statement with @@ -191,7 +194,7 @@ Current limits: - `ORDER BY` supports output-column ordering with `ASC`/`DESC` - `LIMIT` restricts the number of output rows - literals include strings, integers, and `NULL` -- no aggregates +- aggregates: `COUNT(*)`, `COUNT(col)`, `SUM`, `MIN`, `MAX`, `AVG`, with optional `GROUP BY` - projection aliases only via `AS` Runnable SQL examples: @@ -212,6 +215,12 @@ cargo clippy --all-targets --all-features -- -D warnings cargo fmt --check ``` +Benchmarks live under `benches/` and can be run with: + +```bash +cargo bench +``` + ### Notes This repository is still centered on a rule-engine core. The new SQL-related diff --git a/ROADMAP.md b/ROADMAP.md index 2851e4e..aa7abb2 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -35,6 +35,8 @@ This document tracks the current state and next steps for the repository. - [x] `!=`/`<>` inequality and `OR` disjunction in `WHERE` clauses - [x] `LIMIT` clause for restricting output row count - [x] Integer literal and `DataType::Integer` support +- [x] `COUNT`, `SUM`, `MIN`, `MAX`, `AVG` aggregates with `GROUP BY` +- [x] Filter push-down rewrite across `NestedLoopJoin` in the physical layer ### Near-Term Cleanup @@ -81,7 +83,7 @@ This document tracks the current state and next steps for the repository. - [ ] Negative constraints - [ ] Stratified negation in rule bodies - [ ] Disjunctive heads -- [ ] Aggregation support in rule evaluation +- [ ] Aggregation support in rule evaluation (available in SQL; not yet exposed to chase rules) - [x] Semi-naive evaluation - [ ] Termination analysis helpers @@ -108,6 +110,6 @@ This document tracks the current state and next steps for the repository. - [x] Property-based tests - [x] Regression tests - [x] Initial SQL pipeline tests -- [ ] Benchmark coverage +- [x] Benchmark coverage (chase and SQL pipeline via `cargo bench`) - [ ] Snapshot-style frontend tests - [ ] More planner/executor tests as those layers are added diff --git a/benches/chase.rs b/benches/chase.rs new file mode 100644 index 0000000..7c5e19d --- /dev/null +++ b/benches/chase.rs @@ -0,0 +1,98 @@ +//! Benchmarks for the chase subsystem. +//! +//! These are designed to retroactively validate the semi-naive and Skolem +//! work and catch future regressions. Each workload runs several chase +//! variants over the same input so relative numbers are meaningful. + +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; +use query_engine::chase::rule::RuleBuilder; +use query_engine::chase::{ChaseConfig, ChaseVariant, Rule, chase_with_config}; +use query_engine::{Atom, Instance, Term}; + +fn chain_edges(n: usize) -> Instance { + (0..n) + .map(|i| { + Atom::new( + "Edge", + vec![ + Term::constant(format!("n{}", i)), + Term::constant(format!("n{}", i + 1)), + ], + ) + }) + .collect() +} + +fn transitive_closure_rules() -> Vec { + let edge_to_path = RuleBuilder::new() + .when("Edge", vec![Term::var("X"), Term::var("Y")]) + .then("Path", vec![Term::var("X"), Term::var("Y")]) + .build(); + let extend_path = RuleBuilder::new() + .when("Path", vec![Term::var("X"), Term::var("Y")]) + .when("Edge", vec![Term::var("Y"), Term::var("Z")]) + .then("Path", vec![Term::var("X"), Term::var("Z")]) + .build(); + vec![edge_to_path, extend_path] +} + +fn bench_transitive_closure(c: &mut Criterion) { + let mut group = c.benchmark_group("transitive_closure_chain_20"); + let instance = chain_edges(20); + let rules = transitive_closure_rules(); + + for (label, variant, semi) in [ + ("restricted_naive", ChaseVariant::Restricted, false), + ("restricted_semi_naive", ChaseVariant::Restricted, true), + ("standard_naive", ChaseVariant::Standard, false), + ("standard_semi_naive", ChaseVariant::Standard, true), + ] { + let config = ChaseConfig { + variant, + semi_naive: semi, + ..Default::default() + }; + group.bench_function(label, |b| { + b.iter_batched( + || instance.clone(), + |inst| chase_with_config(inst, &rules, config.clone()), + BatchSize::SmallInput, + ); + }); + } + group.finish(); +} + +fn bench_existentials(c: &mut Criterion) { + let mut group = c.benchmark_group("existentials_50_people"); + let instance: Instance = (0..50) + .map(|i| Atom::new("Person", vec![Term::constant(format!("p{}", i))])) + .collect(); + let rule = RuleBuilder::new() + .when("Person", vec![Term::var("X")]) + .then("HasId", vec![Term::var("X"), Term::var("Y")]) + .build(); + let rules = vec![rule]; + + for (label, variant) in [ + ("restricted", ChaseVariant::Restricted), + ("skolem", ChaseVariant::Skolem), + ] { + let config = ChaseConfig { + variant, + semi_naive: false, + ..Default::default() + }; + group.bench_function(label, |b| { + b.iter_batched( + || instance.clone(), + |inst| chase_with_config(inst, &rules, config.clone()), + BatchSize::SmallInput, + ); + }); + } + group.finish(); +} + +criterion_group!(benches, bench_transitive_closure, bench_existentials); +criterion_main!(benches); diff --git a/benches/sql.rs b/benches/sql.rs new file mode 100644 index 0000000..6338695 --- /dev/null +++ b/benches/sql.rs @@ -0,0 +1,115 @@ +//! Benchmarks for the SQL pipeline. +//! +//! Focus areas: scans, single-column filters, multi-table joins with and +//! without filter push-down, and GROUP BY aggregation. + +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; +use query_engine::catalog::PredicateCatalog; +use query_engine::execution::TableStore; +use query_engine::execution::execute; +use query_engine::execution::physical::{execute_physical, plan_physical, rewrite_physical}; +use query_engine::planner::sql::plan_select; +use query_engine::relational::{DataType, Field, Row, Schema, Value}; +use query_engine::sql::parser::parse_select; +use query_engine::{Atom, Instance, Term}; + +fn edges_instance(n: usize) -> Instance { + (0..n) + .map(|i| { + Atom::new( + "L", + vec![ + Term::constant(format!("a{}", i)), + Term::constant(format!("b{}", i)), + ], + ) + }) + .chain((0..n).map(|i| { + Atom::new( + "R", + vec![ + Term::constant(format!("b{}", i)), + Term::constant(format!("c{}", i)), + ], + ) + })) + .collect() +} + +fn bench_filter_pushdown_join(c: &mut Criterion) { + let mut group = c.benchmark_group("filter_pushdown_join_100"); + let instance = edges_instance(100); + let mut catalog = PredicateCatalog::from_instance(&instance).unwrap(); + catalog.rename_columns("L", ["a", "b"]).unwrap(); + catalog.rename_columns("R", ["b", "c"]).unwrap(); + + let select = parse_select("SELECT L.a, R.c FROM L, R WHERE L.b = R.b AND L.a = 'a42'").unwrap(); + let logical = plan_select(&select, &catalog).unwrap(); + + group.bench_function("logical_direct_execute", |b| { + b.iter(|| execute(&logical, &instance).unwrap()); + }); + + let physical_raw = plan_physical(&logical); + group.bench_function("physical_no_rewrite", |b| { + b.iter(|| execute_physical(&physical_raw, &instance).unwrap()); + }); + + let physical_rewritten = rewrite_physical(plan_physical(&logical)); + group.bench_function("physical_with_pushdown", |b| { + b.iter(|| execute_physical(&physical_rewritten, &instance).unwrap()); + }); + + group.finish(); +} + +fn bench_group_by_aggregation(c: &mut Criterion) { + let mut group = c.benchmark_group("group_by_aggregation_1000"); + + let schema = Schema::new(vec![ + Field::new("dept", DataType::Text, false), + Field::new("salary", DataType::Integer, false), + ]); + let mut store = TableStore::new(); + let rows: Vec = (0..1000) + .map(|i| { + let dept = format!("d{}", i % 10); + Row::new(vec![Value::text(dept), Value::Integer((i as i64) * 10)]) + }) + .collect(); + store.insert("Emp", schema.clone(), rows); + + let mut catalog = PredicateCatalog::new(); + catalog.register_table("Emp", schema); + + let select = + parse_select("SELECT dept, COUNT(*), SUM(salary), AVG(salary) FROM Emp GROUP BY dept") + .unwrap(); + let logical = plan_select(&select, &catalog).unwrap(); + + group.bench_function("logical_direct", |b| { + b.iter_batched( + || (), + |_| execute(&logical, &store).unwrap(), + BatchSize::SmallInput, + ); + }); + + let physical = rewrite_physical(plan_physical(&logical)); + group.bench_function("physical", |b| { + b.iter_batched( + || (), + |_| execute_physical(&physical, &store).unwrap(), + BatchSize::SmallInput, + ); + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_filter_pushdown_join, + bench_group_by_aggregation +); +criterion_main!(benches); diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 5f1cd7c..92e3d47 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -12,8 +12,11 @@ use std::error::Error; use std::fmt; use crate::chase::{Instance, Term}; -use crate::planner::logical::{LogicalExpr, LogicalPlan, SortDirection, SortKey}; +use crate::planner::logical::{ + AggregateExpr as PlanAggregateExpr, LogicalExpr, LogicalPlan, SortDirection, SortKey, +}; use crate::relational::{ResultSet, Row, Schema, Value}; +use crate::sql::ast::AggregateFunc; pub use physical::{ NamedPhysicalExpr, PhysicalPlan, execute_physical, plan_physical, rewrite_physical, @@ -133,6 +136,195 @@ pub fn execute(plan: &LogicalPlan, source: &dyn DataSource) -> Result { + let result = execute(input, source)?; + let rows = compute_aggregate(result.rows(), result.schema(), group_by, aggregates)?; + Ok(ResultSet::new(schema.clone(), rows)) + } + } +} + +/// Evaluate group-by + aggregates over a row set, returning one output row +/// per distinct group key. The output row layout is: group_by column values +/// followed by aggregate output values. +pub(crate) fn compute_aggregate( + rows: &[Row], + input_schema: &Schema, + group_by: &[String], + aggregates: &[PlanAggregateExpr], +) -> Result, ExecutionError> { + let group_indexes = group_by + .iter() + .map(|name| { + input_schema + .index_of(name) + .ok_or_else(|| ExecutionError::UnknownColumn(name.clone())) + }) + .collect::, _>>()?; + + // Each aggregate holds an optional input column index (None means COUNT(*)). + let agg_indexes = aggregates + .iter() + .map(|agg| { + agg.arg + .as_ref() + .map(|col| { + input_schema + .index_of(col) + .ok_or_else(|| ExecutionError::UnknownColumn(col.clone())) + }) + .transpose() + }) + .collect::>, _>>()?; + + // Preserve first-seen group order so single-group output is deterministic. + let mut order: Vec> = Vec::new(); + let mut groups: std::collections::HashMap, Vec> = + std::collections::HashMap::new(); + + for row in rows { + let key: Vec = group_indexes + .iter() + .map(|i| row.get(*i).cloned().unwrap_or(Value::Null)) + .collect(); + + let states = groups.entry(key.clone()).or_insert_with(|| { + order.push(key.clone()); + aggregates + .iter() + .map(|agg| AggregateState::new(agg.func)) + .collect() + }); + + for (state, index_opt) in states.iter_mut().zip(agg_indexes.iter()) { + let value = match index_opt { + Some(i) => row.get(*i).cloned().unwrap_or(Value::Null), + None => Value::Null, // COUNT(*) observes each row + }; + state.observe(&value, index_opt.is_none()); + } + } + + // If the user wrote an aggregate with no GROUP BY and no input rows, we + // still need one output row (all-null plus zero counts). + if rows.is_empty() && group_by.is_empty() && !aggregates.is_empty() { + order.push(Vec::new()); + groups.insert( + Vec::new(), + aggregates + .iter() + .map(|agg| AggregateState::new(agg.func)) + .collect(), + ); + } + + let mut out_rows = Vec::new(); + for key in order { + let states = groups.remove(&key).unwrap_or_default(); + let mut values = key; + for state in &states { + values.push(state.finalize()); + } + out_rows.push(Row::new(values)); + } + + Ok(out_rows) +} + +#[derive(Debug)] +pub(crate) enum AggregateState { + Count(i64), + Sum(Option), + Min(Option), + Max(Option), + Avg { sum: i64, count: i64 }, +} + +impl AggregateState { + pub(crate) fn new(func: AggregateFunc) -> Self { + match func { + AggregateFunc::Count => Self::Count(0), + AggregateFunc::Sum => Self::Sum(None), + AggregateFunc::Min => Self::Min(None), + AggregateFunc::Max => Self::Max(None), + AggregateFunc::Avg => Self::Avg { sum: 0, count: 0 }, + } + } + + pub(crate) fn observe(&mut self, value: &Value, is_count_star: bool) { + match self { + Self::Count(c) => { + if is_count_star || !matches!(value, Value::Null) { + *c += 1; + } + } + Self::Sum(total) => { + if let Value::Integer(n) = value { + *total = Some(total.unwrap_or(0) + n); + } + } + Self::Min(current) => { + if !matches!(value, Value::Null) { + match current { + None => *current = Some(value.clone()), + Some(existing) => { + if compare_values_for_agg(value, existing) == std::cmp::Ordering::Less { + *existing = value.clone(); + } + } + } + } + } + Self::Max(current) => { + if !matches!(value, Value::Null) { + match current { + None => *current = Some(value.clone()), + Some(existing) => { + if compare_values_for_agg(value, existing) + == std::cmp::Ordering::Greater + { + *existing = value.clone(); + } + } + } + } + } + Self::Avg { sum, count } => { + if let Value::Integer(n) = value { + *sum += n; + *count += 1; + } + } + } + } + + pub(crate) fn finalize(&self) -> Value { + match self { + Self::Count(c) => Value::Integer(*c), + Self::Sum(total) => total.map(Value::Integer).unwrap_or(Value::Null), + Self::Min(v) | Self::Max(v) => v.clone().unwrap_or(Value::Null), + Self::Avg { sum, count } => { + if *count == 0 { + Value::Null + } else { + Value::Integer(sum / count) + } + } + } + } +} + +fn compare_values_for_agg(left: &Value, right: &Value) -> std::cmp::Ordering { + match (left, right) { + (Value::Integer(a), Value::Integer(b)) => a.cmp(b), + (Value::Text(a), Value::Text(b)) => a.cmp(b), + (Value::Boolean(a), Value::Boolean(b)) => a.cmp(b), + _ => std::cmp::Ordering::Equal, } } diff --git a/src/execution/physical.rs b/src/execution/physical.rs index 0e467c8..3ff850d 100644 --- a/src/execution/physical.rs +++ b/src/execution/physical.rs @@ -15,10 +15,12 @@ use std::cmp::Ordering; -use crate::planner::logical::{LogicalExpr, LogicalPlan, SortDirection, SortKey}; +use crate::planner::logical::{ + AggregateExpr as PlanAggregateExpr, LogicalExpr, LogicalPlan, SortDirection, SortKey, +}; use crate::relational::{ResultSet, Row, Schema, Value}; -use super::{DataSource, ExecutionError}; +use super::{DataSource, ExecutionError, compute_aggregate}; /// A physical plan node in the current execution subset. #[derive(Debug, Clone, PartialEq, Eq)] @@ -54,6 +56,13 @@ pub enum PhysicalPlan { input: Box, count: usize, }, + /// Compute aggregates per group key using an in-memory hash map. + HashAggregate { + input: Box, + group_by: Vec, + aggregates: Vec, + schema: Schema, + }, } /// A named physical expression in a projection. @@ -75,6 +84,7 @@ impl PhysicalPlan { Self::Sort { schema, .. } => schema, Self::Project { schema, .. } => schema, Self::Limit { input, .. } => input.output_schema(), + Self::HashAggregate { schema, .. } => schema, } } } @@ -131,19 +141,209 @@ pub fn plan_physical(plan: &LogicalPlan) -> PhysicalPlan { input: Box::new(plan_physical(input)), count: *count, }, + LogicalPlan::Aggregate { + input, + group_by, + aggregates, + schema, + } => PhysicalPlan::HashAggregate { + input: Box::new(plan_physical(input)), + group_by: group_by.clone(), + aggregates: aggregates.clone(), + schema: schema.clone(), + }, } } /// Apply rule-based rewrites to a physical plan. /// -/// Today the only rewrite is `combine_adjacent_limits`, which collapses -/// `Limit(Limit(child, n), m)` into `Limit(child, min(n, m))`. Future -/// rewrites belong here as additional functions composed in this entry -/// point. +/// Current rewrites: +/// - [`combine_adjacent_limits`] collapses `Limit(Limit(child, n), m)` into +/// `Limit(child, min(n, m))`. +/// - [`push_filter_below_join`] pushes conjuncts of a `Filter` below a +/// `NestedLoopJoin` when they reference only one side's columns, so the +/// join sees fewer rows. pub fn rewrite_physical(plan: PhysicalPlan) -> PhysicalPlan { + let plan = push_filter_below_join(plan); combine_adjacent_limits(plan) } +/// Push conjuncts of a `Filter` below a `NestedLoopJoin` when each conjunct +/// references only columns from one side of the join. Conjuncts that mention +/// both sides remain above the join. +fn push_filter_below_join(plan: PhysicalPlan) -> PhysicalPlan { + match plan { + PhysicalPlan::Filter { input, predicate } => { + let pushed_input = push_filter_below_join(*input); + match pushed_input { + PhysicalPlan::NestedLoopJoin { + left, + right, + schema, + } => { + let left_cols: Vec = left + .output_schema() + .fields() + .iter() + .map(|f| f.name().to_string()) + .collect(); + let right_cols: Vec = right + .output_schema() + .fields() + .iter() + .map(|f| f.name().to_string()) + .collect(); + + let mut left_conjuncts: Vec = Vec::new(); + let mut right_conjuncts: Vec = Vec::new(); + let mut remaining: Vec = Vec::new(); + + for conjunct in split_conjuncts(predicate) { + let refs = collect_column_refs(&conjunct); + let all_left = refs.iter().all(|c| left_cols.contains(c)); + let all_right = refs.iter().all(|c| right_cols.contains(c)); + if !refs.is_empty() && all_left { + left_conjuncts.push(conjunct); + } else if !refs.is_empty() && all_right { + right_conjuncts.push(conjunct); + } else { + remaining.push(conjunct); + } + } + + let left = if let Some(pred) = combine_conjuncts(left_conjuncts) { + Box::new(PhysicalPlan::Filter { + input: left, + predicate: pred, + }) + } else { + left + }; + let right = if let Some(pred) = combine_conjuncts(right_conjuncts) { + Box::new(PhysicalPlan::Filter { + input: right, + predicate: pred, + }) + } else { + right + }; + + // Recurse so pushed filters below the join continue to + // push through deeper joins if any. + let left = Box::new(push_filter_below_join(*left)); + let right = Box::new(push_filter_below_join(*right)); + + let joined = PhysicalPlan::NestedLoopJoin { + left, + right, + schema, + }; + + match combine_conjuncts(remaining) { + Some(pred) => PhysicalPlan::Filter { + input: Box::new(joined), + predicate: pred, + }, + None => joined, + } + } + other => PhysicalPlan::Filter { + input: Box::new(other), + predicate, + }, + } + } + PhysicalPlan::NestedLoopJoin { + left, + right, + schema, + } => PhysicalPlan::NestedLoopJoin { + left: Box::new(push_filter_below_join(*left)), + right: Box::new(push_filter_below_join(*right)), + schema, + }, + PhysicalPlan::Sort { + input, + keys, + schema, + } => PhysicalPlan::Sort { + input: Box::new(push_filter_below_join(*input)), + keys, + schema, + }, + PhysicalPlan::Project { + input, + expressions, + schema, + } => PhysicalPlan::Project { + input: Box::new(push_filter_below_join(*input)), + expressions, + schema, + }, + PhysicalPlan::Limit { input, count } => PhysicalPlan::Limit { + input: Box::new(push_filter_below_join(*input)), + count, + }, + PhysicalPlan::HashAggregate { + input, + group_by, + aggregates, + schema, + } => PhysicalPlan::HashAggregate { + input: Box::new(push_filter_below_join(*input)), + group_by, + aggregates, + schema, + }, + leaf @ PhysicalPlan::SeqScan { .. } => leaf, + } +} + +fn split_conjuncts(expr: LogicalExpr) -> Vec { + let mut out = Vec::new(); + let mut stack = vec![expr]; + while let Some(node) = stack.pop() { + match node { + LogicalExpr::And(left, right) => { + stack.push(*right); + stack.push(*left); + } + other => out.push(other), + } + } + out +} + +fn combine_conjuncts(mut conjuncts: Vec) -> Option { + if conjuncts.is_empty() { + return None; + } + let mut combined = conjuncts.remove(0); + for next in conjuncts { + combined = LogicalExpr::And(Box::new(combined), Box::new(next)); + } + Some(combined) +} + +fn collect_column_refs(expr: &LogicalExpr) -> Vec { + let mut out = Vec::new(); + fn walk(expr: &LogicalExpr, out: &mut Vec) { + match expr { + LogicalExpr::Column(name) => out.push(name.clone()), + LogicalExpr::Literal(_) => {} + LogicalExpr::Eq(left, right) + | LogicalExpr::Ne(left, right) + | LogicalExpr::And(left, right) + | LogicalExpr::Or(left, right) => { + walk(left, out); + walk(right, out); + } + } + } + walk(expr, &mut out); + out +} + fn combine_adjacent_limits(plan: PhysicalPlan) -> PhysicalPlan { match plan { PhysicalPlan::Limit { input, count } => { @@ -193,6 +393,17 @@ fn combine_adjacent_limits(plan: PhysicalPlan) -> PhysicalPlan { expressions, schema, }, + PhysicalPlan::HashAggregate { + input, + group_by, + aggregates, + schema, + } => PhysicalPlan::HashAggregate { + input: Box::new(combine_adjacent_limits(*input)), + group_by, + aggregates, + schema, + }, leaf @ PhysicalPlan::SeqScan { .. } => leaf, } } @@ -265,6 +476,16 @@ pub fn execute_physical( let rows = result.rows().iter().take(*count).cloned().collect(); Ok(ResultSet::new(result.schema().clone(), rows)) } + PhysicalPlan::HashAggregate { + input, + group_by, + aggregates, + schema, + } => { + let result = execute_physical(input, source)?; + let rows = compute_aggregate(result.rows(), result.schema(), group_by, aggregates)?; + Ok(ResultSet::new(schema.clone(), rows)) + } } } @@ -467,4 +688,136 @@ mod tests { assert_eq!(result.rows().len(), 1); assert_eq!(result.rows()[0].values()[0], Value::text("alice")); } + + #[test] + fn rewrite_pushes_single_side_filter_below_join() { + let left_schema = Schema::new(vec![ + Field::new("Parent.parent", DataType::Text, false), + Field::new("Parent.child", DataType::Text, false), + ]); + let right_schema = Schema::new(vec![ + Field::new("Ancestor.parent", DataType::Text, false), + Field::new("Ancestor.child", DataType::Text, false), + ]); + let join_schema = Schema::new( + left_schema + .fields() + .iter() + .chain(right_schema.fields()) + .cloned() + .collect(), + ); + + // Filter( + // NestedLoopJoin(Parent, Ancestor), + // Parent.parent = 'alice' AND Parent.child = Ancestor.parent, + // ) + let plan = PhysicalPlan::Filter { + input: Box::new(PhysicalPlan::NestedLoopJoin { + left: Box::new(PhysicalPlan::SeqScan { + table: "Parent".to_string(), + schema: left_schema, + }), + right: Box::new(PhysicalPlan::SeqScan { + table: "Ancestor".to_string(), + schema: right_schema, + }), + schema: join_schema, + }), + predicate: LogicalExpr::And( + Box::new(LogicalExpr::Eq( + Box::new(LogicalExpr::Column("Parent.parent".to_string())), + Box::new(LogicalExpr::Literal(Value::text("alice"))), + )), + Box::new(LogicalExpr::Eq( + Box::new(LogicalExpr::Column("Parent.child".to_string())), + Box::new(LogicalExpr::Column("Ancestor.parent".to_string())), + )), + ), + }; + + let rewritten = rewrite_physical(plan); + match rewritten { + // The Parent.parent = 'alice' predicate should be pushed onto the + // left side; the join predicate should remain above. + PhysicalPlan::Filter { input, .. } => match *input { + PhysicalPlan::NestedLoopJoin { left, .. } => { + assert!(matches!(*left, PhysicalPlan::Filter { .. })); + } + other => panic!("expected NestedLoopJoin under Filter, got {:?}", other), + }, + other => panic!("expected outer Filter, got {:?}", other), + } + } + + #[test] + fn rewrite_push_filter_preserves_semantics_on_join() { + // Two three-row tables, join predicate filters down to one row. + // Push-down should not change the row count or values. + struct TwoTable; + impl DataSource for TwoTable { + fn scan(&self, table: &str, schema: &Schema) -> Result { + let rows = match table { + "L" => vec![ + Row::new(vec![Value::text("alice"), Value::text("bob")]), + Row::new(vec![Value::text("bob"), Value::text("carol")]), + Row::new(vec![Value::text("carol"), Value::text("dave")]), + ], + "R" => vec![ + Row::new(vec![Value::text("bob"), Value::text("x")]), + Row::new(vec![Value::text("carol"), Value::text("y")]), + Row::new(vec![Value::text("eve"), Value::text("z")]), + ], + _ => Vec::new(), + }; + Ok(ResultSet::new(schema.clone(), rows)) + } + } + + let left_schema = Schema::new(vec![ + Field::new("L.a", DataType::Text, false), + Field::new("L.b", DataType::Text, false), + ]); + let right_schema = Schema::new(vec![ + Field::new("R.a", DataType::Text, false), + Field::new("R.b", DataType::Text, false), + ]); + let join_schema = Schema::new( + left_schema + .fields() + .iter() + .chain(right_schema.fields()) + .cloned() + .collect(), + ); + + let plan = PhysicalPlan::Filter { + input: Box::new(PhysicalPlan::NestedLoopJoin { + left: Box::new(PhysicalPlan::SeqScan { + table: "L".to_string(), + schema: left_schema, + }), + right: Box::new(PhysicalPlan::SeqScan { + table: "R".to_string(), + schema: right_schema, + }), + schema: join_schema, + }), + predicate: LogicalExpr::And( + Box::new(LogicalExpr::Eq( + Box::new(LogicalExpr::Column("L.a".to_string())), + Box::new(LogicalExpr::Literal(Value::text("bob"))), + )), + Box::new(LogicalExpr::Eq( + Box::new(LogicalExpr::Column("L.b".to_string())), + Box::new(LogicalExpr::Column("R.a".to_string())), + )), + ), + }; + + let before = execute_physical(&plan, &TwoTable).unwrap(); + let after = execute_physical(&rewrite_physical(plan.clone()), &TwoTable).unwrap(); + assert_eq!(before.rows().len(), after.rows().len()); + assert_eq!(before.rows(), after.rows()); + } } diff --git a/src/planner/logical.rs b/src/planner/logical.rs index 44dabd3..099df56 100644 --- a/src/planner/logical.rs +++ b/src/planner/logical.rs @@ -1,4 +1,5 @@ use crate::relational::{Schema, Value}; +use crate::sql::ast::AggregateFunc; /// Sort direction for the logical `Sort` operator. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -44,6 +45,17 @@ pub struct SortKey { pub direction: SortDirection, } +/// A single aggregate output in a logical `Aggregate` operator. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AggregateExpr { + /// Output column name for this aggregate. + pub name: String, + /// Aggregate function to apply. + pub func: AggregateFunc, + /// Source column name for the aggregate input, or `None` for `COUNT(*)`. + pub arg: Option, +} + /// A logical plan in the current execution subset. #[derive(Debug, Clone, PartialEq, Eq)] pub enum LogicalPlan { @@ -60,6 +72,14 @@ pub enum LogicalPlan { input: Box, predicate: LogicalExpr, }, + /// Group rows by a list of columns and compute aggregates per group. + /// The output schema is `group_by` columns followed by aggregate outputs. + Aggregate { + input: Box, + group_by: Vec, + aggregates: Vec, + schema: Schema, + }, /// Sort rows by one or more output columns. Sort { input: Box, @@ -86,6 +106,7 @@ impl LogicalPlan { Self::Scan { schema, .. } => schema, Self::CrossJoin { schema, .. } => schema, Self::Filter { input, .. } => input.output_schema(), + Self::Aggregate { schema, .. } => schema, Self::Sort { schema, .. } => schema, Self::Project { schema, .. } => schema, Self::Limit { input, .. } => input.output_schema(), diff --git a/src/planner/sql.rs b/src/planner/sql.rs index d24401e..b2f66fd 100644 --- a/src/planner/sql.rs +++ b/src/planner/sql.rs @@ -4,11 +4,13 @@ use std::fmt; use crate::catalog::{CatalogError, PredicateCatalog}; use crate::planner::logical::{ - LogicalExpr, LogicalPlan, NamedExpr, SortDirection as LogicalSortDirection, SortKey, + AggregateExpr as PlanAggregateExpr, LogicalExpr, LogicalPlan, NamedExpr, + SortDirection as LogicalSortDirection, SortKey, }; use crate::relational::{DataType, Field, Schema, Value}; use crate::sql::ast::{ - BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem, SortDirection, TableRef, + AggregateArg, AggregateFunc, BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem, + SortDirection, TableRef, }; /// Errors returned when translating SQL AST into a logical plan. @@ -24,6 +26,14 @@ pub enum PlannerError { UnsupportedOrderBy, /// The parser or AST contains a wildcard mixed with other projection items. MixedWildcardProjection, + /// A `GROUP BY` expression is not a simple column reference. + UnsupportedGroupBy, + /// A projected column is neither aggregated nor present in `GROUP BY`. + ProjectionNotGrouped(String), + /// An aggregate expression appears in an unsupported position. + UnsupportedAggregate, + /// `COUNT(*)` was used with a non-count aggregate function. + StarArgNotAllowed, } impl fmt::Display for PlannerError { @@ -43,6 +53,18 @@ impl fmt::Display for PlannerError { "wildcard projections cannot be combined with other items" ) } + Self::UnsupportedGroupBy => { + write!(f, "only bare column references are supported in GROUP BY") + } + Self::ProjectionNotGrouped(name) => { + write!(f, "column `{}` is not aggregated and not in GROUP BY", name) + } + Self::UnsupportedAggregate => { + write!(f, "aggregate expressions are only allowed in SELECT items") + } + Self::StarArgNotAllowed => { + write!(f, "`*` is only allowed as the argument to COUNT") + } } } } @@ -54,7 +76,11 @@ impl Error for PlannerError { Self::UnknownColumn(_) | Self::DuplicateSourceName(_) | Self::UnsupportedOrderBy - | Self::MixedWildcardProjection => None, + | Self::MixedWildcardProjection + | Self::UnsupportedGroupBy + | Self::ProjectionNotGrouped(_) + | Self::UnsupportedAggregate + | Self::StarArgNotAllowed => None, } } } @@ -80,7 +106,15 @@ pub fn plan_select( }; } - if !is_wildcard_projection(&select.projection) { + let is_aggregate_query = !select.group_by.is_empty() + || select.projection.iter().any(|item| match item { + SelectItem::Expr { expr, .. } => contains_aggregate(expr), + SelectItem::Wildcard => false, + }); + + if is_aggregate_query { + plan = plan_aggregate(plan, &input_schema, select)?; + } else if !is_wildcard_projection(&select.projection) { let mut expressions = Vec::new(); let mut fields = Vec::new(); for (index, item) in select.projection.iter().enumerate() { @@ -122,6 +156,208 @@ pub fn plan_select( Ok(plan) } +fn contains_aggregate(expr: &Expr) -> bool { + match expr { + Expr::Aggregate { .. } => true, + Expr::Binary { left, right, .. } => contains_aggregate(left) || contains_aggregate(right), + Expr::Identifier(_) | Expr::Literal(_) => false, + } +} + +fn plan_aggregate( + input: LogicalPlan, + input_schema: &Schema, + select: &Select, +) -> Result { + // Resolve GROUP BY expressions to column names. + let mut group_by_cols = Vec::new(); + for expr in &select.group_by { + match expr { + Expr::Identifier(name) => { + let resolved = resolve_column_name(name, input_schema, &select.from)?; + group_by_cols.push(resolved); + } + _ => return Err(PlannerError::UnsupportedGroupBy), + } + } + + // Walk the projection, collecting aggregate expressions and verifying + // non-aggregate column references are in GROUP BY. + let mut aggregates: Vec = Vec::new(); + let mut projection_items: Vec<(String, ProjectionSource)> = Vec::new(); + + for (index, item) in select.projection.iter().enumerate() { + match item { + SelectItem::Wildcard => return Err(PlannerError::MixedWildcardProjection), + SelectItem::Expr { expr, alias } => { + let output_name = alias + .clone() + .unwrap_or_else(|| default_projection_name(expr, index + 1)); + let source = plan_aggregate_projection( + expr, + input_schema, + select, + &group_by_cols, + &mut aggregates, + )?; + projection_items.push((output_name, source)); + } + } + } + + // Build the Aggregate node's output schema: group_by columns followed by + // aggregate outputs. + let mut agg_fields = Vec::new(); + for col in &group_by_cols { + let field_index = input_schema + .index_of(col) + .ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?; + let field = &input_schema.fields()[field_index]; + agg_fields.push(Field::new( + col.clone(), + field.data_type().clone(), + field.nullable(), + )); + } + for agg in &aggregates { + let (dtype, nullable) = aggregate_output_type(agg, input_schema)?; + agg_fields.push(Field::new(agg.name.clone(), dtype, nullable)); + } + let agg_schema = Schema::new(agg_fields); + + let aggregate_plan = LogicalPlan::Aggregate { + input: Box::new(input), + group_by: group_by_cols.clone(), + aggregates, + schema: agg_schema.clone(), + }; + + // Build the final Project over the aggregate output. + let mut expressions = Vec::new(); + let mut fields = Vec::new(); + for (name, source) in projection_items { + let (expr, dtype, nullable) = match source { + ProjectionSource::GroupColumn(col) => { + let index = agg_schema + .index_of(&col) + .ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?; + let field = &agg_schema.fields()[index]; + ( + LogicalExpr::Column(col), + field.data_type().clone(), + field.nullable(), + ) + } + ProjectionSource::AggregateColumn(col) => { + let index = agg_schema + .index_of(&col) + .ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?; + let field = &agg_schema.fields()[index]; + ( + LogicalExpr::Column(col), + field.data_type().clone(), + field.nullable(), + ) + } + ProjectionSource::Literal(value) => { + let (dtype, nullable) = literal_metadata(&value); + (LogicalExpr::Literal(value), dtype, nullable) + } + }; + expressions.push(NamedExpr { + name: name.clone(), + expr, + }); + fields.push(Field::new(name, dtype, nullable)); + } + + Ok(LogicalPlan::Project { + input: Box::new(aggregate_plan), + expressions, + schema: Schema::new(fields), + }) +} + +#[derive(Debug, Clone)] +enum ProjectionSource { + GroupColumn(String), + AggregateColumn(String), + Literal(Value), +} + +fn plan_aggregate_projection( + expr: &Expr, + input_schema: &Schema, + select: &Select, + group_by_cols: &[String], + aggregates: &mut Vec, +) -> Result { + match expr { + Expr::Aggregate { func, arg } => { + let arg_col = match arg { + AggregateArg::Star => { + if !matches!(func, AggregateFunc::Count) { + return Err(PlannerError::StarArgNotAllowed); + } + None + } + AggregateArg::Expr(inner) => match inner.as_ref() { + Expr::Identifier(name) => { + Some(resolve_column_name(name, input_schema, &select.from)?) + } + _ => return Err(PlannerError::UnsupportedAggregate), + }, + }; + let synthetic_name = format!("__agg_{}", aggregates.len()); + aggregates.push(PlanAggregateExpr { + name: synthetic_name.clone(), + func: *func, + arg: arg_col, + }); + Ok(ProjectionSource::AggregateColumn(synthetic_name)) + } + Expr::Identifier(name) => { + let resolved = resolve_column_name(name, input_schema, &select.from)?; + if !group_by_cols.contains(&resolved) { + return Err(PlannerError::ProjectionNotGrouped(name.clone())); + } + Ok(ProjectionSource::GroupColumn(resolved)) + } + Expr::Literal(literal) => Ok(ProjectionSource::Literal(plan_literal(literal))), + Expr::Binary { .. } => Err(PlannerError::UnsupportedAggregate), + } +} + +fn aggregate_output_type( + agg: &PlanAggregateExpr, + input_schema: &Schema, +) -> Result<(DataType, bool), PlannerError> { + match agg.func { + AggregateFunc::Count => Ok((DataType::Integer, false)), + AggregateFunc::Sum | AggregateFunc::Avg => Ok((DataType::Integer, true)), + AggregateFunc::Min | AggregateFunc::Max => { + if let Some(col) = &agg.arg { + let index = input_schema + .index_of(col) + .ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?; + let field = &input_schema.fields()[index]; + Ok((field.data_type().clone(), true)) + } else { + Ok((DataType::Text, true)) + } + } + } +} + +fn literal_metadata(value: &Value) -> (DataType, bool) { + match value { + Value::Text(_) => (DataType::Text, false), + Value::Integer(_) => (DataType::Integer, false), + Value::Boolean(_) => (DataType::Boolean, false), + Value::Null => (DataType::Text, true), + } +} + fn is_wildcard_projection(items: &[SelectItem]) -> bool { matches!(items, [SelectItem::Wildcard]) } @@ -202,6 +438,7 @@ fn plan_expr( Box::new(plan_expr(right, schema, tables)?), )), }, + Expr::Aggregate { .. } => Err(PlannerError::UnsupportedAggregate), } } @@ -264,6 +501,7 @@ fn projection_metadata( Expr::Literal(Literal::Integer(_)) => Ok((DataType::Integer, false)), Expr::Literal(Literal::Null) => Ok((DataType::Text, true)), Expr::Binary { .. } => Ok((DataType::Boolean, true)), + Expr::Aggregate { .. } => Err(PlannerError::UnsupportedAggregate), } } @@ -290,6 +528,23 @@ fn resolve_column_name( fn default_projection_name(expr: &Expr, ordinal: usize) -> String { match expr { + Expr::Aggregate { func, arg } => { + let func_name = match func { + AggregateFunc::Count => "COUNT", + AggregateFunc::Sum => "SUM", + AggregateFunc::Min => "MIN", + AggregateFunc::Max => "MAX", + AggregateFunc::Avg => "AVG", + }; + let arg_str = match arg { + AggregateArg::Star => "*".to_string(), + AggregateArg::Expr(inner) => match inner.as_ref() { + Expr::Identifier(name) => name.clone(), + _ => format!("expr{}", ordinal), + }, + }; + format!("{}({})", func_name, arg_str) + } Expr::Identifier(name) => name.clone(), Expr::Literal(_) | Expr::Binary { .. } => format!("expr{}", ordinal), } @@ -566,6 +821,7 @@ mod tests { alias: None, }], selection: None, + group_by: Vec::new(), order_by: Vec::new(), limit: None, }; diff --git a/src/sql/ast.rs b/src/sql/ast.rs index 05a27bc..9a5cb5b 100644 --- a/src/sql/ast.rs +++ b/src/sql/ast.rs @@ -1,4 +1,5 @@ -/// A parsed `SELECT-FROM-WHERE-ORDER BY-LIMIT` statement in the current SQL subset. +/// A parsed `SELECT-FROM-WHERE-GROUP BY-ORDER BY-LIMIT` statement in the +/// current SQL subset. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Select { /// Output expressions requested by the query. @@ -7,6 +8,8 @@ pub struct Select { pub from: Vec, /// Optional filter predicate. pub selection: Option, + /// Grouping columns. Empty means no `GROUP BY` clause. + pub group_by: Vec, /// Optional output ordering. pub order_by: Vec, /// Optional row limit. @@ -53,6 +56,36 @@ pub enum Expr { op: BinaryOp, right: Box, }, + /// An aggregate function applied to an argument. + Aggregate { + func: AggregateFunc, + arg: AggregateArg, + }, +} + +/// An aggregate function in the current SQL subset. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AggregateFunc { + /// Row count (with `*`) or count of non-null values (with a column). + Count, + /// Sum of integer values. + Sum, + /// Minimum value. + Min, + /// Maximum value. + Max, + /// Arithmetic mean of integer values. + Avg, +} + +/// The argument to an aggregate function: either `*` (only valid for +/// `COUNT`) or an expression. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AggregateArg { + /// `COUNT(*)` style argument. + Star, + /// An expression argument such as `SUM(col)`. + Expr(Box), } /// A SQL literal in the current subset. diff --git a/src/sql/parser.rs b/src/sql/parser.rs index eb7b89a..0fc7bb0 100644 --- a/src/sql/parser.rs +++ b/src/sql/parser.rs @@ -2,7 +2,8 @@ use std::error::Error; use std::fmt; use super::ast::{ - BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem, SortDirection, TableRef, + AggregateArg, AggregateFunc, BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem, + SortDirection, TableRef, }; /// Errors returned by the minimal SQL parser. @@ -50,11 +51,14 @@ enum Token { Desc, Null, Limit, + Group, Identifier(String), String(String), Integer(usize), Star, Comma, + LParen, + RParen, Eq, Ne, } @@ -87,6 +91,13 @@ impl Parser { } else { None }; + let group_by = if self.peek() == Some(&Token::Group) { + self.index += 1; + self.expect_keyword(Token::By, "BY")?; + self.parse_group_by()? + } else { + Vec::new() + }; let order_by = if self.peek() == Some(&Token::Order) { self.index += 1; self.expect_keyword(Token::By, "BY")?; @@ -110,11 +121,25 @@ impl Parser { projection, from, selection, + group_by, order_by, limit, }) } + fn parse_group_by(&mut self) -> Result, ParseError> { + let mut items = Vec::new(); + loop { + items.push(self.parse_operand()?); + if self.peek() == Some(&Token::Comma) { + self.index += 1; + continue; + } + break; + } + Ok(items) + } + fn parse_projection(&mut self) -> Result, ParseError> { let mut items = Vec::new(); @@ -262,7 +287,13 @@ impl Parser { fn parse_operand(&mut self) -> Result { match self.next().ok_or(ParseError::UnexpectedEnd)? { - Token::Identifier(name) => Ok(Expr::Identifier(name)), + Token::Identifier(name) => { + if self.peek() == Some(&Token::LParen) { + self.parse_function_call(name) + } else { + Ok(Expr::Identifier(name)) + } + } Token::String(value) => Ok(Expr::Literal(Literal::String(value))), Token::Integer(n) => Ok(Expr::Literal(Literal::Integer(n as i64))), Token::Null => Ok(Expr::Literal(Literal::Null)), @@ -270,6 +301,31 @@ impl Parser { } } + fn parse_function_call(&mut self, name: String) -> Result { + self.expect_keyword(Token::LParen, "(")?; + let func = match name.to_ascii_uppercase().as_str() { + "COUNT" => AggregateFunc::Count, + "SUM" => AggregateFunc::Sum, + "MIN" => AggregateFunc::Min, + "MAX" => AggregateFunc::Max, + "AVG" => AggregateFunc::Avg, + _ => return Err(ParseError::UnexpectedToken(name)), + }; + + let arg = if self.peek() == Some(&Token::Star) { + self.index += 1; + if !matches!(func, AggregateFunc::Count) { + return Err(ParseError::UnexpectedToken("*".to_string())); + } + AggregateArg::Star + } else { + AggregateArg::Expr(Box::new(self.parse_operand()?)) + }; + + self.expect_keyword(Token::RParen, ")")?; + Ok(Expr::Aggregate { func, arg }) + } + fn expect_keyword(&mut self, token: Token, label: &'static str) -> Result<(), ParseError> { let next = self.next().ok_or(ParseError::UnexpectedEnd)?; if next == token { @@ -325,6 +381,14 @@ fn tokenize(input: &str) -> Result, ParseError> { chars.next(); tokens.push(Token::Comma); } + '(' => { + chars.next(); + tokens.push(Token::LParen); + } + ')' => { + chars.next(); + tokens.push(Token::RParen); + } '!' => { chars.next(); if chars.peek() == Some(&'=') { @@ -367,6 +431,7 @@ fn tokenize(input: &str) -> Result, ParseError> { "DESC" => Token::Desc, "NULL" => Token::Null, "LIMIT" => Token::Limit, + "GROUP" => Token::Group, _ => Token::Identifier(ident), }; tokens.push(token); @@ -462,6 +527,9 @@ fn render_token(token: &Token) -> String { Token::String(value) => format!("'{}'", value), Token::Star => "*".to_string(), Token::Comma => ",".to_string(), + Token::LParen => "(".to_string(), + Token::RParen => ")".to_string(), + Token::Group => "GROUP".to_string(), Token::Eq => "=".to_string(), Token::Ne => "!=".to_string(), } diff --git a/tests/sql_pipeline_tests.rs b/tests/sql_pipeline_tests.rs index 2bad1c2..6ab8182 100644 --- a/tests/sql_pipeline_tests.rs +++ b/tests/sql_pipeline_tests.rs @@ -360,3 +360,156 @@ fn execute_with_table_store_scans_in_memory_rows() { assert_eq!(result.rows().len(), 1); assert_eq!(format!("{}", result.rows()[0].values()[0]), "bob"); } + +#[test] +fn count_star_no_group_by() { + let instance = parent_instance(); + let catalog = PredicateCatalog::from_instance(&instance).unwrap(); + let select = parse_select("SELECT COUNT(*) FROM Parent").unwrap(); + let plan = plan_select(&select, &catalog).unwrap(); + let result = execute(&plan, &instance).unwrap(); + + assert_eq!(result.rows().len(), 1); + assert_eq!(format!("{}", result.rows()[0].values()[0]), "2"); +} + +#[test] +fn count_star_group_by_one_column() { + use query_engine::execution::TableStore; + use query_engine::relational::{DataType, Field, Row, Schema, Value}; + + let schema = Schema::new(vec![ + Field::new("dept", DataType::Text, false), + Field::new("name", DataType::Text, false), + ]); + + let mut store = TableStore::new(); + store.insert( + "Emp", + schema.clone(), + vec![ + Row::new(vec![Value::text("eng"), Value::text("alice")]), + Row::new(vec![Value::text("eng"), Value::text("bob")]), + Row::new(vec![Value::text("sales"), Value::text("carol")]), + ], + ); + + let mut catalog = PredicateCatalog::new(); + catalog.register_table("Emp", schema); + + let select = parse_select("SELECT dept, COUNT(*) FROM Emp GROUP BY dept").unwrap(); + let plan = plan_select(&select, &catalog).unwrap(); + let result = execute(&plan, &store).unwrap(); + + assert_eq!(result.rows().len(), 2); + let mut rows: Vec<(String, String)> = result + .rows() + .iter() + .map(|row| { + ( + format!("{}", row.values()[0]), + format!("{}", row.values()[1]), + ) + }) + .collect(); + rows.sort(); + assert_eq!( + rows, + vec![ + ("eng".to_string(), "2".to_string()), + ("sales".to_string(), "1".to_string()), + ] + ); +} + +#[test] +fn sum_min_max_avg_over_integer_column() { + use query_engine::execution::TableStore; + use query_engine::relational::{DataType, Field, Row, Schema, Value}; + + let schema = Schema::new(vec![ + Field::new("dept", DataType::Text, false), + Field::new("salary", DataType::Integer, false), + ]); + + let mut store = TableStore::new(); + store.insert( + "Emp", + schema.clone(), + vec![ + Row::new(vec![Value::text("eng"), Value::Integer(100)]), + Row::new(vec![Value::text("eng"), Value::Integer(200)]), + Row::new(vec![Value::text("sales"), Value::Integer(50)]), + ], + ); + + let mut catalog = PredicateCatalog::new(); + catalog.register_table("Emp", schema); + + let select = parse_select( + "SELECT dept, SUM(salary), MIN(salary), MAX(salary), AVG(salary) FROM Emp GROUP BY dept", + ) + .unwrap(); + let plan = plan_select(&select, &catalog).unwrap(); + let result = execute(&plan, &store).unwrap(); + assert_eq!(result.rows().len(), 2); + + let mut rows: Vec<(String, String, String, String, String)> = result + .rows() + .iter() + .map(|row| { + ( + format!("{}", row.values()[0]), + format!("{}", row.values()[1]), + format!("{}", row.values()[2]), + format!("{}", row.values()[3]), + format!("{}", row.values()[4]), + ) + }) + .collect(); + rows.sort(); + assert_eq!( + rows[0], + ( + "eng".to_string(), + "300".to_string(), + "100".to_string(), + "200".to_string(), + "150".to_string(), + ) + ); + assert_eq!( + rows[1], + ( + "sales".to_string(), + "50".to_string(), + "50".to_string(), + "50".to_string(), + "50".to_string(), + ) + ); +} + +#[test] +fn projection_not_in_group_by_errors() { + use query_engine::execution::TableStore; + use query_engine::relational::{DataType, Field, Schema}; + + let schema = Schema::new(vec![ + Field::new("dept", DataType::Text, false), + Field::new("name", DataType::Text, false), + ]); + + let mut store = TableStore::new(); + store.insert("Emp", schema.clone(), Vec::new()); + + let mut catalog = PredicateCatalog::new(); + catalog.register_table("Emp", schema); + + let select = parse_select("SELECT dept, name FROM Emp GROUP BY dept").unwrap(); + let err = plan_select(&select, &catalog).unwrap_err(); + assert!( + err.to_string() + .contains("not aggregated and not in GROUP BY") + ); +}