Add filter push-down, SQL aggregation/GROUP BY

This commit is contained in:
Hassan Abedi 2026-04-13 13:14:01 +02:00
parent 57a6eaaef6
commit e63a47f7bd
13 changed files with 1328 additions and 19 deletions

View File

@ -71,7 +71,7 @@ Quick examples:
- The chase engine should remain largely stateless; pass execution state explicitly.
- New chase variants should be composable with existing infrastructure.
- Existential variables generate labeled nulls (`Term::Null`).
- The current SQL support is intentionally narrow: `SELECT-FROM-WHERE-ORDER BY-LIMIT` over predicate-backed tables; equality and inequality predicates combined with `AND` and `OR`; comma-join style multi-table queries; table aliases; ordering by output-column names; integer and string literals.
- The current SQL support is intentionally narrow: `SELECT-FROM-WHERE-GROUP BY-ORDER BY-LIMIT` over predicate-backed tables; equality and inequality predicates combined with `AND` and `OR`; comma-join style multi-table queries; table aliases; ordering by output-column names; integer and string literals; `COUNT`, `SUM`, `MIN`, `MAX`, and `AVG` aggregates with optional `GROUP BY`.
- Stable SQL column names come from explicit catalog registration or the frontend `schema ...` command, including for empty tables; otherwise the default names are positional such as `c0` and `c1`.
- Single-table SQL queries may use the table name as a qualifier when no alias is present.
- Do not describe unsupported SQL features such as aggregates, grouping, or arbitrary expressions as implemented.

View File

@ -34,6 +34,15 @@ binaries = []
[dev-dependencies]
proptest = "1.6"
criterion = { version = "0.5", default-features = false }
[[bench]]
name = "chase"
harness = false
[[bench]]
name = "sql"
harness = false
[profile.release]
strip = "debuginfo"

View File

@ -16,7 +16,8 @@ execution boundaries.
- Script, REPL, and local web UI for experimentation
- Relational schema, catalog, logical-plan, and execution scaffolding
- Physical operator scaffolding with a small rule-based rewrite layer
- A minimal SQL slice for `SELECT-FROM-WHERE-ORDER BY-LIMIT` queries over predicate-backed tables
- A minimal SQL slice for `SELECT-FROM-WHERE-GROUP BY-ORDER BY-LIMIT` queries over predicate-backed tables, including `COUNT`, `SUM`, `MIN`, `MAX`, and `AVG` aggregates
- Filter push-down across joins in the physical rewrite pass
### Architecture
@ -143,6 +144,8 @@ WHERE Parent.child = Ancestor.parent
SELECT p.parent, q.child
FROM Parent AS p, Parent AS q
WHERE p.child = q.parent
SELECT COUNT(*) FROM Parent
SELECT dept, COUNT(*), SUM(salary) FROM Emp GROUP BY dept
```
In the REPL or script runner, use the `sql` command and end the statement with
@ -191,7 +194,7 @@ Current limits:
- `ORDER BY` supports output-column ordering with `ASC`/`DESC`
- `LIMIT` restricts the number of output rows
- literals include strings, integers, and `NULL`
- no aggregates
- aggregates: `COUNT(*)`, `COUNT(col)`, `SUM`, `MIN`, `MAX`, `AVG`, with optional `GROUP BY`
- projection aliases only via `AS`
Runnable SQL examples:
@ -212,6 +215,12 @@ cargo clippy --all-targets --all-features -- -D warnings
cargo fmt --check
```
Benchmarks live under `benches/` and can be run with:
```bash
cargo bench
```
### Notes
This repository is still centered on a rule-engine core. The new SQL-related

View File

@ -35,6 +35,8 @@ This document tracks the current state and next steps for the repository.
- [x] `!=`/`<>` inequality and `OR` disjunction in `WHERE` clauses
- [x] `LIMIT` clause for restricting output row count
- [x] Integer literal and `DataType::Integer` support
- [x] `COUNT`, `SUM`, `MIN`, `MAX`, `AVG` aggregates with `GROUP BY`
- [x] Filter push-down rewrite across `NestedLoopJoin` in the physical layer
### Near-Term Cleanup
@ -81,7 +83,7 @@ This document tracks the current state and next steps for the repository.
- [ ] Negative constraints
- [ ] Stratified negation in rule bodies
- [ ] Disjunctive heads
- [ ] Aggregation support in rule evaluation
- [ ] Aggregation support in rule evaluation (available in SQL; not yet exposed to chase rules)
- [x] Semi-naive evaluation
- [ ] Termination analysis helpers
@ -108,6 +110,6 @@ This document tracks the current state and next steps for the repository.
- [x] Property-based tests
- [x] Regression tests
- [x] Initial SQL pipeline tests
- [ ] Benchmark coverage
- [x] Benchmark coverage (chase and SQL pipeline via `cargo bench`)
- [ ] Snapshot-style frontend tests
- [ ] More planner/executor tests as those layers are added

98
benches/chase.rs Normal file
View File

@ -0,0 +1,98 @@
//! Benchmarks for the chase subsystem.
//!
//! These are designed to retroactively validate the semi-naive and Skolem
//! work and catch future regressions. Each workload runs several chase
//! variants over the same input so relative numbers are meaningful.
use criterion::{BatchSize, Criterion, criterion_group, criterion_main};
use query_engine::chase::rule::RuleBuilder;
use query_engine::chase::{ChaseConfig, ChaseVariant, Rule, chase_with_config};
use query_engine::{Atom, Instance, Term};
fn chain_edges(n: usize) -> Instance {
(0..n)
.map(|i| {
Atom::new(
"Edge",
vec![
Term::constant(format!("n{}", i)),
Term::constant(format!("n{}", i + 1)),
],
)
})
.collect()
}
fn transitive_closure_rules() -> Vec<Rule> {
let edge_to_path = RuleBuilder::new()
.when("Edge", vec![Term::var("X"), Term::var("Y")])
.then("Path", vec![Term::var("X"), Term::var("Y")])
.build();
let extend_path = RuleBuilder::new()
.when("Path", vec![Term::var("X"), Term::var("Y")])
.when("Edge", vec![Term::var("Y"), Term::var("Z")])
.then("Path", vec![Term::var("X"), Term::var("Z")])
.build();
vec![edge_to_path, extend_path]
}
fn bench_transitive_closure(c: &mut Criterion) {
let mut group = c.benchmark_group("transitive_closure_chain_20");
let instance = chain_edges(20);
let rules = transitive_closure_rules();
for (label, variant, semi) in [
("restricted_naive", ChaseVariant::Restricted, false),
("restricted_semi_naive", ChaseVariant::Restricted, true),
("standard_naive", ChaseVariant::Standard, false),
("standard_semi_naive", ChaseVariant::Standard, true),
] {
let config = ChaseConfig {
variant,
semi_naive: semi,
..Default::default()
};
group.bench_function(label, |b| {
b.iter_batched(
|| instance.clone(),
|inst| chase_with_config(inst, &rules, config.clone()),
BatchSize::SmallInput,
);
});
}
group.finish();
}
fn bench_existentials(c: &mut Criterion) {
let mut group = c.benchmark_group("existentials_50_people");
let instance: Instance = (0..50)
.map(|i| Atom::new("Person", vec![Term::constant(format!("p{}", i))]))
.collect();
let rule = RuleBuilder::new()
.when("Person", vec![Term::var("X")])
.then("HasId", vec![Term::var("X"), Term::var("Y")])
.build();
let rules = vec![rule];
for (label, variant) in [
("restricted", ChaseVariant::Restricted),
("skolem", ChaseVariant::Skolem),
] {
let config = ChaseConfig {
variant,
semi_naive: false,
..Default::default()
};
group.bench_function(label, |b| {
b.iter_batched(
|| instance.clone(),
|inst| chase_with_config(inst, &rules, config.clone()),
BatchSize::SmallInput,
);
});
}
group.finish();
}
criterion_group!(benches, bench_transitive_closure, bench_existentials);
criterion_main!(benches);

