diff options
author | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2014-10-18 22:58:24 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2014-10-18 22:58:24 -0400 |
commit | 9f055ef11c3f1d06b51bab34addac68e4d63eab7 (patch) | |
tree | 0cc1e92bedf61f22689df6c34d317de750d3f948 | |
parent | e0b225782a239c217da585001c12c4de169d8f0c (diff) |
support serialization of TRule with boost serialization
-rw-r--r-- | decoder/trule.h | 62 | ||||
-rw-r--r-- | decoder/trule_test.cc | 36 |
2 files changed, 98 insertions, 0 deletions
diff --git a/decoder/trule.h b/decoder/trule.h index adef7cc7..85842bb5 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -11,6 +11,7 @@ #include "sparse_vector.h" #include "wordid.h" +#include "tdict.h" class TRule; typedef boost::shared_ptr<TRule> TRulePtr; @@ -26,6 +27,7 @@ struct AlignmentPoint { short s_; short t_; }; + inline std::ostream& operator<<(std::ostream& os, const AlignmentPoint& p) { return os << static_cast<int>(p.s_) << '-' << static_cast<int>(p.t_); } @@ -163,6 +165,66 @@ class TRule { // optional, shows internal structure of TSG rules boost::shared_ptr<cdec::TreeFragment> tree_structure; + friend class boost::serialization::access; + template<class Archive> + void save(Archive & ar, const unsigned int version) const { + ar & TD::Convert(-lhs_); + unsigned f_size = f_.size(); + ar & f_size; + assert(f_size <= (sizeof(size_t) * 8)); + size_t f_nt_mask = 0; + for (int i = f_.size() - 1; i >= 0; --i) { + f_nt_mask <<= 1; + f_nt_mask |= (f_[i] <= 0 ? 1 : 0); + } + ar & f_nt_mask; + for (unsigned i = 0; i < f_.size(); ++i) + ar & TD::Convert(f_[i] < 0 ? -f_[i] : f_[i]); + unsigned e_size = e_.size(); + ar & e_size; + size_t e_nt_mask = 0; + assert(e_size <= (sizeof(size_t) * 8)); + for (int i = e_.size() - 1; i >= 0; --i) { + e_nt_mask <<= 1; + e_nt_mask |= (e_[i] <= 0 ? 1 : 0); + } + ar & e_nt_mask; + for (unsigned i = 0; i < e_.size(); ++i) + if (e_[i] <= 0) ar & e_[i]; else ar & TD::Convert(e_[i]); + ar & arity_; + ar & scores_; + } + template<class Archive> + void load(Archive & ar, const unsigned int version) { + std::string lhs; ar & lhs; lhs_ = -TD::Convert(lhs); + unsigned f_size; ar & f_size; + f_.resize(f_size); + size_t f_nt_mask; ar & f_nt_mask; + std::string sym; + for (unsigned i = 0; i < f_size; ++i) { + bool mask = (f_nt_mask & 1); + ar & sym; + f_[i] = TD::Convert(sym) * (mask ? -1 : 1); + f_nt_mask >>= 1; + } + unsigned e_size; ar & e_size; + e_.resize(e_size); + size_t e_nt_mask; ar & e_nt_mask; + for (unsigned i = 0; i < e_size; ++i) { + bool mask = (e_nt_mask & 1); + if (mask) { + ar & e_[i]; + } else { + ar & sym; + e_[i] = TD::Convert(sym); + } + e_nt_mask >>= 1; + } + ar & arity_; + ar & scores_; + } + + BOOST_SERIALIZATION_SPLIT_MEMBER() private: TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} }; diff --git a/decoder/trule_test.cc b/decoder/trule_test.cc index 0cb7e2e8..d75c2016 100644 --- a/decoder/trule_test.cc +++ b/decoder/trule_test.cc @@ -4,6 +4,10 @@ #include <boost/test/unit_test.hpp> #include <boost/test/floating_point_comparison.hpp> #include <iostream> +#include <boost/archive/text_oarchive.hpp> +#include <boost/archive/text_iarchive.hpp> +#include <boost/serialization/shared_ptr.hpp> +#include <sstream> #include "tdict.h" using namespace std; @@ -53,3 +57,35 @@ BOOST_AUTO_TEST_CASE(TestRuleR) { BOOST_CHECK_EQUAL(t6.e_[3], 0); } +BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) { + string str; + string t7str; + { + TRule t7; + t7.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [2] saw the [1] . ||| Feature1=0.12321 Foo=0.23232 Bar=0.121"); + cerr << t7.AsString() << endl; + ostringstream os; + TRulePtr tp1(new TRule("[X] ||| a b c ||| x z y ||| A=1 B=2")); + TRulePtr tp2 = tp1; + boost::archive::text_oarchive oa(os); + oa << t7; + oa << tp1; + oa << tp2; + str = os.str(); + t7str = t7.AsString(); + } + { + istringstream is(str); + boost::archive::text_iarchive ia(is); + TRule t8; + ia >> t8; + TRulePtr tp3, tp4; + ia >> tp3; + ia >> tp4; + cerr << t8.AsString() << endl; + BOOST_CHECK_EQUAL(t7str, t8.AsString()); + cerr << tp3->AsString() << endl; + cerr << tp4->AsString() << endl; + } +} + |