summaryrefslogtreecommitdiff
path: root/extractor/translation_table.h
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-06-04 23:17:57 +0100
committerPaul Baltescu <pauldb89@gmail.com>2013-06-04 23:17:57 +0100
commita3243017d6b8c46cc3e41f4243311dc3dbc80ab4 (patch)
tree8b4ac016293196f10b48bed86fe9db0bdc371ac8 /extractor/translation_table.h
parent4a254c100a7565633fb79d57bffafdcb4e8f03db (diff)
Serialize data structures.
Diffstat (limited to 'extractor/translation_table.h')
-rw-r--r--extractor/translation_table.h42
1 files changed, 38 insertions, 4 deletions
diff --git a/extractor/translation_table.h b/extractor/translation_table.h
index ed43ad72..2a37bab7 100644
--- a/extractor/translation_table.h
+++ b/extractor/translation_table.h
@@ -7,6 +7,9 @@
#include <boost/filesystem.hpp>
#include <boost/functional/hash.hpp>
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/split_member.hpp>
+#include <boost/serialization/utility.hpp>
using namespace std;
namespace fs = boost::filesystem;
@@ -23,11 +26,16 @@ class DataArray;
*/
class TranslationTable {
public:
+ // Constructs a translation table from source data, target data and the
+ // corresponding alignment.
TranslationTable(
shared_ptr<DataArray> source_data_array,
shared_ptr<DataArray> target_data_array,
shared_ptr<Alignment> alignment);
+ // Creates empty translation table.
+ TranslationTable();
+
virtual ~TranslationTable();
// Returns p(e | f).
@@ -38,10 +46,7 @@ class TranslationTable {
virtual double GetSourceGivenTargetScore(const string& source_word,
const string& target_word);
- void WriteBinary(const fs::path& filepath) const;
-
- protected:
- TranslationTable();
+ bool operator==(const TranslationTable& other) const;
private:
// Increment links count for the given (f, e) word pair.
@@ -52,6 +57,35 @@ class TranslationTable {
int source_word_id,
int target_word_id) const;
+ friend class boost::serialization::access;
+
+ template<class Archive> void save(Archive& ar, unsigned int) const {
+ ar << *source_data_array << *target_data_array;
+
+ int num_entries = translation_probabilities.size();
+ ar << num_entries;
+ for (auto entry: translation_probabilities) {
+ ar << entry;
+ }
+ }
+
+ template<class Archive> void load(Archive& ar, unsigned int) {
+ source_data_array = make_shared<DataArray>();
+ ar >> *source_data_array;
+ target_data_array = make_shared<DataArray>();
+ ar >> *target_data_array;
+
+ int num_entries;
+ ar >> num_entries;
+ for (size_t i = 0; i < num_entries; ++i) {
+ pair<pair<int, int>, pair<double, double>> entry;
+ ar >> entry;
+ translation_probabilities.insert(entry);
+ }
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER();
+
shared_ptr<DataArray> source_data_array;
shared_ptr<DataArray> target_data_array;
unordered_map<pair<int, int>, pair<double, double>, PairHash>