115
benches/sql.rs Normal file
View File

@ -0,0 +1,115 @@
//! Benchmarks for the SQL pipeline.
//!
//! Focus areas: scans, single-column filters, multi-table joins with and
//! without filter push-down, and GROUP BY aggregation.
use criterion::{BatchSize, Criterion, criterion_group, criterion_main};
use query_engine::catalog::PredicateCatalog;
use query_engine::execution::TableStore;
use query_engine::execution::execute;
use query_engine::execution::physical::{execute_physical, plan_physical, rewrite_physical};
use query_engine::planner::sql::plan_select;
use query_engine::relational::{DataType, Field, Row, Schema, Value};
use query_engine::sql::parser::parse_select;
use query_engine::{Atom, Instance, Term};
fn edges_instance(n: usize) -> Instance {
(0..n)
.map(|i| {
Atom::new(
"L",
vec![
Term::constant(format!("a{}", i)),
Term::constant(format!("b{}", i)),
],
)
})
.chain((0..n).map(|i| {
Atom::new(
"R",
vec![
Term::constant(format!("b{}", i)),
Term::constant(format!("c{}", i)),
],
)
}))
.collect()
}
fn bench_filter_pushdown_join(c: &mut Criterion) {
let mut group = c.benchmark_group("filter_pushdown_join_100");
let instance = edges_instance(100);
let mut catalog = PredicateCatalog::from_instance(&instance).unwrap();
catalog.rename_columns("L", ["a", "b"]).unwrap();
catalog.rename_columns("R", ["b", "c"]).unwrap();
let select = parse_select("SELECT L.a, R.c FROM L, R WHERE L.b = R.b AND L.a = 'a42'").unwrap();
let logical = plan_select(&select, &catalog).unwrap();
group.bench_function("logical_direct_execute", |b| {
b.iter(|| execute(&logical, &instance).unwrap());
});
let physical_raw = plan_physical(&logical);
group.bench_function("physical_no_rewrite", |b| {
b.iter(|| execute_physical(&physical_raw, &instance).unwrap());
});
let physical_rewritten = rewrite_physical(plan_physical(&logical));
group.bench_function("physical_with_pushdown", |b| {
b.iter(|| execute_physical(&physical_rewritten, &instance).unwrap());
});
group.finish();
}
fn bench_group_by_aggregation(c: &mut Criterion) {
let mut group = c.benchmark_group("group_by_aggregation_1000");
let schema = Schema::new(vec![
Field::new("dept", DataType::Text, false),
Field::new("salary", DataType::Integer, false),
]);
let mut store = TableStore::new();
let rows: Vec<Row> = (0..1000)
.map(|i| {
let dept = format!("d{}", i % 10);
Row::new(vec![Value::text(dept), Value::Integer((i as i64) * 10)])
})
.collect();
store.insert("Emp", schema.clone(), rows);
let mut catalog = PredicateCatalog::new();
catalog.register_table("Emp", schema);
let select =
parse_select("SELECT dept, COUNT(*), SUM(salary), AVG(salary) FROM Emp GROUP BY dept")
.unwrap();
let logical = plan_select(&select, &catalog).unwrap();
group.bench_function("logical_direct", |b| {
b.iter_batched(
|| (),
|_| execute(&logical, &store).unwrap(),
BatchSize::SmallInput,
);
});
let physical = rewrite_physical(plan_physical(&logical));
group.bench_function("physical", |b| {
b.iter_batched(
|| (),
|_| execute_physical(&physical, &store).unwrap(),
BatchSize::SmallInput,
);
});
group.finish();
}
criterion_group!(
benches,
bench_filter_pushdown_join,
bench_group_by_aggregation
);
criterion_main!(benches);

View File

