summaryrefslogtreecommitdiff
path: root/src/hg_io.cc
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-03 16:33:55 -0500
committerChris Dyer <redpony@gmail.com>2009-12-03 16:33:55 -0500
commit671c21451542e2dd20e45b4033d44d8e8735f87b (patch)
treeb1773b077dd65b826f067a423d26f7942ce4e043 /src/hg_io.cc
initial check in
Diffstat (limited to 'src/hg_io.cc')
-rw-r--r--src/hg_io.cc585
1 files changed, 585 insertions, 0 deletions
diff --git a/src/hg_io.cc b/src/hg_io.cc
new file mode 100644
index 00000000..629e65f1
--- /dev/null
+++ b/src/hg_io.cc
@@ -0,0 +1,585 @@
+#include "hg_io.h"
+
+#include <sstream>
+#include <iostream>
+
+#include <boost/lexical_cast.hpp>
+
+#include "tdict.h"
+#include "json_parse.h"
+#include "hg.h"
+
+using namespace std;
+
+struct HGReader : public JSONParser {
+ HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; }
+
+ void CreateNode(const string& cat, const vector<int>& in_edges) {
+ WordID c = TD::Convert("X") * -1;
+ if (!cat.empty()) c = TD::Convert(cat) * -1;
+ Hypergraph::Node* node = hg.AddNode(c, "");
+ for (int i = 0; i < in_edges.size(); ++i) {
+ if (in_edges[i] >= hg.edges_.size()) {
+ cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i]
+ << ", but hg only has " << hg.edges_.size() << " edges!\n";
+ abort();
+ }
+ hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node);
+ }
+ }
+ void CreateEdge(const TRulePtr& rule, SparseVector<double>* feats, const SmallVector& tail) {
+ Hypergraph::Edge* edge = hg.AddEdge(rule, tail);
+ feats->swap(edge->feature_values_);
+ }
+
+ bool HandleJSONEvent(int type, const JSON_value* value) {
+ switch(state) {
+ case -1:
+ assert(type == JSON_T_OBJECT_BEGIN);
+ state = 0;
+ break;
+ case 0:
+ if (type == JSON_T_OBJECT_END) {
+ //cerr << "HG created\n"; // TODO, signal some kind of callback
+ } else if (type == JSON_T_KEY) {
+ string val = value->vu.str.value;
+ if (val == "features") { assert(fdict.empty()); state = 1; }
+ else if (val == "is_sorted") { state = 3; }
+ else if (val == "rules") { assert(rules.empty()); state = 4; }
+ else if (val == "node") { state = 8; }
+ else if (val == "edges") { state = 13; }
+ else { cerr << "Unexpected key: " << val << endl; return false; }
+ }
+ break;
+
+ // features
+ case 1:
+ if(type == JSON_T_NULL) { state = 0; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 2;
+ break;
+ case 2:
+ if(type == JSON_T_ARRAY_END) { state = 0; break; }
+ assert(type == JSON_T_STRING);
+ fdict.push_back(FD::Convert(value->vu.str.value));
+ break;
+
+ // is_sorted
+ case 3:
+ assert(type == JSON_T_TRUE || type == JSON_T_FALSE);
+ is_sorted = (type == JSON_T_TRUE);
+ if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; }
+ state = 0;
+ break;
+
+ // rules
+ case 4:
+ if(type == JSON_T_NULL) { state = 0; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 5;
+ break;
+ case 5:
+ if(type == JSON_T_ARRAY_END) { state = 0; break; }
+ assert(type == JSON_T_INTEGER);
+ state = 6;
+ rule_id = value->vu.integer_value;
+ break;
+ case 6:
+ assert(type == JSON_T_STRING);
+ rules[rule_id] = TRulePtr(new TRule(value->vu.str.value));
+ state = 5;
+ break;
+
+ // Nodes
+ case 8:
+ assert(type == JSON_T_OBJECT_BEGIN);
+ ++nodes;
+ in_edges.clear();
+ cat.clear();
+ state = 9; break;
+ case 9:
+ if (type == JSON_T_OBJECT_END) {
+ //cerr << "Creating NODE\n";
+ CreateNode(cat, in_edges);
+ state = 0; break;
+ }
+ assert(type == JSON_T_KEY);
+ cur_key = value->vu.str.value;
+ if (cur_key == "cat") { assert(cat.empty()); state = 10; break; }
+ if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; }
+ cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n";
+ return false;
+ case 10:
+ assert(type == JSON_T_STRING || type == JSON_T_NULL);
+ cat = value->vu.str.value;
+ state = 9; break;
+ case 11:
+ if (type == JSON_T_NULL) { state = 9; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 12; break;
+ case 12:
+ if (type == JSON_T_ARRAY_END) { state = 9; break; }
+ assert(type == JSON_T_INTEGER);
+ //cerr << "in_edges: " << value->vu.integer_value << endl;
+ in_edges.push_back(value->vu.integer_value);
+ break;
+
+ // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12},
+ // { "tail": null, "feats" : [0,0.87,1,0.02], "rule": 17},
+ // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}]
+ case 13:
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 14;
+ break;
+ case 14:
+ if (type == JSON_T_ARRAY_END) { state = 0; break; }
+ assert(type == JSON_T_OBJECT_BEGIN);
+ //cerr << "New edge\n";
+ ++edges;
+ cur_rule.reset(); feats.clear(); tail.clear();
+ state = 15; break;
+ case 15:
+ if (type == JSON_T_OBJECT_END) {
+ CreateEdge(cur_rule, &feats, tail);
+ state = 14; break;
+ }
+ assert(type == JSON_T_KEY);
+ cur_key = value->vu.str.value;
+ //cerr << "edge key " << cur_key << endl;
+ if (cur_key == "rule") { assert(!cur_rule); state = 16; break; }
+ if (cur_key == "feats") { assert(feats.empty()); state = 17; break; }
+ if (cur_key == "tail") { assert(tail.empty()); state = 20; break; }
+ cerr << "Unexpected key " << cur_key << " in edge specification\n";
+ return false;
+ case 16: // edge.rule
+ if (type == JSON_T_INTEGER) {
+ int rule_id = value->vu.integer_value;
+ if (rules.find(rule_id) == rules.end()) {
+ // rules list must come before the edge definitions!
+ cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n";
+ return false;
+ }
+ cur_rule = rules[rule_id];
+ } else if (type == JSON_T_STRING) {
+ cur_rule.reset(new TRule(value->vu.str.value));
+ } else {
+ cerr << "Rule must be either a rule id or a rule string" << endl;
+ return false;
+ }
+ // cerr << "Edge: rule=" << cur_rule->AsString() << endl;
+ state = 15;
+ break;
+ case 17: // edge.feats
+ if (type == JSON_T_NULL) { state = 15; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 18; break;
+ case 18:
+ if (type == JSON_T_ARRAY_END) { state = 15; break; }
+ if (type != JSON_T_INTEGER && type != JSON_T_STRING) {
+ cerr << "Unexpected feature id type\n"; return false;
+ }
+ if (type == JSON_T_INTEGER) {
+ fid = value->vu.integer_value;
+ assert(fid < fdict.size());
+ fid = fdict[fid];
+ } else if (JSON_T_STRING) {
+ fid = FD::Convert(value->vu.str.value);
+ } else { abort(); }
+ state = 19;
+ break;
+ case 19:
+ {
+ assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT);
+ double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) :
+ strtod(value->vu.str.value, NULL));
+ feats.set_value(fid, val);
+ state = 18;
+ break;
+ }
+ case 20: // edge.tail
+ if (type == JSON_T_NULL) { state = 15; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 21; break;
+ case 21:
+ if (type == JSON_T_ARRAY_END) { state = 15; break; }
+ assert(type == JSON_T_INTEGER);
+ tail.push_back(value->vu.integer_value);
+ break;
+ }
+ return true;
+ }
+ string rp;
+ string cat;
+ SmallVector tail;
+ vector<int> in_edges;
+ TRulePtr cur_rule;
+ map<int, TRulePtr> rules;
+ vector<int> fdict;
+ SparseVector<double> feats;
+ int state;
+ int fid;
+ int nodes;
+ int edges;
+ string cur_key;
+ Hypergraph& hg;
+ int rule_id;
+ bool nodes_needed;
+ bool edges_needed;
+ bool is_sorted;
+};
+
+bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) {
+ hg->clear();
+ HGReader reader(hg);
+ return reader.Parse(in);
+}
+
+static void WriteRule(const TRule& r, ostream* out) {
+ if (!r.lhs_) { (*out) << "[X] ||| "; }
+ JSONParser::WriteEscapedString(r.AsString(), out);
+}
+
+bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
+ map<const TRule*, int> rid;
+ ostream& o = *out;
+ rid[NULL] = 0;
+ o << '{';
+ if (!remove_rules) {
+ o << "\"rules\":[";
+ for (int i = 0; i < hg.edges_.size(); ++i) {
+ const TRule* r = hg.edges_[i].rule_.get();
+ int &id = rid[r];
+ if (!id) {
+ id=rid.size() - 1;
+ if (id > 1) o << ',';
+ o << id << ',';
+ WriteRule(*r, &o);
+ };
+ }
+ o << "],";
+ }
+ const bool use_fdict = FD::NumFeats() < 1000;
+ if (use_fdict) {
+ o << "\"features\":[";
+ for (int i = 1; i < FD::NumFeats(); ++i) {
+ o << (i==1 ? "":",") << '"' << FD::Convert(i) << '"';
+ }
+ o << "],";
+ }
+ vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order
+ int edge_count = 0;
+ for (int i = 0; i < hg.nodes_.size(); ++i) {
+ const Hypergraph::Node& node = hg.nodes_[i];
+ if (i > 0) { o << ","; }
+ o << "\"edges\":[";
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
+ edgemap[edge.id_] = edge_count;
+ ++edge_count;
+ o << (j == 0 ? "" : ",") << "{";
+
+ o << "\"tail\":[";
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
+ o << (k > 0 ? "," : "") << edge.tail_nodes_[k];
+ }
+ o << "],";
+
+ o << "\"feats\":[";
+ bool first = true;
+ for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) {
+ if (!it->second) continue;
+ if (!first) o << ',';
+ if (use_fdict)
+ o << (it->first - 1);
+ else
+ o << '"' << FD::Convert(it->first) << '"';
+ o << ',' << it->second;
+ first = false;
+ }
+ o << "]";
+ if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; }
+ o << "}";
+ }
+ o << "],";
+
+ o << "\"node\":{\"in_edges\":[";
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ int mapped_edge = edgemap[node.in_edges_[j]];
+ assert(mapped_edge >= 0);
+ o << (j == 0 ? "" : ",") << mapped_edge;
+ }
+ o << "]";
+ if (node.cat_ < 0) { o << ",\"cat\":\"" << TD::Convert(node.cat_ * -1) << '"'; }
+ o << "}";
+ }
+ o << "}\n";
+ return true;
+}
+
+bool needs_escape[128];
+void InitEscapes() {
+ memset(needs_escape, false, 128);
+ needs_escape[static_cast<size_t>('\'')] = true;
+ needs_escape[static_cast<size_t>('\\')] = true;
+}
+
+string HypergraphIO::Escape(const string& s) {
+ size_t len = s.size();
+ for (int i = 0; i < s.size(); ++i) {
+ unsigned char c = s[i];
+ if (c < 128 && needs_escape[c]) ++len;
+ }
+ if (len == s.size()) return s;
+ string res(len, ' ');
+ size_t o = 0;
+ for (int i = 0; i < s.size(); ++i) {
+ unsigned char c = s[i];
+ if (c < 128 && needs_escape[c])
+ res[o++] = '\\';
+ res[o++] = c;
+ }
+ assert(o == len);
+ return res;
+}
+
+string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses) {
+ static bool first = true;
+ if (first) { InitEscapes(); first = false; }
+ if (hg.nodes_.empty()) return "()";
+ ostringstream os;
+ if (include_global_parentheses) os << '(';
+ static const string EPS="*EPS*";
+ for (int i = 0; i < hg.nodes_.size()-1; ++i) {
+ os << '(';
+ if (hg.nodes_[i].out_edges_.empty()) abort();
+ for (int j = 0; j < hg.nodes_[i].out_edges_.size(); ++j) {
+ const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]];
+ const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS;
+ double prob = log(e.edge_prob_);
+ if (isinf(prob)) { prob = -9e20; }
+ if (isnan(prob)) { prob = 0; }
+ os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),";
+ }
+ os << "),";
+ }
+ if (include_global_parentheses) os << ')';
+ return os.str();
+}
+
+namespace PLF {
+
+const string chars = "'\\";
+const char& quote = chars[0];
+const char& slash = chars[1];
+
+// safe get
+inline char get(const std::string& in, int c) {
+ if (c < 0 || c >= (int)in.size()) return 0;
+ else return in[(size_t)c];
+}
+
+// consume whitespace
+inline void eatws(const std::string& in, int& c) {
+ while (get(in,c) == ' ') { c++; }
+}
+
+// from 'foo' return foo
+std::string getEscapedString(const std::string& in, int &c)
+{
+ eatws(in,c);
+ if (get(in,c++) != quote) return "ERROR";
+ std::string res;
+ char cur = 0;
+ do {
+ cur = get(in,c++);
+ if (cur == slash) { res += get(in,c++); }
+ else if (cur != quote) { res += cur; }
+ } while (get(in,c) != quote && (c < (int)in.size()));
+ c++;
+ eatws(in,c);
+ return res;
+}
+
+// basically atof
+float getFloat(const std::string& in, int &c)
+{
+ std::string tmp;
+ eatws(in,c);
+ while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
+ tmp += get(in,c++);
+ }
+ eatws(in,c);
+ return atof(tmp.c_str());
+}
+
+// basically atoi
+int getInt(const std::string& in, int &c)
+{
+ std::string tmp;
+ eatws(in,c);
+ while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
+ tmp += get(in,c++);
+ }
+ eatws(in,c);
+ return atoi(tmp.c_str());
+}
+
+// maximum number of nodes permitted
+#define MAX_NODES 100000000
+// parse ('foo', 0.23)
+void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
+ if (get(in,c++) != '(') { assert(!"PCN/PLF parse error: expected ( at start of cn alt block\n"); }
+ vector<WordID> ewords(2, 0);
+ ewords[1] = TD::Convert(getEscapedString(in,c));
+ TRulePtr r(new TRule(ewords));
+ //cerr << "RULE: " << r->AsString() << endl;
+ if (get(in,c++) != ',') { assert(!"PCN/PLF parse error: expected , after string\n"); }
+ size_t cnNext = 1;
+ std::vector<float> probs;
+ probs.push_back(getFloat(in,c));
+ while (get(in,c) == ',') {
+ c++;
+ float val = getFloat(in,c);
+ probs.push_back(val);
+ }
+ //if we read more than one prob, this was a lattice, last item was column increment
+ if (probs.size()>1) {
+ cnNext = static_cast<size_t>(probs.back());
+ probs.pop_back();
+ if (cnNext < 1) { assert(!"PCN/PLF parse error: bad link length at last element of cn alt block\n"); }
+ }
+ if (get(in,c++) != ')') { assert(!"PCN/PLF parse error: expected ) at end of cn alt block\n"); }
+ eatws(in,c);
+ Hypergraph::TailNodeVector tail(1, cur_node);
+ Hypergraph::Edge* edge = hg->AddEdge(r, tail);
+ //cerr << " <--" << cur_node << endl;
+ int head_node = cur_node + cnNext;
+ assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory
+ if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); }
+ hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]);
+ for (int i = 0; i < probs.size(); ++i)
+ edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast<string>(i)), probs[i]);
+}
+
+// parse (('foo', 0.23), ('bar', 0.77))
+void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) {
+ //cerr << "PLF READING NODE " << cur_node << endl;
+ if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); }
+ if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); }
+ eatws(in,c);
+ while (1) {
+ if (c > (int)in.size()) { break; }
+ if (get(in,c) == ')') {
+ c++;
+ eatws(in,c);
+ break;
+ }
+ if (get(in,c) == ',' && get(in,c+1) == ')') {
+ c+=2;
+ eatws(in,c);
+ break;
+ }
+ if (get(in,c) == ',') { c++; eatws(in,c); }
+ ReadPLFEdge(in, c, cur_node, hg);
+ }
+}
+
+} // namespace PLF
+
+void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) {
+ hg->clear();
+ int c = 0;
+ int cur_node = 0;
+ if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); }
+ while (1) {
+ if (c > (int)in.size()) { break; }
+ if (PLF::get(in,c) == ')') {
+ c++;
+ PLF::eatws(in,c);
+ break;
+ }
+ if (PLF::get(in,c) == ',' && PLF::get(in,c+1) == ')') {
+ c+=2;
+ PLF::eatws(in,c);
+ break;
+ }
+ if (PLF::get(in,c) == ',') { c++; PLF::eatws(in,c); }
+ PLF::ReadPLFNode(in, c, cur_node, line, hg);
+ ++cur_node;
+ }
+ assert(cur_node == hg->nodes_.size() - 1);
+}
+
+void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) {
+ Lattice& l = *pl;
+ Hypergraph g;
+ ReadFromPLF(plf, &g, 0);
+ const int num_nodes = g.nodes_.size() - 1;
+ l.resize(num_nodes);
+ for (int i = 0; i < num_nodes; ++i) {
+ vector<LatticeArc>& alts = l[i];
+ const Hypergraph::Node& node = g.nodes_[i];
+ const int num_alts = node.out_edges_.size();
+ alts.resize(num_alts);
+ for (int j = 0; j < num_alts; ++j) {
+ const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]];
+ alts[j].label = edge.rule_->e_[1];
+ alts[j].cost = edge.feature_values_.value(FD::Convert("Feature_0"));
+ alts[j].dist2next = edge.head_node_ - node.id_;
+ }
+ }
+}
+
+namespace B64 {
+
+static const char cb64[]="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+static const char cd64[]="|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`abcdefghijklmnopq";
+
+static void encodeblock(const unsigned char* in, ostream* os, int len) {
+ char out[4];
+ out[0] = cb64[ in[0] >> 2 ];
+ out[1] = cb64[ ((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4) ];
+ out[2] = (len > 1 ? cb64[ ((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6) ] : '=');
+ out[3] = (len > 2 ? cb64[ in[2] & 0x3f ] : '=');
+ os->write(out, 4);
+}
+
+void b64encode(const char* data, const size_t size, ostream* out) {
+ size_t cur = 0;
+ while(cur < size) {
+ int len = min(static_cast<size_t>(3), size - cur);
+ encodeblock(reinterpret_cast<const unsigned char*>(&data[cur]), out, len);
+ cur += len;
+ }
+}
+
+static void decodeblock(const unsigned char* in, unsigned char* out) {
+ out[0] = (unsigned char ) (in[0] << 2 | in[1] >> 4);
+ out[1] = (unsigned char ) (in[1] << 4 | in[2] >> 2);
+ out[2] = (unsigned char ) (((in[2] << 6) & 0xc0) | in[3]);
+}
+
+bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize) {
+ size_t cur = 0;
+ size_t ocur = 0;
+ unsigned char in[4];
+ while(cur < insize) {
+ assert(ocur < outsize);
+ for (int i = 0; i < 4; ++i) {
+ unsigned char v = data[cur];
+ v = (unsigned char) ((v < 43 || v > 122) ? '\0' : cd64[ v - 43 ]);
+ if (!v) {
+ cerr << "B64 decode error at offset " << cur << " offending character: " << (int)data[cur] << endl;
+ return false;
+ }
+ v = (unsigned char) ((v == '$') ? '\0' : v - 61);
+ if (v) in[i] = v - 1; else in[i] = 0;
+ ++cur;
+ }
+ decodeblock(in, reinterpret_cast<unsigned char*>(&out[ocur]));
+ ocur += 3;
+ }
+ return true;
+}
+}
+