summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-06-24 14:40:07 +0200
committerPatrick Simianer <p@simianer.de>2013-06-24 14:40:07 +0200
commite547ab5f765c72ad326b1d3a79f26bb221364d7d (patch)
treee205609de0adce98bdf4ec4e799cd776cebe8b72 /extractor
parentbecb1347773ebaae8cab2669afe4bad048cda992 (diff)
parent5794c0109902cf19a52cc8f1799353270ed9d85d (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'extractor')
-rw-r--r--extractor/alignment.cc15
-rw-r--r--extractor/alignment.h19
-rw-r--r--extractor/alignment_test.cc26
-rw-r--r--extractor/compile.cc82
-rw-r--r--extractor/data_array.cc33
-rw-r--r--extractor/data_array.h38
-rw-r--r--extractor/data_array_test.cc91
-rw-r--r--extractor/precomputation.cc21
-rw-r--r--extractor/precomputation.h35
-rw-r--r--extractor/precomputation_test.cc21
-rw-r--r--extractor/run_extractor.cc26
-rw-r--r--extractor/suffix_array.cc20
-rw-r--r--extractor/suffix_array.h28
-rw-r--r--extractor/suffix_array_test.cc46
-rw-r--r--extractor/translation_table.cc27
-rw-r--r--extractor/translation_table.h42
-rw-r--r--extractor/translation_table_test.cc149
17 files changed, 473 insertions, 246 deletions
diff --git a/extractor/alignment.cc b/extractor/alignment.cc
index b187c03a..2278c825 100644
--- a/extractor/alignment.cc
+++ b/extractor/alignment.cc
@@ -23,8 +23,8 @@ Alignment::Alignment(const string& filename) {
boost::split(items, line, boost::is_any_of(" -"));
vector<pair<int, int>> alignment;
alignment.reserve(items.size() / 2);
- for (size_t i = 0; i < items.size(); i += 2) {
- alignment.push_back(make_pair(stoi(items[i]), stoi(items[i + 1])));
+ for (size_t i = 1; i < items.size(); i += 2) {
+ alignment.push_back(make_pair(stoi(items[i - 1]), stoi(items[i])));
}
alignments.push_back(alignment);
}
@@ -39,15 +39,8 @@ vector<pair<int, int>> Alignment::GetLinks(int sentence_index) const {
return alignments[sentence_index];
}
-void Alignment::WriteBinary(const fs::path& filepath) {
- FILE* file = fopen(filepath.string().c_str(), "w");
- int size = alignments.size();
- fwrite(&size, sizeof(int), 1, file);
- for (vector<pair<int, int>> alignment: alignments) {
- size = alignment.size();
- fwrite(&size, sizeof(int), 1, file);
- fwrite(alignment.data(), sizeof(pair<int, int>), size, file);
- }
+bool Alignment::operator==(const Alignment& other) const {
+ return alignments == other.alignments;
}
} // namespace extractor
diff --git a/extractor/alignment.h b/extractor/alignment.h
index 4596f92b..dc5a8b55 100644
--- a/extractor/alignment.h
+++ b/extractor/alignment.h
@@ -5,6 +5,10 @@
#include <vector>
#include <boost/filesystem.hpp>
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/split_member.hpp>
+#include <boost/serialization/utility.hpp>
+#include <boost/serialization/vector.hpp>
namespace fs = boost::filesystem;
using namespace std;
@@ -19,18 +23,23 @@ class Alignment {
// Reads alignment from text file.
Alignment(const string& filename);
+ // Creates empty alignment.
+ Alignment();
+
// Returns the alignment for a given sentence.
virtual vector<pair<int, int>> GetLinks(int sentence_index) const;
- // Writes alignment to file in binary format.
- void WriteBinary(const fs::path& filepath);
-
virtual ~Alignment();
- protected:
- Alignment();
+ bool operator==(const Alignment& alignment) const;
private:
+ friend class boost::serialization::access;
+
+ template<class Archive> void serialize(Archive& ar, unsigned int) {
+ ar & alignments;
+ }
+
vector<vector<pair<int, int>>> alignments;
};
diff --git a/extractor/alignment_test.cc b/extractor/alignment_test.cc
index 43c37ebd..1b8ff531 100644
--- a/extractor/alignment_test.cc
+++ b/extractor/alignment_test.cc
@@ -1,12 +1,16 @@
#include <gtest/gtest.h>
-#include <memory>
+#include <sstream>
#include <string>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+
#include "alignment.h"
using namespace std;
using namespace ::testing;
+namespace ar = boost::archive;
namespace extractor {
namespace {
@@ -14,19 +18,31 @@ namespace {
class AlignmentTest : public Test {
protected:
virtual void SetUp() {
- alignment = make_shared<Alignment>("sample_alignment.txt");
+ alignment = Alignment("sample_alignment.txt");
}
- shared_ptr<Alignment> alignment;
+ Alignment alignment;
};
TEST_F(AlignmentTest, TestGetLinks) {
vector<pair<int, int>> expected_links = {
make_pair(0, 0), make_pair(1, 1), make_pair(2, 2)
};
- EXPECT_EQ(expected_links, alignment->GetLinks(0));
+ EXPECT_EQ(expected_links, alignment.GetLinks(0));
expected_links = {make_pair(1, 0), make_pair(2, 1)};
- EXPECT_EQ(expected_links, alignment->GetLinks(1));
+ EXPECT_EQ(expected_links, alignment.GetLinks(1));
+}
+
+TEST_F(AlignmentTest, TestSerialization) {
+ stringstream stream(ios_base::binary | ios_base::out | ios_base::in);
+ ar::binary_oarchive output_stream(stream, ar::no_header);
+ output_stream << alignment;
+
+ Alignment alignment_copy;
+ ar::binary_iarchive input_stream(stream, ar::no_header);
+ input_stream >> alignment_copy;
+
+ EXPECT_EQ(alignment, alignment_copy);
}
} // namespace
diff --git a/extractor/compile.cc b/extractor/compile.cc
index a9ae2cef..65fdd509 100644
--- a/extractor/compile.cc
+++ b/extractor/compile.cc
@@ -1,6 +1,8 @@
+#include <fstream>
#include <iostream>
#include <string>
+#include <boost/archive/binary_oarchive.hpp>
#include <boost/filesystem.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
@@ -9,8 +11,10 @@
#include "data_array.h"
#include "precomputation.h"
#include "suffix_array.h"
+#include "time_util.h"
#include "translation_table.h"
+namespace ar = boost::archive;
namespace fs = boost::filesystem;
namespace po = boost::program_options;
using namespace std;
@@ -58,11 +62,14 @@ int main(int argc, char** argv) {
return 1;
}
- fs::path output_dir(vm["output"].as<string>().c_str());
+ fs::path output_dir(vm["output"].as<string>());
if (!fs::exists(output_dir)) {
fs::create_directory(output_dir);
}
+ // Reading source and target data.
+ Clock::time_point start_time = Clock::now();
+ cerr << "Reading source and target data..." << endl;
shared_ptr<DataArray> source_data_array, target_data_array;
if (vm.count("bitext")) {
source_data_array = make_shared<DataArray>(
@@ -73,15 +80,53 @@ int main(int argc, char** argv) {
source_data_array = make_shared<DataArray>(vm["source"].as<string>());
target_data_array = make_shared<DataArray>(vm["target"].as<string>());
}
+
+ Clock::time_point start_write = Clock::now();
+ ofstream target_fstream((output_dir / fs::path("target.bin")).string());
+ ar::binary_oarchive target_stream(target_fstream);
+ target_stream << *target_data_array;
+ Clock::time_point stop_write = Clock::now();
+ double write_duration = GetDuration(start_write, stop_write);
+
+ Clock::time_point stop_time = Clock::now();
+ cerr << "Reading data took " << GetDuration(start_time, stop_time)
+ << " seconds" << endl;
+
+ // Constructing and compiling the suffix array.
+ start_time = Clock::now();
+ cerr << "Constructing source suffix array..." << endl;
shared_ptr<SuffixArray> source_suffix_array =
make_shared<SuffixArray>(source_data_array);
- source_suffix_array->WriteBinary(output_dir / fs::path("f.bin"));
- target_data_array->WriteBinary(output_dir / fs::path("e.bin"));
+ start_write = Clock::now();
+ ofstream source_fstream((output_dir / fs::path("source.bin")).string());
+ ar::binary_oarchive output_stream(source_fstream);
+ output_stream << *source_suffix_array;
+ stop_write = Clock::now();
+ write_duration += GetDuration(start_write, stop_write);
+
+ cerr << "Constructing suffix array took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ // Reading alignment.
+ start_time = Clock::now();
+ cerr << "Reading alignment..." << endl;
shared_ptr<Alignment> alignment =
make_shared<Alignment>(vm["alignment"].as<string>());
- alignment->WriteBinary(output_dir / fs::path("a.bin"));
+ start_write = Clock::now();
+ ofstream alignment_fstream((output_dir / fs::path("alignment.bin")).string());
+ ar::binary_oarchive alignment_stream(alignment_fstream);
+ alignment_stream << *alignment;
+ stop_write = Clock::now();
+ write_duration += GetDuration(start_write, stop_write);
+
+ stop_time = Clock::now();
+ cerr << "Reading alignment took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Precomputing collocations..." << endl;
Precomputation precomputation(
source_suffix_array,
vm["frequent"].as<int>(),
@@ -91,10 +136,35 @@ int main(int argc, char** argv) {
vm["min_gap_size"].as<int>(),
vm["max_phrase_len"].as<int>(),
vm["min_frequency"].as<int>());
- precomputation.WriteBinary(output_dir / fs::path("precompute.bin"));
+ start_write = Clock::now();
+ ofstream precomp_fstream((output_dir / fs::path("precomp.bin")).string());
+ ar::binary_oarchive precomp_stream(precomp_fstream);
+ precomp_stream << precomputation;
+ stop_write = Clock::now();
+ write_duration += GetDuration(start_write, stop_write);
+
+ stop_time = Clock::now();
+ cerr << "Precomputing collocations took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ start_time = Clock::now();
+ cerr << "Precomputing conditional probabilities..." << endl;
TranslationTable table(source_data_array, target_data_array, alignment);
- table.WriteBinary(output_dir / fs::path("lex.bin"));
+
+ start_write = Clock::now();
+ ofstream table_fstream((output_dir / fs::path("bilex.bin")).string());
+ ar::binary_oarchive table_stream(table_fstream);
+ table_stream << table;
+ stop_write = Clock::now();
+ write_duration += GetDuration(start_write, stop_write);
+
+ stop_time = Clock::now();
+ cerr << "Precomputing conditional probabilities took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ cerr << "Total time spent writing: " << write_duration
+ << " seconds" << endl;
return 0;
}
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
index 203fe219..2e4bdafb 100644
--- a/extractor/data_array.cc
+++ b/extractor/data_array.cc
@@ -118,33 +118,6 @@ int DataArray::GetSentenceId(int position) const {
return sentence_id[position];
}
-void DataArray::WriteBinary(const fs::path& filepath) const {
- std::cerr << "File: " << filepath.string() << std::endl;
- WriteBinary(fopen(filepath.string().c_str(), "w"));
-}
-
-void DataArray::WriteBinary(FILE* file) const {
- int size = id2word.size();
- fwrite(&size, sizeof(int), 1, file);
- for (string word: id2word) {
- size = word.size();
- fwrite(&size, sizeof(int), 1, file);
- fwrite(word.data(), sizeof(char), size, file);
- }
-
- size = data.size();
- fwrite(&size, sizeof(int), 1, file);
- fwrite(data.data(), sizeof(int), size, file);
-
- size = sentence_id.size();
- fwrite(&size, sizeof(int), 1, file);
- fwrite(sentence_id.data(), sizeof(int), size, file);
-
- size = sentence_start.size();
- fwrite(&size, sizeof(int), 1, file);
- fwrite(sentence_start.data(), sizeof(int), 1, file);
-}
-
bool DataArray::HasWord(const string& word) const {
return word2id.count(word);
}
@@ -158,4 +131,10 @@ string DataArray::GetWord(int word_id) const {
return id2word[word_id];
}
+bool DataArray::operator==(const DataArray& other) const {
+ return word2id == other.word2id && id2word == other.id2word &&
+ data == other.data && sentence_start == other.sentence_start &&
+ sentence_id == other.sentence_id;
+}
+
} // namespace extractor
diff --git a/extractor/data_array.h b/extractor/data_array.h
index 978a6931..2be6a09c 100644
--- a/extractor/data_array.h
+++ b/extractor/data_array.h
@@ -6,6 +6,10 @@
#include <vector>
#include <boost/filesystem.hpp>
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/split_member.hpp>
+#include <boost/serialization/string.hpp>
+#include <boost/serialization/vector.hpp>
namespace fs = boost::filesystem;
using namespace std;
@@ -43,6 +47,9 @@ class DataArray {
// Reads data array from bitext file where the sentences are separated by |||.
DataArray(const string& filename, const Side& side);
+ // Creates empty data array.
+ DataArray();
+
virtual ~DataArray();
// Returns a vector containing the word ids.
@@ -82,14 +89,7 @@ class DataArray {
// Returns the number of the sentence containing the given position.
virtual int GetSentenceId(int position) const;
- // Writes data array to file in binary format.
- void WriteBinary(const fs::path& filepath) const;
-
- // Writes data array to file in binary format.
- void WriteBinary(FILE* file) const;
-
- protected:
- DataArray();
+ bool operator==(const DataArray& other) const;
private:
// Sets up specific constants.
@@ -98,6 +98,28 @@ class DataArray {
// Constructs the data array.
void CreateDataArray(const vector<string>& lines);
+ friend class boost::serialization::access;
+
+ template<class Archive> void save(Archive& ar, unsigned int) const {
+ ar << id2word;
+ ar << data;
+ ar << sentence_id;
+ ar << sentence_start;
+ }
+
+ template<class Archive> void load(Archive& ar, unsigned int) {
+ ar >> id2word;
+ for (size_t i = 0; i < id2word.size(); ++i) {
+ word2id[id2word[i]] = i;
+ }
+
+ ar >> data;
+ ar >> sentence_id;
+ ar >> sentence_start;
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER();
+
unordered_map<string, int> word2id;
vector<string> id2word;
vector<int> data;
diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc
index 71175fda..6c329e34 100644
--- a/extractor/data_array_test.cc
+++ b/extractor/data_array_test.cc
@@ -1,8 +1,11 @@
#include <gtest/gtest.h>
#include <memory>
+#include <sstream>
#include <string>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
#include <boost/filesystem.hpp>
#include "data_array.h"
@@ -10,6 +13,7 @@
using namespace std;
using namespace ::testing;
namespace fs = boost::filesystem;
+namespace ar = boost::archive;
namespace extractor {
namespace {
@@ -18,12 +22,12 @@ class DataArrayTest : public Test {
protected:
virtual void SetUp() {
string sample_test_file("sample_bitext.txt");
- source_data = make_shared<DataArray>(sample_test_file, SOURCE);
- target_data = make_shared<DataArray>(sample_test_file, TARGET);
+ source_data = DataArray(sample_test_file, SOURCE);
+ target_data = DataArray(sample_test_file, TARGET);
}
- shared_ptr<DataArray> source_data;
- shared_ptr<DataArray> target_data;
+ DataArray source_data;
+ DataArray target_data;
};
TEST_F(DataArrayTest, TestGetData) {
@@ -32,11 +36,11 @@ TEST_F(DataArrayTest, TestGetData) {
"ana", "are", "mere", ".", "__END_OF_LINE__",
"ana", "bea", "mult", "lapte", ".", "__END_OF_LINE__"
};
- EXPECT_EQ(expected_source_data, source_data->GetData());
- EXPECT_EQ(expected_source_data.size(), source_data->GetSize());
+ EXPECT_EQ(expected_source_data, source_data.GetData());
+ EXPECT_EQ(expected_source_data.size(), source_data.GetSize());
for (size_t i = 0; i < expected_source_data.size(); ++i) {
- EXPECT_EQ(expected_source_data[i], source_data->AtIndex(i));
- EXPECT_EQ(expected_source_words[i], source_data->GetWordAtIndex(i));
+ EXPECT_EQ(expected_source_data[i], source_data.AtIndex(i));
+ EXPECT_EQ(expected_source_words[i], source_data.GetWordAtIndex(i));
}
vector<int> expected_target_data = {2, 3, 4, 5, 1, 2, 6, 7, 8, 9, 10, 5, 1};
@@ -44,55 +48,68 @@ TEST_F(DataArrayTest, TestGetData) {
"anna", "has", "apples", ".", "__END_OF_LINE__",
"anna", "drinks", "a", "lot", "of", "milk", ".", "__END_OF_LINE__"
};
- EXPECT_EQ(expected_target_data, target_data->GetData());
- EXPECT_EQ(expected_target_data.size(), target_data->GetSize());
+ EXPECT_EQ(expected_target_data, target_data.GetData());
+ EXPECT_EQ(expected_target_data.size(), target_data.GetSize());
for (size_t i = 0; i < expected_target_data.size(); ++i) {
- EXPECT_EQ(expected_target_data[i], target_data->AtIndex(i));
- EXPECT_EQ(expected_target_words[i], target_data->GetWordAtIndex(i));
+ EXPECT_EQ(expected_target_data[i], target_data.AtIndex(i));
+ EXPECT_EQ(expected_target_words[i], target_data.GetWordAtIndex(i));
}
}
TEST_F(DataArrayTest, TestVocabulary) {
- EXPECT_EQ(9, source_data->GetVocabularySize());
- EXPECT_TRUE(source_data->HasWord("mere"));
- EXPECT_EQ(4, source_data->GetWordId("mere"));
- EXPECT_EQ("mere", source_data->GetWord(4));
- EXPECT_FALSE(source_data->HasWord("banane"));
-
- EXPECT_EQ(11, target_data->GetVocabularySize());
- EXPECT_TRUE(target_data->HasWord("apples"));
- EXPECT_EQ(4, target_data->GetWordId("apples"));
- EXPECT_EQ("apples", target_data->GetWord(4));
- EXPECT_FALSE(target_data->HasWord("bananas"));
+ EXPECT_EQ(9, source_data.GetVocabularySize());
+ EXPECT_TRUE(source_data.HasWord("mere"));
+ EXPECT_EQ(4, source_data.GetWordId("mere"));
+ EXPECT_EQ("mere", source_data.GetWord(4));
+ EXPECT_FALSE(source_data.HasWord("banane"));
+
+ EXPECT_EQ(11, target_data.GetVocabularySize());
+ EXPECT_TRUE(target_data.HasWord("apples"));
+ EXPECT_EQ(4, target_data.GetWordId("apples"));
+ EXPECT_EQ("apples", target_data.GetWord(4));
+ EXPECT_FALSE(target_data.HasWord("bananas"));
}
TEST_F(DataArrayTest, TestSentenceData) {
- EXPECT_EQ(2, source_data->GetNumSentences());
- EXPECT_EQ(0, source_data->GetSentenceStart(0));
- EXPECT_EQ(5, source_data->GetSentenceStart(1));
- EXPECT_EQ(11, source_data->GetSentenceStart(2));
+ EXPECT_EQ(2, source_data.GetNumSentences());
+ EXPECT_EQ(0, source_data.GetSentenceStart(0));
+ EXPECT_EQ(5, source_data.GetSentenceStart(1));
+ EXPECT_EQ(11, source_data.GetSentenceStart(2));
- EXPECT_EQ(4, source_data->GetSentenceLength(0));
- EXPECT_EQ(5, source_data->GetSentenceLength(1));
+ EXPECT_EQ(4, source_data.GetSentenceLength(0));
+ EXPECT_EQ(5, source_data.GetSentenceLength(1));
vector<int> expected_source_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1};
for (size_t i = 0; i < expected_source_ids.size(); ++i) {
- EXPECT_EQ(expected_source_ids[i], source_data->GetSentenceId(i));
+ EXPECT_EQ(expected_source_ids[i], source_data.GetSentenceId(i));
}
- EXPECT_EQ(2, target_data->GetNumSentences());
- EXPECT_EQ(0, target_data->GetSentenceStart(0));
- EXPECT_EQ(5, target_data->GetSentenceStart(1));
- EXPECT_EQ(13, target_data->GetSentenceStart(2));
+ EXPECT_EQ(2, target_data.GetNumSentences());
+ EXPECT_EQ(0, target_data.GetSentenceStart(0));
+ EXPECT_EQ(5, target_data.GetSentenceStart(1));
+ EXPECT_EQ(13, target_data.GetSentenceStart(2));
- EXPECT_EQ(4, target_data->GetSentenceLength(0));
- EXPECT_EQ(7, target_data->GetSentenceLength(1));
+ EXPECT_EQ(4, target_data.GetSentenceLength(0));
+ EXPECT_EQ(7, target_data.GetSentenceLength(1));
vector<int> expected_target_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1};
for (size_t i = 0; i < expected_target_ids.size(); ++i) {
- EXPECT_EQ(expected_target_ids[i], target_data->GetSentenceId(i));
+ EXPECT_EQ(expected_target_ids[i], target_data.GetSentenceId(i));
}
}
+TEST_F(DataArrayTest, TestSerialization) {
+ stringstream stream(ios_base::binary | ios_base::out | ios_base::in);
+ ar::binary_oarchive output_stream(stream, ar::no_header);
+ output_stream << source_data << target_data;
+
+ DataArray source_copy, target_copy;
+ ar::binary_iarchive input_stream(stream, ar::no_header);
+ input_stream >> source_copy >> target_copy;
+
+ EXPECT_EQ(source_data, source_copy);
+ EXPECT_EQ(target_data, target_copy);
+}
+
} // namespace
} // namespace extractor
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc
index ee4ba42c..3b8aed69 100644
--- a/extractor/precomputation.cc
+++ b/extractor/precomputation.cc
@@ -165,25 +165,12 @@ void Precomputation::AddStartPositions(
positions.push_back(pos3);
}
-void Precomputation::WriteBinary(const fs::path& filepath) const {
- FILE* file = fopen(filepath.string().c_str(), "w");
-
- // TODO(pauldb): Refactor this code.
- int size = collocations.size();
- fwrite(&size, sizeof(int), 1, file);
- for (auto entry: collocations) {
- size = entry.first.size();
- fwrite(&size, sizeof(int), 1, file);
- fwrite(entry.first.data(), sizeof(int), size, file);
-
- size = entry.second.size();
- fwrite(&size, sizeof(int), 1, file);
- fwrite(entry.second.data(), sizeof(int), size, file);
- }
-}
-
const Index& Precomputation::GetCollocations() const {
return collocations;
}
+bool Precomputation::operator==(const Precomputation& other) const {
+ return collocations == other.collocations;
+}
+
} // namespace extractor
diff --git a/extractor/precomputation.h b/extractor/precomputation.h
index 3e792ac7..9f0c9424 100644
--- a/extractor/precomputation.h
+++ b/extractor/precomputation.h
@@ -9,6 +9,9 @@
#include <boost/filesystem.hpp>
#include <boost/functional/hash.hpp>
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/utility.hpp>
+#include <boost/serialization/vector.hpp>
namespace fs = boost::filesystem;
using namespace std;
@@ -39,19 +42,19 @@ class Precomputation {
int max_rule_symbols, int min_gap_size,
int max_frequent_phrase_len, int min_frequency);
- virtual ~Precomputation();
+ // Creates empty precomputation data structure.
+ Precomputation();
- void WriteBinary(const fs::path& filepath) const;
+ virtual ~Precomputation();
// Returns a reference to the index.
virtual const Index& GetCollocations() const;
+ bool operator==(const Precomputation& other) const;
+
static int FIRST_NONTERMINAL;
static int SECOND_NONTERMINAL;
- protected:
- Precomputation();
-
private:
// Finds the most frequent contiguous collocations.
vector<vector<int>> FindMostFrequentPatterns(
@@ -72,6 +75,28 @@ class Precomputation {
// Adds an occurrence of a ternary collocation.
void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3);
+ friend class boost::serialization::access;
+
+ template<class Archive> void save(Archive& ar, unsigned int) const {
+ int num_entries = collocations.size();
+ ar << num_entries;
+ for (pair<vector<int>, vector<int>> entry: collocations) {
+ ar << entry;
+ }
+ }
+
+ template<class Archive> void load(Archive& ar, unsigned int) {
+ int num_entries;
+ ar >> num_entries;
+ for (size_t i = 0; i < num_entries; ++i) {
+ pair<vector<int>, vector<int>> entry;
+ ar >> entry;
+ collocations.insert(entry);
+ }
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER();
+
Index collocations;
};
diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc
index 363febb7..e81ece5d 100644
--- a/extractor/precomputation_test.cc
+++ b/extractor/precomputation_test.cc
@@ -1,14 +1,19 @@
#include <gtest/gtest.h>
#include <memory>
+#include <sstream>
#include <vector>
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/archive/text_oarchive.hpp>
+
#include "mocks/mock_data_array.h"
#include "mocks/mock_suffix_array.h"
#include "precomputation.h"
using namespace std;
using namespace ::testing;
+namespace ar = boost::archive;
namespace extractor {
namespace {
@@ -29,15 +34,17 @@ class PrecomputationTest : public Test {
GetSuffix(i)).WillRepeatedly(Return(suffixes[i]));
}
EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp));
+
+ precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2);
}
vector<int> data;
shared_ptr<MockDataArray> data_array;
shared_ptr<MockSuffixArray> suffix_array;
+ Precomputation precomputation;
};
TEST_F(PrecomputationTest, TestCollocations) {
- Precomputation precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2);
Index collocations = precomputation.GetCollocations();
vector<int> key = {2, 3, -1, 2};
@@ -101,6 +108,18 @@ TEST_F(PrecomputationTest, TestCollocations) {
EXPECT_EQ(0, collocations.count(key));
}
+TEST_F(PrecomputationTest, TestSerialization) {
+ stringstream stream(ios_base::out | ios_base::in);
+ ar::text_oarchive output_stream(stream, ar::no_header);
+ output_stream << precomputation;
+
+ Precomputation precomputation_copy;
+ ar::text_iarchive input_stream(stream, ar::no_header);
+ input_stream >> precomputation_copy;
+
+ EXPECT_EQ(precomputation, precomputation_copy);
+}
+
} // namespace
} // namespace extractor
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 2fc6f724..8a9ca89d 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -42,11 +42,12 @@ fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) {
}
int main(int argc, char** argv) {
- int num_threads_default = 1;
- #pragma omp parallel
- num_threads_default = omp_get_num_threads();
-
// Sets up the command line arguments map.
+ int max_threads = 1;
+ #pragma omp parallel
+ max_threads = omp_get_num_threads();
+ string threads_option = "Number of parallel threads for extraction "
+ "(max=" + to_string(max_threads) + ")";
po::options_description desc("Command line options");
desc.add_options()
("help,h", "Show available options")
@@ -55,8 +56,7 @@ int main(int argc, char** argv) {
("bitext,b", po::value<string>(), "Parallel text (source ||| target)")
("alignment,a", po::value<string>()->required(), "Bitext word alignment")
("grammars,g", po::value<string>()->required(), "Grammars output path")
- ("threads,t", po::value<int>()->default_value(num_threads_default),
- "Number of parallel extractors")
+ ("threads,t", po::value<int>()->default_value(1), threads_option.c_str())
("frequent", po::value<int>()->default_value(100),
"Number of precomputed frequent patterns")
("super_frequent", po::value<int>()->default_value(10),
@@ -97,7 +97,7 @@ int main(int argc, char** argv) {
}
int num_threads = vm["threads"].as<int>();
- cout << "Grammar extraction will use " << num_threads << " threads." << endl;
+ cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
// Reads the parallel corpus.
Clock::time_point preprocess_start_time = Clock::now();
@@ -118,17 +118,17 @@ int main(int argc, char** argv) {
<< " seconds" << endl;
// Constructs the suffix array for the source data.
- cerr << "Creating source suffix array..." << endl;
start_time = Clock::now();
+ cerr << "Constructing source suffix array..." << endl;
shared_ptr<SuffixArray> source_suffix_array =
make_shared<SuffixArray>(source_data_array);
stop_time = Clock::now();
- cerr << "Creating suffix array took "
+ cerr << "Constructing suffix array took "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
// Reads the alignment.
- cerr << "Reading alignment..." << endl;
start_time = Clock::now();
+ cerr << "Reading alignment..." << endl;
shared_ptr<Alignment> alignment =
make_shared<Alignment>(vm["alignment"].as<string>());
stop_time = Clock::now();
@@ -137,8 +137,8 @@ int main(int argc, char** argv) {
// Constructs an index storing the occurrences in the source data for each
// frequent collocation.
- cerr << "Precomputing collocations..." << endl;
start_time = Clock::now();
+ cerr << "Precomputing collocations..." << endl;
shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(
source_suffix_array,
vm["frequent"].as<int>(),
@@ -154,8 +154,8 @@ int main(int argc, char** argv) {
// Constructs a table storing p(e | f) and p(f | e) for every pair of source
// and target words.
- cerr << "Precomputing conditional probabilities..." << endl;
start_time = Clock::now();
+ cerr << "Precomputing conditional probabilities..." << endl;
shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
source_data_array, target_data_array, alignment);
stop_time = Clock::now();
@@ -229,7 +229,7 @@ int main(int argc, char** argv) {
}
for (size_t i = 0; i < sentences.size(); ++i) {
- cout << "<seg grammar=\"" << GetGrammarFilePath(grammar_path, i) << "\" id=\""
+ cout << "<seg grammar=" << GetGrammarFilePath(grammar_path, i) << " id=\""
<< i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl;
}
diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc
index 65b2d581..0cf4d1f6 100644
--- a/extractor/suffix_array.cc
+++ b/extractor/suffix_array.cc
@@ -186,20 +186,6 @@ shared_ptr<DataArray> SuffixArray::GetData() const {
return data_array;
}
-void SuffixArray::WriteBinary(const fs::path& filepath) const {
- FILE* file = fopen(filepath.string().c_str(), "w");
- assert(file);
- data_array->WriteBinary(file);
-
- int size = suffix_array.size();
- fwrite(&size, sizeof(int), 1, file);
- fwrite(suffix_array.data(), sizeof(int), size, file);
-
- size = word_start.size();
- fwrite(&size, sizeof(int), 1, file);
- fwrite(word_start.data(), sizeof(int), size, file);
-}
-
PhraseLocation SuffixArray::Lookup(int low, int high, const string& word,
int offset) const {
if (!data_array->HasWord(word)) {
@@ -232,4 +218,10 @@ int SuffixArray::LookupRangeStart(int low, int high, int word_id,
return result;
}
+bool SuffixArray::operator==(const SuffixArray& other) const {
+ return *data_array == *other.data_array &&
+ suffix_array == other.suffix_array &&
+ word_start == other.word_start;
+}
+
} // namespace extractor
diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h
index bf731d79..8ee454ec 100644
--- a/extractor/suffix_array.h
+++ b/extractor/suffix_array.h
@@ -6,6 +6,9 @@
#include <vector>
#include <boost/filesystem.hpp>
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/split_member.hpp>
+#include <boost/serialization/vector.hpp>
namespace fs = boost::filesystem;
using namespace std;
@@ -20,6 +23,9 @@ class SuffixArray {
// Creates a suffix array from a data array.
SuffixArray(shared_ptr<DataArray> data_array);
+ // Creates empty suffix array.
+ SuffixArray();
+
virtual ~SuffixArray();
// Returns the size of the suffix array.
@@ -40,10 +46,7 @@ class SuffixArray {
virtual PhraseLocation Lookup(int low, int high, const string& word,
int offset) const;
- void WriteBinary(const fs::path& filepath) const;
-
- protected:
- SuffixArray();
+ bool operator==(const SuffixArray& other) const;
private:
// Constructs the suffix array using the algorithm of Larsson and Sadakane
@@ -65,6 +68,23 @@ class SuffixArray {
// offset value is greater or equal to word_id.
int LookupRangeStart(int low, int high, int word_id, int offset) const;
+ friend class boost::serialization::access;
+
+ template<class Archive> void save(Archive& ar, unsigned int) const {
+ ar << *data_array;
+ ar << suffix_array;
+ ar << word_start;
+ }
+
+ template<class Archive> void load(Archive& ar, unsigned int) {
+ data_array = make_shared<DataArray>();
+ ar >> *data_array;
+ ar >> suffix_array;
+ ar >> word_start;
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER();
+
shared_ptr<DataArray> data_array;
vector<int> suffix_array;
vector<int> word_start;
diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc
index 8431a16e..ba0dbcc3 100644
--- a/extractor/suffix_array_test.cc
+++ b/extractor/suffix_array_test.cc
@@ -1,13 +1,17 @@
#include <gtest/gtest.h>
+#include <vector>
+
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+
#include "mocks/mock_data_array.h"
#include "phrase_location.h"
#include "suffix_array.h"
-#include <vector>
-
using namespace std;
using namespace ::testing;
+namespace ar = boost::archive;
namespace extractor {
namespace {
@@ -20,30 +24,30 @@ class SuffixArrayTest : public Test {
EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));
EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7));
EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13));
- suffix_array = make_shared<SuffixArray>(data_array);
+ suffix_array = SuffixArray(data_array);
}
vector<int> data;
- shared_ptr<SuffixArray> suffix_array;
+ SuffixArray suffix_array;
shared_ptr<MockDataArray> data_array;
};
TEST_F(SuffixArrayTest, TestData) {
- EXPECT_EQ(data_array, suffix_array->GetData());
- EXPECT_EQ(14, suffix_array->GetSize());
+ EXPECT_EQ(data_array, suffix_array.GetData());
+ EXPECT_EQ(14, suffix_array.GetSize());
}
TEST_F(SuffixArrayTest, TestBuildSuffixArray) {
vector<int> expected_suffix_array =
{13, 11, 2, 12, 3, 6, 10, 1, 4, 7, 5, 9, 0, 8};
for (size_t i = 0; i < expected_suffix_array.size(); ++i) {
- EXPECT_EQ(expected_suffix_array[i], suffix_array->GetSuffix(i));
+ EXPECT_EQ(expected_suffix_array[i], suffix_array.GetSuffix(i));
}
}
TEST_F(SuffixArrayTest, TestBuildLCP) {
vector<int> expected_lcp = {-1, 0, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1};
- EXPECT_EQ(expected_lcp, suffix_array->BuildLCPArray());
+ EXPECT_EQ(expected_lcp, suffix_array.BuildLCPArray());
}
TEST_F(SuffixArrayTest, TestLookup) {
@@ -53,25 +57,37 @@ TEST_F(SuffixArrayTest, TestLookup) {
EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6));
- EXPECT_EQ(PhraseLocation(11, 14), suffix_array->Lookup(0, 14, "word1", 0));
+ EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0));
EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false));
- EXPECT_EQ(PhraseLocation(0, 0), suffix_array->Lookup(0, 14, "word2", 0));
+ EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0));
EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4));
- EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 14, "word3", 1));
+ EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1));
EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1));
- EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word4", 2));
+ EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2));
EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true));
EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2));
- EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word5", 3));
+ EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3));
+
+ EXPECT_EQ(PhraseLocation(12, 13), suffix_array.Lookup(11, 13, "word3", 4));
+ EXPECT_EQ(PhraseLocation(11, 11), suffix_array.Lookup(11, 13, "word5", 1));
+}
+
+TEST_F(SuffixArrayTest, TestSerialization) {
+ stringstream stream(ios_base::binary | ios_base::out | ios_base::in);
+ ar::binary_oarchive output_stream(stream, ar::no_header);
+ output_stream << suffix_array;
+
+ SuffixArray suffix_array_copy;
+ ar::binary_iarchive input_stream(stream, ar::no_header);
+ input_stream >> suffix_array_copy;
- EXPECT_EQ(PhraseLocation(12, 13), suffix_array->Lookup(11, 13, "word3", 4));
- EXPECT_EQ(PhraseLocation(11, 11), suffix_array->Lookup(11, 13, "word5", 1));
+ EXPECT_EQ(suffix_array, suffix_array_copy);
}
} // namespace
diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc
index adb59cb5..1b1ba112 100644
--- a/extractor/translation_table.cc
+++ b/extractor/translation_table.cc
@@ -97,7 +97,12 @@ double TranslationTable::GetTargetGivenSourceScore(
int source_id = source_data_array->GetWordId(source_word);
int target_id = target_data_array->GetWordId(target_word);
- return translation_probabilities[make_pair(source_id, target_id)].first;
+ auto entry = make_pair(source_id, target_id);
+ auto it = translation_probabilities.find(entry);
+ if (it == translation_probabilities.end()) {
+ return 0;
+ }
+ return it->second.first;
}
double TranslationTable::GetSourceGivenTargetScore(
@@ -109,18 +114,18 @@ double TranslationTable::GetSourceGivenTargetScore(
int source_id = source_data_array->GetWordId(source_word);
int target_id = target_data_array->GetWordId(target_word);
- return translation_probabilities[make_pair(source_id, target_id)].second;
+ auto entry = make_pair(source_id, target_id);
+ auto it = translation_probabilities.find(entry);
+ if (it == translation_probabilities.end()) {
+ return 0;
+ }
+ return it->second.second;
}
-void TranslationTable::WriteBinary(const fs::path& filepath) const {
- FILE* file = fopen(filepath.string().c_str(), "w");
-
- int size = translation_probabilities.size();
- fwrite(&size, sizeof(int), 1, file);
- for (auto entry: translation_probabilities) {
- fwrite(&entry.first, sizeof(entry.first), 1, file);
- fwrite(&entry.second, sizeof(entry.second), 1, file);
- }
+bool TranslationTable::operator==(const TranslationTable& other) const {
+ return *source_data_array == *other.source_data_array &&
+ *target_data_array == *other.target_data_array &&
+ translation_probabilities == other.translation_probabilities;
}
} // namespace extractor
diff --git a/extractor/translation_table.h b/extractor/translation_table.h
index ed43ad72..2a37bab7 100644
--- a/extractor/translation_table.h
+++ b/extractor/translation_table.h
@@ -7,6 +7,9 @@
#include <boost/filesystem.hpp>
#include <boost/functional/hash.hpp>
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/split_member.hpp>
+#include <boost/serialization/utility.hpp>
using namespace std;
namespace fs = boost::filesystem;
@@ -23,11 +26,16 @@ class DataArray;
*/
class TranslationTable {
public:
+ // Constructs a translation table from source data, target data and the
+ // corresponding alignment.
TranslationTable(
shared_ptr<DataArray> source_data_array,
shared_ptr<DataArray> target_data_array,
shared_ptr<Alignment> alignment);
+ // Creates empty translation table.
+ TranslationTable();
+
virtual ~TranslationTable();
// Returns p(e | f).
@@ -38,10 +46,7 @@ class TranslationTable {
virtual double GetSourceGivenTargetScore(const string& source_word,
const string& target_word);
- void WriteBinary(const fs::path& filepath) const;
-
- protected:
- TranslationTable();
+ bool operator==(const TranslationTable& other) const;
private:
// Increment links count for the given (f, e) word pair.
@@ -52,6 +57,35 @@ class TranslationTable {
int source_word_id,
int target_word_id) const;
+ friend class boost::serialization::access;
+
+ template<class Archive> void save(Archive& ar, unsigned int) const {
+ ar << *source_data_array << *target_data_array;
+
+ int num_entries = translation_probabilities.size();
+ ar << num_entries;
+ for (auto entry: translation_probabilities) {
+ ar << entry;
+ }
+ }
+
+ template<class Archive> void load(Archive& ar, unsigned int) {
+ source_data_array = make_shared<DataArray>();
+ ar >> *source_data_array;
+ target_data_array = make_shared<DataArray>();
+ ar >> *target_data_array;
+
+ int num_entries;
+ ar >> num_entries;
+ for (size_t i = 0; i < num_entries; ++i) {
+ pair<pair<int, int>, pair<double, double>> entry;
+ ar >> entry;
+ translation_probabilities.insert(entry);
+ }
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER();
+
shared_ptr<DataArray> source_data_array;
shared_ptr<DataArray> target_data_array;
unordered_map<pair<int, int>, pair<double, double>, PairHash>
diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc
index d14f2f89..606777bd 100644
--- a/extractor/translation_table_test.cc
+++ b/extractor/translation_table_test.cc
@@ -1,83 +1,106 @@
#include <gtest/gtest.h>
#include <memory>
+#include <sstream>
#include <string>
#include <vector>
+#include <boost/archive/binary_iarchive.hpp>
+#include <boost/archive/binary_oarchive.hpp>
+
#include "mocks/mock_alignment.h"
#include "mocks/mock_data_array.h"
#include "translation_table.h"
using namespace std;
using namespace ::testing;
+namespace ar = boost::archive;
namespace extractor {
namespace {
-TEST(TranslationTableTest, TestScores) {
- vector<string> words = {"a", "b", "c"};
-
- vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0};
- vector<int> source_sentence_start = {0, 6, 10, 14};
- shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>();
- EXPECT_CALL(*source_data_array, GetData())
- .WillRepeatedly(ReturnRef(source_data));
- EXPECT_CALL(*source_data_array, GetNumSentences())
- .WillRepeatedly(Return(3));
- for (size_t i = 0; i < source_sentence_start.size(); ++i) {
- EXPECT_CALL(*source_data_array, GetSentenceStart(i))
- .WillRepeatedly(Return(source_sentence_start[i]));
- }
- for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*source_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
- EXPECT_CALL(*source_data_array, GetWordId(words[i]))
- .WillRepeatedly(Return(i + 2));
- }
- EXPECT_CALL(*source_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
-
- vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};
- vector<int> target_sentence_start = {0, 7, 10, 13};
- shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>();
- EXPECT_CALL(*target_data_array, GetData())
- .WillRepeatedly(ReturnRef(target_data));
- for (size_t i = 0; i < target_sentence_start.size(); ++i) {
- EXPECT_CALL(*target_data_array, GetSentenceStart(i))
- .WillRepeatedly(Return(target_sentence_start[i]));
- }
- for (size_t i = 0; i < words.size(); ++i) {
- EXPECT_CALL(*target_data_array, HasWord(words[i]))
- .WillRepeatedly(Return(true));
- EXPECT_CALL(*target_data_array, GetWordId(words[i]))
- .WillRepeatedly(Return(i + 2));
+class TranslationTableTest : public Test {
+ protected:
+ virtual void SetUp() {
+ vector<string> words = {"a", "b", "c"};
+
+ vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0};
+ vector<int> source_sentence_start = {0, 6, 10, 14};
+ shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetData())
+ .WillRepeatedly(ReturnRef(source_data));
+ EXPECT_CALL(*source_data_array, GetNumSentences())
+ .WillRepeatedly(Return(3));
+ for (size_t i = 0; i < source_sentence_start.size(); ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceStart(i))
+ .WillRepeatedly(Return(source_sentence_start[i]));
+ }
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*source_data_array, HasWord(words[i]))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*source_data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i + 2));
+ }
+ EXPECT_CALL(*source_data_array, HasWord("d"))
+ .WillRepeatedly(Return(false));
+
+ vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};
+ vector<int> target_sentence_start = {0, 7, 10, 13};
+ shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*target_data_array, GetData())
+ .WillRepeatedly(ReturnRef(target_data));
+ for (size_t i = 0; i < target_sentence_start.size(); ++i) {
+ EXPECT_CALL(*target_data_array, GetSentenceStart(i))
+ .WillRepeatedly(Return(target_sentence_start[i]));
+ }
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*target_data_array, HasWord(words[i]))
+ .WillRepeatedly(Return(true));
+ EXPECT_CALL(*target_data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i + 2));
+ }
+ EXPECT_CALL(*target_data_array, HasWord("d"))
+ .WillRepeatedly(Return(false));
+
+ vector<pair<int, int>> links1 = {
+ make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3),
+ make_pair(4, 4), make_pair(4, 5)
+ };
+ vector<pair<int, int>> links2 = {make_pair(1, 0), make_pair(2, 1)};
+ vector<pair<int, int>> links3 = {make_pair(0, 0), make_pair(2, 1)};
+ shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>();
+ EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1));
+ EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2));
+ EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3));
+
+ table = TranslationTable(source_data_array, target_data_array, alignment);
}
- EXPECT_CALL(*target_data_array, HasWord("d"))
- .WillRepeatedly(Return(false));
-
- vector<pair<int, int>> links1 = {
- make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3),
- make_pair(4, 4), make_pair(4, 5)
- };
- vector<pair<int, int>> links2 = {make_pair(1, 0), make_pair(2, 1)};
- vector<pair<int, int>> links3 = {make_pair(0, 0), make_pair(2, 1)};
- shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>();
- EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1));
- EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2));
- EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3));
-
- shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
- source_data_array, target_data_array, alignment);
-
- EXPECT_EQ(0.75, table->GetTargetGivenSourceScore("a", "a"));
- EXPECT_EQ(0, table->GetTargetGivenSourceScore("a", "b"));
- EXPECT_EQ(0.5, table->GetTargetGivenSourceScore("c", "c"));
- EXPECT_EQ(-1, table->GetTargetGivenSourceScore("c", "d"));
-
- EXPECT_EQ(1, table->GetSourceGivenTargetScore("a", "a"));
- EXPECT_EQ(0, table->GetSourceGivenTargetScore("a", "b"));
- EXPECT_EQ(1, table->GetSourceGivenTargetScore("c", "c"));
- EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d"));
+
+ TranslationTable table;
+};
+
+TEST_F(TranslationTableTest, TestScores) {
+ EXPECT_EQ(0.75, table.GetTargetGivenSourceScore("a", "a"));
+ EXPECT_EQ(0, table.GetTargetGivenSourceScore("a", "b"));
+ EXPECT_EQ(0.5, table.GetTargetGivenSourceScore("c", "c"));
+ EXPECT_EQ(-1, table.GetTargetGivenSourceScore("c", "d"));
+
+ EXPECT_EQ(1, table.GetSourceGivenTargetScore("a", "a"));
+ EXPECT_EQ(0, table.GetSourceGivenTargetScore("a", "b"));
+ EXPECT_EQ(1, table.GetSourceGivenTargetScore("c", "c"));
+ EXPECT_EQ(-1, table.GetSourceGivenTargetScore("c", "d"));
+}
+
+TEST_F(TranslationTableTest, TestSerialization) {
+ stringstream stream(ios_base::binary | ios_base::out | ios_base::in);
+ ar::binary_oarchive output_stream(stream, ar::no_header);
+ output_stream << table;
+
+ TranslationTable table_copy;
+ ar::binary_iarchive input_stream(stream, ar::no_header);
+ input_stream >> table_copy;
+
+ EXPECT_EQ(table, table_copy);
}
} // namespace