@ -12,8 +12,11 @@ use std::error::Error;
use std::fmt;
use crate::chase::{Instance, Term};
use crate::planner::logical::{LogicalExpr, LogicalPlan, SortDirection, SortKey};
use crate::planner::logical::{
AggregateExpr as PlanAggregateExpr, LogicalExpr, LogicalPlan, SortDirection, SortKey,
};
use crate::relational::{ResultSet, Row, Schema, Value};
use crate::sql::ast::AggregateFunc;
pub use physical::{
NamedPhysicalExpr, PhysicalPlan, execute_physical, plan_physical, rewrite_physical,
@ -133,6 +136,195 @@ pub fn execute(plan: &LogicalPlan, source: &dyn DataSource) -> Result<ResultSet,
let rows = result.rows().iter().take(*count).cloned().collect();
Ok(ResultSet::new(result.schema().clone(), rows))
}
LogicalPlan::Aggregate {
input,
group_by,
aggregates,
schema,
} => {
let result = execute(input, source)?;
let rows = compute_aggregate(result.rows(), result.schema(), group_by, aggregates)?;
Ok(ResultSet::new(schema.clone(), rows))
}
}
}
/// Evaluate group-by + aggregates over a row set, returning one output row
/// per distinct group key. The output row layout is: group_by column values
/// followed by aggregate output values.
pub(crate) fn compute_aggregate(
rows: &[Row],
input_schema: &Schema,
group_by: &[String],
aggregates: &[PlanAggregateExpr],
) -> Result<Vec<Row>, ExecutionError> {
let group_indexes = group_by
.iter()
.map(|name| {
input_schema
.index_of(name)
.ok_or_else(|| ExecutionError::UnknownColumn(name.clone()))
})
.collect::<Result<Vec<_>, _>>()?;
// Each aggregate holds an optional input column index (None means COUNT(*)).
let agg_indexes = aggregates
.iter()
.map(|agg| {
agg.arg
.as_ref()
.map(|col| {
input_schema
.index_of(col)
.ok_or_else(|| ExecutionError::UnknownColumn(col.clone()))
})
.transpose()
})
.collect::<Result<Vec<Option<usize>>, _>>()?;
// Preserve first-seen group order so single-group output is deterministic.
let mut order: Vec<Vec<Value>> = Vec::new();
let mut groups: std::collections::HashMap<Vec<Value>, Vec<AggregateState>> =
std::collections::HashMap::new();
for row in rows {
let key: Vec<Value> = group_indexes
.iter()
.map(|i| row.get(*i).cloned().unwrap_or(Value::Null))
.collect();
let states = groups.entry(key.clone()).or_insert_with(|| {
order.push(key.clone());
aggregates
.iter()
.map(|agg| AggregateState::new(agg.func))
.collect()
});
for (state, index_opt) in states.iter_mut().zip(agg_indexes.iter()) {
let value = match index_opt {
Some(i) => row.get(*i).cloned().unwrap_or(Value::Null),
None => Value::Null, // COUNT(*) observes each row
};
state.observe(&value, index_opt.is_none());
}
}
// If the user wrote an aggregate with no GROUP BY and no input rows, we
// still need one output row (all-null plus zero counts).
if rows.is_empty() && group_by.is_empty() && !aggregates.is_empty() {
order.push(Vec::new());
groups.insert(
Vec::new(),
aggregates
.iter()
.map(|agg| AggregateState::new(agg.func))
.collect(),
);
}
let mut out_rows = Vec::new();
for key in order {
let states = groups.remove(&key).unwrap_or_default();
let mut values = key;
for state in &states {
values.push(state.finalize());
}
out_rows.push(Row::new(values));
}
Ok(out_rows)
}
#[derive(Debug)]
pub(crate) enum AggregateState {
Count(i64),
Sum(Option<i64>),
Min(Option<Value>),
Max(Option<Value>),
Avg { sum: i64, count: i64 },
}
impl AggregateState {
pub(crate) fn new(func: AggregateFunc) -> Self {
match func {
AggregateFunc::Count => Self::Count(0),
AggregateFunc::Sum => Self::Sum(None),
AggregateFunc::Min => Self::Min(None),
AggregateFunc::Max => Self::Max(None),
AggregateFunc::Avg => Self::Avg { sum: 0, count: 0 },
}
}
pub(crate) fn observe(&mut self, value: &Value, is_count_star: bool) {
match self {
Self::Count(c) => {
if is_count_star || !matches!(value, Value::Null) {
*c += 1;
}
}
Self::Sum(total) => {
if let Value::Integer(n) = value {
*total = Some(total.unwrap_or(0) + n);
}
}
Self::Min(current) => {
if !matches!(value, Value::Null) {
match current {
None => *current = Some(value.clone()),
Some(existing) => {
if compare_values_for_agg(value, existing) == std::cmp::Ordering::Less {
*existing = value.clone();
}
}
}
}
}
Self::Max(current) => {
if !matches!(value, Value::Null) {
match current {
None => *current = Some(value.clone()),
Some(existing) => {
if compare_values_for_agg(value, existing)
== std::cmp::Ordering::Greater
{
*existing = value.clone();
}
}
}
}
}
Self::Avg { sum, count } => {
if let Value::Integer(n) = value {
*sum += n;
*count += 1;
}
}
}
}
pub(crate) fn finalize(&self) -> Value {
match self {
Self::Count(c) => Value::Integer(*c),
Self::Sum(total) => total.map(Value::Integer).unwrap_or(Value::Null),
Self::Min(v) | Self::Max(v) => v.clone().unwrap_or(Value::Null),
Self::Avg { sum, count } => {
if *count == 0 {
Value::Null
} else {
Value::Integer(sum / count)
}
}
}
}
}
fn compare_values_for_agg(left: &Value, right: &Value) -> std::cmp::Ordering {
match (left, right) {
(Value::Integer(a), Value::Integer(b)) => a.cmp(b),
(Value::Text(a), Value::Text(b)) => a.cmp(b),
(Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
_ => std::cmp::Ordering::Equal,
}
}

View File

@ -15,10 +15,12 @@
use std::cmp::Ordering;
use crate::planner::logical::{LogicalExpr, LogicalPlan, SortDirection, SortKey};
use crate::planner::logical::{
AggregateExpr as PlanAggregateExpr, LogicalExpr, LogicalPlan, SortDirection, SortKey,
};
use crate::relational::{ResultSet, Row, Schema, Value};
use super::{DataSource, ExecutionError};
use super::{DataSource, ExecutionError, compute_aggregate};
/// A physical plan node in the current execution subset.
#[derive(Debug, Clone, PartialEq, Eq)]
@ -54,6 +56,13 @@ pub enum PhysicalPlan {
input: Box<PhysicalPlan>,
count: usize,
},
/// Compute aggregates per group key using an in-memory hash map.
HashAggregate {
input: Box<PhysicalPlan>,
group_by: Vec<String>,
aggregates: Vec<PlanAggregateExpr>,
schema: Schema,
},
}
/// A named physical expression in a projection.
@ -75,6 +84,7 @@ impl PhysicalPlan {
Self::Sort { schema, .. } => schema,
Self::Project { schema, .. } => schema,
Self::Limit { input, .. } => input.output_schema(),
Self::HashAggregate { schema, .. } => schema,
}
}
}
@ -131,19 +141,209 @@ pub fn plan_physical(plan: &LogicalPlan) -> PhysicalPlan {
input: Box::new(plan_physical(input)),
count: *count,
},
LogicalPlan::Aggregate {
input,
group_by,
aggregates,
schema,
} => PhysicalPlan::HashAggregate {
input: Box::new(plan_physical(input)),
group_by: group_by.clone(),
aggregates: aggregates.clone(),
schema: schema.clone(),
},
}
}
/// Apply rule-based rewrites to a physical plan.
///
/// Today the only rewrite is `combine_adjacent_limits`, which collapses
/// `Limit(Limit(child, n), m)` into `Limit(child, min(n, m))`. Future
/// rewrites belong here as additional functions composed in this entry
/// point.
/// Current rewrites:
/// - [`combine_adjacent_limits`] collapses `Limit(Limit(child, n), m)` into
/// `Limit(child, min(n, m))`.
/// - [`push_filter_below_join`] pushes conjuncts of a `Filter` below a
/// `NestedLoopJoin` when they reference only one side's columns, so the
/// join sees fewer rows.
pub fn rewrite_physical(plan: PhysicalPlan) -> PhysicalPlan {
let plan = push_filter_below_join(plan);
combine_adjacent_limits(plan)
}
/// Push conjuncts of a `Filter` below a `NestedLoopJoin` when each conjunct
/// references only columns from one side of the join. Conjuncts that mention
/// both sides remain above the join.
fn push_filter_below_join(plan: PhysicalPlan) -> PhysicalPlan {
match plan {
PhysicalPlan::Filter { input, predicate } => {
let pushed_input = push_filter_below_join(*input);
match pushed_input {
PhysicalPlan::NestedLoopJoin {
left,
right,
schema,
} => {
let left_cols: Vec<String> = left
.output_schema()
.fields()
.iter()
.map(|f| f.name().to_string())
.collect();
let right_cols: Vec<String> = right
.output_schema()
.fields()
.iter()
.map(|f| f.name().to_string())
.collect();
let mut left_conjuncts: Vec<LogicalExpr> = Vec::new();
let mut right_conjuncts: Vec<LogicalExpr> = Vec::new();
let mut remaining: Vec<LogicalExpr> = Vec::new();
for conjunct in split_conjuncts(predicate) {
let refs = collect_column_refs(&conjunct);
let all_left = refs.iter().all(|c| left_cols.contains(c));
let all_right = refs.iter().all(|c| right_cols.contains(c));
if !refs.is_empty() && all_left {
left_conjuncts.push(conjunct);
} else if !refs.is_empty() && all_right {
right_conjuncts.push(conjunct);
} else {
remaining.push(conjunct);
}
}
let left = if let Some(pred) = combine_conjuncts(left_conjuncts) {
Box::new(PhysicalPlan::Filter {
input: left,
predicate: pred,
})
} else {
left
};
let right = if let Some(pred) = combine_conjuncts(right_conjuncts) {
Box::new(PhysicalPlan::Filter {
input: right,
predicate: pred,
})
} else {
right
};
// Recurse so pushed filters below the join continue to
// push through deeper joins if any.
let left = Box::new(push_filter_below_join(*left));
let right = Box::new(push_filter_below_join(*right));
let joined = PhysicalPlan::NestedLoopJoin {
left,
right,
schema,
};
match combine_conjuncts(remaining) {
Some(pred) => PhysicalPlan::Filter {
input: Box::new(joined),
predicate: pred,
},
None => joined,
}
}
other => PhysicalPlan::Filter {
input: Box::new(other),
predicate,
},
}
}
PhysicalPlan::NestedLoopJoin {
left,
right,
schema,
} => PhysicalPlan::NestedLoopJoin {
left: Box::new(push_filter_below_join(*left)),
right: Box::new(push_filter_below_join(*right)),
schema,
},
PhysicalPlan::Sort {
input,
keys,
schema,
} => PhysicalPlan::Sort {
input: Box::new(push_filter_below_join(*input)),
keys,
schema,
},
PhysicalPlan::Project {
input,
expressions,
schema,
} => PhysicalPlan::Project {
input: Box::new(push_filter_below_join(*input)),
expressions,
schema,
},
PhysicalPlan::Limit { input, count } => PhysicalPlan::Limit {
input: Box::new(push_filter_below_join(*input)),
count,
},
PhysicalPlan::HashAggregate {
input,
group_by,
aggregates,
schema,
} => PhysicalPlan::HashAggregate {
input: Box::new(push_filter_below_join(*input)),
group_by,
aggregates,
schema,
},
leaf @ PhysicalPlan::SeqScan { .. } => leaf,
}
}
fn split_conjuncts(expr: LogicalExpr) -> Vec<LogicalExpr> {
let mut out = Vec::new();
let mut stack = vec![expr];
while let Some(node) = stack.pop() {
match node {
LogicalExpr::And(left, right) => {
stack.push(*right);
stack.push(*left);
}
other => out.push(other),
}
}
out
}
fn combine_conjuncts(mut conjuncts: Vec<LogicalExpr>) -> Option<LogicalExpr> {
if conjuncts.is_empty() {
return None;
}
let mut combined = conjuncts.remove(0);
for next in conjuncts {
combined = LogicalExpr::And(Box::new(combined), Box::new(next));
}
Some(combined)
}
fn collect_column_refs(expr: &LogicalExpr) -> Vec<String> {
let mut out = Vec::new();
fn walk(expr: &LogicalExpr, out: &mut Vec<String>) {
match expr {
LogicalExpr::Column(name) => out.push(name.clone()),
LogicalExpr::Literal(_) => {}
LogicalExpr::Eq(left, right)
| LogicalExpr::Ne(left, right)
| LogicalExpr::And(left, right)
| LogicalExpr::Or(left, right) => {
walk(left, out);
walk(right, out);
}
}
}
walk(expr, &mut out);
out
}
fn combine_adjacent_limits(plan: PhysicalPlan) -> PhysicalPlan {
match plan {
PhysicalPlan::Limit { input, count } => {
@ -193,6 +393,17 @@ fn combine_adjacent_limits(plan: PhysicalPlan) -> PhysicalPlan {
expressions,
schema,
},
PhysicalPlan::HashAggregate {
input,
group_by,
aggregates,
schema,
} => PhysicalPlan::HashAggregate {
input: Box::new(combine_adjacent_limits(*input)),
group_by,
aggregates,
schema,
},
leaf @ PhysicalPlan::SeqScan { .. } => leaf,
}
}
@ -265,6 +476,16 @@ pub fn execute_physical(
let rows = result.rows().iter().take(*count).cloned().collect();
Ok(ResultSet::new(result.schema().clone(), rows))
}
PhysicalPlan::HashAggregate {
input,
group_by,
aggregates,
schema,
} => {
let result = execute_physical(input, source)?;
let rows = compute_aggregate(result.rows(), result.schema(), group_by, aggregates)?;
Ok(ResultSet::new(schema.clone(), rows))
}
}
}
@ -467,4 +688,136 @@ mod tests {
assert_eq!(result.rows().len(), 1);
assert_eq!(result.rows()[0].values()[0], Value::text("alice"));
}
#[test]
fn rewrite_pushes_single_side_filter_below_join() {
let left_schema = Schema::new(vec![
Field::new("Parent.parent", DataType::Text, false),
Field::new("Parent.child", DataType::Text, false),
]);
let right_schema = Schema::new(vec![
Field::new("Ancestor.parent", DataType::Text, false),
Field::new("Ancestor.child", DataType::Text, false),
]);
let join_schema = Schema::new(
left_schema
.fields()
.iter()
.chain(right_schema.fields())
.cloned()
.collect(),
);
// Filter(
// NestedLoopJoin(Parent, Ancestor),
// Parent.parent = 'alice' AND Parent.child = Ancestor.parent,
// )
let plan = PhysicalPlan::Filter {
input: Box::new(PhysicalPlan::NestedLoopJoin {
left: Box::new(PhysicalPlan::SeqScan {
table: "Parent".to_string(),
schema: left_schema,
}),
right: Box::new(PhysicalPlan::SeqScan {
table: "Ancestor".to_string(),
schema: right_schema,
}),
schema: join_schema,
}),
predicate: LogicalExpr::And(
Box::new(LogicalExpr::Eq(
Box::new(LogicalExpr::Column("Parent.parent".to_string())),
Box::new(LogicalExpr::Literal(Value::text("alice"))),
)),
Box::new(LogicalExpr::Eq(
Box::new(LogicalExpr::Column("Parent.child".to_string())),
Box::new(LogicalExpr::Column("Ancestor.parent".to_string())),
)),
),
};
let rewritten = rewrite_physical(plan);
match rewritten {
// The Parent.parent = 'alice' predicate should be pushed onto the
// left side; the join predicate should remain above.
PhysicalPlan::Filter { input, .. } => match *input {
PhysicalPlan::NestedLoopJoin { left, .. } => {
assert!(matches!(*left, PhysicalPlan::Filter { .. }));
}
other => panic!("expected NestedLoopJoin under Filter, got {:?}", other),
},
other => panic!("expected outer Filter, got {:?}", other),
}
}
#[test]
fn rewrite_push_filter_preserves_semantics_on_join() {
// Two three-row tables, join predicate filters down to one row.
// Push-down should not change the row count or values.
struct TwoTable;
impl DataSource for TwoTable {
fn scan(&self, table: &str, schema: &Schema) -> Result<ResultSet, ExecutionError> {
let rows = match table {
"L" => vec![
Row::new(vec![Value::text("alice"), Value::text("bob")]),
Row::new(vec![Value::text("bob"), Value::text("carol")]),
Row::new(vec![Value::text("carol"), Value::text("dave")]),
],
"R" => vec![
Row::new(vec![Value::text("bob"), Value::text("x")]),
Row::new(vec![Value::text("carol"), Value::text("y")]),
Row::new(vec![Value::text("eve"), Value::text("z")]),
],
_ => Vec::new(),
};
Ok(ResultSet::new(schema.clone(), rows))
}
}
let left_schema = Schema::new(vec![
Field::new("L.a", DataType::Text, false),
Field::new("L.b", DataType::Text, false),
]);
let right_schema = Schema::new(vec![
Field::new("R.a", DataType::Text, false),
Field::new("R.b", DataType::Text, false),
]);
let join_schema = Schema::new(
left_schema
.fields()
.iter()
.chain(right_schema.fields())
.cloned()
.collect(),
);
let plan = PhysicalPlan::Filter {
input: Box::new(PhysicalPlan::NestedLoopJoin {
left: Box::new(PhysicalPlan::SeqScan {
table: "L".to_string(),
schema: left_schema,
}),
right: Box::new(PhysicalPlan::SeqScan {
table: "R".to_string(),
schema: right_schema,
}),
schema: join_schema,
}),
predicate: LogicalExpr::And(
Box::new(LogicalExpr::Eq(
Box::new(LogicalExpr::Column("L.a".to_string())),
Box::new(LogicalExpr::Literal(Value::text("bob"))),
)),
Box::new(LogicalExpr::Eq(
Box::new(LogicalExpr::Column("L.b".to_string())),
Box::new(LogicalExpr::Column("R.a".to_string())),
)),
),
};
let before = execute_physical(&plan, &TwoTable).unwrap();
let after = execute_physical(&rewrite_physical(plan.clone()), &TwoTable).unwrap();
assert_eq!(before.rows().len(), after.rows().len());
assert_eq!(before.rows(), after.rows());
}
}

