diff options
| author | Patrick Simianer <patrick@lilt.com> | 2026-02-26 19:28:22 +0100 |
|---|---|---|
| committer | Patrick Simianer <patrick@lilt.com> | 2026-02-26 19:28:22 +0100 |
| commit | 0abcdd7e4358cb902c320b008d3c04bde07b749e (patch) | |
| tree | f26bd36cc16b792ef4acf5450ef9293b55179167 /rs | |
| parent | 4e62908a1757f83ff703399252ad50758c4eb237 (diff) | |
Add Rust implementation of SCFG decoder
Rust port of the Ruby prototype decoder with performance
optimizations for real Hiero-style grammars:
- Rule indexing by first terminal/NT symbol for fast lookup
- Chart symbol interning (u16 IDs) instead of string hashing
- Passive chart index by (symbol, left) for direct right-endpoint lookup
- Items store rule index instead of cloned rule data
Includes CKY+ parser, chart-to-hypergraph conversion, Viterbi
decoding, derivation extraction, and JSON hypergraph I/O.
Self-filling step in parse uses grammar lookup (not just
remaining active items) to handle rules that were consumed
during the parse loop or skipped by the has_any_at optimization.
Produces identical output to the Ruby prototype on all test examples.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'rs')
| -rw-r--r-- | rs/.gitignore | 1 | ||||
| -rw-r--r-- | rs/Cargo.lock | 249 | ||||
| -rw-r--r-- | rs/Cargo.toml | 9 | ||||
| -rw-r--r-- | rs/README.md | 39 | ||||
| -rw-r--r-- | rs/src/chart_to_hg.rs | 95 | ||||
| -rw-r--r-- | rs/src/grammar.rs | 393 | ||||
| -rw-r--r-- | rs/src/hypergraph.rs | 115 | ||||
| -rw-r--r-- | rs/src/hypergraph_algos.rs | 215 | ||||
| -rw-r--r-- | rs/src/hypergraph_io.rs | 110 | ||||
| -rw-r--r-- | rs/src/lib.rs | 8 | ||||
| -rw-r--r-- | rs/src/main.rs | 110 | ||||
| -rw-r--r-- | rs/src/parse.rs | 347 | ||||
| -rw-r--r-- | rs/src/semiring.rs | 39 | ||||
| -rw-r--r-- | rs/src/sparse_vector.rs | 86 |
14 files changed, 1816 insertions, 0 deletions
diff --git a/rs/.gitignore b/rs/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/rs/.gitignore @@ -0,0 +1 @@ +/target diff --git a/rs/Cargo.lock b/rs/Cargo.lock new file mode 100644 index 0000000..a113e8a --- /dev/null +++ b/rs/Cargo.lock @@ -0,0 +1,249 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys", +] + +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "odenwald" +version = "0.1.0" +dependencies = [ + "clap", + "serde", + "serde_json", +] + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/rs/Cargo.toml b/rs/Cargo.toml new file mode 100644 index 0000000..e37f06f --- /dev/null +++ b/rs/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "odenwald" +version = "0.1.0" +edition = "2021" + +[dependencies] +clap = { version = "4", features = ["derive"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" diff --git a/rs/README.md b/rs/README.md new file mode 100644 index 0000000..5daa458 --- /dev/null +++ b/rs/README.md @@ -0,0 +1,39 @@ +# odenwald + +Rust implementation of the Odenwald SCFG (synchronous context-free grammar) machine translation decoder. + +## Build + +``` +cargo build --release +``` + +## Usage + +``` +odenwald -g <grammar> -w <weights> [-i <input>] [-l] [-p] +``` + +- `-g, --grammar` — grammar file (required) +- `-w, --weights` — weights file (required) +- `-i, --input` — input file (default: stdin) +- `-l, --add-glue` — add glue rules +- `-p, --add-pass-through` — add pass-through rules + +Output: `translation ||| log_score` per input line. + +## Examples + +``` +cargo run -- -g ../example/toy/grammar -w ../example/toy/weights.toy -i ../example/toy/in -l +# → i saw a small shell ||| -0.5 + +cargo run -- -g ../example/toy-reorder/grammar -w ../example/toy-reorder/weights.toy -i ../example/toy-reorder/in -l +# → he reads the book ||| -1.5 +``` + +## Tests + +``` +cargo test +``` diff --git a/rs/src/chart_to_hg.rs b/rs/src/chart_to_hg.rs new file mode 100644 index 0000000..95f642e --- /dev/null +++ b/rs/src/chart_to_hg.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; + +use crate::grammar::{Grammar, Rule, Symbol}; +use crate::hypergraph::{Hypergraph, NodeId}; +use crate::parse::{Chart, visit}; +use crate::sparse_vector::SparseVector; + +pub fn chart_to_hg(chart: &Chart, n: usize, weights: &SparseVector, grammar: &Grammar) -> Hypergraph { + let mut hg = Hypergraph::new(); + let mut seen: HashMap<String, NodeId> = HashMap::new(); + + // Add root node with id=-1 + let root_nid = hg.add_node(-1, "root", -1, -1); + + let mut next_id: i64 = 0; + + // First pass: create nodes + visit(1, 0, n, 0, |i, j| { + for item in chart.at(i, j) { + let lhs_sym = grammar.rules[item.rule_idx].lhs.nt_symbol(); + let key = format!("{},{},{}", lhs_sym, i, j); + if !seen.contains_key(&key) { + let nid = hg.add_node(next_id, lhs_sym, i as i32, j as i32); + seen.insert(key, nid); + next_id += 1; + } + } + }); + + // Second pass: create edges + visit(1, 0, n, 0, |i, j| { + for item in chart.at(i, j) { + let rule = &grammar.rules[item.rule_idx]; + let head_key = format!("{},{},{}", rule.lhs.nt_symbol(), i, j); + let head_nid = *seen.get(&head_key).unwrap(); + + // Build tails + let tails: Vec<NodeId> = if item.tail_spans.iter().all(|s| s.is_none()) + || item.tail_spans.iter().all(|s| { + s.as_ref() + .map_or(true, |sp| sp.left == -1 && sp.right == -1) + }) + { + if rule.rhs.iter().any(|s| s.is_nt()) { + build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid) + } else { + vec![root_nid] + } + } else { + build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid) + }; + + let score = weights.dot(&rule.f).exp(); + let rule_clone = Rule::new( + rule.lhs.clone(), + rule.rhs.clone(), + rule.target.clone(), + rule.map.clone(), + rule.f.clone(), + ); + + hg.add_edge(head_nid, tails, score, rule.f.clone(), rule_clone); + } + }); + + hg +} + +fn build_tails( + rhs: &[Symbol], + tail_spans: &[Option<crate::parse::Span>], + seen: &HashMap<String, NodeId>, + root_nid: NodeId, +) -> Vec<NodeId> { + let mut tails = Vec::new(); + let mut has_nt = false; + for (idx, sym) in rhs.iter().enumerate() { + if let Symbol::NT { symbol, .. } = sym { + has_nt = true; + if let Some(Some(span)) = tail_spans.get(idx) { + if span.left >= 0 && span.right >= 0 { + let key = format!("{},{},{}", symbol, span.left, span.right); + if let Some(&nid) = seen.get(&key) { + tails.push(nid); + } + } + } + } + } + if !has_nt { + vec![root_nid] + } else { + tails + } +} diff --git a/rs/src/grammar.rs b/rs/src/grammar.rs new file mode 100644 index 0000000..cef49ee --- /dev/null +++ b/rs/src/grammar.rs @@ -0,0 +1,393 @@ +use std::fmt; +use std::fs; +use std::io::{self, BufRead}; + +use crate::sparse_vector::SparseVector; + +#[derive(Debug, Clone, PartialEq)] +pub enum Symbol { + T(String), + NT { symbol: String, index: i32 }, +} + +impl Symbol { + pub fn parse(s: &str) -> Symbol { + let s = s.trim(); + if s.starts_with('[') && s.ends_with(']') { + let inner = &s[1..s.len() - 1]; + if let Some((sym, idx)) = inner.split_once(',') { + Symbol::NT { + symbol: sym.trim().to_string(), + index: idx.trim().parse::<i32>().unwrap() - 1, + } + } else { + Symbol::NT { + symbol: inner.trim().to_string(), + index: -1, + } + } + } else { + Symbol::T(s.to_string()) + } + } + + pub fn is_nt(&self) -> bool { + matches!(self, Symbol::NT { .. }) + } + + pub fn is_t(&self) -> bool { + matches!(self, Symbol::T(_)) + } + + pub fn word(&self) -> &str { + match self { + Symbol::T(w) => w, + Symbol::NT { .. } => panic!("word() called on NT"), + } + } + + pub fn nt_symbol(&self) -> &str { + match self { + Symbol::NT { symbol, .. } => symbol, + Symbol::T(_) => panic!("nt_symbol() called on T"), + } + } + + pub fn nt_index(&self) -> i32 { + match self { + Symbol::NT { index, .. } => *index, + Symbol::T(_) => panic!("nt_index() called on T"), + } + } +} + +impl fmt::Display for Symbol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Symbol::T(w) => write!(f, "{}", w), + Symbol::NT { symbol, index } if *index >= 0 => { + write!(f, "[{},{}]", symbol, index + 1) + } + Symbol::NT { symbol, .. } => write!(f, "[{}]", symbol), + } + } +} + +fn splitpipe(s: &str, n: usize) -> Vec<&str> { + let mut parts = Vec::new(); + let mut rest = s; + for _ in 0..n { + if let Some(pos) = rest.find("|||") { + parts.push(rest[..pos].trim()); + rest = rest[pos + 3..].trim(); + } else { + break; + } + } + parts.push(rest.trim()); + parts +} + +#[derive(Debug, Clone)] +pub struct Rule { + pub lhs: Symbol, + pub rhs: Vec<Symbol>, + pub target: Vec<Symbol>, + pub map: Vec<i32>, + pub f: SparseVector, + pub arity: usize, +} + +impl Rule { + pub fn new( + lhs: Symbol, + rhs: Vec<Symbol>, + target: Vec<Symbol>, + map: Vec<i32>, + f: SparseVector, + ) -> Self { + let arity = rhs.iter().filter(|s| s.is_nt()).count(); + Self { + lhs, + rhs, + target, + map, + f, + arity, + } + } + + pub fn from_str(s: &str) -> Self { + let parts = splitpipe(s, 3); + let lhs = Symbol::parse(parts[0]); + + let mut map = Vec::new(); + let mut arity = 0; + let rhs: Vec<Symbol> = parts[1] + .split_whitespace() + .map(|tok| { + let sym = Symbol::parse(tok); + if let Symbol::NT { index, .. } = &sym { + map.push(*index); + arity += 1; + } + sym + }) + .collect(); + + let target: Vec<Symbol> = parts[2] + .split_whitespace() + .map(|tok| Symbol::parse(tok)) + .collect(); + + let f = if parts.len() > 3 && !parts[3].is_empty() { + SparseVector::from_kv(parts[3], '=', ' ') + } else { + SparseVector::new() + }; + + Self { + lhs, + rhs, + target, + map, + f, + arity, + } + } +} + +impl fmt::Display for Rule { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} ||| {} ||| {}", + self.lhs, + self.rhs + .iter() + .map(|s| s.to_string()) + .collect::<Vec<_>>() + .join(" "), + self.target + .iter() + .map(|s| s.to_string()) + .collect::<Vec<_>>() + .join(" "), + ) + } +} + +pub struct Grammar { + pub rules: Vec<Rule>, + pub start_nt: Vec<usize>, + pub start_t: Vec<usize>, + pub flat: Vec<usize>, + pub start_t_by_word: std::collections::HashMap<String, Vec<usize>>, + pub flat_by_first_word: std::collections::HashMap<String, Vec<usize>>, + pub start_nt_by_symbol: std::collections::HashMap<String, Vec<usize>>, +} + +impl Grammar { + pub fn from_file(path: &str) -> io::Result<Self> { + let file = fs::File::open(path)?; + let reader = io::BufReader::new(file); + let mut rules = Vec::new(); + let mut start_nt = Vec::new(); + let mut start_t = Vec::new(); + let mut flat = Vec::new(); + let mut start_t_by_word: std::collections::HashMap<String, Vec<usize>> = + std::collections::HashMap::new(); + let mut flat_by_first_word: std::collections::HashMap<String, Vec<usize>> = + std::collections::HashMap::new(); + let mut start_nt_by_symbol: std::collections::HashMap<String, Vec<usize>> = + std::collections::HashMap::new(); + + for line in reader.lines() { + let line = line?; + let line = line.trim().to_string(); + if line.is_empty() { + continue; + } + let idx = rules.len(); + let rule = Rule::from_str(&line); + if rule.rhs.first().map_or(false, |s| s.is_nt()) { + start_nt_by_symbol + .entry(rule.rhs[0].nt_symbol().to_string()) + .or_default() + .push(idx); + start_nt.push(idx); + } else if rule.arity == 0 { + flat_by_first_word + .entry(rule.rhs[0].word().to_string()) + .or_default() + .push(idx); + flat.push(idx); + } else { + start_t_by_word + .entry(rule.rhs[0].word().to_string()) + .or_default() + .push(idx); + start_t.push(idx); + } + rules.push(rule); + } + + Ok(Self { + rules, + start_nt, + start_t, + flat, + start_t_by_word, + flat_by_first_word, + start_nt_by_symbol, + }) + } + + pub fn add_glue_rules(&mut self) { + let symbols: Vec<String> = self + .rules + .iter() + .map(|r| r.lhs.nt_symbol().to_string()) + .filter(|s| s != "S") + .collect::<std::collections::HashSet<_>>() + .into_iter() + .collect(); + + let mut once = false; + for symbol in symbols { + let idx = self.rules.len(); + self.rules.push(Rule::new( + Symbol::NT { + symbol: "S".to_string(), + index: -1, + }, + vec![Symbol::NT { + symbol: symbol.clone(), + index: 0, + }], + vec![Symbol::NT { + symbol: symbol.clone(), + index: 0, + }], + vec![0], + SparseVector::new(), + )); + self.start_nt.push(idx); + self.start_nt_by_symbol + .entry(symbol) + .or_default() + .push(idx); + once = true; + } + + if once { + let idx = self.rules.len(); + self.rules.push(Rule::new( + Symbol::NT { + symbol: "S".to_string(), + index: -1, + }, + vec![ + Symbol::NT { + symbol: "S".to_string(), + index: 0, + }, + Symbol::NT { + symbol: "X".to_string(), + index: 1, + }, + ], + vec![ + Symbol::NT { + symbol: "S".to_string(), + index: 0, + }, + Symbol::NT { + symbol: "X".to_string(), + index: 1, + }, + ], + vec![0, 1], + SparseVector::new(), + )); + self.start_nt.push(idx); + self.start_nt_by_symbol + .entry("S".to_string()) + .or_default() + .push(idx); + } + } + + pub fn add_pass_through_rules(&mut self, words: &[String]) { + for word in words { + let idx = self.rules.len(); + self.rules.push(Rule::new( + Symbol::NT { + symbol: "X".to_string(), + index: -1, + }, + vec![Symbol::T(word.clone())], + vec![Symbol::T(word.clone())], + vec![], + SparseVector::new(), + )); + self.flat.push(idx); + self.flat_by_first_word + .entry(word.clone()) + .or_default() + .push(idx); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symbol_parse_t() { + let s = Symbol::parse("hello"); + assert_eq!(s, Symbol::T("hello".to_string())); + } + + #[test] + fn test_symbol_parse_nt() { + let s = Symbol::parse("[NP,1]"); + assert_eq!( + s, + Symbol::NT { + symbol: "NP".to_string(), + index: 0 + } + ); + } + + #[test] + fn test_symbol_parse_nt_no_index() { + let s = Symbol::parse("[S]"); + assert_eq!( + s, + Symbol::NT { + symbol: "S".to_string(), + index: -1 + } + ); + } + + #[test] + fn test_rule_from_str() { + let r = Rule::from_str("[NP] ||| ein [NN,1] ||| a [NN,1] ||| logp=0 use_a=1.0"); + assert_eq!(r.lhs.nt_symbol(), "NP"); + assert_eq!(r.rhs.len(), 2); + assert_eq!(r.target.len(), 2); + assert_eq!(r.map, vec![0]); + assert_eq!(r.arity, 1); + assert_eq!(*r.f.map.get("use_a").unwrap(), 1.0); + } + + #[test] + fn test_rule_display() { + let r = Rule::from_str("[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2] ||| logp=0"); + assert_eq!(r.to_string(), "[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2]"); + } +} diff --git a/rs/src/hypergraph.rs b/rs/src/hypergraph.rs new file mode 100644 index 0000000..90069b0 --- /dev/null +++ b/rs/src/hypergraph.rs @@ -0,0 +1,115 @@ +use crate::grammar::Rule; +use crate::sparse_vector::SparseVector; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct NodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct EdgeId(pub usize); + +#[derive(Debug, Clone)] +pub struct Node { + pub id: i64, + pub symbol: String, + pub left: i32, + pub right: i32, + pub outgoing: Vec<EdgeId>, + pub incoming: Vec<EdgeId>, + pub score: f64, +} + +impl Node { + pub fn new(id: i64, symbol: &str, left: i32, right: i32) -> Self { + Self { + id, + symbol: symbol.to_string(), + left, + right, + outgoing: Vec::new(), + incoming: Vec::new(), + score: 0.0, + } + } +} + +#[derive(Debug, Clone)] +pub struct Edge { + pub head: NodeId, + pub tails: Vec<NodeId>, + pub score: f64, + pub f: SparseVector, + pub mark: usize, + pub rule: Rule, +} + +impl Edge { + pub fn arity(&self) -> usize { + self.tails.len() + } + + pub fn marked(&self) -> bool { + self.arity() == self.mark + } +} + +#[derive(Debug)] +pub struct Hypergraph { + pub nodes: Vec<Node>, + pub edges: Vec<Edge>, + pub nodes_by_id: std::collections::HashMap<i64, NodeId>, +} + +impl Hypergraph { + pub fn new() -> Self { + Self { + nodes: Vec::new(), + edges: Vec::new(), + nodes_by_id: std::collections::HashMap::new(), + } + } + + pub fn add_node(&mut self, id: i64, symbol: &str, left: i32, right: i32) -> NodeId { + let nid = NodeId(self.nodes.len()); + self.nodes.push(Node::new(id, symbol, left, right)); + self.nodes_by_id.insert(id, nid); + nid + } + + pub fn add_edge( + &mut self, + head: NodeId, + tails: Vec<NodeId>, + score: f64, + f: SparseVector, + rule: Rule, + ) -> EdgeId { + let eid = EdgeId(self.edges.len()); + self.edges.push(Edge { + head, + tails: tails.clone(), + score, + f, + mark: 0, + rule, + }); + self.nodes[head.0].incoming.push(eid); + for &t in &tails { + self.nodes[t.0].outgoing.push(eid); + } + eid + } + + pub fn node(&self, nid: NodeId) -> &Node { + &self.nodes[nid.0] + } + + pub fn edge(&self, eid: EdgeId) -> &Edge { + &self.edges[eid.0] + } + + pub fn reset(&mut self) { + for e in &mut self.edges { + e.mark = 0; + } + } +} diff --git a/rs/src/hypergraph_algos.rs b/rs/src/hypergraph_algos.rs new file mode 100644 index 0000000..a4878b6 --- /dev/null +++ b/rs/src/hypergraph_algos.rs @@ -0,0 +1,215 @@ +use std::collections::{BTreeSet, HashSet, VecDeque}; + +use crate::grammar::Symbol; +use crate::hypergraph::{EdgeId, Hypergraph, NodeId}; +use crate::semiring::Semiring; + +pub fn topological_sort(hg: &mut Hypergraph) -> Vec<NodeId> { + let mut sorted = Vec::new(); + let mut queue: VecDeque<NodeId> = VecDeque::new(); + + // Start with nodes that have no incoming edges + for (i, node) in hg.nodes.iter().enumerate() { + if node.incoming.is_empty() { + queue.push_back(NodeId(i)); + } + } + + while let Some(nid) = queue.pop_front() { + sorted.push(nid); + let outgoing = hg.nodes[nid.0].outgoing.clone(); + for &eid in &outgoing { + let edge = &mut hg.edges[eid.0]; + if edge.marked() { + continue; + } + edge.mark += 1; + if edge.marked() { + let head = edge.head; + // Check if all incoming edges of head are marked + let all_marked = hg.nodes[head.0] + .incoming + .iter() + .all(|&ie| hg.edges[ie.0].marked()); + if all_marked { + queue.push_back(head); + } + } + } + } + + sorted +} + +pub fn viterbi_path<S: Semiring>(hg: &mut Hypergraph) -> (Vec<EdgeId>, f64) { + let toposorted = topological_sort(hg); + + // Init + for node in &mut hg.nodes { + node.score = S::null(); + } + if let Some(&first) = toposorted.first() { + hg.nodes[first.0].score = S::one(); + } + + let mut best_path: Vec<EdgeId> = Vec::new(); + + for &nid in &toposorted { + let incoming = hg.nodes[nid.0].incoming.clone(); + let mut best_edge: Option<EdgeId> = None; + for &eid in &incoming { + let edge = &hg.edges[eid.0]; + let mut s = S::one(); + for &tid in &edge.tails { + s = S::multiply(s, hg.nodes[tid.0].score); + } + let candidate = S::multiply(s, edge.score); + if hg.nodes[nid.0].score < candidate { + best_edge = Some(eid); + } + hg.nodes[nid.0].score = S::add(hg.nodes[nid.0].score, candidate); + } + if let Some(e) = best_edge { + best_path.push(e); + } + } + + let final_score = toposorted + .last() + .map(|&nid| hg.nodes[nid.0].score) + .unwrap_or(0.0); + + (best_path, final_score) +} + +pub fn derive(hg: &Hypergraph, path: &[EdgeId], cur: NodeId, carry: &mut Vec<String>) { + // Find edge in path whose head matches cur + let node = &hg.nodes[cur.0]; + let edge_id = path + .iter() + .find(|&&eid| { + let e = &hg.edges[eid.0]; + let h = &hg.nodes[e.head.0]; + h.symbol == node.symbol && h.left == node.left && h.right == node.right + }) + .expect("derive: no matching edge found"); + + let edge = &hg.edges[edge_id.0]; + for sym in &edge.rule.target { + match sym { + Symbol::NT { index, .. } => { + // Find which tail to recurse into using map + let tail_idx = edge + .rule + .map + .iter() + .position(|&m| m == *index) + .expect("derive: NT index not found in map"); + derive(hg, path, edge.tails[tail_idx], carry); + } + Symbol::T(word) => { + carry.push(word.clone()); + } + } + } +} + +pub fn all_paths(hg: &mut Hypergraph) -> Vec<Vec<EdgeId>> { + let toposorted = topological_sort(hg); + + let mut paths: Vec<Vec<EdgeId>> = vec![vec![]]; + + for &nid in &toposorted { + let incoming = hg.nodes[nid.0].incoming.clone(); + if incoming.is_empty() { + continue; + } + let mut new_paths = Vec::new(); + while let Some(p) = paths.pop() { + for &eid in &incoming { + let mut np = p.clone(); + np.push(eid); + new_paths.push(np); + } + } + paths = new_paths; + } + + // Dedup by reachable edge set + let mut seen: HashSet<Vec<usize>> = HashSet::new(); + paths + .into_iter() + .filter(|p| { + if p.is_empty() { + return false; + } + let mut reachable = BTreeSet::new(); + mark_reachable(hg, p, *p.last().unwrap(), &mut reachable); + let key: Vec<usize> = reachable.into_iter().map(|eid| eid.0).collect(); + seen.insert(key) + }) + .collect() +} + +fn mark_reachable( + hg: &Hypergraph, + path: &[EdgeId], + edge_id: EdgeId, + used: &mut BTreeSet<EdgeId>, +) { + used.insert(edge_id); + let edge = &hg.edges[edge_id.0]; + for &tail_nid in &edge.tails { + // Find edge in path whose head is this tail node + if let Some(&child_eid) = path.iter().find(|&&eid| hg.edges[eid.0].head == tail_nid) { + if !used.contains(&child_eid) { + mark_reachable(hg, path, child_eid, used); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hypergraph_io::read_hypergraph_from_json; + use crate::semiring::ViterbiSemiring; + + #[test] + fn test_viterbi_toy() { + let mut hg = read_hypergraph_from_json("../example/toy/toy.json", true).unwrap(); + let (path, score) = viterbi_path::<ViterbiSemiring>(&mut hg); + let mut carry = Vec::new(); + let last_head = hg.edges[path.last().unwrap().0].head; + derive(&hg, &path, last_head, &mut carry); + + let translation = carry.join(" "); + let log_score = score.ln(); + + assert_eq!(translation, "i saw a small shell"); + assert!((log_score - (-0.5)).abs() < 1e-9); + } + + #[test] + fn test_all_paths_toy() { + let mut hg = read_hypergraph_from_json("../example/toy/toy.json", true).unwrap(); + + let paths = all_paths(&mut hg); + // The toy hypergraph should have multiple distinct paths + assert!(paths.len() > 1); + + // Collect all translations + let mut translations: Vec<String> = Vec::new(); + for p in &paths { + let mut carry = Vec::new(); + let last_head = hg.edges[p.last().unwrap().0].head; + derive(&hg, p, last_head, &mut carry); + translations.push(carry.join(" ")); + } + + assert!(translations.contains(&"i saw a small shell".to_string())); + assert!(translations.contains(&"i saw a small house".to_string())); + assert!(translations.contains(&"i saw a little shell".to_string())); + assert!(translations.contains(&"i saw a little house".to_string())); + } +} diff --git a/rs/src/hypergraph_io.rs b/rs/src/hypergraph_io.rs new file mode 100644 index 0000000..5aa9f3b --- /dev/null +++ b/rs/src/hypergraph_io.rs @@ -0,0 +1,110 @@ +use std::fs; +use std::io; + +use crate::grammar::Rule; +use crate::hypergraph::{Hypergraph, NodeId}; +use crate::sparse_vector::SparseVector; + +pub fn read_hypergraph_from_json(path: &str, log_weights: bool) -> io::Result<Hypergraph> { + let content = fs::read_to_string(path)?; + let parsed: serde_json::Value = serde_json::from_str(&content) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let mut hg = Hypergraph::new(); + + let weights = match parsed.get("weights").and_then(|w| w.as_object()) { + Some(obj) => SparseVector::from_hash(obj), + None => SparseVector::new(), + }; + + // Add nodes + if let Some(nodes) = parsed.get("nodes").and_then(|n| n.as_array()) { + for n in nodes { + let id = n["id"].as_i64().unwrap(); + let cat = n["cat"].as_str().unwrap_or(""); + let span = n["span"].as_array().unwrap(); + let left = span[0].as_i64().unwrap() as i32; + let right = span[1].as_i64().unwrap() as i32; + hg.add_node(id, cat, left, right); + } + } + + // Add edges + if let Some(edges) = parsed.get("edges").and_then(|e| e.as_array()) { + for e in edges { + let head_id = e["head"].as_i64().unwrap(); + let head_nid = *hg.nodes_by_id.get(&head_id).unwrap(); + + let tails: Vec<NodeId> = e["tails"] + .as_array() + .unwrap() + .iter() + .map(|t| *hg.nodes_by_id.get(&t.as_i64().unwrap()).unwrap()) + .collect(); + + let f = match e.get("f").and_then(|f| f.as_object()) { + Some(obj) => SparseVector::from_hash(obj), + None => SparseVector::new(), + }; + + let rule_str = e["rule"].as_str().unwrap(); + let rule = Rule::from_str(rule_str); + + let score = if log_weights { + weights.dot(&f).exp() + } else { + weights.dot(&f) + }; + + hg.add_edge(head_nid, tails, score, f, rule); + } + } + + Ok(hg) +} + +pub fn write_hypergraph_to_json(hg: &Hypergraph, weights: &SparseVector) -> String { + let mut json_s = String::from("{\n"); + json_s.push_str(&format!( + "\"weights\":{},\n", + serde_json::to_string(&weights.to_json()).unwrap() + )); + + json_s.push_str("\"nodes\":\n[\n"); + let node_strs: Vec<String> = hg + .nodes + .iter() + .map(|n| { + format!( + "{{ \"id\":{}, \"cat\":\"{}\", \"span\":[{},{}] }}", + n.id, + n.symbol.replace('"', "\\\""), + n.left, + n.right + ) + }) + .collect(); + json_s.push_str(&node_strs.join(",\n")); + json_s.push_str("\n],\n"); + + json_s.push_str("\"edges\":\n[\n"); + let edge_strs: Vec<String> = hg + .edges + .iter() + .map(|e| { + let head_id = hg.nodes[e.head.0].id; + let tail_ids: Vec<i64> = e.tails.iter().map(|t| hg.nodes[t.0].id).collect(); + format!( + "{{ \"head\":{}, \"rule\":\"{}\", \"tails\":{:?}, \"f\":{} }}", + head_id, + e.rule.to_string().replace('"', "\\\""), + tail_ids, + serde_json::to_string(&e.f.to_json()).unwrap() + ) + }) + .collect(); + json_s.push_str(&edge_strs.join(",\n")); + json_s.push_str("\n]\n}\n"); + + json_s +} diff --git a/rs/src/lib.rs b/rs/src/lib.rs new file mode 100644 index 0000000..d446e19 --- /dev/null +++ b/rs/src/lib.rs @@ -0,0 +1,8 @@ +pub mod chart_to_hg; +pub mod grammar; +pub mod hypergraph; +pub mod hypergraph_algos; +pub mod hypergraph_io; +pub mod parse; +pub mod semiring; +pub mod sparse_vector; diff --git a/rs/src/main.rs b/rs/src/main.rs new file mode 100644 index 0000000..f03dbf5 --- /dev/null +++ b/rs/src/main.rs @@ -0,0 +1,110 @@ +use std::fs; +use std::io::{self, BufRead}; + +use clap::Parser; + +use odenwald::chart_to_hg::chart_to_hg; +use odenwald::grammar::Grammar; +use odenwald::hypergraph_algos::{derive, viterbi_path}; +use odenwald::parse::{self, Chart}; +use odenwald::semiring::ViterbiSemiring; +use odenwald::sparse_vector::SparseVector; + +#[derive(Parser)] +#[command(name = "odenwald")] +struct Cli { + /// Grammar file + #[arg(short = 'g', long)] + grammar: String, + + /// Weights file + #[arg(short = 'w', long)] + weights: String, + + /// Input file (default: stdin) + #[arg(short = 'i', long, default_value = "-")] + input: String, + + /// Add glue rules + #[arg(short = 'l', long)] + add_glue: bool, + + /// Add pass-through rules + #[arg(short = 'p', long)] + add_pass_through: bool, +} + +fn main() { + let cli = Cli::parse(); + + eprintln!("> reading grammar '{}'", cli.grammar); + let mut grammar = Grammar::from_file(&cli.grammar).expect("Failed to read grammar"); + + if cli.add_glue { + eprintln!(">> adding glue rules"); + grammar.add_glue_rules(); + } + + let weights_content = fs::read_to_string(&cli.weights).expect("Failed to read weights"); + let weights = SparseVector::from_kv(&weights_content, ' ', '\n'); + + let input_lines: Vec<String> = if cli.input == "-" { + let stdin = io::stdin(); + stdin.lock().lines().map(|l| l.unwrap()).collect() + } else { + let content = fs::read_to_string(&cli.input).expect("Failed to read input"); + content.lines().map(|l| l.to_string()).collect() + }; + + for line in &input_lines { + let line = line.trim(); + if line.is_empty() { + continue; + } + let input: Vec<String> = line.split_whitespace().map(|s| s.to_string()).collect(); + let n = input.len(); + + if cli.add_pass_through { + eprintln!(">> adding pass-through rules"); + grammar.add_pass_through_rules(&input); + } + + eprintln!("> initializing charts"); + let mut passive_chart = Chart::new(n); + let mut active_chart = Chart::new(n); + parse::init(&input, n, &mut active_chart, &mut passive_chart, &grammar); + + eprintln!("> parsing"); + parse::parse(&input, n, &mut active_chart, &mut passive_chart, &grammar); + + // Count passive chart items + let mut chart_items = 0; + let mut chart_cells = 0; + parse::visit(1, 0, n, 0, |i, j| { + let items = passive_chart.at(i, j).len(); + if items > 0 { + chart_cells += 1; + chart_items += items; + } + }); + eprintln!("> chart: {} items in {} cells", chart_items, chart_cells); + + let mut hg = chart_to_hg(&passive_chart, n, &weights, &grammar); + eprintln!("> hg: {} nodes, {} edges", hg.nodes.len(), hg.edges.len()); + + eprintln!("> viterbi"); + let (path, score) = viterbi_path::<ViterbiSemiring>(&mut hg); + + if path.is_empty() { + eprintln!("WARNING: no parse found"); + println!("NO_PARSE ||| 0.0"); + continue; + } + + let last_head = hg.edges[path.last().unwrap().0].head; + let mut carry = Vec::new(); + derive(&hg, &path, last_head, &mut carry); + + println!("{} ||| {}", carry.join(" "), score.ln()); + } +} diff --git a/rs/src/parse.rs b/rs/src/parse.rs new file mode 100644 index 0000000..7f64a89 --- /dev/null +++ b/rs/src/parse.rs @@ -0,0 +1,347 @@ +use crate::grammar::{Grammar, Symbol}; + +#[derive(Debug, Clone)] +pub struct Span { + pub left: i32, + pub right: i32, +} + +#[derive(Debug, Clone)] +pub struct Item { + pub rule_idx: usize, + pub left: usize, + pub right: usize, + pub dot: usize, + pub tail_spans: Vec<Option<Span>>, +} + +impl Item { + pub fn from_rule(rule_idx: usize, rhs: &[Symbol], left: usize, right: usize, dot: usize) -> Self { + let tail_spans = rhs + .iter() + .map(|sym| { + if sym.is_nt() { + Some(Span { + left: -1, + right: -1, + }) + } else { + None + } + }) + .collect(); + + Self { + rule_idx, + left, + right, + dot, + tail_spans, + } + } + + pub fn advance(&self, left: usize, right: usize, dot: usize) -> Self { + Self { + rule_idx: self.rule_idx, + left, + right, + dot, + tail_spans: self.tail_spans.clone(), + } + } +} + +pub struct Chart { + _n: usize, + m: Vec<Vec<Vec<Item>>>, + b: std::collections::HashSet<(u16, u16, u16)>, + /// Index: (symbol_id, left) -> sorted vec of right endpoints + spans_by_left: std::collections::HashMap<(u16, u16), Vec<u16>>, + sym_to_id: std::collections::HashMap<String, u16>, + next_sym_id: u16, +} + +impl Chart { + pub fn new(n: usize) -> Self { + let mut m = Vec::with_capacity(n + 1); + for _ in 0..=n { + let mut row = Vec::with_capacity(n + 1); + for _ in 0..=n { + row.push(Vec::new()); + } + m.push(row); + } + Self { + _n: n, + m, + b: std::collections::HashSet::new(), + spans_by_left: std::collections::HashMap::new(), + sym_to_id: std::collections::HashMap::new(), + next_sym_id: 0, + } + } + + fn sym_id(&mut self, symbol: &str) -> u16 { + if let Some(&id) = self.sym_to_id.get(symbol) { + id + } else { + let id = self.next_sym_id; + self.next_sym_id += 1; + self.sym_to_id.insert(symbol.to_string(), id); + id + } + } + + fn sym_id_lookup(&self, symbol: &str) -> Option<u16> { + self.sym_to_id.get(symbol).copied() + } + + pub fn at(&self, i: usize, j: usize) -> &Vec<Item> { + &self.m[i][j] + } + + pub fn at_mut(&mut self, i: usize, j: usize) -> &mut Vec<Item> { + &mut self.m[i][j] + } + + pub fn add(&mut self, item: Item, i: usize, j: usize, symbol: &str) { + let sid = self.sym_id(symbol); + if self.b.insert((i as u16, j as u16, sid)) { + self.spans_by_left + .entry((sid, i as u16)) + .or_default() + .push(j as u16); + } + self.m[i][j].push(item); + } + + pub fn has(&self, symbol: &str, i: usize, j: usize) -> bool { + if let Some(&sid) = self.sym_to_id.get(symbol) { + self.b.contains(&(i as u16, j as u16, sid)) + } else { + false + } + } + + /// Returns right endpoints where (symbol, left, right) exists + pub fn rights_for(&self, symbol: &str, left: usize) -> &[u16] { + static EMPTY: Vec<u16> = Vec::new(); + if let Some(sid) = self.sym_id_lookup(symbol) { + if let Some(rights) = self.spans_by_left.get(&(sid, left as u16)) { + return rights; + } + } + &EMPTY + } + + /// Check if any entry for (symbol, left, *) exists + pub fn has_any_at(&self, symbol: &str, left: usize) -> bool { + if let Some(sid) = self.sym_id_lookup(symbol) { + self.spans_by_left.contains_key(&(sid, left as u16)) + } else { + false + } + } +} + +/// Visit spans bottom-up: from span size `from` up. +/// Yields (left, right) pairs. +pub fn visit<F: FnMut(usize, usize)>(from: usize, l: usize, r: usize, x: usize, mut f: F) { + for span in from..=(r - x) { + for k in l..=(r - span) { + f(k, k + span); + } + } +} + +pub fn init(input: &[String], n: usize, _active_chart: &mut Chart, passive_chart: &mut Chart, grammar: &Grammar) { + for i in 0..n { + if let Some(matching) = grammar.flat_by_first_word.get(&input[i]) { + for &fi in matching { + let rule = &grammar.rules[fi]; + let rhs_len = rule.rhs.len(); + if i + rhs_len > n { + continue; + } + let matches = rule + .rhs + .iter() + .enumerate() + .skip(1) + .all(|(k, s)| s.word() == input[i + k]); + if matches { + let item = Item::from_rule(fi, &rule.rhs, i, i + rhs_len, rhs_len); + passive_chart.add(item, i, i + rhs_len, rule.lhs.nt_symbol()); + } + } + } + } +} + +fn scan(item: &mut Item, input: &[String], limit: usize, grammar: &Grammar) -> bool { + let rhs = &grammar.rules[item.rule_idx].rhs; + while item.dot < rhs.len() && rhs[item.dot].is_t() { + if item.right == limit { + return false; + } + if rhs[item.dot].word() == input[item.right] { + item.dot += 1; + item.right += 1; + } else { + return false; + } + } + true +} + +pub fn parse( + input: &[String], + n: usize, + active_chart: &mut Chart, + passive_chart: &mut Chart, + grammar: &Grammar, +) { + visit(1, 0, n, 0, |i, j| { + // Try to apply rules starting with T + if let Some(matching) = grammar.start_t_by_word.get(&input[i]) { + for &ri in matching { + let mut new_item = Item::from_rule(ri, &grammar.rules[ri].rhs, i, i, 0); + if scan(&mut new_item, input, j, grammar) { + active_chart.at_mut(i, j).push(new_item); + } + } + } + + // Seed active chart with rules starting with NT + for (symbol, rule_indices) in &grammar.start_nt_by_symbol { + if !passive_chart.has_any_at(symbol, i) { + continue; + } + for &ri in rule_indices { + let rule = &grammar.rules[ri]; + if rule.rhs.len() > j - i { + continue; + } + active_chart + .at_mut(i, j) + .push(Item::from_rule(ri, &rule.rhs, i, i, 0)); + } + } + + // Parse + let mut new_symbols: Vec<String> = Vec::new(); + let mut remaining_items: Vec<Item> = Vec::new(); + + while !active_chart.at(i, j).is_empty() { + let active_item = active_chart.at_mut(i, j).pop().unwrap(); + let mut advanced = false; + + let active_rhs = &grammar.rules[active_item.rule_idx].rhs; + if active_item.dot < active_rhs.len() && active_rhs[active_item.dot].is_nt() { + let wanted_symbol = active_rhs[active_item.dot].nt_symbol(); + let k = active_item.right; + + // Use index to directly get matching right endpoints + let rights: Vec<u16> = passive_chart + .rights_for(wanted_symbol, k) + .iter() + .filter(|&&r| (r as usize) <= j) + .copied() + .collect(); + + for l16 in rights { + let l = l16 as usize; + + let mut new_item = + active_item.advance(active_item.left, l, active_item.dot + 1); + new_item.tail_spans[new_item.dot - 1] = Some(Span { + left: k as i32, + right: l as i32, + }); + + let new_rhs = &grammar.rules[new_item.rule_idx].rhs; + if scan(&mut new_item, input, j, grammar) { + if new_item.dot == new_rhs.len() { + if new_item.left == i && new_item.right == j { + let sym_str = grammar.rules[new_item.rule_idx].lhs.nt_symbol(); + if !new_symbols.iter().any(|s| s == sym_str) { + new_symbols.push(sym_str.to_string()); + } + passive_chart.add(new_item, i, j, sym_str); + advanced = true; + } + } else if new_item.right + (new_rhs.len() - new_item.dot) <= j { + active_chart.at_mut(i, j).push(new_item); + advanced = true; + } + } + } + } + + if !advanced { + remaining_items.push(active_item); + } + } + + // Self-filling step + let mut si = 0; + while si < new_symbols.len() { + let s = new_symbols[si].clone(); + + // Try start_nt rules from the grammar for this new symbol. + // This handles rules that weren't seeded in step 2 because + // the symbol didn't exist in the passive chart at that point. + if let Some(rule_indices) = grammar.start_nt_by_symbol.get(&s) { + for &ri in rule_indices { + let rule = &grammar.rules[ri]; + if rule.rhs.len() > j - i { + continue; + } + let mut new_item = Item::from_rule(ri, &rule.rhs, i, i, 0); + new_item.dot = 1; + new_item.right = j; + new_item.tail_spans[0] = Some(Span { + left: i as i32, + right: j as i32, + }); + if new_item.dot == rule.rhs.len() { + let sym_str = rule.lhs.nt_symbol(); + if !new_symbols.iter().any(|ns| ns == sym_str) { + new_symbols.push(sym_str.to_string()); + } + passive_chart.add(new_item, i, j, sym_str); + } + } + } + + // Also check remaining active items (for rules that were + // seeded but couldn't advance during the parse loop) + for item in &remaining_items { + if item.dot != 0 { + continue; + } + let item_rhs = &grammar.rules[item.rule_idx].rhs; + if !item_rhs[item.dot].is_nt() { + continue; + } + if item_rhs[item.dot].nt_symbol() != s { + continue; + } + let mut new_item = item.advance(i, j, item.dot + 1); + new_item.tail_spans[new_item.dot - 1] = Some(Span { + left: i as i32, + right: j as i32, + }); + let new_rhs = &grammar.rules[new_item.rule_idx].rhs; + if new_item.dot == new_rhs.len() { + let sym_str = grammar.rules[new_item.rule_idx].lhs.nt_symbol(); + if !new_symbols.iter().any(|ns| ns == sym_str) { + new_symbols.push(sym_str.to_string()); + } + passive_chart.add(new_item, i, j, sym_str); + } + } + si += 1; + } + }); +} diff --git a/rs/src/semiring.rs b/rs/src/semiring.rs new file mode 100644 index 0000000..9a8fe3e --- /dev/null +++ b/rs/src/semiring.rs @@ -0,0 +1,39 @@ +pub trait Semiring { + fn one() -> f64; + fn null() -> f64; + fn add(a: f64, b: f64) -> f64; + fn multiply(a: f64, b: f64) -> f64; +} + +pub struct ViterbiSemiring; + +impl Semiring for ViterbiSemiring { + fn one() -> f64 { + 1.0 + } + + fn null() -> f64 { + 0.0 + } + + fn add(a: f64, b: f64) -> f64 { + a.max(b) + } + + fn multiply(a: f64, b: f64) -> f64 { + a * b + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_viterbi() { + assert_eq!(ViterbiSemiring::one(), 1.0); + assert_eq!(ViterbiSemiring::null(), 0.0); + assert_eq!(ViterbiSemiring::add(0.3, 0.7), 0.7); + assert_eq!(ViterbiSemiring::multiply(0.5, 0.6), 0.3); + } +} diff --git a/rs/src/sparse_vector.rs b/rs/src/sparse_vector.rs new file mode 100644 index 0000000..4e62f95 --- /dev/null +++ b/rs/src/sparse_vector.rs @@ -0,0 +1,86 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone, Default)] +pub struct SparseVector { + pub map: HashMap<String, f64>, +} + +impl SparseVector { + pub fn new() -> Self { + Self { + map: HashMap::new(), + } + } + + pub fn from_kv(s: &str, kv_sep: char, pair_sep: char) -> Self { + let mut map = HashMap::new(); + for pair in s.split(pair_sep) { + let pair = pair.trim(); + if pair.is_empty() { + continue; + } + if let Some((k, v)) = pair.split_once(kv_sep) { + if let Ok(val) = v.trim().parse::<f64>() { + map.insert(k.trim().to_string(), val); + } + } + } + Self { map } + } + + pub fn from_hash(h: &serde_json::Map<String, serde_json::Value>) -> Self { + let mut map = HashMap::new(); + for (k, v) in h { + if let Some(val) = v.as_f64() { + map.insert(k.clone(), val); + } + } + Self { map } + } + + pub fn dot(&self, other: &SparseVector) -> f64 { + let mut sum = 0.0; + for (k, v) in &self.map { + if let Some(ov) = other.map.get(k) { + sum += v * ov; + } + } + sum + } + + pub fn to_json(&self) -> serde_json::Value { + let map: serde_json::Map<String, serde_json::Value> = self + .map + .iter() + .map(|(k, v)| (k.clone(), serde_json::Value::from(*v))) + .collect(); + serde_json::Value::Object(map) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_kv() { + let sv = SparseVector::from_kv("logp 2\nuse_house 0\nuse_shell 1", ' ', '\n'); + assert_eq!(sv.map["logp"], 2.0); + assert_eq!(sv.map["use_house"], 0.0); + assert_eq!(sv.map["use_shell"], 1.0); + } + + #[test] + fn test_dot() { + let a = SparseVector::from_kv("x=1 y=2", '=', ' '); + let b = SparseVector::from_kv("x=3 y=4 z=5", '=', ' '); + assert_eq!(a.dot(&b), 11.0); + } + + #[test] + fn test_empty_dot() { + let a = SparseVector::new(); + let b = SparseVector::from_kv("x=1", '=', ' '); + assert_eq!(a.dot(&b), 0.0); + } +} |
