From e8451a312b01b0917ffef06afc17f1e8cad5d510 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 18 Oct 2014 22:58:24 -0400 Subject: support serialization of TRule with boost serialization --- decoder/trule.h | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/trule_test.cc | 36 ++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/decoder/trule.h b/decoder/trule.h index adef7cc7..85842bb5 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -11,6 +11,7 @@ #include "sparse_vector.h" #include "wordid.h" +#include "tdict.h" class TRule; typedef boost::shared_ptr TRulePtr; @@ -26,6 +27,7 @@ struct AlignmentPoint { short s_; short t_; }; + inline std::ostream& operator<<(std::ostream& os, const AlignmentPoint& p) { return os << static_cast(p.s_) << '-' << static_cast(p.t_); } @@ -163,6 +165,66 @@ class TRule { // optional, shows internal structure of TSG rules boost::shared_ptr tree_structure; + friend class boost::serialization::access; + template + void save(Archive & ar, const unsigned int version) const { + ar & TD::Convert(-lhs_); + unsigned f_size = f_.size(); + ar & f_size; + assert(f_size <= (sizeof(size_t) * 8)); + size_t f_nt_mask = 0; + for (int i = f_.size() - 1; i >= 0; --i) { + f_nt_mask <<= 1; + f_nt_mask |= (f_[i] <= 0 ? 1 : 0); + } + ar & f_nt_mask; + for (unsigned i = 0; i < f_.size(); ++i) + ar & TD::Convert(f_[i] < 0 ? -f_[i] : f_[i]); + unsigned e_size = e_.size(); + ar & e_size; + size_t e_nt_mask = 0; + assert(e_size <= (sizeof(size_t) * 8)); + for (int i = e_.size() - 1; i >= 0; --i) { + e_nt_mask <<= 1; + e_nt_mask |= (e_[i] <= 0 ? 1 : 0); + } + ar & e_nt_mask; + for (unsigned i = 0; i < e_.size(); ++i) + if (e_[i] <= 0) ar & e_[i]; else ar & TD::Convert(e_[i]); + ar & arity_; + ar & scores_; + } + template + void load(Archive & ar, const unsigned int version) { + std::string lhs; ar & lhs; lhs_ = -TD::Convert(lhs); + unsigned f_size; ar & f_size; + f_.resize(f_size); + size_t f_nt_mask; ar & f_nt_mask; + std::string sym; + for (unsigned i = 0; i < f_size; ++i) { + bool mask = (f_nt_mask & 1); + ar & sym; + f_[i] = TD::Convert(sym) * (mask ? -1 : 1); + f_nt_mask >>= 1; + } + unsigned e_size; ar & e_size; + e_.resize(e_size); + size_t e_nt_mask; ar & e_nt_mask; + for (unsigned i = 0; i < e_size; ++i) { + bool mask = (e_nt_mask & 1); + if (mask) { + ar & e_[i]; + } else { + ar & sym; + e_[i] = TD::Convert(sym); + } + e_nt_mask >>= 1; + } + ar & arity_; + ar & scores_; + } + + BOOST_SERIALIZATION_SPLIT_MEMBER() private: TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} }; diff --git a/decoder/trule_test.cc b/decoder/trule_test.cc index 0cb7e2e8..d75c2016 100644 --- a/decoder/trule_test.cc +++ b/decoder/trule_test.cc @@ -4,6 +4,10 @@ #include #include #include +#include +#include +#include +#include #include "tdict.h" using namespace std; @@ -53,3 +57,35 @@ BOOST_AUTO_TEST_CASE(TestRuleR) { BOOST_CHECK_EQUAL(t6.e_[3], 0); } +BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) { + string str; + string t7str; + { + TRule t7; + t7.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [2] saw the [1] . ||| Feature1=0.12321 Foo=0.23232 Bar=0.121"); + cerr << t7.AsString() << endl; + ostringstream os; + TRulePtr tp1(new TRule("[X] ||| a b c ||| x z y ||| A=1 B=2")); + TRulePtr tp2 = tp1; + boost::archive::text_oarchive oa(os); + oa << t7; + oa << tp1; + oa << tp2; + str = os.str(); + t7str = t7.AsString(); + } + { + istringstream is(str); + boost::archive::text_iarchive ia(is); + TRule t8; + ia >> t8; + TRulePtr tp3, tp4; + ia >> tp3; + ia >> tp4; + cerr << t8.AsString() << endl; + BOOST_CHECK_EQUAL(t7str, t8.AsString()); + cerr << tp3->AsString() << endl; + cerr << tp4->AsString() << endl; + } +} + -- cgit v1.2.3