View File

@ -1,4 +1,5 @@
use crate::relational::{Schema, Value};
use crate::sql::ast::AggregateFunc;
/// Sort direction for the logical `Sort` operator.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -44,6 +45,17 @@ pub struct SortKey {
pub direction: SortDirection,
}
/// A single aggregate output in a logical `Aggregate` operator.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AggregateExpr {
/// Output column name for this aggregate.
pub name: String,
/// Aggregate function to apply.
pub func: AggregateFunc,
/// Source column name for the aggregate input, or `None` for `COUNT(*)`.
pub arg: Option<String>,
}
/// A logical plan in the current execution subset.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LogicalPlan {
@ -60,6 +72,14 @@ pub enum LogicalPlan {
input: Box<LogicalPlan>,
predicate: LogicalExpr,
},
/// Group rows by a list of columns and compute aggregates per group.
/// The output schema is `group_by` columns followed by aggregate outputs.
Aggregate {
input: Box<LogicalPlan>,
group_by: Vec<String>,
aggregates: Vec<AggregateExpr>,
schema: Schema,
},
/// Sort rows by one or more output columns.
Sort {
input: Box<LogicalPlan>,
@ -86,6 +106,7 @@ impl LogicalPlan {
Self::Scan { schema, .. } => schema,
Self::CrossJoin { schema, .. } => schema,
Self::Filter { input, .. } => input.output_schema(),
Self::Aggregate { schema, .. } => schema,
Self::Sort { schema, .. } => schema,
Self::Project { schema, .. } => schema,
Self::Limit { input, .. } => input.output_schema(),

View File

@ -4,11 +4,13 @@ use std::fmt;
use crate::catalog::{CatalogError, PredicateCatalog};
use crate::planner::logical::{
LogicalExpr, LogicalPlan, NamedExpr, SortDirection as LogicalSortDirection, SortKey,
AggregateExpr as PlanAggregateExpr, LogicalExpr, LogicalPlan, NamedExpr,
SortDirection as LogicalSortDirection, SortKey,
};
use crate::relational::{DataType, Field, Schema, Value};
use crate::sql::ast::{
BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem, SortDirection, TableRef,
AggregateArg, AggregateFunc, BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem,
SortDirection, TableRef,
};
/// Errors returned when translating SQL AST into a logical plan.
@ -24,6 +26,14 @@ pub enum PlannerError {
UnsupportedOrderBy,
/// The parser or AST contains a wildcard mixed with other projection items.
MixedWildcardProjection,
/// A `GROUP BY` expression is not a simple column reference.
UnsupportedGroupBy,
/// A projected column is neither aggregated nor present in `GROUP BY`.
ProjectionNotGrouped(String),
/// An aggregate expression appears in an unsupported position.
UnsupportedAggregate,
/// `COUNT(*)` was used with a non-count aggregate function.
StarArgNotAllowed,
}
impl fmt::Display for PlannerError {
@ -43,6 +53,18 @@ impl fmt::Display for PlannerError {
"wildcard projections cannot be combined with other items"
)
}
Self::UnsupportedGroupBy => {
write!(f, "only bare column references are supported in GROUP BY")
}
Self::ProjectionNotGrouped(name) => {
write!(f, "column `{}` is not aggregated and not in GROUP BY", name)
}
Self::UnsupportedAggregate => {
write!(f, "aggregate expressions are only allowed in SELECT items")
}
Self::StarArgNotAllowed => {
write!(f, "`*` is only allowed as the argument to COUNT")
}
}
}
}
@ -54,7 +76,11 @@ impl Error for PlannerError {
Self::UnknownColumn(_)
| Self::DuplicateSourceName(_)
| Self::UnsupportedOrderBy
| Self::MixedWildcardProjection => None,
| Self::MixedWildcardProjection
| Self::UnsupportedGroupBy
| Self::ProjectionNotGrouped(_)
| Self::UnsupportedAggregate
| Self::StarArgNotAllowed => None,
}
}
}
@ -80,7 +106,15 @@ pub fn plan_select(
};
}
if !is_wildcard_projection(&select.projection) {
let is_aggregate_query = !select.group_by.is_empty()
|| select.projection.iter().any(|item| match item {
SelectItem::Expr { expr, .. } => contains_aggregate(expr),
SelectItem::Wildcard => false,
});
if is_aggregate_query {
plan = plan_aggregate(plan, &input_schema, select)?;
} else if !is_wildcard_projection(&select.projection) {
let mut expressions = Vec::new();
let mut fields = Vec::new();
for (index, item) in select.projection.iter().enumerate() {
@ -122,6 +156,208 @@ pub fn plan_select(
Ok(plan)
}
fn contains_aggregate(expr: &Expr) -> bool {
match expr {
Expr::Aggregate { .. } => true,
Expr::Binary { left, right, .. } => contains_aggregate(left) || contains_aggregate(right),
Expr::Identifier(_) | Expr::Literal(_) => false,
}
}
fn plan_aggregate(
input: LogicalPlan,
input_schema: &Schema,
select: &Select,
) -> Result<LogicalPlan, PlannerError> {
// Resolve GROUP BY expressions to column names.
let mut group_by_cols = Vec::new();
for expr in &select.group_by {
match expr {
Expr::Identifier(name) => {
let resolved = resolve_column_name(name, input_schema, &select.from)?;
group_by_cols.push(resolved);
}
_ => return Err(PlannerError::UnsupportedGroupBy),
}
}
// Walk the projection, collecting aggregate expressions and verifying
// non-aggregate column references are in GROUP BY.
let mut aggregates: Vec<PlanAggregateExpr> = Vec::new();
let mut projection_items: Vec<(String, ProjectionSource)> = Vec::new();
for (index, item) in select.projection.iter().enumerate() {
match item {
SelectItem::Wildcard => return Err(PlannerError::MixedWildcardProjection),
SelectItem::Expr { expr, alias } => {
let output_name = alias
.clone()
.unwrap_or_else(|| default_projection_name(expr, index + 1));
let source = plan_aggregate_projection(
expr,
input_schema,
select,
&group_by_cols,
&mut aggregates,
)?;
projection_items.push((output_name, source));
}
}
}
// Build the Aggregate node's output schema: group_by columns followed by
// aggregate outputs.
let mut agg_fields = Vec::new();
for col in &group_by_cols {
let field_index = input_schema
.index_of(col)
.ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?;
let field = &input_schema.fields()[field_index];
agg_fields.push(Field::new(
col.clone(),
field.data_type().clone(),
field.nullable(),
));
}
for agg in &aggregates {
let (dtype, nullable) = aggregate_output_type(agg, input_schema)?;
agg_fields.push(Field::new(agg.name.clone(), dtype, nullable));
}
let agg_schema = Schema::new(agg_fields);
let aggregate_plan = LogicalPlan::Aggregate {
input: Box::new(input),
group_by: group_by_cols.clone(),
aggregates,
schema: agg_schema.clone(),
};
// Build the final Project over the aggregate output.
let mut expressions = Vec::new();
let mut fields = Vec::new();
for (name, source) in projection_items {
let (expr, dtype, nullable) = match source {
ProjectionSource::GroupColumn(col) => {
let index = agg_schema
.index_of(&col)
.ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?;
let field = &agg_schema.fields()[index];
(
LogicalExpr::Column(col),
field.data_type().clone(),
field.nullable(),
)
}
ProjectionSource::AggregateColumn(col) => {
let index = agg_schema
.index_of(&col)
.ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?;
let field = &agg_schema.fields()[index];
(
LogicalExpr::Column(col),
field.data_type().clone(),
field.nullable(),
)
}
ProjectionSource::Literal(value) => {
let (dtype, nullable) = literal_metadata(&value);
(LogicalExpr::Literal(value), dtype, nullable)
}
};
expressions.push(NamedExpr {
name: name.clone(),
expr,
});
fields.push(Field::new(name, dtype, nullable));
}
Ok(LogicalPlan::Project {
input: Box::new(aggregate_plan),
expressions,
schema: Schema::new(fields),
})
}
#[derive(Debug, Clone)]
enum ProjectionSource {
GroupColumn(String),
AggregateColumn(String),
Literal(Value),
}
fn plan_aggregate_projection(
expr: &Expr,
input_schema: &Schema,
select: &Select,
group_by_cols: &[String],
aggregates: &mut Vec<PlanAggregateExpr>,
) -> Result<ProjectionSource, PlannerError> {
match expr {
Expr::Aggregate { func, arg } => {
let arg_col = match arg {
AggregateArg::Star => {
if !matches!(func, AggregateFunc::Count) {
return Err(PlannerError::StarArgNotAllowed);
}
None
}
AggregateArg::Expr(inner) => match inner.as_ref() {
Expr::Identifier(name) => {
Some(resolve_column_name(name, input_schema, &select.from)?)
}
_ => return Err(PlannerError::UnsupportedAggregate),
},
};
let synthetic_name = format!("__agg_{}", aggregates.len());
aggregates.push(PlanAggregateExpr {
name: synthetic_name.clone(),
func: *func,
arg: arg_col,
});
Ok(ProjectionSource::AggregateColumn(synthetic_name))
}
Expr::Identifier(name) => {
let resolved = resolve_column_name(name, input_schema, &select.from)?;
if !group_by_cols.contains(&resolved) {
return Err(PlannerError::ProjectionNotGrouped(name.clone()));
}
Ok(ProjectionSource::GroupColumn(resolved))
}
Expr::Literal(literal) => Ok(ProjectionSource::Literal(plan_literal(literal))),
Expr::Binary { .. } => Err(PlannerError::UnsupportedAggregate),
}
}
fn aggregate_output_type(
agg: &PlanAggregateExpr,
input_schema: &Schema,
) -> Result<(DataType, bool), PlannerError> {
match agg.func {
AggregateFunc::Count => Ok((DataType::Integer, false)),
AggregateFunc::Sum | AggregateFunc::Avg => Ok((DataType::Integer, true)),
AggregateFunc::Min | AggregateFunc::Max => {
if let Some(col) = &agg.arg {
let index = input_schema
.index_of(col)
.ok_or_else(|| PlannerError::UnknownColumn(col.clone()))?;
let field = &input_schema.fields()[index];
Ok((field.data_type().clone(), true))
} else {
Ok((DataType::Text, true))
}
}
}
}
fn literal_metadata(value: &Value) -> (DataType, bool) {
match value {
Value::Text(_) => (DataType::Text, false),
Value::Integer(_) => (DataType::Integer, false),
Value::Boolean(_) => (DataType::Boolean, false),
Value::Null => (DataType::Text, true),
}
}
fn is_wildcard_projection(items: &[SelectItem]) -> bool {
matches!(items, [SelectItem::Wildcard])
}
@ -202,6 +438,7 @@ fn plan_expr(
Box::new(plan_expr(right, schema, tables)?),
)),
},
Expr::Aggregate { .. } => Err(PlannerError::UnsupportedAggregate),
}
}
@ -264,6 +501,7 @@ fn projection_metadata(
Expr::Literal(Literal::Integer(_)) => Ok((DataType::Integer, false)),
Expr::Literal(Literal::Null) => Ok((DataType::Text, true)),
Expr::Binary { .. } => Ok((DataType::Boolean, true)),
Expr::Aggregate { .. } => Err(PlannerError::UnsupportedAggregate),
}
}
@ -290,6 +528,23 @@ fn resolve_column_name(
fn default_projection_name(expr: &Expr, ordinal: usize) -> String {
match expr {
Expr::Aggregate { func, arg } => {
let func_name = match func {
AggregateFunc::Count => "COUNT",
AggregateFunc::Sum => "SUM",
AggregateFunc::Min => "MIN",
AggregateFunc::Max => "MAX",
AggregateFunc::Avg => "AVG",
};
let arg_str = match arg {
AggregateArg::Star => "*".to_string(),
AggregateArg::Expr(inner) => match inner.as_ref() {
Expr::Identifier(name) => name.clone(),
_ => format!("expr{}", ordinal),
},
};
format!("{}({})", func_name, arg_str)
}
Expr::Identifier(name) => name.clone(),
Expr::Literal(_) | Expr::Binary { .. } => format!("expr{}", ordinal),
}
@ -566,6 +821,7 @@ mod tests {
alias: None,
}],
selection: None,
group_by: Vec::new(),
order_by: Vec::new(),
limit: None,
};

