//! Minimal command language for the query-engine REPL and GUI. use crate::chase::rule::RuleBuilder; use crate::chase::{Atom, Rule, Term}; use crate::sql::ast::Select; use crate::sql::parser::parse_select; #[derive(Debug, Clone)] pub enum Command { Fact(Atom), Rule(Rule), Schema { table: String, columns: Vec }, Sql(Select), Run, Query(Vec), Explain(Vec), ShowFacts, ShowRules, Reset, Help, } pub fn parse_script(input: &str) -> Result, String> { let mut commands = Vec::new(); for (index, raw_line) in input.lines().enumerate() { let line = raw_line.trim(); if line.is_empty() || line.starts_with('#') { continue; } let command = parse_command(line).map_err(|err| format!("line {}: {}", index + 1, err))?; commands.push(command); } Ok(commands) } pub fn parse_command(input: &str) -> Result { let trimmed = input.trim(); if trimmed.eq_ignore_ascii_case("run") || trimmed.eq_ignore_ascii_case("run.") { return Ok(Command::Run); } if trimmed.eq_ignore_ascii_case("show facts") || trimmed.eq_ignore_ascii_case("show facts.") { return Ok(Command::ShowFacts); } if trimmed.eq_ignore_ascii_case("show rules") || trimmed.eq_ignore_ascii_case("show rules.") { return Ok(Command::ShowRules); } if trimmed.eq_ignore_ascii_case("reset") || trimmed.eq_ignore_ascii_case("reset.") { return Ok(Command::Reset); } if trimmed.eq_ignore_ascii_case("help") || trimmed.eq_ignore_ascii_case("help.") { return Ok(Command::Help); } if let Some(rest) = strip_keyword(trimmed, "sql") { let select = parse_select(trim_suffix(rest, ';')?).map_err(|err| err.to_string())?; return Ok(Command::Sql(select)); } if let Some(rest) = strip_keyword(trimmed, "schema") { let atom = parse_atom(trim_suffix(rest, '.')?)?; let columns = atom .terms .into_iter() .map(|term| match term { Term::Constant(name) => Ok(name), Term::Null(_) | Term::Variable(_) => { Err("schema columns must be constant identifiers".to_string()) } }) .collect::, _>>()?; return Ok(Command::Schema { table: atom.predicate, columns, }); } if let Some(rest) = strip_keyword(trimmed, "fact") { let atom = parse_atom(trim_suffix(rest, '.')?)?; if !atom.is_ground() { return Err("facts must be ground atoms".to_string()); } return Ok(Command::Fact(atom)); } if let Some(rest) = strip_keyword(trimmed, "rule") { let rule_text = trim_suffix(rest, '.')?; let arrow = find_top_level_arrow(rule_text) .ok_or_else(|| "rule must contain a top-level `->`".to_string())?; let body_text = rule_text[..arrow].trim(); let head_text = rule_text[arrow + 2..].trim(); if body_text.is_empty() || head_text.is_empty() { return Err("rule body and head must both be non-empty".to_string()); } let body = parse_atom_list(body_text)?; let head = parse_atom_list(head_text)?; let mut builder = RuleBuilder::new(); for atom in body { builder = builder.when(&atom.predicate, atom.terms); } for atom in head { builder = builder.then(&atom.predicate, atom.terms); } return Ok(Command::Rule(builder.build())); } if let Some(rest) = strip_keyword(trimmed, "query") { let atoms = parse_atom_list(trim_suffix(rest, '?')?)?; return Ok(Command::Query(atoms)); } if let Some(rest) = strip_keyword(trimmed, "explain") { let atoms = parse_atom_list(trim_suffix(rest, '?')?)?; return Ok(Command::Explain(atoms)); } Err("unknown command; try `help`".to_string()) } fn strip_keyword<'a>(input: &'a str, keyword: &str) -> Option<&'a str> { let prefix = input.get(..keyword.len())?; if !prefix.eq_ignore_ascii_case(keyword) { return None; } let rest = input.get(keyword.len()..)?; if rest.is_empty() { return Some(rest); } let mut chars = rest.chars(); let first = chars.next()?; if first.is_whitespace() { Some(rest.trim_start()) } else { None } } fn trim_suffix(input: &str, suffix: char) -> Result<&str, String> { let trimmed = input.trim(); if let Some(stripped) = trimmed.strip_suffix(suffix) { Ok(stripped.trim_end()) } else { Err(format!("command must end with `{}`", suffix)) } } fn parse_atom_list(input: &str) -> Result, String> { split_top_level(input, ',')? .into_iter() .map(parse_atom) .collect() } fn parse_atom(input: &str) -> Result { let trimmed = input.trim(); let open = trimmed .find('(') .ok_or_else(|| format!("expected `(` in atom `{}`", trimmed))?; let close = trimmed .rfind(')') .ok_or_else(|| format!("expected `)` in atom `{}`", trimmed))?; if close <= open { return Err(format!("malformed atom `{}`", trimmed)); } if close != trimmed.len() - 1 { return Err(format!("unexpected content after atom `{}`", trimmed)); } let predicate = trimmed[..open].trim(); validate_identifier(predicate, "predicate")?; let args = trimmed[open + 1..close].trim(); let terms = if args.is_empty() { Vec::new() } else { split_top_level(args, ',')? .into_iter() .map(parse_term) .collect::, _>>()? }; Ok(Atom::new(predicate, terms)) } fn parse_term(input: &str) -> Result { let trimmed = input.trim(); if trimmed.is_empty() { return Err("empty term".to_string()); } if let Some(var) = trimmed.strip_prefix('?') { validate_identifier(var, "variable")?; return Ok(Term::var(var)); } if trimmed.starts_with('"') { return parse_string_literal(trimmed).map(Term::constant); } if trimmed.chars().any(char::is_whitespace) { return Err(format!( "constants with spaces must be quoted: `{}`", trimmed )); } validate_identifier(trimmed, "constant")?; Ok(Term::constant(trimmed)) } fn parse_string_literal(input: &str) -> Result { if !input.ends_with('"') || input.len() < 2 { return Err(format!("unterminated string literal `{}`", input)); } let inner = &input[1..input.len() - 1]; let mut value = String::new(); let mut escaped = false; for ch in inner.chars() { if escaped { let translated = match ch { '\\' => '\\', '"' => '"', 'n' => '\n', 't' => '\t', other => { return Err(format!("unsupported escape sequence `\\{}`", other)); } }; value.push(translated); escaped = false; continue; } if ch == '\\' { escaped = true; } else { value.push(ch); } } if escaped { return Err("string literal ends with a trailing escape".to_string()); } Ok(value) } fn validate_identifier(value: &str, label: &str) -> Result<(), String> { if value.is_empty() { return Err(format!("{} cannot be empty", label)); } if value.chars().all(is_identifier_char) { Ok(()) } else { Err(format!("invalid {} `{}`", label, value)) } } fn is_identifier_char(ch: char) -> bool { ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | ':') } fn find_top_level_arrow(input: &str) -> Option { let bytes = input.as_bytes(); let mut depth = 0usize; let mut in_string = false; let mut escaped = false; let mut index = 0usize; while index < bytes.len() { let ch = bytes[index] as char; if in_string { if escaped { escaped = false; } else if ch == '\\' { escaped = true; } else if ch == '"' { in_string = false; } index += 1; continue; } match ch { '"' => in_string = true, '(' => depth += 1, ')' => depth = depth.saturating_sub(1), '-' if depth == 0 && bytes.get(index + 1).copied() == Some(b'>') => { return Some(index); } _ => {} } index += 1; } None } fn split_top_level(input: &str, separator: char) -> Result, String> { let mut parts = Vec::new(); let mut depth = 0usize; let mut in_string = false; let mut escaped = false; let mut start = 0usize; for (index, ch) in input.char_indices() { if in_string { if escaped { escaped = false; } else if ch == '\\' { escaped = true; } else if ch == '"' { in_string = false; } continue; } match ch { '"' => in_string = true, '(' => depth += 1, ')' => { if depth == 0 { return Err(format!("unexpected `)` in `{}`", input)); } depth -= 1; } ch if ch == separator && depth == 0 => { let part = input[start..index].trim(); if part.is_empty() { return Err(format!("empty element in `{}`", input)); } parts.push(part); start = index + ch.len_utf8(); } _ => {} } } if in_string { return Err(format!("unterminated string literal in `{}`", input)); } if depth != 0 { return Err(format!("unbalanced parentheses in `{}`", input)); } let tail = input[start..].trim(); if tail.is_empty() { return Err(format!("empty element in `{}`", input)); } parts.push(tail); Ok(parts) } #[cfg(test)] mod tests { use super::*; #[test] fn parse_fact_command() { let command = parse_command(r#"fact Parent(alice, "bob smith")."#).unwrap(); match command { Command::Fact(atom) => { assert_eq!(atom.predicate, "Parent"); assert_eq!(atom.terms.len(), 2); } other => panic!("unexpected command: {:?}", other), } } #[test] fn parse_fact_command_rejects_variables() { let error = parse_command("fact Parent(?X, bob).").unwrap_err(); assert_eq!(error, "facts must be ground atoms"); } #[test] fn parse_rule_command() { let command = parse_command("rule P(?X), Q(?X, a) -> R(?X).").unwrap(); match command { Command::Rule(rule) => { assert_eq!(rule.body.len(), 2); assert_eq!(rule.head.len(), 1); } other => panic!("unexpected command: {:?}", other), } } #[test] fn parse_sql_command() { let command = parse_command("sql SELECT c0 FROM Parent WHERE c1 = 'bob';").unwrap(); match command { Command::Sql(select) => { assert_eq!(select.from, vec!["Parent".to_string()]); assert!(select.selection.is_some()); } other => panic!("unexpected command: {:?}", other), } } #[test] fn parse_sql_join_command() { let command = parse_command( "sql SELECT Parent.parent FROM Parent, Ancestor WHERE Parent.child = Ancestor.parent;", ) .unwrap(); match command { Command::Sql(select) => { assert_eq!( select.from, vec!["Parent".to_string(), "Ancestor".to_string()] ); } other => panic!("unexpected command: {:?}", other), } } #[test] fn parse_schema_command() { let command = parse_command("schema Parent(parent, child).").unwrap(); match command { Command::Schema { table, columns } => { assert_eq!(table, "Parent"); assert_eq!(columns, vec!["parent".to_string(), "child".to_string()]); } other => panic!("unexpected command: {:?}", other), } } #[test] fn parse_query_command() { let command = parse_command("query Ancestor(?X, ?Y), Parent(?Y, ?Z)?").unwrap(); match command { Command::Query(atoms) => assert_eq!(atoms.len(), 2), other => panic!("unexpected command: {:?}", other), } } #[test] fn parse_explain_command() { let command = parse_command("explain Ancestor(alice, carol)?").unwrap(); match command { Command::Explain(atoms) => assert_eq!(atoms.len(), 1), other => panic!("unexpected command: {:?}", other), } } #[test] fn parse_script_reports_line_numbers() { let error = parse_script("help\nbogus\nrun.").unwrap_err(); assert!(error.contains("line 2")); } }