diff options
Diffstat (limited to 'extractor')
| -rw-r--r-- | extractor/alignment.cc | 15 | ||||
| -rw-r--r-- | extractor/alignment.h | 19 | ||||
| -rw-r--r-- | extractor/alignment_test.cc | 26 | ||||
| -rw-r--r-- | extractor/compile.cc | 82 | ||||
| -rw-r--r-- | extractor/data_array.cc | 33 | ||||
| -rw-r--r-- | extractor/data_array.h | 38 | ||||
| -rw-r--r-- | extractor/data_array_test.cc | 91 | ||||
| -rw-r--r-- | extractor/precomputation.cc | 21 | ||||
| -rw-r--r-- | extractor/precomputation.h | 35 | ||||
| -rw-r--r-- | extractor/precomputation_test.cc | 21 | ||||
| -rw-r--r-- | extractor/run_extractor.cc | 26 | ||||
| -rw-r--r-- | extractor/suffix_array.cc | 20 | ||||
| -rw-r--r-- | extractor/suffix_array.h | 28 | ||||
| -rw-r--r-- | extractor/suffix_array_test.cc | 46 | ||||
| -rw-r--r-- | extractor/translation_table.cc | 27 | ||||
| -rw-r--r-- | extractor/translation_table.h | 42 | ||||
| -rw-r--r-- | extractor/translation_table_test.cc | 149 | 
17 files changed, 473 insertions, 246 deletions
| diff --git a/extractor/alignment.cc b/extractor/alignment.cc index b187c03a..2278c825 100644 --- a/extractor/alignment.cc +++ b/extractor/alignment.cc @@ -23,8 +23,8 @@ Alignment::Alignment(const string& filename) {      boost::split(items, line, boost::is_any_of(" -"));      vector<pair<int, int>> alignment;      alignment.reserve(items.size() / 2); -    for (size_t i = 0; i < items.size(); i += 2) { -      alignment.push_back(make_pair(stoi(items[i]), stoi(items[i + 1]))); +    for (size_t i = 1; i < items.size(); i += 2) { +      alignment.push_back(make_pair(stoi(items[i - 1]), stoi(items[i])));      }      alignments.push_back(alignment);    } @@ -39,15 +39,8 @@ vector<pair<int, int>> Alignment::GetLinks(int sentence_index) const {    return alignments[sentence_index];  } -void Alignment::WriteBinary(const fs::path& filepath) { -  FILE* file = fopen(filepath.string().c_str(), "w"); -  int size = alignments.size(); -  fwrite(&size, sizeof(int), 1, file); -  for (vector<pair<int, int>> alignment: alignments) { -    size = alignment.size(); -    fwrite(&size, sizeof(int), 1, file); -    fwrite(alignment.data(), sizeof(pair<int, int>), size, file); -  } +bool Alignment::operator==(const Alignment& other) const { +  return alignments == other.alignments;  }  } // namespace extractor diff --git a/extractor/alignment.h b/extractor/alignment.h index 4596f92b..dc5a8b55 100644 --- a/extractor/alignment.h +++ b/extractor/alignment.h @@ -5,6 +5,10 @@  #include <vector>  #include <boost/filesystem.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/split_member.hpp> +#include <boost/serialization/utility.hpp> +#include <boost/serialization/vector.hpp>  namespace fs = boost::filesystem;  using namespace std; @@ -19,18 +23,23 @@ class Alignment {    // Reads alignment from text file.    Alignment(const string& filename); +  // Creates empty alignment. +  Alignment(); +    // Returns the alignment for a given sentence.    virtual vector<pair<int, int>> GetLinks(int sentence_index) const; -  // Writes alignment to file in binary format. -  void WriteBinary(const fs::path& filepath); -    virtual ~Alignment(); - protected: -  Alignment(); +  bool operator==(const Alignment& alignment) const;   private: +  friend class boost::serialization::access; + +  template<class Archive> void serialize(Archive& ar, unsigned int) { +    ar & alignments; +  } +    vector<vector<pair<int, int>>> alignments;  }; diff --git a/extractor/alignment_test.cc b/extractor/alignment_test.cc index 43c37ebd..1b8ff531 100644 --- a/extractor/alignment_test.cc +++ b/extractor/alignment_test.cc @@ -1,12 +1,16 @@  #include <gtest/gtest.h> -#include <memory> +#include <sstream>  #include <string> +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> +  #include "alignment.h"  using namespace std;  using namespace ::testing; +namespace ar = boost::archive;  namespace extractor {  namespace { @@ -14,19 +18,31 @@ namespace {  class AlignmentTest : public Test {   protected:    virtual void SetUp() { -    alignment = make_shared<Alignment>("sample_alignment.txt"); +    alignment = Alignment("sample_alignment.txt");    } -  shared_ptr<Alignment> alignment; +  Alignment alignment;  };  TEST_F(AlignmentTest, TestGetLinks) {    vector<pair<int, int>> expected_links = {      make_pair(0, 0), make_pair(1, 1), make_pair(2, 2)    }; -  EXPECT_EQ(expected_links, alignment->GetLinks(0)); +  EXPECT_EQ(expected_links, alignment.GetLinks(0));    expected_links = {make_pair(1, 0), make_pair(2, 1)}; -  EXPECT_EQ(expected_links, alignment->GetLinks(1)); +  EXPECT_EQ(expected_links, alignment.GetLinks(1)); +} + +TEST_F(AlignmentTest, TestSerialization) { +  stringstream stream(ios_base::binary | ios_base::out | ios_base::in); +  ar::binary_oarchive output_stream(stream, ar::no_header); +  output_stream << alignment; + +  Alignment alignment_copy; +  ar::binary_iarchive input_stream(stream, ar::no_header); +  input_stream >> alignment_copy; + +  EXPECT_EQ(alignment, alignment_copy);  }  } // namespace diff --git a/extractor/compile.cc b/extractor/compile.cc index a9ae2cef..65fdd509 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -1,6 +1,8 @@ +#include <fstream>  #include <iostream>  #include <string> +#include <boost/archive/binary_oarchive.hpp>  #include <boost/filesystem.hpp>  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> @@ -9,8 +11,10 @@  #include "data_array.h"  #include "precomputation.h"  #include "suffix_array.h" +#include "time_util.h"  #include "translation_table.h" +namespace ar = boost::archive;  namespace fs = boost::filesystem;  namespace po = boost::program_options;  using namespace std; @@ -58,11 +62,14 @@ int main(int argc, char** argv) {      return 1;    } -  fs::path output_dir(vm["output"].as<string>().c_str()); +  fs::path output_dir(vm["output"].as<string>());    if (!fs::exists(output_dir)) {      fs::create_directory(output_dir);    } +  // Reading source and target data. +  Clock::time_point start_time = Clock::now(); +  cerr << "Reading source and target data..." << endl;    shared_ptr<DataArray> source_data_array, target_data_array;    if (vm.count("bitext")) {      source_data_array = make_shared<DataArray>( @@ -73,15 +80,53 @@ int main(int argc, char** argv) {      source_data_array = make_shared<DataArray>(vm["source"].as<string>());      target_data_array = make_shared<DataArray>(vm["target"].as<string>());    } + +  Clock::time_point start_write = Clock::now(); +  ofstream target_fstream((output_dir / fs::path("target.bin")).string()); +  ar::binary_oarchive target_stream(target_fstream); +  target_stream << *target_data_array; +  Clock::time_point stop_write = Clock::now(); +  double write_duration = GetDuration(start_write, stop_write); + +  Clock::time_point stop_time = Clock::now(); +  cerr << "Reading data took " << GetDuration(start_time, stop_time) +       << " seconds" << endl; + +  // Constructing and compiling the suffix array. +  start_time = Clock::now(); +  cerr << "Constructing source suffix array..." << endl;    shared_ptr<SuffixArray> source_suffix_array =        make_shared<SuffixArray>(source_data_array); -  source_suffix_array->WriteBinary(output_dir / fs::path("f.bin")); -  target_data_array->WriteBinary(output_dir / fs::path("e.bin")); +  start_write = Clock::now(); +  ofstream source_fstream((output_dir / fs::path("source.bin")).string()); +  ar::binary_oarchive output_stream(source_fstream); +  output_stream << *source_suffix_array; +  stop_write = Clock::now(); +  write_duration += GetDuration(start_write, stop_write); + +  cerr << "Constructing suffix array took " +       << GetDuration(start_time, stop_time) << " seconds" << endl; + +  // Reading alignment. +  start_time = Clock::now(); +  cerr << "Reading alignment..." << endl;    shared_ptr<Alignment> alignment =        make_shared<Alignment>(vm["alignment"].as<string>()); -  alignment->WriteBinary(output_dir / fs::path("a.bin")); +  start_write = Clock::now(); +  ofstream alignment_fstream((output_dir / fs::path("alignment.bin")).string()); +  ar::binary_oarchive alignment_stream(alignment_fstream); +  alignment_stream << *alignment; +  stop_write = Clock::now(); +  write_duration += GetDuration(start_write, stop_write); + +  stop_time = Clock::now(); +  cerr << "Reading alignment took " +       << GetDuration(start_time, stop_time) << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Precomputing collocations..." << endl;    Precomputation precomputation(        source_suffix_array,        vm["frequent"].as<int>(), @@ -91,10 +136,35 @@ int main(int argc, char** argv) {        vm["min_gap_size"].as<int>(),        vm["max_phrase_len"].as<int>(),        vm["min_frequency"].as<int>()); -  precomputation.WriteBinary(output_dir / fs::path("precompute.bin")); +  start_write = Clock::now(); +  ofstream precomp_fstream((output_dir / fs::path("precomp.bin")).string()); +  ar::binary_oarchive precomp_stream(precomp_fstream); +  precomp_stream << precomputation; +  stop_write = Clock::now(); +  write_duration += GetDuration(start_write, stop_write); + +  stop_time = Clock::now(); +  cerr << "Precomputing collocations took " +       << GetDuration(start_time, stop_time) << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Precomputing conditional probabilities..." << endl;    TranslationTable table(source_data_array, target_data_array, alignment); -  table.WriteBinary(output_dir / fs::path("lex.bin")); + +  start_write = Clock::now(); +  ofstream table_fstream((output_dir / fs::path("bilex.bin")).string()); +  ar::binary_oarchive table_stream(table_fstream); +  table_stream << table; +  stop_write = Clock::now(); +  write_duration += GetDuration(start_write, stop_write); + +  stop_time = Clock::now(); +  cerr << "Precomputing conditional probabilities took " +       << GetDuration(start_time, stop_time) << " seconds" << endl; + +  cerr << "Total time spent writing: " << write_duration +       << " seconds" << endl;    return 0;  } diff --git a/extractor/data_array.cc b/extractor/data_array.cc index 203fe219..2e4bdafb 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -118,33 +118,6 @@ int DataArray::GetSentenceId(int position) const {    return sentence_id[position];  } -void DataArray::WriteBinary(const fs::path& filepath) const { -  std::cerr << "File: " << filepath.string() << std::endl; -  WriteBinary(fopen(filepath.string().c_str(), "w")); -} - -void DataArray::WriteBinary(FILE* file) const { -  int size = id2word.size(); -  fwrite(&size, sizeof(int), 1, file); -  for (string word: id2word) { -    size = word.size(); -    fwrite(&size, sizeof(int), 1, file); -    fwrite(word.data(), sizeof(char), size, file); -  } - -  size = data.size(); -  fwrite(&size, sizeof(int), 1, file); -  fwrite(data.data(), sizeof(int), size, file); - -  size = sentence_id.size(); -  fwrite(&size, sizeof(int), 1, file); -  fwrite(sentence_id.data(), sizeof(int), size, file); - -  size = sentence_start.size(); -  fwrite(&size, sizeof(int), 1, file); -  fwrite(sentence_start.data(), sizeof(int), 1, file); -} -  bool DataArray::HasWord(const string& word) const {    return word2id.count(word);  } @@ -158,4 +131,10 @@ string DataArray::GetWord(int word_id) const {    return id2word[word_id];  } +bool DataArray::operator==(const DataArray& other) const { +  return word2id == other.word2id && id2word == other.id2word && +         data == other.data && sentence_start == other.sentence_start && +         sentence_id == other.sentence_id; +} +  } // namespace extractor diff --git a/extractor/data_array.h b/extractor/data_array.h index 978a6931..2be6a09c 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -6,6 +6,10 @@  #include <vector>  #include <boost/filesystem.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/split_member.hpp> +#include <boost/serialization/string.hpp> +#include <boost/serialization/vector.hpp>  namespace fs = boost::filesystem;  using namespace std; @@ -43,6 +47,9 @@ class DataArray {    // Reads data array from bitext file where the sentences are separated by |||.    DataArray(const string& filename, const Side& side); +  // Creates empty data array. +  DataArray(); +    virtual ~DataArray();    // Returns a vector containing the word ids. @@ -82,14 +89,7 @@ class DataArray {    // Returns the number of the sentence containing the given position.    virtual int GetSentenceId(int position) const; -  // Writes data array to file in binary format. -  void WriteBinary(const fs::path& filepath) const; - -  // Writes data array to file in binary format. -  void WriteBinary(FILE* file) const; - - protected: -  DataArray(); +  bool operator==(const DataArray& other) const;   private:    // Sets up specific constants. @@ -98,6 +98,28 @@ class DataArray {    // Constructs the data array.    void CreateDataArray(const vector<string>& lines); +  friend class boost::serialization::access; + +  template<class Archive> void save(Archive& ar, unsigned int) const { +    ar << id2word; +    ar << data; +    ar << sentence_id; +    ar << sentence_start; +  } + +  template<class Archive> void load(Archive& ar, unsigned int) { +    ar >> id2word; +    for (size_t i = 0; i < id2word.size(); ++i) { +      word2id[id2word[i]] = i; +    } + +    ar >> data; +    ar >> sentence_id; +    ar >> sentence_start; +  } + +  BOOST_SERIALIZATION_SPLIT_MEMBER(); +    unordered_map<string, int> word2id;    vector<string> id2word;    vector<int> data; diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc index 71175fda..6c329e34 100644 --- a/extractor/data_array_test.cc +++ b/extractor/data_array_test.cc @@ -1,8 +1,11 @@  #include <gtest/gtest.h>  #include <memory> +#include <sstream>  #include <string> +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp>  #include <boost/filesystem.hpp>  #include "data_array.h" @@ -10,6 +13,7 @@  using namespace std;  using namespace ::testing;  namespace fs = boost::filesystem; +namespace ar = boost::archive;  namespace extractor {  namespace { @@ -18,12 +22,12 @@ class DataArrayTest : public Test {   protected:    virtual void SetUp() {      string sample_test_file("sample_bitext.txt"); -    source_data = make_shared<DataArray>(sample_test_file, SOURCE); -    target_data = make_shared<DataArray>(sample_test_file, TARGET); +    source_data = DataArray(sample_test_file, SOURCE); +    target_data = DataArray(sample_test_file, TARGET);    } -  shared_ptr<DataArray> source_data; -  shared_ptr<DataArray> target_data; +  DataArray source_data; +  DataArray target_data;  };  TEST_F(DataArrayTest, TestGetData) { @@ -32,11 +36,11 @@ TEST_F(DataArrayTest, TestGetData) {        "ana", "are", "mere", ".", "__END_OF_LINE__",        "ana", "bea", "mult", "lapte", ".", "__END_OF_LINE__"    }; -  EXPECT_EQ(expected_source_data, source_data->GetData()); -  EXPECT_EQ(expected_source_data.size(), source_data->GetSize()); +  EXPECT_EQ(expected_source_data, source_data.GetData()); +  EXPECT_EQ(expected_source_data.size(), source_data.GetSize());    for (size_t i = 0; i < expected_source_data.size(); ++i) { -    EXPECT_EQ(expected_source_data[i], source_data->AtIndex(i)); -    EXPECT_EQ(expected_source_words[i], source_data->GetWordAtIndex(i)); +    EXPECT_EQ(expected_source_data[i], source_data.AtIndex(i)); +    EXPECT_EQ(expected_source_words[i], source_data.GetWordAtIndex(i));    }    vector<int> expected_target_data = {2, 3, 4, 5, 1, 2, 6, 7, 8, 9, 10, 5, 1}; @@ -44,55 +48,68 @@ TEST_F(DataArrayTest, TestGetData) {        "anna", "has", "apples", ".", "__END_OF_LINE__",        "anna", "drinks", "a", "lot", "of", "milk", ".", "__END_OF_LINE__"    }; -  EXPECT_EQ(expected_target_data, target_data->GetData()); -  EXPECT_EQ(expected_target_data.size(), target_data->GetSize()); +  EXPECT_EQ(expected_target_data, target_data.GetData()); +  EXPECT_EQ(expected_target_data.size(), target_data.GetSize());    for (size_t i = 0; i < expected_target_data.size(); ++i) { -    EXPECT_EQ(expected_target_data[i], target_data->AtIndex(i)); -    EXPECT_EQ(expected_target_words[i], target_data->GetWordAtIndex(i)); +    EXPECT_EQ(expected_target_data[i], target_data.AtIndex(i)); +    EXPECT_EQ(expected_target_words[i], target_data.GetWordAtIndex(i));    }  }  TEST_F(DataArrayTest, TestVocabulary) { -  EXPECT_EQ(9, source_data->GetVocabularySize()); -  EXPECT_TRUE(source_data->HasWord("mere")); -  EXPECT_EQ(4, source_data->GetWordId("mere")); -  EXPECT_EQ("mere", source_data->GetWord(4)); -  EXPECT_FALSE(source_data->HasWord("banane")); - -  EXPECT_EQ(11, target_data->GetVocabularySize()); -  EXPECT_TRUE(target_data->HasWord("apples")); -  EXPECT_EQ(4, target_data->GetWordId("apples")); -  EXPECT_EQ("apples", target_data->GetWord(4)); -  EXPECT_FALSE(target_data->HasWord("bananas")); +  EXPECT_EQ(9, source_data.GetVocabularySize()); +  EXPECT_TRUE(source_data.HasWord("mere")); +  EXPECT_EQ(4, source_data.GetWordId("mere")); +  EXPECT_EQ("mere", source_data.GetWord(4)); +  EXPECT_FALSE(source_data.HasWord("banane")); + +  EXPECT_EQ(11, target_data.GetVocabularySize()); +  EXPECT_TRUE(target_data.HasWord("apples")); +  EXPECT_EQ(4, target_data.GetWordId("apples")); +  EXPECT_EQ("apples", target_data.GetWord(4)); +  EXPECT_FALSE(target_data.HasWord("bananas"));  }  TEST_F(DataArrayTest, TestSentenceData) { -  EXPECT_EQ(2, source_data->GetNumSentences()); -  EXPECT_EQ(0, source_data->GetSentenceStart(0)); -  EXPECT_EQ(5, source_data->GetSentenceStart(1)); -  EXPECT_EQ(11, source_data->GetSentenceStart(2)); +  EXPECT_EQ(2, source_data.GetNumSentences()); +  EXPECT_EQ(0, source_data.GetSentenceStart(0)); +  EXPECT_EQ(5, source_data.GetSentenceStart(1)); +  EXPECT_EQ(11, source_data.GetSentenceStart(2)); -  EXPECT_EQ(4, source_data->GetSentenceLength(0)); -  EXPECT_EQ(5, source_data->GetSentenceLength(1)); +  EXPECT_EQ(4, source_data.GetSentenceLength(0)); +  EXPECT_EQ(5, source_data.GetSentenceLength(1));    vector<int> expected_source_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1};    for (size_t i = 0; i < expected_source_ids.size(); ++i) { -    EXPECT_EQ(expected_source_ids[i], source_data->GetSentenceId(i)); +    EXPECT_EQ(expected_source_ids[i], source_data.GetSentenceId(i));    } -  EXPECT_EQ(2, target_data->GetNumSentences()); -  EXPECT_EQ(0, target_data->GetSentenceStart(0)); -  EXPECT_EQ(5, target_data->GetSentenceStart(1)); -  EXPECT_EQ(13, target_data->GetSentenceStart(2)); +  EXPECT_EQ(2, target_data.GetNumSentences()); +  EXPECT_EQ(0, target_data.GetSentenceStart(0)); +  EXPECT_EQ(5, target_data.GetSentenceStart(1)); +  EXPECT_EQ(13, target_data.GetSentenceStart(2)); -  EXPECT_EQ(4, target_data->GetSentenceLength(0)); -  EXPECT_EQ(7, target_data->GetSentenceLength(1)); +  EXPECT_EQ(4, target_data.GetSentenceLength(0)); +  EXPECT_EQ(7, target_data.GetSentenceLength(1));    vector<int> expected_target_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1};    for (size_t i = 0; i < expected_target_ids.size(); ++i) { -    EXPECT_EQ(expected_target_ids[i], target_data->GetSentenceId(i)); +    EXPECT_EQ(expected_target_ids[i], target_data.GetSentenceId(i));    }  } +TEST_F(DataArrayTest, TestSerialization) { +  stringstream stream(ios_base::binary | ios_base::out | ios_base::in); +  ar::binary_oarchive output_stream(stream, ar::no_header); +  output_stream << source_data << target_data; + +  DataArray source_copy, target_copy; +  ar::binary_iarchive input_stream(stream, ar::no_header); +  input_stream >> source_copy >> target_copy; + +  EXPECT_EQ(source_data, source_copy); +  EXPECT_EQ(target_data, target_copy); +} +  } // namespace  } // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index ee4ba42c..3b8aed69 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -165,25 +165,12 @@ void Precomputation::AddStartPositions(    positions.push_back(pos3);  } -void Precomputation::WriteBinary(const fs::path& filepath) const { -  FILE* file = fopen(filepath.string().c_str(), "w"); - -  // TODO(pauldb): Refactor this code. -  int size = collocations.size(); -  fwrite(&size, sizeof(int), 1, file); -  for (auto entry: collocations) { -    size = entry.first.size(); -    fwrite(&size, sizeof(int), 1, file); -    fwrite(entry.first.data(), sizeof(int), size, file); - -    size = entry.second.size(); -    fwrite(&size, sizeof(int), 1, file); -    fwrite(entry.second.data(), sizeof(int), size, file); -  } -} -  const Index& Precomputation::GetCollocations() const {    return collocations;  } +bool Precomputation::operator==(const Precomputation& other) const { +  return collocations == other.collocations; +} +  } // namespace extractor diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 3e792ac7..9f0c9424 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -9,6 +9,9 @@  #include <boost/filesystem.hpp>  #include <boost/functional/hash.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/utility.hpp> +#include <boost/serialization/vector.hpp>  namespace fs = boost::filesystem;  using namespace std; @@ -39,19 +42,19 @@ class Precomputation {        int max_rule_symbols, int min_gap_size,        int max_frequent_phrase_len, int min_frequency); -  virtual ~Precomputation(); +  // Creates empty precomputation data structure. +  Precomputation(); -  void WriteBinary(const fs::path& filepath) const; +  virtual ~Precomputation();    // Returns a reference to the index.    virtual const Index& GetCollocations() const; +  bool operator==(const Precomputation& other) const; +    static int FIRST_NONTERMINAL;    static int SECOND_NONTERMINAL; - protected: -  Precomputation(); -   private:    // Finds the most frequent contiguous collocations.    vector<vector<int>> FindMostFrequentPatterns( @@ -72,6 +75,28 @@ class Precomputation {    // Adds an occurrence of a ternary collocation.    void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3); +  friend class boost::serialization::access; + +  template<class Archive> void save(Archive& ar, unsigned int) const { +    int num_entries = collocations.size(); +    ar << num_entries; +    for (pair<vector<int>, vector<int>> entry: collocations) { +      ar << entry; +    } +  } + +  template<class Archive> void load(Archive& ar, unsigned int) { +    int num_entries; +    ar >> num_entries; +    for (size_t i = 0; i < num_entries; ++i) { +      pair<vector<int>, vector<int>> entry; +      ar >> entry; +      collocations.insert(entry); +    } +  } + +  BOOST_SERIALIZATION_SPLIT_MEMBER(); +    Index collocations;  }; diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index 363febb7..e81ece5d 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -1,14 +1,19 @@  #include <gtest/gtest.h>  #include <memory> +#include <sstream>  #include <vector> +#include <boost/archive/text_iarchive.hpp> +#include <boost/archive/text_oarchive.hpp> +  #include "mocks/mock_data_array.h"  #include "mocks/mock_suffix_array.h"  #include "precomputation.h"  using namespace std;  using namespace ::testing; +namespace ar = boost::archive;  namespace extractor {  namespace { @@ -29,15 +34,17 @@ class PrecomputationTest : public Test {                    GetSuffix(i)).WillRepeatedly(Return(suffixes[i]));      }      EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp)); + +    precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2);    }    vector<int> data;    shared_ptr<MockDataArray> data_array;    shared_ptr<MockSuffixArray> suffix_array; +  Precomputation precomputation;  };  TEST_F(PrecomputationTest, TestCollocations) { -  Precomputation precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2);    Index collocations = precomputation.GetCollocations();    vector<int> key = {2, 3, -1, 2}; @@ -101,6 +108,18 @@ TEST_F(PrecomputationTest, TestCollocations) {    EXPECT_EQ(0, collocations.count(key));  } +TEST_F(PrecomputationTest, TestSerialization) { +  stringstream stream(ios_base::out | ios_base::in); +  ar::text_oarchive output_stream(stream, ar::no_header); +  output_stream << precomputation; + +  Precomputation precomputation_copy; +  ar::text_iarchive input_stream(stream, ar::no_header); +  input_stream >> precomputation_copy; + +  EXPECT_EQ(precomputation, precomputation_copy); +} +  } // namespace  } // namespace extractor diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 2fc6f724..8a9ca89d 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -42,11 +42,12 @@ fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) {  }  int main(int argc, char** argv) { -  int num_threads_default = 1; -  #pragma omp parallel -  num_threads_default = omp_get_num_threads(); -    // Sets up the command line arguments map. +  int max_threads = 1; +  #pragma omp parallel +  max_threads = omp_get_num_threads(); +  string threads_option = "Number of parallel threads for extraction " +                          "(max=" + to_string(max_threads) + ")";    po::options_description desc("Command line options");    desc.add_options()      ("help,h", "Show available options") @@ -55,8 +56,7 @@ int main(int argc, char** argv) {      ("bitext,b", po::value<string>(), "Parallel text (source ||| target)")      ("alignment,a", po::value<string>()->required(), "Bitext word alignment")      ("grammars,g", po::value<string>()->required(), "Grammars output path") -    ("threads,t", po::value<int>()->default_value(num_threads_default), -        "Number of parallel extractors") +    ("threads,t", po::value<int>()->default_value(1), threads_option.c_str())      ("frequent", po::value<int>()->default_value(100),          "Number of precomputed frequent patterns")      ("super_frequent", po::value<int>()->default_value(10), @@ -97,7 +97,7 @@ int main(int argc, char** argv) {    }    int num_threads = vm["threads"].as<int>(); -  cout << "Grammar extraction will use " << num_threads << " threads." << endl; +  cerr << "Grammar extraction will use " << num_threads << " threads." << endl;    // Reads the parallel corpus.    Clock::time_point preprocess_start_time = Clock::now(); @@ -118,17 +118,17 @@ int main(int argc, char** argv) {         << " seconds" << endl;    // Constructs the suffix array for the source data. -  cerr << "Creating source suffix array..." << endl;    start_time = Clock::now(); +  cerr << "Constructing source suffix array..." << endl;    shared_ptr<SuffixArray> source_suffix_array =        make_shared<SuffixArray>(source_data_array);    stop_time = Clock::now(); -  cerr << "Creating suffix array took " +  cerr << "Constructing suffix array took "         << GetDuration(start_time, stop_time) << " seconds" << endl;    // Reads the alignment. -  cerr << "Reading alignment..." << endl;    start_time = Clock::now(); +  cerr << "Reading alignment..." << endl;    shared_ptr<Alignment> alignment =        make_shared<Alignment>(vm["alignment"].as<string>());    stop_time = Clock::now(); @@ -137,8 +137,8 @@ int main(int argc, char** argv) {    // Constructs an index storing the occurrences in the source data for each    // frequent collocation. -  cerr << "Precomputing collocations..." << endl;    start_time = Clock::now(); +  cerr << "Precomputing collocations..." << endl;    shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(        source_suffix_array,        vm["frequent"].as<int>(), @@ -154,8 +154,8 @@ int main(int argc, char** argv) {    // Constructs a table storing p(e | f) and p(f | e) for every pair of source    // and target words. -  cerr << "Precomputing conditional probabilities..." << endl;    start_time = Clock::now(); +  cerr << "Precomputing conditional probabilities..." << endl;    shared_ptr<TranslationTable> table = make_shared<TranslationTable>(        source_data_array, target_data_array, alignment);    stop_time = Clock::now(); @@ -229,7 +229,7 @@ int main(int argc, char** argv) {    }    for (size_t i = 0; i < sentences.size(); ++i) { -    cout << "<seg grammar=\"" << GetGrammarFilePath(grammar_path, i) << "\" id=\"" +    cout << "<seg grammar=" << GetGrammarFilePath(grammar_path, i) << " id=\""           << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl;    } diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc index 65b2d581..0cf4d1f6 100644 --- a/extractor/suffix_array.cc +++ b/extractor/suffix_array.cc @@ -186,20 +186,6 @@ shared_ptr<DataArray> SuffixArray::GetData() const {    return data_array;  } -void SuffixArray::WriteBinary(const fs::path& filepath) const { -  FILE* file = fopen(filepath.string().c_str(), "w"); -  assert(file); -  data_array->WriteBinary(file); - -  int size = suffix_array.size(); -  fwrite(&size, sizeof(int), 1, file); -  fwrite(suffix_array.data(), sizeof(int), size, file); - -  size = word_start.size(); -  fwrite(&size, sizeof(int), 1, file); -  fwrite(word_start.data(), sizeof(int), size, file); -} -  PhraseLocation SuffixArray::Lookup(int low, int high, const string& word,                                     int offset) const {    if (!data_array->HasWord(word)) { @@ -232,4 +218,10 @@ int SuffixArray::LookupRangeStart(int low, int high, int word_id,    return result;  } +bool SuffixArray::operator==(const SuffixArray& other) const { +  return *data_array == *other.data_array && +         suffix_array == other.suffix_array && +         word_start == other.word_start; +} +  } // namespace extractor diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h index bf731d79..8ee454ec 100644 --- a/extractor/suffix_array.h +++ b/extractor/suffix_array.h @@ -6,6 +6,9 @@  #include <vector>  #include <boost/filesystem.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/split_member.hpp> +#include <boost/serialization/vector.hpp>  namespace fs = boost::filesystem;  using namespace std; @@ -20,6 +23,9 @@ class SuffixArray {    // Creates a suffix array from a data array.    SuffixArray(shared_ptr<DataArray> data_array); +  // Creates empty suffix array. +  SuffixArray(); +    virtual ~SuffixArray();    // Returns the size of the suffix array. @@ -40,10 +46,7 @@ class SuffixArray {    virtual PhraseLocation Lookup(int low, int high, const string& word,                                  int offset) const; -  void WriteBinary(const fs::path& filepath) const; - - protected: -  SuffixArray(); +  bool operator==(const SuffixArray& other) const;   private:    // Constructs the suffix array using the algorithm of Larsson and Sadakane @@ -65,6 +68,23 @@ class SuffixArray {    // offset value is greater or equal to word_id.    int LookupRangeStart(int low, int high, int word_id, int offset) const; +  friend class boost::serialization::access; + +  template<class Archive> void save(Archive& ar, unsigned int) const { +    ar << *data_array; +    ar << suffix_array; +    ar << word_start; +  } + +  template<class Archive> void load(Archive& ar, unsigned int) { +    data_array = make_shared<DataArray>(); +    ar >> *data_array; +    ar >> suffix_array; +    ar >> word_start; +  } + +  BOOST_SERIALIZATION_SPLIT_MEMBER(); +    shared_ptr<DataArray> data_array;    vector<int> suffix_array;    vector<int> word_start; diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index 8431a16e..ba0dbcc3 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -1,13 +1,17 @@  #include <gtest/gtest.h> +#include <vector> + +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> +  #include "mocks/mock_data_array.h"  #include "phrase_location.h"  #include "suffix_array.h" -#include <vector> -  using namespace std;  using namespace ::testing; +namespace ar = boost::archive;  namespace extractor {  namespace { @@ -20,30 +24,30 @@ class SuffixArrayTest : public Test {      EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data));      EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7));      EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13)); -    suffix_array = make_shared<SuffixArray>(data_array); +    suffix_array = SuffixArray(data_array);    }    vector<int> data; -  shared_ptr<SuffixArray> suffix_array; +  SuffixArray suffix_array;    shared_ptr<MockDataArray> data_array;  };  TEST_F(SuffixArrayTest, TestData) { -  EXPECT_EQ(data_array, suffix_array->GetData()); -  EXPECT_EQ(14, suffix_array->GetSize()); +  EXPECT_EQ(data_array, suffix_array.GetData()); +  EXPECT_EQ(14, suffix_array.GetSize());  }  TEST_F(SuffixArrayTest, TestBuildSuffixArray) {    vector<int> expected_suffix_array =        {13, 11, 2, 12, 3, 6, 10, 1, 4, 7, 5, 9, 0, 8};    for (size_t i = 0; i < expected_suffix_array.size(); ++i) { -    EXPECT_EQ(expected_suffix_array[i], suffix_array->GetSuffix(i)); +    EXPECT_EQ(expected_suffix_array[i], suffix_array.GetSuffix(i));    }  }  TEST_F(SuffixArrayTest, TestBuildLCP) {    vector<int> expected_lcp = {-1, 0, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1}; -  EXPECT_EQ(expected_lcp, suffix_array->BuildLCPArray()); +  EXPECT_EQ(expected_lcp, suffix_array.BuildLCPArray());  }  TEST_F(SuffixArrayTest, TestLookup) { @@ -53,25 +57,37 @@ TEST_F(SuffixArrayTest, TestLookup) {    EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true));    EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6)); -  EXPECT_EQ(PhraseLocation(11, 14), suffix_array->Lookup(0, 14, "word1", 0)); +  EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0));    EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false)); -  EXPECT_EQ(PhraseLocation(0, 0), suffix_array->Lookup(0, 14, "word2", 0)); +  EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0));    EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true));    EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4)); -  EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 14, "word3", 1)); +  EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1));    EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true));    EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1)); -  EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word4", 2)); +  EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2));    EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true));    EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2)); -  EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word5", 3)); +  EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3)); + +  EXPECT_EQ(PhraseLocation(12, 13), suffix_array.Lookup(11, 13, "word3", 4)); +  EXPECT_EQ(PhraseLocation(11, 11), suffix_array.Lookup(11, 13, "word5", 1)); +} + +TEST_F(SuffixArrayTest, TestSerialization) { +  stringstream stream(ios_base::binary | ios_base::out | ios_base::in); +  ar::binary_oarchive output_stream(stream, ar::no_header); +  output_stream << suffix_array; + +  SuffixArray suffix_array_copy; +  ar::binary_iarchive input_stream(stream, ar::no_header); +  input_stream >> suffix_array_copy; -  EXPECT_EQ(PhraseLocation(12, 13), suffix_array->Lookup(11, 13, "word3", 4)); -  EXPECT_EQ(PhraseLocation(11, 11), suffix_array->Lookup(11, 13, "word5", 1)); +  EXPECT_EQ(suffix_array, suffix_array_copy);  }  } // namespace diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index adb59cb5..1b1ba112 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -97,7 +97,12 @@ double TranslationTable::GetTargetGivenSourceScore(    int source_id = source_data_array->GetWordId(source_word);    int target_id = target_data_array->GetWordId(target_word); -  return translation_probabilities[make_pair(source_id, target_id)].first; +  auto entry = make_pair(source_id, target_id); +  auto it = translation_probabilities.find(entry); +  if (it == translation_probabilities.end()) { +    return 0; +  } +  return it->second.first;  }  double TranslationTable::GetSourceGivenTargetScore( @@ -109,18 +114,18 @@ double TranslationTable::GetSourceGivenTargetScore(    int source_id = source_data_array->GetWordId(source_word);    int target_id = target_data_array->GetWordId(target_word); -  return translation_probabilities[make_pair(source_id, target_id)].second; +  auto entry = make_pair(source_id, target_id); +  auto it = translation_probabilities.find(entry); +  if (it == translation_probabilities.end()) { +    return 0; +  } +  return it->second.second;  } -void TranslationTable::WriteBinary(const fs::path& filepath) const { -  FILE* file = fopen(filepath.string().c_str(), "w"); - -  int size = translation_probabilities.size(); -  fwrite(&size, sizeof(int), 1, file); -  for (auto entry: translation_probabilities) { -    fwrite(&entry.first, sizeof(entry.first), 1, file); -    fwrite(&entry.second, sizeof(entry.second), 1, file); -  } +bool TranslationTable::operator==(const TranslationTable& other) const { +  return *source_data_array == *other.source_data_array && +         *target_data_array == *other.target_data_array && +         translation_probabilities == other.translation_probabilities;  }  } // namespace extractor diff --git a/extractor/translation_table.h b/extractor/translation_table.h index ed43ad72..2a37bab7 100644 --- a/extractor/translation_table.h +++ b/extractor/translation_table.h @@ -7,6 +7,9 @@  #include <boost/filesystem.hpp>  #include <boost/functional/hash.hpp> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/split_member.hpp> +#include <boost/serialization/utility.hpp>  using namespace std;  namespace fs = boost::filesystem; @@ -23,11 +26,16 @@ class DataArray;   */  class TranslationTable {   public: +  // Constructs a translation table from source data, target data and the +  // corresponding alignment.    TranslationTable(        shared_ptr<DataArray> source_data_array,        shared_ptr<DataArray> target_data_array,        shared_ptr<Alignment> alignment); +  // Creates empty translation table. +  TranslationTable(); +    virtual ~TranslationTable();    // Returns p(e | f). @@ -38,10 +46,7 @@ class TranslationTable {    virtual double GetSourceGivenTargetScore(const string& source_word,                                             const string& target_word); -  void WriteBinary(const fs::path& filepath) const; - - protected: -  TranslationTable(); +  bool operator==(const TranslationTable& other) const;   private:    // Increment links count for the given (f, e) word pair. @@ -52,6 +57,35 @@ class TranslationTable {        int source_word_id,        int target_word_id) const; +  friend class boost::serialization::access; + +  template<class Archive> void save(Archive& ar, unsigned int) const { +    ar << *source_data_array << *target_data_array; + +    int num_entries = translation_probabilities.size(); +    ar << num_entries; +    for (auto entry: translation_probabilities) { +      ar << entry; +    } +  } + +  template<class Archive> void load(Archive& ar, unsigned int) { +    source_data_array = make_shared<DataArray>(); +    ar >> *source_data_array; +    target_data_array = make_shared<DataArray>(); +    ar >> *target_data_array; + +    int num_entries; +    ar >> num_entries; +    for (size_t i = 0; i < num_entries; ++i) { +      pair<pair<int, int>, pair<double, double>> entry; +      ar >> entry; +      translation_probabilities.insert(entry); +    } +  } + +  BOOST_SERIALIZATION_SPLIT_MEMBER(); +    shared_ptr<DataArray> source_data_array;    shared_ptr<DataArray> target_data_array;    unordered_map<pair<int, int>, pair<double, double>, PairHash> diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc index d14f2f89..606777bd 100644 --- a/extractor/translation_table_test.cc +++ b/extractor/translation_table_test.cc @@ -1,83 +1,106 @@  #include <gtest/gtest.h>  #include <memory> +#include <sstream>  #include <string>  #include <vector> +#include <boost/archive/binary_iarchive.hpp> +#include <boost/archive/binary_oarchive.hpp> +  #include "mocks/mock_alignment.h"  #include "mocks/mock_data_array.h"  #include "translation_table.h"  using namespace std;  using namespace ::testing; +namespace ar = boost::archive;  namespace extractor {  namespace { -TEST(TranslationTableTest, TestScores) { -  vector<string> words = {"a", "b", "c"}; - -  vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0}; -  vector<int> source_sentence_start = {0, 6, 10, 14}; -  shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>(); -  EXPECT_CALL(*source_data_array, GetData()) -      .WillRepeatedly(ReturnRef(source_data)); -  EXPECT_CALL(*source_data_array, GetNumSentences()) -      .WillRepeatedly(Return(3)); -  for (size_t i = 0; i < source_sentence_start.size(); ++i) { -    EXPECT_CALL(*source_data_array, GetSentenceStart(i)) -        .WillRepeatedly(Return(source_sentence_start[i])); -  } -  for (size_t i = 0; i < words.size(); ++i) { -    EXPECT_CALL(*source_data_array, HasWord(words[i])) -        .WillRepeatedly(Return(true)); -    EXPECT_CALL(*source_data_array, GetWordId(words[i])) -        .WillRepeatedly(Return(i + 2)); -  } -  EXPECT_CALL(*source_data_array, HasWord("d")) -      .WillRepeatedly(Return(false)); - -  vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0}; -  vector<int> target_sentence_start = {0, 7, 10, 13}; -  shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>(); -  EXPECT_CALL(*target_data_array, GetData()) -      .WillRepeatedly(ReturnRef(target_data)); -  for (size_t i = 0; i < target_sentence_start.size(); ++i) { -    EXPECT_CALL(*target_data_array, GetSentenceStart(i)) -        .WillRepeatedly(Return(target_sentence_start[i])); -  } -  for (size_t i = 0; i < words.size(); ++i) { -    EXPECT_CALL(*target_data_array, HasWord(words[i])) -        .WillRepeatedly(Return(true)); -    EXPECT_CALL(*target_data_array, GetWordId(words[i])) -        .WillRepeatedly(Return(i + 2)); +class TranslationTableTest : public Test { + protected: +  virtual void SetUp() { +    vector<string> words = {"a", "b", "c"}; + +    vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0}; +    vector<int> source_sentence_start = {0, 6, 10, 14}; +    shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>(); +    EXPECT_CALL(*source_data_array, GetData()) +        .WillRepeatedly(ReturnRef(source_data)); +    EXPECT_CALL(*source_data_array, GetNumSentences()) +        .WillRepeatedly(Return(3)); +    for (size_t i = 0; i < source_sentence_start.size(); ++i) { +      EXPECT_CALL(*source_data_array, GetSentenceStart(i)) +          .WillRepeatedly(Return(source_sentence_start[i])); +    } +    for (size_t i = 0; i < words.size(); ++i) { +      EXPECT_CALL(*source_data_array, HasWord(words[i])) +          .WillRepeatedly(Return(true)); +      EXPECT_CALL(*source_data_array, GetWordId(words[i])) +          .WillRepeatedly(Return(i + 2)); +    } +    EXPECT_CALL(*source_data_array, HasWord("d")) +        .WillRepeatedly(Return(false)); + +    vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0}; +    vector<int> target_sentence_start = {0, 7, 10, 13}; +    shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>(); +    EXPECT_CALL(*target_data_array, GetData()) +        .WillRepeatedly(ReturnRef(target_data)); +    for (size_t i = 0; i < target_sentence_start.size(); ++i) { +      EXPECT_CALL(*target_data_array, GetSentenceStart(i)) +          .WillRepeatedly(Return(target_sentence_start[i])); +    } +    for (size_t i = 0; i < words.size(); ++i) { +      EXPECT_CALL(*target_data_array, HasWord(words[i])) +          .WillRepeatedly(Return(true)); +      EXPECT_CALL(*target_data_array, GetWordId(words[i])) +          .WillRepeatedly(Return(i + 2)); +    } +    EXPECT_CALL(*target_data_array, HasWord("d")) +        .WillRepeatedly(Return(false)); + +    vector<pair<int, int>> links1 = { +      make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), +      make_pair(4, 4), make_pair(4, 5) +    }; +    vector<pair<int, int>> links2 = {make_pair(1, 0), make_pair(2, 1)}; +    vector<pair<int, int>> links3 = {make_pair(0, 0), make_pair(2, 1)}; +    shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>(); +    EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1)); +    EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2)); +    EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3)); + +    table = TranslationTable(source_data_array, target_data_array, alignment);    } -  EXPECT_CALL(*target_data_array, HasWord("d")) -      .WillRepeatedly(Return(false)); - -  vector<pair<int, int>> links1 = { -    make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), -    make_pair(4, 4), make_pair(4, 5) -  }; -  vector<pair<int, int>> links2 = {make_pair(1, 0), make_pair(2, 1)}; -  vector<pair<int, int>> links3 = {make_pair(0, 0), make_pair(2, 1)}; -  shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>(); -  EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1)); -  EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2)); -  EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3)); - -  shared_ptr<TranslationTable> table = make_shared<TranslationTable>( -      source_data_array, target_data_array, alignment); - -  EXPECT_EQ(0.75, table->GetTargetGivenSourceScore("a", "a")); -  EXPECT_EQ(0, table->GetTargetGivenSourceScore("a", "b")); -  EXPECT_EQ(0.5, table->GetTargetGivenSourceScore("c", "c")); -  EXPECT_EQ(-1, table->GetTargetGivenSourceScore("c", "d")); - -  EXPECT_EQ(1, table->GetSourceGivenTargetScore("a", "a")); -  EXPECT_EQ(0, table->GetSourceGivenTargetScore("a", "b")); -  EXPECT_EQ(1, table->GetSourceGivenTargetScore("c", "c")); -  EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d")); + +  TranslationTable table; +}; + +TEST_F(TranslationTableTest, TestScores) { +  EXPECT_EQ(0.75, table.GetTargetGivenSourceScore("a", "a")); +  EXPECT_EQ(0, table.GetTargetGivenSourceScore("a", "b")); +  EXPECT_EQ(0.5, table.GetTargetGivenSourceScore("c", "c")); +  EXPECT_EQ(-1, table.GetTargetGivenSourceScore("c", "d")); + +  EXPECT_EQ(1, table.GetSourceGivenTargetScore("a", "a")); +  EXPECT_EQ(0, table.GetSourceGivenTargetScore("a", "b")); +  EXPECT_EQ(1, table.GetSourceGivenTargetScore("c", "c")); +  EXPECT_EQ(-1, table.GetSourceGivenTargetScore("c", "d")); +} + +TEST_F(TranslationTableTest, TestSerialization) { +  stringstream stream(ios_base::binary | ios_base::out | ios_base::in); +  ar::binary_oarchive output_stream(stream, ar::no_header); +  output_stream << table; + +  TranslationTable table_copy; +  ar::binary_iarchive input_stream(stream, ar::no_header); +  input_stream >> table_copy; + +  EXPECT_EQ(table, table_copy);  }  } // namespace | 
