summaryrefslogtreecommitdiff
path: root/rs/src/hypergraph.rs
blob: 90069b0e07246152059dc414c361c9077c22586d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use crate::grammar::Rule;
use crate::sparse_vector::SparseVector;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NodeId(pub usize);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct EdgeId(pub usize);

#[derive(Debug, Clone)]
pub struct Node {
    pub id: i64,
    pub symbol: String,
    pub left: i32,
    pub right: i32,
    pub outgoing: Vec<EdgeId>,
    pub incoming: Vec<EdgeId>,
    pub score: f64,
}

impl Node {
    pub fn new(id: i64, symbol: &str, left: i32, right: i32) -> Self {
        Self {
            id,
            symbol: symbol.to_string(),
            left,
            right,
            outgoing: Vec::new(),
            incoming: Vec::new(),
            score: 0.0,
        }
    }
}

#[derive(Debug, Clone)]
pub struct Edge {
    pub head: NodeId,
    pub tails: Vec<NodeId>,
    pub score: f64,
    pub f: SparseVector,
    pub mark: usize,
    pub rule: Rule,
}

impl Edge {
    pub fn arity(&self) -> usize {
        self.tails.len()
    }

    pub fn marked(&self) -> bool {
        self.arity() == self.mark
    }
}

#[derive(Debug)]
pub struct Hypergraph {
    pub nodes: Vec<Node>,
    pub edges: Vec<Edge>,
    pub nodes_by_id: std::collections::HashMap<i64, NodeId>,
}

impl Hypergraph {
    pub fn new() -> Self {
        Self {
            nodes: Vec::new(),
            edges: Vec::new(),
            nodes_by_id: std::collections::HashMap::new(),
        }
    }

    pub fn add_node(&mut self, id: i64, symbol: &str, left: i32, right: i32) -> NodeId {
        let nid = NodeId(self.nodes.len());
        self.nodes.push(Node::new(id, symbol, left, right));
        self.nodes_by_id.insert(id, nid);
        nid
    }

    pub fn add_edge(
        &mut self,
        head: NodeId,
        tails: Vec<NodeId>,
        score: f64,
        f: SparseVector,
        rule: Rule,
    ) -> EdgeId {
        let eid = EdgeId(self.edges.len());
        self.edges.push(Edge {
            head,
            tails: tails.clone(),
            score,
            f,
            mark: 0,
            rule,
        });
        self.nodes[head.0].incoming.push(eid);
        for &t in &tails {
            self.nodes[t.0].outgoing.push(eid);
        }
        eid
    }

    pub fn node(&self, nid: NodeId) -> &Node {
        &self.nodes[nid.0]
    }

    pub fn edge(&self, eid: EdgeId) -> &Edge {
        &self.edges[eid.0]
    }

    pub fn reset(&mut self) {
        for e in &mut self.edges {
            e.mark = 0;
        }
    }
}