summaryrefslogtreecommitdiff
path: root/rs/src/hypergraph_io.rs
diff options
context:
space:
mode:
Diffstat (limited to 'rs/src/hypergraph_io.rs')
-rw-r--r--rs/src/hypergraph_io.rs110
1 files changed, 110 insertions, 0 deletions
diff --git a/rs/src/hypergraph_io.rs b/rs/src/hypergraph_io.rs
new file mode 100644
index 0000000..5aa9f3b
--- /dev/null
+++ b/rs/src/hypergraph_io.rs
@@ -0,0 +1,110 @@
+use std::fs;
+use std::io;
+
+use crate::grammar::Rule;
+use crate::hypergraph::{Hypergraph, NodeId};
+use crate::sparse_vector::SparseVector;
+
+pub fn read_hypergraph_from_json(path: &str, log_weights: bool) -> io::Result<Hypergraph> {
+ let content = fs::read_to_string(path)?;
+ let parsed: serde_json::Value = serde_json::from_str(&content)
+ .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
+
+ let mut hg = Hypergraph::new();
+
+ let weights = match parsed.get("weights").and_then(|w| w.as_object()) {
+ Some(obj) => SparseVector::from_hash(obj),
+ None => SparseVector::new(),
+ };
+
+ // Add nodes
+ if let Some(nodes) = parsed.get("nodes").and_then(|n| n.as_array()) {
+ for n in nodes {
+ let id = n["id"].as_i64().unwrap();
+ let cat = n["cat"].as_str().unwrap_or("");
+ let span = n["span"].as_array().unwrap();
+ let left = span[0].as_i64().unwrap() as i32;
+ let right = span[1].as_i64().unwrap() as i32;
+ hg.add_node(id, cat, left, right);
+ }
+ }
+
+ // Add edges
+ if let Some(edges) = parsed.get("edges").and_then(|e| e.as_array()) {
+ for e in edges {
+ let head_id = e["head"].as_i64().unwrap();
+ let head_nid = *hg.nodes_by_id.get(&head_id).unwrap();
+
+ let tails: Vec<NodeId> = e["tails"]
+ .as_array()
+ .unwrap()
+ .iter()
+ .map(|t| *hg.nodes_by_id.get(&t.as_i64().unwrap()).unwrap())
+ .collect();
+
+ let f = match e.get("f").and_then(|f| f.as_object()) {
+ Some(obj) => SparseVector::from_hash(obj),
+ None => SparseVector::new(),
+ };
+
+ let rule_str = e["rule"].as_str().unwrap();
+ let rule = Rule::from_str(rule_str);
+
+ let score = if log_weights {
+ weights.dot(&f).exp()
+ } else {
+ weights.dot(&f)
+ };
+
+ hg.add_edge(head_nid, tails, score, f, rule);
+ }
+ }
+
+ Ok(hg)
+}
+
+pub fn write_hypergraph_to_json(hg: &Hypergraph, weights: &SparseVector) -> String {
+ let mut json_s = String::from("{\n");
+ json_s.push_str(&format!(
+ "\"weights\":{},\n",
+ serde_json::to_string(&weights.to_json()).unwrap()
+ ));
+
+ json_s.push_str("\"nodes\":\n[\n");
+ let node_strs: Vec<String> = hg
+ .nodes
+ .iter()
+ .map(|n| {
+ format!(
+ "{{ \"id\":{}, \"cat\":\"{}\", \"span\":[{},{}] }}",
+ n.id,
+ n.symbol.replace('"', "\\\""),
+ n.left,
+ n.right
+ )
+ })
+ .collect();
+ json_s.push_str(&node_strs.join(",\n"));
+ json_s.push_str("\n],\n");
+
+ json_s.push_str("\"edges\":\n[\n");
+ let edge_strs: Vec<String> = hg
+ .edges
+ .iter()
+ .map(|e| {
+ let head_id = hg.nodes[e.head.0].id;
+ let tail_ids: Vec<i64> = e.tails.iter().map(|t| hg.nodes[t.0].id).collect();
+ format!(
+ "{{ \"head\":{}, \"rule\":\"{}\", \"tails\":{:?}, \"f\":{} }}",
+ head_id,
+ e.rule.to_string().replace('"', "\\\""),
+ tail_ids,
+ serde_json::to_string(&e.f.to_json()).unwrap()
+ )
+ })
+ .collect();
+ json_s.push_str(&edge_strs.join(",\n"));
+ json_s.push_str("\n]\n}\n");
+
+ json_s
+}