summaryrefslogtreecommitdiff
path: root/rs/src/hypergraph_io.rs
blob: 5aa9f3b1d3788c2e702541c20f578e4d3a7889b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use std::fs;
use std::io;

use crate::grammar::Rule;
use crate::hypergraph::{Hypergraph, NodeId};
use crate::sparse_vector::SparseVector;

pub fn read_hypergraph_from_json(path: &str, log_weights: bool) -> io::Result<Hypergraph> {
    let content = fs::read_to_string(path)?;
    let parsed: serde_json::Value = serde_json::from_str(&content)
        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;

    let mut hg = Hypergraph::new();

    let weights = match parsed.get("weights").and_then(|w| w.as_object()) {
        Some(obj) => SparseVector::from_hash(obj),
        None => SparseVector::new(),
    };

    // Add nodes
    if let Some(nodes) = parsed.get("nodes").and_then(|n| n.as_array()) {
        for n in nodes {
            let id = n["id"].as_i64().unwrap();
            let cat = n["cat"].as_str().unwrap_or("");
            let span = n["span"].as_array().unwrap();
            let left = span[0].as_i64().unwrap() as i32;
            let right = span[1].as_i64().unwrap() as i32;
            hg.add_node(id, cat, left, right);
        }
    }

    // Add edges
    if let Some(edges) = parsed.get("edges").and_then(|e| e.as_array()) {
        for e in edges {
            let head_id = e["head"].as_i64().unwrap();
            let head_nid = *hg.nodes_by_id.get(&head_id).unwrap();

            let tails: Vec<NodeId> = e["tails"]
                .as_array()
                .unwrap()
                .iter()
                .map(|t| *hg.nodes_by_id.get(&t.as_i64().unwrap()).unwrap())
                .collect();

            let f = match e.get("f").and_then(|f| f.as_object()) {
                Some(obj) => SparseVector::from_hash(obj),
                None => SparseVector::new(),
            };

            let rule_str = e["rule"].as_str().unwrap();
            let rule = Rule::from_str(rule_str);

            let score = if log_weights {
                weights.dot(&f).exp()
            } else {
                weights.dot(&f)
            };

            hg.add_edge(head_nid, tails, score, f, rule);
        }
    }

    Ok(hg)
}

pub fn write_hypergraph_to_json(hg: &Hypergraph, weights: &SparseVector) -> String {
    let mut json_s = String::from("{\n");
    json_s.push_str(&format!(
        "\"weights\":{},\n",
        serde_json::to_string(&weights.to_json()).unwrap()
    ));

    json_s.push_str("\"nodes\":\n[\n");
    let node_strs: Vec<String> = hg
        .nodes
        .iter()
        .map(|n| {
            format!(
                "{{ \"id\":{}, \"cat\":\"{}\", \"span\":[{},{}] }}",
                n.id,
                n.symbol.replace('"', "\\\""),
                n.left,
                n.right
            )
        })
        .collect();
    json_s.push_str(&node_strs.join(",\n"));
    json_s.push_str("\n],\n");

    json_s.push_str("\"edges\":\n[\n");
    let edge_strs: Vec<String> = hg
        .edges
        .iter()
        .map(|e| {
            let head_id = hg.nodes[e.head.0].id;
            let tail_ids: Vec<i64> = e.tails.iter().map(|t| hg.nodes[t.0].id).collect();
            format!(
                "{{ \"head\":{}, \"rule\":\"{}\", \"tails\":{:?}, \"f\":{} }}",
                head_id,
                e.rule.to_string().replace('"', "\\\""),
                tail_ids,
                serde_json::to_string(&e.f.to_json()).unwrap()
            )
        })
        .collect();
    json_s.push_str(&edge_strs.join(",\n"));
    json_s.push_str("\n]\n}\n");

    json_s
}