summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-10-18 22:58:24 -0400
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2014-10-18 22:58:24 -0400
commite8451a312b01b0917ffef06afc17f1e8cad5d510 (patch)
tree379711cb2e2d06e831921a457928d74373ac7ce6
parent2e485c55fb17b75c0b153af349d8283ab0e9384f (diff)
support serialization of TRule with boost serialization
-rw-r--r--decoder/trule.h62
-rw-r--r--decoder/trule_test.cc36
2 files changed, 98 insertions, 0 deletions
diff --git a/decoder/trule.h b/decoder/trule.h
index adef7cc7..85842bb5 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -11,6 +11,7 @@
#include "sparse_vector.h"
#include "wordid.h"
+#include "tdict.h"
class TRule;
typedef boost::shared_ptr<TRule> TRulePtr;
@@ -26,6 +27,7 @@ struct AlignmentPoint {
short s_;
short t_;
};
+
inline std::ostream& operator<<(std::ostream& os, const AlignmentPoint& p) {
return os << static_cast<int>(p.s_) << '-' << static_cast<int>(p.t_);
}
@@ -163,6 +165,66 @@ class TRule {
// optional, shows internal structure of TSG rules
boost::shared_ptr<cdec::TreeFragment> tree_structure;
+ friend class boost::serialization::access;
+ template<class Archive>
+ void save(Archive & ar, const unsigned int version) const {
+ ar & TD::Convert(-lhs_);
+ unsigned f_size = f_.size();
+ ar & f_size;
+ assert(f_size <= (sizeof(size_t) * 8));
+ size_t f_nt_mask = 0;
+ for (int i = f_.size() - 1; i >= 0; --i) {
+ f_nt_mask <<= 1;
+ f_nt_mask |= (f_[i] <= 0 ? 1 : 0);
+ }
+ ar & f_nt_mask;
+ for (unsigned i = 0; i < f_.size(); ++i)
+ ar & TD::Convert(f_[i] < 0 ? -f_[i] : f_[i]);
+ unsigned e_size = e_.size();
+ ar & e_size;
+ size_t e_nt_mask = 0;
+ assert(e_size <= (sizeof(size_t) * 8));
+ for (int i = e_.size() - 1; i >= 0; --i) {
+ e_nt_mask <<= 1;
+ e_nt_mask |= (e_[i] <= 0 ? 1 : 0);
+ }
+ ar & e_nt_mask;
+ for (unsigned i = 0; i < e_.size(); ++i)
+ if (e_[i] <= 0) ar & e_[i]; else ar & TD::Convert(e_[i]);
+ ar & arity_;
+ ar & scores_;
+ }
+ template<class Archive>
+ void load(Archive & ar, const unsigned int version) {
+ std::string lhs; ar & lhs; lhs_ = -TD::Convert(lhs);
+ unsigned f_size; ar & f_size;
+ f_.resize(f_size);
+ size_t f_nt_mask; ar & f_nt_mask;
+ std::string sym;
+ for (unsigned i = 0; i < f_size; ++i) {
+ bool mask = (f_nt_mask & 1);
+ ar & sym;
+ f_[i] = TD::Convert(sym) * (mask ? -1 : 1);
+ f_nt_mask >>= 1;
+ }
+ unsigned e_size; ar & e_size;
+ e_.resize(e_size);
+ size_t e_nt_mask; ar & e_nt_mask;
+ for (unsigned i = 0; i < e_size; ++i) {
+ bool mask = (e_nt_mask & 1);
+ if (mask) {
+ ar & e_[i];
+ } else {
+ ar & sym;
+ e_[i] = TD::Convert(sym);
+ }
+ e_nt_mask >>= 1;
+ }
+ ar & arity_;
+ ar & scores_;
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER()
private:
TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}
};
diff --git a/decoder/trule_test.cc b/decoder/trule_test.cc
index 0cb7e2e8..d75c2016 100644
--- a/decoder/trule_test.cc
+++ b/decoder/trule_test.cc
@@ -4,6 +4,10 @@
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
#include <iostream>
+#include <boost/archive/text_oarchive.hpp>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/serialization/shared_ptr.hpp>
+#include <sstream>
#include "tdict.h"
using namespace std;
@@ -53,3 +57,35 @@ BOOST_AUTO_TEST_CASE(TestRuleR) {
BOOST_CHECK_EQUAL(t6.e_[3], 0);
}
+BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) {
+ string str;
+ string t7str;
+ {
+ TRule t7;
+ t7.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [2] saw the [1] . ||| Feature1=0.12321 Foo=0.23232 Bar=0.121");
+ cerr << t7.AsString() << endl;
+ ostringstream os;
+ TRulePtr tp1(new TRule("[X] ||| a b c ||| x z y ||| A=1 B=2"));
+ TRulePtr tp2 = tp1;
+ boost::archive::text_oarchive oa(os);
+ oa << t7;
+ oa << tp1;
+ oa << tp2;
+ str = os.str();
+ t7str = t7.AsString();
+ }
+ {
+ istringstream is(str);
+ boost::archive::text_iarchive ia(is);
+ TRule t8;
+ ia >> t8;
+ TRulePtr tp3, tp4;
+ ia >> tp3;
+ ia >> tp4;
+ cerr << t8.AsString() << endl;
+ BOOST_CHECK_EQUAL(t7str, t8.AsString());
+ cerr << tp3->AsString() << endl;
+ cerr << tp4->AsString() << endl;
+ }
+}
+