summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-10-19 05:24:21 -0400
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-10-19 05:24:21 -0400
commit011a87cfe6d9cc702cb4a8a6d9a765556e460af9 (patch)
tree4747c3fc01d45dfdf58a454ef22e388d9f245e71 /decoder
parente8451a312b01b0917ffef06afc17f1e8cad5d510 (diff)
stop switch to boost serialization for hypergraph IO
Diffstat (limited to 'decoder')
-rw-r--r--decoder/decoder.cc4
-rw-r--r--decoder/forest_writer.cc4
-rw-r--r--decoder/hg.h52
-rw-r--r--decoder/hg_io.cc101
-rw-r--r--decoder/hg_io.h5
-rw-r--r--decoder/hg_test.cc39
-rw-r--r--decoder/rule_lexer.ll1
7 files changed, 99 insertions, 107 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 3cc77d27..f8214f5f 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -930,7 +930,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
@@ -1023,7 +1023,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
diff --git a/decoder/forest_writer.cc b/decoder/forest_writer.cc
index 6e4cccb3..c072e599 100644
--- a/decoder/forest_writer.cc
+++ b/decoder/forest_writer.cc
@@ -11,13 +11,13 @@
using namespace std;
ForestWriter::ForestWriter(const std::string& path, int num) :
- fname_(path + '/' + boost::lexical_cast<string>(num) + ".json.gz"), used_(false) {}
+ fname_(path + '/' + boost::lexical_cast<string>(num) + ".bin.gz"), used_(false) {}
bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) {
assert(!used_);
used_ = true;
cerr << " Writing forest to " << fname_ << endl;
WriteFile wf(fname_);
- return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream());
+ return HypergraphIO::WriteToBinary(forest, wf.stream());
}
diff --git a/decoder/hg.h b/decoder/hg.h
index 256f650f..124eab86 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include <boost/shared_ptr.hpp>
+#include <boost/serialization/vector.hpp>
#include "feature_vector.h"
#include "small_vector.h"
@@ -69,6 +70,18 @@ namespace HG {
short int j_;
short int prev_i_;
short int prev_j_;
+ template<class Archive>
+ void serialize(Archive & ar, const unsigned int version) {
+ ar & head_node_;
+ ar & tail_nodes_;
+ ar & rule_;
+ ar & feature_values_;
+ ar & i_;
+ ar & j_;
+ ar & prev_i_;
+ ar & prev_j_;
+ ar & id_;
+ }
void show(std::ostream &o,unsigned mask=SPAN|RULE) const {
o<<'{';
if (mask&CATEGORY)
@@ -149,6 +162,24 @@ namespace HG {
WordID NT() const { return -cat_; }
EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_
EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_
+ template<class Archive>
+ void save(Archive & ar, const unsigned int version) const {
+ ar & node_hash;
+ ar & id_;
+ ar & TD::Convert(-cat_);
+ ar & in_edges_;
+ ar & out_edges_;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int version) {
+ ar & node_hash;
+ ar & id_;
+ std::string cat; ar & cat;
+ cat_ = -TD::Convert(cat);
+ ar & in_edges_;
+ ar & out_edges_;
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting
node_hash = o.node_hash;
cat_=o.cat_;
@@ -492,6 +523,27 @@ public:
void set_ids(); // resync edge,node .id_
void check_ids() const; // assert that .id_ have been kept in sync
+ template<class Archive>
+ void save(Archive & ar, const unsigned int version) const {
+ unsigned ns = nodes_.size(); ar & ns;
+ unsigned es = edges_.size(); ar & es;
+ for (auto& n : nodes_) ar & n;
+ for (auto& e : edges_) ar & e;
+ int x;
+ x = edges_topo_; ar & x;
+ x = is_linear_chain_; ar & x;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int version) {
+ unsigned ns; ar & ns; nodes_.resize(ns);
+ unsigned es; ar & es; edges_.resize(es);
+ for (auto& n : nodes_) ar & n;
+ for (auto& e : edges_) ar & e;
+ int x;
+ ar & x; edges_topo_ = x;
+ ar & x; is_linear_chain_ = x;
+ }
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
private:
Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {}
};
diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc
index eb0be3d4..67760fb1 100644
--- a/decoder/hg_io.cc
+++ b/decoder/hg_io.cc
@@ -6,6 +6,10 @@
#include <sstream>
#include <iostream>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+
#include "fast_lexical_cast.hpp"
#include "tdict.h"
@@ -271,97 +275,16 @@ bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) {
return reader.Parse(in);
}
-static void WriteRule(const TRule& r, ostream* out) {
- if (!r.lhs_) { (*out) << "[X] ||| "; }
- JSONParser::WriteEscapedString(r.AsString(), out);
+bool HypergraphIO::ReadFromBinary(istream* in, Hypergraph* hg) {
+ boost::archive::binary_iarchive oa(*in);
+ hg->clear();
+ oa >> *hg;
+ return true;
}
-bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
- if (hg.empty()) { *out << "{}\n"; return true; }
- map<const TRule*, int> rid;
- ostream& o = *out;
- rid[NULL] = 0;
- o << '{';
- if (!remove_rules) {
- o << "\"rules\":[";
- for (int i = 0; i < hg.edges_.size(); ++i) {
- const TRule* r = hg.edges_[i].rule_.get();
- int &id = rid[r];
- if (!id) {
- id=rid.size() - 1;
- if (id > 1) o << ',';
- o << id << ',';
- WriteRule(*r, &o);
- };
- }
- o << "],";
- }
- const bool use_fdict = FD::NumFeats() < 1000;
- if (use_fdict) {
- o << "\"features\":[";
- for (int i = 1; i < FD::NumFeats(); ++i) {
- o << (i==1 ? "":",");
- JSONParser::WriteEscapedString(FD::Convert(i), &o);
- }
- o << "],";
- }
- vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order
- int edge_count = 0;
- for (int i = 0; i < hg.nodes_.size(); ++i) {
- const Hypergraph::Node& node = hg.nodes_[i];
- if (i > 0) { o << ","; }
- o << "\"edges\":[";
- for (int j = 0; j < node.in_edges_.size(); ++j) {
- const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
- edgemap[edge.id_] = edge_count;
- ++edge_count;
- o << (j == 0 ? "" : ",") << "{";
-
- o << "\"tail\":[";
- for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
- o << (k > 0 ? "," : "") << edge.tail_nodes_[k];
- }
- o << "],";
-
- o << "\"spans\":[" << edge.i_ << "," << edge.j_ << "," << edge.prev_i_ << "," << edge.prev_j_ << "],";
-
- o << "\"feats\":[";
- bool first = true;
- for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) {
- if (!it->second) continue; // don't write features that have a zero value
- if (!it->first) continue; // if the feature set was frozen this might happen
- if (!first) o << ',';
- if (use_fdict)
- o << (it->first - 1);
- else {
- JSONParser::WriteEscapedString(FD::Convert(it->first), &o);
- }
- o << ',' << it->second;
- first = false;
- }
- o << "]";
- if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; }
- o << "}";
- }
- o << "],";
-
- o << "\"node\":{\"in_edges\":[";
- for (int j = 0; j < node.in_edges_.size(); ++j) {
- int mapped_edge = edgemap[node.in_edges_[j]];
- assert(mapped_edge >= 0);
- o << (j == 0 ? "" : ",") << mapped_edge;
- }
- o << "]";
- if (node.cat_ < 0) {
- o << ",\"cat\":";
- JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o);
- }
- char buf[48];
- sprintf(buf, "%016lX", node.node_hash);
- o << ",\"node_hash\":\"" << buf << "\"";
- o << "}";
- }
- o << "}\n";
+bool HypergraphIO::WriteToBinary(const Hypergraph& hg, ostream* out) {
+ boost::archive::binary_oarchive oa(*out);
+ oa << hg;
return true;
}
diff --git a/decoder/hg_io.h b/decoder/hg_io.h
index 5a2bd808..5ba86f69 100644
--- a/decoder/hg_io.h
+++ b/decoder/hg_io.h
@@ -18,10 +18,11 @@ struct HypergraphIO {
// see test_data/small.json.gz for an email encoding
static bool ReadFromJSON(std::istream* in, Hypergraph* out);
+ static bool ReadFromBinary(std::istream* in, Hypergraph* out);
+ static bool WriteToBinary(const Hypergraph& hg, std::ostream* out);
+
// if remove_rules is used, the hypergraph is serialized without rule information
// (so it only contains structure and feature information)
- static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out);
-
static void WriteAsCFG(const Hypergraph& hg);
// Write only the target size information in bottom-up order.
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index 5cb8626a..25eddcec 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -1,6 +1,11 @@
#define BOOST_TEST_MODULE hg_test
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+#include <boost/serialization/vector.hpp>
+#include <sstream>
#include <iostream>
#include "tdict.h"
@@ -427,19 +432,29 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) {
}
}
-BOOST_AUTO_TEST_CASE(TestReadWriteHG) {
+BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) {
std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA);
- Hypergraph hg,hg2;
- CreateHG(path, &hg);
- hg.edges_.front().j_ = 23;
- hg.edges_.back().prev_i_ = 99;
- ostringstream os;
- HypergraphIO::WriteToJSON(hg, false, &os);
- istringstream is(os.str());
- HypergraphIO::ReadFromJSON(&is, &hg2);
- BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths());
- BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23);
- BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99);
+ Hypergraph hg;
+ Hypergraph hg2;
+ std::string out;
+ {
+ CreateHG(path, &hg);
+ hg.edges_.front().j_ = 23;
+ hg.edges_.back().prev_i_ = 99;
+ ostringstream os;
+ boost::archive::text_oarchive oa(os);
+ oa << hg;
+ out = os.str();
+ }
+ {
+ cerr << out << endl;
+ istringstream is(out);
+ boost::archive::text_iarchive ia(is);
+ ia >> hg2;
+ BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths());
+ BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23);
+ BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99);
+ }
}
BOOST_AUTO_TEST_SUITE_END()
diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll
index d4a8d86b..8b48ab7b 100644
--- a/decoder/rule_lexer.ll
+++ b/decoder/rule_lexer.ll
@@ -356,6 +356,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const
void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) {
init_default_feature_names();
+ scfglex_fname = srule;
lex_mono_rules = mono;
lex_line = 1;
rule_callback_extra = extra;