diff options
Diffstat (limited to 'src/hg_io.cc')
-rw-r--r-- | src/hg_io.cc | 598 |
1 files changed, 0 insertions, 598 deletions
diff --git a/src/hg_io.cc b/src/hg_io.cc deleted file mode 100644 index e21b1714..00000000 --- a/src/hg_io.cc +++ /dev/null @@ -1,598 +0,0 @@ -#include "hg_io.h" - -#include <sstream> -#include <iostream> - -#include <boost/lexical_cast.hpp> - -#include "tdict.h" -#include "json_parse.h" -#include "hg.h" - -using namespace std; - -struct HGReader : public JSONParser { - HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; } - - void CreateNode(const string& cat, const vector<int>& in_edges) { - WordID c = TD::Convert("X") * -1; - if (!cat.empty()) c = TD::Convert(cat) * -1; - Hypergraph::Node* node = hg.AddNode(c, ""); - for (int i = 0; i < in_edges.size(); ++i) { - if (in_edges[i] >= hg.edges_.size()) { - cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i] - << ", but hg only has " << hg.edges_.size() << " edges!\n"; - abort(); - } - hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node); - } - } - void CreateEdge(const TRulePtr& rule, SparseVector<double>* feats, const SmallVector& tail) { - Hypergraph::Edge* edge = hg.AddEdge(rule, tail); - feats->swap(edge->feature_values_); - } - - bool HandleJSONEvent(int type, const JSON_value* value) { - switch(state) { - case -1: - assert(type == JSON_T_OBJECT_BEGIN); - state = 0; - break; - case 0: - if (type == JSON_T_OBJECT_END) { - //cerr << "HG created\n"; // TODO, signal some kind of callback - } else if (type == JSON_T_KEY) { - string val = value->vu.str.value; - if (val == "features") { assert(fdict.empty()); state = 1; } - else if (val == "is_sorted") { state = 3; } - else if (val == "rules") { assert(rules.empty()); state = 4; } - else if (val == "node") { state = 8; } - else if (val == "edges") { state = 13; } - else { cerr << "Unexpected key: " << val << endl; return false; } - } - break; - - // features - case 1: - if(type == JSON_T_NULL) { state = 0; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 2; - break; - case 2: - if(type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_STRING); - fdict.push_back(FD::Convert(value->vu.str.value)); - break; - - // is_sorted - case 3: - assert(type == JSON_T_TRUE || type == JSON_T_FALSE); - is_sorted = (type == JSON_T_TRUE); - if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; } - state = 0; - break; - - // rules - case 4: - if(type == JSON_T_NULL) { state = 0; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 5; - break; - case 5: - if(type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_INTEGER); - state = 6; - rule_id = value->vu.integer_value; - break; - case 6: - assert(type == JSON_T_STRING); - rules[rule_id] = TRulePtr(new TRule(value->vu.str.value)); - state = 5; - break; - - // Nodes - case 8: - assert(type == JSON_T_OBJECT_BEGIN); - ++nodes; - in_edges.clear(); - cat.clear(); - state = 9; break; - case 9: - if (type == JSON_T_OBJECT_END) { - //cerr << "Creating NODE\n"; - CreateNode(cat, in_edges); - state = 0; break; - } - assert(type == JSON_T_KEY); - cur_key = value->vu.str.value; - if (cur_key == "cat") { assert(cat.empty()); state = 10; break; } - if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; } - cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n"; - return false; - case 10: - assert(type == JSON_T_STRING || type == JSON_T_NULL); - cat = value->vu.str.value; - state = 9; break; - case 11: - if (type == JSON_T_NULL) { state = 9; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 12; break; - case 12: - if (type == JSON_T_ARRAY_END) { state = 9; break; } - assert(type == JSON_T_INTEGER); - //cerr << "in_edges: " << value->vu.integer_value << endl; - in_edges.push_back(value->vu.integer_value); - break; - - // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12}, - // { "tail": null, "feats" : [0,0.87,1,0.02], "rule": 17}, - // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}] - case 13: - assert(type == JSON_T_ARRAY_BEGIN); - state = 14; - break; - case 14: - if (type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_OBJECT_BEGIN); - //cerr << "New edge\n"; - ++edges; - cur_rule.reset(); feats.clear(); tail.clear(); - state = 15; break; - case 15: - if (type == JSON_T_OBJECT_END) { - CreateEdge(cur_rule, &feats, tail); - state = 14; break; - } - assert(type == JSON_T_KEY); - cur_key = value->vu.str.value; - //cerr << "edge key " << cur_key << endl; - if (cur_key == "rule") { assert(!cur_rule); state = 16; break; } - if (cur_key == "feats") { assert(feats.empty()); state = 17; break; } - if (cur_key == "tail") { assert(tail.empty()); state = 20; break; } - cerr << "Unexpected key " << cur_key << " in edge specification\n"; - return false; - case 16: // edge.rule - if (type == JSON_T_INTEGER) { - int rule_id = value->vu.integer_value; - if (rules.find(rule_id) == rules.end()) { - // rules list must come before the edge definitions! - cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n"; - return false; - } - cur_rule = rules[rule_id]; - } else if (type == JSON_T_STRING) { - cur_rule.reset(new TRule(value->vu.str.value)); - } else { - cerr << "Rule must be either a rule id or a rule string" << endl; - return false; - } - // cerr << "Edge: rule=" << cur_rule->AsString() << endl; - state = 15; - break; - case 17: // edge.feats - if (type == JSON_T_NULL) { state = 15; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 18; break; - case 18: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - if (type != JSON_T_INTEGER && type != JSON_T_STRING) { - cerr << "Unexpected feature id type\n"; return false; - } - if (type == JSON_T_INTEGER) { - fid = value->vu.integer_value; - assert(fid < fdict.size()); - fid = fdict[fid]; - } else if (JSON_T_STRING) { - fid = FD::Convert(value->vu.str.value); - } else { abort(); } - state = 19; - break; - case 19: - { - assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT); - double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) : - strtod(value->vu.str.value, NULL)); - feats.set_value(fid, val); - state = 18; - break; - } - case 20: // edge.tail - if (type == JSON_T_NULL) { state = 15; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 21; break; - case 21: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - assert(type == JSON_T_INTEGER); - tail.push_back(value->vu.integer_value); - break; - } - return true; - } - string rp; - string cat; - SmallVector tail; - vector<int> in_edges; - TRulePtr cur_rule; - map<int, TRulePtr> rules; - vector<int> fdict; - SparseVector<double> feats; - int state; - int fid; - int nodes; - int edges; - string cur_key; - Hypergraph& hg; - int rule_id; - bool nodes_needed; - bool edges_needed; - bool is_sorted; -}; - -bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) { - hg->clear(); - HGReader reader(hg); - return reader.Parse(in); -} - -static void WriteRule(const TRule& r, ostream* out) { - if (!r.lhs_) { (*out) << "[X] ||| "; } - JSONParser::WriteEscapedString(r.AsString(), out); -} - -bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { - map<const TRule*, int> rid; - ostream& o = *out; - rid[NULL] = 0; - o << '{'; - if (!remove_rules) { - o << "\"rules\":["; - for (int i = 0; i < hg.edges_.size(); ++i) { - const TRule* r = hg.edges_[i].rule_.get(); - int &id = rid[r]; - if (!id) { - id=rid.size() - 1; - if (id > 1) o << ','; - o << id << ','; - WriteRule(*r, &o); - }; - } - o << "],"; - } - const bool use_fdict = FD::NumFeats() < 1000; - if (use_fdict) { - o << "\"features\":["; - for (int i = 1; i < FD::NumFeats(); ++i) { - o << (i==1 ? "":",") << '"' << FD::Convert(i) << '"'; - } - o << "],"; - } - vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order - int edge_count = 0; - for (int i = 0; i < hg.nodes_.size(); ++i) { - const Hypergraph::Node& node = hg.nodes_[i]; - if (i > 0) { o << ","; } - o << "\"edges\":["; - for (int j = 0; j < node.in_edges_.size(); ++j) { - const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; - edgemap[edge.id_] = edge_count; - ++edge_count; - o << (j == 0 ? "" : ",") << "{"; - - o << "\"tail\":["; - for (int k = 0; k < edge.tail_nodes_.size(); ++k) { - o << (k > 0 ? "," : "") << edge.tail_nodes_[k]; - } - o << "],"; - - o << "\"feats\":["; - bool first = true; - for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) { - if (!it->second) continue; - if (!first) o << ','; - if (use_fdict) - o << (it->first - 1); - else - o << '"' << FD::Convert(it->first) << '"'; - o << ',' << it->second; - first = false; - } - o << "]"; - if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; } - o << "}"; - } - o << "],"; - - o << "\"node\":{\"in_edges\":["; - for (int j = 0; j < node.in_edges_.size(); ++j) { - int mapped_edge = edgemap[node.in_edges_[j]]; - assert(mapped_edge >= 0); - o << (j == 0 ? "" : ",") << mapped_edge; - } - o << "]"; - if (node.cat_ < 0) { o << ",\"cat\":\"" << TD::Convert(node.cat_ * -1) << '"'; } - o << "}"; - } - o << "}\n"; - return true; -} - -bool needs_escape[128]; -void InitEscapes() { - memset(needs_escape, false, 128); - needs_escape[static_cast<size_t>('\'')] = true; - needs_escape[static_cast<size_t>('\\')] = true; -} - -string HypergraphIO::Escape(const string& s) { - size_t len = s.size(); - for (int i = 0; i < s.size(); ++i) { - unsigned char c = s[i]; - if (c < 128 && needs_escape[c]) ++len; - } - if (len == s.size()) return s; - string res(len, ' '); - size_t o = 0; - for (int i = 0; i < s.size(); ++i) { - unsigned char c = s[i]; - if (c < 128 && needs_escape[c]) - res[o++] = '\\'; - res[o++] = c; - } - assert(o == len); - return res; -} - -string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses) { - static bool first = true; - if (first) { InitEscapes(); first = false; } - if (hg.nodes_.empty()) return "()"; - ostringstream os; - if (include_global_parentheses) os << '('; - static const string EPS="*EPS*"; - for (int i = 0; i < hg.nodes_.size()-1; ++i) { - if (hg.nodes_[i].out_edges_.empty()) abort(); - const bool last_node = (i == hg.nodes_.size() - 2); - const int out_edges_size = hg.nodes_[i].out_edges_.size(); - // compound splitter adds an extra goal transition which we suppress with - // the following conditional - if (!last_node || out_edges_size != 1 || - hg.edges_[hg.nodes_[i].out_edges_[0]].rule_->EWords() == 1) { - os << '('; - for (int j = 0; j < out_edges_size; ++j) { - const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]]; - const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS; - double prob = log(e.edge_prob_); - if (isinf(prob)) { prob = -9e20; } - if (isnan(prob)) { prob = 0; } - os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),"; - } - os << "),"; - } - } - if (include_global_parentheses) os << ')'; - return os.str(); -} - -namespace PLF { - -const string chars = "'\\"; -const char& quote = chars[0]; -const char& slash = chars[1]; - -// safe get -inline char get(const std::string& in, int c) { - if (c < 0 || c >= (int)in.size()) return 0; - else return in[(size_t)c]; -} - -// consume whitespace -inline void eatws(const std::string& in, int& c) { - while (get(in,c) == ' ') { c++; } -} - -// from 'foo' return foo -std::string getEscapedString(const std::string& in, int &c) -{ - eatws(in,c); - if (get(in,c++) != quote) return "ERROR"; - std::string res; - char cur = 0; - do { - cur = get(in,c++); - if (cur == slash) { res += get(in,c++); } - else if (cur != quote) { res += cur; } - } while (get(in,c) != quote && (c < (int)in.size())); - c++; - eatws(in,c); - return res; -} - -// basically atof -float getFloat(const std::string& in, int &c) -{ - std::string tmp; - eatws(in,c); - while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') { - tmp += get(in,c++); - } - eatws(in,c); - if (tmp.empty()) { - cerr << "Syntax error while reading number! col=" << c << endl; - abort(); - } - return atof(tmp.c_str()); -} - -// basically atoi -int getInt(const std::string& in, int &c) -{ - std::string tmp; - eatws(in,c); - while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') { - tmp += get(in,c++); - } - eatws(in,c); - return atoi(tmp.c_str()); -} - -// maximum number of nodes permitted -#define MAX_NODES 100000000 -// parse ('foo', 0.23) -void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { - if (get(in,c++) != '(') { assert(!"PCN/PLF parse error: expected ( at start of cn alt block\n"); } - vector<WordID> ewords(2, 0); - ewords[1] = TD::Convert(getEscapedString(in,c)); - TRulePtr r(new TRule(ewords)); - // cerr << "RULE: " << r->AsString() << endl; - if (get(in,c++) != ',') { assert(!"PCN/PLF parse error: expected , after string\n"); } - size_t cnNext = 1; - std::vector<float> probs; - probs.push_back(getFloat(in,c)); - while (get(in,c) == ',') { - c++; - float val = getFloat(in,c); - probs.push_back(val); - // cerr << val << endl; //REMO - } - //if we read more than one prob, this was a lattice, last item was column increment - if (probs.size()>1) { - cnNext = static_cast<size_t>(probs.back()); - probs.pop_back(); - if (cnNext < 1) { cerr << cnNext << endl; - assert(!"PCN/PLF parse error: bad link length at last element of cn alt block\n"); } - } - if (get(in,c++) != ')') { assert(!"PCN/PLF parse error: expected ) at end of cn alt block\n"); } - eatws(in,c); - Hypergraph::TailNodeVector tail(1, cur_node); - Hypergraph::Edge* edge = hg->AddEdge(r, tail); - //cerr << " <--" << cur_node << endl; - int head_node = cur_node + cnNext; - assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory - if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); } - hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]); - for (int i = 0; i < probs.size(); ++i) - edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast<string>(i)), probs[i]); -} - -// parse (('foo', 0.23), ('bar', 0.77)) -void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) { - //cerr << "PLF READING NODE " << cur_node << endl; - if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); } - if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); } - eatws(in,c); - while (1) { - if (c > (int)in.size()) { break; } - if (get(in,c) == ')') { - c++; - eatws(in,c); - break; - } - if (get(in,c) == ',' && get(in,c+1) == ')') { - c+=2; - eatws(in,c); - break; - } - if (get(in,c) == ',') { c++; eatws(in,c); } - ReadPLFEdge(in, c, cur_node, hg); - } -} - -} // namespace PLF - -void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) { - hg->clear(); - int c = 0; - int cur_node = 0; - if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); } - while (1) { - if (c > (int)in.size()) { break; } - if (PLF::get(in,c) == ')') { - c++; - PLF::eatws(in,c); - break; - } - if (PLF::get(in,c) == ',' && PLF::get(in,c+1) == ')') { - c+=2; - PLF::eatws(in,c); - break; - } - if (PLF::get(in,c) == ',') { c++; PLF::eatws(in,c); } - PLF::ReadPLFNode(in, c, cur_node, line, hg); - ++cur_node; - } - assert(cur_node == hg->nodes_.size() - 1); -} - -void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) { - Lattice& l = *pl; - Hypergraph g; - ReadFromPLF(plf, &g, 0); - const int num_nodes = g.nodes_.size() - 1; - l.resize(num_nodes); - for (int i = 0; i < num_nodes; ++i) { - vector<LatticeArc>& alts = l[i]; - const Hypergraph::Node& node = g.nodes_[i]; - const int num_alts = node.out_edges_.size(); - alts.resize(num_alts); - for (int j = 0; j < num_alts; ++j) { - const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]]; - alts[j].label = edge.rule_->e_[1]; - alts[j].cost = edge.feature_values_.value(FD::Convert("Feature_0")); - alts[j].dist2next = edge.head_node_ - node.id_; - } - } -} - -namespace B64 { - -static const char cb64[]="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -static const char cd64[]="|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`abcdefghijklmnopq"; - -static void encodeblock(const unsigned char* in, ostream* os, int len) { - char out[4]; - out[0] = cb64[ in[0] >> 2 ]; - out[1] = cb64[ ((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4) ]; - out[2] = (len > 1 ? cb64[ ((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6) ] : '='); - out[3] = (len > 2 ? cb64[ in[2] & 0x3f ] : '='); - os->write(out, 4); -} - -void b64encode(const char* data, const size_t size, ostream* out) { - size_t cur = 0; - while(cur < size) { - int len = min(static_cast<size_t>(3), size - cur); - encodeblock(reinterpret_cast<const unsigned char*>(&data[cur]), out, len); - cur += len; - } -} - -static void decodeblock(const unsigned char* in, unsigned char* out) { - out[0] = (unsigned char ) (in[0] << 2 | in[1] >> 4); - out[1] = (unsigned char ) (in[1] << 4 | in[2] >> 2); - out[2] = (unsigned char ) (((in[2] << 6) & 0xc0) | in[3]); -} - -bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize) { - size_t cur = 0; - size_t ocur = 0; - unsigned char in[4]; - while(cur < insize) { - assert(ocur < outsize); - for (int i = 0; i < 4; ++i) { - unsigned char v = data[cur]; - v = (unsigned char) ((v < 43 || v > 122) ? '\0' : cd64[ v - 43 ]); - if (!v) { - cerr << "B64 decode error at offset " << cur << " offending character: " << (int)data[cur] << endl; - return false; - } - v = (unsigned char) ((v == '$') ? '\0' : v - 61); - if (v) in[i] = v - 1; else in[i] = 0; - ++cur; - } - decodeblock(in, reinterpret_cast<unsigned char*>(&out[ocur])); - ocur += 3; - } - return true; -} -} - |