diff options
author | Paul Baltescu <pauldb89@gmail.com> | 2013-06-04 23:17:57 +0100 |
---|---|---|
committer | Paul Baltescu <pauldb89@gmail.com> | 2013-06-04 23:17:57 +0100 |
commit | a3243017d6b8c46cc3e41f4243311dc3dbc80ab4 (patch) | |
tree | 8b4ac016293196f10b48bed86fe9db0bdc371ac8 /extractor/translation_table.h | |
parent | 4a254c100a7565633fb79d57bffafdcb4e8f03db (diff) |
Serialize data structures.
Diffstat (limited to 'extractor/translation_table.h')
-rw-r--r-- | extractor/translation_table.h | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/extractor/translation_table.h b/extractor/translation_table.h index ed43ad72..2a37bab7 100644 --- a/extractor/translation_table.h +++ b/extractor/translation_table.h @@ -7,6 +7,9 @@ #include <boost/filesystem.hpp> #include <boost/functional/hash.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/split_member.hpp> +#include <boost/serialization/utility.hpp> using namespace std; namespace fs = boost::filesystem; @@ -23,11 +26,16 @@ class DataArray; */ class TranslationTable { public: + // Constructs a translation table from source data, target data and the + // corresponding alignment. TranslationTable( shared_ptr<DataArray> source_data_array, shared_ptr<DataArray> target_data_array, shared_ptr<Alignment> alignment); + // Creates empty translation table. + TranslationTable(); + virtual ~TranslationTable(); // Returns p(e | f). @@ -38,10 +46,7 @@ class TranslationTable { virtual double GetSourceGivenTargetScore(const string& source_word, const string& target_word); - void WriteBinary(const fs::path& filepath) const; - - protected: - TranslationTable(); + bool operator==(const TranslationTable& other) const; private: // Increment links count for the given (f, e) word pair. @@ -52,6 +57,35 @@ class TranslationTable { int source_word_id, int target_word_id) const; + friend class boost::serialization::access; + + template<class Archive> void save(Archive& ar, unsigned int) const { + ar << *source_data_array << *target_data_array; + + int num_entries = translation_probabilities.size(); + ar << num_entries; + for (auto entry: translation_probabilities) { + ar << entry; + } + } + + template<class Archive> void load(Archive& ar, unsigned int) { + source_data_array = make_shared<DataArray>(); + ar >> *source_data_array; + target_data_array = make_shared<DataArray>(); + ar >> *target_data_array; + + int num_entries; + ar >> num_entries; + for (size_t i = 0; i < num_entries; ++i) { + pair<pair<int, int>, pair<double, double>> entry; + ar >> entry; + translation_probabilities.insert(entry); + } + } + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + shared_ptr<DataArray> source_data_array; shared_ptr<DataArray> target_data_array; unordered_map<pair<int, int>, pair<double, double>, PairHash> |