diff options
-rw-r--r-- | extractor/alignment.cc | 11 | ||||
-rw-r--r-- | extractor/alignment.h | 19 | ||||
-rw-r--r-- | extractor/alignment_test.cc | 26 | ||||
-rw-r--r-- | extractor/compile.cc | 82 | ||||
-rw-r--r-- | extractor/data_array.cc | 33 | ||||
-rw-r--r-- | extractor/data_array.h | 38 | ||||
-rw-r--r-- | extractor/data_array_test.cc | 91 | ||||
-rw-r--r-- | extractor/precomputation.cc | 21 | ||||
-rw-r--r-- | extractor/precomputation.h | 35 | ||||
-rw-r--r-- | extractor/precomputation_test.cc | 21 | ||||
-rw-r--r-- | extractor/run_extractor.cc | 10 | ||||
-rw-r--r-- | extractor/suffix_array.cc | 20 | ||||
-rw-r--r-- | extractor/suffix_array.h | 28 | ||||
-rw-r--r-- | extractor/suffix_array_test.cc | 46 | ||||
-rw-r--r-- | extractor/translation_table.cc | 13 | ||||
-rw-r--r-- | extractor/translation_table.h | 42 | ||||
-rw-r--r-- | extractor/translation_table_test.cc | 149 |
17 files changed, 451 insertions, 234 deletions
diff --git a/extractor/alignment.cc b/extractor/alignment.cc index 68bfde1a..2278c825 100644 --- a/extractor/alignment.cc +++ b/extractor/alignment.cc @@ -39,15 +39,8 @@ vector<pair<int, int>> Alignment::GetLinks(int sentence_index) const { return alignments[sentence_index]; } -void Alignment::WriteBinary(const fs::path& filepath) { - FILE* file = fopen(filepath.string().c_str(), "w"); - int size = alignments.size(); - fwrite(&size, sizeof(int), 1, file); - for (vector<pair<int, int>> alignment: alignments) { - size = alignment.size(); - fwrite(&size, sizeof(int), 1, file); - fwrite(alignment.data(), sizeof(pair<int, int>), size, file); - } +bool Alignment::operator==(const Alignment& other) const { + return alignments == other.alignments; } } // namespace extractor diff --git a/extractor/alignment.h b/extractor/alignment.h index 4596f92b..dc5a8b55 100644 --- a/extractor/alignment.h +++ b/extractor/alignment.h @@ -5,6 +5,10 @@ #include <vector> #include <boost/filesystem.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/split_member.hpp> +#include <boost/serialization/utility.hpp> +#include <boost/serialization/vector.hpp> namespace fs = boost::filesystem; using namespace std; @@ -19,18 +23,23 @@ class Alignment { // Reads alignment from text file. Alignment(const string& filename); + // Creates empty alignment. + Alignment(); + // Returns the alignment for a given sentence. virtual vector<pair<int, int>> GetLinks(int sentence_index) const; - // Writes alignment to file in binary format. - void WriteBinary(const fs::path& filepath); - virtual ~Alignment(); - protected: - Alignment(); + bool operator==(const Alignment& alignment) const; private: + friend class boost::serialization::access; + + template<class Archive> void serialize(Archive& ar, unsigned int) { + ar & alignments; + } + vector<vector<pair<int, int>>> alignments; }; diff --git a/extractor/alignment_test.cc b/extractor/alignment_test.cc index 43c37ebd..1b8ff531 100644 --- a/extractor/alignment_test.cc +++ b/extractor/alignment_test.cc @@ -1,12 +1,16 @@ #include <gtest/gtest.h> -#include <memory> +#include <sstream> #include <string> +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> + #include "alignment.h" using namespace std; using namespace ::testing; +namespace ar = boost::archive; namespace extractor { namespace { @@ -14,19 +18,31 @@ namespace { class AlignmentTest : public Test { protected: virtual void SetUp() { - alignment = make_shared<Alignment>("sample_alignment.txt"); + alignment = Alignment("sample_alignment.txt"); } - shared_ptr<Alignment> alignment; + Alignment alignment; }; TEST_F(AlignmentTest, TestGetLinks) { vector<pair<int, int>> expected_links = { make_pair(0, 0), make_pair(1, 1), make_pair(2, 2) }; - EXPECT_EQ(expected_links, alignment->GetLinks(0)); + EXPECT_EQ(expected_links, alignment.GetLinks(0)); expected_links = {make_pair(1, 0), make_pair(2, 1)}; - EXPECT_EQ(expected_links, alignment->GetLinks(1)); + EXPECT_EQ(expected_links, alignment.GetLinks(1)); +} + +TEST_F(AlignmentTest, TestSerialization) { + stringstream stream(ios_base::binary | ios_base::out | ios_base::in); + ar::binary_oarchive output_stream(stream, ar::no_header); + output_stream << alignment; + + Alignment alignment_copy; + ar::binary_iarchive input_stream(stream, ar::no_header); + input_stream >> alignment_copy; + + EXPECT_EQ(alignment, alignment_copy); } } // namespace diff --git a/extractor/compile.cc b/extractor/compile.cc index a9ae2cef..65fdd509 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -1,6 +1,8 @@ +#include <fstream> #include <iostream> #include <string> +#include <boost/archive/binary_oarchive.hpp> #include <boost/filesystem.hpp> #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> @@ -9,8 +11,10 @@ #include "data_array.h" #include "precomputation.h" #include "suffix_array.h" +#include "time_util.h" #include "translation_table.h" +namespace ar = boost::archive; namespace fs = boost::filesystem; namespace po = boost::program_options; using namespace std; @@ -58,11 +62,14 @@ int main(int argc, char** argv) { return 1; } - fs::path output_dir(vm["output"].as<string>().c_str()); + fs::path output_dir(vm["output"].as<string>()); if (!fs::exists(output_dir)) { fs::create_directory(output_dir); } + // Reading source and target data. + Clock::time_point start_time = Clock::now(); + cerr << "Reading source and target data..." << endl; shared_ptr<DataArray> source_data_array, target_data_array; if (vm.count("bitext")) { source_data_array = make_shared<DataArray>( @@ -73,15 +80,53 @@ int main(int argc, char** argv) { source_data_array = make_shared<DataArray>(vm["source"].as<string>()); target_data_array = make_shared<DataArray>(vm["target"].as<string>()); } + + Clock::time_point start_write = Clock::now(); + ofstream target_fstream((output_dir / fs::path("target.bin")).string()); + ar::binary_oarchive target_stream(target_fstream); + target_stream << *target_data_array; + Clock::time_point stop_write = Clock::now(); + double write_duration = GetDuration(start_write, stop_write); + + Clock::time_point stop_time = Clock::now(); + cerr << "Reading data took " << GetDuration(start_time, stop_time) + << " seconds" << endl; + + // Constructing and compiling the suffix array. + start_time = Clock::now(); + cerr << "Constructing source suffix array..." << endl; shared_ptr<SuffixArray> source_suffix_array = make_shared<SuffixArray>(source_data_array); - source_suffix_array->WriteBinary(output_dir / fs::path("f.bin")); - target_data_array->WriteBinary(output_dir / fs::path("e.bin")); + start_write = Clock::now(); + ofstream source_fstream((output_dir / fs::path("source.bin")).string()); + ar::binary_oarchive output_stream(source_fstream); + output_stream << *source_suffix_array; + stop_write = Clock::now(); + write_duration += GetDuration(start_write, stop_write); + + cerr << "Constructing suffix array took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + // Reading alignment. + start_time = Clock::now(); + cerr << "Reading alignment..." << endl; shared_ptr<Alignment> alignment = make_shared<Alignment>(vm["alignment"].as<string>()); - alignment->WriteBinary(output_dir / fs::path("a.bin")); + start_write = Clock::now(); + ofstream alignment_fstream((output_dir / fs::path("alignment.bin")).string()); + ar::binary_oarchive alignment_stream(alignment_fstream); + alignment_stream << *alignment; + stop_write = Clock::now(); + write_duration += GetDuration(start_write, stop_write); + + stop_time = Clock::now(); + cerr << "Reading alignment took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Precomputing collocations..." << endl; Precomputation precomputation( source_suffix_array, vm["frequent"].as<int>(), @@ -91,10 +136,35 @@ int main(int argc, char** argv) { vm["min_gap_size"].as<int>(), vm["max_phrase_len"].as<int>(), vm["min_frequency"].as<int>()); - precomputation.WriteBinary(output_dir / fs::path("precompute.bin")); + start_write = Clock::now(); + ofstream precomp_fstream((output_dir / fs::path("precomp.bin")).string()); + ar::binary_oarchive precomp_stream(precomp_fstream); + precomp_stream << precomputation; + stop_write = Clock::now(); + write_duration += GetDuration(start_write, stop_write); + + stop_time = Clock::now(); + cerr << "Precomputing collocations took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Precomputing conditional probabilities..." << endl; TranslationTable table(source_data_array, target_data_array, alignment); - table.WriteBinary(output_dir / fs::path("lex.bin")); + + start_write = Clock::now(); + ofstream table_fstream((output_dir / fs::path("bilex.bin")).string()); + ar::binary_oarchive table_stream(table_fstream); + table_stream << table; + stop_write = Clock::now(); + write_duration += GetDuration(start_write, stop_write); + + stop_time = Clock::now(); + cerr << "Precomputing conditional probabilities took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + cerr << "Total time spent writing: " << write_duration + << " seconds" << endl; return 0; } diff --git a/extractor/data_array.cc b/extractor/data_array.cc index 203fe219..2e4bdafb 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -118,33 +118,6 @@ int DataArray::GetSentenceId(int position) const { return sentence_id[position]; } -void DataArray::WriteBinary(const fs::path& filepath) const { - std::cerr << "File: " << filepath.string() << std::endl; - WriteBinary(fopen(filepath.string().c_str(), "w")); -} - -void DataArray::WriteBinary(FILE* file) const { - int size = id2word.size(); - fwrite(&size, sizeof(int), 1, file); - for (string word: id2word) { - size = word.size(); - fwrite(&size, sizeof(int), 1, file); - fwrite(word.data(), sizeof(char), size, file); - } - - size = data.size(); - fwrite(&size, sizeof(int), 1, file); - fwrite(data.data(), sizeof(int), size, file); - - size = sentence_id.size(); - fwrite(&size, sizeof(int), 1, file); - fwrite(sentence_id.data(), sizeof(int), size, file); - - size = sentence_start.size(); - fwrite(&size, sizeof(int), 1, file); - fwrite(sentence_start.data(), sizeof(int), 1, file); -} - bool DataArray::HasWord(const string& word) const { return word2id.count(word); } @@ -158,4 +131,10 @@ string DataArray::GetWord(int word_id) const { return id2word[word_id]; } +bool DataArray::operator==(const DataArray& other) const { + return word2id == other.word2id && id2word == other.id2word && + data == other.data && sentence_start == other.sentence_start && + sentence_id == other.sentence_id; +} + } // namespace extractor diff --git a/extractor/data_array.h b/extractor/data_array.h index 978a6931..2be6a09c 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -6,6 +6,10 @@ #include <vector> #include <boost/filesystem.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/split_member.hpp> +#include <boost/serialization/string.hpp> +#include <boost/serialization/vector.hpp> namespace fs = boost::filesystem; using namespace std; @@ -43,6 +47,9 @@ class DataArray { // Reads data array from bitext file where the sentences are separated by |||. DataArray(const string& filename, const Side& side); + // Creates empty data array. + DataArray(); + virtual ~DataArray(); // Returns a vector containing the word ids. @@ -82,14 +89,7 @@ class DataArray { // Returns the number of the sentence containing the given position. virtual int GetSentenceId(int position) const; - // Writes data array to file in binary format. - void WriteBinary(const fs::path& filepath) const; - - // Writes data array to file in binary format. - void WriteBinary(FILE* file) const; - - protected: - DataArray(); + bool operator==(const DataArray& other) const; private: // Sets up specific constants. @@ -98,6 +98,28 @@ class DataArray { // Constructs the data array. void CreateDataArray(const vector<string>& lines); + friend class boost::serialization::access; + + template<class Archive> void save(Archive& ar, unsigned int) const { + ar << id2word; + ar << data; + ar << sentence_id; + ar << sentence_start; + } + + template<class Archive> void load(Archive& ar, unsigned int) { + ar >> id2word; + for (size_t i = 0; i < id2word.size(); ++i) { + word2id[id2word[i]] = i; + } + + ar >> data; + ar >> sentence_id; + ar >> sentence_start; + } + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + unordered_map<string, int> word2id; vector<string> id2word; vector<int> data; diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc index 71175fda..6c329e34 100644 --- a/extractor/data_array_test.cc +++ b/extractor/data_array_test.cc @@ -1,8 +1,11 @@ #include <gtest/gtest.h> #include <memory> +#include <sstream> #include <string> +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> #include <boost/filesystem.hpp> #include "data_array.h" @@ -10,6 +13,7 @@ using namespace std; using namespace ::testing; namespace fs = boost::filesystem; +namespace ar = boost::archive; namespace extractor { namespace { @@ -18,12 +22,12 @@ class DataArrayTest : public Test { protected: virtual void SetUp() { string sample_test_file("sample_bitext.txt"); - source_data = make_shared<DataArray>(sample_test_file, SOURCE); - target_data = make_shared<DataArray>(sample_test_file, TARGET); + source_data = DataArray(sample_test_file, SOURCE); + target_data = DataArray(sample_test_file, TARGET); } - shared_ptr<DataArray> source_data; - shared_ptr<DataArray> target_data; + DataArray source_data; + DataArray target_data; }; TEST_F(DataArrayTest, TestGetData) { @@ -32,11 +36,11 @@ TEST_F(DataArrayTest, TestGetData) { "ana", "are", "mere", ".", "__END_OF_LINE__", "ana", "bea", "mult", "lapte", ".", "__END_OF_LINE__" }; - EXPECT_EQ(expected_source_data, source_data->GetData()); - EXPECT_EQ(expected_source_data.size(), source_data->GetSize()); + EXPECT_EQ(expected_source_data, source_data.GetData()); + EXPECT_EQ(expected_source_data.size(), source_data.GetSize()); for (size_t i = 0; i < expected_source_data.size(); ++i) { - EXPECT_EQ(expected_source_data[i], source_data->AtIndex(i)); - EXPECT_EQ(expected_source_words[i], source_data->GetWordAtIndex(i)); + EXPECT_EQ(expected_source_data[i], source_data.AtIndex(i)); + EXPECT_EQ(expected_source_words[i], source_data.GetWordAtIndex(i)); } vector<int> expected_target_data = {2, 3, 4, 5, 1, 2, 6, 7, 8, 9, 10, 5, 1}; @@ -44,55 +48,68 @@ TEST_F(DataArrayTest, TestGetData) { "anna", "has", "apples", ".", "__END_OF_LINE__", "anna", "drinks", "a", "lot", "of", "milk", ".", "__END_OF_LINE__" }; - EXPECT_EQ(expected_target_data, target_data->GetData()); - EXPECT_EQ(expected_target_data.size(), target_data->GetSize()); + EXPECT_EQ(expected_target_data, target_data.GetData()); + EXPECT_EQ(expected_target_data.size(), target_data.GetSize()); for (size_t i = 0; i < expected_target_data.size(); ++i) { - EXPECT_EQ(expected_target_data[i], target_data->AtIndex(i)); - EXPECT_EQ(expected_target_words[i], target_data->GetWordAtIndex(i)); + EXPECT_EQ(expected_target_data[i], target_data.AtIndex(i)); + EXPECT_EQ(expected_target_words[i], target_data.GetWordAtIndex(i)); } } TEST_F(DataArrayTest, TestVocabulary) { - EXPECT_EQ(9, source_data->GetVocabularySize()); - EXPECT_TRUE(source_data->HasWord("mere")); - EXPECT_EQ(4, source_data->GetWordId("mere")); - EXPECT_EQ("mere", source_data->GetWord(4)); - EXPECT_FALSE(source_data->HasWord("banane")); - - EXPECT_EQ(11, target_data->GetVocabularySize()); - EXPECT_TRUE(target_data->HasWord("apples")); - EXPECT_EQ(4, target_data->GetWordId("apples")); - EXPECT_EQ("apples", target_data->GetWord(4)); - EXPECT_FALSE(target_data->HasWord("bananas")); + EXPECT_EQ(9, source_data.GetVocabularySize()); + EXPECT_TRUE(source_data.HasWord("mere")); + EXPECT_EQ(4, source_data.GetWordId("mere")); + EXPECT_EQ("mere", source_data.GetWord(4)); + EXPECT_FALSE(source_data.HasWord("banane")); + + EXPECT_EQ(11, target_data.GetVocabularySize()); + EXPECT_TRUE(target_data.HasWord("apples")); + EXPECT_EQ(4, target_data.GetWordId("apples")); + EXPECT_EQ("apples", target_data.GetWord(4)); + EXPECT_FALSE(target_data.HasWord("bananas")); } TEST_F(DataArrayTest, TestSentenceData) { - EXPECT_EQ(2, source_data->GetNumSentences()); - EXPECT_EQ(0, source_data->GetSentenceStart(0)); - EXPECT_EQ(5, source_data->GetSentenceStart(1)); - EXPECT_EQ(11, source_data->GetSentenceStart(2)); + EXPECT_EQ(2, source_data.GetNumSentences()); + EXPECT_EQ(0, source_data.GetSentenceStart(0)); + EXPECT_EQ(5, source_data.GetSentenceStart(1)); + EXPECT_EQ(11, source_data.GetSentenceStart(2)); - EXPECT_EQ(4, source_data->GetSentenceLength(0)); - EXPECT_EQ(5, source_data->GetSentenceLength(1)); + EXPECT_EQ(4, source_data.GetSentenceLength(0)); + EXPECT_EQ(5, source_data.GetSentenceLength(1)); vector<int> expected_source_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}; for (size_t i = 0; i < expected_source_ids.size(); ++i) { - EXPECT_EQ(expected_source_ids[i], source_data->GetSentenceId(i)); + EXPECT_EQ(expected_source_ids[i], source_data.GetSentenceId(i)); } - EXPECT_EQ(2, target_data->GetNumSentences()); - EXPECT_EQ(0, target_data->GetSentenceStart(0)); - EXPECT_EQ(5, target_data->GetSentenceStart(1)); - EXPECT_EQ(13, target_data->GetSentenceStart(2)); + EXPECT_EQ(2, target_data.GetNumSentences()); + EXPECT_EQ(0, target_data.GetSentenceStart(0)); + EXPECT_EQ(5, target_data.GetSentenceStart(1)); + EXPECT_EQ(13, target_data.GetSentenceStart(2)); - EXPECT_EQ(4, target_data->GetSentenceLength(0)); - EXPECT_EQ(7, target_data->GetSentenceLength(1)); + EXPECT_EQ(4, target_data.GetSentenceLength(0)); + EXPECT_EQ(7, target_data.GetSentenceLength(1)); vector<int> expected_target_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}; for (size_t i = 0; i < expected_target_ids.size(); ++i) { - EXPECT_EQ(expected_target_ids[i], target_data->GetSentenceId(i)); + EXPECT_EQ(expected_target_ids[i], target_data.GetSentenceId(i)); } } +TEST_F(DataArrayTest, TestSerialization) { + stringstream stream(ios_base::binary | ios_base::out | ios_base::in); + ar::binary_oarchive output_stream(stream, ar::no_header); + output_stream << source_data << target_data; + + DataArray source_copy, target_copy; + ar::binary_iarchive input_stream(stream, ar::no_header); + input_stream >> source_copy >> target_copy; + + EXPECT_EQ(source_data, source_copy); + EXPECT_EQ(target_data, target_copy); +} + } // namespace } // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index ee4ba42c..3b8aed69 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -165,25 +165,12 @@ void Precomputation::AddStartPositions( positions.push_back(pos3); } -void Precomputation::WriteBinary(const fs::path& filepath) const { - FILE* file = fopen(filepath.string().c_str(), "w"); - - // TODO(pauldb): Refactor this code. - int size = collocations.size(); - fwrite(&size, sizeof(int), 1, file); - for (auto entry: collocations) { - size = entry.first.size(); - fwrite(&size, sizeof(int), 1, file); - fwrite(entry.first.data(), sizeof(int), size, file); - - size = entry.second.size(); - fwrite(&size, sizeof(int), 1, file); - fwrite(entry.second.data(), sizeof(int), size, file); - } -} - const Index& Precomputation::GetCollocations() const { return collocations; } +bool Precomputation::operator==(const Precomputation& other) const { + return collocations == other.collocations; +} + } // namespace extractor diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 3e792ac7..9f0c9424 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -9,6 +9,9 @@ #include <boost/filesystem.hpp> #include <boost/functional/hash.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/utility.hpp> +#include <boost/serialization/vector.hpp> namespace fs = boost::filesystem; using namespace std; @@ -39,19 +42,19 @@ class Precomputation { int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency); - virtual ~Precomputation(); + // Creates empty precomputation data structure. + Precomputation(); - void WriteBinary(const fs::path& filepath) const; + virtual ~Precomputation(); // Returns a reference to the index. virtual const Index& GetCollocations() const; + bool operator==(const Precomputation& other) const; + static int FIRST_NONTERMINAL; static int SECOND_NONTERMINAL; - protected: - Precomputation(); - private: // Finds the most frequent contiguous collocations. vector<vector<int>> FindMostFrequentPatterns( @@ -72,6 +75,28 @@ class Precomputation { // Adds an occurrence of a ternary collocation. void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3); + friend class boost::serialization::access; + + template<class Archive> void save(Archive& ar, unsigned int) const { + int num_entries = collocations.size(); + ar << num_entries; + for (pair<vector<int>, vector<int>> entry: collocations) { + ar << entry; + } + } + + template<class Archive> void load(Archive& ar, unsigned int) { + int num_entries; + ar >> num_entries; + for (size_t i = 0; i < num_entries; ++i) { + pair<vector<int>, vector<int>> entry; + ar >> entry; + collocations.insert(entry); + } + } + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + Index collocations; }; diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index 363febb7..e81ece5d 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -1,14 +1,19 @@ #include <gtest/gtest.h> #include <memory> +#include <sstream> #include <vector> +#include <boost/archive/text_iarchive.hpp> +#include <boost/archive/text_oarchive.hpp> + #include "mocks/mock_data_array.h" #include "mocks/mock_suffix_array.h" #include "precomputation.h" using namespace std; using namespace ::testing; +namespace ar = boost::archive; namespace extractor { namespace { @@ -29,15 +34,17 @@ class PrecomputationTest : public Test { GetSuffix(i)).WillRepeatedly(Return(suffixes[i])); } EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp)); + + precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); } vector<int> data; shared_ptr<MockDataArray> data_array; shared_ptr<MockSuffixArray> suffix_array; + Precomputation precomputation; }; TEST_F(PrecomputationTest, TestCollocations) { - Precomputation precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); Index collocations = precomputation.GetCollocations(); vector<int> key = {2, 3, -1, 2}; @@ -101,6 +108,18 @@ TEST_F(PrecomputationTest, TestCollocations) { EXPECT_EQ(0, collocations.count(key)); } +TEST_F(PrecomputationTest, TestSerialization) { + stringstream stream(ios_base::out | ios_base::in); + ar::text_oarchive output_stream(stream, ar::no_header); + output_stream << precomputation; + + Precomputation precomputation_copy; + ar::text_iarchive input_stream(stream, ar::no_header); + input_stream >> precomputation_copy; + + EXPECT_EQ(precomputation, precomputation_copy); +} + } // namespace } // namespace extractor diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 6cee42d5..8a9ca89d 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -118,17 +118,17 @@ int main(int argc, char** argv) { << " seconds" << endl; // Constructs the suffix array for the source data. - cerr << "Creating source suffix array..." << endl; start_time = Clock::now(); + cerr << "Constructing source suffix array..." << endl; shared_ptr<SuffixArray> source_suffix_array = make_shared<SuffixArray>(source_data_array); stop_time = Clock::now(); - cerr << "Creating suffix array took " + cerr << "Constructing suffix array took " << GetDuration(start_time, stop_time) << " seconds" << endl; // Reads the alignment. - cerr << "Reading alignment..." << endl; start_time = Clock::now(); + cerr << "Reading alignment..." << endl; shared_ptr<Alignment> alignment = make_shared<Alignment>(vm["alignment"].as<string>()); stop_time = Clock::now(); @@ -137,8 +137,8 @@ int main(int argc, char** argv) { // Constructs an index storing the occurrences in the source data for each // frequent collocation. - cerr << "Precomputing collocations..." << endl; start_time = Clock::now(); + cerr << "Precomputing collocations..." << endl; shared_ptr<Precomputation> precomputation = make_shared<Precomputation>( source_suffix_array, vm["frequent"].as<int>(), @@ -154,8 +154,8 @@ int main(int argc, char** argv) { // Constructs a table storing p(e | f) and p(f | e) for every pair of source // and target words. - cerr << "Precomputing conditional probabilities..." << endl; start_time = Clock::now(); + cerr << "Precomputing conditional probabilities..." << endl; shared_ptr<TranslationTable> table = make_shared<TranslationTable>( source_data_array, target_data_array, alignment); stop_time = Clock::now(); diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc index 65b2d581..0cf4d1f6 100644 --- a/extractor/suffix_array.cc +++ b/extractor/suffix_array.cc @@ -186,20 +186,6 @@ shared_ptr<DataArray> SuffixArray::GetData() const { return data_array; } -void SuffixArray::WriteBinary(const fs::path& filepath) const { - FILE* file = fopen(filepath.string().c_str(), "w"); - assert(file); - data_array->WriteBinary(file); - - int size = suffix_array.size(); - fwrite(&size, sizeof(int), 1, file); - fwrite(suffix_array.data(), sizeof(int), size, file); - - size = word_start.size(); - fwrite(&size, sizeof(int), 1, file); - fwrite(word_start.data(), sizeof(int), size, file); -} - PhraseLocation SuffixArray::Lookup(int low, int high, const string& word, int offset) const { if (!data_array->HasWord(word)) { @@ -232,4 +218,10 @@ int SuffixArray::LookupRangeStart(int low, int high, int word_id, return result; } +bool SuffixArray::operator==(const SuffixArray& other) const { + return *data_array == *other.data_array && + suffix_array == other.suffix_array && + word_start == other.word_start; +} + } // namespace extractor diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h index bf731d79..8ee454ec 100644 --- a/extractor/suffix_array.h +++ b/extractor/suffix_array.h @@ -6,6 +6,9 @@ #include <vector> #include <boost/filesystem.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/split_member.hpp> +#include <boost/serialization/vector.hpp> namespace fs = boost::filesystem; using namespace std; @@ -20,6 +23,9 @@ class SuffixArray { // Creates a suffix array from a data array. SuffixArray(shared_ptr<DataArray> data_array); + // Creates empty suffix array. + SuffixArray(); + virtual ~SuffixArray(); // Returns the size of the suffix array. @@ -40,10 +46,7 @@ class SuffixArray { virtual PhraseLocation Lookup(int low, int high, const string& word, int offset) const; - void WriteBinary(const fs::path& filepath) const; - - protected: - SuffixArray(); + bool operator==(const SuffixArray& other) const; private: // Constructs the suffix array using the algorithm of Larsson and Sadakane @@ -65,6 +68,23 @@ class SuffixArray { // offset value is greater or equal to word_id. int LookupRangeStart(int low, int high, int word_id, int offset) const; + friend class boost::serialization::access; + + template<class Archive> void save(Archive& ar, unsigned int) const { + ar << *data_array; + ar << suffix_array; + ar << word_start; + } + + template<class Archive> void load(Archive& ar, unsigned int) { + data_array = make_shared<DataArray>(); + ar >> *data_array; + ar >> suffix_array; + ar >> word_start; + } + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + shared_ptr<DataArray> data_array; vector<int> suffix_array; vector<int> word_start; diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index 8431a16e..ba0dbcc3 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -1,13 +1,17 @@ #include <gtest/gtest.h> +#include <vector> + +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> + #include "mocks/mock_data_array.h" #include "phrase_location.h" #include "suffix_array.h" -#include <vector> - using namespace std; using namespace ::testing; +namespace ar = boost::archive; namespace extractor { namespace { @@ -20,30 +24,30 @@ class SuffixArrayTest : public Test { EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7)); EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13)); - suffix_array = make_shared<SuffixArray>(data_array); + suffix_array = SuffixArray(data_array); } vector<int> data; - shared_ptr<SuffixArray> suffix_array; + SuffixArray suffix_array; shared_ptr<MockDataArray> data_array; }; TEST_F(SuffixArrayTest, TestData) { - EXPECT_EQ(data_array, suffix_array->GetData()); - EXPECT_EQ(14, suffix_array->GetSize()); + EXPECT_EQ(data_array, suffix_array.GetData()); + EXPECT_EQ(14, suffix_array.GetSize()); } TEST_F(SuffixArrayTest, TestBuildSuffixArray) { vector<int> expected_suffix_array = {13, 11, 2, 12, 3, 6, 10, 1, 4, 7, 5, 9, 0, 8}; for (size_t i = 0; i < expected_suffix_array.size(); ++i) { - EXPECT_EQ(expected_suffix_array[i], suffix_array->GetSuffix(i)); + EXPECT_EQ(expected_suffix_array[i], suffix_array.GetSuffix(i)); } } TEST_F(SuffixArrayTest, TestBuildLCP) { vector<int> expected_lcp = {-1, 0, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1}; - EXPECT_EQ(expected_lcp, suffix_array->BuildLCPArray()); + EXPECT_EQ(expected_lcp, suffix_array.BuildLCPArray()); } TEST_F(SuffixArrayTest, TestLookup) { @@ -53,25 +57,37 @@ TEST_F(SuffixArrayTest, TestLookup) { EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6)); - EXPECT_EQ(PhraseLocation(11, 14), suffix_array->Lookup(0, 14, "word1", 0)); + EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0)); EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false)); - EXPECT_EQ(PhraseLocation(0, 0), suffix_array->Lookup(0, 14, "word2", 0)); + EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0)); EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4)); - EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 14, "word3", 1)); + EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1)); EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1)); - EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word4", 2)); + EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2)); EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2)); - EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word5", 3)); + EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3)); + + EXPECT_EQ(PhraseLocation(12, 13), suffix_array.Lookup(11, 13, "word3", 4)); + EXPECT_EQ(PhraseLocation(11, 11), suffix_array.Lookup(11, 13, "word5", 1)); +} + +TEST_F(SuffixArrayTest, TestSerialization) { + stringstream stream(ios_base::binary | ios_base::out | ios_base::in); + ar::binary_oarchive output_stream(stream, ar::no_header); + output_stream << suffix_array; + + SuffixArray suffix_array_copy; + ar::binary_iarchive input_stream(stream, ar::no_header); + input_stream >> suffix_array_copy; - EXPECT_EQ(PhraseLocation(12, 13), suffix_array->Lookup(11, 13, "word3", 4)); - EXPECT_EQ(PhraseLocation(11, 11), suffix_array->Lookup(11, 13, "word5", 1)); + EXPECT_EQ(suffix_array, suffix_array_copy); } } // namespace diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index adb59cb5..03e41d9a 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -112,15 +112,10 @@ double TranslationTable::GetSourceGivenTargetScore( return translation_probabilities[make_pair(source_id, target_id)].second; } -void TranslationTable::WriteBinary(const fs::path& filepath) const { - FILE* file = fopen(filepath.string().c_str(), "w"); - - int size = translation_probabilities.size(); - fwrite(&size, sizeof(int), 1, file); - for (auto entry: translation_probabilities) { - fwrite(&entry.first, sizeof(entry.first), 1, file); - fwrite(&entry.second, sizeof(entry.second), 1, file); - } +bool TranslationTable::operator==(const TranslationTable& other) const { + return *source_data_array == *other.source_data_array && + *target_data_array == *other.target_data_array && + translation_probabilities == other.translation_probabilities; } } // namespace extractor diff --git a/extractor/translation_table.h b/extractor/translation_table.h index ed43ad72..2a37bab7 100644 --- a/extractor/translation_table.h +++ b/extractor/translation_table.h @@ -7,6 +7,9 @@ #include <boost/filesystem.hpp> #include <boost/functional/hash.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/split_member.hpp> +#include <boost/serialization/utility.hpp> using namespace std; namespace fs = boost::filesystem; @@ -23,11 +26,16 @@ class DataArray; */ class TranslationTable { public: + // Constructs a translation table from source data, target data and the + // corresponding alignment. TranslationTable( shared_ptr<DataArray> source_data_array, shared_ptr<DataArray> target_data_array, shared_ptr<Alignment> alignment); + // Creates empty translation table. + TranslationTable(); + virtual ~TranslationTable(); // Returns p(e | f). @@ -38,10 +46,7 @@ class TranslationTable { virtual double GetSourceGivenTargetScore(const string& source_word, const string& target_word); - void WriteBinary(const fs::path& filepath) const; - - protected: - TranslationTable(); + bool operator==(const TranslationTable& other) const; private: // Increment links count for the given (f, e) word pair. @@ -52,6 +57,35 @@ class TranslationTable { int source_word_id, int target_word_id) const; + friend class boost::serialization::access; + + template<class Archive> void save(Archive& ar, unsigned int) const { + ar << *source_data_array << *target_data_array; + + int num_entries = translation_probabilities.size(); + ar << num_entries; + for (auto entry: translation_probabilities) { + ar << entry; + } + } + + template<class Archive> void load(Archive& ar, unsigned int) { + source_data_array = make_shared<DataArray>(); + ar >> *source_data_array; + target_data_array = make_shared<DataArray>(); + ar >> *target_data_array; + + int num_entries; + ar >> num_entries; + for (size_t i = 0; i < num_entries; ++i) { + pair<pair<int, int>, pair<double, double>> entry; + ar >> entry; + translation_probabilities.insert(entry); + } + } + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + shared_ptr<DataArray> source_data_array; shared_ptr<DataArray> target_data_array; unordered_map<pair<int, int>, pair<double, double>, PairHash> diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc index d14f2f89..606777bd 100644 --- a/extractor/translation_table_test.cc +++ b/extractor/translation_table_test.cc @@ -1,83 +1,106 @@ #include <gtest/gtest.h> #include <memory> +#include <sstream> #include <string> #include <vector> +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> + #include "mocks/mock_alignment.h" #include "mocks/mock_data_array.h" #include "translation_table.h" using namespace std; using namespace ::testing; +namespace ar = boost::archive; namespace extractor { namespace { -TEST(TranslationTableTest, TestScores) { - vector<string> words = {"a", "b", "c"}; - - vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0}; - vector<int> source_sentence_start = {0, 6, 10, 14}; - shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>(); - EXPECT_CALL(*source_data_array, GetData()) - .WillRepeatedly(ReturnRef(source_data)); - EXPECT_CALL(*source_data_array, GetNumSentences()) - .WillRepeatedly(Return(3)); - for (size_t i = 0; i < source_sentence_start.size(); ++i) { - EXPECT_CALL(*source_data_array, GetSentenceStart(i)) - .WillRepeatedly(Return(source_sentence_start[i])); - } - for (size_t i = 0; i < words.size(); ++i) { - EXPECT_CALL(*source_data_array, HasWord(words[i])) - .WillRepeatedly(Return(true)); - EXPECT_CALL(*source_data_array, GetWordId(words[i])) - .WillRepeatedly(Return(i + 2)); - } - EXPECT_CALL(*source_data_array, HasWord("d")) - .WillRepeatedly(Return(false)); - - vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0}; - vector<int> target_sentence_start = {0, 7, 10, 13}; - shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>(); - EXPECT_CALL(*target_data_array, GetData()) - .WillRepeatedly(ReturnRef(target_data)); - for (size_t i = 0; i < target_sentence_start.size(); ++i) { - EXPECT_CALL(*target_data_array, GetSentenceStart(i)) - .WillRepeatedly(Return(target_sentence_start[i])); - } - for (size_t i = 0; i < words.size(); ++i) { - EXPECT_CALL(*target_data_array, HasWord(words[i])) - .WillRepeatedly(Return(true)); - EXPECT_CALL(*target_data_array, GetWordId(words[i])) - .WillRepeatedly(Return(i + 2)); +class TranslationTableTest : public Test { + protected: + virtual void SetUp() { + vector<string> words = {"a", "b", "c"}; + + vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0}; + vector<int> source_sentence_start = {0, 6, 10, 14}; + shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*source_data_array, GetData()) + .WillRepeatedly(ReturnRef(source_data)); + EXPECT_CALL(*source_data_array, GetNumSentences()) + .WillRepeatedly(Return(3)); + for (size_t i = 0; i < source_sentence_start.size(); ++i) { + EXPECT_CALL(*source_data_array, GetSentenceStart(i)) + .WillRepeatedly(Return(source_sentence_start[i])); + } + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*source_data_array, HasWord(words[i])) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*source_data_array, GetWordId(words[i])) + .WillRepeatedly(Return(i + 2)); + } + EXPECT_CALL(*source_data_array, HasWord("d")) + .WillRepeatedly(Return(false)); + + vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0}; + vector<int> target_sentence_start = {0, 7, 10, 13}; + shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*target_data_array, GetData()) + .WillRepeatedly(ReturnRef(target_data)); + for (size_t i = 0; i < target_sentence_start.size(); ++i) { + EXPECT_CALL(*target_data_array, GetSentenceStart(i)) + .WillRepeatedly(Return(target_sentence_start[i])); + } + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*target_data_array, HasWord(words[i])) + .WillRepeatedly(Return(true)); + EXPECT_CALL(*target_data_array, GetWordId(words[i])) + .WillRepeatedly(Return(i + 2)); + } + EXPECT_CALL(*target_data_array, HasWord("d")) + .WillRepeatedly(Return(false)); + + vector<pair<int, int>> links1 = { + make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), + make_pair(4, 4), make_pair(4, 5) + }; + vector<pair<int, int>> links2 = {make_pair(1, 0), make_pair(2, 1)}; + vector<pair<int, int>> links3 = {make_pair(0, 0), make_pair(2, 1)}; + shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>(); + EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1)); + EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2)); + EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3)); + + table = TranslationTable(source_data_array, target_data_array, alignment); } - EXPECT_CALL(*target_data_array, HasWord("d")) - .WillRepeatedly(Return(false)); - - vector<pair<int, int>> links1 = { - make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), - make_pair(4, 4), make_pair(4, 5) - }; - vector<pair<int, int>> links2 = {make_pair(1, 0), make_pair(2, 1)}; - vector<pair<int, int>> links3 = {make_pair(0, 0), make_pair(2, 1)}; - shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>(); - EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1)); - EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2)); - EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3)); - - shared_ptr<TranslationTable> table = make_shared<TranslationTable>( - source_data_array, target_data_array, alignment); - - EXPECT_EQ(0.75, table->GetTargetGivenSourceScore("a", "a")); - EXPECT_EQ(0, table->GetTargetGivenSourceScore("a", "b")); - EXPECT_EQ(0.5, table->GetTargetGivenSourceScore("c", "c")); - EXPECT_EQ(-1, table->GetTargetGivenSourceScore("c", "d")); - - EXPECT_EQ(1, table->GetSourceGivenTargetScore("a", "a")); - EXPECT_EQ(0, table->GetSourceGivenTargetScore("a", "b")); - EXPECT_EQ(1, table->GetSourceGivenTargetScore("c", "c")); - EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d")); + + TranslationTable table; +}; + +TEST_F(TranslationTableTest, TestScores) { + EXPECT_EQ(0.75, table.GetTargetGivenSourceScore("a", "a")); + EXPECT_EQ(0, table.GetTargetGivenSourceScore("a", "b")); + EXPECT_EQ(0.5, table.GetTargetGivenSourceScore("c", "c")); + EXPECT_EQ(-1, table.GetTargetGivenSourceScore("c", "d")); + + EXPECT_EQ(1, table.GetSourceGivenTargetScore("a", "a")); + EXPECT_EQ(0, table.GetSourceGivenTargetScore("a", "b")); + EXPECT_EQ(1, table.GetSourceGivenTargetScore("c", "c")); + EXPECT_EQ(-1, table.GetSourceGivenTargetScore("c", "d")); +} + +TEST_F(TranslationTableTest, TestSerialization) { + stringstream stream(ios_base::binary | ios_base::out | ios_base::in); + ar::binary_oarchive output_stream(stream, ar::no_header); + output_stream << table; + + TranslationTable table_copy; + ar::binary_iarchive input_stream(stream, ar::no_header); + input_stream >> table_copy; + + EXPECT_EQ(table, table_copy); } } // namespace |