View File

@ -1,4 +1,5 @@
/// A parsed `SELECT-FROM-WHERE-ORDER BY-LIMIT` statement in the current SQL subset.
/// A parsed `SELECT-FROM-WHERE-GROUP BY-ORDER BY-LIMIT` statement in the
/// current SQL subset.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Select {
/// Output expressions requested by the query.
@ -7,6 +8,8 @@ pub struct Select {
pub from: Vec<TableRef>,
/// Optional filter predicate.
pub selection: Option<Expr>,
/// Grouping columns. Empty means no `GROUP BY` clause.
pub group_by: Vec<Expr>,
/// Optional output ordering.
pub order_by: Vec<OrderByItem>,
/// Optional row limit.
@ -53,6 +56,36 @@ pub enum Expr {
op: BinaryOp,
right: Box<Expr>,
},
/// An aggregate function applied to an argument.
Aggregate {
func: AggregateFunc,
arg: AggregateArg,
},
}
/// An aggregate function in the current SQL subset.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregateFunc {
/// Row count (with `*`) or count of non-null values (with a column).
Count,
/// Sum of integer values.
Sum,
/// Minimum value.
Min,
/// Maximum value.
Max,
/// Arithmetic mean of integer values.
Avg,
}
/// The argument to an aggregate function: either `*` (only valid for
/// `COUNT`) or an expression.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AggregateArg {
/// `COUNT(*)` style argument.
Star,
/// An expression argument such as `SUM(col)`.
Expr(Box<Expr>),
}
/// A SQL literal in the current subset.

