diff options
| author | Patrick Simianer <p@simianer.de> | 2013-05-02 09:09:59 +0200 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2013-05-02 09:09:59 +0200 | 
| commit | 9e50f0237413180fba11b500c9dce5c600e3c157 (patch) | |
| tree | 556fc31d231353c853a864afffddd43dc525549a /extractor | |
| parent | d18024a41cbc1b54db88d499571349a6234b6db8 (diff) | |
| parent | 14ed53426726202813a8e82d706b44266f015fe1 (diff) | |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'extractor')
103 files changed, 7398 insertions, 0 deletions
| diff --git a/extractor/Makefile.am b/extractor/Makefile.am new file mode 100644 index 00000000..e94a9b91 --- /dev/null +++ b/extractor/Makefile.am @@ -0,0 +1,152 @@ +if HAVE_CXX11 + +bin_PROGRAMS = compile run_extractor + +EXTRA_PROGRAMS = alignment_test \ +    data_array_test \ +    fast_intersector_test \ +    feature_count_source_target_test \ +    feature_is_source_singleton_test \ +    feature_is_source_target_singleton_test \ +    feature_max_lex_source_given_target_test \ +    feature_max_lex_target_given_source_test \ +    feature_sample_source_count_test \ +    feature_target_given_source_coherent_test \ +    grammar_extractor_test \ +    matchings_finder_test \ +    phrase_test \ +    precomputation_test \ +    rule_extractor_helper_test \ +    rule_extractor_test \ +    rule_factory_test \ +    sampler_test \ +    scorer_test \ +    suffix_array_test \ +    target_phrase_extractor_test \ +    translation_table_test + +if HAVE_GTEST +  RUNNABLE_TESTS = alignment_test \ +    data_array_test \ +    fast_intersector_test \ +    feature_count_source_target_test \ +    feature_is_source_singleton_test \ +    feature_is_source_target_singleton_test \ +    feature_max_lex_source_given_target_test \ +    feature_max_lex_target_given_source_test \ +    feature_sample_source_count_test \ +    feature_target_given_source_coherent_test \ +    grammar_extractor_test \ +    matchings_finder_test \ +    phrase_test \ +    precomputation_test \ +    rule_extractor_helper_test \ +    rule_extractor_test \ +    rule_factory_test \ +    sampler_test \ +    scorer_test \ +    suffix_array_test \ +    target_phrase_extractor_test \ +    translation_table_test +endif + +noinst_PROGRAMS = $(RUNNABLE_TESTS) + +TESTS = $(RUNNABLE_TESTS) + +alignment_test_SOURCES = alignment_test.cc +alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +data_array_test_SOURCES = data_array_test.cc +data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +fast_intersector_test_SOURCES = fast_intersector_test.cc +fast_intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +feature_count_source_target_test_SOURCES = features/count_source_target_test.cc +feature_count_source_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_is_source_singleton_test_SOURCES = features/is_source_singleton_test.cc +feature_is_source_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_is_source_target_singleton_test_SOURCES = features/is_source_target_singleton_test.cc +feature_is_source_target_singleton_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_max_lex_source_given_target_test_SOURCES = features/max_lex_source_given_target_test.cc +feature_max_lex_source_given_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +feature_max_lex_target_given_source_test_SOURCES = features/max_lex_target_given_source_test.cc +feature_max_lex_target_given_source_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +feature_sample_source_count_test_SOURCES = features/sample_source_count_test.cc +feature_sample_source_count_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +feature_target_given_source_coherent_test_SOURCES = features/target_given_source_coherent_test.cc +feature_target_given_source_coherent_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +grammar_extractor_test_SOURCES = grammar_extractor_test.cc +grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +matchings_finder_test_SOURCES = matchings_finder_test.cc +matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +phrase_test_SOURCES = phrase_test.cc +phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +precomputation_test_SOURCES = precomputation_test.cc +precomputation_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +rule_extractor_helper_test_SOURCES = rule_extractor_helper_test.cc +rule_extractor_helper_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +rule_extractor_test_SOURCES = rule_extractor_test.cc +rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +rule_factory_test_SOURCES = rule_factory_test.cc +rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +sampler_test_SOURCES = sampler_test.cc +sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +scorer_test_SOURCES = scorer_test.cc +scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +suffix_array_test_SOURCES = suffix_array_test.cc +suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc +target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +translation_table_test_SOURCES = translation_table_test.cc +translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a + +noinst_LIBRARIES = libextractor.a libcompile.a + +compile_SOURCES = compile.cc +compile_LDADD = libcompile.a +run_extractor_SOURCES = run_extractor.cc +run_extractor_LDADD = libextractor.a + +libcompile_a_SOURCES = \ +  alignment.cc \ +  data_array.cc \ +  phrase_location.cc \ +  precomputation.cc \ +  suffix_array.cc \ +  time_util.cc \ +  translation_table.cc + +libextractor_a_SOURCES = \ +  alignment.cc \ +  data_array.cc \ +  fast_intersector.cc \ +  features/count_source_target.cc \ +  features/feature.cc \ +  features/is_source_singleton.cc \ +  features/is_source_target_singleton.cc \ +  features/max_lex_source_given_target.cc \ +  features/max_lex_target_given_source.cc \ +  features/sample_source_count.cc \ +  features/target_given_source_coherent.cc \ +  grammar.cc \ +  grammar_extractor.cc \ +  matchings_finder.cc \ +  matchings_trie.cc \ +  phrase.cc \ +  phrase_builder.cc \ +  phrase_location.cc \ +  precomputation.cc \ +  rule.cc \ +  rule_extractor.cc \ +  rule_extractor_helper.cc \ +  rule_factory.cc \ +  sampler.cc \ +  scorer.cc \ +  suffix_array.cc \ +  target_phrase_extractor.cc \ +  time_util.cc \ +  translation_table.cc \ +  vocabulary.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(CXX11_SWITCH) -fopenmp $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS) +AM_LDFLAGS = -fopenmp +endif diff --git a/extractor/README.md b/extractor/README.md new file mode 100644 index 00000000..575f5ca5 --- /dev/null +++ b/extractor/README.md @@ -0,0 +1,14 @@ +C++ implementation of the online grammar extractor originally developed by [Adam Lopez](http://www.cs.jhu.edu/~alopez/). + +To run the extractor you need to: + +    cdec/extractor/run_extractor -a <alignment> -b <parallel_corpus> -g <grammar_output_path> < <input_sentences> > <sgm_file> + +To run unit tests you need first to configure `cdec` with the [Google Test](https://code.google.com/p/googletest/) and [Google Mock](https://code.google.com/p/googlemock/) libraries: + +    ./configure --with-gtest=</absolute/path/to/gtest> --with-gmock=</absolute/path/to/gmock> + +Then, you simply need to: + +    cd cdec/extractor +    make check diff --git a/extractor/alignment.cc b/extractor/alignment.cc new file mode 100644 index 00000000..1aea34b3 --- /dev/null +++ b/extractor/alignment.cc @@ -0,0 +1,53 @@ +#include "alignment.h" + +#include <fstream> +#include <sstream> +#include <string> +#include <fcntl.h> +#include <unistd.h> +#include <vector> + +#include <boost/algorithm/string.hpp> +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +namespace extractor { + +Alignment::Alignment(const string& filename) { +  ifstream infile(filename.c_str()); +  string line; +  while (getline(infile, line)) { +    vector<string> items; +    boost::split(items, line, boost::is_any_of(" -")); +    vector<pair<int, int> > alignment; +    alignment.reserve(items.size() / 2); +    for (size_t i = 0; i < items.size(); i += 2) { +      alignment.push_back(make_pair(stoi(items[i]), stoi(items[i + 1]))); +    } +    alignments.push_back(alignment); +  } +  alignments.shrink_to_fit(); +} + +Alignment::Alignment() {} + +Alignment::~Alignment() {} + +vector<pair<int, int> > Alignment::GetLinks(int sentence_index) const { +  return alignments[sentence_index]; +} + +void Alignment::WriteBinary(const fs::path& filepath) { +  FILE* file = fopen(filepath.string().c_str(), "w"); +  int size = alignments.size(); +  fwrite(&size, sizeof(int), 1, file); +  for (vector<pair<int, int> > alignment: alignments) { +    size = alignment.size(); +    fwrite(&size, sizeof(int), 1, file); +    fwrite(alignment.data(), sizeof(pair<int, int>), size, file); +  } +} + +} // namespace extractor diff --git a/extractor/alignment.h b/extractor/alignment.h new file mode 100644 index 00000000..e9292121 --- /dev/null +++ b/extractor/alignment.h @@ -0,0 +1,39 @@ +#ifndef _ALIGNMENT_H_ +#define _ALIGNMENT_H_ + +#include <string> +#include <vector> + +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +namespace extractor { + +/** + * Data structure storing the word alignments for a parallel corpus. + */ +class Alignment { + public: +  // Reads alignment from text file. +  Alignment(const string& filename); + +  // Returns the alignment for a given sentence. +  virtual vector<pair<int, int> > GetLinks(int sentence_index) const; + +  // Writes alignment to file in binary format. +  void WriteBinary(const fs::path& filepath); + +  virtual ~Alignment(); + + protected: +  Alignment(); + + private: +  vector<vector<pair<int, int> > > alignments; +}; + +} // namespace extractor + +#endif diff --git a/extractor/alignment_test.cc b/extractor/alignment_test.cc new file mode 100644 index 00000000..a7defb66 --- /dev/null +++ b/extractor/alignment_test.cc @@ -0,0 +1,33 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "alignment.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class AlignmentTest : public Test { + protected: +  virtual void SetUp() { +    alignment = make_shared<Alignment>("sample_alignment.txt"); +  } + +  shared_ptr<Alignment> alignment; +}; + +TEST_F(AlignmentTest, TestGetLinks) { +  vector<pair<int, int> > expected_links = { +    make_pair(0, 0), make_pair(1, 1), make_pair(2, 2) +  }; +  EXPECT_EQ(expected_links, alignment->GetLinks(0)); +  expected_links = {make_pair(1, 0), make_pair(2, 1)}; +  EXPECT_EQ(expected_links, alignment->GetLinks(1)); +} + +} // namespace +} // namespace extractor diff --git a/extractor/compile.cc b/extractor/compile.cc new file mode 100644 index 00000000..a9ae2cef --- /dev/null +++ b/extractor/compile.cc @@ -0,0 +1,100 @@ +#include <iostream> +#include <string> + +#include <boost/filesystem.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "alignment.h" +#include "data_array.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "translation_table.h" + +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace std; +using namespace extractor; + +int main(int argc, char** argv) { +  po::options_description desc("Command line options"); +  desc.add_options() +    ("help,h", "Show available options") +    ("source,f", po::value<string>(), "Source language corpus") +    ("target,e", po::value<string>(), "Target language corpus") +    ("bitext,b", po::value<string>(), "Parallel text (source ||| target)") +    ("alignment,a", po::value<string>()->required(), "Bitext word alignment") +    ("output,o", po::value<string>()->required(), "Output path") +    ("frequent", po::value<int>()->default_value(100), +        "Number of precomputed frequent patterns") +    ("super_frequent", po::value<int>()->default_value(10), +        "Number of precomputed super frequent patterns") +    ("max_rule_span,s", po::value<int>()->default_value(15), +        "Maximum rule span") +    ("max_rule_symbols,l", po::value<int>()->default_value(5), +        "Maximum number of symbols (terminals + nontermals) in a rule") +    ("min_gap_size,g", po::value<int>()->default_value(1), "Minimum gap size") +    ("max_phrase_len,p", po::value<int>()->default_value(4), +        "Maximum frequent phrase length") +    ("min_frequency", po::value<int>()->default_value(1000), +        "Minimum number of occurrences for a pharse to be considered frequent"); + +  po::variables_map vm; +  po::store(po::parse_command_line(argc, argv, desc), vm); + +  // Check for help argument before notify, so we don't need to pass in the +  // required parameters. +  if (vm.count("help")) { +    cout << desc << endl; +    return 0; +  } + +  po::notify(vm); + +  if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) { +    cerr << "A paralel corpus is required. " +         << "Use -f (source) with -e (target) or -b (bitext)." +         << endl; +    return 1; +  } + +  fs::path output_dir(vm["output"].as<string>().c_str()); +  if (!fs::exists(output_dir)) { +    fs::create_directory(output_dir); +  } + +  shared_ptr<DataArray> source_data_array, target_data_array; +  if (vm.count("bitext")) { +    source_data_array = make_shared<DataArray>( +        vm["bitext"].as<string>(), SOURCE); +    target_data_array = make_shared<DataArray>( +        vm["bitext"].as<string>(), TARGET); +  } else { +    source_data_array = make_shared<DataArray>(vm["source"].as<string>()); +    target_data_array = make_shared<DataArray>(vm["target"].as<string>()); +  } +  shared_ptr<SuffixArray> source_suffix_array = +      make_shared<SuffixArray>(source_data_array); +  source_suffix_array->WriteBinary(output_dir / fs::path("f.bin")); +  target_data_array->WriteBinary(output_dir / fs::path("e.bin")); + +  shared_ptr<Alignment> alignment = +      make_shared<Alignment>(vm["alignment"].as<string>()); +  alignment->WriteBinary(output_dir / fs::path("a.bin")); + +  Precomputation precomputation( +      source_suffix_array, +      vm["frequent"].as<int>(), +      vm["super_frequent"].as<int>(), +      vm["max_rule_span"].as<int>(), +      vm["max_rule_symbols"].as<int>(), +      vm["min_gap_size"].as<int>(), +      vm["max_phrase_len"].as<int>(), +      vm["min_frequency"].as<int>()); +  precomputation.WriteBinary(output_dir / fs::path("precompute.bin")); + +  TranslationTable table(source_data_array, target_data_array, alignment); +  table.WriteBinary(output_dir / fs::path("lex.bin")); + +  return 0; +} diff --git a/extractor/data_array.cc b/extractor/data_array.cc new file mode 100644 index 00000000..203fe219 --- /dev/null +++ b/extractor/data_array.cc @@ -0,0 +1,161 @@ +#include "data_array.h" + +#include <fstream> +#include <iostream> +#include <sstream> +#include <string> + +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +namespace extractor { + +int DataArray::NULL_WORD = 0; +int DataArray::END_OF_LINE = 1; +string DataArray::NULL_WORD_STR = "__NULL__"; +string DataArray::END_OF_LINE_STR = "__END_OF_LINE__"; + +DataArray::DataArray() { +  InitializeDataArray(); +} + +DataArray::DataArray(const string& filename) { +  InitializeDataArray(); +  ifstream infile(filename.c_str()); +  vector<string> lines; +  string line; +  while (getline(infile, line)) { +    lines.push_back(line); +  } +  CreateDataArray(lines); +} + +DataArray::DataArray(const string& filename, const Side& side) { +  InitializeDataArray(); +  ifstream infile(filename.c_str()); +  vector<string> lines; +  string line, delimiter = "|||"; +  while (getline(infile, line)) { +    int position = line.find(delimiter); +    if (side == SOURCE) { +      lines.push_back(line.substr(0, position)); +    } else { +      lines.push_back(line.substr(position + delimiter.size())); +    } +  } +  CreateDataArray(lines); +} + +void DataArray::InitializeDataArray() { +  word2id[NULL_WORD_STR] = NULL_WORD; +  id2word.push_back(NULL_WORD_STR); +  word2id[END_OF_LINE_STR] = END_OF_LINE; +  id2word.push_back(END_OF_LINE_STR); +} + +void DataArray::CreateDataArray(const vector<string>& lines) { +  for (size_t i = 0; i < lines.size(); ++i) { +    sentence_start.push_back(data.size()); + +    istringstream iss(lines[i]); +    string word; +    while (iss >> word) { +      if (word2id.count(word) == 0) { +        word2id[word] = id2word.size(); +        id2word.push_back(word); +      } +      data.push_back(word2id[word]); +      sentence_id.push_back(i); +    } +    data.push_back(END_OF_LINE); +    sentence_id.push_back(i); +  } +  sentence_start.push_back(data.size()); + +  data.shrink_to_fit(); +  sentence_id.shrink_to_fit(); +  sentence_start.shrink_to_fit(); +} + +DataArray::~DataArray() {} + +const vector<int>& DataArray::GetData() const { +  return data; +} + +int DataArray::AtIndex(int index) const { +  return data[index]; +} + +string DataArray::GetWordAtIndex(int index) const { +  return id2word[data[index]]; +} + +int DataArray::GetSize() const { +  return data.size(); +} + +int DataArray::GetVocabularySize() const { +  return id2word.size(); +} + +int DataArray::GetNumSentences() const { +  return sentence_start.size() - 1; +} + +int DataArray::GetSentenceStart(int position) const { +  return sentence_start[position]; +} + +int DataArray::GetSentenceLength(int sentence_id) const { +  // Ignore end of line markers. +  return sentence_start[sentence_id + 1] - sentence_start[sentence_id] - 1; +} + +int DataArray::GetSentenceId(int position) const { +  return sentence_id[position]; +} + +void DataArray::WriteBinary(const fs::path& filepath) const { +  std::cerr << "File: " << filepath.string() << std::endl; +  WriteBinary(fopen(filepath.string().c_str(), "w")); +} + +void DataArray::WriteBinary(FILE* file) const { +  int size = id2word.size(); +  fwrite(&size, sizeof(int), 1, file); +  for (string word: id2word) { +    size = word.size(); +    fwrite(&size, sizeof(int), 1, file); +    fwrite(word.data(), sizeof(char), size, file); +  } + +  size = data.size(); +  fwrite(&size, sizeof(int), 1, file); +  fwrite(data.data(), sizeof(int), size, file); + +  size = sentence_id.size(); +  fwrite(&size, sizeof(int), 1, file); +  fwrite(sentence_id.data(), sizeof(int), size, file); + +  size = sentence_start.size(); +  fwrite(&size, sizeof(int), 1, file); +  fwrite(sentence_start.data(), sizeof(int), 1, file); +} + +bool DataArray::HasWord(const string& word) const { +  return word2id.count(word); +} + +int DataArray::GetWordId(const string& word) const { +  auto result = word2id.find(word); +  return result == word2id.end() ? -1 : result->second; +} + +string DataArray::GetWord(int word_id) const { +  return id2word[word_id]; +} + +} // namespace extractor diff --git a/extractor/data_array.h b/extractor/data_array.h new file mode 100644 index 00000000..978a6931 --- /dev/null +++ b/extractor/data_array.h @@ -0,0 +1,110 @@ +#ifndef _DATA_ARRAY_H_ +#define _DATA_ARRAY_H_ + +#include <string> +#include <unordered_map> +#include <vector> + +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +namespace extractor { + +enum Side { +  SOURCE, +  TARGET +}; + +/** + * Data structure storing information about a single side of a parallel corpus. + * + * Each word is mapped to a unique integer (word_id). The data structure holds + * the corpus in the numberized format, together with the hash table mapping + * words to word_ids. It also holds additional information such as the starting + * index for each sentence and, for each token, the index of the sentence it + * belongs to. + * + * Note: This class has features for both the source and target data arrays. + * Maybe we can save some memory by having more specific implementations (not + * likely to save a lot of memory tough). + */ +class DataArray { + public: +  static int NULL_WORD; +  static int END_OF_LINE; +  static string NULL_WORD_STR; +  static string END_OF_LINE_STR; + +  // Reads data array from text file. +  DataArray(const string& filename); + +  // Reads data array from bitext file where the sentences are separated by |||. +  DataArray(const string& filename, const Side& side); + +  virtual ~DataArray(); + +  // Returns a vector containing the word ids. +  virtual const vector<int>& GetData() const; + +  // Returns the word id at the specified position. +  virtual int AtIndex(int index) const; + +  // Returns the original word at the specified position. +  virtual string GetWordAtIndex(int index) const; + +  // Returns the size of the data array. +  virtual int GetSize() const; + +  // Returns the number of distinct words in the data array. +  virtual int GetVocabularySize() const; + +  // Returns whether a word has ever been observed in the data array. +  virtual bool HasWord(const string& word) const; + +  // Returns the word id for a given word or -1 if it the word has never been +  // observed. +  virtual int GetWordId(const string& word) const; + +  // Returns the word corresponding to a particular word id. +  virtual string GetWord(int word_id) const; + +  // Returns the number of sentences in the data. +  virtual int GetNumSentences() const; + +  // Returns the index where the sentence containing the given position starts. +  virtual int GetSentenceStart(int position) const; + +  // Returns the length of the sentence. +  virtual int GetSentenceLength(int sentence_id) const; + +  // Returns the number of the sentence containing the given position. +  virtual int GetSentenceId(int position) const; + +  // Writes data array to file in binary format. +  void WriteBinary(const fs::path& filepath) const; + +  // Writes data array to file in binary format. +  void WriteBinary(FILE* file) const; + + protected: +  DataArray(); + + private: +  // Sets up specific constants. +  void InitializeDataArray(); + +  // Constructs the data array. +  void CreateDataArray(const vector<string>& lines); + +  unordered_map<string, int> word2id; +  vector<string> id2word; +  vector<int> data; +  vector<int> sentence_id; +  vector<int> sentence_start; +}; + +} // namespace extractor + +#endif diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc new file mode 100644 index 00000000..71175fda --- /dev/null +++ b/extractor/data_array_test.cc @@ -0,0 +1,98 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include <boost/filesystem.hpp> + +#include "data_array.h" + +using namespace std; +using namespace ::testing; +namespace fs = boost::filesystem; + +namespace extractor { +namespace { + +class DataArrayTest : public Test { + protected: +  virtual void SetUp() { +    string sample_test_file("sample_bitext.txt"); +    source_data = make_shared<DataArray>(sample_test_file, SOURCE); +    target_data = make_shared<DataArray>(sample_test_file, TARGET); +  } + +  shared_ptr<DataArray> source_data; +  shared_ptr<DataArray> target_data; +}; + +TEST_F(DataArrayTest, TestGetData) { +  vector<int> expected_source_data = {2, 3, 4, 5, 1, 2, 6, 7, 8, 5, 1}; +  vector<string> expected_source_words = { +      "ana", "are", "mere", ".", "__END_OF_LINE__", +      "ana", "bea", "mult", "lapte", ".", "__END_OF_LINE__" +  }; +  EXPECT_EQ(expected_source_data, source_data->GetData()); +  EXPECT_EQ(expected_source_data.size(), source_data->GetSize()); +  for (size_t i = 0; i < expected_source_data.size(); ++i) { +    EXPECT_EQ(expected_source_data[i], source_data->AtIndex(i)); +    EXPECT_EQ(expected_source_words[i], source_data->GetWordAtIndex(i)); +  } + +  vector<int> expected_target_data = {2, 3, 4, 5, 1, 2, 6, 7, 8, 9, 10, 5, 1}; +  vector<string> expected_target_words = { +      "anna", "has", "apples", ".", "__END_OF_LINE__", +      "anna", "drinks", "a", "lot", "of", "milk", ".", "__END_OF_LINE__" +  }; +  EXPECT_EQ(expected_target_data, target_data->GetData()); +  EXPECT_EQ(expected_target_data.size(), target_data->GetSize()); +  for (size_t i = 0; i < expected_target_data.size(); ++i) { +    EXPECT_EQ(expected_target_data[i], target_data->AtIndex(i)); +    EXPECT_EQ(expected_target_words[i], target_data->GetWordAtIndex(i)); +  } +} + +TEST_F(DataArrayTest, TestVocabulary) { +  EXPECT_EQ(9, source_data->GetVocabularySize()); +  EXPECT_TRUE(source_data->HasWord("mere")); +  EXPECT_EQ(4, source_data->GetWordId("mere")); +  EXPECT_EQ("mere", source_data->GetWord(4)); +  EXPECT_FALSE(source_data->HasWord("banane")); + +  EXPECT_EQ(11, target_data->GetVocabularySize()); +  EXPECT_TRUE(target_data->HasWord("apples")); +  EXPECT_EQ(4, target_data->GetWordId("apples")); +  EXPECT_EQ("apples", target_data->GetWord(4)); +  EXPECT_FALSE(target_data->HasWord("bananas")); +} + +TEST_F(DataArrayTest, TestSentenceData) { +  EXPECT_EQ(2, source_data->GetNumSentences()); +  EXPECT_EQ(0, source_data->GetSentenceStart(0)); +  EXPECT_EQ(5, source_data->GetSentenceStart(1)); +  EXPECT_EQ(11, source_data->GetSentenceStart(2)); + +  EXPECT_EQ(4, source_data->GetSentenceLength(0)); +  EXPECT_EQ(5, source_data->GetSentenceLength(1)); + +  vector<int> expected_source_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}; +  for (size_t i = 0; i < expected_source_ids.size(); ++i) { +    EXPECT_EQ(expected_source_ids[i], source_data->GetSentenceId(i)); +  } + +  EXPECT_EQ(2, target_data->GetNumSentences()); +  EXPECT_EQ(0, target_data->GetSentenceStart(0)); +  EXPECT_EQ(5, target_data->GetSentenceStart(1)); +  EXPECT_EQ(13, target_data->GetSentenceStart(2)); + +  EXPECT_EQ(4, target_data->GetSentenceLength(0)); +  EXPECT_EQ(7, target_data->GetSentenceLength(1)); + +  vector<int> expected_target_ids = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}; +  for (size_t i = 0; i < expected_target_ids.size(); ++i) { +    EXPECT_EQ(expected_target_ids[i], target_data->GetSentenceId(i)); +  } +} + +} // namespace +} // namespace extractor diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc new file mode 100644 index 00000000..2a7693b2 --- /dev/null +++ b/extractor/fast_intersector.cc @@ -0,0 +1,195 @@ +#include "fast_intersector.h" + +#include <cassert> + +#include "data_array.h" +#include "phrase.h" +#include "phrase_location.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "vocabulary.h" + +namespace extractor { + +FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array, +                                 shared_ptr<Precomputation> precomputation, +                                 shared_ptr<Vocabulary> vocabulary, +                                 int max_rule_span, +                                 int min_gap_size) : +    suffix_array(suffix_array), +    vocabulary(vocabulary), +    max_rule_span(max_rule_span), +    min_gap_size(min_gap_size) { +  Index precomputed_collocations = precomputation->GetCollocations(); +  for (pair<vector<int>, vector<int> > entry: precomputed_collocations) { +    vector<int> phrase = ConvertPhrase(entry.first); +    collocations[phrase] = entry.second; +  } +} + +FastIntersector::FastIntersector() {} + +FastIntersector::~FastIntersector() {} + +vector<int> FastIntersector::ConvertPhrase(const vector<int>& old_phrase) { +  vector<int> new_phrase; +  new_phrase.reserve(old_phrase.size()); +  shared_ptr<DataArray> data_array = suffix_array->GetData(); +  for (int word_id: old_phrase) { +    if (word_id < 0) { +      new_phrase.push_back(word_id); +    } else { +      new_phrase.push_back( +          vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); +    } +  } +  return new_phrase; +} + +PhraseLocation FastIntersector::Intersect( +    PhraseLocation& prefix_location, +    PhraseLocation& suffix_location, +    const Phrase& phrase) { +  vector<int> symbols = phrase.Get(); + +  // We should never attempt to do an intersect query for a pattern starting or +  // ending with a non terminal. The RuleFactory should handle these cases, +  // initializing the matchings list with the one for the pattern without the +  // starting or ending terminal. +  assert(vocabulary->IsTerminal(symbols.front()) +      && vocabulary->IsTerminal(symbols.back())); + +  if (collocations.count(symbols)) { +    return PhraseLocation(collocations[symbols], phrase.Arity() + 1); +  } + +  bool prefix_ends_with_x = +      !vocabulary->IsTerminal(symbols[symbols.size() - 2]); +  bool suffix_starts_with_x = !vocabulary->IsTerminal(symbols[1]); +  if (EstimateNumOperations(prefix_location, prefix_ends_with_x) <= +      EstimateNumOperations(suffix_location, suffix_starts_with_x)) { +    return ExtendPrefixPhraseLocation(prefix_location, phrase, +                                      prefix_ends_with_x, symbols.back()); +  } else { +    return ExtendSuffixPhraseLocation(suffix_location, phrase, +                                      suffix_starts_with_x, symbols.front()); +  } +} + +int FastIntersector::EstimateNumOperations( +    const PhraseLocation& phrase_location, bool has_margin_x) const { +  int num_locations = phrase_location.GetSize(); +  return has_margin_x ? num_locations * max_rule_span : num_locations; +} + +PhraseLocation FastIntersector::ExtendPrefixPhraseLocation( +    PhraseLocation& prefix_location, const Phrase& phrase, +    bool prefix_ends_with_x, int next_symbol) const { +  ExtendPhraseLocation(prefix_location); +  vector<int> positions = *prefix_location.matchings; +  int num_subpatterns = prefix_location.num_subpatterns; + +  vector<int> new_positions; +  shared_ptr<DataArray> data_array = suffix_array->GetData(); +  int data_array_symbol = data_array->GetWordId( +      vocabulary->GetTerminalValue(next_symbol)); +  if (data_array_symbol == -1) { +    return PhraseLocation(new_positions, num_subpatterns); +  } + +  pair<int, int> range = GetSearchRange(prefix_ends_with_x); +  for (size_t i = 0; i < positions.size(); i += num_subpatterns) { +    int sent_id = data_array->GetSentenceId(positions[i]); +    int sent_end = data_array->GetSentenceStart(sent_id + 1) - 1; +    int pattern_end = positions[i + num_subpatterns - 1] + range.first; +    if (prefix_ends_with_x) { +      pattern_end += phrase.GetChunkLen(phrase.Arity() - 1) - 1; +    } else { +      pattern_end += phrase.GetChunkLen(phrase.Arity()) - 2; +    } +    // Searches for the last symbol in the phrase after each prefix occurrence. +    for (int j = range.first; j < range.second; ++j) { +      if (pattern_end >= sent_end || +          pattern_end - positions[i] >= max_rule_span) { +        break; +      } + +      if (data_array->AtIndex(pattern_end) == data_array_symbol) { +        new_positions.insert(new_positions.end(), positions.begin() + i, +                             positions.begin() + i + num_subpatterns); +        if (prefix_ends_with_x) { +          new_positions.push_back(pattern_end); +        } +      } +      ++pattern_end; +    } +  } + +  return PhraseLocation(new_positions, phrase.Arity() + 1); +} + +PhraseLocation FastIntersector::ExtendSuffixPhraseLocation( +    PhraseLocation& suffix_location, const Phrase& phrase, +    bool suffix_starts_with_x, int prev_symbol) const { +  ExtendPhraseLocation(suffix_location); +  vector<int> positions = *suffix_location.matchings; +  int num_subpatterns = suffix_location.num_subpatterns; + +  vector<int> new_positions; +  shared_ptr<DataArray> data_array = suffix_array->GetData(); +  int data_array_symbol = data_array->GetWordId( +      vocabulary->GetTerminalValue(prev_symbol)); +  if (data_array_symbol == -1) { +    return PhraseLocation(new_positions, num_subpatterns); +  } + +  pair<int, int> range = GetSearchRange(suffix_starts_with_x); +  for (size_t i = 0; i < positions.size(); i += num_subpatterns) { +    int sent_id = data_array->GetSentenceId(positions[i]); +    int sent_start = data_array->GetSentenceStart(sent_id); +    int pattern_start = positions[i] - range.first; +    int pattern_end = positions[i + num_subpatterns - 1] + +        phrase.GetChunkLen(phrase.Arity()) - 1; +    // Searches for the first symbol in the phrase before each suffix +    // occurrence. +    for (int j = range.first; j < range.second; ++j) { +      if (pattern_start < sent_start || +          pattern_end - pattern_start >= max_rule_span) { +        break; +      } + +      if (data_array->AtIndex(pattern_start) == data_array_symbol) { +        new_positions.push_back(pattern_start); +        new_positions.insert(new_positions.end(), +                             positions.begin() + i + !suffix_starts_with_x, +                             positions.begin() + i + num_subpatterns); +      } +      --pattern_start; +    } +  } + +  return PhraseLocation(new_positions, phrase.Arity() + 1); +} + +void FastIntersector::ExtendPhraseLocation(PhraseLocation& location) const { +  if (location.matchings != NULL) { +    return; +  } + +  location.num_subpatterns = 1; +  location.matchings = make_shared<vector<int> >(); +  for (int i = location.sa_low; i < location.sa_high; ++i) { +    location.matchings->push_back(suffix_array->GetSuffix(i)); +  } +  location.sa_low = location.sa_high = 0; +} + +pair<int, int> FastIntersector::GetSearchRange(bool has_marginal_x) const { +  if (has_marginal_x) { +    return make_pair(min_gap_size + 1, max_rule_span); +  } else { +    return make_pair(1, 2); +  } +} + +} // namespace extractor diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h new file mode 100644 index 00000000..f950a2a9 --- /dev/null +++ b/extractor/fast_intersector.h @@ -0,0 +1,96 @@ +#ifndef _FAST_INTERSECTOR_H_ +#define _FAST_INTERSECTOR_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +#include <boost/functional/hash.hpp> + +using namespace std; + +namespace extractor { + +typedef boost::hash<vector<int> > VectorHash; +typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; + +class Phrase; +class PhraseLocation; +class Precomputation; +class SuffixArray; +class Vocabulary; + +/** + * Component for searching the training data for occurrences of source phrases + * containing nonterminals + * + * Given a source phrase containing a nonterminal, we first query the + * precomputed index containing frequent collocations. If the phrase is not + * frequent enough, we extend the matchings of either its prefix or its suffix, + * depending on which operation seems to require less computations. + * + * Note: This method for intersecting phrase locations is faster than both + * mergers (linear or Baeza Yates) described in Adam Lopez' dissertation. + */ +class FastIntersector { + public: +  FastIntersector(shared_ptr<SuffixArray> suffix_array, +                  shared_ptr<Precomputation> precomputation, +                  shared_ptr<Vocabulary> vocabulary, +                  int max_rule_span, +                  int min_gap_size); + +  virtual ~FastIntersector(); + +  // Finds the locations of a phrase given the locations of its prefix and +  // suffix. +  virtual PhraseLocation Intersect(PhraseLocation& prefix_location, +                                   PhraseLocation& suffix_location, +                                   const Phrase& phrase); + + protected: +  FastIntersector(); + + private: +  // Uses the vocabulary to convert the phrase from the numberized format +  // specified by the source data array to the numberized format given by the +  // vocabulary. +  vector<int> ConvertPhrase(const vector<int>& old_phrase); + +  // Estimates the number of computations needed if the prefix/suffix is +  // extended. If the last/first symbol is separated from the rest of the phrase +  // by a nonterminal, then for each occurrence of the prefix/suffix we need to +  // check max_rule_span positions. Otherwise, we only need to check a single +  // position for each occurrence. +  int EstimateNumOperations(const PhraseLocation& phrase_location, +                            bool has_margin_x) const; + +  // Uses the occurrences of the prefix to find the occurrences of the phrase. +  PhraseLocation ExtendPrefixPhraseLocation(PhraseLocation& prefix_location, +                                            const Phrase& phrase, +                                            bool prefix_ends_with_x, +                                            int next_symbol) const; + +  // Uses the occurrences of the suffix to find the occurrences of the phrase. +  PhraseLocation ExtendSuffixPhraseLocation(PhraseLocation& suffix_location, +                                            const Phrase& phrase, +                                            bool suffix_starts_with_x, +                                            int prev_symbol) const; + +  // Extends the prefix/suffix location to a list of subpatterns positions if it +  // represents a suffix array range. +  void ExtendPhraseLocation(PhraseLocation& location) const; + +  // Returns the range in which the search should be performed. +  pair<int, int> GetSearchRange(bool has_marginal_x) const; + +  shared_ptr<SuffixArray> suffix_array; +  shared_ptr<Vocabulary> vocabulary; +  int max_rule_span; +  int min_gap_size; +  Index collocations; +}; + +} // namespace extractor + +#endif diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc new file mode 100644 index 00000000..76c3aaea --- /dev/null +++ b/extractor/fast_intersector_test.cc @@ -0,0 +1,146 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "fast_intersector.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_precomputation.h" +#include "mocks/mock_suffix_array.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_location.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class FastIntersectorTest : public Test { + protected: +  virtual void SetUp() { +    vector<string> words = {"EOL", "it", "makes", "him", "and", "mars", ",", +                            "sets", "on", "takes", "off", "."}; +    vocabulary = make_shared<MockVocabulary>(); +    for (size_t i = 0; i < words.size(); ++i) { +      EXPECT_CALL(*vocabulary, GetTerminalIndex(words[i])) +          .WillRepeatedly(Return(i)); +      EXPECT_CALL(*vocabulary, GetTerminalValue(i)) +          .WillRepeatedly(Return(words[i])); +    } + +    vector<int> data = {1, 2, 3, 4, 1, 5, 3, 6, 1, +                        7, 3, 8, 4, 1, 9, 3, 10, 11, 0}; +    data_array = make_shared<MockDataArray>(); +    for (size_t i = 0; i < data.size(); ++i) { +      EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); +      EXPECT_CALL(*data_array, GetSentenceId(i)) +          .WillRepeatedly(Return(0)); +    } +    EXPECT_CALL(*data_array, GetSentenceStart(0)) +        .WillRepeatedly(Return(0)); +    EXPECT_CALL(*data_array, GetSentenceStart(1)) +        .WillRepeatedly(Return(19)); +    for (size_t i = 0; i < words.size(); ++i) { +      EXPECT_CALL(*data_array, GetWordId(words[i])) +          .WillRepeatedly(Return(i)); +      EXPECT_CALL(*data_array, GetWord(i)) +          .WillRepeatedly(Return(words[i])); +    } + +    vector<int> suffixes = {18, 0, 4, 8, 13, 1, 2, 6, 10, 15, 3, 12, 5, 7, 9, +                            11, 14, 16, 17}; +    suffix_array = make_shared<MockSuffixArray>(); +    EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); +    for (size_t i = 0; i < suffixes.size(); ++i) { +      EXPECT_CALL(*suffix_array, GetSuffix(i)). +          WillRepeatedly(Return(suffixes[i])); +    } + +    precomputation = make_shared<MockPrecomputation>(); +    EXPECT_CALL(*precomputation, GetCollocations()) +        .WillRepeatedly(ReturnRef(collocations)); + +    phrase_builder = make_shared<PhraseBuilder>(vocabulary); +    intersector = make_shared<FastIntersector>(suffix_array, precomputation, +                                               vocabulary, 15, 1); +  } + +  Index collocations; +  shared_ptr<MockDataArray> data_array; +  shared_ptr<MockSuffixArray> suffix_array; +  shared_ptr<MockPrecomputation> precomputation; +  shared_ptr<MockVocabulary> vocabulary; +  shared_ptr<FastIntersector> intersector; +  shared_ptr<PhraseBuilder> phrase_builder; +}; + +TEST_F(FastIntersectorTest, TestCachedCollocation) { +  vector<int> symbols = {8, -1, 9}; +  vector<int> expected_location = {11}; +  Phrase phrase = phrase_builder->Build(symbols); +  PhraseLocation prefix_location(15, 16), suffix_location(16, 17); + +  collocations[symbols] = expected_location; +  EXPECT_CALL(*precomputation, GetCollocations()) +      .WillRepeatedly(ReturnRef(collocations)); +  intersector = make_shared<FastIntersector>(suffix_array, precomputation, +                                             vocabulary, 15, 1); + +  PhraseLocation result = intersector->Intersect( +      prefix_location, suffix_location, phrase); + +  EXPECT_EQ(PhraseLocation(expected_location, 2), result); +  EXPECT_EQ(PhraseLocation(15, 16), prefix_location); +  EXPECT_EQ(PhraseLocation(16, 17), suffix_location); +} + +TEST_F(FastIntersectorTest, TestIntersectaXbXcExtendSuffix) { +  vector<int> symbols = {1, -1, 3, -1, 1}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> prefix_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, +                             8, 15, 3, 15}; +  vector<int> suffix_locs = {2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; +  PhraseLocation prefix_location(prefix_locs, 2); +  PhraseLocation suffix_location(suffix_locs, 2); + +  vector<int> expected_locs = {0, 2, 4, 0, 2, 8, 0, 2, 13, 4, 6, 8, 0, 6, 8, +                               4, 6, 13, 0, 6, 13, 8, 10, 13, 4, 10, 13, +                               0, 10, 13}; +  PhraseLocation result = intersector->Intersect( +      prefix_location, suffix_location, phrase); +  EXPECT_EQ(PhraseLocation(expected_locs, 3), result); +} + +TEST_F(FastIntersectorTest, TestIntersectaXbExtendPrefix) { +  vector<int> symbols = {1, -1, 3}; +  Phrase phrase = phrase_builder->Build(symbols); +  PhraseLocation prefix_location(1, 5), suffix_location(6, 10); + +  vector<int> expected_prefix_locs = {0, 4, 8, 13}; +  vector<int> expected_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, +                               8, 15, 13, 15}; +  PhraseLocation result = intersector->Intersect( +      prefix_location, suffix_location, phrase); +  EXPECT_EQ(PhraseLocation(expected_locs, 2), result); +  EXPECT_EQ(PhraseLocation(expected_prefix_locs, 1), prefix_location); +} + +TEST_F(FastIntersectorTest, TestIntersectCheckEstimates) { +  // The suffix matches in fewer positions, but because it starts with an X +  // it requires more operations and we prefer extending the prefix. +  vector<int> symbols = {1, -1, 4, 1}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> prefix_locs = {0, 3, 0, 12, 4, 12, 8, 12}; +  PhraseLocation prefix_location(prefix_locs, 2), suffix_location(10, 12); + +  vector<int> expected_locs = {0, 3, 0, 12, 4, 12, 8, 12}; +  PhraseLocation result = intersector->Intersect( +      prefix_location, suffix_location, phrase); +  EXPECT_EQ(PhraseLocation(expected_locs, 2), result); +  EXPECT_EQ(PhraseLocation(10, 12), suffix_location); +} + +} // namespace +} // namespace extractor diff --git a/extractor/features/count_source_target.cc b/extractor/features/count_source_target.cc new file mode 100644 index 00000000..db0385e0 --- /dev/null +++ b/extractor/features/count_source_target.cc @@ -0,0 +1,17 @@ +#include "count_source_target.h" + +#include <cmath> + +namespace extractor { +namespace features { + +double CountSourceTarget::Score(const FeatureContext& context) const { +  return log10(1 + context.pair_count); +} + +string CountSourceTarget::GetName() const { +  return "CountEF"; +} + +} // namespace features +} // namespace extractor diff --git a/extractor/features/count_source_target.h b/extractor/features/count_source_target.h new file mode 100644 index 00000000..8747fa60 --- /dev/null +++ b/extractor/features/count_source_target.h @@ -0,0 +1,22 @@ +#ifndef _COUNT_SOURCE_TARGET_H_ +#define _COUNT_SOURCE_TARGET_H_ + +#include "feature.h" + +namespace extractor { +namespace features { + +/** + * Feature for the number of times a word pair was found in the bitext. + */ +class CountSourceTarget : public Feature { + public: +  double Score(const FeatureContext& context) const; + +  string GetName() const; +}; + +} // namespace features +} // namespace extractor + +#endif diff --git a/extractor/features/count_source_target_test.cc b/extractor/features/count_source_target_test.cc new file mode 100644 index 00000000..1fd0c2aa --- /dev/null +++ b/extractor/features/count_source_target_test.cc @@ -0,0 +1,36 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "count_source_target.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace features { +namespace { + +class CountSourceTargetTest : public Test { + protected: +  virtual void SetUp() { +    feature = make_shared<CountSourceTarget>(); +  } + +  shared_ptr<CountSourceTarget> feature; +}; + +TEST_F(CountSourceTargetTest, TestGetName) { +  EXPECT_EQ("CountEF", feature->GetName()); +} + +TEST_F(CountSourceTargetTest, TestScore) { +  Phrase phrase; +  FeatureContext context(phrase, phrase, 0.5, 9, 13); +  EXPECT_EQ(1.0, feature->Score(context)); +} + +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/feature.cc b/extractor/features/feature.cc new file mode 100644 index 00000000..939bcc59 --- /dev/null +++ b/extractor/features/feature.cc @@ -0,0 +1,11 @@ +#include "feature.h" + +namespace extractor { +namespace features { + +const double Feature::MAX_SCORE = 99.0; + +Feature::~Feature() {} + +} // namespace features +} // namespace extractor diff --git a/extractor/features/feature.h b/extractor/features/feature.h new file mode 100644 index 00000000..36ea504a --- /dev/null +++ b/extractor/features/feature.h @@ -0,0 +1,47 @@ +#ifndef _FEATURE_H_ +#define _FEATURE_H_ + +#include <string> + +#include "phrase.h" + +using namespace std; + +namespace extractor { +namespace features { + +/** + * Structure providing context for computing feature scores. + */ +struct FeatureContext { +  FeatureContext(const Phrase& source_phrase, const Phrase& target_phrase, +                 double source_phrase_count, int pair_count, int num_samples) : +    source_phrase(source_phrase), target_phrase(target_phrase), +    source_phrase_count(source_phrase_count), pair_count(pair_count), +    num_samples(num_samples) {} + +  Phrase source_phrase; +  Phrase target_phrase; +  double source_phrase_count; +  int pair_count; +  int num_samples; +}; + +/** + * Base class for features. + */ +class Feature { + public: +  virtual double Score(const FeatureContext& context) const = 0; + +  virtual string GetName() const = 0; + +  virtual ~Feature(); + +  static const double MAX_SCORE; +}; + +} // namespace features +} // namespace extractor + +#endif diff --git a/extractor/features/is_source_singleton.cc b/extractor/features/is_source_singleton.cc new file mode 100644 index 00000000..1abb486f --- /dev/null +++ b/extractor/features/is_source_singleton.cc @@ -0,0 +1,17 @@ +#include "is_source_singleton.h" + +#include <cmath> + +namespace extractor { +namespace features { + +double IsSourceSingleton::Score(const FeatureContext& context) const { +  return fabs(context.source_phrase_count - 1) < 1e-6; +} + +string IsSourceSingleton::GetName() const { +  return "IsSingletonF"; +} + +} // namespace features +} // namespace extractor diff --git a/extractor/features/is_source_singleton.h b/extractor/features/is_source_singleton.h new file mode 100644 index 00000000..b8352d0e --- /dev/null +++ b/extractor/features/is_source_singleton.h @@ -0,0 +1,22 @@ +#ifndef _IS_SOURCE_SINGLETON_H_ +#define _IS_SOURCE_SINGLETON_H_ + +#include "feature.h" + +namespace extractor { +namespace features { + +/** + * Boolean feature checking if the source phrase occurs only once in the data. + */ +class IsSourceSingleton : public Feature { + public: +  double Score(const FeatureContext& context) const; + +  string GetName() const; +}; + +} // namespace features +} // namespace extractor + +#endif diff --git a/extractor/features/is_source_singleton_test.cc b/extractor/features/is_source_singleton_test.cc new file mode 100644 index 00000000..f4266671 --- /dev/null +++ b/extractor/features/is_source_singleton_test.cc @@ -0,0 +1,39 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "is_source_singleton.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace features { +namespace { + +class IsSourceSingletonTest : public Test { + protected: +  virtual void SetUp() { +    feature = make_shared<IsSourceSingleton>(); +  } + +  shared_ptr<IsSourceSingleton> feature; +}; + +TEST_F(IsSourceSingletonTest, TestGetName) { +  EXPECT_EQ("IsSingletonF", feature->GetName()); +} + +TEST_F(IsSourceSingletonTest, TestScore) { +  Phrase phrase; +  FeatureContext context(phrase, phrase, 0.5, 3, 31); +  EXPECT_EQ(0, feature->Score(context)); + +  context = FeatureContext(phrase, phrase, 1, 3, 25); +  EXPECT_EQ(1, feature->Score(context)); +} + +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/is_source_target_singleton.cc b/extractor/features/is_source_target_singleton.cc new file mode 100644 index 00000000..03b3c62c --- /dev/null +++ b/extractor/features/is_source_target_singleton.cc @@ -0,0 +1,17 @@ +#include "is_source_target_singleton.h" + +#include <cmath> + +namespace extractor { +namespace features { + +double IsSourceTargetSingleton::Score(const FeatureContext& context) const { +  return context.pair_count == 1; +} + +string IsSourceTargetSingleton::GetName() const { +  return "IsSingletonFE"; +} + +} // namespace features +} // namespace extractor diff --git a/extractor/features/is_source_target_singleton.h b/extractor/features/is_source_target_singleton.h new file mode 100644 index 00000000..dacfebba --- /dev/null +++ b/extractor/features/is_source_target_singleton.h @@ -0,0 +1,22 @@ +#ifndef _IS_SOURCE_TARGET_SINGLETON_H_ +#define _IS_SOURCE_TARGET_SINGLETON_H_ + +#include "feature.h" + +namespace extractor { +namespace features { + +/** + * Boolean feature checking if the phrase pair occurs only once in the data. + */ +class IsSourceTargetSingleton : public Feature { + public: +  double Score(const FeatureContext& context) const; + +  string GetName() const; +}; + +} // namespace features +} // namespace extractor + +#endif diff --git a/extractor/features/is_source_target_singleton_test.cc b/extractor/features/is_source_target_singleton_test.cc new file mode 100644 index 00000000..929635b0 --- /dev/null +++ b/extractor/features/is_source_target_singleton_test.cc @@ -0,0 +1,39 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "is_source_target_singleton.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace features { +namespace { + +class IsSourceTargetSingletonTest : public Test { + protected: +  virtual void SetUp() { +    feature = make_shared<IsSourceTargetSingleton>(); +  } + +  shared_ptr<IsSourceTargetSingleton> feature; +}; + +TEST_F(IsSourceTargetSingletonTest, TestGetName) { +  EXPECT_EQ("IsSingletonFE", feature->GetName()); +} + +TEST_F(IsSourceTargetSingletonTest, TestScore) { +  Phrase phrase; +  FeatureContext context(phrase, phrase, 0.5, 3, 7); +  EXPECT_EQ(0, feature->Score(context)); + +  context = FeatureContext(phrase, phrase, 2.3, 1, 28); +  EXPECT_EQ(1, feature->Score(context)); +} + +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/max_lex_source_given_target.cc b/extractor/features/max_lex_source_given_target.cc new file mode 100644 index 00000000..65d0ec68 --- /dev/null +++ b/extractor/features/max_lex_source_given_target.cc @@ -0,0 +1,37 @@ +#include "max_lex_source_given_target.h" + +#include <cmath> + +#include "data_array.h" +#include "translation_table.h" + +namespace extractor { +namespace features { + +MaxLexSourceGivenTarget::MaxLexSourceGivenTarget( +    shared_ptr<TranslationTable> table) : +    table(table) {} + +double MaxLexSourceGivenTarget::Score(const FeatureContext& context) const { +  vector<string> source_words = context.source_phrase.GetWords(); +  vector<string> target_words = context.target_phrase.GetWords(); +  target_words.push_back(DataArray::NULL_WORD_STR); + +  double score = 0; +  for (string source_word: source_words) { +    double max_score = 0; +    for (string target_word: target_words) { +      max_score = max(max_score, +          table->GetSourceGivenTargetScore(source_word, target_word)); +    } +    score += max_score > 0 ? -log10(max_score) : MAX_SCORE; +  } +  return score; +} + +string MaxLexSourceGivenTarget::GetName() const { +  return "MaxLexFgivenE"; +} + +} // namespace features +} // namespace extractor diff --git a/extractor/features/max_lex_source_given_target.h b/extractor/features/max_lex_source_given_target.h new file mode 100644 index 00000000..461b0ebf --- /dev/null +++ b/extractor/features/max_lex_source_given_target.h @@ -0,0 +1,34 @@ +#ifndef _MAX_LEX_SOURCE_GIVEN_TARGET_H_ +#define _MAX_LEX_SOURCE_GIVEN_TARGET_H_ + +#include <memory> + +#include "feature.h" + +using namespace std; + +namespace extractor { + +class TranslationTable; + +namespace features { + +/** + * Feature computing max(p(f | e)) across all pairs of words in the phrase pair. + */ +class MaxLexSourceGivenTarget : public Feature { + public: +  MaxLexSourceGivenTarget(shared_ptr<TranslationTable> table); + +  double Score(const FeatureContext& context) const; + +  string GetName() const; + + private: +  shared_ptr<TranslationTable> table; +}; + +} // namespace features +} // namespace extractor + +#endif diff --git a/extractor/features/max_lex_source_given_target_test.cc b/extractor/features/max_lex_source_given_target_test.cc new file mode 100644 index 00000000..7f6aae41 --- /dev/null +++ b/extractor/features/max_lex_source_given_target_test.cc @@ -0,0 +1,78 @@ +#include <gtest/gtest.h> + +#include <cmath> +#include <memory> +#include <string> + +#include "data_array.h" +#include "mocks/mock_translation_table.h" +#include "mocks/mock_vocabulary.h" +#include "phrase_builder.h" +#include "max_lex_source_given_target.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace features { +namespace { + +class MaxLexSourceGivenTargetTest : public Test { + protected: +  virtual void SetUp() { +    vector<string> source_words = {"f1", "f2", "f3"}; +    vector<string> target_words = {"e1", "e2", "e3"}; + +    vocabulary = make_shared<MockVocabulary>(); +    for (size_t i = 0; i < source_words.size(); ++i) { +      EXPECT_CALL(*vocabulary, GetTerminalValue(i)) +          .WillRepeatedly(Return(source_words[i])); +    } +    for (size_t i = 0; i < target_words.size(); ++i) { +      EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size())) +          .WillRepeatedly(Return(target_words[i])); +    } + +    phrase_builder = make_shared<PhraseBuilder>(vocabulary); + +    table = make_shared<MockTranslationTable>(); +    for (size_t i = 0; i < source_words.size(); ++i) { +      for (size_t j = 0; j < target_words.size(); ++j) { +        int value = i - j; +        EXPECT_CALL(*table, GetSourceGivenTargetScore( +            source_words[i], target_words[j])).WillRepeatedly(Return(value)); +      } +    } + +    for (size_t i = 0; i < source_words.size(); ++i) { +      int value = i * 3; +      EXPECT_CALL(*table, GetSourceGivenTargetScore( +          source_words[i], DataArray::NULL_WORD_STR)) +          .WillRepeatedly(Return(value)); +    } + +    feature = make_shared<MaxLexSourceGivenTarget>(table); +  } + +  shared_ptr<MockVocabulary> vocabulary; +  shared_ptr<PhraseBuilder> phrase_builder; +  shared_ptr<MockTranslationTable> table; +  shared_ptr<MaxLexSourceGivenTarget> feature; +}; + +TEST_F(MaxLexSourceGivenTargetTest, TestGetName) { +  EXPECT_EQ("MaxLexFgivenE", feature->GetName()); +} + +TEST_F(MaxLexSourceGivenTargetTest, TestScore) { +  vector<int> source_symbols = {0, 1, 2}; +  Phrase source_phrase = phrase_builder->Build(source_symbols); +  vector<int> target_symbols = {3, 4, 5}; +  Phrase target_phrase = phrase_builder->Build(target_symbols); +  FeatureContext context(source_phrase, target_phrase, 0.3, 7, 11); +  EXPECT_EQ(99 - log10(18), feature->Score(context)); +} + +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/max_lex_target_given_source.cc b/extractor/features/max_lex_target_given_source.cc new file mode 100644 index 00000000..33783054 --- /dev/null +++ b/extractor/features/max_lex_target_given_source.cc @@ -0,0 +1,37 @@ +#include "max_lex_target_given_source.h" + +#include <cmath> + +#include "data_array.h" +#include "translation_table.h" + +namespace extractor { +namespace features { + +MaxLexTargetGivenSource::MaxLexTargetGivenSource( +    shared_ptr<TranslationTable> table) : +    table(table) {} + +double MaxLexTargetGivenSource::Score(const FeatureContext& context) const { +  vector<string> source_words = context.source_phrase.GetWords(); +  source_words.push_back(DataArray::NULL_WORD_STR); +  vector<string> target_words = context.target_phrase.GetWords(); + +  double score = 0; +  for (string target_word: target_words) { +    double max_score = 0; +    for (string source_word: source_words) { +      max_score = max(max_score, +          table->GetTargetGivenSourceScore(source_word, target_word)); +    } +    score += max_score > 0 ? -log10(max_score) : MAX_SCORE; +  } +  return score; +} + +string MaxLexTargetGivenSource::GetName() const { +  return "MaxLexEgivenF"; +} + +} // namespace features +} // namespace extractor diff --git a/extractor/features/max_lex_target_given_source.h b/extractor/features/max_lex_target_given_source.h new file mode 100644 index 00000000..c3c87327 --- /dev/null +++ b/extractor/features/max_lex_target_given_source.h @@ -0,0 +1,34 @@ +#ifndef _MAX_LEX_TARGET_GIVEN_SOURCE_H_ +#define _MAX_LEX_TARGET_GIVEN_SOURCE_H_ + +#include <memory> + +#include "feature.h" + +using namespace std; + +namespace extractor { + +class TranslationTable; + +namespace features { + +/** + * Feature computing max(p(e | f)) across all pairs of words in the phrase pair. + */ +class MaxLexTargetGivenSource : public Feature { + public: +  MaxLexTargetGivenSource(shared_ptr<TranslationTable> table); + +  double Score(const FeatureContext& context) const; + +  string GetName() const; + + private: +  shared_ptr<TranslationTable> table; +}; + +} // namespace features +} // namespace extractor + +#endif diff --git a/extractor/features/max_lex_target_given_source_test.cc b/extractor/features/max_lex_target_given_source_test.cc new file mode 100644 index 00000000..6d0efd9c --- /dev/null +++ b/extractor/features/max_lex_target_given_source_test.cc @@ -0,0 +1,78 @@ +#include <gtest/gtest.h> + +#include <cmath> +#include <memory> +#include <string> + +#include "data_array.h" +#include "mocks/mock_translation_table.h" +#include "mocks/mock_vocabulary.h" +#include "phrase_builder.h" +#include "max_lex_target_given_source.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace features { +namespace { + +class MaxLexTargetGivenSourceTest : public Test { + protected: +  virtual void SetUp() { +    vector<string> source_words = {"f1", "f2", "f3"}; +    vector<string> target_words = {"e1", "e2", "e3"}; + +    vocabulary = make_shared<MockVocabulary>(); +    for (size_t i = 0; i < source_words.size(); ++i) { +      EXPECT_CALL(*vocabulary, GetTerminalValue(i)) +          .WillRepeatedly(Return(source_words[i])); +    } +    for (size_t i = 0; i < target_words.size(); ++i) { +      EXPECT_CALL(*vocabulary, GetTerminalValue(i + source_words.size())) +          .WillRepeatedly(Return(target_words[i])); +    } + +    phrase_builder = make_shared<PhraseBuilder>(vocabulary); + +    table = make_shared<MockTranslationTable>(); +    for (size_t i = 0; i < source_words.size(); ++i) { +      for (size_t j = 0; j < target_words.size(); ++j) { +        int value = i - j; +        EXPECT_CALL(*table, GetTargetGivenSourceScore( +            source_words[i], target_words[j])).WillRepeatedly(Return(value)); +      } +    } + +    for (size_t i = 0; i < target_words.size(); ++i) { +      int value = i * 3; +      EXPECT_CALL(*table, GetTargetGivenSourceScore( +          DataArray::NULL_WORD_STR, target_words[i])) +          .WillRepeatedly(Return(value)); +    } + +    feature = make_shared<MaxLexTargetGivenSource>(table); +  } + +  shared_ptr<MockVocabulary> vocabulary; +  shared_ptr<PhraseBuilder> phrase_builder; +  shared_ptr<MockTranslationTable> table; +  shared_ptr<MaxLexTargetGivenSource> feature; +}; + +TEST_F(MaxLexTargetGivenSourceTest, TestGetName) { +  EXPECT_EQ("MaxLexEgivenF", feature->GetName()); +} + +TEST_F(MaxLexTargetGivenSourceTest, TestScore) { +  vector<int> source_symbols = {0, 1, 2}; +  Phrase source_phrase = phrase_builder->Build(source_symbols); +  vector<int> target_symbols = {3, 4, 5}; +  Phrase target_phrase = phrase_builder->Build(target_symbols); +  FeatureContext context(source_phrase, target_phrase, 0.3, 7, 19); +  EXPECT_EQ(-log10(36), feature->Score(context)); +} + +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/sample_source_count.cc b/extractor/features/sample_source_count.cc new file mode 100644 index 00000000..b110fc51 --- /dev/null +++ b/extractor/features/sample_source_count.cc @@ -0,0 +1,17 @@ +#include "sample_source_count.h" + +#include <cmath> + +namespace extractor { +namespace features { + +double SampleSourceCount::Score(const FeatureContext& context) const { +  return log10(1 + context.num_samples); +} + +string SampleSourceCount::GetName() const { +  return "SampleCountF"; +} + +} // namespace features +} // namespace extractor diff --git a/extractor/features/sample_source_count.h b/extractor/features/sample_source_count.h new file mode 100644 index 00000000..ee6e59a0 --- /dev/null +++ b/extractor/features/sample_source_count.h @@ -0,0 +1,23 @@ +#ifndef _SAMPLE_SOURCE_COUNT_H_ +#define _SAMPLE_SOURCE_COUNT_H_ + +#include "feature.h" + +namespace extractor { +namespace features { + +/** + * Feature scoring the number of times the source phrase occurs in the sampled + * set. + */ +class SampleSourceCount : public Feature { + public: +  double Score(const FeatureContext& context) const; + +  string GetName() const; +}; + +} // namespace features +} // namespace extractor + +#endif diff --git a/extractor/features/sample_source_count_test.cc b/extractor/features/sample_source_count_test.cc new file mode 100644 index 00000000..63856b9d --- /dev/null +++ b/extractor/features/sample_source_count_test.cc @@ -0,0 +1,40 @@ +#include <gtest/gtest.h> + +#include <cmath> +#include <memory> +#include <string> + +#include "sample_source_count.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace features { +namespace { + +class SampleSourceCountTest : public Test { + protected: +  virtual void SetUp() { +    feature = make_shared<SampleSourceCount>(); +  } + +  shared_ptr<SampleSourceCount> feature; +}; + +TEST_F(SampleSourceCountTest, TestGetName) { +  EXPECT_EQ("SampleCountF", feature->GetName()); +} + +TEST_F(SampleSourceCountTest, TestScore) { +  Phrase phrase; +  FeatureContext context(phrase, phrase, 0, 3, 1); +  EXPECT_EQ(log10(2), feature->Score(context)); + +  context = FeatureContext(phrase, phrase, 3.2, 3, 9); +  EXPECT_EQ(1.0, feature->Score(context)); +} + +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/features/target_given_source_coherent.cc b/extractor/features/target_given_source_coherent.cc new file mode 100644 index 00000000..c4551d88 --- /dev/null +++ b/extractor/features/target_given_source_coherent.cc @@ -0,0 +1,18 @@ +#include "target_given_source_coherent.h" + +#include <cmath> + +namespace extractor { +namespace features { + +double TargetGivenSourceCoherent::Score(const FeatureContext& context) const { +  double prob = (double) context.pair_count / context.num_samples; +  return prob > 0 ? -log10(prob) : MAX_SCORE; +} + +string TargetGivenSourceCoherent::GetName() const { +  return "EgivenFCoherent"; +} + +} // namespace features +} // namespace extractor diff --git a/extractor/features/target_given_source_coherent.h b/extractor/features/target_given_source_coherent.h new file mode 100644 index 00000000..e66d70a5 --- /dev/null +++ b/extractor/features/target_given_source_coherent.h @@ -0,0 +1,23 @@ +#ifndef _TARGET_GIVEN_SOURCE_COHERENT_H_ +#define _TARGET_GIVEN_SOURCE_COHERENT_H_ + +#include "feature.h" + +namespace extractor { +namespace features { + +/** + * Feature computing the ratio of the phrase pair count over all source phrase + * occurrences (sampled). + */ +class TargetGivenSourceCoherent : public Feature { + public: +  double Score(const FeatureContext& context) const; + +  string GetName() const; +}; + +} // namespace features +} // namespace extractor + +#endif diff --git a/extractor/features/target_given_source_coherent_test.cc b/extractor/features/target_given_source_coherent_test.cc new file mode 100644 index 00000000..454105e1 --- /dev/null +++ b/extractor/features/target_given_source_coherent_test.cc @@ -0,0 +1,39 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> + +#include "target_given_source_coherent.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace features { +namespace { + +class TargetGivenSourceCoherentTest : public Test { + protected: +  virtual void SetUp() { +    feature = make_shared<TargetGivenSourceCoherent>(); +  } + +  shared_ptr<TargetGivenSourceCoherent> feature; +}; + +TEST_F(TargetGivenSourceCoherentTest, TestGetName) { +  EXPECT_EQ("EgivenFCoherent", feature->GetName()); +} + +TEST_F(TargetGivenSourceCoherentTest, TestScore) { +  Phrase phrase; +  FeatureContext context(phrase, phrase, 0.3, 2, 20); +  EXPECT_EQ(1.0, feature->Score(context)); + +  context = FeatureContext(phrase, phrase, 1.9, 0, 1); +  EXPECT_EQ(99.0, feature->Score(context)); +} + +} // namespace +} // namespace features +} // namespace extractor diff --git a/extractor/grammar.cc b/extractor/grammar.cc new file mode 100644 index 00000000..b45a8261 --- /dev/null +++ b/extractor/grammar.cc @@ -0,0 +1,43 @@ +#include "grammar.h" + +#include <iomanip> + +#include "rule.h" + +using namespace std; + +namespace extractor { + +Grammar::Grammar(const vector<Rule>& rules, +                 const vector<string>& feature_names) : +  rules(rules), feature_names(feature_names) {} + +vector<Rule> Grammar::GetRules() const { +  return rules; +} + +vector<string> Grammar::GetFeatureNames() const { +  return feature_names; +} + +ostream& operator<<(ostream& os, const Grammar& grammar) { +  vector<Rule> rules = grammar.GetRules(); +  vector<string> feature_names = grammar.GetFeatureNames(); +  os << setprecision(12); +  for (Rule rule: rules) { +    os << "[X] ||| " << rule.source_phrase << " ||| " +                     << rule.target_phrase << " |||"; +    for (size_t i = 0; i < rule.scores.size(); ++i) { +      os << " " << feature_names[i] << "=" << rule.scores[i]; +    } +    os << " |||"; +    for (auto link: rule.alignment) { +      os << " " << link.first << "-" << link.second; +    } +    os << '\n'; +  } + +  return os; +} + +} // namespace extractor diff --git a/extractor/grammar.h b/extractor/grammar.h new file mode 100644 index 00000000..fed41b16 --- /dev/null +++ b/extractor/grammar.h @@ -0,0 +1,34 @@ +#ifndef _GRAMMAR_H_ +#define _GRAMMAR_H_ + +#include <iostream> +#include <string> +#include <vector> + +using namespace std; + +namespace extractor { + +class Rule; + +/** + * Grammar class wrapping the set of rules to be extracted. + */ +class Grammar { + public: +  Grammar(const vector<Rule>& rules, const vector<string>& feature_names); + +  vector<Rule> GetRules() const; + +  vector<string> GetFeatureNames() const; + +  friend ostream& operator<<(ostream& os, const Grammar& grammar); + + private: +  vector<Rule> rules; +  vector<string> feature_names; +}; + +} // namespace extractor + +#endif diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc new file mode 100644 index 00000000..8050ce7b --- /dev/null +++ b/extractor/grammar_extractor.cc @@ -0,0 +1,62 @@ +#include "grammar_extractor.h" + +#include <iterator> +#include <sstream> +#include <vector> + +#include "grammar.h" +#include "rule.h" +#include "rule_factory.h" +#include "vocabulary.h" + +using namespace std; + +namespace extractor { + +GrammarExtractor::GrammarExtractor( +    shared_ptr<SuffixArray> source_suffix_array, +    shared_ptr<DataArray> target_data_array, +    shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation, +    shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span, +    int max_nonterminals, int max_rule_symbols, int max_samples, +    bool require_tight_phrases) : +    vocabulary(make_shared<Vocabulary>()), +    rule_factory(make_shared<HieroCachingRuleFactory>( +        source_suffix_array, target_data_array, alignment, vocabulary, +        precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals, +        max_rule_symbols, max_samples, require_tight_phrases)) {} + +GrammarExtractor::GrammarExtractor( +    shared_ptr<Vocabulary> vocabulary, +    shared_ptr<HieroCachingRuleFactory> rule_factory) : +    vocabulary(vocabulary), +    rule_factory(rule_factory) {} + +Grammar GrammarExtractor::GetGrammar(const string& sentence) { +  vector<string> words = TokenizeSentence(sentence); +  vector<int> word_ids = AnnotateWords(words); +  return rule_factory->GetGrammar(word_ids); +} + +vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { +  vector<string> result; +  result.push_back("<s>"); + +  istringstream buffer(sentence); +  copy(istream_iterator<string>(buffer), +       istream_iterator<string>(), +       back_inserter(result)); + +  result.push_back("</s>"); +  return result; +} + +vector<int> GrammarExtractor::AnnotateWords(const vector<string>& words) { +  vector<int> result; +  for (string word: words) { +    result.push_back(vocabulary->GetTerminalIndex(word)); +  } +  return result; +} + +} // namespace extractor diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h new file mode 100644 index 00000000..b36ceeb9 --- /dev/null +++ b/extractor/grammar_extractor.h @@ -0,0 +1,62 @@ +#ifndef _GRAMMAR_EXTRACTOR_H_ +#define _GRAMMAR_EXTRACTOR_H_ + +#include <memory> +#include <string> +#include <vector> + +using namespace std; + +namespace extractor { + +class Alignment; +class DataArray; +class Grammar; +class HieroCachingRuleFactory; +class Precomputation; +class Rule; +class Scorer; +class SuffixArray; +class Vocabulary; + +/** + * Class wrapping all the logic for extracting the synchronous context free + * grammars. + */ +class GrammarExtractor { + public: +  GrammarExtractor( +      shared_ptr<SuffixArray> source_suffix_array, +      shared_ptr<DataArray> target_data_array, +      shared_ptr<Alignment> alignment, +      shared_ptr<Precomputation> precomputation, +      shared_ptr<Scorer> scorer, +      int min_gap_size, +      int max_rule_span, +      int max_nonterminals, +      int max_rule_symbols, +      int max_samples, +      bool require_tight_phrases); + +  // For testing only. +  GrammarExtractor(shared_ptr<Vocabulary> vocabulary, +                   shared_ptr<HieroCachingRuleFactory> rule_factory); + +  // Converts the sentence to a vector of word ids and uses the RuleFactory to +  // extract the SCFG rules which may be used to decode the sentence. +  Grammar GetGrammar(const string& sentence); + + private: +  // Splits the sentence in a vector of words. +  vector<string> TokenizeSentence(const string& sentence); + +  // Maps the words to word ids. +  vector<int> AnnotateWords(const vector<string>& words); + +  shared_ptr<Vocabulary> vocabulary; +  shared_ptr<HieroCachingRuleFactory> rule_factory; +}; + +} // namespace extractor + +#endif diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc new file mode 100644 index 00000000..823bb8b4 --- /dev/null +++ b/extractor/grammar_extractor_test.cc @@ -0,0 +1,51 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "grammar.h" +#include "grammar_extractor.h" +#include "mocks/mock_rule_factory.h" +#include "mocks/mock_vocabulary.h" +#include "rule.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +TEST(GrammarExtractorTest, TestAnnotatingWords) { +  shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>(); +  EXPECT_CALL(*vocabulary, GetTerminalIndex("<s>")) +      .WillRepeatedly(Return(0)); +  EXPECT_CALL(*vocabulary, GetTerminalIndex("Anna")) +      .WillRepeatedly(Return(1)); +  EXPECT_CALL(*vocabulary, GetTerminalIndex("has")) +      .WillRepeatedly(Return(2)); +  EXPECT_CALL(*vocabulary, GetTerminalIndex("many")) +      .WillRepeatedly(Return(3)); +  EXPECT_CALL(*vocabulary, GetTerminalIndex("apples")) +      .WillRepeatedly(Return(4)); +  EXPECT_CALL(*vocabulary, GetTerminalIndex(".")) +      .WillRepeatedly(Return(5)); +  EXPECT_CALL(*vocabulary, GetTerminalIndex("</s>")) +      .WillRepeatedly(Return(6)); + +  shared_ptr<MockHieroCachingRuleFactory> factory = +      make_shared<MockHieroCachingRuleFactory>(); +  vector<int> word_ids = {0, 1, 2, 3, 3, 4, 5, 6}; +  vector<Rule> rules; +  vector<string> feature_names; +  Grammar grammar(rules, feature_names); +  EXPECT_CALL(*factory, GetGrammar(word_ids)) +      .WillOnce(Return(grammar)); + +  GrammarExtractor extractor(vocabulary, factory); +  string sentence = "Anna has many many apples ."; +  extractor.GetGrammar(sentence); +} + +} // namespace +} // namespace extractor diff --git a/extractor/matchings_finder.cc b/extractor/matchings_finder.cc new file mode 100644 index 00000000..ceed6891 --- /dev/null +++ b/extractor/matchings_finder.cc @@ -0,0 +1,25 @@ +#include "matchings_finder.h" + +#include "suffix_array.h" +#include "phrase_location.h" + +namespace extractor { + +MatchingsFinder::MatchingsFinder(shared_ptr<SuffixArray> suffix_array) : +    suffix_array(suffix_array) {} + +MatchingsFinder::MatchingsFinder() {} + +MatchingsFinder::~MatchingsFinder() {} + +PhraseLocation MatchingsFinder::Find(PhraseLocation& location, +                                     const string& word, int offset) { +  if (location.sa_low == -1 && location.sa_high == -1) { +    location.sa_low = 0; +    location.sa_high = suffix_array->GetSize(); +  } + +  return suffix_array->Lookup(location.sa_low, location.sa_high, word, offset); +} + +} // namespace extractor diff --git a/extractor/matchings_finder.h b/extractor/matchings_finder.h new file mode 100644 index 00000000..451f4a4c --- /dev/null +++ b/extractor/matchings_finder.h @@ -0,0 +1,37 @@ +#ifndef _MATCHINGS_FINDER_H_ +#define _MATCHINGS_FINDER_H_ + +#include <memory> +#include <string> + +using namespace std; + +namespace extractor { + +class PhraseLocation; +class SuffixArray; + +/** + * Class wrapping the suffix array lookup for a contiguous phrase. + */ +class MatchingsFinder { + public: +  MatchingsFinder(shared_ptr<SuffixArray> suffix_array); + +  virtual ~MatchingsFinder(); + +  // Uses the suffix array to search only for the last word of the phrase +  // starting from the range in which the prefix of the phrase occurs. +  virtual PhraseLocation Find(PhraseLocation& location, const string& word, +                              int offset); + + protected: +  MatchingsFinder(); + + private: +  shared_ptr<SuffixArray> suffix_array; +}; + +} // namespace extractor + +#endif diff --git a/extractor/matchings_finder_test.cc b/extractor/matchings_finder_test.cc new file mode 100644 index 00000000..d40e5191 --- /dev/null +++ b/extractor/matchings_finder_test.cc @@ -0,0 +1,44 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "matchings_finder.h" +#include "mocks/mock_suffix_array.h" +#include "phrase_location.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class MatchingsFinderTest : public Test { + protected: +  virtual void SetUp() { +    suffix_array = make_shared<MockSuffixArray>(); +    EXPECT_CALL(*suffix_array, Lookup(0, 10, _, _)) +        .Times(1) +        .WillOnce(Return(PhraseLocation(3, 5))); + +    matchings_finder = make_shared<MatchingsFinder>(suffix_array); +  } + +  shared_ptr<MatchingsFinder> matchings_finder; +  shared_ptr<MockSuffixArray> suffix_array; +}; + +TEST_F(MatchingsFinderTest, TestFind) { +  PhraseLocation phrase_location(0, 10), expected_result(3, 5); +  EXPECT_EQ(expected_result, matchings_finder->Find(phrase_location, "bla", 2)); +} + +TEST_F(MatchingsFinderTest, ResizeUnsetRange) { +  EXPECT_CALL(*suffix_array, GetSize()).Times(1).WillOnce(Return(10)); + +  PhraseLocation phrase_location, expected_result(3, 5); +  EXPECT_EQ(expected_result, matchings_finder->Find(phrase_location, "bla", 2)); +  EXPECT_EQ(PhraseLocation(0, 10), phrase_location); +} + +} // namespace +} // namespace extractor diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc new file mode 100644 index 00000000..7fb7a529 --- /dev/null +++ b/extractor/matchings_trie.cc @@ -0,0 +1,29 @@ +#include "matchings_trie.h" + +namespace extractor { + +MatchingsTrie::MatchingsTrie() { +  root = make_shared<TrieNode>(); +} + +MatchingsTrie::~MatchingsTrie() { +  DeleteTree(root); +} + +shared_ptr<TrieNode> MatchingsTrie::GetRoot() const { +  return root; +} + +void MatchingsTrie::DeleteTree(shared_ptr<TrieNode> root) { +  if (root != NULL) { +    for (auto child: root->children) { +      DeleteTree(child.second); +    } +    if (root->suffix_link != NULL) { +      root->suffix_link.reset(); +    } +    root.reset(); +  } +} + +} // namespace extractor diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h new file mode 100644 index 00000000..1fb29693 --- /dev/null +++ b/extractor/matchings_trie.h @@ -0,0 +1,66 @@ +#ifndef _MATCHINGS_TRIE_ +#define _MATCHINGS_TRIE_ + +#include <memory> +#include <unordered_map> + +#include "phrase.h" +#include "phrase_location.h" + +using namespace std; + +namespace extractor { + +/** + * Trie node containing all the occurrences of the corresponding phrase in the + * source data. + */ +struct TrieNode { +  TrieNode(shared_ptr<TrieNode> suffix_link = shared_ptr<TrieNode>(), +           Phrase phrase = Phrase(), +           PhraseLocation matchings = PhraseLocation()) : +      suffix_link(suffix_link), phrase(phrase), matchings(matchings) {} + +  // Adds a trie node as a child of the current node. +  void AddChild(int key, shared_ptr<TrieNode> child_node) { +    children[key] = child_node; +  } + +  // Checks if a child exists for a given key. +  bool HasChild(int key) { +    return children.count(key); +  } + +  // Gets the child corresponding to the given key. +  shared_ptr<TrieNode> GetChild(int key) { +    return children[key]; +  } + +  shared_ptr<TrieNode> suffix_link; +  Phrase phrase; +  PhraseLocation matchings; +  unordered_map<int, shared_ptr<TrieNode> > children; +}; + +/** + * Trie containing all the phrases that can be obtained from a sentence. + */ +class MatchingsTrie { + public: +  MatchingsTrie(); + +  virtual ~MatchingsTrie(); + +  // Returns the root of the trie. +  shared_ptr<TrieNode> GetRoot() const; + + private: +  // Recursively deletes a subtree of the trie. +  void DeleteTree(shared_ptr<TrieNode> root); + +  shared_ptr<TrieNode> root; +}; + +} // namespace extractor + +#endif diff --git a/extractor/mocks/mock_alignment.h b/extractor/mocks/mock_alignment.h new file mode 100644 index 00000000..299c3d1c --- /dev/null +++ b/extractor/mocks/mock_alignment.h @@ -0,0 +1,14 @@ +#include <gmock/gmock.h> + +#include "alignment.h" + +namespace extractor { + +typedef vector<pair<int, int> > SentenceLinks; + +class MockAlignment : public Alignment { + public: +  MOCK_CONST_METHOD1(GetLinks, SentenceLinks(int sentence_id)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h new file mode 100644 index 00000000..6f85abb4 --- /dev/null +++ b/extractor/mocks/mock_data_array.h @@ -0,0 +1,23 @@ +#include <gmock/gmock.h> + +#include "data_array.h" + +namespace extractor { + +class MockDataArray : public DataArray { + public: +  MOCK_CONST_METHOD0(GetData, const vector<int>&()); +  MOCK_CONST_METHOD1(AtIndex, int(int index)); +  MOCK_CONST_METHOD1(GetWordAtIndex, string(int index)); +  MOCK_CONST_METHOD0(GetSize, int()); +  MOCK_CONST_METHOD0(GetVocabularySize, int()); +  MOCK_CONST_METHOD1(HasWord, bool(const string& word)); +  MOCK_CONST_METHOD1(GetWordId, int(const string& word)); +  MOCK_CONST_METHOD1(GetWord, string(int word_id)); +  MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id)); +  MOCK_CONST_METHOD0(GetNumSentences, int()); +  MOCK_CONST_METHOD1(GetSentenceStart, int(int sentence_id)); +  MOCK_CONST_METHOD1(GetSentenceId, int(int position)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_fast_intersector.h b/extractor/mocks/mock_fast_intersector.h new file mode 100644 index 00000000..f0b628d7 --- /dev/null +++ b/extractor/mocks/mock_fast_intersector.h @@ -0,0 +1,15 @@ +#include <gmock/gmock.h> + +#include "fast_intersector.h" +#include "phrase.h" +#include "phrase_location.h" + +namespace extractor { + +class MockFastIntersector : public FastIntersector { + public: +  MOCK_METHOD3(Intersect, PhraseLocation(PhraseLocation&, PhraseLocation&, +                                         const Phrase&)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_feature.h b/extractor/mocks/mock_feature.h new file mode 100644 index 00000000..0b0f0ead --- /dev/null +++ b/extractor/mocks/mock_feature.h @@ -0,0 +1,15 @@ +#include <gmock/gmock.h> + +#include "features/feature.h" + +namespace extractor { +namespace features { + +class MockFeature : public Feature { + public: +  MOCK_CONST_METHOD1(Score, double(const FeatureContext& context)); +  MOCK_CONST_METHOD0(GetName, string()); +}; + +} // namespace features +} // namespace extractor diff --git a/extractor/mocks/mock_matchings_finder.h b/extractor/mocks/mock_matchings_finder.h new file mode 100644 index 00000000..827526fd --- /dev/null +++ b/extractor/mocks/mock_matchings_finder.h @@ -0,0 +1,13 @@ +#include <gmock/gmock.h> + +#include "matchings_finder.h" +#include "phrase_location.h" + +namespace extractor { + +class MockMatchingsFinder : public MatchingsFinder { + public: +  MOCK_METHOD3(Find, PhraseLocation(PhraseLocation&, const string&, int)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h new file mode 100644 index 00000000..8753343e --- /dev/null +++ b/extractor/mocks/mock_precomputation.h @@ -0,0 +1,12 @@ +#include <gmock/gmock.h> + +#include "precomputation.h" + +namespace extractor { + +class MockPrecomputation : public Precomputation { + public: +  MOCK_CONST_METHOD0(GetCollocations, const Index&()); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_rule_extractor.h b/extractor/mocks/mock_rule_extractor.h new file mode 100644 index 00000000..aad11651 --- /dev/null +++ b/extractor/mocks/mock_rule_extractor.h @@ -0,0 +1,16 @@ +#include <gmock/gmock.h> + +#include "phrase.h" +#include "phrase_builder.h" +#include "rule.h" +#include "rule_extractor.h" + +namespace extractor { + +class MockRuleExtractor : public RuleExtractor { + public: +  MOCK_CONST_METHOD2(ExtractRules, vector<Rule>(const Phrase&, +      const PhraseLocation&)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_rule_extractor_helper.h b/extractor/mocks/mock_rule_extractor_helper.h new file mode 100644 index 00000000..468468f6 --- /dev/null +++ b/extractor/mocks/mock_rule_extractor_helper.h @@ -0,0 +1,82 @@ +#include <gmock/gmock.h> + +#include <vector> + +#include "rule_extractor_helper.h" + +using namespace std; + +namespace extractor { + +typedef unordered_map<int, int> Indexes; + +class MockRuleExtractorHelper : public RuleExtractorHelper { + public: +  MOCK_CONST_METHOD5(GetLinksSpans, void(vector<int>&, vector<int>&, +      vector<int>&, vector<int>&, int)); +  MOCK_CONST_METHOD4(CheckAlignedTerminals, bool(const vector<int>&, +      const vector<int>&, const vector<int>&, int)); +  MOCK_CONST_METHOD4(CheckTightPhrases, bool(const vector<int>&, +      const vector<int>&, const vector<int>&, int)); +  MOCK_CONST_METHOD1(GetGapOrder, vector<int>(const vector<pair<int, int> >&)); +  MOCK_CONST_METHOD4(GetSourceIndexes, Indexes(const vector<int>&, +      const vector<int>&, int, int)); + +  // We need to implement these methods, because Google Mock doesn't support +  // methods with more than 10 arguments. +  bool FindFixPoint( +      int, int, const vector<int>&, const vector<int>&, int& target_phrase_low, +      int& target_phrase_high, const vector<int>&, const vector<int>&, +      int& source_back_low, int& source_back_high, int, int, int, int, bool, +      bool, bool) const { +    target_phrase_low = this->target_phrase_low; +    target_phrase_high = this->target_phrase_high; +    source_back_low = this->source_back_low; +    source_back_high = this->source_back_high; +    return find_fix_point; +  } + +  bool GetGaps(vector<pair<int, int> >& source_gaps, +               vector<pair<int, int> >& target_gaps, +               const vector<int>&, const vector<int>&, const vector<int>&, +               const vector<int>&, const vector<int>&, const vector<int>&, +               int, int, int, int, int, int, int& num_symbols, +               bool& met_constraints) const { +    source_gaps = this->source_gaps; +    target_gaps = this->target_gaps; +    num_symbols = this->num_symbols; +    met_constraints = this->met_constraints; +    return get_gaps; +  } + +  void SetUp( +      int target_phrase_low, int target_phrase_high, int source_back_low, +      int source_back_high, bool find_fix_point, +      vector<pair<int, int> > source_gaps, vector<pair<int, int> > target_gaps, +      int num_symbols, bool met_constraints, bool get_gaps) { +    this->target_phrase_low = target_phrase_low; +    this->target_phrase_high = target_phrase_high; +    this->source_back_low = source_back_low; +    this->source_back_high = source_back_high; +    this->find_fix_point = find_fix_point; +    this->source_gaps = source_gaps; +    this->target_gaps = target_gaps; +    this->num_symbols = num_symbols; +    this->met_constraints = met_constraints; +    this->get_gaps = get_gaps; +  } + + private: +  int target_phrase_low; +  int target_phrase_high; +  int source_back_low; +  int source_back_high; +  bool find_fix_point; +  vector<pair<int, int> > source_gaps; +  vector<pair<int, int> > target_gaps; +  int num_symbols; +  bool met_constraints; +  bool get_gaps; +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h new file mode 100644 index 00000000..7389b396 --- /dev/null +++ b/extractor/mocks/mock_rule_factory.h @@ -0,0 +1,13 @@ +#include <gmock/gmock.h> + +#include "grammar.h" +#include "rule_factory.h" + +namespace extractor { + +class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { + public: +  MOCK_METHOD1(GetGrammar, Grammar(const vector<int>& word_ids)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h new file mode 100644 index 00000000..75c43c27 --- /dev/null +++ b/extractor/mocks/mock_sampler.h @@ -0,0 +1,13 @@ +#include <gmock/gmock.h> + +#include "phrase_location.h" +#include "sampler.h" + +namespace extractor { + +class MockSampler : public Sampler { + public: +  MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_scorer.h b/extractor/mocks/mock_scorer.h new file mode 100644 index 00000000..cc0c444d --- /dev/null +++ b/extractor/mocks/mock_scorer.h @@ -0,0 +1,15 @@ +#include <gmock/gmock.h> + +#include "scorer.h" +#include "features/feature.h" + +namespace extractor { + +class MockScorer : public Scorer { + public: +  MOCK_CONST_METHOD1(Score, vector<double>( +      const features::FeatureContext& context)); +  MOCK_CONST_METHOD0(GetFeatureNames, vector<string>()); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_suffix_array.h b/extractor/mocks/mock_suffix_array.h new file mode 100644 index 00000000..7018acc7 --- /dev/null +++ b/extractor/mocks/mock_suffix_array.h @@ -0,0 +1,23 @@ +#include <gmock/gmock.h> + +#include <memory> +#include <string> + +#include "data_array.h" +#include "phrase_location.h" +#include "suffix_array.h" + +using namespace std; + +namespace extractor { + +class MockSuffixArray : public SuffixArray { + public: +  MOCK_CONST_METHOD0(GetSize, int()); +  MOCK_CONST_METHOD0(GetData, shared_ptr<DataArray>()); +  MOCK_CONST_METHOD0(BuildLCPArray, vector<int>()); +  MOCK_CONST_METHOD1(GetSuffix, int(int)); +  MOCK_CONST_METHOD4(Lookup, PhraseLocation(int, int, const string& word, int)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_target_phrase_extractor.h b/extractor/mocks/mock_target_phrase_extractor.h new file mode 100644 index 00000000..6aad853c --- /dev/null +++ b/extractor/mocks/mock_target_phrase_extractor.h @@ -0,0 +1,16 @@ +#include <gmock/gmock.h> + +#include "target_phrase_extractor.h" + +namespace extractor { + +typedef pair<Phrase, PhraseAlignment> PhraseExtract; + +class MockTargetPhraseExtractor : public TargetPhraseExtractor { + public: +  MOCK_CONST_METHOD6(ExtractPhrases, vector<PhraseExtract>( +      const vector<pair<int, int> > &, const vector<int>&, int, int, +      const unordered_map<int, int>&, int)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_translation_table.h b/extractor/mocks/mock_translation_table.h new file mode 100644 index 00000000..307e4282 --- /dev/null +++ b/extractor/mocks/mock_translation_table.h @@ -0,0 +1,13 @@ +#include <gmock/gmock.h> + +#include "translation_table.h" + +namespace extractor { + +class MockTranslationTable : public TranslationTable { + public: +  MOCK_METHOD2(GetSourceGivenTargetScore, double(const string&, const string&)); +  MOCK_METHOD2(GetTargetGivenSourceScore, double(const string&, const string&)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_vocabulary.h b/extractor/mocks/mock_vocabulary.h new file mode 100644 index 00000000..042c9ce2 --- /dev/null +++ b/extractor/mocks/mock_vocabulary.h @@ -0,0 +1,13 @@ +#include <gmock/gmock.h> + +#include "vocabulary.h" + +namespace extractor { + +class MockVocabulary : public Vocabulary { + public: +  MOCK_METHOD1(GetTerminalValue, string(int word_id)); +  MOCK_METHOD1(GetTerminalIndex, int(const string& word)); +}; + +} // namespace extractor diff --git a/extractor/phrase.cc b/extractor/phrase.cc new file mode 100644 index 00000000..e619bfe5 --- /dev/null +++ b/extractor/phrase.cc @@ -0,0 +1,58 @@ +#include "phrase.h" + +namespace extractor { + +int Phrase::Arity() const { +  return var_pos.size(); +} + +int Phrase::GetChunkLen(int index) const { +  if (var_pos.size() == 0) { +    return symbols.size(); +  } else if (index == 0) { +    return var_pos[0]; +  } else if (index == var_pos.size()) { +    return symbols.size() - var_pos.back() - 1; +  } else { +    return var_pos[index] - var_pos[index - 1] - 1; +  } +} + +vector<int> Phrase::Get() const { +  return symbols; +} + +int Phrase::GetSymbol(int position) const { +  return symbols[position]; +} + +int Phrase::GetNumSymbols() const { +  return symbols.size(); +} + +vector<string> Phrase::GetWords() const { +  return words; +} + +bool Phrase::operator<(const Phrase& other) const { +  return symbols < other.symbols; +} + +ostream& operator<<(ostream& os, const Phrase& phrase) { +  int current_word = 0; +  for (size_t i = 0; i < phrase.symbols.size(); ++i) { +    if (phrase.symbols[i] < 0) { +      os << "[X," << -phrase.symbols[i] << "]"; +    } else { +      os << phrase.words[current_word]; +      ++current_word; +    } + +    if (i + 1 < phrase.symbols.size()) { +      os << " "; +    } +  } +  return os; +} + +} // namspace extractor diff --git a/extractor/phrase.h b/extractor/phrase.h new file mode 100644 index 00000000..a8e91e3c --- /dev/null +++ b/extractor/phrase.h @@ -0,0 +1,52 @@ +#ifndef _PHRASE_H_ +#define _PHRASE_H_ + +#include <iostream> +#include <string> +#include <vector> + +#include "phrase_builder.h" + +using namespace std; + +namespace extractor { + +/** + * Structure containing the data for a phrase. + */ +class Phrase { + public: +  friend Phrase PhraseBuilder::Build(const vector<int>& phrase); + +  // Returns the number of nonterminals in the phrase. +  int Arity() const; + +  // Returns the number of terminals (length) for the given chunk. (A chunk is a +  // contiguous sequence of terminals in the phrase). +  int GetChunkLen(int index) const; + +  // Returns the symbols (word ids) marking up the phrase. +  vector<int> Get() const; + +  // Returns the symbol located at the given position in the phrase. +  int GetSymbol(int position) const; + +  // Returns the number of symbols in the phrase. +  int GetNumSymbols() const; + +  // Returns the words making up the phrase. (Nonterminals are stripped out.) +  vector<string> GetWords() const; + +  bool operator<(const Phrase& other) const; + +  friend ostream& operator<<(ostream& os, const Phrase& phrase); + + private: +  vector<int> symbols; +  vector<int> var_pos; +  vector<string> words; +}; + +} // namespace extractor + +#endif diff --git a/extractor/phrase_builder.cc b/extractor/phrase_builder.cc new file mode 100644 index 00000000..9faee4be --- /dev/null +++ b/extractor/phrase_builder.cc @@ -0,0 +1,48 @@ +#include "phrase_builder.h" + +#include "phrase.h" +#include "vocabulary.h" + +namespace extractor { + +PhraseBuilder::PhraseBuilder(shared_ptr<Vocabulary> vocabulary) : +    vocabulary(vocabulary) {} + +Phrase PhraseBuilder::Build(const vector<int>& symbols) { +  Phrase phrase; +  phrase.symbols = symbols; +  for (size_t i = 0; i < symbols.size(); ++i) { +    if (vocabulary->IsTerminal(symbols[i])) { +      phrase.words.push_back(vocabulary->GetTerminalValue(symbols[i])); +    } else { +      phrase.var_pos.push_back(i); +    } +  } +  return phrase; +} + +Phrase PhraseBuilder::Extend(const Phrase& phrase, bool start_x, bool end_x) { +  vector<int> symbols = phrase.Get(); +  int num_nonterminals = 0; +  if (start_x) { +    num_nonterminals = 1; +    symbols.insert(symbols.begin(), +        vocabulary->GetNonterminalIndex(num_nonterminals)); +  } + +  for (size_t i = start_x; i < symbols.size(); ++i) { +    if (!vocabulary->IsTerminal(symbols[i])) { +      ++num_nonterminals; +      symbols[i] = vocabulary->GetNonterminalIndex(num_nonterminals); +    } +  } + +  if (end_x) { +    ++num_nonterminals; +    symbols.push_back(vocabulary->GetNonterminalIndex(num_nonterminals)); +  } + +  return Build(symbols); +} + +} // namespace extractor diff --git a/extractor/phrase_builder.h b/extractor/phrase_builder.h new file mode 100644 index 00000000..de86dbae --- /dev/null +++ b/extractor/phrase_builder.h @@ -0,0 +1,33 @@ +#ifndef _PHRASE_BUILDER_H_ +#define _PHRASE_BUILDER_H_ + +#include <memory> +#include <vector> + +using namespace std; + +namespace extractor { + +class Phrase; +class Vocabulary; + +/** + * Component for constructing phrases. + */ +class PhraseBuilder { + public: +  PhraseBuilder(shared_ptr<Vocabulary> vocabulary); + +  // Constructs a phrase starting from an array of symbols. +  Phrase Build(const vector<int>& symbols); + +  // Extends a phrase with a leading and/or trailing nonterminal. +  Phrase Extend(const Phrase& phrase, bool start_x, bool end_x); + + private: +  shared_ptr<Vocabulary> vocabulary; +}; + +} // namespace extractor + +#endif diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc new file mode 100644 index 00000000..678ae270 --- /dev/null +++ b/extractor/phrase_location.cc @@ -0,0 +1,43 @@ +#include "phrase_location.h" + +namespace extractor { + +PhraseLocation::PhraseLocation(int sa_low, int sa_high) : +    sa_low(sa_low), sa_high(sa_high), num_subpatterns(0) {} + +PhraseLocation::PhraseLocation(const vector<int>& matchings, +                               int num_subpatterns) : +    sa_low(0), sa_high(0), +    matchings(make_shared<vector<int> >(matchings)), +    num_subpatterns(num_subpatterns) {} + +bool PhraseLocation::IsEmpty() const { +  return GetSize() == 0; +} + +int PhraseLocation::GetSize() const { +  if (num_subpatterns > 0) { +    return matchings->size(); +  } else { +    return sa_high - sa_low; +  } +} + +bool operator==(const PhraseLocation& a, const PhraseLocation& b) { +  if (a.sa_low != b.sa_low || a.sa_high != b.sa_high || +      a.num_subpatterns != b.num_subpatterns) { +    return false; +  } + +  if (a.matchings == NULL && b.matchings == NULL) { +    return true; +  } + +  if (a.matchings == NULL || b.matchings == NULL) { +    return false; +  } + +  return *a.matchings == *b.matchings; +} + +} // namespace extractor diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h new file mode 100644 index 00000000..91950e03 --- /dev/null +++ b/extractor/phrase_location.h @@ -0,0 +1,41 @@ +#ifndef _PHRASE_LOCATION_H_ +#define _PHRASE_LOCATION_H_ + +#include <memory> +#include <vector> + +using namespace std; + +namespace extractor { + +/** + * Structure containing information about the occurrences of a phrase in the + * source data. + * + * Every consecutive (disjoint) group of num_subpatterns entries in matchings + * vector encodes an occurrence of the phrase. The i-th entry of a group + * represents the start of the i-th subpattern of the phrase. If the phrase + * doesn't contain any nonterminals, then it may also be represented as the + * range in the suffix array which matches the phrase. + */ +struct PhraseLocation { +  PhraseLocation(int sa_low = -1, int sa_high = -1); + +  PhraseLocation(const vector<int>& matchings, int num_subpatterns); + +  // Checks if a phrase has any occurrences in the source data. +  bool IsEmpty() const; + +  // Returns the number of occurrences of a phrase in the source data. +  int GetSize() const; + +  friend bool operator==(const PhraseLocation& a, const PhraseLocation& b); + +  int sa_low, sa_high; +  shared_ptr<vector<int> > matchings; +  int num_subpatterns; +}; + +} // namespace extractor + +#endif diff --git a/extractor/phrase_test.cc b/extractor/phrase_test.cc new file mode 100644 index 00000000..3ba9368a --- /dev/null +++ b/extractor/phrase_test.cc @@ -0,0 +1,83 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class PhraseTest : public Test { + protected: +  virtual void SetUp() { +    shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>(); +    vector<string> words = {"w1", "w2", "w3", "w4"}; +    for (size_t i = 0; i < words.size(); ++i) { +      EXPECT_CALL(*vocabulary, GetTerminalValue(i + 1)) +          .WillRepeatedly(Return(words[i])); +    } +    shared_ptr<PhraseBuilder> phrase_builder = +        make_shared<PhraseBuilder>(vocabulary); + +    symbols1 = vector<int>{1, 2, 3}; +    phrase1 = phrase_builder->Build(symbols1); +    symbols2 = vector<int>{1, 2, -1, 3, -2, 4}; +    phrase2 = phrase_builder->Build(symbols2); +  } + +  vector<int> symbols1, symbols2; +  Phrase phrase1, phrase2; +}; + +TEST_F(PhraseTest, TestArity) { +  EXPECT_EQ(0, phrase1.Arity()); +  EXPECT_EQ(2, phrase2.Arity()); +} + +TEST_F(PhraseTest, GetChunkLen) { +  EXPECT_EQ(3, phrase1.GetChunkLen(0)); + +  EXPECT_EQ(2, phrase2.GetChunkLen(0)); +  EXPECT_EQ(1, phrase2.GetChunkLen(1)); +  EXPECT_EQ(1, phrase2.GetChunkLen(2)); +} + +TEST_F(PhraseTest, TestGet) { +  EXPECT_EQ(symbols1, phrase1.Get()); +  EXPECT_EQ(symbols2, phrase2.Get()); +} + +TEST_F(PhraseTest, TestGetSymbol) { +  for (size_t i = 0; i < symbols1.size(); ++i) { +    EXPECT_EQ(symbols1[i], phrase1.GetSymbol(i)); +  } +  for (size_t i = 0; i < symbols2.size(); ++i) { +    EXPECT_EQ(symbols2[i], phrase2.GetSymbol(i)); +  } +} + +TEST_F(PhraseTest, TestGetNumSymbols) { +  EXPECT_EQ(3, phrase1.GetNumSymbols()); +  EXPECT_EQ(6, phrase2.GetNumSymbols()); +} + +TEST_F(PhraseTest, TestGetWords) { +  vector<string> expected_words = {"w1", "w2", "w3"}; +  EXPECT_EQ(expected_words, phrase1.GetWords()); +  expected_words = {"w1", "w2", "w3", "w4"}; +  EXPECT_EQ(expected_words, phrase2.GetWords()); +} + +TEST_F(PhraseTest, TestComparator) { +  EXPECT_FALSE(phrase1 < phrase2); +  EXPECT_TRUE(phrase2 < phrase1); +} + +} // namespace +} // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc new file mode 100644 index 00000000..b3906943 --- /dev/null +++ b/extractor/precomputation.cc @@ -0,0 +1,189 @@ +#include "precomputation.h" + +#include <iostream> +#include <queue> + +#include "data_array.h" +#include "suffix_array.h" + +using namespace std; + +namespace extractor { + +int Precomputation::FIRST_NONTERMINAL = -1; +int Precomputation::SECOND_NONTERMINAL = -2; + +Precomputation::Precomputation( +    shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, +    int num_super_frequent_patterns, int max_rule_span, +    int max_rule_symbols, int min_gap_size, +    int max_frequent_phrase_len, int min_frequency) { +  vector<int> data = suffix_array->GetData()->GetData(); +  vector<vector<int> > frequent_patterns = FindMostFrequentPatterns( +      suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, +      min_frequency); + +  // Construct sets containing the frequent and superfrequent contiguous +  // collocations. +  unordered_set<vector<int>, VectorHash> frequent_patterns_set; +  unordered_set<vector<int>, VectorHash> super_frequent_patterns_set; +  for (size_t i = 0; i < frequent_patterns.size(); ++i) { +    frequent_patterns_set.insert(frequent_patterns[i]); +    if (i < num_super_frequent_patterns) { +      super_frequent_patterns_set.insert(frequent_patterns[i]); +    } +  } + +  vector<tuple<int, int, int> > matchings; +  for (size_t i = 0; i < data.size(); ++i) { +    // If the sentence is over, add all the discontiguous frequent patterns to +    // the index. +    if (data[i] == DataArray::END_OF_LINE) { +      AddCollocations(matchings, data, max_rule_span, min_gap_size, +          max_rule_symbols); +      matchings.clear(); +      continue; +    } +    vector<int> pattern; +    // Find all the contiguous frequent patterns starting at position i. +    for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { +      pattern.push_back(data[i + j - 1]); +      if (frequent_patterns_set.count(pattern)) { +        int is_super_frequent = super_frequent_patterns_set.count(pattern); +        matchings.push_back(make_tuple(i, j, is_super_frequent)); +      } else { +        // If the current pattern is not frequent, any longer pattern having the +        // current pattern as prefix will not be frequent. +        break; +      } +    } +  } +} + +Precomputation::Precomputation() {} + +Precomputation::~Precomputation() {} + +vector<vector<int> > Precomputation::FindMostFrequentPatterns( +    shared_ptr<SuffixArray> suffix_array, const vector<int>& data, +    int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) { +  vector<int> lcp = suffix_array->BuildLCPArray(); +  vector<int> run_start(max_frequent_phrase_len); + +  // Find all the patterns occurring at least min_frequency times. +  priority_queue<pair<int, pair<int, int> > > heap; +  for (size_t i = 1; i < lcp.size(); ++i) { +    for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) { +      int frequency = i - run_start[len]; +      if (frequency >= min_frequency) { +        heap.push(make_pair(frequency, +            make_pair(suffix_array->GetSuffix(run_start[len]), len + 1))); +      } +      run_start[len] = i; +    } +  } + +  // Extract the most frequent patterns. +  vector<vector<int> > frequent_patterns; +  while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { +    int start = heap.top().second.first; +    int len = heap.top().second.second; +    heap.pop(); + +    vector<int> pattern(data.begin() + start, data.begin() + start + len); +    if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == +        pattern.end()) { +      frequent_patterns.push_back(pattern); +    } +  } +  return frequent_patterns; +} + +void Precomputation::AddCollocations( +    const vector<tuple<int, int, int> >& matchings, const vector<int>& data, +    int max_rule_span, int min_gap_size, int max_rule_symbols) { +  // Select the leftmost subpattern. +  for (size_t i = 0; i < matchings.size(); ++i) { +    int start1, size1, is_super1; +    tie(start1, size1, is_super1) = matchings[i]; + +    // Select the second (middle) subpattern +    for (size_t j = i + 1; j < matchings.size(); ++j) { +      int start2, size2, is_super2; +      tie(start2, size2, is_super2) = matchings[j]; +      if (start2 - start1 >= max_rule_span) { +        break; +      } + +      if (start2 - start1 - size1 >= min_gap_size +          && start2 + size2 - start1 <= max_rule_span +          && size1 + size2 + 1 <= max_rule_symbols) { +        vector<int> pattern(data.begin() + start1, +            data.begin() + start1 + size1); +        pattern.push_back(Precomputation::FIRST_NONTERMINAL); +        pattern.insert(pattern.end(), data.begin() + start2, +            data.begin() + start2 + size2); +        AddStartPositions(collocations[pattern], start1, start2); + +        // Try extending the binary collocation to a ternary collocation. +        if (is_super2) { +          pattern.push_back(Precomputation::SECOND_NONTERMINAL); +          // Select the rightmost subpattern. +          for (size_t k = j + 1; k < matchings.size(); ++k) { +            int start3, size3, is_super3; +            tie(start3, size3, is_super3) = matchings[k]; +            if (start3 - start1 >= max_rule_span) { +              break; +            } + +            if (start3 - start2 - size2 >= min_gap_size +                && start3 + size3 - start1 <= max_rule_span +                && size1 + size2 + size3 + 2 <= max_rule_symbols +                && (is_super1 || is_super3)) { +              pattern.insert(pattern.end(), data.begin() + start3, +                  data.begin() + start3 + size3); +              AddStartPositions(collocations[pattern], start1, start2, start3); +              pattern.erase(pattern.end() - size3); +            } +          } +        } +      } +    } +  } +} + +void Precomputation::AddStartPositions( +    vector<int>& positions, int pos1, int pos2) { +  positions.push_back(pos1); +  positions.push_back(pos2); +} + +void Precomputation::AddStartPositions( +    vector<int>& positions, int pos1, int pos2, int pos3) { +  positions.push_back(pos1); +  positions.push_back(pos2); +  positions.push_back(pos3); +} + +void Precomputation::WriteBinary(const fs::path& filepath) const { +  FILE* file = fopen(filepath.string().c_str(), "w"); + +  // TODO(pauldb): Refactor this code. +  int size = collocations.size(); +  fwrite(&size, sizeof(int), 1, file); +  for (auto entry: collocations) { +    size = entry.first.size(); +    fwrite(&size, sizeof(int), 1, file); +    fwrite(entry.first.data(), sizeof(int), size, file); + +    size = entry.second.size(); +    fwrite(&size, sizeof(int), 1, file); +    fwrite(entry.second.data(), sizeof(int), size, file); +  } +} + +const Index& Precomputation::GetCollocations() const { +  return collocations; +} + +} // namespace extractor diff --git a/extractor/precomputation.h b/extractor/precomputation.h new file mode 100644 index 00000000..e3c4d26a --- /dev/null +++ b/extractor/precomputation.h @@ -0,0 +1,80 @@ +#ifndef _PRECOMPUTATION_H_ +#define _PRECOMPUTATION_H_ + +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <tuple> +#include <vector> + +#include <boost/filesystem.hpp> +#include <boost/functional/hash.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +namespace extractor { + +typedef boost::hash<vector<int> > VectorHash; +typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; + +class SuffixArray; + +/** + * Data structure wrapping an index with all the occurrences of the most + * frequent discontiguous collocations in the source data. + * + * Let a, b, c be contiguous collocations. The index will contain an entry for + * every collocation of the form: + * - aXb, where a and b are frequent + * - aXbXc, where a and b are super-frequent and c is frequent or + *                b and c are super-frequent and a is frequent. + */ +class Precomputation { + public: +  // Constructs the index using the suffix array. +  Precomputation( +      shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, +      int num_super_frequent_patterns, int max_rule_span, +      int max_rule_symbols, int min_gap_size, +      int max_frequent_phrase_len, int min_frequency); + +  virtual ~Precomputation(); + +  void WriteBinary(const fs::path& filepath) const; + +  // Returns a reference to the index. +  virtual const Index& GetCollocations() const; + +  static int FIRST_NONTERMINAL; +  static int SECOND_NONTERMINAL; + + protected: +  Precomputation(); + + private: +  // Finds the most frequent contiguous collocations. +  vector<vector<int> > FindMostFrequentPatterns( +      shared_ptr<SuffixArray> suffix_array, const vector<int>& data, +      int num_frequent_patterns, int max_frequent_phrase_len, +      int min_frequency); + +  // Given the locations of the frequent contiguous collocations in a sentence, +  // it adds new entries to the index for each discontiguous collocation +  // matching the criteria specified in the class description. +  void AddCollocations( +      const vector<std::tuple<int, int, int> >& matchings, const vector<int>& data, +      int max_rule_span, int min_gap_size, int max_rule_symbols); + +  // Adds an occurrence of a binary collocation. +  void AddStartPositions(vector<int>& positions, int pos1, int pos2); + +  // Adds an occurrence of a ternary collocation. +  void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3); + +  Index collocations; +}; + +} // namespace extractor + +#endif diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc new file mode 100644 index 00000000..363febb7 --- /dev/null +++ b/extractor/precomputation_test.cc @@ -0,0 +1,106 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +#include "mocks/mock_data_array.h" +#include "mocks/mock_suffix_array.h" +#include "precomputation.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class PrecomputationTest : public Test { + protected: +  virtual void SetUp() { +    data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1}; +    data_array = make_shared<MockDataArray>(); +    EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + +    vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13}; +    vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0}; +    suffix_array = make_shared<MockSuffixArray>(); +    EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); +    for (size_t i = 0; i < suffixes.size(); ++i) { +      EXPECT_CALL(*suffix_array, +                  GetSuffix(i)).WillRepeatedly(Return(suffixes[i])); +    } +    EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp)); +  } + +  vector<int> data; +  shared_ptr<MockDataArray> data_array; +  shared_ptr<MockSuffixArray> suffix_array; +}; + +TEST_F(PrecomputationTest, TestCollocations) { +  Precomputation precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); +  Index collocations = precomputation.GetCollocations(); + +  vector<int> key = {2, 3, -1, 2}; +  vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {2, 3, -1, 2, 3}; +  expected_value = {1, 5, 1, 8, 5, 8}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {2, 3, -1, 3}; +  expected_value = {1, 6, 1, 9, 5, 9}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {3, -1, 2}; +  expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {3, -1, 3}; +  expected_value = {2, 6, 2, 9, 6, 9}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {3, -1, 2, 3}; +  expected_value = {2, 5, 2, 8, 6, 8}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {2, -1, 2}; +  expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {2, -1, 2, 3}; +  expected_value = {1, 5, 1, 8, 5, 8}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {2, -1, 3}; +  expected_value = {1, 6, 1, 9, 5, 9}; +  EXPECT_EQ(expected_value, collocations[key]); + +  key = {2, -1, 2, -2, 2}; +  expected_value = {1, 5, 8, 5, 8, 11}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {2, -1, 2, -2, 3}; +  expected_value = {1, 5, 9}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {2, -1, 3, -2, 2}; +  expected_value = {1, 6, 8, 5, 9, 11}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {2, -1, 3, -2, 3}; +  expected_value = {1, 6, 9}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {3, -1, 2, -2, 2}; +  expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {3, -1, 2, -2, 3}; +  expected_value = {2, 5, 9}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {3, -1, 3, -2, 2}; +  expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11}; +  EXPECT_EQ(expected_value, collocations[key]); +  key = {3, -1, 3, -2, 3}; +  expected_value = {2, 6, 9}; +  EXPECT_EQ(expected_value, collocations[key]); + +  // Exceeds max_rule_symbols. +  key = {2, -1, 2, -2, 2, 3}; +  EXPECT_EQ(0, collocations.count(key)); +  // Contains non frequent pattern. +  key = {2, -1, 5}; +  EXPECT_EQ(0, collocations.count(key)); +} + +} // namespace +} // namespace extractor + diff --git a/extractor/rule.cc b/extractor/rule.cc new file mode 100644 index 00000000..b6c7d783 --- /dev/null +++ b/extractor/rule.cc @@ -0,0 +1,14 @@ +#include "rule.h" + +namespace extractor { + +Rule::Rule(const Phrase& source_phrase, +           const Phrase& target_phrase, +           const vector<double>& scores, +           const vector<pair<int, int> >& alignment) : +  source_phrase(source_phrase), +  target_phrase(target_phrase), +  scores(scores), +  alignment(alignment) {} + +} // namespace extractor diff --git a/extractor/rule.h b/extractor/rule.h new file mode 100644 index 00000000..bc95709e --- /dev/null +++ b/extractor/rule.h @@ -0,0 +1,27 @@ +#ifndef _RULE_H_ +#define _RULE_H_ + +#include <vector> + +#include "phrase.h" + +using namespace std; + +namespace extractor { + +/** + * Structure containing the data for a SCFG rule. + */ +struct Rule { +  Rule(const Phrase& source_phrase, const Phrase& target_phrase, +       const vector<double>& scores, const vector<pair<int, int> >& alignment); + +  Phrase source_phrase; +  Phrase target_phrase; +  vector<double> scores; +  vector<pair<int, int> > alignment; +}; + +} // namespace extractor + +#endif diff --git a/extractor/rule_extractor.cc b/extractor/rule_extractor.cc new file mode 100644 index 00000000..fa7386a4 --- /dev/null +++ b/extractor/rule_extractor.cc @@ -0,0 +1,343 @@ +#include "rule_extractor.h" + +#include <map> + +#include "alignment.h" +#include "data_array.h" +#include "features/feature.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule.h" +#include "rule_extractor_helper.h" +#include "scorer.h" +#include "target_phrase_extractor.h" + +using namespace std; + +namespace extractor { + +RuleExtractor::RuleExtractor( +    shared_ptr<DataArray> source_data_array, +    shared_ptr<DataArray> target_data_array, +    shared_ptr<Alignment> alignment, +    shared_ptr<PhraseBuilder> phrase_builder, +    shared_ptr<Scorer> scorer, +    shared_ptr<Vocabulary> vocabulary, +    int max_rule_span, +    int min_gap_size, +    int max_nonterminals, +    int max_rule_symbols, +    bool require_aligned_terminal, +    bool require_aligned_chunks, +    bool require_tight_phrases) : +    target_data_array(target_data_array), +    source_data_array(source_data_array), +    phrase_builder(phrase_builder), +    scorer(scorer), +    max_rule_span(max_rule_span), +    min_gap_size(min_gap_size), +    max_nonterminals(max_nonterminals), +    max_rule_symbols(max_rule_symbols), +    require_tight_phrases(require_tight_phrases) { +  helper = make_shared<RuleExtractorHelper>( +      source_data_array, target_data_array, alignment, max_rule_span, +      max_rule_symbols, require_aligned_terminal, require_aligned_chunks, +      require_tight_phrases); +  target_phrase_extractor = make_shared<TargetPhraseExtractor>( +      target_data_array, alignment, phrase_builder, helper, vocabulary, +      max_rule_span, require_tight_phrases); +} + +RuleExtractor::RuleExtractor( +    shared_ptr<DataArray> source_data_array, +    shared_ptr<PhraseBuilder> phrase_builder, +    shared_ptr<Scorer> scorer, +    shared_ptr<TargetPhraseExtractor> target_phrase_extractor, +    shared_ptr<RuleExtractorHelper> helper, +    int max_rule_span, +    int min_gap_size, +    int max_nonterminals, +    int max_rule_symbols, +    bool require_tight_phrases) : +    source_data_array(source_data_array), +    phrase_builder(phrase_builder), +    scorer(scorer), +    target_phrase_extractor(target_phrase_extractor), +    helper(helper), +    max_rule_span(max_rule_span), +    min_gap_size(min_gap_size), +    max_nonterminals(max_nonterminals), +    max_rule_symbols(max_rule_symbols), +    require_tight_phrases(require_tight_phrases) {} + +RuleExtractor::RuleExtractor() {} + +RuleExtractor::~RuleExtractor() {} + +vector<Rule> RuleExtractor::ExtractRules(const Phrase& phrase, +                                         const PhraseLocation& location) const { +  int num_subpatterns = location.num_subpatterns; +  vector<int> matchings = *location.matchings; + +  // Calculate statistics for the (sampled) occurrences of the source phrase. +  map<Phrase, double> source_phrase_counter; +  map<Phrase, map<Phrase, map<PhraseAlignment, int> > > alignments_counter; +  for (auto i = matchings.begin(); i != matchings.end(); i += num_subpatterns) { +    vector<int> matching(i, i + num_subpatterns); +    vector<Extract> extracts = ExtractAlignments(phrase, matching); + +    for (Extract e: extracts) { +      source_phrase_counter[e.source_phrase] += e.pairs_count; +      alignments_counter[e.source_phrase][e.target_phrase][e.alignment] += 1; +    } +  } + +  // Compute the feature scores and find the most likely (frequent) alignment +  // for each pair of source-target phrases. +  int num_samples = matchings.size() / num_subpatterns; +  vector<Rule> rules; +  for (auto source_phrase_entry: alignments_counter) { +    Phrase source_phrase = source_phrase_entry.first; +    for (auto target_phrase_entry: source_phrase_entry.second) { +      Phrase target_phrase = target_phrase_entry.first; + +      int max_locations = 0, num_locations = 0; +      PhraseAlignment most_frequent_alignment; +      for (auto alignment_entry: target_phrase_entry.second) { +        num_locations += alignment_entry.second; +        if (alignment_entry.second > max_locations) { +          most_frequent_alignment = alignment_entry.first; +          max_locations = alignment_entry.second; +        } +      } + +      features::FeatureContext context(source_phrase, target_phrase, +          source_phrase_counter[source_phrase], num_locations, num_samples); +      vector<double> scores = scorer->Score(context); +      rules.push_back(Rule(source_phrase, target_phrase, scores, +                           most_frequent_alignment)); +    } +  } +  return rules; +} + +vector<Extract> RuleExtractor::ExtractAlignments( +    const Phrase& phrase, const vector<int>& matching) const { +  vector<Extract> extracts; +  int sentence_id = source_data_array->GetSentenceId(matching[0]); +  int source_sent_start = source_data_array->GetSentenceStart(sentence_id); + +  // Get the span in the opposite sentence for each word in the source-target +  // sentece pair. +  vector<int> source_low, source_high, target_low, target_high; +  helper->GetLinksSpans(source_low, source_high, target_low, target_high, +                        sentence_id); + +  int num_subpatterns = matching.size(); +  vector<int> chunklen(num_subpatterns); +  for (size_t i = 0; i < num_subpatterns; ++i) { +    chunklen[i] = phrase.GetChunkLen(i); +  } + +  // Basic checks to see if we can extract phrase pairs for this occurrence. +  if (!helper->CheckAlignedTerminals(matching, chunklen, source_low, +                                     source_sent_start) || +      !helper->CheckTightPhrases(matching, chunklen, source_low, +                                 source_sent_start)) { +    return extracts; +  } + +  int source_back_low = -1, source_back_high = -1; +  int source_phrase_low = matching[0] - source_sent_start; +  int source_phrase_high = matching.back() + chunklen.back() - +                           source_sent_start; +  int target_phrase_low = -1, target_phrase_high = -1; +  // Find target span and reflected source span for the source phrase. +  if (!helper->FindFixPoint(source_phrase_low, source_phrase_high, source_low, +                            source_high, target_phrase_low, target_phrase_high, +                            target_low, target_high, source_back_low, +                            source_back_high, sentence_id, min_gap_size, 0, +                            max_nonterminals - matching.size() + 1, true, true, +                            false)) { +    return extracts; +  } + +  // Get spans for nonterminal gaps. +  bool met_constraints = true; +  int num_symbols = phrase.GetNumSymbols(); +  vector<pair<int, int> > source_gaps, target_gaps; +  if (!helper->GetGaps(source_gaps, target_gaps, matching, chunklen, source_low, +                       source_high, target_low, target_high, source_phrase_low, +                       source_phrase_high, source_back_low, source_back_high, +                       sentence_id, source_sent_start, num_symbols, +                       met_constraints)) { +    return extracts; +  } + +  // Find target phrases aligned with the initial source phrase. +  bool starts_with_x = source_back_low != source_phrase_low; +  bool ends_with_x = source_back_high != source_phrase_high; +  Phrase source_phrase = phrase_builder->Extend( +      phrase, starts_with_x, ends_with_x); +  unordered_map<int, int> source_indexes = helper->GetSourceIndexes( +      matching, chunklen, starts_with_x, source_sent_start); +  if (met_constraints) { +    AddExtracts(extracts, source_phrase, source_indexes, target_gaps, +                target_low, target_phrase_low, target_phrase_high, sentence_id); +  } + +  if (source_gaps.size() >= max_nonterminals || +      source_phrase.GetNumSymbols() >= max_rule_symbols || +      source_back_high - source_back_low + min_gap_size > max_rule_span) { +    // Cannot add any more nonterminals. +    return extracts; +  } + +  // Extend the source phrase by adding a leading and/or trailing nonterminal +  // and find target phrases aligned with the extended source phrase. +  for (int i = 0; i < 2; ++i) { +    for (int j = 1 - i; j < 2; ++j) { +      AddNonterminalExtremities(extracts, matching, chunklen, source_phrase, +          source_back_low, source_back_high, source_low, source_high, +          target_low, target_high, target_gaps, sentence_id, source_sent_start, +          starts_with_x, ends_with_x, i, j); +    } +  } + +  return extracts; +} + +void RuleExtractor::AddExtracts( +    vector<Extract>& extracts, const Phrase& source_phrase, +    const unordered_map<int, int>& source_indexes, +    const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, +    int target_phrase_low, int target_phrase_high, int sentence_id) const { +  auto target_phrases = target_phrase_extractor->ExtractPhrases( +      target_gaps, target_low, target_phrase_low, target_phrase_high, +      source_indexes, sentence_id); + +  if (target_phrases.size() > 0) { +    // Split the probability equally across all target phrases that can be +    // aligned with a single occurrence of the source phrase. +    double pairs_count = 1.0 / target_phrases.size(); +    for (auto target_phrase: target_phrases) { +      extracts.push_back(Extract(source_phrase, target_phrase.first, +                                 pairs_count, target_phrase.second)); +    } +  } +} + +void RuleExtractor::AddNonterminalExtremities( +    vector<Extract>& extracts, const vector<int>& matching, +    const vector<int>& chunklen, const Phrase& source_phrase, +    int source_back_low, int source_back_high, const vector<int>& source_low, +    const vector<int>& source_high, const vector<int>& target_low, +    const vector<int>& target_high, vector<pair<int, int> > target_gaps, +    int sentence_id, int source_sent_start, int starts_with_x, int ends_with_x, +    int extend_left, int extend_right) const { +  int source_x_low = source_back_low, source_x_high = source_back_high; + +  // Check if the extended source phrase will remain tight. +  if (require_tight_phrases) { +    if (source_low[source_back_low - extend_left] == -1 || +        source_low[source_back_high + extend_right - 1] == -1) { +      return; +    } +  } + +  // Check if we can add a nonterminal to the left. +  if (extend_left) { +    if (starts_with_x || source_back_low < min_gap_size) { +      return; +    } + +    source_x_low = source_back_low - min_gap_size; +    if (require_tight_phrases) { +      while (source_x_low >= 0 && source_low[source_x_low] == -1) { +        --source_x_low; +      } +    } +    if (source_x_low < 0) { +      return; +    } +  } + +  // Check if we can add a nonterminal to the right. +  if (extend_right) { +    int source_sent_len = source_data_array->GetSentenceLength(sentence_id); +    if (ends_with_x || source_back_high + min_gap_size > source_sent_len) { +      return; +    } +    source_x_high = source_back_high + min_gap_size; +    if (require_tight_phrases) { +      while (source_x_high <= source_sent_len && +             source_low[source_x_high - 1] == -1) { +        ++source_x_high; +      } +    } + +    if (source_x_high > source_sent_len) { +      return; +    } +  } + +  // More length checks. +  int new_nonterminals = extend_left + extend_right; +  if (source_x_high - source_x_low > max_rule_span || +      target_gaps.size() + new_nonterminals > max_nonterminals || +      source_phrase.GetNumSymbols() + new_nonterminals > max_rule_symbols) { +    return; +  } + +  // Find the target span for the extended phrase and the reflected source span. +  int target_x_low = -1, target_x_high = -1; +  if (!helper->FindFixPoint(source_x_low, source_x_high, source_low, +                            source_high, target_x_low, target_x_high, +                            target_low, target_high, source_x_low, +                            source_x_high, sentence_id, 1, 1, +                            new_nonterminals, extend_left, extend_right, +                            true)) { +    return; +  } + +  // Check gap integrity for the leading nonterminal. +  if (extend_left) { +    int source_gap_low = -1, source_gap_high = -1; +    int target_gap_low = -1, target_gap_high = -1; +    if ((require_tight_phrases && source_low[source_x_low] == -1) || +        !helper->FindFixPoint(source_x_low, source_back_low, source_low, +                              source_high, target_gap_low, target_gap_high, +                              target_low, target_high, source_gap_low, +                              source_gap_high, sentence_id, 0, 0, 0, false, +                              false, false)) { +      return; +    } +    target_gaps.insert(target_gaps.begin(), +                       make_pair(target_gap_low, target_gap_high)); +  } + +  // Check gap integrity for the trailing nonterminal. +  if (extend_right) { +    int target_gap_low = -1, target_gap_high = -1; +    int source_gap_low = -1, source_gap_high = -1; +    if ((require_tight_phrases && source_low[source_x_high - 1] == -1) || +        !helper->FindFixPoint(source_back_high, source_x_high, source_low, +                              source_high, target_gap_low, target_gap_high, +                              target_low, target_high, source_gap_low, +                              source_gap_high, sentence_id, 0, 0, 0, false, +                              false, false)) { +      return; +    } +    target_gaps.push_back(make_pair(target_gap_low, target_gap_high)); +  } + +  // Find target phrases aligned with the extended source phrase. +  Phrase new_source_phrase = phrase_builder->Extend(source_phrase, extend_left, +                                                    extend_right); +  unordered_map<int, int> source_indexes = helper->GetSourceIndexes( +      matching, chunklen, extend_left || starts_with_x, source_sent_start); +  AddExtracts(extracts, new_source_phrase, source_indexes, target_gaps, +              target_low, target_x_low, target_x_high, sentence_id); +} + +} // namespace extractor diff --git a/extractor/rule_extractor.h b/extractor/rule_extractor.h new file mode 100644 index 00000000..26e6f21c --- /dev/null +++ b/extractor/rule_extractor.h @@ -0,0 +1,124 @@ +#ifndef _RULE_EXTRACTOR_H_ +#define _RULE_EXTRACTOR_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +#include "phrase.h" + +using namespace std; + +namespace extractor { + +typedef vector<pair<int, int> > PhraseAlignment; + +class Alignment; +class DataArray; +class PhraseBuilder; +class PhraseLocation; +class Rule; +class RuleExtractorHelper; +class Scorer; +class TargetPhraseExtractor; + +/** + * Structure containing data about the occurrences of a source-target phrase pair + * in the parallel corpus. + */ +struct Extract { +  Extract(const Phrase& source_phrase, const Phrase& target_phrase, +          double pairs_count, const PhraseAlignment& alignment) : +      source_phrase(source_phrase), target_phrase(target_phrase), +      pairs_count(pairs_count), alignment(alignment) {} + +  Phrase source_phrase; +  Phrase target_phrase; +  double pairs_count; +  PhraseAlignment alignment; +}; + +/** + * Component for extracting SCFG rules. + */ +class RuleExtractor { + public: +  RuleExtractor(shared_ptr<DataArray> source_data_array, +                shared_ptr<DataArray> target_data_array, +                shared_ptr<Alignment> alingment, +                shared_ptr<PhraseBuilder> phrase_builder, +                shared_ptr<Scorer> scorer, +                shared_ptr<Vocabulary> vocabulary, +                int min_gap_size, +                int max_rule_span, +                int max_nonterminals, +                int max_rule_symbols, +                bool require_aligned_terminal, +                bool require_aligned_chunks, +                bool require_tight_phrases); + +  // For testing only. +  RuleExtractor(shared_ptr<DataArray> source_data_array, +                shared_ptr<PhraseBuilder> phrase_builder, +                shared_ptr<Scorer> scorer, +                shared_ptr<TargetPhraseExtractor> target_phrase_extractor, +                shared_ptr<RuleExtractorHelper> helper, +                int max_rule_span, +                int min_gap_size, +                int max_nonterminals, +                int max_rule_symbols, +                bool require_tight_phrases); + +  virtual ~RuleExtractor(); + +  // Extracts SCFG rules given a source phrase and a set of its occurrences +  // in the source data. +  virtual vector<Rule> ExtractRules(const Phrase& phrase, +                                    const PhraseLocation& location) const; + + protected: +  RuleExtractor(); + + private: +  // Finds all target phrases that can be aligned with the source phrase for a +  // particular occurrence in the data. +  vector<Extract> ExtractAlignments(const Phrase& phrase, +                                    const vector<int>& matching) const; + +  // Extracts all target phrases for a given occurrence of the source phrase in +  // the data. Constructs a vector of Extracts using these target phrases. +  void AddExtracts( +      vector<Extract>& extracts, const Phrase& source_phrase, +      const unordered_map<int, int>& source_indexes, +      const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, +      int target_phrase_low, int target_phrase_high, int sentence_id) const; + +  // Adds a leading and/or trailing nonterminal to the source phrase and +  // extracts target phrases that can be aligned with the extended source +  // phrase. +  void AddNonterminalExtremities( +      vector<Extract>& extracts, const vector<int>& matching, +      const vector<int>& chunklen, const Phrase& source_phrase, +      int source_back_low, int source_back_high, const vector<int>& source_low, +      const vector<int>& source_high, const vector<int>& target_low, +      const vector<int>& target_high, vector<pair<int, int> > target_gaps, +      int sentence_id, int source_sent_start, int starts_with_x, +      int ends_with_x, int extend_left, int extend_right) const; + + private: +  shared_ptr<DataArray> target_data_array; +  shared_ptr<DataArray> source_data_array; +  shared_ptr<PhraseBuilder> phrase_builder; +  shared_ptr<Scorer> scorer; +  shared_ptr<TargetPhraseExtractor> target_phrase_extractor; +  shared_ptr<RuleExtractorHelper> helper; +  int max_rule_span; +  int min_gap_size; +  int max_nonterminals; +  int max_rule_symbols; +  bool require_tight_phrases; +}; + +} // namespace extractor + +#endif diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc new file mode 100644 index 00000000..8a9516f2 --- /dev/null +++ b/extractor/rule_extractor_helper.cc @@ -0,0 +1,362 @@ +#include "rule_extractor_helper.h" + +#include "data_array.h" +#include "alignment.h" + +namespace extractor { + +RuleExtractorHelper::RuleExtractorHelper( +    shared_ptr<DataArray> source_data_array, +    shared_ptr<DataArray> target_data_array, +    shared_ptr<Alignment> alignment, +    int max_rule_span, +    int max_rule_symbols, +    bool require_aligned_terminal, +    bool require_aligned_chunks, +    bool require_tight_phrases) : +    source_data_array(source_data_array), +    target_data_array(target_data_array), +    alignment(alignment), +    max_rule_span(max_rule_span), +    max_rule_symbols(max_rule_symbols), +    require_aligned_terminal(require_aligned_terminal), +    require_aligned_chunks(require_aligned_chunks), +    require_tight_phrases(require_tight_phrases) {} + +RuleExtractorHelper::RuleExtractorHelper() {} + +RuleExtractorHelper::~RuleExtractorHelper() {} + +void RuleExtractorHelper::GetLinksSpans( +    vector<int>& source_low, vector<int>& source_high, +    vector<int>& target_low, vector<int>& target_high, int sentence_id) const { +  int source_sent_len = source_data_array->GetSentenceLength(sentence_id); +  int target_sent_len = target_data_array->GetSentenceLength(sentence_id); +  source_low = vector<int>(source_sent_len, -1); +  source_high = vector<int>(source_sent_len, -1); + +  target_low = vector<int>(target_sent_len, -1); +  target_high = vector<int>(target_sent_len, -1); +  vector<pair<int, int> > links = alignment->GetLinks(sentence_id); +  for (auto link: links) { +    if (source_low[link.first] == -1 || source_low[link.first] > link.second) { +      source_low[link.first] = link.second; +    } +    source_high[link.first] = max(source_high[link.first], link.second + 1); + +    if (target_low[link.second] == -1 || target_low[link.second] > link.first) { +      target_low[link.second] = link.first; +    } +    target_high[link.second] = max(target_high[link.second], link.first + 1); +  } +} + +bool RuleExtractorHelper::CheckAlignedTerminals( +    const vector<int>& matching, +    const vector<int>& chunklen, +    const vector<int>& source_low, +    int source_sent_start) const { +  if (!require_aligned_terminal) { +    return true; +  } + +  int num_aligned_chunks = 0; +  for (size_t i = 0; i < chunklen.size(); ++i) { +    for (size_t j = 0; j < chunklen[i]; ++j) { +      int sent_index = matching[i] - source_sent_start + j; +      if (source_low[sent_index] != -1) { +        ++num_aligned_chunks; +        break; +      } +    } +  } + +  if (num_aligned_chunks == 0) { +    return false; +  } + +  return !require_aligned_chunks || num_aligned_chunks == chunklen.size(); +} + +bool RuleExtractorHelper::CheckTightPhrases( +    const vector<int>& matching, +    const vector<int>& chunklen, +    const vector<int>& source_low, +    int source_sent_start) const { +  if (!require_tight_phrases) { +    return true; +  } + +  // Check if the chunk extremities are aligned. +  for (size_t i = 0; i + 1 < chunklen.size(); ++i) { +    int gap_start = matching[i] + chunklen[i] - source_sent_start; +    int gap_end = matching[i + 1] - 1 - source_sent_start; +    if (source_low[gap_start] == -1 || source_low[gap_end] == -1) { +      return false; +    } +  } + +  return true; +} + +bool RuleExtractorHelper::FindFixPoint( +    int source_phrase_low, int source_phrase_high, +    const vector<int>& source_low, const vector<int>& source_high, +    int& target_phrase_low, int& target_phrase_high, +    const vector<int>& target_low, const vector<int>& target_high, +    int& source_back_low, int& source_back_high, int sentence_id, +    int min_source_gap_size, int min_target_gap_size, +    int max_new_x, bool allow_low_x, bool allow_high_x, +    bool allow_arbitrary_expansion) const { +  int prev_target_low = target_phrase_low; +  int prev_target_high = target_phrase_high; + +  FindProjection(source_phrase_low, source_phrase_high, source_low, +                 source_high, target_phrase_low, target_phrase_high); + +  if (target_phrase_low == -1) { +    // Note: Low priority corner case inherited from Adam's code: +    // If w is unaligned, but we don't require aligned terminals, returning an +    // error here prevents the extraction of the allowed rule +    // X -> X_1 w X_2 / X_1 X_2 +    return false; +  } + +  int source_sent_len = source_data_array->GetSentenceLength(sentence_id); +  int target_sent_len = target_data_array->GetSentenceLength(sentence_id); +  // Extend the target span to the left. +  if (prev_target_low != -1 && target_phrase_low != prev_target_low) { +    if (prev_target_low - target_phrase_low < min_target_gap_size) { +      target_phrase_low = prev_target_low - min_target_gap_size; +      if (target_phrase_low < 0) { +        return false; +      } +    } +  } + +  // Extend the target span to the right. +  if (prev_target_high != -1 && target_phrase_high != prev_target_high) { +    if (target_phrase_high - prev_target_high < min_target_gap_size) { +      target_phrase_high = prev_target_high + min_target_gap_size; +      if (target_phrase_high > target_sent_len) { +        return false; +      } +    } +  } + +  // Check target span length. +  if (target_phrase_high - target_phrase_low > max_rule_span) { +    return false; +  } + +  // Find the initial reflected source span. +  source_back_low = source_back_high = -1; +  FindProjection(target_phrase_low, target_phrase_high, target_low, target_high, +                 source_back_low, source_back_high); +  int new_x = 0; +  bool new_low_x = false, new_high_x = false; +  while (true) { +    source_back_low = min(source_back_low, source_phrase_low); +    source_back_high = max(source_back_high, source_phrase_high); + +    // Stop if the reflected source span matches the previous source span. +    if (source_back_low == source_phrase_low && +        source_back_high == source_phrase_high) { +      return true; +    } + +    if (!allow_low_x && source_back_low < source_phrase_low) { +      // Extension on the left side not allowed. +      return false; +    } +    if (!allow_high_x && source_back_high > source_phrase_high) { +      // Extension on the right side not allowed. +      return false; +    } + +    // Extend left side. +    if (source_back_low < source_phrase_low) { +      if (new_low_x == false) { +        if (new_x >= max_new_x) { +          return false; +        } +        new_low_x = true; +        ++new_x; +      } +      if (source_phrase_low - source_back_low < min_source_gap_size) { +        source_back_low = source_phrase_low - min_source_gap_size; +        if (source_back_low < 0) { +          return false; +        } +      } +    } + +    // Extend right side. +    if (source_back_high > source_phrase_high) { +      if (new_high_x == false) { +        if (new_x >= max_new_x) { +          return false; +        } +        new_high_x = true; +        ++new_x; +      } +      if (source_back_high - source_phrase_high < min_source_gap_size) { +        source_back_high = source_phrase_high + min_source_gap_size; +        if (source_back_high > source_sent_len) { +          return false; +        } +      } +    } + +    if (source_back_high - source_back_low > max_rule_span) { +      // Rule span too wide. +      return false; +    } + +    prev_target_low = target_phrase_low; +    prev_target_high = target_phrase_high; +    // Find the reflection including the left gap (if one was added). +    FindProjection(source_back_low, source_phrase_low, source_low, source_high, +                   target_phrase_low, target_phrase_high); +    // Find the reflection including the right gap (if one was added). +    FindProjection(source_phrase_high, source_back_high, source_low, +                   source_high, target_phrase_low, target_phrase_high); +    // Stop if the new re-reflected target span matches the previous target +    // span. +    if (prev_target_low == target_phrase_low && +        prev_target_high == target_phrase_high) { +      return true; +    } + +    if (!allow_arbitrary_expansion) { +      // Arbitrary expansion not allowed. +      return false; +    } +    if (target_phrase_high - target_phrase_low > max_rule_span) { +      // Target side too wide. +      return false; +    } + +    source_phrase_low = source_back_low; +    source_phrase_high = source_back_high; +    // Re-reflect the target span. +    FindProjection(target_phrase_low, prev_target_low, target_low, target_high, +                   source_back_low, source_back_high); +    FindProjection(prev_target_high, target_phrase_high, target_low, +                   target_high, source_back_low, source_back_high); +  } + +  return false; +} + +void RuleExtractorHelper::FindProjection( +    int source_phrase_low, int source_phrase_high, +    const vector<int>& source_low, const vector<int>& source_high, +    int& target_phrase_low, int& target_phrase_high) const { +  for (size_t i = source_phrase_low; i < source_phrase_high; ++i) { +    if (source_low[i] != -1) { +      if (target_phrase_low == -1 || source_low[i] < target_phrase_low) { +        target_phrase_low = source_low[i]; +      } +      target_phrase_high = max(target_phrase_high, source_high[i]); +    } +  } +} + +bool RuleExtractorHelper::GetGaps( +     vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, +     const vector<int>& matching, const vector<int>& chunklen, +     const vector<int>& source_low, const vector<int>& source_high, +     const vector<int>& target_low, const vector<int>& target_high, +     int source_phrase_low, int source_phrase_high, int source_back_low, +     int source_back_high, int sentence_id, int source_sent_start, +     int& num_symbols, bool& met_constraints) const { +  if (source_back_low < source_phrase_low) { +    source_gaps.push_back(make_pair(source_back_low, source_phrase_low)); +    if (num_symbols >= max_rule_symbols) { +      // Source side contains too many symbols. +      return false; +    } +    ++num_symbols; +    if (require_tight_phrases && (source_low[source_back_low] == -1 || +        source_low[source_phrase_low - 1] == -1)) { +      // Inside edges of preceding gap are not tight. +      return false; +    } +  } else if (require_tight_phrases && source_low[source_phrase_low] == -1) { +    // This is not a hard error. We can't extract this phrase, but we might +    // still be able to extract a superphrase. +    met_constraints = false; +  } + +  for (size_t i = 0; i + 1 < chunklen.size(); ++i) { +    int gap_start = matching[i] + chunklen[i] - source_sent_start; +    int gap_end = matching[i + 1] - source_sent_start; +    source_gaps.push_back(make_pair(gap_start, gap_end)); +  } + +  if (source_phrase_high < source_back_high) { +    source_gaps.push_back(make_pair(source_phrase_high, source_back_high)); +    if (num_symbols >= max_rule_symbols) { +      // Source side contains too many symbols. +      return false; +    } +    ++num_symbols; +    if (require_tight_phrases && (source_low[source_phrase_high] == -1 || +        source_low[source_back_high - 1] == -1)) { +      // Inside edges of following gap are not tight. +      return false; +    } +  } else if (require_tight_phrases && +             source_low[source_phrase_high - 1] == -1) { +    // This is not a hard error. We can't extract this phrase, but we might +    // still be able to extract a superphrase. +    met_constraints = false; +  } + +  target_gaps.resize(source_gaps.size(), make_pair(-1, -1)); +  for (size_t i = 0; i < source_gaps.size(); ++i) { +    if (!FindFixPoint(source_gaps[i].first, source_gaps[i].second, source_low, +                      source_high, target_gaps[i].first, target_gaps[i].second, +                      target_low, target_high, source_gaps[i].first, +                      source_gaps[i].second, sentence_id, 0, 0, 0, false, false, +                      false)) { +      // Gap fails integrity check. +      return false; +    } +  } + +  return true; +} + +vector<int> RuleExtractorHelper::GetGapOrder( +    const vector<pair<int, int> >& gaps) const { +  vector<int> gap_order(gaps.size()); +  for (size_t i = 0; i < gap_order.size(); ++i) { +    for (size_t j = 0; j < i; ++j) { +      if (gaps[gap_order[j]] < gaps[i]) { +        ++gap_order[i]; +      } else { +        ++gap_order[j]; +      } +    } +  } +  return gap_order; +} + +unordered_map<int, int> RuleExtractorHelper::GetSourceIndexes( +    const vector<int>& matching, const vector<int>& chunklen, +    int starts_with_x, int source_sent_start) const { + unordered_map<int, int> source_indexes; + int num_symbols = starts_with_x; + for (size_t i = 0; i < matching.size(); ++i) { +   for (size_t j = 0; j < chunklen[i]; ++j) { +     source_indexes[matching[i] + j - source_sent_start] = num_symbols; +     ++num_symbols; +   } +   ++num_symbols; + } + return source_indexes; +} + +} // namespace extractor diff --git a/extractor/rule_extractor_helper.h b/extractor/rule_extractor_helper.h new file mode 100644 index 00000000..d4ae45d4 --- /dev/null +++ b/extractor/rule_extractor_helper.h @@ -0,0 +1,101 @@ +#ifndef _RULE_EXTRACTOR_HELPER_H_ +#define _RULE_EXTRACTOR_HELPER_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +using namespace std; + +namespace extractor { + +class Alignment; +class DataArray; + +/** + * Helper class for extracting SCFG rules. + */ +class RuleExtractorHelper { + public: +  RuleExtractorHelper(shared_ptr<DataArray> source_data_array, +                      shared_ptr<DataArray> target_data_array, +                      shared_ptr<Alignment> alignment, +                      int max_rule_span, +                      int max_rule_symbols, +                      bool require_aligned_terminal, +                      bool require_aligned_chunks, +                      bool require_tight_phrases); + +  virtual ~RuleExtractorHelper(); + +  // Find the alignment span for each word in the source target sentence pair. +  virtual void GetLinksSpans(vector<int>& source_low, vector<int>& source_high, +                             vector<int>& target_low, vector<int>& target_high, +                             int sentence_id) const; + +  // Check if one chunk (all chunks) is aligned at least in one point. +  virtual bool CheckAlignedTerminals(const vector<int>& matching, +                                     const vector<int>& chunklen, +                                     const vector<int>& source_low, +                                     int source_sent_start) const; + +  // Check if the chunks are tight. +  virtual bool CheckTightPhrases(const vector<int>& matching, +                                 const vector<int>& chunklen, +                                 const vector<int>& source_low, +                                 int source_sent_start) const; + +  // Find the target span and the reflected source span for a source phrase +  // occurrence. +  virtual bool FindFixPoint( +      int source_phrase_low, int source_phrase_high, +      const vector<int>& source_low, const vector<int>& source_high, +      int& target_phrase_low, int& target_phrase_high, +      const vector<int>& target_low, const vector<int>& target_high, +      int& source_back_low, int& source_back_high, int sentence_id, +      int min_source_gap_size, int min_target_gap_size, +      int max_new_x, bool allow_low_x, bool allow_high_x, +      bool allow_arbitrary_expansion) const; + +  // Find the gap spans for each nonterminal in the source phrase. +  virtual bool GetGaps( +      vector<pair<int, int> >& source_gaps, vector<pair<int, int> >& target_gaps, +      const vector<int>& matching, const vector<int>& chunklen, +      const vector<int>& source_low, const vector<int>& source_high, +      const vector<int>& target_low, const vector<int>& target_high, +      int source_phrase_low, int source_phrase_high, int source_back_low, +      int source_back_high, int sentence_id, int source_sent_start, +      int& num_symbols, bool& met_constraints) const; + +  // Get the order of the nonterminals in the target phrase. +  virtual vector<int> GetGapOrder(const vector<pair<int, int> >& gaps) const; + +  // Map each terminal symbol with its position in the source phrase. +  virtual unordered_map<int, int> GetSourceIndexes( +      const vector<int>& matching, const vector<int>& chunklen, +      int starts_with_x, int source_sent_start) const; + + protected: +  RuleExtractorHelper(); + + private: +  // Find the projection of a source phrase in the target sentence. May also be +  // used to find the projection of a target phrase in the source sentence. +  void FindProjection( +      int source_phrase_low, int source_phrase_high, +      const vector<int>& source_low, const vector<int>& source_high, +      int& target_phrase_low, int& target_phrase_high) const; + +  shared_ptr<DataArray> source_data_array; +  shared_ptr<DataArray> target_data_array; +  shared_ptr<Alignment> alignment; +  int max_rule_span; +  int max_rule_symbols; +  bool require_aligned_terminal; +  bool require_aligned_chunks; +  bool require_tight_phrases; +}; + +} // namespace extractor + +#endif diff --git a/extractor/rule_extractor_helper_test.cc b/extractor/rule_extractor_helper_test.cc new file mode 100644 index 00000000..9b82abb1 --- /dev/null +++ b/extractor/rule_extractor_helper_test.cc @@ -0,0 +1,645 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "rule_extractor_helper.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class RuleExtractorHelperTest : public Test { + protected: +  virtual void SetUp() { +    source_data_array = make_shared<MockDataArray>(); +    EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +        .WillRepeatedly(Return(12)); + +    target_data_array = make_shared<MockDataArray>(); +    EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +        .WillRepeatedly(Return(12)); + +    vector<pair<int, int> > links = { +      make_pair(0, 0), make_pair(0, 1), make_pair(2, 2), make_pair(3, 1) +    }; +    alignment = make_shared<MockAlignment>(); +    EXPECT_CALL(*alignment, GetLinks(_)).WillRepeatedly(Return(links)); +  } + +  shared_ptr<MockDataArray> source_data_array; +  shared_ptr<MockDataArray> target_data_array; +  shared_ptr<MockAlignment> alignment; +  shared_ptr<RuleExtractorHelper> helper; +}; + +TEST_F(RuleExtractorHelperTest, TestGetLinksSpans) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(4)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(3)); + +  vector<int> source_low, source_high, target_low, target_high; +  helper->GetLinksSpans(source_low, source_high, target_low, target_high, 0); + +  vector<int> expected_source_low = {0, -1, 2, 1}; +  EXPECT_EQ(expected_source_low, source_low); +  vector<int> expected_source_high = {2, -1, 3, 2}; +  EXPECT_EQ(expected_source_high, source_high); +  vector<int> expected_target_low = {0, 0, 2}; +  EXPECT_EQ(expected_target_low, target_low); +  vector<int> expected_target_high = {1, 4, 3}; +  EXPECT_EQ(expected_target_high, target_high); +} + +TEST_F(RuleExtractorHelperTest, TestCheckAlignedFalse) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, false, false, true); +  EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0); +  EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0); + +  vector<int> matching, chunklen, source_low; +  EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, +                                            source_low, 10)); +} + +TEST_F(RuleExtractorHelperTest, TestCheckAlignedTerminal) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, false, true); + +  vector<int> matching = {10, 12}; +  vector<int> chunklen = {1, 3}; +  vector<int> source_low = {-1, 1, -1, 3, -1}; +  EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, +                                            source_low, 10)); +  source_low = {-1, 1, -1, -1, -1}; +  EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, +                                             source_low, 10)); +} + +TEST_F(RuleExtractorHelperTest, TestCheckAlignedChunks) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); + +  vector<int> matching = {10, 12}; +  vector<int> chunklen = {1, 3}; +  vector<int> source_low = {2, 1, -1, 3, -1}; +  EXPECT_TRUE(helper->CheckAlignedTerminals(matching, chunklen, +                                            source_low, 10)); +  source_low = {-1, 1, -1, 3, -1}; +  EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, +                                             source_low, 10)); +  source_low = {2, 1, -1, -1, -1}; +  EXPECT_FALSE(helper->CheckAlignedTerminals(matching, chunklen, +                                             source_low, 10)); +} + + +TEST_F(RuleExtractorHelperTest, TestCheckTightPhrasesFalse) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, false); +  EXPECT_CALL(*source_data_array, GetSentenceId(_)).Times(0); +  EXPECT_CALL(*source_data_array, GetSentenceStart(_)).Times(0); + +  vector<int> matching, chunklen, source_low; +  EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low, 10)); +} + +TEST_F(RuleExtractorHelperTest, TestCheckTightPhrases) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); + +  vector<int> matching = {10, 14, 18}; +  vector<int> chunklen = {2, 3, 1}; +  // No missing links. +  vector<int> source_low = {0, 1, 2, 3, 4, 5, 6, 7, 8}; +  EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low, 10)); + +  // Missing link at the beginning or ending of a gap. +  source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8}; +  EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low, 10)); +  source_low = {0, 1, 2, -1, 4, 5, 6, 7, 8}; +  EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low, 10)); +  source_low = {0, 1, 2, 3, 4, 5, 6, -1, 8}; +  EXPECT_FALSE(helper->CheckTightPhrases(matching, chunklen, source_low, 10)); + +  // Missing link inside the gap. +  chunklen = {1, 3, 1}; +  source_low = {0, 1, -1, 3, 4, 5, 6, 7, 8}; +  EXPECT_TRUE(helper->CheckTightPhrases(matching, chunklen, source_low, 10)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointBadEdgeCase) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); + +  vector<int> source_low = {0, -1, 2}; +  vector<int> source_high = {1, -1, 3}; +  vector<int> target_low = {0, -1, 2}; +  vector<int> target_high = {1, -1, 3}; +  int source_phrase_low = 1, source_phrase_high = 2; +  int source_back_low, source_back_high; +  int target_phrase_low = -1, target_phrase_high = 1; + +  // This should be in fact true. See comment about the inherited bug. +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 0, 0, +                                    0, false, false, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSentenceOutOfBounds) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(3)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(3)); + +  vector<int> source_low = {0, 0, 2}; +  vector<int> source_high = {1, 2, 3}; +  vector<int> target_low = {0, 1, 2}; +  vector<int> target_high = {2, 2, 3}; +  int source_phrase_low = 1, source_phrase_high = 2; +  int source_back_low, source_back_high; +  int target_phrase_low = 1, target_phrase_high = 2; + +  // Extend out of sentence to left. +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 2, 2, +                                    0, false, false, false)); +  source_low = {0, 1, 2}; +  source_high = {1, 3, 3}; +  target_low = {0, 1, 1}; +  target_high = {1, 2, 3}; +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 2, 2, +                                    0, false, false, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetTooWide) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 5, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); + +  vector<int> source_low = {0, 0, 0, 0, 0, 0, 0}; +  vector<int> source_high = {7, 7, 7, 7, 7, 7, 7}; +  vector<int> target_low = {0, -1, -1, -1, -1, -1, 0}; +  vector<int> target_high = {7, -1, -1, -1, -1, -1, 7}; +  int source_phrase_low = 2, source_phrase_high = 5; +  int source_back_low, source_back_high; +  int target_phrase_low = -1, target_phrase_high = -1; + +  // Projection is too wide. +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 1, 1, +                                    0, false, false, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPoint) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); + +  vector<int> source_low = {1, 1, 1, 3, 4, 5, 5}; +  vector<int> source_high = {2, 2, 3, 4, 6, 6, 6}; +  vector<int> target_low = {-1, 0, 2, 3, 4, 4, -1}; +  vector<int> target_high = {-1, 3, 3, 4, 5, 7, -1}; +  int source_phrase_low = 2, source_phrase_high = 5; +  int source_back_low, source_back_high; +  int target_phrase_low = 2, target_phrase_high = 5; + +  EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                   source_low, source_high, target_phrase_low, +                                   target_phrase_high, target_low, target_high, +                                   source_back_low, source_back_high, 1, 1, 1, +                                   2, true, true, false)); +  EXPECT_EQ(1, target_phrase_low); +  EXPECT_EQ(6, target_phrase_high); +  EXPECT_EQ(0, source_back_low); +  EXPECT_EQ(7, source_back_high); + +  source_low = {0, -1, 1, 3, 4, -1, 6}; +  source_high = {1, -1, 3, 4, 6, -1, 7}; +  target_low = {0, 2, 2, 3, 4, 4, 6}; +  target_high = {1, 3, 3, 4, 5, 5, 7}; +  source_phrase_low = 2, source_phrase_high = 5; +  target_phrase_low = -1, target_phrase_high = -1; +  EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                   source_low, source_high, target_phrase_low, +                                   target_phrase_high, target_low, target_high, +                                   source_back_low, source_back_high, 1, 1, 1, +                                   2, true, true, false)); +  EXPECT_EQ(1, target_phrase_low); +  EXPECT_EQ(6, target_phrase_high); +  EXPECT_EQ(2, source_back_low); +  EXPECT_EQ(5, source_back_high); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointExtensionsNotAllowed) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(3)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(3)); + +  vector<int> source_low = {0, 0, 2}; +  vector<int> source_high = {1, 2, 3}; +  vector<int> target_low = {0, 1, 2}; +  vector<int> target_high = {2, 2, 3}; +  int source_phrase_low = 1, source_phrase_high = 2; +  int source_back_low, source_back_high; +  int target_phrase_low = -1, target_phrase_high = -1; + +  // Extension on the left side not allowed. +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 1, 1, +                                    1, false, true, false)); +  // Extension on the left side is allowed, but we can't add anymore X. +  target_phrase_low = -1, target_phrase_high = -1; +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 1, 1, +                                    0, true, true, false)); +  source_low = {0, 1, 2}; +  source_high = {1, 3, 3}; +  target_low = {0, 1, 1}; +  target_high = {1, 2, 3}; +  // Extension on the right side not allowed. +  target_phrase_low = -1, target_phrase_high = -1; +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 1, 1, +                                    1, true, false, false)); +  // Extension on the right side is allowed, but we can't add anymore X. +  target_phrase_low = -1, target_phrase_high = -1; +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 1, 1, +                                    0, true, true, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointSourceSentenceOutOfBounds) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(3)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(3)); + +  vector<int> source_low = {0, 0, 2}; +  vector<int> source_high = {1, 2, 3}; +  vector<int> target_low = {0, 1, 2}; +  vector<int> target_high = {2, 2, 3}; +  int source_phrase_low = 1, source_phrase_high = 2; +  int source_back_low, source_back_high; +  int target_phrase_low = 1, target_phrase_high = 2; +  // Extend out of sentence to left. +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 2, 1, +                                    1, true, true, false)); +  source_low = {0, 1, 2}; +  source_high = {1, 3, 3}; +  target_low = {0, 1, 1}; +  target_high = {1, 2, 3}; +  target_phrase_low = 1, target_phrase_high = 2; +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 2, 1, +                                    1, true, true, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointTargetSourceWide) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 5, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); + +  vector<int> source_low = {2, -1, 2, 3, 4, -1, 4}; +  vector<int> source_high = {3, -1, 3, 4, 5, -1, 5}; +  vector<int> target_low = {-1, -1, 0, 3, 4, -1, -1}; +  vector<int> target_high = {-1, -1, 3, 4, 7, -1, -1}; +  int source_phrase_low = 2, source_phrase_high = 5; +  int source_back_low, source_back_high; +  int target_phrase_low = -1, target_phrase_high = -1; + +  // Second projection (on source side) is too wide. +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 1, 1, +                                    2, true, true, false)); +} + +TEST_F(RuleExtractorHelperTest, TestFindFixPointArbitraryExpansion) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 20, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(11)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(11)); + +  vector<int> source_low = {1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9}; +  vector<int> source_high = {2, 3, 4, 5, 5, 6, 7, 8, 9, 10, 10}; +  vector<int> target_low = {-1, 0, 1, 2, 3, 5, 6, 7, 8, 9, -1}; +  vector<int> target_high = {-1, 2, 3, 4, 5, 6, 8, 9, 10, 11, -1}; +  int source_phrase_low = 4, source_phrase_high = 7; +  int source_back_low, source_back_high; +  int target_phrase_low = -1, target_phrase_high = -1; +  EXPECT_FALSE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                    source_low, source_high, target_phrase_low, +                                    target_phrase_high, target_low, target_high, +                                    source_back_low, source_back_high, 0, 1, 1, +                                    10, true, true, false)); + +  source_phrase_low = 4, source_phrase_high = 7; +  target_phrase_low = -1, target_phrase_high = -1; +  EXPECT_TRUE(helper->FindFixPoint(source_phrase_low, source_phrase_high, +                                   source_low, source_high, target_phrase_low, +                                   target_phrase_high, target_low, target_high, +                                   source_back_low, source_back_high, 0, 1, 1, +                                   10, true, true, true)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapOrder) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); + +  vector<pair<int, int> > gaps = +      {make_pair(0, 3), make_pair(5, 8), make_pair(11, 12), make_pair(15, 17)}; +  vector<int> expected_gap_order = {0, 1, 2, 3}; +  EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps)); + +  gaps = {make_pair(15, 17), make_pair(8, 9), make_pair(5, 6), make_pair(0, 3)}; +  expected_gap_order = {3, 2, 1, 0}; +  EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps)); + +  gaps = {make_pair(8, 9), make_pair(5, 6), make_pair(0, 3), make_pair(15, 17)}; +  expected_gap_order = {2, 1, 0, 3}; +  EXPECT_EQ(expected_gap_order, helper->GetGapOrder(gaps)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsExceedNumSymbols) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); + +  bool met_constraints = true; +  vector<int> source_low = {1, 1, 2, 3, 4, 5, 6}; +  vector<int> source_high = {2, 2, 3, 4, 5, 6, 7}; +  vector<int> target_low = {-1, 0, 2, 3, 4, 5, 6}; +  vector<int> target_high = {-1, 2, 3, 4, 5, 6, 7}; +  int source_phrase_low = 1, source_phrase_high = 6; +  int source_back_low = 0, source_back_high = 6; +  vector<int> matching = {11, 13, 15}; +  vector<int> chunklen = {1, 1, 1}; +  vector<pair<int, int> > source_gaps, target_gaps; +  int num_symbols = 5; +  EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, +                               source_low, source_high, target_low, target_high, +                               source_phrase_low, source_phrase_high, +                               source_back_low, source_back_high, 5, 10, +                               num_symbols, met_constraints)); + +  source_low = {0, 1, 2, 3, 4, 5, 5}; +  source_high = {1, 2, 3, 4, 5, 6, 6}; +  target_low = {0, 1, 2, 3, 4, 5, -1}; +  target_high = {1, 2, 3, 4, 5, 7, -1}; +  source_phrase_low = 1, source_phrase_high = 6; +  source_back_low = 1, source_back_high = 7; +  num_symbols = 5; +  EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, +                               source_low, source_high, target_low, target_high, +                               source_phrase_low, source_phrase_high, +                               source_back_low, source_back_high, 5, 10, +                               num_symbols, met_constraints)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsExtensionsNotTight) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 7, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); + +  bool met_constraints = true; +  vector<int> source_low = {-1, 1, 2, 3, 4, 5, -1}; +  vector<int> source_high = {-1, 2, 3, 4, 5, 6, -1}; +  vector<int> target_low = {-1, 1, 2, 3, 4, 5, -1}; +  vector<int> target_high = {-1, 2, 3, 4, 5, 6, -1}; +  int source_phrase_low = 1, source_phrase_high = 6; +  int source_back_low = 0, source_back_high = 6; +  vector<int> matching = {11, 13, 15}; +  vector<int> chunklen = {1, 1, 1}; +  vector<pair<int, int> > source_gaps, target_gaps; +  int num_symbols = 5; +  EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, +                               source_low, source_high, target_low, target_high, +                               source_phrase_low, source_phrase_high, +                               source_back_low, source_back_high, 5, 10, +                               num_symbols, met_constraints)); + +  source_phrase_low = 1, source_phrase_high = 6; +  source_back_low = 1, source_back_high = 7; +  num_symbols = 5; +  EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, +                               source_low, source_high, target_low, target_high, +                               source_phrase_low, source_phrase_high, +                               source_back_low, source_back_high, 5, 10, +                               num_symbols, met_constraints)); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsNotTightExtremities) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 7, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); + +  bool met_constraints = true; +  vector<int> source_low = {-1, -1, 2, 3, 4, 5, 6}; +  vector<int> source_high = {-1, -1, 3, 4, 5, 6, 7}; +  vector<int> target_low = {-1, -1, 2, 3, 4, 5, 6}; +  vector<int> target_high = {-1, -1, 3, 4, 5, 6, 7}; +  int source_phrase_low = 1, source_phrase_high = 6; +  int source_back_low = 1, source_back_high = 6; +  vector<int> matching = {11, 13, 15}; +  vector<int> chunklen = {1, 1, 1}; +  vector<pair<int, int> > source_gaps, target_gaps; +  int num_symbols = 5; +  EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, +                              source_low, source_high, target_low, target_high, +                              source_phrase_low, source_phrase_high, +                              source_back_low, source_back_high, 5, 10, +                              num_symbols, met_constraints)); +  EXPECT_FALSE(met_constraints); +  vector<pair<int, int> > expected_gaps = {make_pair(2, 3), make_pair(4, 5)}; +  EXPECT_EQ(expected_gaps, source_gaps); +  EXPECT_EQ(expected_gaps, target_gaps); + +  source_low = {-1, 1, 2, 3, 4, -1, 6}; +  source_high = {-1, 2, 3, 4, 5, -1, 7}; +  target_low = {-1, 1, 2, 3, 4, -1, 6}; +  target_high = {-1, 2, 3, 4, 5, -1, 7}; +  met_constraints = true; +  source_gaps.clear(); +  target_gaps.clear(); +  EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, +                              source_low, source_high, target_low, target_high, +                              source_phrase_low, source_phrase_high, +                              source_back_low, source_back_high, 5, 10, +                              num_symbols, met_constraints)); +  EXPECT_FALSE(met_constraints); +  EXPECT_EQ(expected_gaps, source_gaps); +  EXPECT_EQ(expected_gaps, target_gaps); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapsWithExtensions) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); + +  bool met_constraints = true; +  vector<int> source_low = {-1, 5, 2, 3, 4, 1, -1}; +  vector<int> source_high = {-1, 6, 3, 4, 5, 2, -1}; +  vector<int> target_low = {-1, 5, 2, 3, 4, 1, -1}; +  vector<int> target_high = {-1, 6, 3, 4, 5, 2, -1}; +  int source_phrase_low = 2, source_phrase_high = 5; +  int source_back_low = 1, source_back_high = 6; +  vector<int> matching = {12, 14}; +  vector<int> chunklen = {1, 1}; +  vector<pair<int, int> > source_gaps, target_gaps; +  int num_symbols = 3; +  EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, +                              source_low, source_high, target_low, target_high, +                              source_phrase_low, source_phrase_high, +                              source_back_low, source_back_high, 5, 10, +                              num_symbols, met_constraints)); +  vector<pair<int, int> > expected_source_gaps = { +    make_pair(1, 2), make_pair(3, 4), make_pair(5, 6) +  }; +  EXPECT_EQ(expected_source_gaps, source_gaps); +  vector<pair<int, int> > expected_target_gaps = { +    make_pair(5, 6), make_pair(3, 4), make_pair(1, 2) +  }; +  EXPECT_EQ(expected_target_gaps, target_gaps); +} + +TEST_F(RuleExtractorHelperTest, TestGetGaps) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); + +  bool met_constraints = true; +  vector<int> source_low = {-1, 1, 4, 3, 2, 5, -1}; +  vector<int> source_high = {-1, 2, 5, 4, 3, 6, -1}; +  vector<int> target_low = {-1, 1, 4, 3, 2, 5, -1}; +  vector<int> target_high = {-1, 2, 5, 4, 3, 6, -1}; +  int source_phrase_low = 1, source_phrase_high = 6; +  int source_back_low = 1, source_back_high = 6; +  vector<int> matching = {11, 13, 15}; +  vector<int> chunklen = {1, 1, 1}; +  vector<pair<int, int> > source_gaps, target_gaps; +  int num_symbols = 5; +  EXPECT_TRUE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, +                              source_low, source_high, target_low, target_high, +                              source_phrase_low, source_phrase_high, +                              source_back_low, source_back_high, 5, 10, +                              num_symbols, met_constraints)); +  vector<pair<int, int> > expected_source_gaps = { +    make_pair(2, 3), make_pair(4, 5) +  }; +  EXPECT_EQ(expected_source_gaps, source_gaps); +  vector<pair<int, int> > expected_target_gaps = { +    make_pair(4, 5), make_pair(2, 3) +  }; +  EXPECT_EQ(expected_target_gaps, target_gaps); +} + +TEST_F(RuleExtractorHelperTest, TestGetGapIntegrityChecksFailed) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); +  EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); +  EXPECT_CALL(*target_data_array, GetSentenceLength(_)) +      .WillRepeatedly(Return(7)); + +  bool met_constraints = true; +  vector<int> source_low = {-1, 3, 2, 3, 4, 3, -1}; +  vector<int> source_high = {-1, 4, 3, 4, 5, 4, -1}; +  vector<int> target_low = {-1, -1, 2, 1, 4, -1, -1}; +  vector<int> target_high = {-1, -1, 3, 6, 5, -1, -1}; +  int source_phrase_low = 2, source_phrase_high = 5; +  int source_back_low = 2, source_back_high = 5; +  vector<int> matching = {12, 14}; +  vector<int> chunklen = {1, 1}; +  vector<pair<int, int> > source_gaps, target_gaps; +  int num_symbols = 3; +  EXPECT_FALSE(helper->GetGaps(source_gaps, target_gaps, matching, chunklen, +                               source_low, source_high, target_low, target_high, +                               source_phrase_low, source_phrase_high, +                               source_back_low, source_back_high, 5, 10, +                               num_symbols, met_constraints)); +} + +TEST_F(RuleExtractorHelperTest, TestGetSourceIndexes) { +  helper = make_shared<RuleExtractorHelper>(source_data_array, +      target_data_array, alignment, 10, 5, true, true, true); + +  vector<int> matching = {13, 18, 21}; +  vector<int> chunklen = {3, 2, 1}; +  unordered_map<int, int> expected_indexes = { +      {3, 1}, {4, 2}, {5, 3}, {8, 5}, {9, 6}, {11, 8} +  }; +  EXPECT_EQ(expected_indexes, helper->GetSourceIndexes(matching, chunklen, +                                                       1, 10)); + +  matching = {12, 17}; +  chunklen = {2, 4}; +  expected_indexes = {{2, 0}, {3, 1}, {7, 3}, {8, 4}, {9, 5}, {10, 6}}; +  EXPECT_EQ(expected_indexes, helper->GetSourceIndexes(matching, chunklen, +                                                       0, 10)); +} + +} // namespace +} // namespace extractor diff --git a/extractor/rule_extractor_test.cc b/extractor/rule_extractor_test.cc new file mode 100644 index 00000000..5c1501c7 --- /dev/null +++ b/extractor/rule_extractor_test.cc @@ -0,0 +1,168 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_rule_extractor_helper.h" +#include "mocks/mock_scorer.h" +#include "mocks/mock_target_phrase_extractor.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule_extractor.h" +#include "rule.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class RuleExtractorTest : public Test { + protected: +  virtual void SetUp() { +    source_data_array = make_shared<MockDataArray>(); +    EXPECT_CALL(*source_data_array, GetSentenceId(_)) +        .WillRepeatedly(Return(0)); +    EXPECT_CALL(*source_data_array, GetSentenceStart(_)) +        .WillRepeatedly(Return(0)); +    EXPECT_CALL(*source_data_array, GetSentenceLength(_)) +        .WillRepeatedly(Return(10)); + +    helper = make_shared<MockRuleExtractorHelper>(); +    EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _)) +        .WillRepeatedly(Return(true)); +    EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _)) +        .WillRepeatedly(Return(true)); +    unordered_map<int, int> source_indexes; +    EXPECT_CALL(*helper, GetSourceIndexes(_, _, _, _)) +        .WillRepeatedly(Return(source_indexes)); + +    vocabulary = make_shared<MockVocabulary>(); +    EXPECT_CALL(*vocabulary, GetTerminalValue(87)) +        .WillRepeatedly(Return("a")); +    phrase_builder = make_shared<PhraseBuilder>(vocabulary); +    vector<int> symbols = {87}; +    Phrase target_phrase = phrase_builder->Build(symbols); +    PhraseAlignment phrase_alignment = {make_pair(0, 0)}; + +    target_phrase_extractor = make_shared<MockTargetPhraseExtractor>(); +    vector<pair<Phrase, PhraseAlignment> > target_phrases = { +      make_pair(target_phrase, phrase_alignment) +    }; +    EXPECT_CALL(*target_phrase_extractor, ExtractPhrases(_, _, _, _, _, _)) +        .WillRepeatedly(Return(target_phrases)); + +    scorer = make_shared<MockScorer>(); +    vector<double> scores = {0.3, 7.2}; +    EXPECT_CALL(*scorer, Score(_)).WillRepeatedly(Return(scores)); + +    extractor = make_shared<RuleExtractor>(source_data_array, phrase_builder, +        scorer, target_phrase_extractor, helper, 10, 1, 3, 5, false); +  } + +  shared_ptr<MockDataArray> source_data_array; +  shared_ptr<MockVocabulary> vocabulary; +  shared_ptr<PhraseBuilder> phrase_builder; +  shared_ptr<MockRuleExtractorHelper> helper; +  shared_ptr<MockScorer> scorer; +  shared_ptr<MockTargetPhraseExtractor> target_phrase_extractor; +  shared_ptr<RuleExtractor> extractor; +}; + +TEST_F(RuleExtractorTest, TestExtractRulesAlignedTerminalsFail) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  EXPECT_CALL(*helper, CheckAlignedTerminals(_, _, _, _)) +      .WillRepeatedly(Return(false)); +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesTightPhrasesFail) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  EXPECT_CALL(*helper, CheckTightPhrases(_, _, _, _)) +      .WillRepeatedly(Return(false)); +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesNoFixPoint) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); + +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  // Set FindFixPoint to return false. +  vector<pair<int, int> > gaps; +  helper->SetUp(0, 0, 0, 0, false, gaps, gaps, 0, true, true); + +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesGapsFail) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); + +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  // Set CheckGaps to return false. +  vector<pair<int, int> > gaps; +  helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, false); + +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(0, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesNoExtremities) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); + +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).Times(1); +  vector<pair<int, int> > gaps(3); +  // Set FindFixPoint to return true. The number of gaps equals the number of +  // nonterminals, so we won't add any extremities. +  helper->SetUp(0, 0, 0, 0, true, gaps, gaps, 0, true, true); + +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(1, rules.size()); +} + +TEST_F(RuleExtractorTest, TestExtractRulesAddExtremities) { +  vector<int> symbols = {87}; +  Phrase phrase = phrase_builder->Build(symbols); +  vector<int> matching = {2}; +  PhraseLocation phrase_location(matching, 1); + +  vector<int> links(10, -1); +  EXPECT_CALL(*helper, GetLinksSpans(_, _, _, _, _)).WillOnce(DoAll( +      SetArgReferee<0>(links), +      SetArgReferee<1>(links), +      SetArgReferee<2>(links), +      SetArgReferee<3>(links))); + +  vector<pair<int, int> > gaps; +  // Set FindFixPoint to return true. The number of gaps equals the number of +  // nonterminals, so we won't add any extremities. +  helper->SetUp(0, 0, 2, 3, true, gaps, gaps, 0, true, true); + +  vector<Rule> rules = extractor->ExtractRules(phrase, phrase_location); +  EXPECT_EQ(4, rules.size()); +} + +} // namespace +} // namespace extractor diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc new file mode 100644 index 00000000..8c30fb9e --- /dev/null +++ b/extractor/rule_factory.cc @@ -0,0 +1,303 @@ +#include "rule_factory.h" + +#include <chrono> +#include <memory> +#include <queue> +#include <vector> + +#include "grammar.h" +#include "fast_intersector.h" +#include "matchings_finder.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "rule.h" +#include "rule_extractor.h" +#include "sampler.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "vocabulary.h" + +using namespace std; +using namespace chrono; + +namespace extractor { + +typedef high_resolution_clock Clock; + +struct State { +  State(int start, int end, const vector<int>& phrase, +      const vector<int>& subpatterns_start, shared_ptr<TrieNode> node, +      bool starts_with_x) : +      start(start), end(end), phrase(phrase), +      subpatterns_start(subpatterns_start), node(node), +      starts_with_x(starts_with_x) {} + +  int start, end; +  vector<int> phrase, subpatterns_start; +  shared_ptr<TrieNode> node; +  bool starts_with_x; +}; + +HieroCachingRuleFactory::HieroCachingRuleFactory( +    shared_ptr<SuffixArray> source_suffix_array, +    shared_ptr<DataArray> target_data_array, +    shared_ptr<Alignment> alignment, +    const shared_ptr<Vocabulary>& vocabulary, +    shared_ptr<Precomputation> precomputation, +    shared_ptr<Scorer> scorer, +    int min_gap_size, +    int max_rule_span, +    int max_nonterminals, +    int max_rule_symbols, +    int max_samples, +    bool require_tight_phrases) : +    vocabulary(vocabulary), +    scorer(scorer), +    min_gap_size(min_gap_size), +    max_rule_span(max_rule_span), +    max_nonterminals(max_nonterminals), +    max_chunks(max_nonterminals + 1), +    max_rule_symbols(max_rule_symbols) { +  matchings_finder = make_shared<MatchingsFinder>(source_suffix_array); +  fast_intersector = make_shared<FastIntersector>(source_suffix_array, +      precomputation, vocabulary, max_rule_span, min_gap_size); +  phrase_builder = make_shared<PhraseBuilder>(vocabulary); +  rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(), +      target_data_array, alignment, phrase_builder, scorer, vocabulary, +      max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true, +      false, require_tight_phrases); +  sampler = make_shared<Sampler>(source_suffix_array, max_samples); +} + +HieroCachingRuleFactory::HieroCachingRuleFactory( +    shared_ptr<MatchingsFinder> finder, +    shared_ptr<FastIntersector> fast_intersector, +    shared_ptr<PhraseBuilder> phrase_builder, +    shared_ptr<RuleExtractor> rule_extractor, +    shared_ptr<Vocabulary> vocabulary, +    shared_ptr<Sampler> sampler, +    shared_ptr<Scorer> scorer, +    int min_gap_size, +    int max_rule_span, +    int max_nonterminals, +    int max_chunks, +    int max_rule_symbols) : +    matchings_finder(finder), +    fast_intersector(fast_intersector), +    phrase_builder(phrase_builder), +    rule_extractor(rule_extractor), +    vocabulary(vocabulary), +    sampler(sampler), +    scorer(scorer), +    min_gap_size(min_gap_size), +    max_rule_span(max_rule_span), +    max_nonterminals(max_nonterminals), +    max_chunks(max_chunks), +    max_rule_symbols(max_rule_symbols) {} + +HieroCachingRuleFactory::HieroCachingRuleFactory() {} + +HieroCachingRuleFactory::~HieroCachingRuleFactory() {} + +Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { +  Clock::time_point start_time = Clock::now(); +  double total_extract_time = 0; +  double total_intersect_time = 0; +  double total_lookup_time = 0; + +  MatchingsTrie trie; +  shared_ptr<TrieNode> root = trie.GetRoot(); + +  int first_x = vocabulary->GetNonterminalIndex(1); +  shared_ptr<TrieNode> x_root(new TrieNode(root)); +  root->AddChild(first_x, x_root); + +  queue<State> states; +  for (size_t i = 0; i < word_ids.size(); ++i) { +    states.push(State(i, i, vector<int>(), vector<int>(1, i), root, false)); +  } +  for (size_t i = min_gap_size; i < word_ids.size(); ++i) { +    states.push(State(i - min_gap_size, i, vector<int>(1, first_x), +        vector<int>(1, i), x_root, true)); +  } + +  vector<Rule> rules; +  while (!states.empty()) { +    State state = states.front(); +    states.pop(); + +    shared_ptr<TrieNode> node = state.node; +    vector<int> phrase = state.phrase; +    int word_id = word_ids[state.end]; +    phrase.push_back(word_id); +    Phrase next_phrase = phrase_builder->Build(phrase); +    shared_ptr<TrieNode> next_node; + +    if (CannotHaveMatchings(node, word_id)) { +      if (!node->HasChild(word_id)) { +        node->AddChild(word_id, shared_ptr<TrieNode>()); +      } +      continue; +    } + +    if (RequiresLookup(node, word_id)) { +      shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ? +          trie.GetRoot() : node->suffix_link->GetChild(word_id); +      if (state.starts_with_x) { +        // If the phrase starts with a non terminal, we simply use the matchings +        // from the suffix link. +        next_node = make_shared<TrieNode>( +            next_suffix_link, next_phrase, next_suffix_link->matchings); +      } else { +        PhraseLocation phrase_location; +        if (next_phrase.Arity() > 0) { +          // For phrases containing a nonterminal, we use either the occurrences +          // of the prefix or the suffix to determine the occurrences of the +          // phrase. +          Clock::time_point intersect_start = Clock::now(); +          phrase_location = fast_intersector->Intersect( +              node->matchings, next_suffix_link->matchings, next_phrase); +          Clock::time_point intersect_stop = Clock::now(); +          total_intersect_time += GetDuration(intersect_start, intersect_stop); +        } else { +          // For phrases not containing any nonterminals, we simply query the +          // suffix array using the suffix array range of the prefix as a +          // starting point. +          Clock::time_point lookup_start = Clock::now(); +          phrase_location = matchings_finder->Find( +              node->matchings, +              vocabulary->GetTerminalValue(word_id), +              state.phrase.size()); +          Clock::time_point lookup_stop = Clock::now(); +          total_lookup_time += GetDuration(lookup_start, lookup_stop); +        } + +        if (phrase_location.IsEmpty()) { +          continue; +        } + +        // Create new trie node to store data about the current phrase. +        next_node = make_shared<TrieNode>( +            next_suffix_link, next_phrase, phrase_location); +      } +      // Add the new trie node to the trie cache. +      node->AddChild(word_id, next_node); + +      // Automatically adds a trailing non terminal if allowed. Simply copy the +      // matchings from the prefix node. +      AddTrailingNonterminal(phrase, next_phrase, next_node, +                             state.starts_with_x); + +      Clock::time_point extract_start = Clock::now(); +      if (!state.starts_with_x) { +        // Extract rules for the sampled set of occurrences. +        PhraseLocation sample = sampler->Sample(next_node->matchings); +        vector<Rule> new_rules = +            rule_extractor->ExtractRules(next_phrase, sample); +        rules.insert(rules.end(), new_rules.begin(), new_rules.end()); +      } +      Clock::time_point extract_stop = Clock::now(); +      total_extract_time += GetDuration(extract_start, extract_stop); +    } else { +      next_node = node->GetChild(word_id); +    } + +    // Create more states (phrases) to be analyzed. +    vector<State> new_states = ExtendState(word_ids, state, phrase, next_phrase, +                                           next_node); +    for (State new_state: new_states) { +      states.push(new_state); +    } +  } + +  Clock::time_point stop_time = Clock::now(); +  #pragma omp critical (stderr_write) +  { +    cerr << "Total time for rule lookup, extraction, and scoring = " +         << GetDuration(start_time, stop_time) << " seconds" << endl; +    cerr << "Extract time = " << total_extract_time << " seconds" << endl; +    cerr << "Intersect time = " << total_intersect_time << " seconds" << endl; +    cerr << "Lookup time = " << total_lookup_time << " seconds" << endl; +  } +  return Grammar(rules, scorer->GetFeatureNames()); +} + +bool HieroCachingRuleFactory::CannotHaveMatchings( +    shared_ptr<TrieNode> node, int word_id) { +  if (node->HasChild(word_id) && node->GetChild(word_id) == NULL) { +    return true; +  } + +  shared_ptr<TrieNode> suffix_link = node->suffix_link; +  return suffix_link != NULL && suffix_link->GetChild(word_id) == NULL; +} + +bool HieroCachingRuleFactory::RequiresLookup( +    shared_ptr<TrieNode> node, int word_id) { +  return !node->HasChild(word_id); +} + +void HieroCachingRuleFactory::AddTrailingNonterminal( +    vector<int> symbols, +    const Phrase& prefix, +    const shared_ptr<TrieNode>& prefix_node, +    bool starts_with_x) { +  if (prefix.Arity() >= max_nonterminals) { +    return; +  } + +  int var_id = vocabulary->GetNonterminalIndex(prefix.Arity() + 1); +  symbols.push_back(var_id); +  Phrase var_phrase = phrase_builder->Build(symbols); + +  int suffix_var_id = vocabulary->GetNonterminalIndex( +      prefix.Arity() + (starts_with_x == 0)); +  shared_ptr<TrieNode> var_suffix_link = +      prefix_node->suffix_link->GetChild(suffix_var_id); + +  prefix_node->AddChild(var_id, make_shared<TrieNode>( +      var_suffix_link, var_phrase, prefix_node->matchings)); +} + +vector<State> HieroCachingRuleFactory::ExtendState( +    const vector<int>& word_ids, +    const State& state, +    vector<int> symbols, +    const Phrase& phrase, +    const shared_ptr<TrieNode>& node) { +  int span = state.end - state.start; +  vector<State> new_states; +  if (symbols.size() >= max_rule_symbols || state.end + 1 >= word_ids.size() || +      span >= max_rule_span) { +    return new_states; +  } + +  // New state for adding the next symbol. +  new_states.push_back(State(state.start, state.end + 1, symbols, +      state.subpatterns_start, node, state.starts_with_x)); + +  int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0); +  if (symbols.size() + 1 >= max_rule_symbols || +      phrase.Arity() >= max_nonterminals || +      num_subpatterns >= max_chunks) { +    return new_states; +  } + +  // New states for adding a nonterminal followed by a new symbol. +  int var_id = vocabulary->GetNonterminalIndex(phrase.Arity() + 1); +  symbols.push_back(var_id); +  vector<int> subpatterns_start = state.subpatterns_start; +  size_t i = state.end + 1 + min_gap_size; +  while (i < word_ids.size() && i - state.start <= max_rule_span) { +    subpatterns_start.push_back(i); +    new_states.push_back(State(state.start, i, symbols, subpatterns_start, +        node->GetChild(var_id), state.starts_with_x)); +    subpatterns_start.pop_back(); +    ++i; +  } + +  return new_states; +} + +} // namespace extractor diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h new file mode 100644 index 00000000..52e8712a --- /dev/null +++ b/extractor/rule_factory.h @@ -0,0 +1,118 @@ +#ifndef _RULE_FACTORY_H_ +#define _RULE_FACTORY_H_ + +#include <memory> +#include <vector> + +#include "matchings_trie.h" + +using namespace std; + +namespace extractor { + +class Alignment; +class DataArray; +class FastIntersector; +class Grammar; +class MatchingsFinder; +class PhraseBuilder; +class Precomputation; +class Rule; +class RuleExtractor; +class Sampler; +class Scorer; +class State; +class SuffixArray; +class Vocabulary; + +/** + * Component containing most of the logic for extracting SCFG rules for a given + * sentence. + * + * Given a sentence (as a vector of word ids), this class constructs all the + * possible source phrases starting from this sentence. For each source phrase, + * it finds all its occurrences in the source data and samples some of these + * occurrences to extract aligned source-target phrase pairs. A trie cache is + * used to avoid unnecessary computations if a source phrase can be constructed + * more than once (e.g. some words occur more than once in the sentence). + */ +class HieroCachingRuleFactory { + public: +  HieroCachingRuleFactory( +      shared_ptr<SuffixArray> source_suffix_array, +      shared_ptr<DataArray> target_data_array, +      shared_ptr<Alignment> alignment, +      const shared_ptr<Vocabulary>& vocabulary, +      shared_ptr<Precomputation> precomputation, +      shared_ptr<Scorer> scorer, +      int min_gap_size, +      int max_rule_span, +      int max_nonterminals, +      int max_rule_symbols, +      int max_samples, +      bool require_tight_phrases); + +  // For testing only. +  HieroCachingRuleFactory( +      shared_ptr<MatchingsFinder> finder, +      shared_ptr<FastIntersector> fast_intersector, +      shared_ptr<PhraseBuilder> phrase_builder, +      shared_ptr<RuleExtractor> rule_extractor, +      shared_ptr<Vocabulary> vocabulary, +      shared_ptr<Sampler> sampler, +      shared_ptr<Scorer> scorer, +      int min_gap_size, +      int max_rule_span, +      int max_nonterminals, +      int max_chunks, +      int max_rule_symbols); + +  virtual ~HieroCachingRuleFactory(); + +  // Constructs SCFG rules for a given sentence. +  // (See class description for more details.) +  virtual Grammar GetGrammar(const vector<int>& word_ids); + + protected: +  HieroCachingRuleFactory(); + + private: +  // Checks if the phrase (if previously encountered) or its prefix have any +  // occurrences in the source data. +  bool CannotHaveMatchings(shared_ptr<TrieNode> node, int word_id); + +  // Checks if the phrase has previously been analyzed. +  bool RequiresLookup(shared_ptr<TrieNode> node, int word_id); + +  // Creates a new state in the trie that corresponds to adding a trailing +  // nonterminal to the current phrase. +  void AddTrailingNonterminal(vector<int> symbols, +                              const Phrase& prefix, +                              const shared_ptr<TrieNode>& prefix_node, +                              bool starts_with_x); + +  // Extends the current state by possibly adding a nonterminal followed by a +  // terminal. +  vector<State> ExtendState(const vector<int>& word_ids, +                            const State& state, +                            vector<int> symbols, +                            const Phrase& phrase, +                            const shared_ptr<TrieNode>& node); + +  shared_ptr<MatchingsFinder> matchings_finder; +  shared_ptr<FastIntersector> fast_intersector; +  shared_ptr<PhraseBuilder> phrase_builder; +  shared_ptr<RuleExtractor> rule_extractor; +  shared_ptr<Vocabulary> vocabulary; +  shared_ptr<Sampler> sampler; +  shared_ptr<Scorer> scorer; +  int min_gap_size; +  int max_rule_span; +  int max_nonterminals; +  int max_chunks; +  int max_rule_symbols; +}; + +} // namespace extractor + +#endif diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc new file mode 100644 index 00000000..2129dfa0 --- /dev/null +++ b/extractor/rule_factory_test.cc @@ -0,0 +1,103 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "grammar.h" +#include "mocks/mock_fast_intersector.h" +#include "mocks/mock_matchings_finder.h" +#include "mocks/mock_rule_extractor.h" +#include "mocks/mock_sampler.h" +#include "mocks/mock_scorer.h" +#include "mocks/mock_vocabulary.h" +#include "phrase_builder.h" +#include "phrase_location.h" +#include "rule_factory.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class RuleFactoryTest : public Test { + protected: +  virtual void SetUp() { +    finder = make_shared<MockMatchingsFinder>(); +    fast_intersector = make_shared<MockFastIntersector>(); + +    vocabulary = make_shared<MockVocabulary>(); +    EXPECT_CALL(*vocabulary, GetTerminalValue(2)).WillRepeatedly(Return("a")); +    EXPECT_CALL(*vocabulary, GetTerminalValue(3)).WillRepeatedly(Return("b")); +    EXPECT_CALL(*vocabulary, GetTerminalValue(4)).WillRepeatedly(Return("c")); + +    phrase_builder = make_shared<PhraseBuilder>(vocabulary); + +    scorer = make_shared<MockScorer>(); +    feature_names = {"f1"}; +    EXPECT_CALL(*scorer, GetFeatureNames()) +        .WillRepeatedly(Return(feature_names)); + +    sampler = make_shared<MockSampler>(); +    EXPECT_CALL(*sampler, Sample(_)) +        .WillRepeatedly(Return(PhraseLocation(0, 1))); + +    Phrase phrase; +    vector<double> scores = {0.5}; +    vector<pair<int, int> > phrase_alignment = {make_pair(0, 0)}; +    vector<Rule> rules = {Rule(phrase, phrase, scores, phrase_alignment)}; +    extractor = make_shared<MockRuleExtractor>(); +    EXPECT_CALL(*extractor, ExtractRules(_, _)) +        .WillRepeatedly(Return(rules)); +  } + +  vector<string> feature_names; +  shared_ptr<MockMatchingsFinder> finder; +  shared_ptr<MockFastIntersector> fast_intersector; +  shared_ptr<MockVocabulary> vocabulary; +  shared_ptr<PhraseBuilder> phrase_builder; +  shared_ptr<MockScorer> scorer; +  shared_ptr<MockSampler> sampler; +  shared_ptr<MockRuleExtractor> extractor; +  shared_ptr<HieroCachingRuleFactory> factory; +}; + +TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { +  factory = make_shared<HieroCachingRuleFactory>(finder, fast_intersector, +      phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5); + +  EXPECT_CALL(*finder, Find(_, _, _)) +      .Times(6) +      .WillRepeatedly(Return(PhraseLocation(0, 1))); + +  EXPECT_CALL(*fast_intersector, Intersect(_, _, _)) +      .Times(1) +      .WillRepeatedly(Return(PhraseLocation(0, 1))); + +  vector<int> word_ids = {2, 3, 4}; +  Grammar grammar = factory->GetGrammar(word_ids); +  EXPECT_EQ(feature_names, grammar.GetFeatureNames()); +  EXPECT_EQ(7, grammar.GetRules().size()); +} + +TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { +  factory = make_shared<HieroCachingRuleFactory>(finder, fast_intersector, +      phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5); + +  EXPECT_CALL(*finder, Find(_, _, _)) +      .Times(12) +      .WillRepeatedly(Return(PhraseLocation(0, 1))); + +  EXPECT_CALL(*fast_intersector, Intersect(_, _, _)) +      .Times(16) +      .WillRepeatedly(Return(PhraseLocation(0, 1))); + +  vector<int> word_ids = {2, 3, 4, 2, 3}; +  Grammar grammar = factory->GetGrammar(word_ids); +  EXPECT_EQ(feature_names, grammar.GetFeatureNames()); +  EXPECT_EQ(28, grammar.GetRules().size()); +} + +} // namespace +} // namespace extractor diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc new file mode 100644 index 00000000..aec83e3b --- /dev/null +++ b/extractor/run_extractor.cc @@ -0,0 +1,242 @@ +#include <chrono> +#include <fstream> +#include <iostream> +#include <memory> +#include <string> +#include <vector> + +#include <omp.h> +#include <boost/filesystem.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> + +#include "alignment.h" +#include "data_array.h" +#include "features/count_source_target.h" +#include "features/feature.h" +#include "features/is_source_singleton.h" +#include "features/is_source_target_singleton.h" +#include "features/max_lex_source_given_target.h" +#include "features/max_lex_target_given_source.h" +#include "features/sample_source_count.h" +#include "features/target_given_source_coherent.h" +#include "grammar.h" +#include "grammar_extractor.h" +#include "precomputation.h" +#include "rule.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "translation_table.h" + +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace std; +using namespace extractor; +using namespace features; + +// Returns the file path in which a given grammar should be written. +fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) { +  string file_name = "grammar." + to_string(file_number); +  return grammar_path / file_name; +} + +int main(int argc, char** argv) { +  int num_threads_default = 1; +  #pragma omp parallel +  num_threads_default = omp_get_num_threads(); + +  // Sets up the command line arguments map. +  po::options_description desc("Command line options"); +  desc.add_options() +    ("help,h", "Show available options") +    ("source,f", po::value<string>(), "Source language corpus") +    ("target,e", po::value<string>(), "Target language corpus") +    ("bitext,b", po::value<string>(), "Parallel text (source ||| target)") +    ("alignment,a", po::value<string>()->required(), "Bitext word alignment") +    ("grammars,g", po::value<string>()->required(), "Grammars output path") +    ("threads,t", po::value<int>()->default_value(num_threads_default), +        "Number of parallel extractors") +    ("frequent", po::value<int>()->default_value(100), +        "Number of precomputed frequent patterns") +    ("super_frequent", po::value<int>()->default_value(10), +        "Number of precomputed super frequent patterns") +    ("max_rule_span", po::value<int>()->default_value(15), +        "Maximum rule span") +    ("max_rule_symbols", po::value<int>()->default_value(5), +        "Maximum number of symbols (terminals + nontermals) in a rule") +    ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size") +    ("max_phrase_len", po::value<int>()->default_value(4), +        "Maximum frequent phrase length") +    ("max_nonterminals", po::value<int>()->default_value(2), +        "Maximum number of nonterminals in a rule") +    ("min_frequency", po::value<int>()->default_value(1000), +        "Minimum number of occurrences for a pharse to be considered frequent") +    ("max_samples", po::value<int>()->default_value(300), +        "Maximum number of samples") +    ("tight_phrases", po::value<bool>()->default_value(true), +        "False if phrases may be loose (better, but slower)"); + +  po::variables_map vm; +  po::store(po::parse_command_line(argc, argv, desc), vm); + +  // Checks for the help option before calling notify, so the we don't get an +  // exception for missing required arguments. +  if (vm.count("help")) { +    cout << desc << endl; +    return 0; +  } + +  po::notify(vm); + +  if (!((vm.count("source") && vm.count("target")) || vm.count("bitext"))) { +    cerr << "A paralel corpus is required. " +         << "Use -f (source) with -e (target) or -b (bitext)." +         << endl; +    return 1; +  } + +  int num_threads = vm["threads"].as<int>(); +  cout << "Grammar extraction will use " << num_threads << " threads." << endl; + +  // Reads the parallel corpus. +  Clock::time_point preprocess_start_time = Clock::now(); +  cerr << "Reading source and target data..." << endl; +  Clock::time_point start_time = Clock::now(); +  shared_ptr<DataArray> source_data_array, target_data_array; +  if (vm.count("bitext")) { +    source_data_array = make_shared<DataArray>( +        vm["bitext"].as<string>(), SOURCE); +    target_data_array = make_shared<DataArray>( +        vm["bitext"].as<string>(), TARGET); +  } else { +    source_data_array = make_shared<DataArray>(vm["source"].as<string>()); +    target_data_array = make_shared<DataArray>(vm["target"].as<string>()); +  } +  Clock::time_point stop_time = Clock::now(); +  cerr << "Reading data took " << GetDuration(start_time, stop_time) +       << " seconds" << endl; + +  // Constructs the suffix array for the source data. +  cerr << "Creating source suffix array..." << endl; +  start_time = Clock::now(); +  shared_ptr<SuffixArray> source_suffix_array = +      make_shared<SuffixArray>(source_data_array); +  stop_time = Clock::now(); +  cerr << "Creating suffix array took " +       << GetDuration(start_time, stop_time) << " seconds" << endl; + +  // Reads the alignment. +  cerr << "Reading alignment..." << endl; +  start_time = Clock::now(); +  shared_ptr<Alignment> alignment = +      make_shared<Alignment>(vm["alignment"].as<string>()); +  stop_time = Clock::now(); +  cerr << "Reading alignment took " +       << GetDuration(start_time, stop_time) << " seconds" << endl; + +  // Constructs an index storing the occurrences in the source data for each +  // frequent collocation. +  cerr << "Precomputing collocations..." << endl; +  start_time = Clock::now(); +  shared_ptr<Precomputation> precomputation = make_shared<Precomputation>( +      source_suffix_array, +      vm["frequent"].as<int>(), +      vm["super_frequent"].as<int>(), +      vm["max_rule_span"].as<int>(), +      vm["max_rule_symbols"].as<int>(), +      vm["min_gap_size"].as<int>(), +      vm["max_phrase_len"].as<int>(), +      vm["min_frequency"].as<int>()); +  stop_time = Clock::now(); +  cerr << "Precomputing collocations took " +       << GetDuration(start_time, stop_time) << " seconds" << endl; + +  // Constructs a table storing p(e | f) and p(f | e) for every pair of source +  // and target words. +  cerr << "Precomputing conditional probabilities..." << endl; +  start_time = Clock::now(); +  shared_ptr<TranslationTable> table = make_shared<TranslationTable>( +      source_data_array, target_data_array, alignment); +  stop_time = Clock::now(); +  cerr << "Precomputing conditional probabilities took " +       << GetDuration(start_time, stop_time) << " seconds" << endl; + +  Clock::time_point preprocess_stop_time = Clock::now(); +  cerr << "Overall preprocessing step took " +       << GetDuration(preprocess_start_time, preprocess_stop_time) +       << " seconds" << endl; + +  // Features used to score each grammar rule. +  Clock::time_point extraction_start_time = Clock::now(); +  vector<shared_ptr<Feature> > features = { +      make_shared<TargetGivenSourceCoherent>(), +      make_shared<SampleSourceCount>(), +      make_shared<CountSourceTarget>(), +      make_shared<MaxLexSourceGivenTarget>(table), +      make_shared<MaxLexTargetGivenSource>(table), +      make_shared<IsSourceSingleton>(), +      make_shared<IsSourceTargetSingleton>() +  }; +  shared_ptr<Scorer> scorer = make_shared<Scorer>(features); + +  // Sets up the grammar extractor. +  GrammarExtractor extractor( +      source_suffix_array, +      target_data_array, +      alignment, +      precomputation, +      scorer, +      vm["min_gap_size"].as<int>(), +      vm["max_rule_span"].as<int>(), +      vm["max_nonterminals"].as<int>(), +      vm["max_rule_symbols"].as<int>(), +      vm["max_samples"].as<int>(), +      vm["tight_phrases"].as<bool>()); + +  // Releases extra memory used by the initial precomputation. +  precomputation.reset(); + +  // Creates the grammars directory if it doesn't exist. +  fs::path grammar_path = vm["grammars"].as<string>(); +  if (!fs::is_directory(grammar_path)) { +    fs::create_directory(grammar_path); +  } + +  // Reads all sentences for which we extract grammar rules (the paralellization +  // is simplified if we read all sentences upfront). +  string sentence; +  vector<string> sentences; +  while (getline(cin, sentence)) { +    sentences.push_back(sentence); +  } + +  // Extracts the grammar for each sentence and saves it to a file. +  vector<string> suffixes(sentences.size()); +  #pragma omp parallel for schedule(dynamic) num_threads(num_threads) +  for (size_t i = 0; i < sentences.size(); ++i) { +    string suffix; +    int position = sentences[i].find("|||"); +    if (position != sentences[i].npos) { +      suffix = sentences[i].substr(position); +      sentences[i] = sentences[i].substr(0, position); +    } +    suffixes[i] = suffix; + +    Grammar grammar = extractor.GetGrammar(sentences[i]); +    ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); +    output << grammar; +  } + +  for (size_t i = 0; i < sentences.size(); ++i) { +    cout << "<seg grammar=\"" << GetGrammarFilePath(grammar_path, i) << "\" id=\"" +         << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl; +  } + +  Clock::time_point extraction_stop_time = Clock::now(); +  cerr << "Overall extraction step took " +       << GetDuration(extraction_start_time, extraction_stop_time) +       << " seconds" << endl; + +  return 0; +} diff --git a/extractor/sample_alignment.txt b/extractor/sample_alignment.txt new file mode 100644 index 00000000..80b446a4 --- /dev/null +++ b/extractor/sample_alignment.txt @@ -0,0 +1,2 @@ +0-0 1-1 2-2 +1-0 2-1 diff --git a/extractor/sample_bitext.txt b/extractor/sample_bitext.txt new file mode 100644 index 00000000..93d6b39d --- /dev/null +++ b/extractor/sample_bitext.txt @@ -0,0 +1,2 @@ +ana are mere . ||| anna has apples . +ana bea mult lapte . ||| anna drinks a lot of milk . diff --git a/extractor/sampler.cc b/extractor/sampler.cc new file mode 100644 index 00000000..d81956b5 --- /dev/null +++ b/extractor/sampler.cc @@ -0,0 +1,46 @@ +#include "sampler.h" + +#include "phrase_location.h" +#include "suffix_array.h" + +namespace extractor { + +Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) : +    suffix_array(suffix_array), max_samples(max_samples) {} + +Sampler::Sampler() {} + +Sampler::~Sampler() {} + +PhraseLocation Sampler::Sample(const PhraseLocation& location) const { +  vector<int> sample; +  int num_subpatterns; +  if (location.matchings == NULL) { +    // Sample suffix array range. +    num_subpatterns = 1; +    int low = location.sa_low, high = location.sa_high; +    double step = max(1.0, (double) (high - low) / max_samples); +    for (double i = low; i < high && sample.size() < max_samples; i += step) { +      sample.push_back(suffix_array->GetSuffix(Round(i))); +    } +  } else { +    // Sample vector of occurrences. +    num_subpatterns = location.num_subpatterns; +    int num_matchings = location.matchings->size() / num_subpatterns; +    double step = max(1.0, (double) num_matchings / max_samples); +    for (double i = 0, num_samples = 0; +         i < num_matchings && num_samples < max_samples; +         i += step, ++num_samples) { +      int start = Round(i) * num_subpatterns; +      sample.insert(sample.end(), location.matchings->begin() + start, +          location.matchings->begin() + start + num_subpatterns); +    } +  } +  return PhraseLocation(sample, num_subpatterns); +} + +int Sampler::Round(double x) const { +  return x + 0.5; +} + +} // namespace extractor diff --git a/extractor/sampler.h b/extractor/sampler.h new file mode 100644 index 00000000..be4aa1bb --- /dev/null +++ b/extractor/sampler.h @@ -0,0 +1,38 @@ +#ifndef _SAMPLER_H_ +#define _SAMPLER_H_ + +#include <memory> + +using namespace std; + +namespace extractor { + +class PhraseLocation; +class SuffixArray; + +/** + * Provides uniform sampling for a PhraseLocation. + */ +class Sampler { + public: +  Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples); + +  virtual ~Sampler(); + +  // Samples uniformly at most max_samples phrase occurrences. +  virtual PhraseLocation Sample(const PhraseLocation& location) const; + + protected: +  Sampler(); + + private: +  // Round floating point number to the nearest integer. +  int Round(double x) const; + +  shared_ptr<SuffixArray> suffix_array; +  int max_samples; +}; + +} // namespace extractor + +#endif diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc new file mode 100644 index 00000000..e9abebfa --- /dev/null +++ b/extractor/sampler_test.cc @@ -0,0 +1,74 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_suffix_array.h" +#include "phrase_location.h" +#include "sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class SamplerTest : public Test { + protected: +  virtual void SetUp() { +    suffix_array = make_shared<MockSuffixArray>(); +    for (int i = 0; i < 10; ++i) { +      EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); +    } +  } + +  shared_ptr<MockSuffixArray> suffix_array; +  shared_ptr<Sampler> sampler; +}; + +TEST_F(SamplerTest, TestSuffixArrayRange) { +  PhraseLocation location(0, 10); + +  sampler = make_shared<Sampler>(suffix_array, 1); +  vector<int> expected_locations = {0}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + +  sampler = make_shared<Sampler>(suffix_array, 2); +  expected_locations = {0, 5}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + +  sampler = make_shared<Sampler>(suffix_array, 3); +  expected_locations = {0, 3, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + +  sampler = make_shared<Sampler>(suffix_array, 4); +  expected_locations = {0, 3, 5, 8}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + +  sampler = make_shared<Sampler>(suffix_array, 100); +  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); +} + +TEST_F(SamplerTest, TestSubstringsSample) { +  vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +  PhraseLocation location(locations, 2); + +  sampler = make_shared<Sampler>(suffix_array, 1); +  vector<int> expected_locations = {0, 1}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + +  sampler = make_shared<Sampler>(suffix_array, 2); +  expected_locations = {0, 1, 6, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + +  sampler = make_shared<Sampler>(suffix_array, 3); +  expected_locations = {0, 1, 4, 5, 6, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + +  sampler = make_shared<Sampler>(suffix_array, 7); +  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); +} + +} // namespace +} // namespace extractor diff --git a/extractor/scorer.cc b/extractor/scorer.cc new file mode 100644 index 00000000..d3ebf1c9 --- /dev/null +++ b/extractor/scorer.cc @@ -0,0 +1,30 @@ +#include "scorer.h" + +#include "features/feature.h" + +namespace extractor { + +Scorer::Scorer(const vector<shared_ptr<features::Feature> >& features) : +    features(features) {} + +Scorer::Scorer() {} + +Scorer::~Scorer() {} + +vector<double> Scorer::Score(const features::FeatureContext& context) const { +  vector<double> scores; +  for (auto feature: features) { +    scores.push_back(feature->Score(context)); +  } +  return scores; +} + +vector<string> Scorer::GetFeatureNames() const { +  vector<string> feature_names; +  for (auto feature: features) { +    feature_names.push_back(feature->GetName()); +  } +  return feature_names; +} + +} // namespace extractor diff --git a/extractor/scorer.h b/extractor/scorer.h new file mode 100644 index 00000000..af8a3b10 --- /dev/null +++ b/extractor/scorer.h @@ -0,0 +1,41 @@ +#ifndef _SCORER_H_ +#define _SCORER_H_ + +#include <memory> +#include <string> +#include <vector> + +using namespace std; + +namespace extractor { + +namespace features { +  class Feature; +  class FeatureContext; +} // namespace features + +/** + * Computes the feature scores for a source-target phrase pair. + */ +class Scorer { + public: +  Scorer(const vector<shared_ptr<features::Feature> >& features); + +  virtual ~Scorer(); + +  // Computes the feature score for the given context. +  virtual vector<double> Score(const features::FeatureContext& context) const; + +  // Returns the set of feature names used to score any context. +  virtual vector<string> GetFeatureNames() const; + + protected: +  Scorer(); + + private: +  vector<shared_ptr<features::Feature> > features; +}; + +} // namespace extractor + +#endif diff --git a/extractor/scorer_test.cc b/extractor/scorer_test.cc new file mode 100644 index 00000000..3a09c9cc --- /dev/null +++ b/extractor/scorer_test.cc @@ -0,0 +1,49 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "mocks/mock_feature.h" +#include "scorer.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class ScorerTest : public Test { + protected: +  virtual void SetUp() { +    feature1 = make_shared<features::MockFeature>(); +    EXPECT_CALL(*feature1, Score(_)).WillRepeatedly(Return(0.5)); +    EXPECT_CALL(*feature1, GetName()).WillRepeatedly(Return("f1")); + +    feature2 = make_shared<features::MockFeature>(); +    EXPECT_CALL(*feature2, Score(_)).WillRepeatedly(Return(-1.3)); +    EXPECT_CALL(*feature2, GetName()).WillRepeatedly(Return("f2")); + +    vector<shared_ptr<features::Feature> > features = {feature1, feature2}; +    scorer = make_shared<Scorer>(features); +  } + +  shared_ptr<features::MockFeature> feature1; +  shared_ptr<features::MockFeature> feature2; +  shared_ptr<Scorer> scorer; +}; + +TEST_F(ScorerTest, TestScore) { +  vector<double> expected_scores = {0.5, -1.3}; +  Phrase phrase; +  features::FeatureContext context(phrase, phrase, 0.3, 2, 11); +  EXPECT_EQ(expected_scores, scorer->Score(context)); +} + +TEST_F(ScorerTest, TestGetNames) { +  vector<string> expected_names = {"f1", "f2"}; +  EXPECT_EQ(expected_names, scorer->GetFeatureNames()); +} + +} // namespace +} // namespace extractor diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc new file mode 100644 index 00000000..65b2d581 --- /dev/null +++ b/extractor/suffix_array.cc @@ -0,0 +1,235 @@ +#include "suffix_array.h" + +#include <cassert> +#include <chrono> +#include <iostream> +#include <string> +#include <vector> + +#include "data_array.h" +#include "phrase_location.h" +#include "time_util.h" + +namespace fs = boost::filesystem; +using namespace std; +using namespace chrono; + +namespace extractor { + +SuffixArray::SuffixArray(shared_ptr<DataArray> data_array) : +    data_array(data_array) { +  BuildSuffixArray(); +} + +SuffixArray::SuffixArray() {} + +SuffixArray::~SuffixArray() {} + +void SuffixArray::BuildSuffixArray() { +  vector<int> groups = data_array->GetData(); +  groups.reserve(groups.size() + 1); +  groups.push_back(DataArray::NULL_WORD); +  suffix_array.resize(groups.size()); +  word_start.resize(data_array->GetVocabularySize() + 1); + +  InitialBucketSort(groups); + +  int combined_group_size = 0; +  for (size_t i = 1; i < word_start.size(); ++i) { +    if (word_start[i] - word_start[i - 1] == 1) { +      ++combined_group_size; +      suffix_array[word_start[i] - combined_group_size] = -combined_group_size; +    } else { +      combined_group_size = 0; +    } +  } + +  PrefixDoublingSort(groups); +  cerr << "\tFinalizing sort..." << endl; + +  for (size_t i = 0; i < groups.size(); ++i) { +    suffix_array[groups[i]] = i; +  } +} + +void SuffixArray::InitialBucketSort(vector<int>& groups) { +  Clock::time_point start_time = Clock::now(); +  for (size_t i = 0; i < groups.size(); ++i) { +    ++word_start[groups[i]]; +  } + +  for (size_t i = 1; i < word_start.size(); ++i) { +    word_start[i] += word_start[i - 1]; +  } + +  for (size_t i = 0; i < groups.size(); ++i) { +    --word_start[groups[i]]; +    suffix_array[word_start[groups[i]]] = i; +  } + +  for (size_t i = 0; i < suffix_array.size(); ++i) { +    groups[i] = word_start[groups[i] + 1] - 1; +  } +  Clock::time_point stop_time = Clock::now(); +  cerr << "\tBucket sort took " << GetDuration(start_time, stop_time) +       << " seconds" << endl; +} + +void SuffixArray::PrefixDoublingSort(vector<int>& groups) { +  int step = 1; +  while (suffix_array[0] != -suffix_array.size()) { +    int combined_group_size = 0; +    int i = 0; +    while (i < suffix_array.size()) { +      if (suffix_array[i] < 0) { +        int skip = -suffix_array[i]; +        combined_group_size += skip; +        i += skip; +        suffix_array[i - combined_group_size] = -combined_group_size; +      } else { +        combined_group_size = 0; +        int j = groups[suffix_array[i]]; +        TernaryQuicksort(i, j, step, groups); +        i = j + 1; +      } +    } +    step *= 2; +  } +} + +void SuffixArray::TernaryQuicksort(int left, int right, int step, +    vector<int>& groups) { +  if (left > right) { +    return; +  } + +  int pivot = left + rand() % (right - left + 1); +  int pivot_value = groups[suffix_array[pivot] + step]; +  swap(suffix_array[pivot], suffix_array[left]); +  int mid_left = left, mid_right = left; +  for (int i = left + 1; i <= right; ++i) { +    if (groups[suffix_array[i] + step] < pivot_value) { +      ++mid_right; +      int temp = suffix_array[i]; +      suffix_array[i] = suffix_array[mid_right]; +      suffix_array[mid_right] = suffix_array[mid_left]; +      suffix_array[mid_left] = temp; +      ++mid_left; +    } else if (groups[suffix_array[i] + step] == pivot_value) { +      ++mid_right; +      int temp = suffix_array[i]; +      suffix_array[i] = suffix_array[mid_right]; +      suffix_array[mid_right] = temp; +    } +  } + +  TernaryQuicksort(left, mid_left - 1, step, groups); + +  if (mid_left == mid_right) { +    groups[suffix_array[mid_left]] = mid_left; +    suffix_array[mid_left] = -1; +  } else { +    for (int i = mid_left; i <= mid_right; ++i) { +      groups[suffix_array[i]] = mid_right; +    } +  } + +  TernaryQuicksort(mid_right + 1, right, step, groups); +} + +vector<int> SuffixArray::BuildLCPArray() const { +  Clock::time_point start_time = Clock::now(); +  cerr << "\tConstructing LCP array..." << endl; + +  vector<int> lcp(suffix_array.size()); +  vector<int> rank(suffix_array.size()); +  const vector<int>& data = data_array->GetData(); + +  for (size_t i = 0; i < suffix_array.size(); ++i) { +    rank[suffix_array[i]] = i; +  } + +  int prefix_len = 0; +  for (size_t i = 0; i < suffix_array.size(); ++i) { +    if (rank[i] == 0) { +      lcp[rank[i]] = -1; +    } else { +      int j = suffix_array[rank[i] - 1]; +      while (i + prefix_len < data.size() && j + prefix_len < data.size() +          && data[i + prefix_len] == data[j + prefix_len]) { +        ++prefix_len; +      } +      lcp[rank[i]] = prefix_len; +    } + +    if (prefix_len > 0) { +      --prefix_len; +    } +  } + +  Clock::time_point stop_time = Clock::now(); +  cerr << "\tConstructing LCP took " +       << GetDuration(start_time, stop_time) << " seconds" << endl; + +  return lcp; +} + +int SuffixArray::GetSuffix(int rank) const { +  return suffix_array[rank]; +} + +int SuffixArray::GetSize() const { +  return suffix_array.size(); +} + +shared_ptr<DataArray> SuffixArray::GetData() const { +  return data_array; +} + +void SuffixArray::WriteBinary(const fs::path& filepath) const { +  FILE* file = fopen(filepath.string().c_str(), "w"); +  assert(file); +  data_array->WriteBinary(file); + +  int size = suffix_array.size(); +  fwrite(&size, sizeof(int), 1, file); +  fwrite(suffix_array.data(), sizeof(int), size, file); + +  size = word_start.size(); +  fwrite(&size, sizeof(int), 1, file); +  fwrite(word_start.data(), sizeof(int), size, file); +} + +PhraseLocation SuffixArray::Lookup(int low, int high, const string& word, +                                   int offset) const { +  if (!data_array->HasWord(word)) { +    // Return empty phrase location. +    return PhraseLocation(0, 0); +  } + +  int word_id = data_array->GetWordId(word); +  if (offset == 0) { +    return PhraseLocation(word_start[word_id], word_start[word_id + 1]); +  } + +  return PhraseLocation(LookupRangeStart(low, high, word_id, offset), +      LookupRangeStart(low, high, word_id + 1, offset)); +} + +int SuffixArray::LookupRangeStart(int low, int high, int word_id, +                                  int offset) const { +  int result = high; +  while (low < high) { +    int middle = low + (high - low) / 2; +    if (suffix_array[middle] + offset >= data_array->GetSize() || +        data_array->AtIndex(suffix_array[middle] + offset) < word_id) { +      low = middle + 1; +    } else { +      result = middle; +      high = middle; +    } +  } +  return result; +} + +} // namespace extractor diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h new file mode 100644 index 00000000..bf731d79 --- /dev/null +++ b/extractor/suffix_array.h @@ -0,0 +1,75 @@ +#ifndef _SUFFIX_ARRAY_H_ +#define _SUFFIX_ARRAY_H_ + +#include <memory> +#include <string> +#include <vector> + +#include <boost/filesystem.hpp> + +namespace fs = boost::filesystem; +using namespace std; + +namespace extractor { + +class DataArray; +class PhraseLocation; + +class SuffixArray { + public: +  // Creates a suffix array from a data array. +  SuffixArray(shared_ptr<DataArray> data_array); + +  virtual ~SuffixArray(); + +  // Returns the size of the suffix array. +  virtual int GetSize() const; + +  // Returns the data array on top of which the suffix array is constructed. +  virtual shared_ptr<DataArray> GetData() const; + +  // Constructs the longest-common-prefix array using the algorithm of Kasai et +  // al. (2001). +  virtual vector<int> BuildLCPArray() const; + +  // Returns the i-th suffix. +  virtual int GetSuffix(int rank) const; + +  // Given the range in which a phrase is located and the next word, returns the +  // range corresponding to the phrase extended with the next word. +  virtual PhraseLocation Lookup(int low, int high, const string& word, +                                int offset) const; + +  void WriteBinary(const fs::path& filepath) const; + + protected: +  SuffixArray(); + + private: +  // Constructs the suffix array using the algorithm of Larsson and Sadakane +  // (1999). +  void BuildSuffixArray(); + +  // Bucket sort on the data array (used for initializing the construction of +  // the suffix array.) +  void InitialBucketSort(vector<int>& groups); + +  void TernaryQuicksort(int left, int right, int step, vector<int>& groups); + +  // Constructs the suffix array in log(n) steps by doubling the length of the +  // suffixes at each step. +  void PrefixDoublingSort(vector<int>& groups); + +  // Given a [low, high) range in the suffix array in which all elements have +  // the first offset-1 values the same, it returns the first position where the +  // offset value is greater or equal to word_id. +  int LookupRangeStart(int low, int high, int word_id, int offset) const; + +  shared_ptr<DataArray> data_array; +  vector<int> suffix_array; +  vector<int> word_start; +}; + +} // namespace extractor + +#endif diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc new file mode 100644 index 00000000..8431a16e --- /dev/null +++ b/extractor/suffix_array_test.cc @@ -0,0 +1,78 @@ +#include <gtest/gtest.h> + +#include "mocks/mock_data_array.h" +#include "phrase_location.h" +#include "suffix_array.h" + +#include <vector> + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class SuffixArrayTest : public Test { + protected: +  virtual void SetUp() { +    data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2}; +    data_array = make_shared<MockDataArray>(); +    EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); +    EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7)); +    EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13)); +    suffix_array = make_shared<SuffixArray>(data_array); +  } + +  vector<int> data; +  shared_ptr<SuffixArray> suffix_array; +  shared_ptr<MockDataArray> data_array; +}; + +TEST_F(SuffixArrayTest, TestData) { +  EXPECT_EQ(data_array, suffix_array->GetData()); +  EXPECT_EQ(14, suffix_array->GetSize()); +} + +TEST_F(SuffixArrayTest, TestBuildSuffixArray) { +  vector<int> expected_suffix_array = +      {13, 11, 2, 12, 3, 6, 10, 1, 4, 7, 5, 9, 0, 8}; +  for (size_t i = 0; i < expected_suffix_array.size(); ++i) { +    EXPECT_EQ(expected_suffix_array[i], suffix_array->GetSuffix(i)); +  } +} + +TEST_F(SuffixArrayTest, TestBuildLCP) { +  vector<int> expected_lcp = {-1, 0, 2, 0, 1, 0, 0, 3, 1, 1, 0, 0, 4, 1}; +  EXPECT_EQ(expected_lcp, suffix_array->BuildLCPArray()); +} + +TEST_F(SuffixArrayTest, TestLookup) { +  for (size_t i = 0; i < data.size(); ++i) { +    EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); +  } + +  EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true)); +  EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6)); +  EXPECT_EQ(PhraseLocation(11, 14), suffix_array->Lookup(0, 14, "word1", 0)); + +  EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false)); +  EXPECT_EQ(PhraseLocation(0, 0), suffix_array->Lookup(0, 14, "word2", 0)); + +  EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true)); +  EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4)); +  EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 14, "word3", 1)); + +  EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true)); +  EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1)); +  EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word4", 2)); + +  EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true)); +  EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2)); +  EXPECT_EQ(PhraseLocation(11, 13), suffix_array->Lookup(11, 13, "word5", 3)); + +  EXPECT_EQ(PhraseLocation(12, 13), suffix_array->Lookup(11, 13, "word3", 4)); +  EXPECT_EQ(PhraseLocation(11, 11), suffix_array->Lookup(11, 13, "word5", 1)); +} + +} // namespace +} // namespace extractor diff --git a/extractor/target_phrase_extractor.cc b/extractor/target_phrase_extractor.cc new file mode 100644 index 00000000..2b8a2e4a --- /dev/null +++ b/extractor/target_phrase_extractor.cc @@ -0,0 +1,158 @@ +#include "target_phrase_extractor.h" + +#include <unordered_set> + +#include "alignment.h" +#include "data_array.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "rule_extractor_helper.h" +#include "vocabulary.h" + +using namespace std; + +namespace extractor { + +TargetPhraseExtractor::TargetPhraseExtractor( +    shared_ptr<DataArray> target_data_array, +    shared_ptr<Alignment> alignment, +    shared_ptr<PhraseBuilder> phrase_builder, +    shared_ptr<RuleExtractorHelper> helper, +    shared_ptr<Vocabulary> vocabulary, +    int max_rule_span, +    bool require_tight_phrases) : +    target_data_array(target_data_array), +    alignment(alignment), +    phrase_builder(phrase_builder), +    helper(helper), +    vocabulary(vocabulary), +    max_rule_span(max_rule_span), +    require_tight_phrases(require_tight_phrases) {} + +TargetPhraseExtractor::TargetPhraseExtractor() {} + +TargetPhraseExtractor::~TargetPhraseExtractor() {} + +vector<pair<Phrase, PhraseAlignment> > TargetPhraseExtractor::ExtractPhrases( +    const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, +    int target_phrase_low, int target_phrase_high, +    const unordered_map<int, int>& source_indexes, int sentence_id) const { +  int target_sent_len = target_data_array->GetSentenceLength(sentence_id); + +  vector<int> target_gap_order = helper->GetGapOrder(target_gaps); + +  int target_x_low = target_phrase_low, target_x_high = target_phrase_high; +  if (!require_tight_phrases) { +    // Extend loose target phrase to the left. +    while (target_x_low > 0 && +           target_phrase_high - target_x_low < max_rule_span && +           target_low[target_x_low - 1] == -1) { +      --target_x_low; +    } +    // Extend loose target phrase to the right. +    while (target_x_high < target_sent_len && +           target_x_high - target_phrase_low < max_rule_span && +           target_low[target_x_high] == -1) { +      ++target_x_high; +    } +  } + +  vector<pair<int, int> > gaps(target_gaps.size()); +  for (size_t i = 0; i < gaps.size(); ++i) { +    gaps[i] = target_gaps[target_gap_order[i]]; +    if (!require_tight_phrases) { +      // Extend gap to the left. +      while (gaps[i].first > target_x_low && +             target_low[gaps[i].first - 1] == -1) { +        --gaps[i].first; +      } +      // Extend gap to the right. +      while (gaps[i].second < target_x_high && +             target_low[gaps[i].second] == -1) { +        ++gaps[i].second; +      } +    } +  } + +  // Compute the range in which each chunk may start or end. (Even indexes +  // represent the range in which the chunk may start, odd indexes represent the +  // range in which the chunk may end.) +  vector<pair<int, int> > ranges(2 * gaps.size() + 2); +  ranges.front() = make_pair(target_x_low, target_phrase_low); +  ranges.back() = make_pair(target_phrase_high, target_x_high); +  for (size_t i = 0; i < gaps.size(); ++i) { +    int j = target_gap_order[i]; +    ranges[i * 2 + 1] = make_pair(gaps[i].first, target_gaps[j].first); +    ranges[i * 2 + 2] = make_pair(target_gaps[j].second, gaps[i].second); +  } + +  vector<pair<Phrase, PhraseAlignment> > target_phrases; +  vector<int> subpatterns(ranges.size()); +  GeneratePhrases(target_phrases, ranges, 0, subpatterns, target_gap_order, +                  target_phrase_low, target_phrase_high, source_indexes, +                  sentence_id); +  return target_phrases; +} + +void TargetPhraseExtractor::GeneratePhrases( +    vector<pair<Phrase, PhraseAlignment> >& target_phrases, +    const vector<pair<int, int> >& ranges, int index, vector<int>& subpatterns, +    const vector<int>& target_gap_order, int target_phrase_low, +    int target_phrase_high, const unordered_map<int, int>& source_indexes, +    int sentence_id) const { +  if (index >= ranges.size()) { +    if (subpatterns.back() - subpatterns.front() > max_rule_span) { +      return; +    } + +    vector<int> symbols; +    unordered_map<int, int> target_indexes; + +    // Construct target phrase chunk by chunk. +    int target_sent_start = target_data_array->GetSentenceStart(sentence_id); +    for (size_t i = 0; i * 2 < subpatterns.size(); ++i) { +      for (size_t j = subpatterns[i * 2]; j < subpatterns[i * 2 + 1]; ++j) { +        target_indexes[j] = symbols.size(); +        string target_word = target_data_array->GetWordAtIndex( +            target_sent_start + j); +        symbols.push_back(vocabulary->GetTerminalIndex(target_word)); +      } +      if (i < target_gap_order.size()) { +        symbols.push_back(vocabulary->GetNonterminalIndex( +            target_gap_order[i] + 1)); +      } +    } + +    // Construct the alignment between the source and the target phrase. +    vector<pair<int, int> > links = alignment->GetLinks(sentence_id); +    vector<pair<int, int> > alignment; +    for (pair<int, int> link: links) { +      if (target_indexes.count(link.second)) { +        alignment.push_back(make_pair(source_indexes.find(link.first)->second, +                                      target_indexes[link.second])); +      } +    } + +    Phrase target_phrase = phrase_builder->Build(symbols); +    target_phrases.push_back(make_pair(target_phrase, alignment)); +    return; +  } + +  subpatterns[index] = ranges[index].first; +  if (index > 0) { +    subpatterns[index] = max(subpatterns[index], subpatterns[index - 1]); +  } +  // Choose every possible combination of [start, end) for the current chunk. +  while (subpatterns[index] <= ranges[index].second) { +    subpatterns[index + 1] = max(subpatterns[index], ranges[index + 1].first); +    while (subpatterns[index + 1] <= ranges[index + 1].second) { +      GeneratePhrases(target_phrases, ranges, index + 2, subpatterns, +                      target_gap_order, target_phrase_low, target_phrase_high, +                      source_indexes, sentence_id); +      ++subpatterns[index + 1]; +    } +    ++subpatterns[index]; +  } +} + +} // namespace extractor diff --git a/extractor/target_phrase_extractor.h b/extractor/target_phrase_extractor.h new file mode 100644 index 00000000..289bae2f --- /dev/null +++ b/extractor/target_phrase_extractor.h @@ -0,0 +1,64 @@ +#ifndef _TARGET_PHRASE_EXTRACTOR_H_ +#define _TARGET_PHRASE_EXTRACTOR_H_ + +#include <memory> +#include <unordered_map> +#include <vector> + +using namespace std; + +namespace extractor { + +typedef vector<pair<int, int> > PhraseAlignment; + +class Alignment; +class DataArray; +class Phrase; +class PhraseBuilder; +class RuleExtractorHelper; +class Vocabulary; + +class TargetPhraseExtractor { + public: +  TargetPhraseExtractor(shared_ptr<DataArray> target_data_array, +                        shared_ptr<Alignment> alignment, +                        shared_ptr<PhraseBuilder> phrase_builder, +                        shared_ptr<RuleExtractorHelper> helper, +                        shared_ptr<Vocabulary> vocabulary, +                        int max_rule_span, +                        bool require_tight_phrases); + +  virtual ~TargetPhraseExtractor(); + +  // Finds all the target phrases that can extracted from a span in the +  // target sentence (matching the given set of target phrase gaps). +  virtual vector<pair<Phrase, PhraseAlignment> > ExtractPhrases( +      const vector<pair<int, int> >& target_gaps, const vector<int>& target_low, +      int target_phrase_low, int target_phrase_high, +      const unordered_map<int, int>& source_indexes, int sentence_id) const; + + protected: +  TargetPhraseExtractor(); + + private: +  // Computes the cartesian product over the sets of possible target phrase +  // chunks. +  void GeneratePhrases( +      vector<pair<Phrase, PhraseAlignment> >& target_phrases, +      const vector<pair<int, int> >& ranges, int index, +      vector<int>& subpatterns, const vector<int>& target_gap_order, +      int target_phrase_low, int target_phrase_high, +      const unordered_map<int, int>& source_indexes, int sentence_id) const; + +  shared_ptr<DataArray> target_data_array; +  shared_ptr<Alignment> alignment; +  shared_ptr<PhraseBuilder> phrase_builder; +  shared_ptr<RuleExtractorHelper> helper; +  shared_ptr<Vocabulary> vocabulary; +  int max_rule_span; +  bool require_tight_phrases; +}; + +} // namespace extractor + +#endif diff --git a/extractor/target_phrase_extractor_test.cc b/extractor/target_phrase_extractor_test.cc new file mode 100644 index 00000000..80927dee --- /dev/null +++ b/extractor/target_phrase_extractor_test.cc @@ -0,0 +1,143 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_rule_extractor_helper.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_builder.h" +#include "target_phrase_extractor.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class TargetPhraseExtractorTest : public Test { + protected: +  virtual void SetUp() { +    data_array = make_shared<MockDataArray>(); +    alignment = make_shared<MockAlignment>(); +    vocabulary = make_shared<MockVocabulary>(); +    phrase_builder = make_shared<PhraseBuilder>(vocabulary); +    helper = make_shared<MockRuleExtractorHelper>(); +  } + +  shared_ptr<MockDataArray> data_array; +  shared_ptr<MockAlignment> alignment; +  shared_ptr<MockVocabulary> vocabulary; +  shared_ptr<PhraseBuilder> phrase_builder; +  shared_ptr<MockRuleExtractorHelper> helper; +  shared_ptr<TargetPhraseExtractor> extractor; +}; + +TEST_F(TargetPhraseExtractorTest, TestExtractTightPhrasesTrue) { +  EXPECT_CALL(*data_array, GetSentenceLength(1)).WillRepeatedly(Return(5)); +  EXPECT_CALL(*data_array, GetSentenceStart(1)).WillRepeatedly(Return(3)); + +  vector<string> target_words = {"a", "b", "c", "d", "e"}; +  vector<int> target_symbols = {20, 21, 22, 23, 24}; +  for (size_t i = 0; i < target_words.size(); ++i) { +    EXPECT_CALL(*data_array, GetWordAtIndex(i + 3)) +        .WillRepeatedly(Return(target_words[i])); +    EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i])) +        .WillRepeatedly(Return(target_symbols[i])); +    EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i])) +        .WillRepeatedly(Return(target_words[i])); +  } + +  vector<pair<int, int> > links = { +    make_pair(0, 0), make_pair(1, 3), make_pair(2, 2), make_pair(3, 1), +    make_pair(4, 4) +  }; +  EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links)); + +  vector<int> gap_order = {1, 0}; +  EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order)); + +  extractor = make_shared<TargetPhraseExtractor>( +      data_array, alignment, phrase_builder, helper, vocabulary, 10, true); + +  vector<pair<int, int> > target_gaps = {make_pair(3, 4), make_pair(1, 2)}; +  vector<int> target_low = {0, 3, 2, 1, 4}; +  unordered_map<int, int> source_indexes = {{0, 0}, {2, 2}, {4, 4}}; + +  vector<pair<Phrase, PhraseAlignment> > results =  extractor->ExtractPhrases( +      target_gaps, target_low, 0, 5, source_indexes, 1); +  EXPECT_EQ(1, results.size()); +  vector<int> expected_symbols = {20, -2, 22, -1, 24}; +  EXPECT_EQ(expected_symbols, results[0].first.Get()); +  vector<string> expected_words = {"a", "c", "e"}; +  EXPECT_EQ(expected_words, results[0].first.GetWords()); +  vector<pair<int, int> > expected_alignment = { +    make_pair(0, 0), make_pair(2, 2), make_pair(4, 4) +  }; +  EXPECT_EQ(expected_alignment, results[0].second); +} + +TEST_F(TargetPhraseExtractorTest, TestExtractPhrasesTightPhrasesFalse) { +  vector<string> target_words = {"a", "b", "c", "d", "e", "f", "END_OF_LINE"}; +  vector<int> target_symbols = {20, 21, 22, 23, 24, 25, 1}; +  EXPECT_CALL(*data_array, GetSentenceLength(0)).WillRepeatedly(Return(6)); +  EXPECT_CALL(*data_array, GetSentenceStart(0)).WillRepeatedly(Return(0)); + +  for (size_t i = 0; i < target_words.size(); ++i) { +    EXPECT_CALL(*data_array, GetWordAtIndex(i)) +        .WillRepeatedly(Return(target_words[i])); +    EXPECT_CALL(*vocabulary, GetTerminalIndex(target_words[i])) +        .WillRepeatedly(Return(target_symbols[i])); +    EXPECT_CALL(*vocabulary, GetTerminalValue(target_symbols[i])) +        .WillRepeatedly(Return(target_words[i])); +  } + +  vector<pair<int, int> > links = {make_pair(1, 1)}; +  EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links)); + +  vector<int> gap_order = {0}; +  EXPECT_CALL(*helper, GetGapOrder(_)).WillRepeatedly(Return(gap_order)); + +  extractor = make_shared<TargetPhraseExtractor>( +      data_array, alignment, phrase_builder, helper, vocabulary, 10, false); + +  vector<pair<int, int> > target_gaps = {make_pair(2, 4)}; +  vector<int> target_low = {-1, 1, -1, -1, -1, -1}; +  unordered_map<int, int> source_indexes = {{1, 1}}; + +  vector<pair<Phrase, PhraseAlignment> > results = extractor->ExtractPhrases( +      target_gaps, target_low, 1, 5, source_indexes, 0); +  EXPECT_EQ(10, results.size()); + +  for (int i = 0; i < 2; ++i) { +    for (int j = 4; j <= 6; ++j) { +      for (int k = 4; k <= j; ++k) { +        vector<string> expected_words; +        for (int l = i; l < 2; ++l) { +          expected_words.push_back(target_words[l]); +        } +        for (int l = k; l < j; ++l) { +          expected_words.push_back(target_words[l]); +        } + +        PhraseAlignment expected_alignment; +        expected_alignment.push_back(make_pair(1, 1 - i)); + +        bool found_expected_pair = false; +        for (auto result: results) { +          if (result.first.GetWords() == expected_words && +              result.second == expected_alignment) { +            found_expected_pair = true; +          } +        } + +        EXPECT_TRUE(found_expected_pair); +      } +    } +  } +} + +} // namespace +} // namespace extractor diff --git a/extractor/time_util.cc b/extractor/time_util.cc new file mode 100644 index 00000000..e46a0c3d --- /dev/null +++ b/extractor/time_util.cc @@ -0,0 +1,10 @@ +#include "time_util.h" + +namespace extractor { + +double GetDuration(const Clock::time_point& start_time, +                   const Clock::time_point& stop_time) { +  return duration_cast<milliseconds>(stop_time - start_time).count() / 1000.0; +} + +} // namespace extractor diff --git a/extractor/time_util.h b/extractor/time_util.h new file mode 100644 index 00000000..f7fd51d3 --- /dev/null +++ b/extractor/time_util.h @@ -0,0 +1,19 @@ +#ifndef _TIME_UTIL_H_ +#define _TIME_UTIL_H_ + +#include <chrono> + +using namespace std; +using namespace chrono; + +namespace extractor { + +typedef high_resolution_clock Clock; + +// Computes the duration in seconds of the specified time interval. +double GetDuration(const Clock::time_point& start_time, +                   const Clock::time_point& stop_time); + +} // namespace extractor + +#endif diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc new file mode 100644 index 00000000..45da707a --- /dev/null +++ b/extractor/translation_table.cc @@ -0,0 +1,126 @@ +#include "translation_table.h" + +#include <string> +#include <vector> + +#include <boost/functional/hash.hpp> + +#include "alignment.h" +#include "data_array.h" + +using namespace std; + +namespace extractor { + +TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array, +                                   shared_ptr<DataArray> target_data_array, +                                   shared_ptr<Alignment> alignment) : +    source_data_array(source_data_array), target_data_array(target_data_array) { +  const vector<int>& source_data = source_data_array->GetData(); +  const vector<int>& target_data = target_data_array->GetData(); + +  unordered_map<int, int> source_links_count; +  unordered_map<int, int> target_links_count; +  unordered_map<pair<int, int>, int, PairHash> links_count; + +  // For each pair of aligned source target words increment their link count by +  // 1. Unaligned words are paired with the NULL token. +  for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) { +    vector<pair<int, int> > links = alignment->GetLinks(i); +    int source_start = source_data_array->GetSentenceStart(i); +    int target_start = target_data_array->GetSentenceStart(i); +    // Ignore END_OF_LINE markers. +    int next_source_start = source_data_array->GetSentenceStart(i + 1) - 1; +    int next_target_start = target_data_array->GetSentenceStart(i + 1) - 1; +    vector<int> source_sentence(source_data.begin() + source_start, +        source_data.begin() + next_source_start); +    vector<int> target_sentence(target_data.begin() + target_start, +        target_data.begin() + next_target_start); +    vector<int> source_linked_words(source_sentence.size()); +    vector<int> target_linked_words(target_sentence.size()); + +    for (pair<int, int> link: links) { +      source_linked_words[link.first] = 1; +      target_linked_words[link.second] = 1; +      IncrementLinksCount(source_links_count, target_links_count, links_count, +          source_sentence[link.first], target_sentence[link.second]); +    } + +    for (size_t i = 0; i < source_sentence.size(); ++i) { +      if (!source_linked_words[i]) { +        IncrementLinksCount(source_links_count, target_links_count, links_count, +                            source_sentence[i], DataArray::NULL_WORD); +      } +    } + +    for (size_t i = 0; i < target_sentence.size(); ++i) { +      if (!target_linked_words[i]) { +        IncrementLinksCount(source_links_count, target_links_count, links_count, +                            DataArray::NULL_WORD, target_sentence[i]); +      } +    } +  } + +  // Calculating: +  //   p(e | f) = count(e, f) / count(f) +  //   p(f | e) = count(e, f) / count(e) +  for (pair<pair<int, int>, int> link_count: links_count) { +    int source_word = link_count.first.first; +    int target_word = link_count.first.second; +    double score1 = 1.0 * link_count.second / source_links_count[source_word]; +    double score2 = 1.0 * link_count.second / target_links_count[target_word]; +    translation_probabilities[link_count.first] = make_pair(score1, score2); +  } +} + +TranslationTable::TranslationTable() {} + +TranslationTable::~TranslationTable() {} + +void TranslationTable::IncrementLinksCount( +    unordered_map<int, int>& source_links_count, +    unordered_map<int, int>& target_links_count, +    unordered_map<pair<int, int>, int, PairHash>& links_count, +    int source_word_id, +    int target_word_id) const { +  ++source_links_count[source_word_id]; +  ++target_links_count[target_word_id]; +  ++links_count[make_pair(source_word_id, target_word_id)]; +} + +double TranslationTable::GetTargetGivenSourceScore( +    const string& source_word, const string& target_word) { +  if (!source_data_array->HasWord(source_word) || +      !target_data_array->HasWord(target_word)) { +    return -1; +  } + +  int source_id = source_data_array->GetWordId(source_word); +  int target_id = target_data_array->GetWordId(target_word); +  return translation_probabilities[make_pair(source_id, target_id)].first; +} + +double TranslationTable::GetSourceGivenTargetScore( +    const string& source_word, const string& target_word) { +  if (!source_data_array->HasWord(source_word) || +      !target_data_array->HasWord(target_word)) { +    return -1; +  } + +  int source_id = source_data_array->GetWordId(source_word); +  int target_id = target_data_array->GetWordId(target_word); +  return translation_probabilities[make_pair(source_id, target_id)].second; +} + +void TranslationTable::WriteBinary(const fs::path& filepath) const { +  FILE* file = fopen(filepath.string().c_str(), "w"); + +  int size = translation_probabilities.size(); +  fwrite(&size, sizeof(int), 1, file); +  for (auto entry: translation_probabilities) { +    fwrite(&entry.first, sizeof(entry.first), 1, file); +    fwrite(&entry.second, sizeof(entry.second), 1, file); +  } +} + +} // namespace extractor diff --git a/extractor/translation_table.h b/extractor/translation_table.h new file mode 100644 index 00000000..10504d3b --- /dev/null +++ b/extractor/translation_table.h @@ -0,0 +1,63 @@ +#ifndef _TRANSLATION_TABLE_ +#define _TRANSLATION_TABLE_ + +#include <memory> +#include <string> +#include <unordered_map> + +#include <boost/filesystem.hpp> +#include <boost/functional/hash.hpp> + +using namespace std; +namespace fs = boost::filesystem; + +namespace extractor { + +typedef boost::hash<pair<int, int> > PairHash; + +class Alignment; +class DataArray; + +/** + * Bilexical table with conditional probabilities. + */ +class TranslationTable { + public: +  TranslationTable( +      shared_ptr<DataArray> source_data_array, +      shared_ptr<DataArray> target_data_array, +      shared_ptr<Alignment> alignment); + +  virtual ~TranslationTable(); + +  // Returns p(e | f). +  virtual double GetTargetGivenSourceScore(const string& source_word, +                                           const string& target_word); + +  // Returns p(f | e). +  virtual double GetSourceGivenTargetScore(const string& source_word, +                                           const string& target_word); + +  void WriteBinary(const fs::path& filepath) const; + + protected: +  TranslationTable(); + + private: +  // Increment links count for the given (f, e) word pair. +  void IncrementLinksCount( +      unordered_map<int, int>& source_links_count, +      unordered_map<int, int>& target_links_count, +      unordered_map<pair<int, int>, int, PairHash>& links_count, +      int source_word_id, +      int target_word_id) const; + +  shared_ptr<DataArray> source_data_array; +  shared_ptr<DataArray> target_data_array; +  unordered_map<pair<int, int>, pair<double, double>, PairHash> +      translation_probabilities; +}; + +} // namespace extractor + +#endif diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc new file mode 100644 index 00000000..051b5715 --- /dev/null +++ b/extractor/translation_table_test.cc @@ -0,0 +1,84 @@ +#include <gtest/gtest.h> + +#include <memory> +#include <string> +#include <vector> + +#include "mocks/mock_alignment.h" +#include "mocks/mock_data_array.h" +#include "translation_table.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +TEST(TranslationTableTest, TestScores) { +  vector<string> words = {"a", "b", "c"}; + +  vector<int> source_data = {2, 3, 2, 3, 4, 0, 2, 3, 6, 0, 2, 3, 6, 0}; +  vector<int> source_sentence_start = {0, 6, 10, 14}; +  shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>(); +  EXPECT_CALL(*source_data_array, GetData()) +      .WillRepeatedly(ReturnRef(source_data)); +  EXPECT_CALL(*source_data_array, GetNumSentences()) +      .WillRepeatedly(Return(3)); +  for (size_t i = 0; i < source_sentence_start.size(); ++i) { +    EXPECT_CALL(*source_data_array, GetSentenceStart(i)) +        .WillRepeatedly(Return(source_sentence_start[i])); +  } +  for (size_t i = 0; i < words.size(); ++i) { +    EXPECT_CALL(*source_data_array, HasWord(words[i])) +        .WillRepeatedly(Return(true)); +    EXPECT_CALL(*source_data_array, GetWordId(words[i])) +        .WillRepeatedly(Return(i + 2)); +  } +  EXPECT_CALL(*source_data_array, HasWord("d")) +      .WillRepeatedly(Return(false)); + +  vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0}; +  vector<int> target_sentence_start = {0, 7, 10, 13}; +  shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>(); +  EXPECT_CALL(*target_data_array, GetData()) +      .WillRepeatedly(ReturnRef(target_data)); +  for (size_t i = 0; i < target_sentence_start.size(); ++i) { +    EXPECT_CALL(*target_data_array, GetSentenceStart(i)) +        .WillRepeatedly(Return(target_sentence_start[i])); +  } +  for (size_t i = 0; i < words.size(); ++i) { +    EXPECT_CALL(*target_data_array, HasWord(words[i])) +        .WillRepeatedly(Return(true)); +    EXPECT_CALL(*target_data_array, GetWordId(words[i])) +        .WillRepeatedly(Return(i + 2)); +  } +  EXPECT_CALL(*target_data_array, HasWord("d")) +      .WillRepeatedly(Return(false)); + +  vector<pair<int, int> > links1 = { +    make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), +    make_pair(4, 4), make_pair(4, 5) +  }; +  vector<pair<int, int> > links2 = {make_pair(1, 0), make_pair(2, 1)}; +  vector<pair<int, int> > links3 = {make_pair(0, 0), make_pair(2, 1)}; +  shared_ptr<MockAlignment> alignment = make_shared<MockAlignment>(); +  EXPECT_CALL(*alignment, GetLinks(0)).WillRepeatedly(Return(links1)); +  EXPECT_CALL(*alignment, GetLinks(1)).WillRepeatedly(Return(links2)); +  EXPECT_CALL(*alignment, GetLinks(2)).WillRepeatedly(Return(links3)); + +  shared_ptr<TranslationTable> table = make_shared<TranslationTable>( +      source_data_array, target_data_array, alignment); + +  EXPECT_EQ(0.75, table->GetTargetGivenSourceScore("a", "a")); +  EXPECT_EQ(0, table->GetTargetGivenSourceScore("a", "b")); +  EXPECT_EQ(0.5, table->GetTargetGivenSourceScore("c", "c")); +  EXPECT_EQ(-1, table->GetTargetGivenSourceScore("c", "d")); + +  EXPECT_EQ(1, table->GetSourceGivenTargetScore("a", "a")); +  EXPECT_EQ(0, table->GetSourceGivenTargetScore("a", "b")); +  EXPECT_EQ(1, table->GetSourceGivenTargetScore("c", "c")); +  EXPECT_EQ(-1, table->GetSourceGivenTargetScore("c", "d")); +} + +} // namespace +} // namespace extractor diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc new file mode 100644 index 00000000..15795d1e --- /dev/null +++ b/extractor/vocabulary.cc @@ -0,0 +1,37 @@ +#include "vocabulary.h" + +namespace extractor { + +Vocabulary::~Vocabulary() {} + +int Vocabulary::GetTerminalIndex(const string& word) { +  int word_id = -1; +  #pragma omp critical (vocabulary) +  { +    if (!dictionary.count(word)) { +      word_id = words.size(); +      dictionary[word] = word_id; +      words.push_back(word); +    } else { +      word_id = dictionary[word]; +    } +  } +  return word_id; +} + +int Vocabulary::GetNonterminalIndex(int position) { +  return -position; +} + +bool Vocabulary::IsTerminal(int symbol) { +  return symbol >= 0; +} + +string Vocabulary::GetTerminalValue(int symbol) { +  string word; +  #pragma omp critical (vocabulary) +  word = words[symbol]; +  return word; +} + +} // namespace extractor diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h new file mode 100644 index 00000000..c8fd9411 --- /dev/null +++ b/extractor/vocabulary.h @@ -0,0 +1,48 @@ +#ifndef _VOCABULARY_H_ +#define _VOCABULARY_H_ + +#include <string> +#include <unordered_map> +#include <vector> + +using namespace std; + +namespace extractor { + +/** + * Data structure for mapping words to word ids. + * + * This strucure contains words located in the frequent collocations and words + * encountered during the grammar extraction time. This dictionary is + * considerably smaller than the dictionaries in the data arrays (and so is the + * query time). Note that this is the single data structure that changes state + * and needs to have thread safe read/write operations. + * + * Note: For an experiment using different vocabulary instances for each thread, + * the running time did not improve implying that the critical regions do not + * cause bottlenecks. + */ +class Vocabulary { + public: +  virtual ~Vocabulary(); + +  // Returns the word id for the given word. +  virtual int GetTerminalIndex(const string& word); + +  // Returns the id for a nonterminal located at the given position in a phrase. +  int GetNonterminalIndex(int position); + +  // Checks if a symbol is a nonterminal. +  bool IsTerminal(int symbol); + +  // Returns the word corresponding to the given word id. +  virtual string GetTerminalValue(int symbol); + + private: +  unordered_map<string, int> dictionary; +  vector<string> words; +}; + +} // namespace extractor + +#endif | 
