diff options
Diffstat (limited to 'rs/src/main.rs')
| -rw-r--r-- | rs/src/main.rs | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/rs/src/main.rs b/rs/src/main.rs new file mode 100644 index 0000000..f03dbf5 --- /dev/null +++ b/rs/src/main.rs @@ -0,0 +1,110 @@ +use std::fs; +use std::io::{self, BufRead}; + +use clap::Parser; + +use odenwald::chart_to_hg::chart_to_hg; +use odenwald::grammar::Grammar; +use odenwald::hypergraph_algos::{derive, viterbi_path}; +use odenwald::parse::{self, Chart}; +use odenwald::semiring::ViterbiSemiring; +use odenwald::sparse_vector::SparseVector; + +#[derive(Parser)] +#[command(name = "odenwald")] +struct Cli { + /// Grammar file + #[arg(short = 'g', long)] + grammar: String, + + /// Weights file + #[arg(short = 'w', long)] + weights: String, + + /// Input file (default: stdin) + #[arg(short = 'i', long, default_value = "-")] + input: String, + + /// Add glue rules + #[arg(short = 'l', long)] + add_glue: bool, + + /// Add pass-through rules + #[arg(short = 'p', long)] + add_pass_through: bool, +} + +fn main() { + let cli = Cli::parse(); + + eprintln!("> reading grammar '{}'", cli.grammar); + let mut grammar = Grammar::from_file(&cli.grammar).expect("Failed to read grammar"); + + if cli.add_glue { + eprintln!(">> adding glue rules"); + grammar.add_glue_rules(); + } + + let weights_content = fs::read_to_string(&cli.weights).expect("Failed to read weights"); + let weights = SparseVector::from_kv(&weights_content, ' ', '\n'); + + let input_lines: Vec<String> = if cli.input == "-" { + let stdin = io::stdin(); + stdin.lock().lines().map(|l| l.unwrap()).collect() + } else { + let content = fs::read_to_string(&cli.input).expect("Failed to read input"); + content.lines().map(|l| l.to_string()).collect() + }; + + for line in &input_lines { + let line = line.trim(); + if line.is_empty() { + continue; + } + let input: Vec<String> = line.split_whitespace().map(|s| s.to_string()).collect(); + let n = input.len(); + + if cli.add_pass_through { + eprintln!(">> adding pass-through rules"); + grammar.add_pass_through_rules(&input); + } + + eprintln!("> initializing charts"); + let mut passive_chart = Chart::new(n); + let mut active_chart = Chart::new(n); + parse::init(&input, n, &mut active_chart, &mut passive_chart, &grammar); + + eprintln!("> parsing"); + parse::parse(&input, n, &mut active_chart, &mut passive_chart, &grammar); + + // Count passive chart items + let mut chart_items = 0; + let mut chart_cells = 0; + parse::visit(1, 0, n, 0, |i, j| { + let items = passive_chart.at(i, j).len(); + if items > 0 { + chart_cells += 1; + chart_items += items; + } + }); + eprintln!("> chart: {} items in {} cells", chart_items, chart_cells); + + let mut hg = chart_to_hg(&passive_chart, n, &weights, &grammar); + eprintln!("> hg: {} nodes, {} edges", hg.nodes.len(), hg.edges.len()); + + eprintln!("> viterbi"); + let (path, score) = viterbi_path::<ViterbiSemiring>(&mut hg); + + if path.is_empty() { + eprintln!("WARNING: no parse found"); + println!("NO_PARSE ||| 0.0"); + continue; + } + + let last_head = hg.edges[path.last().unwrap().0].head; + let mut carry = Vec::new(); + derive(&hg, &path, last_head, &mut carry); + + println!("{} ||| {}", carry.join(" "), score.ln()); + } +} |