View File

@ -2,7 +2,8 @@ use std::error::Error;
use std::fmt;
use super::ast::{
BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem, SortDirection, TableRef,
AggregateArg, AggregateFunc, BinaryOp, Expr, Literal, OrderByItem, Select, SelectItem,
SortDirection, TableRef,
};
/// Errors returned by the minimal SQL parser.
@ -50,11 +51,14 @@ enum Token {
Desc,
Null,
Limit,
Group,
Identifier(String),
String(String),
Integer(usize),
Star,
Comma,
LParen,
RParen,
Eq,
Ne,
}
@ -87,6 +91,13 @@ impl Parser {
} else {
None
};
let group_by = if self.peek() == Some(&Token::Group) {
self.index += 1;
self.expect_keyword(Token::By, "BY")?;
self.parse_group_by()?
} else {
Vec::new()
};
let order_by = if self.peek() == Some(&Token::Order) {
self.index += 1;
self.expect_keyword(Token::By, "BY")?;
@ -110,11 +121,25 @@ impl Parser {
projection,
from,
selection,
group_by,
order_by,
limit,
})
}
fn parse_group_by(&mut self) -> Result<Vec<Expr>, ParseError> {
let mut items = Vec::new();
loop {
items.push(self.parse_operand()?);
if self.peek() == Some(&Token::Comma) {
self.index += 1;
continue;
}
break;
}
Ok(items)
}
fn parse_projection(&mut self) -> Result<Vec<SelectItem>, ParseError> {
let mut items = Vec::new();
@ -262,7 +287,13 @@ impl Parser {
fn parse_operand(&mut self) -> Result<Expr, ParseError> {
match self.next().ok_or(ParseError::UnexpectedEnd)? {
Token::Identifier(name) => Ok(Expr::Identifier(name)),
Token::Identifier(name) => {
if self.peek() == Some(&Token::LParen) {
self.parse_function_call(name)
} else {
Ok(Expr::Identifier(name))
}
}
Token::String(value) => Ok(Expr::Literal(Literal::String(value))),
Token::Integer(n) => Ok(Expr::Literal(Literal::Integer(n as i64))),
Token::Null => Ok(Expr::Literal(Literal::Null)),
@ -270,6 +301,31 @@ impl Parser {
}
}
fn parse_function_call(&mut self, name: String) -> Result<Expr, ParseError> {
self.expect_keyword(Token::LParen, "(")?;
let func = match name.to_ascii_uppercase().as_str() {
"COUNT" => AggregateFunc::Count,
"SUM" => AggregateFunc::Sum,
"MIN" => AggregateFunc::Min,
"MAX" => AggregateFunc::Max,
"AVG" => AggregateFunc::Avg,
_ => return Err(ParseError::UnexpectedToken(name)),
};
let arg = if self.peek() == Some(&Token::Star) {
self.index += 1;
if !matches!(func, AggregateFunc::Count) {
return Err(ParseError::UnexpectedToken("*".to_string()));
}
AggregateArg::Star
} else {
AggregateArg::Expr(Box::new(self.parse_operand()?))
};
self.expect_keyword(Token::RParen, ")")?;
Ok(Expr::Aggregate { func, arg })
}
fn expect_keyword(&mut self, token: Token, label: &'static str) -> Result<(), ParseError> {
let next = self.next().ok_or(ParseError::UnexpectedEnd)?;
if next == token {
@ -325,6 +381,14 @@ fn tokenize(input: &str) -> Result<Vec<Token>, ParseError> {
chars.next();
tokens.push(Token::Comma);
}
'(' => {
chars.next();
tokens.push(Token::LParen);
}
')' => {
chars.next();
tokens.push(Token::RParen);
}
'!' => {
chars.next();
if chars.peek() == Some(&'=') {
@ -367,6 +431,7 @@ fn tokenize(input: &str) -> Result<Vec<Token>, ParseError> {
"DESC" => Token::Desc,
"NULL" => Token::Null,
"LIMIT" => Token::Limit,
"GROUP" => Token::Group,
_ => Token::Identifier(ident),
};
tokens.push(token);
@ -462,6 +527,9 @@ fn render_token(token: &Token) -> String {
Token::String(value) => format!("'{}'", value),
Token::Star => "*".to_string(),
Token::Comma => ",".to_string(),
Token::LParen => "(".to_string(),
Token::RParen => ")".to_string(),
Token::Group => "GROUP".to_string(),
Token::Eq => "=".to_string(),
Token::Ne => "!=".to_string(),
}

View File

@ -360,3 +360,156 @@ fn execute_with_table_store_scans_in_memory_rows() {
assert_eq!(result.rows().len(), 1);
assert_eq!(format!("{}", result.rows()[0].values()[0]), "bob");
}
#[test]
fn count_star_no_group_by() {
let instance = parent_instance();
let catalog = PredicateCatalog::from_instance(&instance).unwrap();
let select = parse_select("SELECT COUNT(*) FROM Parent").unwrap();
let plan = plan_select(&select, &catalog).unwrap();
let result = execute(&plan, &instance).unwrap();
assert_eq!(result.rows().len(), 1);
assert_eq!(format!("{}", result.rows()[0].values()[0]), "2");
}
#[test]
fn count_star_group_by_one_column() {
use query_engine::execution::TableStore;
use query_engine::relational::{DataType, Field, Row, Schema, Value};
let schema = Schema::new(vec![
Field::new("dept", DataType::Text, false),
Field::new("name", DataType::Text, false),
]);
let mut store = TableStore::new();
store.insert(
"Emp",
schema.clone(),
vec![
Row::new(vec![Value::text("eng"), Value::text("alice")]),
Row::new(vec![Value::text("eng"), Value::text("bob")]),
Row::new(vec![Value::text("sales"), Value::text("carol")]),
],
);
let mut catalog = PredicateCatalog::new();
catalog.register_table("Emp", schema);
let select = parse_select("SELECT dept, COUNT(*) FROM Emp GROUP BY dept").unwrap();
let plan = plan_select(&select, &catalog).unwrap();
let result = execute(&plan, &store).unwrap();
assert_eq!(result.rows().len(), 2);
let mut rows: Vec<(String, String)> = result
.rows()
.iter()
.map(|row| {
(
format!("{}", row.values()[0]),
format!("{}", row.values()[1]),
)
})
.collect();
rows.sort();
assert_eq!(
rows,
vec![
("eng".to_string(), "2".to_string()),
("sales".to_string(), "1".to_string()),
]
);
}
#[test]
fn sum_min_max_avg_over_integer_column() {
use query_engine::execution::TableStore;
use query_engine::relational::{DataType, Field, Row, Schema, Value};
let schema = Schema::new(vec![
Field::new("dept", DataType::Text, false),
Field::new("salary", DataType::Integer, false),
]);
let mut store = TableStore::new();
store.insert(
"Emp",
schema.clone(),
vec![
Row::new(vec![Value::text("eng"), Value::Integer(100)]),
Row::new(vec![Value::text("eng"), Value::Integer(200)]),
Row::new(vec![Value::text("sales"), Value::Integer(50)]),
],
);
let mut catalog = PredicateCatalog::new();
catalog.register_table("Emp", schema);
let select = parse_select(
"SELECT dept, SUM(salary), MIN(salary), MAX(salary), AVG(salary) FROM Emp GROUP BY dept",
)
.unwrap();
let plan = plan_select(&select, &catalog).unwrap();
let result = execute(&plan, &store).unwrap();
assert_eq!(result.rows().len(), 2);
let mut rows: Vec<(String, String, String, String, String)> = result
.rows()
.iter()
.map(|row| {
(
format!("{}", row.values()[0]),
format!("{}", row.values()[1]),
format!("{}", row.values()[2]),
format!("{}", row.values()[3]),
format!("{}", row.values()[4]),
)
})
.collect();
rows.sort();
assert_eq!(
rows[0],
(
"eng".to_string(),
"300".to_string(),
"100".to_string(),
"200".to_string(),
"150".to_string(),
)
);
assert_eq!(
rows[1],
(
"sales".to_string(),
"50".to_string(),
"50".to_string(),
"50".to_string(),
"50".to_string(),
)
);
}
#[test]
fn projection_not_in_group_by_errors() {
use query_engine::execution::TableStore;
use query_engine::relational::{DataType, Field, Schema};
let schema = Schema::new(vec![
Field::new("dept", DataType::Text, false),
Field::new("name", DataType::Text, false),
]);
let mut store = TableStore::new();
store.insert("Emp", schema.clone(), Vec::new());
let mut catalog = PredicateCatalog::new();
catalog.register_table("Emp", schema);
let select = parse_select("SELECT dept, name FROM Emp GROUP BY dept").unwrap();
let err = plan_select(&select, &catalog).unwrap_err();
assert!(
err.to_string()
.contains("not aggregated and not in GROUP BY")
);
}