summaryrefslogtreecommitdiff
path: root/rs/src/chart_to_hg.rs
diff options
context:
space:
mode:
Diffstat (limited to 'rs/src/chart_to_hg.rs')
-rw-r--r--rs/src/chart_to_hg.rs95
1 files changed, 95 insertions, 0 deletions
diff --git a/rs/src/chart_to_hg.rs b/rs/src/chart_to_hg.rs
new file mode 100644
index 0000000..95f642e
--- /dev/null
+++ b/rs/src/chart_to_hg.rs
@@ -0,0 +1,95 @@
+use std::collections::HashMap;
+
+use crate::grammar::{Grammar, Rule, Symbol};
+use crate::hypergraph::{Hypergraph, NodeId};
+use crate::parse::{Chart, visit};
+use crate::sparse_vector::SparseVector;
+
+pub fn chart_to_hg(chart: &Chart, n: usize, weights: &SparseVector, grammar: &Grammar) -> Hypergraph {
+ let mut hg = Hypergraph::new();
+ let mut seen: HashMap<String, NodeId> = HashMap::new();
+
+ // Add root node with id=-1
+ let root_nid = hg.add_node(-1, "root", -1, -1);
+
+ let mut next_id: i64 = 0;
+
+ // First pass: create nodes
+ visit(1, 0, n, 0, |i, j| {
+ for item in chart.at(i, j) {
+ let lhs_sym = grammar.rules[item.rule_idx].lhs.nt_symbol();
+ let key = format!("{},{},{}", lhs_sym, i, j);
+ if !seen.contains_key(&key) {
+ let nid = hg.add_node(next_id, lhs_sym, i as i32, j as i32);
+ seen.insert(key, nid);
+ next_id += 1;
+ }
+ }
+ });
+
+ // Second pass: create edges
+ visit(1, 0, n, 0, |i, j| {
+ for item in chart.at(i, j) {
+ let rule = &grammar.rules[item.rule_idx];
+ let head_key = format!("{},{},{}", rule.lhs.nt_symbol(), i, j);
+ let head_nid = *seen.get(&head_key).unwrap();
+
+ // Build tails
+ let tails: Vec<NodeId> = if item.tail_spans.iter().all(|s| s.is_none())
+ || item.tail_spans.iter().all(|s| {
+ s.as_ref()
+ .map_or(true, |sp| sp.left == -1 && sp.right == -1)
+ })
+ {
+ if rule.rhs.iter().any(|s| s.is_nt()) {
+ build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid)
+ } else {
+ vec![root_nid]
+ }
+ } else {
+ build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid)
+ };
+
+ let score = weights.dot(&rule.f).exp();
+ let rule_clone = Rule::new(
+ rule.lhs.clone(),
+ rule.rhs.clone(),
+ rule.target.clone(),
+ rule.map.clone(),
+ rule.f.clone(),
+ );
+
+ hg.add_edge(head_nid, tails, score, rule.f.clone(), rule_clone);
+ }
+ });
+
+ hg
+}
+
+fn build_tails(
+ rhs: &[Symbol],
+ tail_spans: &[Option<crate::parse::Span>],
+ seen: &HashMap<String, NodeId>,
+ root_nid: NodeId,
+) -> Vec<NodeId> {
+ let mut tails = Vec::new();
+ let mut has_nt = false;
+ for (idx, sym) in rhs.iter().enumerate() {
+ if let Symbol::NT { symbol, .. } = sym {
+ has_nt = true;
+ if let Some(Some(span)) = tail_spans.get(idx) {
+ if span.left >= 0 && span.right >= 0 {
+ let key = format!("{},{},{}", symbol, span.left, span.right);
+ if let Some(&nid) = seen.get(&key) {
+ tails.push(nid);
+ }
+ }
+ }
+ }
+ }
+ if !has_nt {
+ vec![root_nid]
+ } else {
+ tails
+ }
+}