From 77666a09c0f82b231605da463a946a5a5fcd09b6 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Tue, 24 Feb 2026 17:07:57 +0100 Subject: Fix reordering bug in derive and add test example derive used a sequential counter to index into the source-side NT map, which only worked for monotone rules. Now looks up tails by the target NT's own index via map.index(i.index). Adds toy-reorder example (German verb-final -> English SVO) to exercise the fix. Also updates trollop -> optimist and guards xmlsimple require. Co-Authored-By: Claude Opus 4.6 --- example/toy-reorder/grammar | 6 ++++++ example/toy-reorder/in | 1 + example/toy-reorder/weights.toy | 1 + prototype/hypergraph.rb | 4 +--- prototype/ow_proto.rb | 6 +++--- 5 files changed, 12 insertions(+), 6 deletions(-) create mode 100644 example/toy-reorder/grammar create mode 100644 example/toy-reorder/in create mode 100644 example/toy-reorder/weights.toy diff --git a/example/toy-reorder/grammar b/example/toy-reorder/grammar new file mode 100644 index 0000000..d93d98c --- /dev/null +++ b/example/toy-reorder/grammar @@ -0,0 +1,6 @@ +[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2] ||| logp=0 +[NP] ||| er ||| he ||| logp=-0.5 +[NP] ||| das [NN,1] ||| the [NN,1] ||| logp=0 +[NN] ||| buch ||| book ||| logp=0 +[VP] ||| [NP,1] [V,2] ||| [V,2] [NP,1] ||| logp=0 +[V] ||| liest ||| reads ||| logp=-0.25 diff --git a/example/toy-reorder/in b/example/toy-reorder/in new file mode 100644 index 0000000..6eeb3a9 --- /dev/null +++ b/example/toy-reorder/in @@ -0,0 +1 @@ +er das buch liest diff --git a/example/toy-reorder/weights.toy b/example/toy-reorder/weights.toy new file mode 100644 index 0000000..3eb2502 --- /dev/null +++ b/example/toy-reorder/weights.toy @@ -0,0 +1 @@ +logp 2 diff --git a/prototype/hypergraph.rb b/prototype/hypergraph.rb index fd72393..fdaba5a 100644 --- a/prototype/hypergraph.rb +++ b/prototype/hypergraph.rb @@ -196,11 +196,9 @@ def HG::derive path, cur, carry edge = path.select { |e| e.head.symbol==cur.symbol \ && e.head.left==cur.left \ && e.head.right==cur.right }.first - j = 0 edge.rule.target.each { |i| if i.class == Grammar::NT - derive path, edge.tails[edge.rule.map[j]], carry - j += 1 + derive path, edge.tails[edge.rule.map.index(i.index)], carry else carry << i end diff --git a/prototype/ow_proto.rb b/prototype/ow_proto.rb index 912090b..41fe683 100755 --- a/prototype/ow_proto.rb +++ b/prototype/ow_proto.rb @@ -1,7 +1,7 @@ #!/usr/bin/env ruby -require 'trollop' -require 'xmlsimple' +require 'optimist' +begin; require 'xmlsimple'; rescue LoadError; end require_relative 'parse' def read_grammar fn, add_glue, add_pass_through, input=nil @@ -19,7 +19,7 @@ def read_grammar fn, add_glue, add_pass_through, input=nil end def main - cfg = Trollop::options do + cfg = Optimist::options do opt :input, "", :type => :string, :default => '-', :short => '-i' opt :grammar, "", :type => :string, :required => true, :short => '-g' opt :weights, "", :type => :string, :required => true, :short => '-w' -- cgit v1.2.3 From e76951f21263eb7010a2898b9744364e989e90b8 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Tue, 24 Feb 2026 17:14:37 +0100 Subject: Deduplicate all_paths by reachable edge set The Cartesian product over all nodes produces duplicate derivations when edges differ only in nodes unreachable from the top. Walk reachable edges from the top edge of each path and drop paths with identical reachable sets. Co-Authored-By: Claude Opus 4.6 --- prototype/hypergraph.rb | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/prototype/hypergraph.rb b/prototype/hypergraph.rb index fdaba5a..08d1a29 100644 --- a/prototype/hypergraph.rb +++ b/prototype/hypergraph.rb @@ -189,7 +189,21 @@ def HG::all_paths hypergraph, root paths = new_paths } - return paths + seen = Set.new + paths.select { |p| + reachable = Set.new + mark_reachable p, p.last, reachable + key = reachable.map(&:object_id).sort + !seen.include?(key) && seen.add(key) + } +end + +def HG::mark_reachable path, edge, used + used << edge + edge.tails.each { |t| + child = path.find { |e| e.head == t } + mark_reachable path, child, used if child && !used.include?(child) + } end def HG::derive path, cur, carry -- cgit v1.2.3 From 22dc0fbdf002c7824941abc17a715a3e70ff37c1 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Tue, 24 Feb 2026 17:16:06 +0100 Subject: Emit binary glue rule only once The [S] -> [S] [X] concatenation rule was duplicated for every non-S LHS symbol. Move it out of the loop so it's added once. Co-Authored-By: Claude Opus 4.6 --- prototype/grammar.rb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prototype/grammar.rb b/prototype/grammar.rb index abccb15..4aebd95 100644 --- a/prototype/grammar.rb +++ b/prototype/grammar.rb @@ -120,9 +120,9 @@ class Grammar @rules.map { |r| r.lhs.symbol }.select { |s| s != 'S' }.uniq.each { |symbol| @rules << Rule.new(NT.new('S'), [NT.new(symbol, 0)], [NT.new(symbol, 0)], [0]) @start_nt << @rules.last - @rules << Rule.new(NT.new('S'), [NT.new('S', 0), NT.new('X', 1)], [NT.new('S', 0), NT.new('X', 1)], [0, 1]) - @start_nt << @rules.last } + @rules << Rule.new(NT.new('S'), [NT.new('S', 0), NT.new('X', 1)], [NT.new('S', 0), NT.new('X', 1)], [0, 1]) + @start_nt << @rules.last end def add_pass_through_rules a -- cgit v1.2.3 From 4e62908a1757f83ff703399252ad50758c4eb237 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Tue, 24 Feb 2026 17:18:29 +0100 Subject: Replace silent rescue with explicit type check in Item constructor When creating an Item from a Rule (not an Item), tail_spans doesn't exist. Check with is_a?(Item) instead of catching the exception silently. Co-Authored-By: Claude Opus 4.6 --- prototype/parse.rb | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/prototype/parse.rb b/prototype/parse.rb index adf2b91..40a69e7 100644 --- a/prototype/parse.rb +++ b/prototype/parse.rb @@ -90,14 +90,10 @@ class Item < Grammar::Rule rule_or_item.rhs.each_with_index { |x,i| # duplicate rhs partially @rhs << x if x.class == Grammar::NT - begin - if i >= dot - @tail_spans[i] = Span.new(-1, -1) - else - @tail_spans[i] = rule_or_item.tail_spans[i].dup - end - rescue + if i >= dot || !rule_or_item.is_a?(Item) @tail_spans[i] = Span.new(-1, -1) + else + @tail_spans[i] = rule_or_item.tail_spans[i].dup end end } -- cgit v1.2.3 From 0abcdd7e4358cb902c320b008d3c04bde07b749e Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Thu, 26 Feb 2026 19:28:22 +0100 Subject: Add Rust implementation of SCFG decoder Rust port of the Ruby prototype decoder with performance optimizations for real Hiero-style grammars: - Rule indexing by first terminal/NT symbol for fast lookup - Chart symbol interning (u16 IDs) instead of string hashing - Passive chart index by (symbol, left) for direct right-endpoint lookup - Items store rule index instead of cloned rule data Includes CKY+ parser, chart-to-hypergraph conversion, Viterbi decoding, derivation extraction, and JSON hypergraph I/O. Self-filling step in parse uses grammar lookup (not just remaining active items) to handle rules that were consumed during the parse loop or skipped by the has_any_at optimization. Produces identical output to the Ruby prototype on all test examples. Co-Authored-By: Claude Opus 4.6 --- rs/.gitignore | 1 + rs/Cargo.lock | 249 ++++++++++++++++++++++++++++ rs/Cargo.toml | 9 ++ rs/README.md | 39 +++++ rs/src/chart_to_hg.rs | 95 +++++++++++ rs/src/grammar.rs | 393 +++++++++++++++++++++++++++++++++++++++++++++ rs/src/hypergraph.rs | 115 +++++++++++++ rs/src/hypergraph_algos.rs | 215 +++++++++++++++++++++++++ rs/src/hypergraph_io.rs | 110 +++++++++++++ rs/src/lib.rs | 8 + rs/src/main.rs | 110 +++++++++++++ rs/src/parse.rs | 347 +++++++++++++++++++++++++++++++++++++++ rs/src/semiring.rs | 39 +++++ rs/src/sparse_vector.rs | 86 ++++++++++ 14 files changed, 1816 insertions(+) create mode 100644 rs/.gitignore create mode 100644 rs/Cargo.lock create mode 100644 rs/Cargo.toml create mode 100644 rs/README.md create mode 100644 rs/src/chart_to_hg.rs create mode 100644 rs/src/grammar.rs create mode 100644 rs/src/hypergraph.rs create mode 100644 rs/src/hypergraph_algos.rs create mode 100644 rs/src/hypergraph_io.rs create mode 100644 rs/src/lib.rs create mode 100644 rs/src/main.rs create mode 100644 rs/src/parse.rs create mode 100644 rs/src/semiring.rs create mode 100644 rs/src/sparse_vector.rs diff --git a/rs/.gitignore b/rs/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/rs/.gitignore @@ -0,0 +1 @@ +/target diff --git a/rs/Cargo.lock b/rs/Cargo.lock new file mode 100644 index 0000000..a113e8a --- /dev/null +++ b/rs/Cargo.lock @@ -0,0 +1,249 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys", +] + +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "odenwald" +version = "0.1.0" +dependencies = [ + "clap", + "serde", + "serde_json", +] + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/rs/Cargo.toml b/rs/Cargo.toml new file mode 100644 index 0000000..e37f06f --- /dev/null +++ b/rs/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "odenwald" +version = "0.1.0" +edition = "2021" + +[dependencies] +clap = { version = "4", features = ["derive"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" diff --git a/rs/README.md b/rs/README.md new file mode 100644 index 0000000..5daa458 --- /dev/null +++ b/rs/README.md @@ -0,0 +1,39 @@ +# odenwald + +Rust implementation of the Odenwald SCFG (synchronous context-free grammar) machine translation decoder. + +## Build + +``` +cargo build --release +``` + +## Usage + +``` +odenwald -g -w [-i ] [-l] [-p] +``` + +- `-g, --grammar` — grammar file (required) +- `-w, --weights` — weights file (required) +- `-i, --input` — input file (default: stdin) +- `-l, --add-glue` — add glue rules +- `-p, --add-pass-through` — add pass-through rules + +Output: `translation ||| log_score` per input line. + +## Examples + +``` +cargo run -- -g ../example/toy/grammar -w ../example/toy/weights.toy -i ../example/toy/in -l +# → i saw a small shell ||| -0.5 + +cargo run -- -g ../example/toy-reorder/grammar -w ../example/toy-reorder/weights.toy -i ../example/toy-reorder/in -l +# → he reads the book ||| -1.5 +``` + +## Tests + +``` +cargo test +``` diff --git a/rs/src/chart_to_hg.rs b/rs/src/chart_to_hg.rs new file mode 100644 index 0000000..95f642e --- /dev/null +++ b/rs/src/chart_to_hg.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; + +use crate::grammar::{Grammar, Rule, Symbol}; +use crate::hypergraph::{Hypergraph, NodeId}; +use crate::parse::{Chart, visit}; +use crate::sparse_vector::SparseVector; + +pub fn chart_to_hg(chart: &Chart, n: usize, weights: &SparseVector, grammar: &Grammar) -> Hypergraph { + let mut hg = Hypergraph::new(); + let mut seen: HashMap = HashMap::new(); + + // Add root node with id=-1 + let root_nid = hg.add_node(-1, "root", -1, -1); + + let mut next_id: i64 = 0; + + // First pass: create nodes + visit(1, 0, n, 0, |i, j| { + for item in chart.at(i, j) { + let lhs_sym = grammar.rules[item.rule_idx].lhs.nt_symbol(); + let key = format!("{},{},{}", lhs_sym, i, j); + if !seen.contains_key(&key) { + let nid = hg.add_node(next_id, lhs_sym, i as i32, j as i32); + seen.insert(key, nid); + next_id += 1; + } + } + }); + + // Second pass: create edges + visit(1, 0, n, 0, |i, j| { + for item in chart.at(i, j) { + let rule = &grammar.rules[item.rule_idx]; + let head_key = format!("{},{},{}", rule.lhs.nt_symbol(), i, j); + let head_nid = *seen.get(&head_key).unwrap(); + + // Build tails + let tails: Vec = if item.tail_spans.iter().all(|s| s.is_none()) + || item.tail_spans.iter().all(|s| { + s.as_ref() + .map_or(true, |sp| sp.left == -1 && sp.right == -1) + }) + { + if rule.rhs.iter().any(|s| s.is_nt()) { + build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid) + } else { + vec![root_nid] + } + } else { + build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid) + }; + + let score = weights.dot(&rule.f).exp(); + let rule_clone = Rule::new( + rule.lhs.clone(), + rule.rhs.clone(), + rule.target.clone(), + rule.map.clone(), + rule.f.clone(), + ); + + hg.add_edge(head_nid, tails, score, rule.f.clone(), rule_clone); + } + }); + + hg +} + +fn build_tails( + rhs: &[Symbol], + tail_spans: &[Option], + seen: &HashMap, + root_nid: NodeId, +) -> Vec { + let mut tails = Vec::new(); + let mut has_nt = false; + for (idx, sym) in rhs.iter().enumerate() { + if let Symbol::NT { symbol, .. } = sym { + has_nt = true; + if let Some(Some(span)) = tail_spans.get(idx) { + if span.left >= 0 && span.right >= 0 { + let key = format!("{},{},{}", symbol, span.left, span.right); + if let Some(&nid) = seen.get(&key) { + tails.push(nid); + } + } + } + } + } + if !has_nt { + vec![root_nid] + } else { + tails + } +} diff --git a/rs/src/grammar.rs b/rs/src/grammar.rs new file mode 100644 index 0000000..cef49ee --- /dev/null +++ b/rs/src/grammar.rs @@ -0,0 +1,393 @@ +use std::fmt; +use std::fs; +use std::io::{self, BufRead}; + +use crate::sparse_vector::SparseVector; + +#[derive(Debug, Clone, PartialEq)] +pub enum Symbol { + T(String), + NT { symbol: String, index: i32 }, +} + +impl Symbol { + pub fn parse(s: &str) -> Symbol { + let s = s.trim(); + if s.starts_with('[') && s.ends_with(']') { + let inner = &s[1..s.len() - 1]; + if let Some((sym, idx)) = inner.split_once(',') { + Symbol::NT { + symbol: sym.trim().to_string(), + index: idx.trim().parse::().unwrap() - 1, + } + } else { + Symbol::NT { + symbol: inner.trim().to_string(), + index: -1, + } + } + } else { + Symbol::T(s.to_string()) + } + } + + pub fn is_nt(&self) -> bool { + matches!(self, Symbol::NT { .. }) + } + + pub fn is_t(&self) -> bool { + matches!(self, Symbol::T(_)) + } + + pub fn word(&self) -> &str { + match self { + Symbol::T(w) => w, + Symbol::NT { .. } => panic!("word() called on NT"), + } + } + + pub fn nt_symbol(&self) -> &str { + match self { + Symbol::NT { symbol, .. } => symbol, + Symbol::T(_) => panic!("nt_symbol() called on T"), + } + } + + pub fn nt_index(&self) -> i32 { + match self { + Symbol::NT { index, .. } => *index, + Symbol::T(_) => panic!("nt_index() called on T"), + } + } +} + +impl fmt::Display for Symbol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Symbol::T(w) => write!(f, "{}", w), + Symbol::NT { symbol, index } if *index >= 0 => { + write!(f, "[{},{}]", symbol, index + 1) + } + Symbol::NT { symbol, .. } => write!(f, "[{}]", symbol), + } + } +} + +fn splitpipe(s: &str, n: usize) -> Vec<&str> { + let mut parts = Vec::new(); + let mut rest = s; + for _ in 0..n { + if let Some(pos) = rest.find("|||") { + parts.push(rest[..pos].trim()); + rest = rest[pos + 3..].trim(); + } else { + break; + } + } + parts.push(rest.trim()); + parts +} + +#[derive(Debug, Clone)] +pub struct Rule { + pub lhs: Symbol, + pub rhs: Vec, + pub target: Vec, + pub map: Vec, + pub f: SparseVector, + pub arity: usize, +} + +impl Rule { + pub fn new( + lhs: Symbol, + rhs: Vec, + target: Vec, + map: Vec, + f: SparseVector, + ) -> Self { + let arity = rhs.iter().filter(|s| s.is_nt()).count(); + Self { + lhs, + rhs, + target, + map, + f, + arity, + } + } + + pub fn from_str(s: &str) -> Self { + let parts = splitpipe(s, 3); + let lhs = Symbol::parse(parts[0]); + + let mut map = Vec::new(); + let mut arity = 0; + let rhs: Vec = parts[1] + .split_whitespace() + .map(|tok| { + let sym = Symbol::parse(tok); + if let Symbol::NT { index, .. } = &sym { + map.push(*index); + arity += 1; + } + sym + }) + .collect(); + + let target: Vec = parts[2] + .split_whitespace() + .map(|tok| Symbol::parse(tok)) + .collect(); + + let f = if parts.len() > 3 && !parts[3].is_empty() { + SparseVector::from_kv(parts[3], '=', ' ') + } else { + SparseVector::new() + }; + + Self { + lhs, + rhs, + target, + map, + f, + arity, + } + } +} + +impl fmt::Display for Rule { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} ||| {} ||| {}", + self.lhs, + self.rhs + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(" "), + self.target + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(" "), + ) + } +} + +pub struct Grammar { + pub rules: Vec, + pub start_nt: Vec, + pub start_t: Vec, + pub flat: Vec, + pub start_t_by_word: std::collections::HashMap>, + pub flat_by_first_word: std::collections::HashMap>, + pub start_nt_by_symbol: std::collections::HashMap>, +} + +impl Grammar { + pub fn from_file(path: &str) -> io::Result { + let file = fs::File::open(path)?; + let reader = io::BufReader::new(file); + let mut rules = Vec::new(); + let mut start_nt = Vec::new(); + let mut start_t = Vec::new(); + let mut flat = Vec::new(); + let mut start_t_by_word: std::collections::HashMap> = + std::collections::HashMap::new(); + let mut flat_by_first_word: std::collections::HashMap> = + std::collections::HashMap::new(); + let mut start_nt_by_symbol: std::collections::HashMap> = + std::collections::HashMap::new(); + + for line in reader.lines() { + let line = line?; + let line = line.trim().to_string(); + if line.is_empty() { + continue; + } + let idx = rules.len(); + let rule = Rule::from_str(&line); + if rule.rhs.first().map_or(false, |s| s.is_nt()) { + start_nt_by_symbol + .entry(rule.rhs[0].nt_symbol().to_string()) + .or_default() + .push(idx); + start_nt.push(idx); + } else if rule.arity == 0 { + flat_by_first_word + .entry(rule.rhs[0].word().to_string()) + .or_default() + .push(idx); + flat.push(idx); + } else { + start_t_by_word + .entry(rule.rhs[0].word().to_string()) + .or_default() + .push(idx); + start_t.push(idx); + } + rules.push(rule); + } + + Ok(Self { + rules, + start_nt, + start_t, + flat, + start_t_by_word, + flat_by_first_word, + start_nt_by_symbol, + }) + } + + pub fn add_glue_rules(&mut self) { + let symbols: Vec = self + .rules + .iter() + .map(|r| r.lhs.nt_symbol().to_string()) + .filter(|s| s != "S") + .collect::>() + .into_iter() + .collect(); + + let mut once = false; + for symbol in symbols { + let idx = self.rules.len(); + self.rules.push(Rule::new( + Symbol::NT { + symbol: "S".to_string(), + index: -1, + }, + vec![Symbol::NT { + symbol: symbol.clone(), + index: 0, + }], + vec![Symbol::NT { + symbol: symbol.clone(), + index: 0, + }], + vec![0], + SparseVector::new(), + )); + self.start_nt.push(idx); + self.start_nt_by_symbol + .entry(symbol) + .or_default() + .push(idx); + once = true; + } + + if once { + let idx = self.rules.len(); + self.rules.push(Rule::new( + Symbol::NT { + symbol: "S".to_string(), + index: -1, + }, + vec![ + Symbol::NT { + symbol: "S".to_string(), + index: 0, + }, + Symbol::NT { + symbol: "X".to_string(), + index: 1, + }, + ], + vec![ + Symbol::NT { + symbol: "S".to_string(), + index: 0, + }, + Symbol::NT { + symbol: "X".to_string(), + index: 1, + }, + ], + vec![0, 1], + SparseVector::new(), + )); + self.start_nt.push(idx); + self.start_nt_by_symbol + .entry("S".to_string()) + .or_default() + .push(idx); + } + } + + pub fn add_pass_through_rules(&mut self, words: &[String]) { + for word in words { + let idx = self.rules.len(); + self.rules.push(Rule::new( + Symbol::NT { + symbol: "X".to_string(), + index: -1, + }, + vec![Symbol::T(word.clone())], + vec![Symbol::T(word.clone())], + vec![], + SparseVector::new(), + )); + self.flat.push(idx); + self.flat_by_first_word + .entry(word.clone()) + .or_default() + .push(idx); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_symbol_parse_t() { + let s = Symbol::parse("hello"); + assert_eq!(s, Symbol::T("hello".to_string())); + } + + #[test] + fn test_symbol_parse_nt() { + let s = Symbol::parse("[NP,1]"); + assert_eq!( + s, + Symbol::NT { + symbol: "NP".to_string(), + index: 0 + } + ); + } + + #[test] + fn test_symbol_parse_nt_no_index() { + let s = Symbol::parse("[S]"); + assert_eq!( + s, + Symbol::NT { + symbol: "S".to_string(), + index: -1 + } + ); + } + + #[test] + fn test_rule_from_str() { + let r = Rule::from_str("[NP] ||| ein [NN,1] ||| a [NN,1] ||| logp=0 use_a=1.0"); + assert_eq!(r.lhs.nt_symbol(), "NP"); + assert_eq!(r.rhs.len(), 2); + assert_eq!(r.target.len(), 2); + assert_eq!(r.map, vec![0]); + assert_eq!(r.arity, 1); + assert_eq!(*r.f.map.get("use_a").unwrap(), 1.0); + } + + #[test] + fn test_rule_display() { + let r = Rule::from_str("[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2] ||| logp=0"); + assert_eq!(r.to_string(), "[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2]"); + } +} diff --git a/rs/src/hypergraph.rs b/rs/src/hypergraph.rs new file mode 100644 index 0000000..90069b0 --- /dev/null +++ b/rs/src/hypergraph.rs @@ -0,0 +1,115 @@ +use crate::grammar::Rule; +use crate::sparse_vector::SparseVector; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct NodeId(pub usize); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct EdgeId(pub usize); + +#[derive(Debug, Clone)] +pub struct Node { + pub id: i64, + pub symbol: String, + pub left: i32, + pub right: i32, + pub outgoing: Vec, + pub incoming: Vec, + pub score: f64, +} + +impl Node { + pub fn new(id: i64, symbol: &str, left: i32, right: i32) -> Self { + Self { + id, + symbol: symbol.to_string(), + left, + right, + outgoing: Vec::new(), + incoming: Vec::new(), + score: 0.0, + } + } +} + +#[derive(Debug, Clone)] +pub struct Edge { + pub head: NodeId, + pub tails: Vec, + pub score: f64, + pub f: SparseVector, + pub mark: usize, + pub rule: Rule, +} + +impl Edge { + pub fn arity(&self) -> usize { + self.tails.len() + } + + pub fn marked(&self) -> bool { + self.arity() == self.mark + } +} + +#[derive(Debug)] +pub struct Hypergraph { + pub nodes: Vec, + pub edges: Vec, + pub nodes_by_id: std::collections::HashMap, +} + +impl Hypergraph { + pub fn new() -> Self { + Self { + nodes: Vec::new(), + edges: Vec::new(), + nodes_by_id: std::collections::HashMap::new(), + } + } + + pub fn add_node(&mut self, id: i64, symbol: &str, left: i32, right: i32) -> NodeId { + let nid = NodeId(self.nodes.len()); + self.nodes.push(Node::new(id, symbol, left, right)); + self.nodes_by_id.insert(id, nid); + nid + } + + pub fn add_edge( + &mut self, + head: NodeId, + tails: Vec, + score: f64, + f: SparseVector, + rule: Rule, + ) -> EdgeId { + let eid = EdgeId(self.edges.len()); + self.edges.push(Edge { + head, + tails: tails.clone(), + score, + f, + mark: 0, + rule, + }); + self.nodes[head.0].incoming.push(eid); + for &t in &tails { + self.nodes[t.0].outgoing.push(eid); + } + eid + } + + pub fn node(&self, nid: NodeId) -> &Node { + &self.nodes[nid.0] + } + + pub fn edge(&self, eid: EdgeId) -> &Edge { + &self.edges[eid.0] + } + + pub fn reset(&mut self) { + for e in &mut self.edges { + e.mark = 0; + } + } +} diff --git a/rs/src/hypergraph_algos.rs b/rs/src/hypergraph_algos.rs new file mode 100644 index 0000000..a4878b6 --- /dev/null +++ b/rs/src/hypergraph_algos.rs @@ -0,0 +1,215 @@ +use std::collections::{BTreeSet, HashSet, VecDeque}; + +use crate::grammar::Symbol; +use crate::hypergraph::{EdgeId, Hypergraph, NodeId}; +use crate::semiring::Semiring; + +pub fn topological_sort(hg: &mut Hypergraph) -> Vec { + let mut sorted = Vec::new(); + let mut queue: VecDeque = VecDeque::new(); + + // Start with nodes that have no incoming edges + for (i, node) in hg.nodes.iter().enumerate() { + if node.incoming.is_empty() { + queue.push_back(NodeId(i)); + } + } + + while let Some(nid) = queue.pop_front() { + sorted.push(nid); + let outgoing = hg.nodes[nid.0].outgoing.clone(); + for &eid in &outgoing { + let edge = &mut hg.edges[eid.0]; + if edge.marked() { + continue; + } + edge.mark += 1; + if edge.marked() { + let head = edge.head; + // Check if all incoming edges of head are marked + let all_marked = hg.nodes[head.0] + .incoming + .iter() + .all(|&ie| hg.edges[ie.0].marked()); + if all_marked { + queue.push_back(head); + } + } + } + } + + sorted +} + +pub fn viterbi_path(hg: &mut Hypergraph) -> (Vec, f64) { + let toposorted = topological_sort(hg); + + // Init + for node in &mut hg.nodes { + node.score = S::null(); + } + if let Some(&first) = toposorted.first() { + hg.nodes[first.0].score = S::one(); + } + + let mut best_path: Vec = Vec::new(); + + for &nid in &toposorted { + let incoming = hg.nodes[nid.0].incoming.clone(); + let mut best_edge: Option = None; + for &eid in &incoming { + let edge = &hg.edges[eid.0]; + let mut s = S::one(); + for &tid in &edge.tails { + s = S::multiply(s, hg.nodes[tid.0].score); + } + let candidate = S::multiply(s, edge.score); + if hg.nodes[nid.0].score < candidate { + best_edge = Some(eid); + } + hg.nodes[nid.0].score = S::add(hg.nodes[nid.0].score, candidate); + } + if let Some(e) = best_edge { + best_path.push(e); + } + } + + let final_score = toposorted + .last() + .map(|&nid| hg.nodes[nid.0].score) + .unwrap_or(0.0); + + (best_path, final_score) +} + +pub fn derive(hg: &Hypergraph, path: &[EdgeId], cur: NodeId, carry: &mut Vec) { + // Find edge in path whose head matches cur + let node = &hg.nodes[cur.0]; + let edge_id = path + .iter() + .find(|&&eid| { + let e = &hg.edges[eid.0]; + let h = &hg.nodes[e.head.0]; + h.symbol == node.symbol && h.left == node.left && h.right == node.right + }) + .expect("derive: no matching edge found"); + + let edge = &hg.edges[edge_id.0]; + for sym in &edge.rule.target { + match sym { + Symbol::NT { index, .. } => { + // Find which tail to recurse into using map + let tail_idx = edge + .rule + .map + .iter() + .position(|&m| m == *index) + .expect("derive: NT index not found in map"); + derive(hg, path, edge.tails[tail_idx], carry); + } + Symbol::T(word) => { + carry.push(word.clone()); + } + } + } +} + +pub fn all_paths(hg: &mut Hypergraph) -> Vec> { + let toposorted = topological_sort(hg); + + let mut paths: Vec> = vec![vec![]]; + + for &nid in &toposorted { + let incoming = hg.nodes[nid.0].incoming.clone(); + if incoming.is_empty() { + continue; + } + let mut new_paths = Vec::new(); + while let Some(p) = paths.pop() { + for &eid in &incoming { + let mut np = p.clone(); + np.push(eid); + new_paths.push(np); + } + } + paths = new_paths; + } + + // Dedup by reachable edge set + let mut seen: HashSet> = HashSet::new(); + paths + .into_iter() + .filter(|p| { + if p.is_empty() { + return false; + } + let mut reachable = BTreeSet::new(); + mark_reachable(hg, p, *p.last().unwrap(), &mut reachable); + let key: Vec = reachable.into_iter().map(|eid| eid.0).collect(); + seen.insert(key) + }) + .collect() +} + +fn mark_reachable( + hg: &Hypergraph, + path: &[EdgeId], + edge_id: EdgeId, + used: &mut BTreeSet, +) { + used.insert(edge_id); + let edge = &hg.edges[edge_id.0]; + for &tail_nid in &edge.tails { + // Find edge in path whose head is this tail node + if let Some(&child_eid) = path.iter().find(|&&eid| hg.edges[eid.0].head == tail_nid) { + if !used.contains(&child_eid) { + mark_reachable(hg, path, child_eid, used); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hypergraph_io::read_hypergraph_from_json; + use crate::semiring::ViterbiSemiring; + + #[test] + fn test_viterbi_toy() { + let mut hg = read_hypergraph_from_json("../example/toy/toy.json", true).unwrap(); + let (path, score) = viterbi_path::(&mut hg); + let mut carry = Vec::new(); + let last_head = hg.edges[path.last().unwrap().0].head; + derive(&hg, &path, last_head, &mut carry); + + let translation = carry.join(" "); + let log_score = score.ln(); + + assert_eq!(translation, "i saw a small shell"); + assert!((log_score - (-0.5)).abs() < 1e-9); + } + + #[test] + fn test_all_paths_toy() { + let mut hg = read_hypergraph_from_json("../example/toy/toy.json", true).unwrap(); + + let paths = all_paths(&mut hg); + // The toy hypergraph should have multiple distinct paths + assert!(paths.len() > 1); + + // Collect all translations + let mut translations: Vec = Vec::new(); + for p in &paths { + let mut carry = Vec::new(); + let last_head = hg.edges[p.last().unwrap().0].head; + derive(&hg, p, last_head, &mut carry); + translations.push(carry.join(" ")); + } + + assert!(translations.contains(&"i saw a small shell".to_string())); + assert!(translations.contains(&"i saw a small house".to_string())); + assert!(translations.contains(&"i saw a little shell".to_string())); + assert!(translations.contains(&"i saw a little house".to_string())); + } +} diff --git a/rs/src/hypergraph_io.rs b/rs/src/hypergraph_io.rs new file mode 100644 index 0000000..5aa9f3b --- /dev/null +++ b/rs/src/hypergraph_io.rs @@ -0,0 +1,110 @@ +use std::fs; +use std::io; + +use crate::grammar::Rule; +use crate::hypergraph::{Hypergraph, NodeId}; +use crate::sparse_vector::SparseVector; + +pub fn read_hypergraph_from_json(path: &str, log_weights: bool) -> io::Result { + let content = fs::read_to_string(path)?; + let parsed: serde_json::Value = serde_json::from_str(&content) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let mut hg = Hypergraph::new(); + + let weights = match parsed.get("weights").and_then(|w| w.as_object()) { + Some(obj) => SparseVector::from_hash(obj), + None => SparseVector::new(), + }; + + // Add nodes + if let Some(nodes) = parsed.get("nodes").and_then(|n| n.as_array()) { + for n in nodes { + let id = n["id"].as_i64().unwrap(); + let cat = n["cat"].as_str().unwrap_or(""); + let span = n["span"].as_array().unwrap(); + let left = span[0].as_i64().unwrap() as i32; + let right = span[1].as_i64().unwrap() as i32; + hg.add_node(id, cat, left, right); + } + } + + // Add edges + if let Some(edges) = parsed.get("edges").and_then(|e| e.as_array()) { + for e in edges { + let head_id = e["head"].as_i64().unwrap(); + let head_nid = *hg.nodes_by_id.get(&head_id).unwrap(); + + let tails: Vec = e["tails"] + .as_array() + .unwrap() + .iter() + .map(|t| *hg.nodes_by_id.get(&t.as_i64().unwrap()).unwrap()) + .collect(); + + let f = match e.get("f").and_then(|f| f.as_object()) { + Some(obj) => SparseVector::from_hash(obj), + None => SparseVector::new(), + }; + + let rule_str = e["rule"].as_str().unwrap(); + let rule = Rule::from_str(rule_str); + + let score = if log_weights { + weights.dot(&f).exp() + } else { + weights.dot(&f) + }; + + hg.add_edge(head_nid, tails, score, f, rule); + } + } + + Ok(hg) +} + +pub fn write_hypergraph_to_json(hg: &Hypergraph, weights: &SparseVector) -> String { + let mut json_s = String::from("{\n"); + json_s.push_str(&format!( + "\"weights\":{},\n", + serde_json::to_string(&weights.to_json()).unwrap() + )); + + json_s.push_str("\"nodes\":\n[\n"); + let node_strs: Vec = hg + .nodes + .iter() + .map(|n| { + format!( + "{{ \"id\":{}, \"cat\":\"{}\", \"span\":[{},{}] }}", + n.id, + n.symbol.replace('"', "\\\""), + n.left, + n.right + ) + }) + .collect(); + json_s.push_str(&node_strs.join(",\n")); + json_s.push_str("\n],\n"); + + json_s.push_str("\"edges\":\n[\n"); + let edge_strs: Vec = hg + .edges + .iter() + .map(|e| { + let head_id = hg.nodes[e.head.0].id; + let tail_ids: Vec = e.tails.iter().map(|t| hg.nodes[t.0].id).collect(); + format!( + "{{ \"head\":{}, \"rule\":\"{}\", \"tails\":{:?}, \"f\":{} }}", + head_id, + e.rule.to_string().replace('"', "\\\""), + tail_ids, + serde_json::to_string(&e.f.to_json()).unwrap() + ) + }) + .collect(); + json_s.push_str(&edge_strs.join(",\n")); + json_s.push_str("\n]\n}\n"); + + json_s +} diff --git a/rs/src/lib.rs b/rs/src/lib.rs new file mode 100644 index 0000000..d446e19 --- /dev/null +++ b/rs/src/lib.rs @@ -0,0 +1,8 @@ +pub mod chart_to_hg; +pub mod grammar; +pub mod hypergraph; +pub mod hypergraph_algos; +pub mod hypergraph_io; +pub mod parse; +pub mod semiring; +pub mod sparse_vector; diff --git a/rs/src/main.rs b/rs/src/main.rs new file mode 100644 index 0000000..f03dbf5 --- /dev/null +++ b/rs/src/main.rs @@ -0,0 +1,110 @@ +use std::fs; +use std::io::{self, BufRead}; + +use clap::Parser; + +use odenwald::chart_to_hg::chart_to_hg; +use odenwald::grammar::Grammar; +use odenwald::hypergraph_algos::{derive, viterbi_path}; +use odenwald::parse::{self, Chart}; +use odenwald::semiring::ViterbiSemiring; +use odenwald::sparse_vector::SparseVector; + +#[derive(Parser)] +#[command(name = "odenwald")] +struct Cli { + /// Grammar file + #[arg(short = 'g', long)] + grammar: String, + + /// Weights file + #[arg(short = 'w', long)] + weights: String, + + /// Input file (default: stdin) + #[arg(short = 'i', long, default_value = "-")] + input: String, + + /// Add glue rules + #[arg(short = 'l', long)] + add_glue: bool, + + /// Add pass-through rules + #[arg(short = 'p', long)] + add_pass_through: bool, +} + +fn main() { + let cli = Cli::parse(); + + eprintln!("> reading grammar '{}'", cli.grammar); + let mut grammar = Grammar::from_file(&cli.grammar).expect("Failed to read grammar"); + + if cli.add_glue { + eprintln!(">> adding glue rules"); + grammar.add_glue_rules(); + } + + let weights_content = fs::read_to_string(&cli.weights).expect("Failed to read weights"); + let weights = SparseVector::from_kv(&weights_content, ' ', '\n'); + + let input_lines: Vec = if cli.input == "-" { + let stdin = io::stdin(); + stdin.lock().lines().map(|l| l.unwrap()).collect() + } else { + let content = fs::read_to_string(&cli.input).expect("Failed to read input"); + content.lines().map(|l| l.to_string()).collect() + }; + + for line in &input_lines { + let line = line.trim(); + if line.is_empty() { + continue; + } + let input: Vec = line.split_whitespace().map(|s| s.to_string()).collect(); + let n = input.len(); + + if cli.add_pass_through { + eprintln!(">> adding pass-through rules"); + grammar.add_pass_through_rules(&input); + } + + eprintln!("> initializing charts"); + let mut passive_chart = Chart::new(n); + let mut active_chart = Chart::new(n); + parse::init(&input, n, &mut active_chart, &mut passive_chart, &grammar); + + eprintln!("> parsing"); + parse::parse(&input, n, &mut active_chart, &mut passive_chart, &grammar); + + // Count passive chart items + let mut chart_items = 0; + let mut chart_cells = 0; + parse::visit(1, 0, n, 0, |i, j| { + let items = passive_chart.at(i, j).len(); + if items > 0 { + chart_cells += 1; + chart_items += items; + } + }); + eprintln!("> chart: {} items in {} cells", chart_items, chart_cells); + + let mut hg = chart_to_hg(&passive_chart, n, &weights, &grammar); + eprintln!("> hg: {} nodes, {} edges", hg.nodes.len(), hg.edges.len()); + + eprintln!("> viterbi"); + let (path, score) = viterbi_path::(&mut hg); + + if path.is_empty() { + eprintln!("WARNING: no parse found"); + println!("NO_PARSE ||| 0.0"); + continue; + } + + let last_head = hg.edges[path.last().unwrap().0].head; + let mut carry = Vec::new(); + derive(&hg, &path, last_head, &mut carry); + + println!("{} ||| {}", carry.join(" "), score.ln()); + } +} diff --git a/rs/src/parse.rs b/rs/src/parse.rs new file mode 100644 index 0000000..7f64a89 --- /dev/null +++ b/rs/src/parse.rs @@ -0,0 +1,347 @@ +use crate::grammar::{Grammar, Symbol}; + +#[derive(Debug, Clone)] +pub struct Span { + pub left: i32, + pub right: i32, +} + +#[derive(Debug, Clone)] +pub struct Item { + pub rule_idx: usize, + pub left: usize, + pub right: usize, + pub dot: usize, + pub tail_spans: Vec>, +} + +impl Item { + pub fn from_rule(rule_idx: usize, rhs: &[Symbol], left: usize, right: usize, dot: usize) -> Self { + let tail_spans = rhs + .iter() + .map(|sym| { + if sym.is_nt() { + Some(Span { + left: -1, + right: -1, + }) + } else { + None + } + }) + .collect(); + + Self { + rule_idx, + left, + right, + dot, + tail_spans, + } + } + + pub fn advance(&self, left: usize, right: usize, dot: usize) -> Self { + Self { + rule_idx: self.rule_idx, + left, + right, + dot, + tail_spans: self.tail_spans.clone(), + } + } +} + +pub struct Chart { + _n: usize, + m: Vec>>, + b: std::collections::HashSet<(u16, u16, u16)>, + /// Index: (symbol_id, left) -> sorted vec of right endpoints + spans_by_left: std::collections::HashMap<(u16, u16), Vec>, + sym_to_id: std::collections::HashMap, + next_sym_id: u16, +} + +impl Chart { + pub fn new(n: usize) -> Self { + let mut m = Vec::with_capacity(n + 1); + for _ in 0..=n { + let mut row = Vec::with_capacity(n + 1); + for _ in 0..=n { + row.push(Vec::new()); + } + m.push(row); + } + Self { + _n: n, + m, + b: std::collections::HashSet::new(), + spans_by_left: std::collections::HashMap::new(), + sym_to_id: std::collections::HashMap::new(), + next_sym_id: 0, + } + } + + fn sym_id(&mut self, symbol: &str) -> u16 { + if let Some(&id) = self.sym_to_id.get(symbol) { + id + } else { + let id = self.next_sym_id; + self.next_sym_id += 1; + self.sym_to_id.insert(symbol.to_string(), id); + id + } + } + + fn sym_id_lookup(&self, symbol: &str) -> Option { + self.sym_to_id.get(symbol).copied() + } + + pub fn at(&self, i: usize, j: usize) -> &Vec { + &self.m[i][j] + } + + pub fn at_mut(&mut self, i: usize, j: usize) -> &mut Vec { + &mut self.m[i][j] + } + + pub fn add(&mut self, item: Item, i: usize, j: usize, symbol: &str) { + let sid = self.sym_id(symbol); + if self.b.insert((i as u16, j as u16, sid)) { + self.spans_by_left + .entry((sid, i as u16)) + .or_default() + .push(j as u16); + } + self.m[i][j].push(item); + } + + pub fn has(&self, symbol: &str, i: usize, j: usize) -> bool { + if let Some(&sid) = self.sym_to_id.get(symbol) { + self.b.contains(&(i as u16, j as u16, sid)) + } else { + false + } + } + + /// Returns right endpoints where (symbol, left, right) exists + pub fn rights_for(&self, symbol: &str, left: usize) -> &[u16] { + static EMPTY: Vec = Vec::new(); + if let Some(sid) = self.sym_id_lookup(symbol) { + if let Some(rights) = self.spans_by_left.get(&(sid, left as u16)) { + return rights; + } + } + &EMPTY + } + + /// Check if any entry for (symbol, left, *) exists + pub fn has_any_at(&self, symbol: &str, left: usize) -> bool { + if let Some(sid) = self.sym_id_lookup(symbol) { + self.spans_by_left.contains_key(&(sid, left as u16)) + } else { + false + } + } +} + +/// Visit spans bottom-up: from span size `from` up. +/// Yields (left, right) pairs. +pub fn visit(from: usize, l: usize, r: usize, x: usize, mut f: F) { + for span in from..=(r - x) { + for k in l..=(r - span) { + f(k, k + span); + } + } +} + +pub fn init(input: &[String], n: usize, _active_chart: &mut Chart, passive_chart: &mut Chart, grammar: &Grammar) { + for i in 0..n { + if let Some(matching) = grammar.flat_by_first_word.get(&input[i]) { + for &fi in matching { + let rule = &grammar.rules[fi]; + let rhs_len = rule.rhs.len(); + if i + rhs_len > n { + continue; + } + let matches = rule + .rhs + .iter() + .enumerate() + .skip(1) + .all(|(k, s)| s.word() == input[i + k]); + if matches { + let item = Item::from_rule(fi, &rule.rhs, i, i + rhs_len, rhs_len); + passive_chart.add(item, i, i + rhs_len, rule.lhs.nt_symbol()); + } + } + } + } +} + +fn scan(item: &mut Item, input: &[String], limit: usize, grammar: &Grammar) -> bool { + let rhs = &grammar.rules[item.rule_idx].rhs; + while item.dot < rhs.len() && rhs[item.dot].is_t() { + if item.right == limit { + return false; + } + if rhs[item.dot].word() == input[item.right] { + item.dot += 1; + item.right += 1; + } else { + return false; + } + } + true +} + +pub fn parse( + input: &[String], + n: usize, + active_chart: &mut Chart, + passive_chart: &mut Chart, + grammar: &Grammar, +) { + visit(1, 0, n, 0, |i, j| { + // Try to apply rules starting with T + if let Some(matching) = grammar.start_t_by_word.get(&input[i]) { + for &ri in matching { + let mut new_item = Item::from_rule(ri, &grammar.rules[ri].rhs, i, i, 0); + if scan(&mut new_item, input, j, grammar) { + active_chart.at_mut(i, j).push(new_item); + } + } + } + + // Seed active chart with rules starting with NT + for (symbol, rule_indices) in &grammar.start_nt_by_symbol { + if !passive_chart.has_any_at(symbol, i) { + continue; + } + for &ri in rule_indices { + let rule = &grammar.rules[ri]; + if rule.rhs.len() > j - i { + continue; + } + active_chart + .at_mut(i, j) + .push(Item::from_rule(ri, &rule.rhs, i, i, 0)); + } + } + + // Parse + let mut new_symbols: Vec = Vec::new(); + let mut remaining_items: Vec = Vec::new(); + + while !active_chart.at(i, j).is_empty() { + let active_item = active_chart.at_mut(i, j).pop().unwrap(); + let mut advanced = false; + + let active_rhs = &grammar.rules[active_item.rule_idx].rhs; + if active_item.dot < active_rhs.len() && active_rhs[active_item.dot].is_nt() { + let wanted_symbol = active_rhs[active_item.dot].nt_symbol(); + let k = active_item.right; + + // Use index to directly get matching right endpoints + let rights: Vec = passive_chart + .rights_for(wanted_symbol, k) + .iter() + .filter(|&&r| (r as usize) <= j) + .copied() + .collect(); + + for l16 in rights { + let l = l16 as usize; + + let mut new_item = + active_item.advance(active_item.left, l, active_item.dot + 1); + new_item.tail_spans[new_item.dot - 1] = Some(Span { + left: k as i32, + right: l as i32, + }); + + let new_rhs = &grammar.rules[new_item.rule_idx].rhs; + if scan(&mut new_item, input, j, grammar) { + if new_item.dot == new_rhs.len() { + if new_item.left == i && new_item.right == j { + let sym_str = grammar.rules[new_item.rule_idx].lhs.nt_symbol(); + if !new_symbols.iter().any(|s| s == sym_str) { + new_symbols.push(sym_str.to_string()); + } + passive_chart.add(new_item, i, j, sym_str); + advanced = true; + } + } else if new_item.right + (new_rhs.len() - new_item.dot) <= j { + active_chart.at_mut(i, j).push(new_item); + advanced = true; + } + } + } + } + + if !advanced { + remaining_items.push(active_item); + } + } + + // Self-filling step + let mut si = 0; + while si < new_symbols.len() { + let s = new_symbols[si].clone(); + + // Try start_nt rules from the grammar for this new symbol. + // This handles rules that weren't seeded in step 2 because + // the symbol didn't exist in the passive chart at that point. + if let Some(rule_indices) = grammar.start_nt_by_symbol.get(&s) { + for &ri in rule_indices { + let rule = &grammar.rules[ri]; + if rule.rhs.len() > j - i { + continue; + } + let mut new_item = Item::from_rule(ri, &rule.rhs, i, i, 0); + new_item.dot = 1; + new_item.right = j; + new_item.tail_spans[0] = Some(Span { + left: i as i32, + right: j as i32, + }); + if new_item.dot == rule.rhs.len() { + let sym_str = rule.lhs.nt_symbol(); + if !new_symbols.iter().any(|ns| ns == sym_str) { + new_symbols.push(sym_str.to_string()); + } + passive_chart.add(new_item, i, j, sym_str); + } + } + } + + // Also check remaining active items (for rules that were + // seeded but couldn't advance during the parse loop) + for item in &remaining_items { + if item.dot != 0 { + continue; + } + let item_rhs = &grammar.rules[item.rule_idx].rhs; + if !item_rhs[item.dot].is_nt() { + continue; + } + if item_rhs[item.dot].nt_symbol() != s { + continue; + } + let mut new_item = item.advance(i, j, item.dot + 1); + new_item.tail_spans[new_item.dot - 1] = Some(Span { + left: i as i32, + right: j as i32, + }); + let new_rhs = &grammar.rules[new_item.rule_idx].rhs; + if new_item.dot == new_rhs.len() { + let sym_str = grammar.rules[new_item.rule_idx].lhs.nt_symbol(); + if !new_symbols.iter().any(|ns| ns == sym_str) { + new_symbols.push(sym_str.to_string()); + } + passive_chart.add(new_item, i, j, sym_str); + } + } + si += 1; + } + }); +} diff --git a/rs/src/semiring.rs b/rs/src/semiring.rs new file mode 100644 index 0000000..9a8fe3e --- /dev/null +++ b/rs/src/semiring.rs @@ -0,0 +1,39 @@ +pub trait Semiring { + fn one() -> f64; + fn null() -> f64; + fn add(a: f64, b: f64) -> f64; + fn multiply(a: f64, b: f64) -> f64; +} + +pub struct ViterbiSemiring; + +impl Semiring for ViterbiSemiring { + fn one() -> f64 { + 1.0 + } + + fn null() -> f64 { + 0.0 + } + + fn add(a: f64, b: f64) -> f64 { + a.max(b) + } + + fn multiply(a: f64, b: f64) -> f64 { + a * b + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_viterbi() { + assert_eq!(ViterbiSemiring::one(), 1.0); + assert_eq!(ViterbiSemiring::null(), 0.0); + assert_eq!(ViterbiSemiring::add(0.3, 0.7), 0.7); + assert_eq!(ViterbiSemiring::multiply(0.5, 0.6), 0.3); + } +} diff --git a/rs/src/sparse_vector.rs b/rs/src/sparse_vector.rs new file mode 100644 index 0000000..4e62f95 --- /dev/null +++ b/rs/src/sparse_vector.rs @@ -0,0 +1,86 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone, Default)] +pub struct SparseVector { + pub map: HashMap, +} + +impl SparseVector { + pub fn new() -> Self { + Self { + map: HashMap::new(), + } + } + + pub fn from_kv(s: &str, kv_sep: char, pair_sep: char) -> Self { + let mut map = HashMap::new(); + for pair in s.split(pair_sep) { + let pair = pair.trim(); + if pair.is_empty() { + continue; + } + if let Some((k, v)) = pair.split_once(kv_sep) { + if let Ok(val) = v.trim().parse::() { + map.insert(k.trim().to_string(), val); + } + } + } + Self { map } + } + + pub fn from_hash(h: &serde_json::Map) -> Self { + let mut map = HashMap::new(); + for (k, v) in h { + if let Some(val) = v.as_f64() { + map.insert(k.clone(), val); + } + } + Self { map } + } + + pub fn dot(&self, other: &SparseVector) -> f64 { + let mut sum = 0.0; + for (k, v) in &self.map { + if let Some(ov) = other.map.get(k) { + sum += v * ov; + } + } + sum + } + + pub fn to_json(&self) -> serde_json::Value { + let map: serde_json::Map = self + .map + .iter() + .map(|(k, v)| (k.clone(), serde_json::Value::from(*v))) + .collect(); + serde_json::Value::Object(map) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_kv() { + let sv = SparseVector::from_kv("logp 2\nuse_house 0\nuse_shell 1", ' ', '\n'); + assert_eq!(sv.map["logp"], 2.0); + assert_eq!(sv.map["use_house"], 0.0); + assert_eq!(sv.map["use_shell"], 1.0); + } + + #[test] + fn test_dot() { + let a = SparseVector::from_kv("x=1 y=2", '=', ' '); + let b = SparseVector::from_kv("x=3 y=4 z=5", '=', ' '); + assert_eq!(a.dot(&b), 11.0); + } + + #[test] + fn test_empty_dot() { + let a = SparseVector::new(); + let b = SparseVector::from_kv("x=1", '=', ' '); + assert_eq!(a.dot(&b), 0.0); + } +} -- cgit v1.2.3 From 44f225d0642d2ecf13f533f68b9ae12d849809ea Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Thu, 26 Feb 2026 19:29:18 +0100 Subject: Fix two bugs in prototype parser 1. Inner visit at span (0,1) yielded no sub-spans because visit(1, 0, 1, 1) iterates span from 1 to r-x=0, which is empty. This prevented unary rules like [S] ||| [X,1] from completing at the leftmost span, so S(0,1) was never created. Drop the x=1 parameter (default x=0); scan already handles bounds checking. 2. Self-filling step searched remaining_items for unary NT rules, but those rules could be absent if consumed (advanced) during the parse loop. Look up grammar.start_nt directly instead, which covers all cases. Co-Authored-By: Claude Opus 4.6 --- prototype/parse.rb | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/prototype/parse.rb b/prototype/parse.rb index 40a69e7..1741030 100644 --- a/prototype/parse.rb +++ b/prototype/parse.rb @@ -152,7 +152,7 @@ def Parse::parse input, n, active_chart, passive_chart, grammar while !active_chart.at(i,j).empty? active_item = active_chart.at(i,j).pop advanced = false - visit(1, i, j, 1) { |k,l| + visit(1, i, j) { |k,l| if passive_chart.has active_item.rhs[active_item.dot].symbol, k, l if k == active_item.right new_item = Item.new active_item, active_item.left, l, active_item.dot+1 @@ -182,16 +182,15 @@ def Parse::parse input, n, active_chart, passive_chart, grammar # 'self-filling' step new_symbols.each { |s| - remaining_items.each { |item| - next if item.dot!=0 - next if item.rhs[item.dot].class!=Grammar::NT - if item.rhs[item.dot].symbol == s - new_item = Item.new item, i, j, item.dot+1 - new_item.tail_spans[new_item.dot-1] = Span.new(i,j) - if new_item.dot==new_item.rhs.size - new_symbols << new_item.lhs.symbol if !new_symbols.include? new_item.lhs.symbol - passive_chart.add new_item, i, j - end + grammar.start_nt.each { |r| + next if r.rhs.size > j-i + next if r.rhs.first.class!=Grammar::NT + next if r.rhs.first.symbol != s + new_item = Item.new r, i, j, 1 + new_item.tail_spans[0] = Span.new(i,j) + if new_item.dot==new_item.rhs.size + new_symbols << new_item.lhs.symbol if !new_symbols.include? new_item.lhs.symbol + passive_chart.add new_item, i, j end } } -- cgit v1.2.3 From a1c5862a46b524d3e11a87c5a732c0c257aefe20 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Thu, 26 Feb 2026 19:31:35 +0100 Subject: Fix C++ ow binary to produce translations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Node::mark and Node::score uninitialized, causing segfaults in topological_sort — add default initializers (0, 0.0) - odenwald.cc called incomplete sv_path() + exit(1) instead of viterbi_path() - viterbi_path: add reset() before topological_sort, initialize best_edge to nullptr - derive: off-by-one in NT order indexing — start j at 1 and use order[j]-1 (1-indexed alignment map) - read: ifs.readsome() returns 0 on macOS — use ifs.read() + ifs.gcount() - manual() signature: add missing Vocabulary parameter - Remove gperftools/tcmalloc dependency from Makefile Co-Authored-By: Claude Opus 4.6 --- Makefile | 8 +------- src/hypergraph.cc | 11 ++++++----- src/hypergraph.hh | 6 +++--- src/odenwald.cc | 3 +-- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/Makefile b/Makefile index f59d18e..8d5147b 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,5 @@ COMPILER=clang CFLAGS=-std=c++11 -O3 -Wall -TCMALLOC=external/gperftools/lib/libtcmalloc_minimal.a -pthread MSGPACK_C_INCLUDE=-I external/msgpack-c/include MSGPACK_C=external/msgpack-c/lib/libmsgpack.a $(MSGPACK_C_INCLUDE) JSON_CPP_INCLUDE=-I external/json-cpp/include @@ -14,7 +13,7 @@ PRINT_END = @echo -e "\e[1;32mfinished building $@\e[0m" ############################################################################### # all # -all: $(BIN)/ow util test +all: $(BIN)/ow test ############################################################################### # ow @@ -26,7 +25,6 @@ $(BIN)/ow: $(BIN) $(SRC)/hypergraph.o $(SRC)/odenwald.cc -lstdc++ \ -lm \ $(MSGPACK_C) \ - $(TCMALLOC) \ $(SRC)/hypergraph.o \ $(SRC)/odenwald.cc \ -o $(BIN)/ow @@ -84,7 +82,6 @@ $(BIN)/test_grammar: $(BIN) $(SRC)/test_grammar.cc $(SRC)/grammar.hh $(COMPILER) $(CFLAGS) \ -lstdc++ \ -lm \ - $(TCMALLOC) \ $(MSGPACK_C) \ $(SRC)/test_grammar.cc \ -o $(BIN)/test_grammar @@ -95,7 +92,6 @@ $(BIN)/test_hypergraph: $(BIN) $(SRC)/test_hypergraph.cc $(SRC)/hypergraph.o $(S $(COMPILER) $(CFLAGS) \ -lstdc++ \ -lm \ - $(TCMALLOC) \ $(MSGPACK_C) \ $(SRC)/hypergraph.o \ $(SRC)/test_hypergraph.cc \ @@ -108,7 +104,6 @@ $(BIN)/test_parse: $(BIN) $(SRC)/test_parse.cc $(SRC)/parse.hh \ $(COMPILER) $(CFLAGS) \ -lstdc++ \ -lm \ - $(TCMALLOC) \ $(MSGPACK_C) \ $(SRC)/test_parse.cc \ -o $(BIN)/test_parse @@ -119,7 +114,6 @@ $(BIN)/test_sparse_vector: $(BIN) $(SRC)/test_sparse_vector.cc $(SRC)/sparse_vec $(COMPILER) $(CFLAGS) \ -lstdc++ \ -lm \ - $(TCMALLOC) \ $(SRC)/test_sparse_vector.cc \ -o $(BIN)/test_sparse_vector $(PRINT_END) diff --git a/src/hypergraph.cc b/src/hypergraph.cc index 6ec8441..0c36abe 100644 --- a/src/hypergraph.cc +++ b/src/hypergraph.cc @@ -73,13 +73,13 @@ viterbi_path(Hypergraph& hg, Path& p) [](Node* n) { return n->incoming.size() == 0; }); //list::iterator root = hg.nodes.begin(); + Hg::reset(hg.nodes, hg.edges); Hg::topological_sort(hg.nodes, root); - // ^^^ FIXME do I need to do this when reading from file? Semiring::Viterbi semiring; Hg::init(hg.nodes, root, semiring); for (auto n: hg.nodes) { - Edge* best_edge; + Edge* best_edge = nullptr; bool best = false; for (auto e: n->incoming) { score_t s = semiring.one; @@ -135,10 +135,10 @@ derive(const Path& p, const Node* cur, vector& carry) } } // FIXME this is probably not so good - unsigned j = 0; + unsigned j = 1; for (auto it: next->rule->target) { if (it->type() == G::NON_TERMINAL) { - derive(p, next->tails[next->rule->order[j]], carry); + derive(p, next->tails[next->rule->order[j]-1], carry); j++; } else { carry.push_back(it->symbol()); @@ -156,7 +156,8 @@ read(Hypergraph& hg, vector& rules, G::Vocabulary& vocab, const string msgpack::unpacker pac; while(true) { pac.reserve_buffer(32*1024); - size_t bytes = ifs.readsome(pac.buffer(), pac.buffer_capacity()); + ifs.read(pac.buffer(), pac.buffer_capacity()); + size_t bytes = ifs.gcount(); pac.buffer_consumed(bytes); msgpack::unpacked result; while(pac.next(&result)) { diff --git a/src/hypergraph.hh b/src/hypergraph.hh index 7a268c3..d782c9e 100644 --- a/src/hypergraph.hh +++ b/src/hypergraph.hh @@ -48,10 +48,10 @@ struct Node { string symbol; short left; short right; - score_t score; + score_t score = 0.0; vector incoming; vector outgoing; - unsigned int mark; + unsigned int mark = 0; inline bool is_marked() { return mark >= incoming.size(); }; friend ostream& operator<<(ostream& os, const Node& n); @@ -98,7 +98,7 @@ void write(Hypergraph& hg, vector& rules, const string& fn); // FIXME void -manual(Hypergraph& hg, vector& rules); +manual(Hypergraph& hg, vector& rules, G::Vocabulary& vocab); } // namespace diff --git a/src/odenwald.cc b/src/odenwald.cc index a520d0b..bdf21f8 100644 --- a/src/odenwald.cc +++ b/src/odenwald.cc @@ -20,8 +20,7 @@ main(int argc, char** argv) // viterbi clock_t begin_viterbi = clock(); Hg::Path p; - Hg::sv_path(hg, p); - exit(1); + Hg::viterbi_path(hg, p); vector s; Hg::derive(p, p.back()->head, s); for (auto it: s) -- cgit v1.2.3 From 1377ffbdd2791e50cb3ca21d11c8c21febdbf911 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 27 Feb 2026 09:46:14 +0100 Subject: Ignore output directories Co-Authored-By: Claude Opus 4.6 --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 479d0be..fed879d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ *.o bin/ -#example/*/output +*/*/output external/cdec_json_serialization/ -- cgit v1.2.3