191 lines
5.9 KiB
Rust
Raw Normal View History

//! Atom operator: scan a [`Table`] under an [`AtomPattern`] and return a
//! binding [`Relation`].
//!
//! An atom pattern specifies, for each table column, either a variable to bind
//! or a literal that the cell must equal. A variable appearing in more than one
//! column forces those cells to be equal (so `Edge(X, X)` keeps only
//! self-loops). The output relation has one column per distinct variable, in
//! first-occurrence order.
use std::collections::HashMap;
use crate::{relation::Relation, table::Table, value::Value};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Term {
Var(String),
Lit(Value),
}
#[derive(Debug, Clone)]
pub struct AtomPattern {
pub columns: Vec<Term>,
}
/// # Panics
/// Panics if `pattern.columns.len() != table.arity`.
#[must_use]
pub fn scan_atom(table: &Table, pattern: &AtomPattern) -> Relation {
assert_eq!(
pattern.columns.len(),
table.arity,
"pattern arity mismatch: pattern has {}, table has {}",
pattern.columns.len(),
table.arity,
);
let mut output_vars: Vec<String> = Vec::new();
let mut output_positions: Vec<usize> = Vec::new();
let mut equality_pairs: Vec<(usize, usize)> = Vec::new();
let mut literal_checks: Vec<(usize, &Value)> = Vec::new();
let mut first_position: HashMap<&str, usize> = HashMap::new();
for (i, term) in pattern.columns.iter().enumerate() {
match term {
Term::Var(name) => {
if let Some(&j) = first_position.get(name.as_str()) {
equality_pairs.push((j, i));
} else {
first_position.insert(name.as_str(), i);
output_vars.push(name.clone());
output_positions.push(i);
}
}
Term::Lit(value) => literal_checks.push((i, value)),
}
}
let mut output = Relation::new(output_vars);
'rows: for row in &table.rows {
for &(i, lit) in &literal_checks {
if &row[i] != lit {
continue 'rows;
}
}
for &(j, i) in &equality_pairs {
if row[i] != row[j] {
continue 'rows;
}
}
let projected: Vec<Value> = output_positions.iter().map(|&i| row[i].clone()).collect();
output.push(projected);
}
output
}
#[cfg(test)]
mod tests {
use super::*;
fn var(name: &str) -> Term {
Term::Var(name.to_string())
}
fn lit(value: i64) -> Term {
Term::Lit(Value::Int(value))
}
fn int(value: i64) -> Value {
Value::Int(value)
}
#[test]
fn repeated_variable_keeps_only_self_loops() {
let edge = Table::from_rows(
2,
vec![
vec![int(1), int(2)],
vec![int(2), int(2)],
vec![int(3), int(3)],
vec![int(1), int(1)],
],
);
let pattern = AtomPattern {
columns: vec![var("X"), var("X")],
};
let result = scan_atom(&edge, &pattern);
assert_eq!(result.columns, vec!["X".to_string()]);
assert_eq!(result.rows, vec![vec![int(2)], vec![int(3)], vec![int(1)]]);
}
#[test]
fn literal_filters_rows_to_match() {
let edge = Table::from_rows(
2,
vec![
vec![int(1), int(2)],
vec![int(2), int(3)],
vec![int(1), int(4)],
],
);
let pattern = AtomPattern {
columns: vec![lit(1), var("Y")],
};
let result = scan_atom(&edge, &pattern);
assert_eq!(result.columns, vec!["Y".to_string()]);
assert_eq!(result.rows, vec![vec![int(2)], vec![int(4)]]);
}
#[test]
fn distinct_variables_project_in_first_occurrence_order() {
let triples = Table::from_rows(
3,
vec![vec![int(1), int(2), int(3)], vec![int(4), int(5), int(6)]],
);
let pattern = AtomPattern {
columns: vec![var("A"), var("B"), var("C")],
};
let result = scan_atom(&triples, &pattern);
assert_eq!(
result.columns,
vec!["A".to_string(), "B".to_string(), "C".to_string()],
);
assert_eq!(
result.rows,
vec![vec![int(1), int(2), int(3)], vec![int(4), int(5), int(6)]],
);
}
#[test]
fn variable_repeated_three_times_requires_all_equal() {
let triples = Table::from_rows(
3,
vec![
vec![int(1), int(1), int(1)],
vec![int(1), int(1), int(2)],
vec![int(2), int(2), int(2)],
vec![int(1), int(2), int(1)],
],
);
let pattern = AtomPattern {
columns: vec![var("X"), var("X"), var("X")],
};
let result = scan_atom(&triples, &pattern);
assert_eq!(result.columns, vec!["X".to_string()]);
assert_eq!(result.rows, vec![vec![int(1)], vec![int(2)]]);
}
#[test]
fn literal_filter_repeated_var_and_projection_combine() {
// Pattern: [Lit(1), Var("X"), Lit(2), Var("X")].
// Keep rows where col0 == 1, col2 == 2, and col1 == col3.
// Output is one column [X], bound to col1 (the first occurrence).
let table = Table::from_rows(
4,
vec![
vec![int(1), int(7), int(2), int(7)],
vec![int(1), int(7), int(2), int(8)],
vec![int(0), int(7), int(2), int(7)],
vec![int(1), int(7), int(3), int(7)],
vec![int(1), int(9), int(2), int(9)],
],
);
let pattern = AtomPattern {
columns: vec![lit(1), var("X"), lit(2), var("X")],
};
let result = scan_atom(&table, &pattern);
assert_eq!(result.columns, vec!["X".to_string()]);
assert_eq!(result.rows, vec![vec![int(7)], vec![int(9)]]);
}
}