From 0abcdd7e4358cb902c320b008d3c04bde07b749e Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Thu, 26 Feb 2026 19:28:22 +0100 Subject: Add Rust implementation of SCFG decoder Rust port of the Ruby prototype decoder with performance optimizations for real Hiero-style grammars: - Rule indexing by first terminal/NT symbol for fast lookup - Chart symbol interning (u16 IDs) instead of string hashing - Passive chart index by (symbol, left) for direct right-endpoint lookup - Items store rule index instead of cloned rule data Includes CKY+ parser, chart-to-hypergraph conversion, Viterbi decoding, derivation extraction, and JSON hypergraph I/O. Self-filling step in parse uses grammar lookup (not just remaining active items) to handle rules that were consumed during the parse loop or skipped by the has_any_at optimization. Produces identical output to the Ruby prototype on all test examples. Co-Authored-By: Claude Opus 4.6 --- rs/src/chart_to_hg.rs | 95 +++++++++++ rs/src/grammar.rs | 393 +++++++++++++++++++++++++++++++++++++++++++++ rs/src/hypergraph.rs | 115 +++++++++++++ rs/src/hypergraph_algos.rs | 215 +++++++++++++++++++++++++ rs/src/hypergraph_io.rs | 110 +++++++++++++ rs/src/lib.rs | 8 + rs/src/main.rs | 110 +++++++++++++ rs/src/parse.rs | 347 +++++++++++++++++++++++++++++++++++++++ rs/src/semiring.rs | 39 +++++ rs/src/sparse_vector.rs | 86 ++++++++++ 10 files changed, 1518 insertions(+) create mode 100644 rs/src/chart_to_hg.rs create mode 100644 rs/src/grammar.rs create mode 100644 rs/src/hypergraph.rs create mode 100644 rs/src/hypergraph_algos.rs create mode 100644 rs/src/hypergraph_io.rs create mode 100644 rs/src/lib.rs create mode 100644 rs/src/main.rs create mode 100644 rs/src/parse.rs create mode 100644 rs/src/semiring.rs create mode 100644 rs/src/sparse_vector.rs (limited to 'rs/src') diff --git a/rs/src/chart_to_hg.rs b/rs/src/chart_to_hg.rs new file mode 100644 index 0000000..95f642e --- /dev/null +++ b/rs/src/chart_to_hg.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; + +use crate::grammar::{Grammar, Rule, Symbol}; +use crate::hypergraph::{Hypergraph, NodeId}; +use crate::parse::{Chart, visit}; +use crate::sparse_vector::SparseVector; + +pub fn chart_to_hg(chart: &Chart, n: usize, weights: &SparseVector, grammar: &Grammar) -> Hypergraph { + let mut hg = Hypergraph::new(); + let mut seen: HashMap = HashMap::new(); + + // Add root node with id=-1 + let root_nid = hg.add_node(-1, "root", -1, -1); + + let mut next_id: i64 = 0; + + // First pass: create nodes + visit(1, 0, n, 0, |i, j| { + for item in chart.at(i, j) { + let lhs_sym = grammar.rules[item.rule_idx].lhs.nt_symbol(); + let key = format!("{},{},{}", lhs_sym, i, j); + if !seen.contains_key(&key) { + let nid = hg.add_node(next_id, lhs_sym, i as i32, j as i32); + seen.insert(key, nid); + next_id += 1; + } + } + }); + + // Second pass: create edges + visit(1, 0, n, 0, |i, j| { + for item in chart.at(i, j) { + let rule = &grammar.rules[item.rule_idx]; + let head_key = format!("{},{},{}", rule.lhs.nt_symbol(), i, j); + let head_nid = *seen.get(&head_key).unwrap(); + + // Build tails + let tails: Vec = if item.tail_spans.iter().all(|s| s.is_none()) + || item.tail_spans.iter().all(|s| { + s.as_ref() + .map_or(true, |sp| sp.left == -1 && sp.right == -1) + }) + { + if rule.rhs.iter().any(|s| s.is_nt()) { + build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid) + } else { + vec![root_nid] + } + } else { + build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid) + }; + + let score = weights.dot(&rule.f).exp(); + let rule_clone = Rule::new( + rule.lhs.clone(), + rule.rhs.clone(), + rule.target.clone(), + rule.map.clone(), + rule.f.clone(), + ); + + hg.add_edge(head_nid, tails, score, rule.f.clone(), rule_clone); + } + }); + + hg +} + +fn build_tails( + rhs: &[Symbol], + tail_spans: &[Option], + seen: &HashMap, + root_nid: NodeId, +) -> Vec { + let mut tails = Vec::new(); + let mut has_nt = false; + for (idx, sym) in rhs.iter().enumerate() { + if let Symbol::NT { symbol, .. } = sym { + has_nt = true; + if let Some(Some(span)) = tail_spans.get(idx) { + if span.left >= 0 && span.right >= 0 { + let key = format!("{},{},{}", symbol, span.left, span.right); + if let Some(&nid) = seen.get(&key) { + tails.push(nid); + } + } + } + } + } + if !has_nt { + vec![root_nid] + } else { + tails + } +} diff --git a/rs/src/grammar.rs b/rs/src/grammar.rs new file mode 100644 index 0000000..cef49ee --- /dev/null +++ b/rs/src/grammar.rs @@ -0,0 +1,393 @@ +use std::fmt; +use std::fs; +use std::io::{self, BufRead}; + +use crate::sparse_vector::SparseVector; + +#[derive(Debug, Clone, PartialEq)] +pub enum Symbol { + T(String), + NT { symbol: String, index: i32 }, +} + +impl Symbol { + pub fn parse(s: &str) -> Symbol { + let s = s.trim(); + if s.starts_with('[') && s.ends_with(']') { + let inner = &s[1..s.len() - 1]; + if let Some((sym, idx)) = inner.split_once(',') { + Symbol::NT { + symbol: sym.trim().to_string(), + index: idx.trim().parse::().unwrap() - 1, + } + } else { + Symbol::NT { + symbol: inner.trim().to_string(), + index: -1, + } + } + } else { + Symbol::T(s.to_string()) + } + } + + pub fn is_nt(&self) -> bool { + matches!(self, Symbol::NT { .. }) + } + + pub fn is_t(&self) -> bool { + matches!(self, Symbol::T(_)) + } + + pub fn word(&self) -> &str { + match self { + Symbol::T(w) => w, + Symbol::NT { .. } => panic!("word() called on NT"), + } + } + + pub fn nt_symbol(&self) -> &str { + match self { + Symbol::NT { symbol, .. } => symbol, + Symbol::T(_) => panic!("nt_symbol() called on T"), + } + } + + pub fn nt_index(&self) -> i32 { + match self { + Symbol::NT { index, .. } => *index, + Symbol::T(_) => panic!("nt_index() called on T"), + } + } +} + +impl fmt::Display for Symbol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Symbol::T(w) => write!(f, "{}", w), + Symbol::NT { symbol, index } if *index >= 0 => { + write!(f, "[{},{}]", symbol, index + 1) + } + Symbol::NT { symbol, .. } => write!(f, "[{}]", symbol), + } + } +} + +fn splitpipe(s: &str, n: usize) -> Vec<&str> { + let mut parts = Vec::new(); + let mut rest = s; + for _ in 0..n { + if let Some(pos) = rest.find("|||") { + parts.push(rest[..pos].trim()); + rest = rest[pos + 3..].trim(); + } else { + break; + } + } + parts.push(rest.trim()); + parts +} + +#[derive(Debug, Clone)] +pub struct Rule { + pub lhs: Symbol, + pub rhs: Vec, + pub target: Vec, + pub map: Vec, + pub f: SparseVector, + pub arity: usize, +} + +impl Rule { + pub fn new( + lhs: Symbol, + rhs: Vec, + target: Vec, + map: Vec, + f: SparseVector, + ) -> Self { + let arity = rhs.iter().filter(|s| s.is_nt()).count(); + Self { + lhs, + rhs, + target, + map, + f, + arity, + } + } + + pub fn from_str(s: &str) -> Self { + let parts = splitpipe(s, 3); + let lhs = Symbol::parse(parts[0]); + + let mut map = Vec::new(); + let mut arity = 0; + let rhs: Vec = parts[1] + .split_whitespace() + .map(|tok| { + let sym = Symbol::parse(tok); + if let Symbol::NT { index, .. } = &sym { + map.push(*index); + arity += 1; + } + sym + }) + .collect(); + + let target: Vec = parts[2] + .split_whitespace() + .map(|tok| Symbol::parse(tok)) + .collect(); + + let f = if parts.len() > 3 && !parts[3].is_empty() { + SparseVector::from_kv(parts[3], '=', ' ') + } else { + SparseVector::new() + }; + + Self { + lhs, + rhs, + target, + map, + f, + arity, + } + } +} + +impl fmt::Display for Rule { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} ||| {} ||| {}", + self.lhs, + self.rhs + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(" "), + self.target + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(" "), + ) + } +} + +pub struct Grammar { + pub rules: Vec, + pub start_nt: Vec, + pub start_t: Vec, + pub flat: Vec, + pub start_t_by_word: std::collections::HashMap>, + pub flat_by_first_word: std::collections::HashMap>, + pub start_nt_by_symbol: std::collections::HashMap>, +} + +impl Grammar { + pub fn from_file(path: &str) -> io::Result { + let file = fs::File::open(path)?; + let reader = io::BufReader::new(file); + let mut rules = Vec::new(); + let mut start_nt = Vec::new(); + let mut start_t = Vec::new(); + let mut flat = Vec::new(); + let mut start_t_by_word: std::collections::HashMap> = + std::collections::HashMap::new(); + let mut flat_by_first_word: std::collections::HashMap> = + std::collections::HashMap::new(); + let mut start_nt_by_symbol: std::collections::HashMap> = + std::collections::HashMap::new(); + + for line in reader.lines() { + let line = line?; + let line = line.trim().to_string(); + if line.is_empty() { + continue; + } + let idx = rules.len(); + let rule = Rule::from_str(&line); + if rule.rhs.first().map_or(false, |s| s.is_nt()) { + start_nt_by_symbol + .entry(rule.rhs[0].nt_symbol().to_string()) + .or_default() + .push(idx); + start_nt.push(idx); + } else if rule.arity == 0 { + flat_by_first_word + .entry(rule.rhs[0].word().to_string()) + .or_default() + .push(idx); + flat.push(idx); + } else { + start_t_by_word + .entry(rule.rhs[0].word().to_string()) + .or_default() + .push(idx); + start_t.push(idx); + } + rules.push(rule); + } + + Ok(Self { + rules, + start_nt, + start_t, + flat, + start_t_by_word, + flat_by_first_word, + start_nt_by_symbol, + }) + } + + pub fn add_glue_rules(&mut self) { + let symbols: Vec = self + .rules + .iter() + .map(|r| r.lhs.nt_symbol().to_string()) + .filter(|s| s != "S") + .collect::>() + .into_iter() + .collect(); + + let mut once = false; + for symbol in symbols { + let idx = self.rules.len(); + self.rules.push(Rule::new( + Symbol::NT { + symbol: "S".to_string(), + index: -1, + }, + vec![Symbol::NT { + symbol: symbol.clone(), + index: 0, + }], + vec![Symbol::NT { + symbol: symbol.clone(), + index: 0, + }], + vec![0], + SparseVector::new(), + )); + self.start_nt.push(idx); + self.start_nt_by_symbol + .entry(symbol) + .or_default() + .push(idx); + once = true; + } + + if once { + let idx = self.rules.len(); + self.rules.push(Rule::new( + Symbol::NT { + symbol: "S".to_string(), + index: -1, + }, + vec![ + Symbol::NT { + symbol: "S".to_string(), + index: 0, + }, + Symbol::NT { + symbol: "X".to_string(), + index: 1, + }, + ], + vec![ + Symbol::NT { + symbol: "S".to_string(), + index: 0, + }, + Symbol::NT { + symbol: "X".to_string(), + index: 1, + }, + ], + vec![0, 1], + SparseVector::new(), + )); + self.start_nt.push(idx); + self.start_nt_by_symbol + .entry("S".to_string()) + .or_default() + .push(idx); + } + } + + pub fn add_pass_through_rules(&mut self, words: &[String]) { + for word in words { + let idx = self.rules.len(); + self.rules.push(Rule::new( + Symbol::NT { + symbol: "X".to_string(), + index: -1, + }, + vec![Symbol::T(word.clone())], + vec![Symbol::T(word.clone())], + vec![], + SparseVector::new(), + )); + self.flat.push(idx); + self.flat_by_first_word + .entry(word.clone()) + .or_default() + .push(idx); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symbol_parse_t() { + let s = Symbol::parse("hello"); + assert_eq!(s, Symbol::T("hello".to_string())); + } + + #[test] + fn test_symbol_parse_nt() { + let s = Symbol::parse("[NP,1]"); + assert_eq!( + s, + Symbol::NT { + symbol: "NP".to_string(), + index: 0 + } + ); + } + + #[test] + fn test_symbol_parse_nt_no_index() { + let s = Symbol::parse("[S]"); + assert_eq!( + s, + Symbol::NT { + symbol: "S".to_string(), + index: -1 + } + ); + } + + #[test] + fn test_rule_from_str() { + let r = Rule::from_str("[NP] ||| ein [NN,1] ||| a [NN,1] ||| logp=0 use_a=1.0"); + assert_eq!(r.lhs.nt_symbol(), "NP"); + assert_eq!(r.rhs.len(), 2); + assert_eq!(r.target.len(), 2); + assert_eq!(r.map, vec![0]); + assert_eq!(r.arity, 1); + assert_eq!(*r.f.map.get("use_a").unwrap(), 1.0); + } + + #[test] + fn test_rule_display() { + let r = Rule::from_str("[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2] ||| logp=0"); + assert_eq!(r.to_string(), "[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2]"); + } +} diff --git a/rs/src/hypergraph.rs b/rs/src/hypergraph.rs new file mode 100644 index 0000000..90069b0 --- /dev/null +++ b/rs/src/hypergraph.rs @@ -0,0 +1,115 @@ +use crate::grammar::Rule; +use crate::sparse_vector::SparseVector; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct NodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct EdgeId(pub usize); + +#[derive(Debug, Clone)] +pub struct Node { + pub id: i64, + pub symbol: String, + pub left: i32, + pub right: i32, + pub outgoing: Vec, + pub incoming: Vec, + pub score: f64, +} + +impl Node { + pub fn new(id: i64, symbol: &str, left: i32, right: i32) -> Self { + Self { + id, + symbol: symbol.to_string(), + left, + right, + outgoing: Vec::new(), + incoming: Vec::new(), + score: 0.0, + } + } +} + +#[derive(Debug, Clone)] +pub struct Edge { + pub head: NodeId, + pub tails: Vec, + pub score: f64, + pub f: SparseVector, + pub mark: usize, + pub rule: Rule, +} + +impl Edge { + pub fn arity(&self) -> usize { + self.tails.len() + } + + pub fn marked(&self) -> bool { + self.arity() == self.mark + } +} + +#[derive(Debug)] +pub struct Hypergraph { + pub nodes: Vec, + pub edges: Vec, + pub nodes_by_id: std::collections::HashMap, +} + +impl Hypergraph { + pub fn new() -> Self { + Self { + nodes: Vec::new(), + edges: Vec::new(), + nodes_by_id: std::collections::HashMap::new(), + } + } + + pub fn add_node(&mut self, id: i64, symbol: &str, left: i32, right: i32) -> NodeId { + let nid = NodeId(self.nodes.len()); + self.nodes.push(Node::new(id, symbol, left, right)); + self.nodes_by_id.insert(id, nid); + nid + } + + pub fn add_edge( + &mut self, + head: NodeId, + tails: Vec, + score: f64, + f: SparseVector, + rule: Rule, + ) -> EdgeId { + let eid = EdgeId(self.edges.len()); + self.edges.push(Edge { + head, + tails: tails.clone(), + score, + f, + mark: 0, + rule, + }); + self.nodes[head.0].incoming.push(eid); + for &t in &tails { + self.nodes[t.0].outgoing.push(eid); + } + eid + } + + pub fn node(&self, nid: NodeId) -> &Node { + &self.nodes[nid.0] + } + + pub fn edge(&self, eid: EdgeId) -> &Edge { + &self.edges[eid.0] + } + + pub fn reset(&mut self) { + for e in &mut self.edges { + e.mark = 0; + } + } +} diff --git a/rs/src/hypergraph_algos.rs b/rs/src/hypergraph_algos.rs new file mode 100644 index 0000000..a4878b6 --- /dev/null +++ b/rs/src/hypergraph_algos.rs @@ -0,0 +1,215 @@ +use std::collections::{BTreeSet, HashSet, VecDeque}; + +use crate::grammar::Symbol; +use crate::hypergraph::{EdgeId, Hypergraph, NodeId}; +use crate::semiring::Semiring; + +pub fn topological_sort(hg: &mut Hypergraph) -> Vec { + let mut sorted = Vec::new(); + let mut queue: VecDeque = VecDeque::new(); + + // Start with nodes that have no incoming edges + for (i, node) in hg.nodes.iter().enumerate() { + if node.incoming.is_empty() { + queue.push_back(NodeId(i)); + } + } + + while let Some(nid) = queue.pop_front() { + sorted.push(nid); + let outgoing = hg.nodes[nid.0].outgoing.clone(); + for &eid in &outgoing { + let edge = &mut hg.edges[eid.0]; + if edge.marked() { + continue; + } + edge.mark += 1; + if edge.marked() { + let head = edge.head; + // Check if all incoming edges of head are marked + let all_marked = hg.nodes[head.0] + .incoming + .iter() + .all(|&ie| hg.edges[ie.0].marked()); + if all_marked { + queue.push_back(head); + } + } + } + } + + sorted +} + +pub fn viterbi_path(hg: &mut Hypergraph) -> (Vec, f64) { + let toposorted = topological_sort(hg); + + // Init + for node in &mut hg.nodes { + node.score = S::null(); + } + if let Some(&first) = toposorted.first() { + hg.nodes[first.0].score = S::one(); + } + + let mut best_path: Vec = Vec::new(); + + for &nid in &toposorted { + let incoming = hg.nodes[nid.0].incoming.clone(); + let mut best_edge: Option = None; + for &eid in &incoming { + let edge = &hg.edges[eid.0]; + let mut s = S::one(); + for &tid in &edge.tails { + s = S::multiply(s, hg.nodes[tid.0].score); + } + let candidate = S::multiply(s, edge.score); + if hg.nodes[nid.0].score < candidate { + best_edge = Some(eid); + } + hg.nodes[nid.0].score = S::add(hg.nodes[nid.0].score, candidate); + } + if let Some(e) = best_edge { + best_path.push(e); + } + } + + let final_score = toposorted + .last() + .map(|&nid| hg.nodes[nid.0].score) + .unwrap_or(0.0); + + (best_path, final_score) +} + +pub fn derive(hg: &Hypergraph, path: &[EdgeId], cur: NodeId, carry: &mut Vec) { + // Find edge in path whose head matches cur + let node = &hg.nodes[cur.0]; + let edge_id = path + .iter() + .find(|&&eid| { + let e = &hg.edges[eid.0]; + let h = &hg.nodes[e.head.0]; + h.symbol == node.symbol && h.left == node.left && h.right == node.right + }) + .expect("derive: no matching edge found"); + + let edge = &hg.edges[edge_id.0]; + for sym in &edge.rule.target { + match sym { + Symbol::NT { index, .. } => { + // Find which tail to recurse into using map + let tail_idx = edge + .rule + .map + .iter() + .position(|&m| m == *index) + .expect("derive: NT index not found in map"); + derive(hg, path, edge.tails[tail_idx], carry); + } + Symbol::T(word) => { + carry.push(word.clone()); + } + } + } +} + +pub fn all_paths(hg: &mut Hypergraph) -> Vec> { + let toposorted = topological_sort(hg); + + let mut paths: Vec> = vec![vec![]]; + + for &nid in &toposorted { + let incoming = hg.nodes[nid.0].incoming.clone(); + if incoming.is_empty() { + continue; + } + let mut new_paths = Vec::new(); + while let Some(p) = paths.pop() { + for &eid in &incoming { + let mut np = p.clone(); + np.push(eid); + new_paths.push(np); + } + } + paths = new_paths; + } + + // Dedup by reachable edge set + let mut seen: HashSet> = HashSet::new(); + paths + .into_iter() + .filter(|p| { + if p.is_empty() { + return false; + } + let mut reachable = BTreeSet::new(); + mark_reachable(hg, p, *p.last().unwrap(), &mut reachable); + let key: Vec = reachable.into_iter().map(|eid| eid.0).collect(); + seen.insert(key) + }) + .collect() +} + +fn mark_reachable( + hg: &Hypergraph, + path: &[EdgeId], + edge_id: EdgeId, + used: &mut BTreeSet, +) { + used.insert(edge_id); + let edge = &hg.edges[edge_id.0]; + for &tail_nid in &edge.tails { + // Find edge in path whose head is this tail node + if let Some(&child_eid) = path.iter().find(|&&eid| hg.edges[eid.0].head == tail_nid) { + if !used.contains(&child_eid) { + mark_reachable(hg, path, child_eid, used); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hypergraph_io::read_hypergraph_from_json; + use crate::semiring::ViterbiSemiring; + + #[test] + fn test_viterbi_toy() { + let mut hg = read_hypergraph_from_json("../example/toy/toy.json", true).unwrap(); + let (path, score) = viterbi_path::(&mut hg); + let mut carry = Vec::new(); + let last_head = hg.edges[path.last().unwrap().0].head; + derive(&hg, &path, last_head, &mut carry); + + let translation = carry.join(" "); + let log_score = score.ln(); + + assert_eq!(translation, "i saw a small shell"); + assert!((log_score - (-0.5)).abs() < 1e-9); + } + + #[test] + fn test_all_paths_toy() { + let mut hg = read_hypergraph_from_json("../example/toy/toy.json", true).unwrap(); + + let paths = all_paths(&mut hg); + // The toy hypergraph should have multiple distinct paths + assert!(paths.len() > 1); + + // Collect all translations + let mut translations: Vec = Vec::new(); + for p in &paths { + let mut carry = Vec::new(); + let last_head = hg.edges[p.last().unwrap().0].head; + derive(&hg, p, last_head, &mut carry); + translations.push(carry.join(" ")); + } + + assert!(translations.contains(&"i saw a small shell".to_string())); + assert!(translations.contains(&"i saw a small house".to_string())); + assert!(translations.contains(&"i saw a little shell".to_string())); + assert!(translations.contains(&"i saw a little house".to_string())); + } +} diff --git a/rs/src/hypergraph_io.rs b/rs/src/hypergraph_io.rs new file mode 100644 index 0000000..5aa9f3b --- /dev/null +++ b/rs/src/hypergraph_io.rs @@ -0,0 +1,110 @@ +use std::fs; +use std::io; + +use crate::grammar::Rule; +use crate::hypergraph::{Hypergraph, NodeId}; +use crate::sparse_vector::SparseVector; + +pub fn read_hypergraph_from_json(path: &str, log_weights: bool) -> io::Result { + let content = fs::read_to_string(path)?; + let parsed: serde_json::Value = serde_json::from_str(&content) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let mut hg = Hypergraph::new(); + + let weights = match parsed.get("weights").and_then(|w| w.as_object()) { + Some(obj) => SparseVector::from_hash(obj), + None => SparseVector::new(), + }; + + // Add nodes + if let Some(nodes) = parsed.get("nodes").and_then(|n| n.as_array()) { + for n in nodes { + let id = n["id"].as_i64().unwrap(); + let cat = n["cat"].as_str().unwrap_or(""); + let span = n["span"].as_array().unwrap(); + let left = span[0].as_i64().unwrap() as i32; + let right = span[1].as_i64().unwrap() as i32; + hg.add_node(id, cat, left, right); + } + } + + // Add edges + if let Some(edges) = parsed.get("edges").and_then(|e| e.as_array()) { + for e in edges { + let head_id = e["head"].as_i64().unwrap(); + let head_nid = *hg.nodes_by_id.get(&head_id).unwrap(); + + let tails: Vec = e["tails"] + .as_array() + .unwrap() + .iter() + .map(|t| *hg.nodes_by_id.get(&t.as_i64().unwrap()).unwrap()) + .collect(); + + let f = match e.get("f").and_then(|f| f.as_object()) { + Some(obj) => SparseVector::from_hash(obj), + None => SparseVector::new(), + }; + + let rule_str = e["rule"].as_str().unwrap(); + let rule = Rule::from_str(rule_str); + + let score = if log_weights { + weights.dot(&f).exp() + } else { + weights.dot(&f) + }; + + hg.add_edge(head_nid, tails, score, f, rule); + } + } + + Ok(hg) +} + +pub fn write_hypergraph_to_json(hg: &Hypergraph, weights: &SparseVector) -> String { + let mut json_s = String::from("{\n"); + json_s.push_str(&format!( + "\"weights\":{},\n", + serde_json::to_string(&weights.to_json()).unwrap() + )); + + json_s.push_str("\"nodes\":\n[\n"); + let node_strs: Vec = hg + .nodes + .iter() + .map(|n| { + format!( + "{{ \"id\":{}, \"cat\":\"{}\", \"span\":[{},{}] }}", + n.id, + n.symbol.replace('"', "\\\""), + n.left, + n.right + ) + }) + .collect(); + json_s.push_str(&node_strs.join(",\n")); + json_s.push_str("\n],\n"); + + json_s.push_str("\"edges\":\n[\n"); + let edge_strs: Vec = hg + .edges + .iter() + .map(|e| { + let head_id = hg.nodes[e.head.0].id; + let tail_ids: Vec = e.tails.iter().map(|t| hg.nodes[t.0].id).collect(); + format!( + "{{ \"head\":{}, \"rule\":\"{}\", \"tails\":{:?}, \"f\":{} }}", + head_id, + e.rule.to_string().replace('"', "\\\""), + tail_ids, + serde_json::to_string(&e.f.to_json()).unwrap() + ) + }) + .collect(); + json_s.push_str(&edge_strs.join(",\n")); + json_s.push_str("\n]\n}\n"); + + json_s +} diff --git a/rs/src/lib.rs b/rs/src/lib.rs new file mode 100644 index 0000000..d446e19 --- /dev/null +++ b/rs/src/lib.rs @@ -0,0 +1,8 @@ +pub mod chart_to_hg; +pub mod grammar; +pub mod hypergraph; +pub mod hypergraph_algos; +pub mod hypergraph_io; +pub mod parse; +pub mod semiring; +pub mod sparse_vector; diff --git a/rs/src/main.rs b/rs/src/main.rs new file mode 100644 index 0000000..f03dbf5 --- /dev/null +++ b/rs/src/main.rs @@ -0,0 +1,110 @@ +use std::fs; +use std::io::{self, BufRead}; + +use clap::Parser; + +use odenwald::chart_to_hg::chart_to_hg; +use odenwald::grammar::Grammar; +use odenwald::hypergraph_algos::{derive, viterbi_path}; +use odenwald::parse::{self, Chart}; +use odenwald::semiring::ViterbiSemiring; +use odenwald::sparse_vector::SparseVector; + +#[derive(Parser)] +#[command(name = "odenwald")] +struct Cli { + /// Grammar file + #[arg(short = 'g', long)] + grammar: String, + + /// Weights file + #[arg(short = 'w', long)] + weights: String, + + /// Input file (default: stdin) + #[arg(short = 'i', long, default_value = "-")] + input: String, + + /// Add glue rules + #[arg(short = 'l', long)] + add_glue: bool, + + /// Add pass-through rules + #[arg(short = 'p', long)] + add_pass_through: bool, +} + +fn main() { + let cli = Cli::parse(); + + eprintln!("> reading grammar '{}'", cli.grammar); + let mut grammar = Grammar::from_file(&cli.grammar).expect("Failed to read grammar"); + + if cli.add_glue { + eprintln!(">> adding glue rules"); + grammar.add_glue_rules(); + } + + let weights_content = fs::read_to_string(&cli.weights).expect("Failed to read weights"); + let weights = SparseVector::from_kv(&weights_content, ' ', '\n'); + + let input_lines: Vec = if cli.input == "-" { + let stdin = io::stdin(); + stdin.lock().lines().map(|l| l.unwrap()).collect() + } else { + let content = fs::read_to_string(&cli.input).expect("Failed to read input"); + content.lines().map(|l| l.to_string()).collect() + }; + + for line in &input_lines { + let line = line.trim(); + if line.is_empty() { + continue; + } + let input: Vec = line.split_whitespace().map(|s| s.to_string()).collect(); + let n = input.len(); + + if cli.add_pass_through { + eprintln!(">> adding pass-through rules"); + grammar.add_pass_through_rules(&input); + } + + eprintln!("> initializing charts"); + let mut passive_chart = Chart::new(n); + let mut active_chart = Chart::new(n); + parse::init(&input, n, &mut active_chart, &mut passive_chart, &grammar); + + eprintln!("> parsing"); + parse::parse(&input, n, &mut active_chart, &mut passive_chart, &grammar); + + // Count passive chart items + let mut chart_items = 0; + let mut chart_cells = 0; + parse::visit(1, 0, n, 0, |i, j| { + let items = passive_chart.at(i, j).len(); + if items > 0 { + chart_cells += 1; + chart_items += items; + } + }); + eprintln!("> chart: {} items in {} cells", chart_items, chart_cells); + + let mut hg = chart_to_hg(&passive_chart, n, &weights, &grammar); + eprintln!("> hg: {} nodes, {} edges", hg.nodes.len(), hg.edges.len()); + + eprintln!("> viterbi"); + let (path, score) = viterbi_path::(&mut hg); + + if path.is_empty() { + eprintln!("WARNING: no parse found"); + println!("NO_PARSE ||| 0.0"); + continue; + } + + let last_head = hg.edges[path.last().unwrap().0].head; + let mut carry = Vec::new(); + derive(&hg, &path, last_head, &mut carry); + + println!("{} ||| {}", carry.join(" "), score.ln()); + } +} diff --git a/rs/src/parse.rs b/rs/src/parse.rs new file mode 100644 index 0000000..7f64a89 --- /dev/null +++ b/rs/src/parse.rs @@ -0,0 +1,347 @@ +use crate::grammar::{Grammar, Symbol}; + +#[derive(Debug, Clone)] +pub struct Span { + pub left: i32, + pub right: i32, +} + +#[derive(Debug, Clone)] +pub struct Item { + pub rule_idx: usize, + pub left: usize, + pub right: usize, + pub dot: usize, + pub tail_spans: Vec>, +} + +impl Item { + pub fn from_rule(rule_idx: usize, rhs: &[Symbol], left: usize, right: usize, dot: usize) -> Self { + let tail_spans = rhs + .iter() + .map(|sym| { + if sym.is_nt() { + Some(Span { + left: -1, + right: -1, + }) + } else { + None + } + }) + .collect(); + + Self { + rule_idx, + left, + right, + dot, + tail_spans, + } + } + + pub fn advance(&self, left: usize, right: usize, dot: usize) -> Self { + Self { + rule_idx: self.rule_idx, + left, + right, + dot, + tail_spans: self.tail_spans.clone(), + } + } +} + +pub struct Chart { + _n: usize, + m: Vec>>, + b: std::collections::HashSet<(u16, u16, u16)>, + /// Index: (symbol_id, left) -> sorted vec of right endpoints + spans_by_left: std::collections::HashMap<(u16, u16), Vec>, + sym_to_id: std::collections::HashMap, + next_sym_id: u16, +} + +impl Chart { + pub fn new(n: usize) -> Self { + let mut m = Vec::with_capacity(n + 1); + for _ in 0..=n { + let mut row = Vec::with_capacity(n + 1); + for _ in 0..=n { + row.push(Vec::new()); + } + m.push(row); + } + Self { + _n: n, + m, + b: std::collections::HashSet::new(), + spans_by_left: std::collections::HashMap::new(), + sym_to_id: std::collections::HashMap::new(), + next_sym_id: 0, + } + } + + fn sym_id(&mut self, symbol: &str) -> u16 { + if let Some(&id) = self.sym_to_id.get(symbol) { + id + } else { + let id = self.next_sym_id; + self.next_sym_id += 1; + self.sym_to_id.insert(symbol.to_string(), id); + id + } + } + + fn sym_id_lookup(&self, symbol: &str) -> Option { + self.sym_to_id.get(symbol).copied() + } + + pub fn at(&self, i: usize, j: usize) -> &Vec { + &self.m[i][j] + } + + pub fn at_mut(&mut self, i: usize, j: usize) -> &mut Vec { + &mut self.m[i][j] + } + + pub fn add(&mut self, item: Item, i: usize, j: usize, symbol: &str) { + let sid = self.sym_id(symbol); + if self.b.insert((i as u16, j as u16, sid)) { + self.spans_by_left + .entry((sid, i as u16)) + .or_default() + .push(j as u16); + } + self.m[i][j].push(item); + } + + pub fn has(&self, symbol: &str, i: usize, j: usize) -> bool { + if let Some(&sid) = self.sym_to_id.get(symbol) { + self.b.contains(&(i as u16, j as u16, sid)) + } else { + false + } + } + + /// Returns right endpoints where (symbol, left, right) exists + pub fn rights_for(&self, symbol: &str, left: usize) -> &[u16] { + static EMPTY: Vec = Vec::new(); + if let Some(sid) = self.sym_id_lookup(symbol) { + if let Some(rights) = self.spans_by_left.get(&(sid, left as u16)) { + return rights; + } + } + &EMPTY + } + + /// Check if any entry for (symbol, left, *) exists + pub fn has_any_at(&self, symbol: &str, left: usize) -> bool { + if let Some(sid) = self.sym_id_lookup(symbol) { + self.spans_by_left.contains_key(&(sid, left as u16)) + } else { + false + } + } +} + +/// Visit spans bottom-up: from span size `from` up. +/// Yields (left, right) pairs. +pub fn visit(from: usize, l: usize, r: usize, x: usize, mut f: F) { + for span in from..=(r - x) { + for k in l..=(r - span) { + f(k, k + span); + } + } +} + +pub fn init(input: &[String], n: usize, _active_chart: &mut Chart, passive_chart: &mut Chart, grammar: &Grammar) { + for i in 0..n { + if let Some(matching) = grammar.flat_by_first_word.get(&input[i]) { + for &fi in matching { + let rule = &grammar.rules[fi]; + let rhs_len = rule.rhs.len(); + if i + rhs_len > n { + continue; + } + let matches = rule + .rhs + .iter() + .enumerate() + .skip(1) + .all(|(k, s)| s.word() == input[i + k]); + if matches { + let item = Item::from_rule(fi, &rule.rhs, i, i + rhs_len, rhs_len); + passive_chart.add(item, i, i + rhs_len, rule.lhs.nt_symbol()); + } + } + } + } +} + +fn scan(item: &mut Item, input: &[String], limit: usize, grammar: &Grammar) -> bool { + let rhs = &grammar.rules[item.rule_idx].rhs; + while item.dot < rhs.len() && rhs[item.dot].is_t() { + if item.right == limit { + return false; + } + if rhs[item.dot].word() == input[item.right] { + item.dot += 1; + item.right += 1; + } else { + return false; + } + } + true +} + +pub fn parse( + input: &[String], + n: usize, + active_chart: &mut Chart, + passive_chart: &mut Chart, + grammar: &Grammar, +) { + visit(1, 0, n, 0, |i, j| { + // Try to apply rules starting with T + if let Some(matching) = grammar.start_t_by_word.get(&input[i]) { + for &ri in matching { + let mut new_item = Item::from_rule(ri, &grammar.rules[ri].rhs, i, i, 0); + if scan(&mut new_item, input, j, grammar) { + active_chart.at_mut(i, j).push(new_item); + } + } + } + + // Seed active chart with rules starting with NT + for (symbol, rule_indices) in &grammar.start_nt_by_symbol { + if !passive_chart.has_any_at(symbol, i) { + continue; + } + for &ri in rule_indices { + let rule = &grammar.rules[ri]; + if rule.rhs.len() > j - i { + continue; + } + active_chart + .at_mut(i, j) + .push(Item::from_rule(ri, &rule.rhs, i, i, 0)); + } + } + + // Parse + let mut new_symbols: Vec = Vec::new(); + let mut remaining_items: Vec = Vec::new(); + + while !active_chart.at(i, j).is_empty() { + let active_item = active_chart.at_mut(i, j).pop().unwrap(); + let mut advanced = false; + + let active_rhs = &grammar.rules[active_item.rule_idx].rhs; + if active_item.dot < active_rhs.len() && active_rhs[active_item.dot].is_nt() { + let wanted_symbol = active_rhs[active_item.dot].nt_symbol(); + let k = active_item.right; + + // Use index to directly get matching right endpoints + let rights: Vec = passive_chart + .rights_for(wanted_symbol, k) + .iter() + .filter(|&&r| (r as usize) <= j) + .copied() + .collect(); + + for l16 in rights { + let l = l16 as usize; + + let mut new_item = + active_item.advance(active_item.left, l, active_item.dot + 1); + new_item.tail_spans[new_item.dot - 1] = Some(Span { + left: k as i32, + right: l as i32, + }); + + let new_rhs = &grammar.rules[new_item.rule_idx].rhs; + if scan(&mut new_item, input, j, grammar) { + if new_item.dot == new_rhs.len() { + if new_item.left == i && new_item.right == j { + let sym_str = grammar.rules[new_item.rule_idx].lhs.nt_symbol(); + if !new_symbols.iter().any(|s| s == sym_str) { + new_symbols.push(sym_str.to_string()); + } + passive_chart.add(new_item, i, j, sym_str); + advanced = true; + } + } else if new_item.right + (new_rhs.len() - new_item.dot) <= j { + active_chart.at_mut(i, j).push(new_item); + advanced = true; + } + } + } + } + + if !advanced { + remaining_items.push(active_item); + } + } + + // Self-filling step + let mut si = 0; + while si < new_symbols.len() { + let s = new_symbols[si].clone(); + + // Try start_nt rules from the grammar for this new symbol. + // This handles rules that weren't seeded in step 2 because + // the symbol didn't exist in the passive chart at that point. + if let Some(rule_indices) = grammar.start_nt_by_symbol.get(&s) { + for &ri in rule_indices { + let rule = &grammar.rules[ri]; + if rule.rhs.len() > j - i { + continue; + } + let mut new_item = Item::from_rule(ri, &rule.rhs, i, i, 0); + new_item.dot = 1; + new_item.right = j; + new_item.tail_spans[0] = Some(Span { + left: i as i32, + right: j as i32, + }); + if new_item.dot == rule.rhs.len() { + let sym_str = rule.lhs.nt_symbol(); + if !new_symbols.iter().any(|ns| ns == sym_str) { + new_symbols.push(sym_str.to_string()); + } + passive_chart.add(new_item, i, j, sym_str); + } + } + } + + // Also check remaining active items (for rules that were + // seeded but couldn't advance during the parse loop) + for item in &remaining_items { + if item.dot != 0 { + continue; + } + let item_rhs = &grammar.rules[item.rule_idx].rhs; + if !item_rhs[item.dot].is_nt() { + continue; + } + if item_rhs[item.dot].nt_symbol() != s { + continue; + } + let mut new_item = item.advance(i, j, item.dot + 1); + new_item.tail_spans[new_item.dot - 1] = Some(Span { + left: i as i32, + right: j as i32, + }); + let new_rhs = &grammar.rules[new_item.rule_idx].rhs; + if new_item.dot == new_rhs.len() { + let sym_str = grammar.rules[new_item.rule_idx].lhs.nt_symbol(); + if !new_symbols.iter().any(|ns| ns == sym_str) { + new_symbols.push(sym_str.to_string()); + } + passive_chart.add(new_item, i, j, sym_str); + } + } + si += 1; + } + }); +} diff --git a/rs/src/semiring.rs b/rs/src/semiring.rs new file mode 100644 index 0000000..9a8fe3e --- /dev/null +++ b/rs/src/semiring.rs @@ -0,0 +1,39 @@ +pub trait Semiring { + fn one() -> f64; + fn null() -> f64; + fn add(a: f64, b: f64) -> f64; + fn multiply(a: f64, b: f64) -> f64; +} + +pub struct ViterbiSemiring; + +impl Semiring for ViterbiSemiring { + fn one() -> f64 { + 1.0 + } + + fn null() -> f64 { + 0.0 + } + + fn add(a: f64, b: f64) -> f64 { + a.max(b) + } + + fn multiply(a: f64, b: f64) -> f64 { + a * b + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_viterbi() { + assert_eq!(ViterbiSemiring::one(), 1.0); + assert_eq!(ViterbiSemiring::null(), 0.0); + assert_eq!(ViterbiSemiring::add(0.3, 0.7), 0.7); + assert_eq!(ViterbiSemiring::multiply(0.5, 0.6), 0.3); + } +} diff --git a/rs/src/sparse_vector.rs b/rs/src/sparse_vector.rs new file mode 100644 index 0000000..4e62f95 --- /dev/null +++ b/rs/src/sparse_vector.rs @@ -0,0 +1,86 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone, Default)] +pub struct SparseVector { + pub map: HashMap, +} + +impl SparseVector { + pub fn new() -> Self { + Self { + map: HashMap::new(), + } + } + + pub fn from_kv(s: &str, kv_sep: char, pair_sep: char) -> Self { + let mut map = HashMap::new(); + for pair in s.split(pair_sep) { + let pair = pair.trim(); + if pair.is_empty() { + continue; + } + if let Some((k, v)) = pair.split_once(kv_sep) { + if let Ok(val) = v.trim().parse::() { + map.insert(k.trim().to_string(), val); + } + } + } + Self { map } + } + + pub fn from_hash(h: &serde_json::Map) -> Self { + let mut map = HashMap::new(); + for (k, v) in h { + if let Some(val) = v.as_f64() { + map.insert(k.clone(), val); + } + } + Self { map } + } + + pub fn dot(&self, other: &SparseVector) -> f64 { + let mut sum = 0.0; + for (k, v) in &self.map { + if let Some(ov) = other.map.get(k) { + sum += v * ov; + } + } + sum + } + + pub fn to_json(&self) -> serde_json::Value { + let map: serde_json::Map = self + .map + .iter() + .map(|(k, v)| (k.clone(), serde_json::Value::from(*v))) + .collect(); + serde_json::Value::Object(map) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_kv() { + let sv = SparseVector::from_kv("logp 2\nuse_house 0\nuse_shell 1", ' ', '\n'); + assert_eq!(sv.map["logp"], 2.0); + assert_eq!(sv.map["use_house"], 0.0); + assert_eq!(sv.map["use_shell"], 1.0); + } + + #[test] + fn test_dot() { + let a = SparseVector::from_kv("x=1 y=2", '=', ' '); + let b = SparseVector::from_kv("x=3 y=4 z=5", '=', ' '); + assert_eq!(a.dot(&b), 11.0); + } + + #[test] + fn test_empty_dot() { + let a = SparseVector::new(); + let b = SparseVector::from_kv("x=1", '=', ' '); + assert_eq!(a.dot(&b), 0.0); + } +} -- cgit v1.2.3