summaryrefslogtreecommitdiff
path: root/rs/src/chart_to_hg.rs
blob: 95f642e70e1439f75b5348ed430563fb25476197 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
use std::collections::HashMap;

use crate::grammar::{Grammar, Rule, Symbol};
use crate::hypergraph::{Hypergraph, NodeId};
use crate::parse::{Chart, visit};
use crate::sparse_vector::SparseVector;

pub fn chart_to_hg(chart: &Chart, n: usize, weights: &SparseVector, grammar: &Grammar) -> Hypergraph {
    let mut hg = Hypergraph::new();
    let mut seen: HashMap<String, NodeId> = HashMap::new();

    // Add root node with id=-1
    let root_nid = hg.add_node(-1, "root", -1, -1);

    let mut next_id: i64 = 0;

    // First pass: create nodes
    visit(1, 0, n, 0, |i, j| {
        for item in chart.at(i, j) {
            let lhs_sym = grammar.rules[item.rule_idx].lhs.nt_symbol();
            let key = format!("{},{},{}", lhs_sym, i, j);
            if !seen.contains_key(&key) {
                let nid = hg.add_node(next_id, lhs_sym, i as i32, j as i32);
                seen.insert(key, nid);
                next_id += 1;
            }
        }
    });

    // Second pass: create edges
    visit(1, 0, n, 0, |i, j| {
        for item in chart.at(i, j) {
            let rule = &grammar.rules[item.rule_idx];
            let head_key = format!("{},{},{}", rule.lhs.nt_symbol(), i, j);
            let head_nid = *seen.get(&head_key).unwrap();

            // Build tails
            let tails: Vec<NodeId> = if item.tail_spans.iter().all(|s| s.is_none())
                || item.tail_spans.iter().all(|s| {
                    s.as_ref()
                        .map_or(true, |sp| sp.left == -1 && sp.right == -1)
                })
            {
                if rule.rhs.iter().any(|s| s.is_nt()) {
                    build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid)
                } else {
                    vec![root_nid]
                }
            } else {
                build_tails(&rule.rhs, &item.tail_spans, &seen, root_nid)
            };

            let score = weights.dot(&rule.f).exp();
            let rule_clone = Rule::new(
                rule.lhs.clone(),
                rule.rhs.clone(),
                rule.target.clone(),
                rule.map.clone(),
                rule.f.clone(),
            );

            hg.add_edge(head_nid, tails, score, rule.f.clone(), rule_clone);
        }
    });

    hg
}

fn build_tails(
    rhs: &[Symbol],
    tail_spans: &[Option<crate::parse::Span>],
    seen: &HashMap<String, NodeId>,
    root_nid: NodeId,
) -> Vec<NodeId> {
    let mut tails = Vec::new();
    let mut has_nt = false;
    for (idx, sym) in rhs.iter().enumerate() {
        if let Symbol::NT { symbol, .. } = sym {
            has_nt = true;
            if let Some(Some(span)) = tail_spans.get(idx) {
                if span.left >= 0 && span.right >= 0 {
                    let key = format!("{},{},{}", symbol, span.left, span.right);
                    if let Some(&nid) = seen.get(&key) {
                        tails.push(nid);
                    }
                }
            }
        }
    }
    if !has_nt {
        vec![root_nid]
    } else {
        tails
    }
}