summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-10-19 05:24:21 -0400
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-10-19 05:24:21 -0400
commit011a87cfe6d9cc702cb4a8a6d9a765556e460af9 (patch)
tree4747c3fc01d45dfdf58a454ef22e388d9f245e71
parente8451a312b01b0917ffef06afc17f1e8cad5d510 (diff)
stop switch to boost serialization for hypergraph IO
-rw-r--r--decoder/decoder.cc4
-rw-r--r--decoder/forest_writer.cc4
-rw-r--r--decoder/hg.h52
-rw-r--r--decoder/hg_io.cc101
-rw-r--r--decoder/hg_io.h5
-rw-r--r--decoder/hg_test.cc39
-rw-r--r--decoder/rule_lexer.ll1
-rw-r--r--python/cdec/hypergraph.pxd3
-rw-r--r--training/dpmert/mr_dpmert_generate_mapper_input.cc2
-rw-r--r--training/dpmert/mr_dpmert_map.cc4
-rw-r--r--training/minrisk/minrisk_optimize.cc2
-rw-r--r--training/pro/mr_pro_map.cc2
-rw-r--r--training/rampion/rampion_cccp.cc2
-rw-r--r--training/utils/grammar_convert.cc5
-rw-r--r--utils/small_vector.h16
-rw-r--r--utils/small_vector_test.cc30
16 files changed, 156 insertions, 116 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 3cc77d27..f8214f5f 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -930,7 +930,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
@@ -1023,7 +1023,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
diff --git a/decoder/forest_writer.cc b/decoder/forest_writer.cc
index 6e4cccb3..c072e599 100644
--- a/decoder/forest_writer.cc
+++ b/decoder/forest_writer.cc
@@ -11,13 +11,13 @@
using namespace std;
ForestWriter::ForestWriter(const std::string& path, int num) :
- fname_(path + '/' + boost::lexical_cast<string>(num) + ".json.gz"), used_(false) {}
+ fname_(path + '/' + boost::lexical_cast<string>(num) + ".bin.gz"), used_(false) {}
bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) {
assert(!used_);
used_ = true;
cerr << " Writing forest to " << fname_ << endl;
WriteFile wf(fname_);
- return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream());
+ return HypergraphIO::WriteToBinary(forest, wf.stream());
}
diff --git a/decoder/hg.h b/decoder/hg.h
index 256f650f..124eab86 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include <boost/shared_ptr.hpp>
+#include <boost/serialization/vector.hpp>
#include "feature_vector.h"
#include "small_vector.h"
@@ -69,6 +70,18 @@ namespace HG {
short int j_;
short int prev_i_;
short int prev_j_;
+ template<class Archive>
+ void serialize(Archive & ar, const unsigned int version) {
+ ar & head_node_;
+ ar & tail_nodes_;
+ ar & rule_;
+ ar & feature_values_;
+ ar & i_;
+ ar & j_;
+ ar & prev_i_;
+ ar & prev_j_;
+ ar & id_;
+ }
void show(std::ostream &o,unsigned mask=SPAN|RULE) const {
o<<'{';
if (mask&CATEGORY)
@@ -149,6 +162,24 @@ namespace HG {
WordID NT() const { return -cat_; }
EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_
EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_
+ template<class Archive>
+ void save(Archive & ar, const unsigned int version) const {
+ ar & node_hash;
+ ar & id_;
+ ar & TD::Convert(-cat_);
+ ar & in_edges_;
+ ar & out_edges_;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int version) {
+ ar & node_hash;
+ ar & id_;
+ std::string cat; ar & cat;
+ cat_ = -TD::Convert(cat);
+ ar & in_edges_;
+ ar & out_edges_;
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting
node_hash = o.node_hash;
cat_=o.cat_;
@@ -492,6 +523,27 @@ public:
void set_ids(); // resync edge,node .id_
void check_ids() const; // assert that .id_ have been kept in sync
+ template<class Archive>
+ void save(Archive & ar, const unsigned int version) const {
+ unsigned ns = nodes_.size(); ar & ns;
+ unsigned es = edges_.size(); ar & es;
+ for (auto& n : nodes_) ar & n;
+ for (auto& e : edges_) ar & e;
+ int x;
+ x = edges_topo_; ar & x;
+ x = is_linear_chain_; ar & x;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int version) {
+ unsigned ns; ar & ns; nodes_.resize(ns);
+ unsigned es; ar & es; edges_.resize(es);
+ for (auto& n : nodes_) ar & n;
+ for (auto& e : edges_) ar & e;
+ int x;
+ ar & x; edges_topo_ = x;
+ ar & x; is_linear_chain_ = x;
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
private:
Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {}
};
diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc
index eb0be3d4..67760fb1 100644
--- a/decoder/hg_io.cc
+++ b/decoder/hg_io.cc
@@ -6,6 +6,10 @@
#include <sstream>
#include <iostream>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+
#include "fast_lexical_cast.hpp"
#include "tdict.h"
@@ -271,97 +275,16 @@ bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) {
return reader.Parse(in);
}
-static void WriteRule(const TRule& r, ostream* out) {
- if (!r.lhs_) { (*out) << "[X] ||| "; }
- JSONParser::WriteEscapedString(r.AsString(), out);
+bool HypergraphIO::ReadFromBinary(istream* in, Hypergraph* hg) {
+ boost::archive::binary_iarchive oa(*in);
+ hg->clear();
+ oa >> *hg;
+ return true;
}
-bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
- if (hg.empty()) { *out << "{}\n"; return true; }
- map<const TRule*, int> rid;
- ostream& o = *out;
- rid[NULL] = 0;
- o << '{';
- if (!remove_rules) {
- o << "\"rules\":[";
- for (int i = 0; i < hg.edges_.size(); ++i) {
- const TRule* r = hg.edges_[i].rule_.get();
- int &id = rid[r];
- if (!id) {
- id=rid.size() - 1;
- if (id > 1) o << ',';
- o << id << ',';
- WriteRule(*r, &o);
- };
- }
- o << "],";
- }
- const bool use_fdict = FD::NumFeats() < 1000;
- if (use_fdict) {
- o << "\"features\":[";
- for (int i = 1; i < FD::NumFeats(); ++i) {
- o << (i==1 ? "":",");
- JSONParser::WriteEscapedString(FD::Convert(i), &o);
- }
- o << "],";
- }
- vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order
- int edge_count = 0;
- for (int i = 0; i < hg.nodes_.size(); ++i) {
- const Hypergraph::Node& node = hg.nodes_[i];
- if (i > 0) { o << ","; }
- o << "\"edges\":[";
- for (int j = 0; j < node.in_edges_.size(); ++j) {
- const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
- edgemap[edge.id_] = edge_count;
- ++edge_count;
- o << (j == 0 ? "" : ",") << "{";
-
- o << "\"tail\":[";
- for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
- o << (k > 0 ? "," : "") << edge.tail_nodes_[k];
- }
- o << "],";
-
- o << "\"spans\":[" << edge.i_ << "," << edge.j_ << "," << edge.prev_i_ << "," << edge.prev_j_ << "],";
-
- o << "\"feats\":[";
- bool first = true;
- for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) {
- if (!it->second) continue; // don't write features that have a zero value
- if (!it->first) continue; // if the feature set was frozen this might happen
- if (!first) o << ',';
- if (use_fdict)
- o << (it->first - 1);
- else {
- JSONParser::WriteEscapedString(FD::Convert(it->first), &o);
- }
- o << ',' << it->second;
- first = false;
- }
- o << "]";
- if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; }
- o << "}";
- }
- o << "],";
-
- o << "\"node\":{\"in_edges\":[";
- for (int j = 0; j < node.in_edges_.size(); ++j) {
- int mapped_edge = edgemap[node.in_edges_[j]];
- assert(mapped_edge >= 0);
- o << (j == 0 ? "" : ",") << mapped_edge;
- }
- o << "]";
- if (node.cat_ < 0) {
- o << ",\"cat\":";
- JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o);
- }
- char buf[48];
- sprintf(buf, "%016lX", node.node_hash);
- o << ",\"node_hash\":\"" << buf << "\"";
- o << "}";
- }
- o << "}\n";
+bool HypergraphIO::WriteToBinary(const Hypergraph& hg, ostream* out) {
+ boost::archive::binary_oarchive oa(*out);
+ oa << hg;
return true;
}
diff --git a/decoder/hg_io.h b/decoder/hg_io.h
index 5a2bd808..5ba86f69 100644
--- a/decoder/hg_io.h
+++ b/decoder/hg_io.h
@@ -18,10 +18,11 @@ struct HypergraphIO {
// see test_data/small.json.gz for an email encoding
static bool ReadFromJSON(std::istream* in, Hypergraph* out);
+ static bool ReadFromBinary(std::istream* in, Hypergraph* out);
+ static bool WriteToBinary(const Hypergraph& hg, std::ostream* out);
+
// if remove_rules is used, the hypergraph is serialized without rule information
// (so it only contains structure and feature information)
- static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out);
-
static void WriteAsCFG(const Hypergraph& hg);
// Write only the target size information in bottom-up order.
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index 5cb8626a..25eddcec 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -1,6 +1,11 @@
#define BOOST_TEST_MODULE hg_test
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+#include <boost/serialization/vector.hpp>
+#include <sstream>
#include <iostream>
#include "tdict.h"
@@ -427,19 +432,29 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) {
}
}
-BOOST_AUTO_TEST_CASE(TestReadWriteHG) {
+BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) {
std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
- Hypergraph hg,hg2;
- CreateHG(path, &hg);
- hg.edges_.front().j_ = 23;
- hg.edges_.back().prev_i_ = 99;
- ostringstream os;
- HypergraphIO::WriteToJSON(hg, false, &os);
- istringstream is(os.str());
- HypergraphIO::ReadFromJSON(&is, &hg2);
- BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths());
- BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23);
- BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99);
+ Hypergraph hg;
+ Hypergraph hg2;
+ std::string out;
+ {
+ CreateHG(path, &hg);
+ hg.edges_.front().j_ = 23;
+ hg.edges_.back().prev_i_ = 99;
+ ostringstream os;
+ boost::archive::text_oarchive oa(os);
+ oa << hg;
+ out = os.str();
+ }
+ {
+ cerr << out << endl;
+ istringstream is(out);
+ boost::archive::text_iarchive ia(is);
+ ia >> hg2;
+ BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths());
+ BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23);
+ BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99);
+ }
}
BOOST_AUTO_TEST_SUITE_END()
diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll
index d4a8d86b..8b48ab7b 100644
--- a/decoder/rule_lexer.ll
+++ b/decoder/rule_lexer.ll
@@ -356,6 +356,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const
void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) {
init_default_feature_names();
+ scfglex_fname = srule;
lex_mono_rules = mono;
lex_line = 1;
rule_callback_extra = extra;
diff --git a/python/cdec/hypergraph.pxd b/python/cdec/hypergraph.pxd
index 1e150bbc..9780cf8b 100644
--- a/python/cdec/hypergraph.pxd
+++ b/python/cdec/hypergraph.pxd
@@ -63,7 +63,8 @@ cdef extern from "decoder/viterbi.h":
cdef extern from "decoder/hg_io.h" namespace "HypergraphIO":
# Hypergraph JSON I/O
bint ReadFromJSON(istream* inp, Hypergraph* out)
- bint WriteToJSON(Hypergraph& hg, bint remove_rules, ostream* out)
+ bint ReadFromBinary(istream* inp, Hypergraph* out)
+ bint WriteToBinary(Hypergraph& hg, ostream* out)
# Hypergraph PLF I/O
void ReadFromPLF(string& inp, Hypergraph* out)
string AsPLF(Hypergraph& hg, bint include_global_parentheses)
diff --git a/training/dpmert/mr_dpmert_generate_mapper_input.cc b/training/dpmert/mr_dpmert_generate_mapper_input.cc
index 199cd23a..3fa2f476 100644
--- a/training/dpmert/mr_dpmert_generate_mapper_input.cc
+++ b/training/dpmert/mr_dpmert_generate_mapper_input.cc
@@ -70,7 +70,7 @@ int main(int argc, char** argv) {
unsigned dev_set_size = conf["dev_set_size"].as<unsigned>();
for (unsigned i = 0; i < dev_set_size; ++i) {
for (unsigned j = 0; j < directions.size(); ++j) {
- cout << forest_repository << '/' << i << ".json.gz " << i << ' ';
+ cout << forest_repository << '/' << i << ".bin.gz " << i << ' ';
print(cout, origin, "=", ";");
cout << ' ';
print(cout, directions[j], "=", ";");
diff --git a/training/dpmert/mr_dpmert_map.cc b/training/dpmert/mr_dpmert_map.cc
index d1efcf96..2bf3f8fc 100644
--- a/training/dpmert/mr_dpmert_map.cc
+++ b/training/dpmert/mr_dpmert_map.cc
@@ -83,7 +83,7 @@ int main(int argc, char** argv) {
istringstream is(line);
int sent_id;
string file, s_origin, s_direction;
- // path-to-file (JSON) sent_ed starting-point search-direction
+ // path-to-file sent_ed starting-point search-direction
is >> file >> sent_id >> s_origin >> s_direction;
SparseVector<double> origin;
ReadSparseVectorString(s_origin, &origin);
@@ -93,7 +93,7 @@ int main(int argc, char** argv) {
if (last_file != file) {
last_file = file;
ReadFile rf(file);
- HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
}
const ConvexHullWeightFunction wf(origin, direction);
const ConvexHull hull = Inside<ConvexHull, ConvexHullWeightFunction>(hg, NULL, wf);
diff --git a/training/minrisk/minrisk_optimize.cc b/training/minrisk/minrisk_optimize.cc
index da8b5260..a2938fb0 100644
--- a/training/minrisk/minrisk_optimize.cc
+++ b/training/minrisk/minrisk_optimize.cc
@@ -178,7 +178,7 @@ int main(int argc, char** argv) {
ReadFile rf(file);
if (kis.size() % 5 == 0) { cerr << '.'; }
if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; }
- HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
hg.Reweight(weights);
curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]);
if (kbest_file.size())
diff --git a/training/pro/mr_pro_map.cc b/training/pro/mr_pro_map.cc
index da58cd24..b142fd05 100644
--- a/training/pro/mr_pro_map.cc
+++ b/training/pro/mr_pro_map.cc
@@ -203,7 +203,7 @@ int main(int argc, char** argv) {
const string kbest_file = os.str();
if (FileExists(kbest_file))
J_i.ReadFromFile(kbest_file);
- HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
hg.Reweight(weights);
J_i.AddKBestCandidates(hg, kbest_size, ds[sent_id]);
J_i.WriteToFile(kbest_file);
diff --git a/training/rampion/rampion_cccp.cc b/training/rampion/rampion_cccp.cc
index 1e36dc51..1c45bac5 100644
--- a/training/rampion/rampion_cccp.cc
+++ b/training/rampion/rampion_cccp.cc
@@ -136,7 +136,7 @@ int main(int argc, char** argv) {
ReadFile rf(file);
if (kis.size() % 5 == 0) { cerr << '.'; }
if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; }
- HypergraphIO::ReadFromJSON(rf.stream(), &hg);
+ HypergraphIO::ReadFromBinary(rf.stream(), &hg);
hg.Reweight(weights);
curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]);
if (kbest_file.size())
diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc
index 5c1b4d4a..000f2a26 100644
--- a/training/utils/grammar_convert.cc
+++ b/training/utils/grammar_convert.cc
@@ -43,7 +43,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::notify(*conf);
if (conf->count("help") || conf->count("input") == 0) {
- cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n";
+ cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into serialized hypergraph.\n";
cerr << dcmdline_options << endl;
exit(1);
}
@@ -254,7 +254,8 @@ void ProcessHypergraph(const vector<double>& w, const po::variables_map& conf, c
if (w.size() > 0) { hg->Reweight(w); }
if (conf.count("collapse_weights")) CollapseWeights(hg);
if (conf["output"].as<string>() == "json") {
- HypergraphIO::WriteToJSON(*hg, false, &cout);
+ cerr << "NOT IMPLEMENTED ... talk to cdyer if you need this functionality\n";
+ // HypergraphIO::WriteToBinary(*hg, &cout);
if (!ref.empty()) { cerr << "REF: " << ref << endl; }
} else {
vector<WordID> onebest;
diff --git a/utils/small_vector.h b/utils/small_vector.h
index c8cbcb2c..f16bc898 100644
--- a/utils/small_vector.h
+++ b/utils/small_vector.h
@@ -15,6 +15,7 @@
#include <new>
#include <stdint.h>
#include <boost/functional/hash.hpp>
+#include <boost/serialization/map.hpp>
//sizeof(T)/sizeof(T*)>1?sizeof(T)/sizeof(T*):1
@@ -297,6 +298,21 @@ public:
return hash_range(data_.ptr,data_.ptr+size_);
}
+ template<class Archive>
+ void save(Archive & ar, const unsigned int) const {
+ ar & size_;
+ for (unsigned i = 0; i < size_; ++i)
+ ar & (*this)[i];
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int) {
+ uint16_t s;
+ ar & s;
+ this->resize(s);
+ for (unsigned i = 0; i < size_; ++i)
+ ar & (*this)[i];
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
private:
union StorageType {
T vals[SV_MAX];
diff --git a/utils/small_vector_test.cc b/utils/small_vector_test.cc
index a4eb89ae..9e1a148d 100644
--- a/utils/small_vector_test.cc
+++ b/utils/small_vector_test.cc
@@ -3,6 +3,10 @@
#define BOOST_TEST_MODULE svTest
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <string>
+#include <sstream>
#include <iostream>
#include <vector>
@@ -128,3 +132,29 @@ BOOST_AUTO_TEST_CASE(Small) {
cerr << sizeof(SmallVectorInt) << endl;
cerr << sizeof(vector<int>) << endl;
}
+
+BOOST_AUTO_TEST_CASE(Serialize) {
+ std::string in;
+ {
+ SmallVectorInt v;
+ v.push_back(0);
+ v.push_back(1);
+ v.push_back(-2);
+ ostringstream os;
+ boost::archive::text_oarchive oa(os);
+ oa << v;
+ in = os.str();
+ cerr << in;
+ }
+ {
+ istringstream is(in);
+ boost::archive::text_iarchive ia(is);
+ SmallVectorInt v;
+ ia >> v;
+ BOOST_CHECK_EQUAL(v.size(), 3);
+ BOOST_CHECK_EQUAL(v[0], 0);
+ BOOST_CHECK_EQUAL(v[1], 1);
+ BOOST_CHECK_EQUAL(v[2], -2);
+ }
+}
+