diff options
Diffstat (limited to 'extractor')
53 files changed, 1298 insertions, 549 deletions
| diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 65a3d436..e5b439f9 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -1,5 +1,5 @@ -bin_PROGRAMS = compile run_extractor +bin_PROGRAMS = compile run_extractor extract  if HAVE_CXX11 @@ -15,16 +15,19 @@ EXTRA_PROGRAMS = alignment_test \      feature_target_given_source_coherent_test \      grammar_extractor_test \      matchings_finder_test \ +    matchings_sampler_test \ +    phrase_location_sampler_test \      phrase_test \      precomputation_test \      rule_extractor_helper_test \      rule_extractor_test \      rule_factory_test \ -    sampler_test \      scorer_test \ +    suffix_array_sampler_test \      suffix_array_test \      target_phrase_extractor_test \ -    translation_table_test +    translation_table_test \ +    vocabulary_test  if HAVE_GTEST    RUNNABLE_TESTS = alignment_test \ @@ -39,16 +42,19 @@ if HAVE_GTEST      feature_target_given_source_coherent_test \      grammar_extractor_test \      matchings_finder_test \ +    matchings_sampler_test \ +    phrase_location_sampler_test \      phrase_test \      precomputation_test \      rule_extractor_helper_test \      rule_extractor_test \      rule_factory_test \ -    sampler_test \      scorer_test \ +    suffix_array_sampler_test \      suffix_array_test \      target_phrase_extractor_test \ -    translation_table_test +    translation_table_test \ +    vocabulary_test  endif  noinst_PROGRAMS = $(RUNNABLE_TESTS) @@ -79,6 +85,10 @@ grammar_extractor_test_SOURCES = grammar_extractor_test.cc  grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  matchings_finder_test_SOURCES = matchings_finder_test.cc  matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +matchings_sampler_test_SOURCES = matchings_sampler_test.cc +matchings_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +phrase_location_sampler_test_SOURCES = phrase_location_sampler_test.cc +phrase_location_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  phrase_test_SOURCES = phrase_test.cc  phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  precomputation_test_SOURCES = precomputation_test.cc @@ -89,57 +99,31 @@ rule_extractor_test_SOURCES = rule_extractor_test.cc  rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  rule_factory_test_SOURCES = rule_factory_test.cc  rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a -sampler_test_SOURCES = sampler_test.cc -sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  scorer_test_SOURCES = scorer_test.cc  scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +suffix_array_sampler_test_SOURCES = suffix_array_sampler_test.cc +suffix_array_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  suffix_array_test_SOURCES = suffix_array_test.cc  suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc  target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a  translation_table_test_SOURCES = translation_table_test.cc  translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +vocabulary_test_SOURCES = vocabulary_test.cc +vocabulary_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a -noinst_LIBRARIES = libextractor.a libcompile.a +noinst_LIBRARIES = libextractor.a  compile_SOURCES = compile.cc -compile_LDADD = libcompile.a +compile_LDADD = libextractor.a  run_extractor_SOURCES = run_extractor.cc  run_extractor_LDADD = libextractor.a - -libcompile_a_SOURCES = \ -  alignment.cc \ -  data_array.cc \ -  phrase_location.cc \ -  precomputation.cc \ -  suffix_array.cc \ -  time_util.cc \ -  translation_table.cc \ -  alignment.h \ -  data_array.h \ -  fast_intersector.h \ -  grammar.h \ -  grammar_extractor.h \ -  matchings_finder.h \ -  matchings_trie.h \ -  phrase.h \ -  phrase_builder.h \ -  phrase_location.h \ -  precomputation.h \ -  rule.h \ -  rule_extractor.h \ -  rule_extractor_helper.h \ -  rule_factory.h \ -  sampler.h \ -  scorer.h \ -  suffix_array.h \ -  target_phrase_extractor.h \ -  time_util.h \ -  translation_table.h \ -  vocabulary.h +extract_SOURCES = extract.cc +extract_LDADD = libextractor.a  libextractor_a_SOURCES = \    alignment.cc \ +  backoff_sampler.cc \    data_array.cc \    fast_intersector.cc \    features/count_source_target.cc \ @@ -153,18 +137,20 @@ libextractor_a_SOURCES = \    grammar.cc \    grammar_extractor.cc \    matchings_finder.cc \ +  matchings_sampler.cc \    matchings_trie.cc \    phrase.cc \    phrase_builder.cc \    phrase_location.cc \ +  phrase_location_sampler.cc \    precomputation.cc \    rule.cc \    rule_extractor.cc \    rule_extractor_helper.cc \    rule_factory.cc \ -  sampler.cc \    scorer.cc \    suffix_array.cc \ +  suffix_array_sampler.cc \    target_phrase_extractor.cc \    time_util.cc \    translation_table.cc \ diff --git a/extractor/README.md b/extractor/README.md index c9db8de8..642fbd1d 100644 --- a/extractor/README.md +++ b/extractor/README.md @@ -1,8 +1,14 @@  C++ implementation of the online grammar extractor originally developed by [Adam Lopez](http://www.cs.jhu.edu/~alopez/). -To run the extractor you need to: +The grammar extraction takes place in two steps: (a) precomputing a number of data structures and (b) actually extracting the grammars. All the flags below have the same meaning as in the cython implementation. -    cdec/extractor/run_extractor -t <num_threads> -a <alignment> -b <parallel_corpus> -g <grammar_output_path> < <input_sentences> > <sgm_file> +To compile the data structures you need to run: + +    cdec/extractor/compile -a <alignment> -b <parallel_corpus> -c <compile_config_file> -o <compile_directory> + +To extract the grammars you need to run: + +    cdec/extract/extract -t <num_threads> -c <compile_config_file> -g <grammar_output_path> < <input_sentencs> > <sgm_file>  To run unit tests you need first to configure `cdec` with the [Google Test](https://code.google.com/p/googletest/) and [Google Mock](https://code.google.com/p/googlemock/) libraries: diff --git a/extractor/alignment.cc b/extractor/alignment.cc index 2278c825..4a7a14f4 100644 --- a/extractor/alignment.cc +++ b/extractor/alignment.cc @@ -8,9 +8,7 @@  #include <vector>  #include <boost/algorithm/string.hpp> -#include <boost/filesystem.hpp> -namespace fs = boost::filesystem;  using namespace std;  namespace extractor { diff --git a/extractor/alignment.h b/extractor/alignment.h index dc5a8b55..76c27da2 100644 --- a/extractor/alignment.h +++ b/extractor/alignment.h @@ -4,13 +4,11 @@  #include <string>  #include <vector> -#include <boost/filesystem.hpp>  #include <boost/serialization/serialization.hpp>  #include <boost/serialization/split_member.hpp>  #include <boost/serialization/utility.hpp>  #include <boost/serialization/vector.hpp> -namespace fs = boost::filesystem;  using namespace std;  namespace extractor { diff --git a/extractor/backoff_sampler.cc b/extractor/backoff_sampler.cc new file mode 100644 index 00000000..891276c6 --- /dev/null +++ b/extractor/backoff_sampler.cc @@ -0,0 +1,66 @@ +#include "backoff_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" + +namespace extractor { + +BackoffSampler::BackoffSampler( +    shared_ptr<DataArray> source_data_array, int max_samples) : +    source_data_array(source_data_array), max_samples(max_samples) {} + +BackoffSampler::BackoffSampler() {} + +PhraseLocation BackoffSampler::Sample( +    const PhraseLocation& location, +    const unordered_set<int>& blacklisted_sentence_ids) const { +  vector<int> samples; +  int low = GetRangeLow(location), high = GetRangeHigh(location); +  int last = low - 1; +  double step = max(1.0, (double) (high - low) / max_samples); +  for (double num_samples = 0, i = low; +       num_samples < max_samples && i < high; +       ++num_samples, i += step) { +    int sample = round(i); +    int position = GetPosition(location, sample); +    int sentence_id = source_data_array->GetSentenceId(position); +    bool found = false; +    if (last >= sample || +        blacklisted_sentence_ids.count(sentence_id)) { +      for (double backoff_step = 1; backoff_step < step; ++backoff_step) { +        double j = i - backoff_step; +        sample = round(j); +        if (sample >= 0) { +          position = GetPosition(location, sample); +          sentence_id = source_data_array->GetSentenceId(position); +          if (sample > last && !blacklisted_sentence_ids.count(sentence_id)) { +            found = true; +            break; +          } +        } + +        double k = i + backoff_step; +        sample = round(k); +        if (sample < high) { +          position = GetPosition(location, sample); +          sentence_id = source_data_array->GetSentenceId(position); +          if (!blacklisted_sentence_ids.count(sentence_id)) { +            found = true; +            break; +          } +        } +      } +    } else { +      found = true; +    } + +    if (found) { +      last = sample; +      AppendMatching(samples, sample, location); +    } +  } + +  return PhraseLocation(samples, GetNumSubpatterns(location)); +} + +} // namespace extractor diff --git a/extractor/backoff_sampler.h b/extractor/backoff_sampler.h new file mode 100644 index 00000000..5c244105 --- /dev/null +++ b/extractor/backoff_sampler.h @@ -0,0 +1,41 @@ +#ifndef _BACKOFF_SAMPLER_H_ +#define _BACKOFF_SAMPLER_H_ + +#include <vector> + +#include "sampler.h" + +namespace extractor { + +class DataArray; +class PhraseLocation; + +class BackoffSampler : public Sampler { + public: +  BackoffSampler(shared_ptr<DataArray> source_data_array, int max_samples); + +  BackoffSampler(); + +  PhraseLocation Sample( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids) const; + + private: +  virtual int GetNumSubpatterns(const PhraseLocation& location) const = 0; + +  virtual int GetRangeLow(const PhraseLocation& location) const = 0; + +  virtual int GetRangeHigh(const PhraseLocation& location) const = 0; + +  virtual int GetPosition(const PhraseLocation& location, int index) const = 0; + +  virtual void AppendMatching(vector<int>& samples, int index, +                              const PhraseLocation& location) const = 0; + +  shared_ptr<DataArray> source_data_array; +  int max_samples; +}; + +} // namespace extractor + +#endif diff --git a/extractor/compile.cc b/extractor/compile.cc index 65fdd509..3ee668ce 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -13,6 +13,7 @@  #include "suffix_array.h"  #include "time_util.h"  #include "translation_table.h" +#include "vocabulary.h"  namespace ar = boost::archive;  namespace fs = boost::filesystem; @@ -29,6 +30,8 @@ int main(int argc, char** argv) {      ("bitext,b", po::value<string>(), "Parallel text (source ||| target)")      ("alignment,a", po::value<string>()->required(), "Bitext word alignment")      ("output,o", po::value<string>()->required(), "Output path") +    ("config,c", po::value<string>()->required(), +        "Path where the config file will be generated")      ("frequent", po::value<int>()->default_value(100),          "Number of precomputed frequent patterns")      ("super_frequent", po::value<int>()->default_value(10), @@ -81,8 +84,12 @@ int main(int argc, char** argv) {      target_data_array = make_shared<DataArray>(vm["target"].as<string>());    } +  ofstream config_stream(vm["config"].as<string>()); +    Clock::time_point start_write = Clock::now(); -  ofstream target_fstream((output_dir / fs::path("target.bin")).string()); +  string target_path = (output_dir / fs::path("target.bin")).string(); +  config_stream << "target = " << target_path << endl; +  ofstream target_fstream(target_path);    ar::binary_oarchive target_stream(target_fstream);    target_stream << *target_data_array;    Clock::time_point stop_write = Clock::now(); @@ -99,7 +106,9 @@ int main(int argc, char** argv) {        make_shared<SuffixArray>(source_data_array);    start_write = Clock::now(); -  ofstream source_fstream((output_dir / fs::path("source.bin")).string()); +  string source_path = (output_dir / fs::path("source.bin")).string(); +  config_stream << "source = " << source_path << endl; +  ofstream source_fstream(source_path);    ar::binary_oarchive output_stream(source_fstream);    output_stream << *source_suffix_array;    stop_write = Clock::now(); @@ -115,7 +124,9 @@ int main(int argc, char** argv) {        make_shared<Alignment>(vm["alignment"].as<string>());    start_write = Clock::now(); -  ofstream alignment_fstream((output_dir / fs::path("alignment.bin")).string()); +  string alignment_path = (output_dir / fs::path("alignment.bin")).string(); +  config_stream << "alignment = " << alignment_path << endl; +  ofstream alignment_fstream(alignment_path);    ar::binary_oarchive alignment_stream(alignment_fstream);    alignment_stream << *alignment;    stop_write = Clock::now(); @@ -125,9 +136,12 @@ int main(int argc, char** argv) {    cerr << "Reading alignment took "         << GetDuration(start_time, stop_time) << " seconds" << endl; +  shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>(); +    start_time = Clock::now();    cerr << "Precomputing collocations..." << endl;    Precomputation precomputation( +      vocabulary,        source_suffix_array,        vm["frequent"].as<int>(),        vm["super_frequent"].as<int>(), @@ -138,9 +152,17 @@ int main(int argc, char** argv) {        vm["min_frequency"].as<int>());    start_write = Clock::now(); -  ofstream precomp_fstream((output_dir / fs::path("precomp.bin")).string()); +  string precomputation_path = (output_dir / fs::path("precomp.bin")).string(); +  config_stream << "precomputation = " << precomputation_path << endl; +  ofstream precomp_fstream(precomputation_path);    ar::binary_oarchive precomp_stream(precomp_fstream);    precomp_stream << precomputation; + +  string vocabulary_path = (output_dir / fs::path("vocab.bin")).string(); +  config_stream << "vocabulary = " << vocabulary_path << endl; +  ofstream vocab_fstream(vocabulary_path); +  ar::binary_oarchive vocab_stream(vocab_fstream); +  vocab_stream << *vocabulary;    stop_write = Clock::now();    write_duration += GetDuration(start_write, stop_write); @@ -153,7 +175,9 @@ int main(int argc, char** argv) {    TranslationTable table(source_data_array, target_data_array, alignment);    start_write = Clock::now(); -  ofstream table_fstream((output_dir / fs::path("bilex.bin")).string()); +  string table_path = (output_dir / fs::path("bilex.bin")).string(); +  config_stream << "ttable = " << table_path << endl; +  ofstream table_fstream(table_path);    ar::binary_oarchive table_stream(table_fstream);    table_stream << table;    stop_write = Clock::now(); diff --git a/extractor/data_array.cc b/extractor/data_array.cc index 2e4bdafb..9612aa8a 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -5,9 +5,6 @@  #include <sstream>  #include <string> -#include <boost/filesystem.hpp> - -namespace fs = boost::filesystem;  using namespace std;  namespace extractor { @@ -81,7 +78,7 @@ void DataArray::CreateDataArray(const vector<string>& lines) {  DataArray::~DataArray() {} -const vector<int>& DataArray::GetData() const { +vector<int> DataArray::GetData() const {    return data;  } @@ -93,6 +90,18 @@ string DataArray::GetWordAtIndex(int index) const {    return id2word[data[index]];  } +vector<int> DataArray::GetWordIds(int index, int size) const { +  return vector<int>(data.begin() + index, data.begin() + index + size); +} + +vector<string> DataArray::GetWords(int start_index, int size) const { +  vector<string> words; +  for (int word_id: GetWordIds(start_index, size)) { +    words.push_back(id2word[word_id]); +  } +  return words; +} +  int DataArray::GetSize() const {    return data.size();  } @@ -118,10 +127,6 @@ int DataArray::GetSentenceId(int position) const {    return sentence_id[position];  } -bool DataArray::HasWord(const string& word) const { -  return word2id.count(word); -} -  int DataArray::GetWordId(const string& word) const {    auto result = word2id.find(word);    return result == word2id.end() ? -1 : result->second; diff --git a/extractor/data_array.h b/extractor/data_array.h index 2be6a09c..b96901d1 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -5,13 +5,11 @@  #include <unordered_map>  #include <vector> -#include <boost/filesystem.hpp>  #include <boost/serialization/serialization.hpp>  #include <boost/serialization/split_member.hpp>  #include <boost/serialization/string.hpp>  #include <boost/serialization/vector.hpp> -namespace fs = boost::filesystem;  using namespace std;  namespace extractor { @@ -53,7 +51,7 @@ class DataArray {    virtual ~DataArray();    // Returns a vector containing the word ids. -  virtual const vector<int>& GetData() const; +  virtual vector<int> GetData() const;    // Returns the word id at the specified position.    virtual int AtIndex(int index) const; @@ -61,15 +59,20 @@ class DataArray {    // Returns the original word at the specified position.    virtual string GetWordAtIndex(int index) const; +  // Returns the substring of word ids starting at the specified position and +  // having the specified length. +  virtual vector<int> GetWordIds(int start_index, int size) const; + +  // Returns the substring of words starting at the specified position and +  // having the specified length. +  virtual vector<string> GetWords(int start_index, int size) const; +    // Returns the size of the data array.    virtual int GetSize() const;    // Returns the number of distinct words in the data array.    virtual int GetVocabularySize() const; -  // Returns whether a word has ever been observed in the data array. -  virtual bool HasWord(const string& word) const; -    // Returns the word id for a given word or -1 if it the word has never been    // observed.    virtual int GetWordId(const string& word) const; diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc index 6c329e34..99f79d91 100644 --- a/extractor/data_array_test.cc +++ b/extractor/data_array_test.cc @@ -56,18 +56,26 @@ TEST_F(DataArrayTest, TestGetData) {    }  } +TEST_F(DataArrayTest, TestSubstrings) { +  vector<int> expected_word_ids = {3, 4, 5}; +  vector<string> expected_words = {"are", "mere", "."}; +  EXPECT_EQ(expected_word_ids, source_data.GetWordIds(1, 3)); +  EXPECT_EQ(expected_words, source_data.GetWords(1, 3)); + +  expected_word_ids = {7, 8}; +  expected_words = {"a", "lot"}; +  EXPECT_EQ(expected_word_ids, target_data.GetWordIds(7, 2)); +  EXPECT_EQ(expected_words, target_data.GetWords(7, 2)); +} +  TEST_F(DataArrayTest, TestVocabulary) {    EXPECT_EQ(9, source_data.GetVocabularySize()); -  EXPECT_TRUE(source_data.HasWord("mere"));    EXPECT_EQ(4, source_data.GetWordId("mere"));    EXPECT_EQ("mere", source_data.GetWord(4)); -  EXPECT_FALSE(source_data.HasWord("banane"));    EXPECT_EQ(11, target_data.GetVocabularySize()); -  EXPECT_TRUE(target_data.HasWord("apples"));    EXPECT_EQ(4, target_data.GetWordId("apples"));    EXPECT_EQ("apples", target_data.GetWord(4)); -  EXPECT_FALSE(target_data.HasWord("bananas"));  }  TEST_F(DataArrayTest, TestSentenceData) { diff --git a/extractor/extract.cc b/extractor/extract.cc new file mode 100644 index 00000000..387cbe9b --- /dev/null +++ b/extractor/extract.cc @@ -0,0 +1,253 @@ +#include <fstream> +#include <iostream> +#include <memory> +#include <string> +#include <vector> + +#include <boost/archive/binary_iarchive.hpp> +#include <boost/filesystem.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +#include <omp.h> + +#include "alignment.h" +#include "data_array.h" +#include "features/count_source_target.h" +#include "features/feature.h" +#include "features/is_source_singleton.h" +#include "features/is_source_target_singleton.h" +#include "features/max_lex_source_given_target.h" +#include "features/max_lex_target_given_source.h" +#include "features/sample_source_count.h" +#include "features/target_given_source_coherent.h" +#include "grammar.h" +#include "grammar_extractor.h" +#include "precomputation.h" +#include "rule.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "translation_table.h" +#include "vocabulary.h" + +namespace ar = boost::archive; +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace extractor; +using namespace features; +using namespace std; + +// Returns the file path in which a given grammar should be written. +fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) { +  string file_name = "grammar." + to_string(file_number); +  return grammar_path / file_name; +} + +int main(int argc, char** argv) { +  po::options_description general_options("General options"); +  int max_threads = 1; +  #pragma omp parallel +  max_threads = omp_get_num_threads(); +  string threads_option = "Number of threads used for grammar extraction " +                          "max(" + to_string(max_threads) + ")"; +  general_options.add_options() +    ("threads,t", po::value<int>()->required()->default_value(1), +     threads_option.c_str()) +    ("grammars,g", po::value<string>()->required(), "Grammars output path") +    ("max_rule_span", po::value<int>()->default_value(15), +        "Maximum rule span") +    ("max_rule_symbols", po::value<int>()->default_value(5), +        "Maximum number of symbols (terminals + nontermals) in a rule") +    ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size") +    ("max_nonterminals", po::value<int>()->default_value(2), +        "Maximum number of nonterminals in a rule") +    ("max_samples", po::value<int>()->default_value(300), +        "Maximum number of samples") +    ("tight_phrases", po::value<bool>()->default_value(true), +        "False if phrases may be loose (better, but slower)") +    ("leave_one_out", po::value<bool>()->zero_tokens(), +        "do leave-one-out estimation of grammars " +        "(e.g. for extracting grammars for the training set"); + +  po::options_description cmdline_options("Command line options"); +  cmdline_options.add_options() +    ("help", "Show available options") +    ("config,c", po::value<string>()->required(), "Path to config file"); +  cmdline_options.add(general_options); + +  po::options_description config_options("Config file options"); +  config_options.add_options() +    ("target", po::value<string>()->required(), +        "Path to target data file in binary format") +    ("source", po::value<string>()->required(), +        "Path to source suffix array file in binary format") +    ("alignment", po::value<string>()->required(), +        "Path to alignment file in binary format") +    ("precomputation", po::value<string>()->required(), +        "Path to precomputation file in binary format") +    ("vocabulary", po::value<string>()->required(), +        "Path to vocabulary file in binary format") +    ("ttable", po::value<string>()->required(), +        "Path to translation table in binary format"); +  config_options.add(general_options); + +  po::variables_map vm; +  po::store(po::parse_command_line(argc, argv, cmdline_options), vm); +  if (vm.count("help")) { +    po::options_description all_options; +    all_options.add(cmdline_options).add(config_options); +    cout << all_options << endl; +    return 0; +  } + +  po::notify(vm); + +  ifstream config_stream(vm["config"].as<string>()); +  po::store(po::parse_config_file(config_stream, config_options), vm); +  po::notify(vm); + +  int num_threads = vm["threads"].as<int>(); +  cerr << "Grammar extraction will use " << num_threads << " threads." << endl; + +  Clock::time_point read_start_time = Clock::now(); + +  Clock::time_point start_time = Clock::now(); +  cerr << "Reading target data in binary format..." << endl; +  shared_ptr<DataArray> target_data_array = make_shared<DataArray>(); +  ifstream target_fstream(vm["target"].as<string>()); +  ar::binary_iarchive target_stream(target_fstream); +  target_stream >> *target_data_array; +  Clock::time_point end_time = Clock::now(); +  cerr << "Reading target data took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading source suffix array in binary format..." << endl; +  shared_ptr<SuffixArray> source_suffix_array = make_shared<SuffixArray>(); +  ifstream source_fstream(vm["source"].as<string>()); +  ar::binary_iarchive source_stream(source_fstream); +  source_stream >> *source_suffix_array; +  end_time = Clock::now(); +  cerr << "Reading source suffix array took " +       << GetDuration(start_time, end_time) << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading alignment in binary format..." << endl; +  shared_ptr<Alignment> alignment = make_shared<Alignment>(); +  ifstream alignment_fstream(vm["alignment"].as<string>()); +  ar::binary_iarchive alignment_stream(alignment_fstream); +  alignment_stream >> *alignment; +  end_time = Clock::now(); +  cerr << "Reading alignment took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading precomputation in binary format..." << endl; +  shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(); +  ifstream precomputation_fstream(vm["precomputation"].as<string>()); +  ar::binary_iarchive precomputation_stream(precomputation_fstream); +  precomputation_stream >> *precomputation; +  end_time = Clock::now(); +  cerr << "Reading precomputation took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading vocabulary in binary format..." << endl; +  shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>(); +  ifstream vocabulary_fstream(vm["vocabulary"].as<string>()); +  ar::binary_iarchive vocabulary_stream(vocabulary_fstream); +  vocabulary_stream >> *vocabulary; +  end_time = Clock::now(); +  cerr << "Reading vocabulary took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  start_time = Clock::now(); +  cerr << "Reading translation table in binary format..." << endl; +  shared_ptr<TranslationTable> table = make_shared<TranslationTable>(); +  ifstream ttable_fstream(vm["ttable"].as<string>()); +  ar::binary_iarchive ttable_stream(ttable_fstream); +  ttable_stream >> *table; +  end_time = Clock::now(); +  cerr << "Reading translation table took " << GetDuration(start_time, end_time) +       << " seconds" << endl; + +  Clock::time_point read_end_time = Clock::now(); +  cerr << "Total time spent loading data structures into memory: " +       << GetDuration(read_start_time, read_end_time) << " seconds" << endl; + +  Clock::time_point extraction_start_time = Clock::now(); +  // Features used to score each grammar rule. +  vector<shared_ptr<Feature>> features = { +      make_shared<TargetGivenSourceCoherent>(), +      make_shared<SampleSourceCount>(), +      make_shared<CountSourceTarget>(), +      make_shared<MaxLexSourceGivenTarget>(table), +      make_shared<MaxLexTargetGivenSource>(table), +      make_shared<IsSourceSingleton>(), +      make_shared<IsSourceTargetSingleton>() +  }; +  shared_ptr<Scorer> scorer = make_shared<Scorer>(features); + +  GrammarExtractor extractor( +      source_suffix_array, +      target_data_array, +      alignment, +      precomputation, +      scorer, +      vocabulary, +      vm["min_gap_size"].as<int>(), +      vm["max_rule_span"].as<int>(), +      vm["max_nonterminals"].as<int>(), +      vm["max_rule_symbols"].as<int>(), +      vm["max_samples"].as<int>(), +      vm["tight_phrases"].as<bool>()); + +  // Creates the grammars directory if it doesn't exist. +  fs::path grammar_path = vm["grammars"].as<string>(); +  if (!fs::is_directory(grammar_path)) { +    fs::create_directory(grammar_path); +  } + +  // Reads all sentences for which we extract grammar rules (the paralellization +  // is simplified if we read all sentences upfront). +  string sentence; +  vector<string> sentences; +  while (getline(cin, sentence)) { +    sentences.push_back(sentence); +  } + +  // Extracts the grammar for each sentence and saves it to a file. +  vector<string> suffixes(sentences.size()); +  bool leave_one_out = vm.count("leave_one_out"); +  #pragma omp parallel for schedule(dynamic) num_threads(num_threads) +  for (size_t i = 0; i < sentences.size(); ++i) { +    string suffix; +    int position = sentences[i].find("|||"); +    if (position != sentences[i].npos) { +      suffix = sentences[i].substr(position); +      sentences[i] = sentences[i].substr(0, position); +    } +    suffixes[i] = suffix; + +    unordered_set<int> blacklisted_sentence_ids; +    if (leave_one_out) { +      blacklisted_sentence_ids.insert(i); +    } +    Grammar grammar = extractor.GetGrammar( +        sentences[i], blacklisted_sentence_ids); +    ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); +    output << grammar; +  } + +  for (size_t i = 0; i < sentences.size(); ++i) { +    cout << "<seg grammar=" << GetGrammarFilePath(grammar_path, i) << " id=\"" +         << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl; +  } + +  Clock::time_point extraction_stop_time = Clock::now(); +  cerr << "Overall extraction step took " +       << GetDuration(extraction_start_time, extraction_stop_time) +       << " seconds" << endl; + +  return 0; +} diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc index a8591a72..0d1fa6d8 100644 --- a/extractor/fast_intersector.cc +++ b/extractor/fast_intersector.cc @@ -11,41 +11,22 @@  namespace extractor { -FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array, -                                 shared_ptr<Precomputation> precomputation, -                                 shared_ptr<Vocabulary> vocabulary, -                                 int max_rule_span, -                                 int min_gap_size) : +FastIntersector::FastIntersector( +    shared_ptr<SuffixArray> suffix_array, +    shared_ptr<Precomputation> precomputation, +    shared_ptr<Vocabulary> vocabulary, +    int max_rule_span, +    int min_gap_size) :      suffix_array(suffix_array), +    precomputation(precomputation),      vocabulary(vocabulary),      max_rule_span(max_rule_span), -    min_gap_size(min_gap_size) { -  Index precomputed_collocations = precomputation->GetCollocations(); -  for (pair<vector<int>, vector<int>> entry: precomputed_collocations) { -    vector<int> phrase = ConvertPhrase(entry.first); -    collocations[phrase] = entry.second; -  } -} +    min_gap_size(min_gap_size) {}  FastIntersector::FastIntersector() {}  FastIntersector::~FastIntersector() {} -vector<int> FastIntersector::ConvertPhrase(const vector<int>& old_phrase) { -  vector<int> new_phrase; -  new_phrase.reserve(old_phrase.size()); -  shared_ptr<DataArray> data_array = suffix_array->GetData(); -  for (int word_id: old_phrase) { -    if (word_id < 0) { -      new_phrase.push_back(word_id); -    } else { -      new_phrase.push_back( -          vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); -    } -  } -  return new_phrase; -} -  PhraseLocation FastIntersector::Intersect(      PhraseLocation& prefix_location,      PhraseLocation& suffix_location, @@ -59,8 +40,9 @@ PhraseLocation FastIntersector::Intersect(    assert(vocabulary->IsTerminal(symbols.front())        && vocabulary->IsTerminal(symbols.back())); -  if (collocations.count(symbols)) { -    return PhraseLocation(collocations[symbols], phrase.Arity() + 1); +  if (precomputation->Contains(symbols)) { +    return PhraseLocation(precomputation->GetCollocations(symbols), +                          phrase.Arity() + 1);    }    bool prefix_ends_with_x = diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h index 2819d239..305373dc 100644 --- a/extractor/fast_intersector.h +++ b/extractor/fast_intersector.h @@ -12,7 +12,6 @@ using namespace std;  namespace extractor {  typedef boost::hash<vector<int>> VectorHash; -typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;  class Phrase;  class PhraseLocation; @@ -52,11 +51,6 @@ class FastIntersector {    FastIntersector();   private: -  // Uses the vocabulary to convert the phrase from the numberized format -  // specified by the source data array to the numberized format given by the -  // vocabulary. -  vector<int> ConvertPhrase(const vector<int>& old_phrase); -    // Estimates the number of computations needed if the prefix/suffix is    // extended. If the last/first symbol is separated from the rest of the phrase    // by a nonterminal, then for each occurrence of the prefix/suffix we need to @@ -85,10 +79,10 @@ class FastIntersector {    pair<int, int> GetSearchRange(bool has_marginal_x) const;    shared_ptr<SuffixArray> suffix_array; +  shared_ptr<Precomputation> precomputation;    shared_ptr<Vocabulary> vocabulary;    int max_rule_span;    int min_gap_size; -  Index collocations;  };  } // namespace extractor diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc index 76c3aaea..f2a26ba1 100644 --- a/extractor/fast_intersector_test.cc +++ b/extractor/fast_intersector_test.cc @@ -59,15 +59,13 @@ class FastIntersectorTest : public Test {      }      precomputation = make_shared<MockPrecomputation>(); -    EXPECT_CALL(*precomputation, GetCollocations()) -        .WillRepeatedly(ReturnRef(collocations)); +    EXPECT_CALL(*precomputation, Contains(_)).WillRepeatedly(Return(false));      phrase_builder = make_shared<PhraseBuilder>(vocabulary);      intersector = make_shared<FastIntersector>(suffix_array, precomputation,                                                 vocabulary, 15, 1);    } -  Index collocations;    shared_ptr<MockDataArray> data_array;    shared_ptr<MockSuffixArray> suffix_array;    shared_ptr<MockPrecomputation> precomputation; @@ -82,9 +80,9 @@ TEST_F(FastIntersectorTest, TestCachedCollocation) {    Phrase phrase = phrase_builder->Build(symbols);    PhraseLocation prefix_location(15, 16), suffix_location(16, 17); -  collocations[symbols] = expected_location; -  EXPECT_CALL(*precomputation, GetCollocations()) -      .WillRepeatedly(ReturnRef(collocations)); +  EXPECT_CALL(*precomputation, Contains(symbols)).WillRepeatedly(Return(true)); +  EXPECT_CALL(*precomputation, GetCollocations(symbols)). +      WillRepeatedly(Return(expected_location));    intersector = make_shared<FastIntersector>(suffix_array, precomputation,                                               vocabulary, 15, 1); diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 487abcaf..1dc94c25 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -19,10 +19,11 @@ GrammarExtractor::GrammarExtractor(      shared_ptr<SuffixArray> source_suffix_array,      shared_ptr<DataArray> target_data_array,      shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation, -    shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span, +    shared_ptr<Scorer> scorer, shared_ptr<Vocabulary> vocabulary, +    int min_gap_size, int max_rule_span,      int max_nonterminals, int max_rule_symbols, int max_samples,      bool require_tight_phrases) : -    vocabulary(make_shared<Vocabulary>()), +    vocabulary(vocabulary),      rule_factory(make_shared<HieroCachingRuleFactory>(          source_suffix_array, target_data_array, alignment, vocabulary,          precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals, @@ -34,10 +35,12 @@ GrammarExtractor::GrammarExtractor(      vocabulary(vocabulary),      rule_factory(rule_factory) {} -Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { +Grammar GrammarExtractor::GetGrammar( +    const string& sentence, +    const unordered_set<int>& blacklisted_sentence_ids) {    vector<string> words = TokenizeSentence(sentence);    vector<int> word_ids = AnnotateWords(words); -  return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); +  return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids);  }  vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index ae407b47..0f3069b0 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -15,7 +15,6 @@ class DataArray;  class Grammar;  class HieroCachingRuleFactory;  class Precomputation; -class Rule;  class Scorer;  class SuffixArray;  class Vocabulary; @@ -32,6 +31,7 @@ class GrammarExtractor {        shared_ptr<Alignment> alignment,        shared_ptr<Precomputation> precomputation,        shared_ptr<Scorer> scorer, +      shared_ptr<Vocabulary> vocabulary,        int min_gap_size,        int max_rule_span,        int max_nonterminals, @@ -45,7 +45,9 @@ class GrammarExtractor {    // Converts the sentence to a vector of word ids and uses the RuleFactory to    // extract the SCFG rules which may be used to decode the sentence. -  Grammar GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array); +  Grammar GetGrammar( +      const string& sentence, +      const unordered_set<int>& blacklisted_sentence_ids);   private:    // Splits the sentence in a vector of words. diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc index f32a9599..719e90ff 100644 --- a/extractor/grammar_extractor_test.cc +++ b/extractor/grammar_extractor_test.cc @@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {    Grammar grammar(rules, feature_names);    unordered_set<int> blacklisted_sentence_ids;    shared_ptr<DataArray> source_data_array; -  EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array)) +  EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids))        .WillOnce(Return(grammar));    GrammarExtractor extractor(vocabulary, factory);    string sentence = "Anna has many many apples ."; -  extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array); +  extractor.GetGrammar(sentence, blacklisted_sentence_ids);  }  } // namespace diff --git a/extractor/matchings_sampler.cc b/extractor/matchings_sampler.cc new file mode 100644 index 00000000..75a62366 --- /dev/null +++ b/extractor/matchings_sampler.cc @@ -0,0 +1,39 @@ +#include "matchings_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" + +namespace extractor { + +MatchingsSampler::MatchingsSampler( +    shared_ptr<DataArray> data_array, int max_samples) : +    BackoffSampler(data_array, max_samples) {} + +MatchingsSampler::MatchingsSampler() {} + +int MatchingsSampler::GetNumSubpatterns(const PhraseLocation& location) const { +  return location.num_subpatterns; +} + +int MatchingsSampler::GetRangeLow(const PhraseLocation&) const { +  return 0; +} + +int MatchingsSampler::GetRangeHigh(const PhraseLocation& location) const { +  return location.matchings->size() / location.num_subpatterns; +} + +int MatchingsSampler::GetPosition(const PhraseLocation& location, +                                  int index) const { +  return (*location.matchings)[index * location.num_subpatterns]; +} + +void MatchingsSampler::AppendMatching(vector<int>& samples, int index, +                                      const PhraseLocation& location) const { +  int start = index * location.num_subpatterns; +  copy(location.matchings->begin() + start, +       location.matchings->begin() + start + location.num_subpatterns, +       back_inserter(samples)); +} + +} // namespace extractor diff --git a/extractor/matchings_sampler.h b/extractor/matchings_sampler.h new file mode 100644 index 00000000..ca4fce93 --- /dev/null +++ b/extractor/matchings_sampler.h @@ -0,0 +1,31 @@ +#ifndef _MATCHINGS_SAMPLER_H_ +#define _MATCHINGS_SAMPLER_H_ + +#include "backoff_sampler.h" + +namespace extractor { + +class DataArray; + +class MatchingsSampler : public BackoffSampler { + public: +  MatchingsSampler(shared_ptr<DataArray> data_array, int max_samples); + +  MatchingsSampler(); + + private: +   int GetNumSubpatterns(const PhraseLocation& location) const; + +   int GetRangeLow(const PhraseLocation& location) const; + +   int GetRangeHigh(const PhraseLocation& location) const; + +   int GetPosition(const PhraseLocation& location, int index) const; + +   void AppendMatching(vector<int>& samples, int index, +                       const PhraseLocation& location) const; +}; + +} // namespace extractor + +#endif diff --git a/extractor/matchings_sampler_test.cc b/extractor/matchings_sampler_test.cc new file mode 100644 index 00000000..bc927152 --- /dev/null +++ b/extractor/matchings_sampler_test.cc @@ -0,0 +1,118 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_data_array.h" +#include "matchings_sampler.h" +#include "phrase_location.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class MatchingsSamplerTest : public Test { + protected: +  virtual void SetUp() { +    vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +    location = PhraseLocation(locations, 2); + +    data_array = make_shared<MockDataArray>(); +    for (int i = 0; i < 10; ++i) { +      EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i / 2)); +    } +  } + +  unordered_set<int> blacklisted_sentence_ids; +  PhraseLocation location; +  shared_ptr<MockDataArray> data_array; +  shared_ptr<MatchingsSampler> sampler; +}; + +TEST_F(MatchingsSamplerTest, TestSample) { +  sampler = make_shared<MatchingsSampler>(data_array, 1); +  vector<int> expected_locations = {0, 1}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 2); +  expected_locations = {0, 1, 6, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 3); +  expected_locations = {0, 1, 4, 5, 6, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 7); +  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(MatchingsSamplerTest, TestBackoffSample) { +  sampler = make_shared<MatchingsSampler>(data_array, 1); +  blacklisted_sentence_ids = {0}; +  vector<int> expected_locations = {2, 3}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 1, 2, 3}; +  expected_locations = {8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 1, 2, 3, 4}; +  expected_locations = {}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 2); +  blacklisted_sentence_ids = {0, 3}; +  expected_locations = {2, 3, 4, 5}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 3); +  blacklisted_sentence_ids = {0, 3}; +  expected_locations = {2, 3, 4, 5, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 2, 3}; +  expected_locations = {2, 3, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 4); +  blacklisted_sentence_ids = {0, 1, 2, 3}; +  expected_locations = {8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {1, 3}; +  expected_locations = {0, 1, 4, 5, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  sampler = make_shared<MatchingsSampler>(data_array, 7); +  blacklisted_sentence_ids = {0, 1, 2, 3, 4}; +  expected_locations = {}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 2, 4}; +  expected_locations = {2, 3, 6, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {1, 3}; +  expected_locations = {0, 1, 4, 5, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 2), +            sampler->Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index 6f85abb4..98e711d2 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -6,12 +6,13 @@ namespace extractor {  class MockDataArray : public DataArray {   public: -  MOCK_CONST_METHOD0(GetData, const vector<int>&()); +  MOCK_CONST_METHOD0(GetData, vector<int>());    MOCK_CONST_METHOD1(AtIndex, int(int index));    MOCK_CONST_METHOD1(GetWordAtIndex, string(int index)); +  MOCK_CONST_METHOD2(GetWordIds, vector<int>(int start_index, int size)); +  MOCK_CONST_METHOD2(GetWords, vector<string>(int start_index, int size));    MOCK_CONST_METHOD0(GetSize, int());    MOCK_CONST_METHOD0(GetVocabularySize, int()); -  MOCK_CONST_METHOD1(HasWord, bool(const string& word));    MOCK_CONST_METHOD1(GetWordId, int(const string& word));    MOCK_CONST_METHOD1(GetWord, string(int word_id));    MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id)); diff --git a/extractor/mocks/mock_matchings_sampler.h b/extractor/mocks/mock_matchings_sampler.h new file mode 100644 index 00000000..de2009c3 --- /dev/null +++ b/extractor/mocks/mock_matchings_sampler.h @@ -0,0 +1,15 @@ +#include <gmock/gmock.h> + +#include "phrase_location.h" +#include "matchings_sampler.h" + +namespace extractor { + +class MockMatchingsSampler : public MatchingsSampler { + public: +  MOCK_CONST_METHOD2(Sample, PhraseLocation( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h index 8753343e..5f7aa999 100644 --- a/extractor/mocks/mock_precomputation.h +++ b/extractor/mocks/mock_precomputation.h @@ -6,7 +6,8 @@ namespace extractor {  class MockPrecomputation : public Precomputation {   public: -  MOCK_CONST_METHOD0(GetCollocations, const Index&()); +  MOCK_CONST_METHOD1(Contains, bool(const vector<int>& pattern)); +  MOCK_CONST_METHOD1(GetCollocations, vector<int>(const vector<int>& pattern));  };  } // namespace extractor diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 86a084b5..53eb5022 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -7,7 +7,9 @@ namespace extractor {  class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {   public: -  MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array)); +  MOCK_METHOD2(GetGrammar, Grammar( +      const vector<int>& word_ids, +      const unordered_set<int>& blacklisted_sentence_ids));  };  } // namespace extractor diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h index 75c43c27..b2742f62 100644 --- a/extractor/mocks/mock_sampler.h +++ b/extractor/mocks/mock_sampler.h @@ -7,7 +7,9 @@ namespace extractor {  class MockSampler : public Sampler {   public: -  MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); +  MOCK_CONST_METHOD2(Sample, PhraseLocation( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids));  };  } // namespace extractor diff --git a/extractor/mocks/mock_suffix_array_sampler.h b/extractor/mocks/mock_suffix_array_sampler.h new file mode 100644 index 00000000..d799b969 --- /dev/null +++ b/extractor/mocks/mock_suffix_array_sampler.h @@ -0,0 +1,15 @@ +#include <gmock/gmock.h> + +#include "phrase_location.h" +#include "suffix_array_sampler.h" + +namespace extractor { + +class MockSuffixArraySampler : public SuffixArrayRangeSampler { + public: +  MOCK_CONST_METHOD2(Sample, PhraseLocation( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids)); +}; + +} // namespace extractor diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc index 13140cac..2c367893 100644 --- a/extractor/phrase_location.cc +++ b/extractor/phrase_location.cc @@ -1,5 +1,7 @@  #include "phrase_location.h" +#include <iostream> +  namespace extractor {  PhraseLocation::PhraseLocation(int sa_low, int sa_high) : diff --git a/extractor/phrase_location_sampler.cc b/extractor/phrase_location_sampler.cc new file mode 100644 index 00000000..a2eec105 --- /dev/null +++ b/extractor/phrase_location_sampler.cc @@ -0,0 +1,34 @@ +#include "phrase_location_sampler.h" + +#include "matchings_sampler.h" +#include "phrase_location.h" +#include "suffix_array.h" +#include "suffix_array_sampler.h" + +namespace extractor { + +PhraseLocationSampler::PhraseLocationSampler( +    shared_ptr<SuffixArray> suffix_array, int max_samples) { +  matchings_sampler = make_shared<MatchingsSampler>( +      suffix_array->GetData(), max_samples); +  suffix_array_sampler = make_shared<SuffixArrayRangeSampler>( +      suffix_array, max_samples); +} + +PhraseLocationSampler::PhraseLocationSampler( +    shared_ptr<MatchingsSampler> matchings_sampler, +    shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler) : +    matchings_sampler(matchings_sampler), +    suffix_array_sampler(suffix_array_sampler) {} + +PhraseLocation PhraseLocationSampler::Sample( +    const PhraseLocation& location, +    const unordered_set<int>& blacklisted_sentence_ids) const { +  if (location.matchings == NULL) { +    return suffix_array_sampler->Sample(location, blacklisted_sentence_ids); +  } else { +    return matchings_sampler->Sample(location, blacklisted_sentence_ids); +  } +} + +} // namespace extractor diff --git a/extractor/phrase_location_sampler.h b/extractor/phrase_location_sampler.h new file mode 100644 index 00000000..0e88335e --- /dev/null +++ b/extractor/phrase_location_sampler.h @@ -0,0 +1,35 @@ +#ifndef _PHRASE_LOCATION_SAMPLER_H_ +#define _PHRASE_LOCATION_SAMPLER_H_ + +#include <memory> + +#include "sampler.h" + +namespace extractor { + +class MatchingsSampler; +class PhraseLocation; +class SuffixArray; +class SuffixArrayRangeSampler; + +class PhraseLocationSampler : public Sampler { + public: +  PhraseLocationSampler(shared_ptr<SuffixArray> suffix_array, int max_samples); + +  // For testing only. +  PhraseLocationSampler( +      shared_ptr<MatchingsSampler> matchings_sampler, +      shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler); + +  PhraseLocation Sample( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids) const; + + private: +  shared_ptr<MatchingsSampler> matchings_sampler; +  shared_ptr<SuffixArrayRangeSampler> suffix_array_sampler; +}; + +} // namespace extractor + +#endif diff --git a/extractor/phrase_location_sampler_test.cc b/extractor/phrase_location_sampler_test.cc new file mode 100644 index 00000000..e7520ce7 --- /dev/null +++ b/extractor/phrase_location_sampler_test.cc @@ -0,0 +1,50 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_matchings_sampler.h" +#include "mocks/mock_suffix_array_sampler.h" +#include "phrase_location.h" +#include "phrase_location_sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class MatchingsSamplerTest : public Test { + protected: +  virtual void SetUp() { +    matchings_sampler = make_shared<MockMatchingsSampler>(); +    suffix_array_sampler = make_shared<MockSuffixArraySampler>(); + +    sampler = make_shared<PhraseLocationSampler>( +        matchings_sampler, suffix_array_sampler); +  } + +  shared_ptr<MockMatchingsSampler> matchings_sampler; +  shared_ptr<MockSuffixArraySampler> suffix_array_sampler; +  shared_ptr<PhraseLocationSampler> sampler; +}; + +TEST_F(MatchingsSamplerTest, TestSuffixArrayRange) { +  vector<int> locations = {0, 1, 2, 3}; +  PhraseLocation location(0, 3), result(locations, 2); +  unordered_set<int> blacklisted_sentence_ids; +  EXPECT_CALL(*suffix_array_sampler, Sample(location, blacklisted_sentence_ids)) +      .WillOnce(Return(result)); +  EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(MatchingsSamplerTest, TestMatchings) { +  vector<int> locations = {0, 1, 2, 3}; +  PhraseLocation location(locations, 2), result(locations, 2); +  unordered_set<int> blacklisted_sentence_ids; +  EXPECT_CALL(*matchings_sampler, Sample(location, blacklisted_sentence_ids)) +      .WillOnce(Return(result)); +  EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 3b8aed69..3e58e2a9 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -5,59 +5,67 @@  #include "data_array.h"  #include "suffix_array.h" +#include "time_util.h" +#include "vocabulary.h"  using namespace std;  namespace extractor { -int Precomputation::FIRST_NONTERMINAL = -1; -int Precomputation::SECOND_NONTERMINAL = -2; -  Precomputation::Precomputation( -    shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, -    int num_super_frequent_patterns, int max_rule_span, -    int max_rule_symbols, int min_gap_size, +    shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array, +    int num_frequent_patterns, int num_super_frequent_patterns, +    int max_rule_span, int max_rule_symbols, int min_gap_size,      int max_frequent_phrase_len, int min_frequency) { -  vector<int> data = suffix_array->GetData()->GetData(); +  Clock::time_point start_time = Clock::now(); +  shared_ptr<DataArray> data_array = suffix_array->GetData(); +  vector<int> data = data_array->GetData();    vector<vector<int>> frequent_patterns = FindMostFrequentPatterns(        suffix_array, data, num_frequent_patterns, max_frequent_phrase_len,        min_frequency); +  Clock::time_point end_time = Clock::now(); +  cerr << "Finding most frequent patterns took " +       << GetDuration(start_time, end_time) << " seconds..." << endl; -  // Construct sets containing the frequent and superfrequent contiguous -  // collocations. -  unordered_set<vector<int>, VectorHash> frequent_patterns_set; -  unordered_set<vector<int>, VectorHash> super_frequent_patterns_set; +  vector<vector<int>> pattern_annotations(frequent_patterns.size()); +  unordered_map<vector<int>, int, VectorHash> frequent_patterns_index;    for (size_t i = 0; i < frequent_patterns.size(); ++i) { -    frequent_patterns_set.insert(frequent_patterns[i]); -    if (i < num_super_frequent_patterns) { -      super_frequent_patterns_set.insert(frequent_patterns[i]); -    } +    frequent_patterns_index[frequent_patterns[i]] = i; +    pattern_annotations[i] = AnnotatePattern(vocabulary, data_array, +                                             frequent_patterns[i]);    } +  start_time = Clock::now();    vector<tuple<int, int, int>> matchings; +  vector<vector<int>> annotations;    for (size_t i = 0; i < data.size(); ++i) {      // If the sentence is over, add all the discontiguous frequent patterns to      // the index.      if (data[i] == DataArray::END_OF_LINE) { -      AddCollocations(matchings, data, max_rule_span, min_gap_size, -          max_rule_symbols); +      UpdateIndex(matchings, annotations, max_rule_span, min_gap_size, +                  max_rule_symbols);        matchings.clear(); +      annotations.clear();        continue;      } -    vector<int> pattern;      // Find all the contiguous frequent patterns starting at position i. +    vector<int> pattern;      for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) {        pattern.push_back(data[i + j - 1]); -      if (frequent_patterns_set.count(pattern)) { -        int is_super_frequent = super_frequent_patterns_set.count(pattern); -        matchings.push_back(make_tuple(i, j, is_super_frequent)); -      } else { +      auto it = frequent_patterns_index.find(pattern); +      if (it == frequent_patterns_index.end()) {          // If the current pattern is not frequent, any longer pattern having the          // current pattern as prefix will not be frequent.          break;        } +      int is_super_frequent = it->second < num_super_frequent_patterns; +      matchings.push_back(make_tuple(i, j, is_super_frequent)); +      annotations.push_back(pattern_annotations[it->second]);      }    } +  end_time = Clock::now(); +  cerr << "Constructing collocations index took " +       << GetDuration(start_time, end_time) << " seconds..." << endl;  }  Precomputation::Precomputation() {} @@ -75,9 +83,9 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(    for (size_t i = 1; i < lcp.size(); ++i) {      for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) {        int frequency = i - run_start[len]; -      if (frequency >= min_frequency) { -        heap.push(make_pair(frequency, -            make_pair(suffix_array->GetSuffix(run_start[len]), len + 1))); +      int start = suffix_array->GetSuffix(run_start[len]); +      if (frequency >= min_frequency && start + len <= data.size()) { +        heap.push(make_pair(frequency, make_pair(start, len + 1)));        }        run_start[len] = i;      } @@ -99,8 +107,20 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns(    return frequent_patterns;  } -void Precomputation::AddCollocations( -    const vector<tuple<int, int, int>>& matchings, const vector<int>& data, +vector<int> Precomputation::AnnotatePattern( +    shared_ptr<Vocabulary> vocabulary, shared_ptr<DataArray> data_array, +    const vector<int>& pattern) const { +  vector<int> annotation; +  for (int word_id: pattern) { +    annotation.push_back(vocabulary->GetTerminalIndex( +        data_array->GetWord(word_id))); +  } +  return annotation; +} + +void Precomputation::UpdateIndex( +    const vector<tuple<int, int, int>>& matchings, +    const vector<vector<int>>& annotations,      int max_rule_span, int min_gap_size, int max_rule_symbols) {    // Select the leftmost subpattern.    for (size_t i = 0; i < matchings.size(); ++i) { @@ -118,16 +138,14 @@ void Precomputation::AddCollocations(        if (start2 - start1 - size1 >= min_gap_size            && start2 + size2 - start1 <= max_rule_span            && size1 + size2 + 1 <= max_rule_symbols) { -        vector<int> pattern(data.begin() + start1, -            data.begin() + start1 + size1); -        pattern.push_back(Precomputation::FIRST_NONTERMINAL); -        pattern.insert(pattern.end(), data.begin() + start2, -            data.begin() + start2 + size2); -        AddStartPositions(collocations[pattern], start1, start2); +        vector<int> pattern = annotations[i]; +        pattern.push_back(-1); +        AppendSubpattern(pattern, annotations[j]); +        AppendCollocation(index[pattern], start1, start2);          // Try extending the binary collocation to a ternary collocation.          if (is_super2) { -          pattern.push_back(Precomputation::SECOND_NONTERMINAL); +          pattern.push_back(-2);            // Select the rightmost subpattern.            for (size_t k = j + 1; k < matchings.size(); ++k) {              int start3, size3, is_super3; @@ -140,9 +158,8 @@ void Precomputation::AddCollocations(                  && start3 + size3 - start1 <= max_rule_span                  && size1 + size2 + size3 + 2 <= max_rule_symbols                  && (is_super1 || is_super3)) { -              pattern.insert(pattern.end(), data.begin() + start3, -                  data.begin() + start3 + size3); -              AddStartPositions(collocations[pattern], start1, start2, start3); +              AppendSubpattern(pattern, annotations[k]); +              AppendCollocation(index[pattern], start1, start2, start3);                pattern.erase(pattern.end() - size3);              }            } @@ -152,25 +169,35 @@ void Precomputation::AddCollocations(    }  } -void Precomputation::AddStartPositions( -    vector<int>& positions, int pos1, int pos2) { -  positions.push_back(pos1); -  positions.push_back(pos2); +void Precomputation::AppendSubpattern( +    vector<int>& pattern, +    const vector<int>& subpattern) { +  copy(subpattern.begin(), subpattern.end(), back_inserter(pattern)); +} + +void Precomputation::AppendCollocation( +    vector<int>& collocations, int pos1, int pos2) { +  collocations.push_back(pos1); +  collocations.push_back(pos2); +} + +void Precomputation::AppendCollocation( +    vector<int>& collocations, int pos1, int pos2, int pos3) { +  collocations.push_back(pos1); +  collocations.push_back(pos2); +  collocations.push_back(pos3);  } -void Precomputation::AddStartPositions( -    vector<int>& positions, int pos1, int pos2, int pos3) { -  positions.push_back(pos1); -  positions.push_back(pos2); -  positions.push_back(pos3); +bool Precomputation::Contains(const vector<int>& pattern) const { +  return index.count(pattern);  } -const Index& Precomputation::GetCollocations() const { -  return collocations; +vector<int> Precomputation::GetCollocations(const vector<int>& pattern) const { +  return index.at(pattern);  }  bool Precomputation::operator==(const Precomputation& other) const { -  return collocations == other.collocations; +  return index == other.index;  }  } // namespace extractor diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 9f0c9424..2b34fc29 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -7,13 +7,11 @@  #include <tuple>  #include <vector> -#include <boost/filesystem.hpp>  #include <boost/functional/hash.hpp>  #include <boost/serialization/serialization.hpp>  #include <boost/serialization/utility.hpp>  #include <boost/serialization/vector.hpp> -namespace fs = boost::filesystem;  using namespace std;  namespace extractor { @@ -21,7 +19,9 @@ namespace extractor {  typedef boost::hash<vector<int>> VectorHash;  typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; +class DataArray;  class SuffixArray; +class Vocabulary;  /**   * Data structure wrapping an index with all the occurrences of the most @@ -37,9 +37,9 @@ class Precomputation {   public:    // Constructs the index using the suffix array.    Precomputation( -      shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, -      int num_super_frequent_patterns, int max_rule_span, -      int max_rule_symbols, int min_gap_size, +      shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array, +      int num_frequent_patterns, int num_super_frequent_patterns, +      int max_rule_span, int max_rule_symbols, int min_gap_size,        int max_frequent_phrase_len, int min_frequency);    // Creates empty precomputation data structure. @@ -47,13 +47,13 @@ class Precomputation {    virtual ~Precomputation(); -  // Returns a reference to the index. -  virtual const Index& GetCollocations() const; +  // Returns whether a pattern is contained in the index of collocations. +  virtual bool Contains(const vector<int>& pattern) const; -  bool operator==(const Precomputation& other) const; +  // Returns the list of collocations for a given pattern. +  virtual vector<int> GetCollocations(const vector<int>& pattern) const; -  static int FIRST_NONTERMINAL; -  static int SECOND_NONTERMINAL; +  bool operator==(const Precomputation& other) const;   private:    // Finds the most frequent contiguous collocations. @@ -62,25 +62,32 @@ class Precomputation {        int num_frequent_patterns, int max_frequent_phrase_len,        int min_frequency); +  vector<int> AnnotatePattern(shared_ptr<Vocabulary> vocabulary, +                              shared_ptr<DataArray> data_array, +                              const vector<int>& pattern) const; +    // Given the locations of the frequent contiguous collocations in a sentence,    // it adds new entries to the index for each discontiguous collocation    // matching the criteria specified in the class description. -  void AddCollocations( -      const vector<std::tuple<int, int, int>>& matchings, const vector<int>& data, +  void UpdateIndex( +      const vector<tuple<int, int, int>>& matchings, +      const vector<vector<int>>& annotations,        int max_rule_span, int min_gap_size, int max_rule_symbols); +  void AppendSubpattern(vector<int>& pattern, const vector<int>& subpattern); +    // Adds an occurrence of a binary collocation. -  void AddStartPositions(vector<int>& positions, int pos1, int pos2); +  void AppendCollocation(vector<int>& collocations, int pos1, int pos2);    // Adds an occurrence of a ternary collocation. -  void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3); +  void AppendCollocation(vector<int>& collocations, int pos1, int pos2, int pos3);    friend class boost::serialization::access;    template<class Archive> void save(Archive& ar, unsigned int) const { -    int num_entries = collocations.size(); +    int num_entries = index.size();      ar << num_entries; -    for (pair<vector<int>, vector<int>> entry: collocations) { +    for (pair<vector<int>, vector<int>> entry: index) {        ar << entry;      }    } @@ -91,13 +98,13 @@ class Precomputation {      for (size_t i = 0; i < num_entries; ++i) {        pair<vector<int>, vector<int>> entry;        ar >> entry; -      collocations.insert(entry); +      index.insert(entry);      }    }    BOOST_SERIALIZATION_SPLIT_MEMBER(); -  Index collocations; +  Index index;  };  } // namespace extractor diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index e81ece5d..3a98ce05 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -9,6 +9,7 @@  #include "mocks/mock_data_array.h"  #include "mocks/mock_suffix_array.h" +#include "mocks/mock_vocabulary.h"  #include "precomputation.h"  using namespace std; @@ -23,7 +24,12 @@ class PrecomputationTest : public Test {    virtual void SetUp() {      data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1};      data_array = make_shared<MockDataArray>(); -    EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); +    EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data)); +    for (size_t i = 0; i < data.size(); ++i) { +      EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); +    } +    EXPECT_CALL(*data_array, GetWord(2)).WillRepeatedly(Return("2")); +    EXPECT_CALL(*data_array, GetWord(3)).WillRepeatedly(Return("3"));      vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13};      vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0}; @@ -35,77 +41,98 @@ class PrecomputationTest : public Test {      }      EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp)); -    precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); +    vocabulary = make_shared<MockVocabulary>(); +    EXPECT_CALL(*vocabulary, GetTerminalIndex("2")).WillRepeatedly(Return(2)); +    EXPECT_CALL(*vocabulary, GetTerminalIndex("3")).WillRepeatedly(Return(3)); + +    precomputation = Precomputation(vocabulary, suffix_array, +                                    3, 3, 10, 5, 1, 4, 2);    }    vector<int> data;    shared_ptr<MockDataArray> data_array;    shared_ptr<MockSuffixArray> suffix_array; +  shared_ptr<MockVocabulary> vocabulary;    Precomputation precomputation;  };  TEST_F(PrecomputationTest, TestCollocations) { -  Index collocations = precomputation.GetCollocations(); -    vector<int> key = {2, 3, -1, 2};    vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {2, 3, -1, 2, 3};    expected_value = {1, 5, 1, 8, 5, 8}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {2, 3, -1, 3};    expected_value = {1, 6, 1, 9, 5, 9}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {3, -1, 2};    expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {3, -1, 3};    expected_value = {2, 6, 2, 9, 6, 9}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {3, -1, 2, 3};    expected_value = {2, 5, 2, 8, 6, 8}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {2, -1, 2};    expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {2, -1, 2, 3};    expected_value = {1, 5, 1, 8, 5, 8}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {2, -1, 3};    expected_value = {1, 6, 1, 9, 5, 9}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {2, -1, 2, -2, 2};    expected_value = {1, 5, 8, 5, 8, 11}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {2, -1, 2, -2, 3};    expected_value = {1, 5, 9}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {2, -1, 3, -2, 2};    expected_value = {1, 6, 8, 5, 9, 11}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {2, -1, 3, -2, 3};    expected_value = {1, 6, 9}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {3, -1, 2, -2, 2};    expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {3, -1, 2, -2, 3};    expected_value = {2, 5, 9}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {3, -1, 3, -2, 2};    expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    key = {3, -1, 3, -2, 3};    expected_value = {2, 6, 9}; -  EXPECT_EQ(expected_value, collocations[key]); +  EXPECT_TRUE(precomputation.Contains(key)); +  EXPECT_EQ(expected_value, precomputation.GetCollocations(key));    // Exceeds max_rule_symbols.    key = {2, -1, 2, -2, 2, 3}; -  EXPECT_EQ(0, collocations.count(key)); +  EXPECT_FALSE(precomputation.Contains(key));    // Contains non frequent pattern.    key = {2, -1, 5}; -  EXPECT_EQ(0, collocations.count(key)); +  EXPECT_FALSE(precomputation.Contains(key));  }  TEST_F(PrecomputationTest, TestSerialization) { diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 6ae2d792..18a60695 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -12,6 +12,7 @@  #include "phrase_builder.h"  #include "rule.h"  #include "rule_extractor.h" +#include "phrase_location_sampler.h"  #include "sampler.h"  #include "scorer.h"  #include "suffix_array.h" @@ -68,7 +69,8 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(        target_data_array, alignment, phrase_builder, scorer, vocabulary,        max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true,        false, require_tight_phrases); -  sampler = make_shared<Sampler>(source_suffix_array, max_samples); +  sampler = make_shared<PhraseLocationSampler>( +      source_suffix_array, max_samples);  }  HieroCachingRuleFactory::HieroCachingRuleFactory( @@ -101,7 +103,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}  HieroCachingRuleFactory::~HieroCachingRuleFactory() {} -Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { +Grammar HieroCachingRuleFactory::GetGrammar( +    const vector<int>& word_ids, +    const unordered_set<int>& blacklisted_sentence_ids) {    Clock::time_point start_time = Clock::now();    double total_extract_time = 0;    double total_intersect_time = 0; @@ -193,7 +197,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const u        Clock::time_point extract_start = Clock::now();        if (!state.starts_with_x) {          // Extract rules for the sampled set of occurrences. -        PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array); +        PhraseLocation sample = sampler->Sample( +            next_node->matchings, blacklisted_sentence_ids);          vector<Rule> new_rules =              rule_extractor->ExtractRules(next_phrase, sample);          rules.insert(rules.end(), new_rules.begin(), new_rules.end()); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index df63a9d8..1a9fa2af 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -72,7 +72,9 @@ class HieroCachingRuleFactory {    // Constructs SCFG rules for a given sentence.    // (See class description for more details.) -  virtual Grammar GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array); +  virtual Grammar GetGrammar( +      const vector<int>& word_ids, +      const unordered_set<int>& blacklisted_sentence_ids);   protected:    HieroCachingRuleFactory(); diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index f26cc567..332c5959 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -40,7 +40,7 @@ class RuleFactoryTest : public Test {          .WillRepeatedly(Return(feature_names));      sampler = make_shared<MockSampler>(); -    EXPECT_CALL(*sampler, Sample(_)) +    EXPECT_CALL(*sampler, Sample(_, _))          .WillRepeatedly(Return(PhraseLocation(0, 1)));      Phrase phrase; @@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {    vector<int> word_ids = {2, 3, 4};    unordered_set<int> blacklisted_sentence_ids; -  shared_ptr<DataArray> source_data_array; -  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); +  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);    EXPECT_EQ(feature_names, grammar.GetFeatureNames());    EXPECT_EQ(7, grammar.GetRules().size());  } @@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {    vector<int> word_ids = {2, 3, 4, 2, 3};    unordered_set<int> blacklisted_sentence_ids; -  shared_ptr<DataArray> source_data_array; -  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); +  Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids);    EXPECT_EQ(feature_names, grammar.GetFeatureNames());    EXPECT_EQ(28, grammar.GetRules().size());  } diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 6f59f0b6..f1aa5e35 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -5,10 +5,10 @@  #include <string>  #include <vector> -#include <omp.h>  #include <boost/filesystem.hpp>  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> +#include <omp.h>  #include "alignment.h"  #include "data_array.h" @@ -28,6 +28,7 @@  #include "suffix_array.h"  #include "time_util.h"  #include "translation_table.h" +#include "vocabulary.h"  namespace fs = boost::filesystem;  namespace po = boost::program_options; @@ -77,7 +78,8 @@ int main(int argc, char** argv) {      ("tight_phrases", po::value<bool>()->default_value(true),          "False if phrases may be loose (better, but slower)")      ("leave_one_out", po::value<bool>()->zero_tokens(), -        "do leave-one-out estimation (e.g. for extracting grammars for the training set)"); +        "do leave-one-out estimation of grammars " +        "(e.g. for extracting grammars for the training set");    po::variables_map vm;    po::store(po::parse_command_line(argc, argv, desc), vm); @@ -98,11 +100,6 @@ int main(int argc, char** argv) {      return 1;    } -  bool leave_one_out = false; -  if (vm.count("leave_one_out")) { -    leave_one_out = true; -  } -    int num_threads = vm["threads"].as<int>();    cerr << "Grammar extraction will use " << num_threads << " threads." << endl; @@ -142,11 +139,14 @@ int main(int argc, char** argv) {    cerr << "Reading alignment took "         << GetDuration(start_time, stop_time) << " seconds" << endl; +  shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>(); +    // Constructs an index storing the occurrences in the source data for each    // frequent collocation.    start_time = Clock::now();    cerr << "Precomputing collocations..." << endl;    shared_ptr<Precomputation> precomputation = make_shared<Precomputation>( +      vocabulary,        source_suffix_array,        vm["frequent"].as<int>(),        vm["super_frequent"].as<int>(), @@ -174,8 +174,8 @@ int main(int argc, char** argv) {         << GetDuration(preprocess_start_time, preprocess_stop_time)         << " seconds" << endl; -  // Features used to score each grammar rule.    Clock::time_point extraction_start_time = Clock::now(); +  // Features used to score each grammar rule.    vector<shared_ptr<Feature>> features = {        make_shared<TargetGivenSourceCoherent>(),        make_shared<SampleSourceCount>(), @@ -194,6 +194,7 @@ int main(int argc, char** argv) {        alignment,        precomputation,        scorer, +      vocabulary,        vm["min_gap_size"].as<int>(),        vm["max_rule_span"].as<int>(),        vm["max_nonterminals"].as<int>(), @@ -201,9 +202,6 @@ int main(int argc, char** argv) {        vm["max_samples"].as<int>(),        vm["tight_phrases"].as<bool>()); -  // Releases extra memory used by the initial precomputation. -  precomputation.reset(); -    // Creates the grammars directory if it doesn't exist.    fs::path grammar_path = vm["grammars"].as<string>();    if (!fs::is_directory(grammar_path)) { @@ -219,6 +217,7 @@ int main(int argc, char** argv) {    }    // Extracts the grammar for each sentence and saves it to a file. +  bool leave_one_out = vm.count("leave_one_out");    vector<string> suffixes(sentences.size());    #pragma omp parallel for schedule(dynamic) num_threads(num_threads)    for (size_t i = 0; i < sentences.size(); ++i) { @@ -231,8 +230,11 @@ int main(int argc, char** argv) {      suffixes[i] = suffix;      unordered_set<int> blacklisted_sentence_ids; -    if (leave_one_out) blacklisted_sentence_ids.insert(i); -    Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array); +    if (leave_one_out) { +      blacklisted_sentence_ids.insert(i); +    } +    Grammar grammar = extractor.GetGrammar( +        sentences[i], blacklisted_sentence_ids);      ofstream output(GetGrammarFilePath(grammar_path, i).c_str());      output << grammar;    } diff --git a/extractor/sampler.cc b/extractor/sampler.cc deleted file mode 100644 index 963afa7a..00000000 --- a/extractor/sampler.cc +++ /dev/null @@ -1,75 +0,0 @@ -#include "sampler.h" - -#include "phrase_location.h" -#include "suffix_array.h" - -namespace extractor { - -Sampler::Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples) : -    suffix_array(suffix_array), max_samples(max_samples) {} - -Sampler::Sampler() {} - -Sampler::~Sampler() {} - -PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const { -  vector<int> sample; -  int num_subpatterns; -  if (location.matchings == NULL) { -    // Sample suffix array range. -    num_subpatterns = 1; -    int low = location.sa_low, high = location.sa_high; -    double step = max(1.0, (double) (high - low) / max_samples); -    double i = low, last = i; -    bool found; -    while (sample.size() < max_samples && i < high) { -      int x = suffix_array->GetSuffix(Round(i)); -      int id = source_data_array->GetSentenceId(x); -      if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { -        found = false; -        double backoff_step = 1; -        while (true) { -          if ((double)backoff_step >= step) break; -          double j = i - backoff_step; -          x = suffix_array->GetSuffix(Round(j)); -          id = source_data_array->GetSentenceId(x); -          if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { -            found = true; last = i; break; -          } -          double k = i + backoff_step; -          x = suffix_array->GetSuffix(Round(k)); -          id = source_data_array->GetSentenceId(x); -          if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { -            found = true; last = k; break; -          } -          if (j <= last && k >= high) break; -          backoff_step++; -        } -      } else { -        found = true; -        last = i; -      } -      if (found) sample.push_back(x); -      i += step; -    } -  } else { -    // Sample vector of occurrences. -    num_subpatterns = location.num_subpatterns; -    int num_matchings = location.matchings->size() / num_subpatterns; -    double step = max(1.0, (double) num_matchings / max_samples); -    for (double i = 0, num_samples = 0; -         i < num_matchings && num_samples < max_samples; -         i += step, ++num_samples) { -      int start = Round(i) * num_subpatterns; -      sample.insert(sample.end(), location.matchings->begin() + start, -          location.matchings->begin() + start + num_subpatterns); -    } -  } -  return PhraseLocation(sample, num_subpatterns); -} - -int Sampler::Round(double x) const { -  return x + 0.5; -} - -} // namespace extractor diff --git a/extractor/sampler.h b/extractor/sampler.h index de450c48..3c4e37f1 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -4,36 +4,20 @@  #include <memory>  #include <unordered_set> -#include "data_array.h" -  using namespace std;  namespace extractor {  class PhraseLocation; -class SuffixArray;  /** - * Provides uniform sampling for a PhraseLocation. + * Base sampler class.   */  class Sampler {   public: -  Sampler(shared_ptr<SuffixArray> suffix_array, int max_samples); - -  virtual ~Sampler(); - -  // Samples uniformly at most max_samples phrase occurrences. -  virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const; - - protected: -  Sampler(); - - private: -  // Round floating point number to the nearest integer. -  int Round(double x) const; - -  shared_ptr<SuffixArray> suffix_array; -  int max_samples; +  virtual PhraseLocation Sample( +      const PhraseLocation& location, +      const unordered_set<int>& blacklisted_sentence_ids) const = 0;  };  } // namespace extractor diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc deleted file mode 100644 index 965567ba..00000000 --- a/extractor/sampler_test.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include <gtest/gtest.h> - -#include <memory> - -#include "mocks/mock_suffix_array.h" -#include "mocks/mock_data_array.h" -#include "phrase_location.h" -#include "sampler.h" - -using namespace std; -using namespace ::testing; - -namespace extractor { -namespace { - -class SamplerTest : public Test { - protected: -  virtual void SetUp() { -    source_data_array = make_shared<MockDataArray>(); -    EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999)); -    suffix_array = make_shared<MockSuffixArray>(); -    for (int i = 0; i < 10; ++i) { -      EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); -    } -  } - -  shared_ptr<MockSuffixArray> suffix_array; -  shared_ptr<Sampler> sampler; -  shared_ptr<MockDataArray> source_data_array; -}; - -TEST_F(SamplerTest, TestSuffixArrayRange) { -  PhraseLocation location(0, 10); -  unordered_set<int> blacklist; - -  sampler = make_shared<Sampler>(suffix_array, 1); -  vector<int> expected_locations = {0}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - -  sampler = make_shared<Sampler>(suffix_array, 2); -  expected_locations = {0, 5}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - -  sampler = make_shared<Sampler>(suffix_array, 3); -  expected_locations = {0, 3, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - -  sampler = make_shared<Sampler>(suffix_array, 4); -  expected_locations = {0, 3, 5, 8}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - -  sampler = make_shared<Sampler>(suffix_array, 100); -  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -} - -TEST_F(SamplerTest, TestSubstringsSample) { -  vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  unordered_set<int> blacklist; -  PhraseLocation location(locations, 2); - -  sampler = make_shared<Sampler>(suffix_array, 1); -  vector<int> expected_locations = {0, 1}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); - -  sampler = make_shared<Sampler>(suffix_array, 2); -  expected_locations = {0, 1, 6, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); - -  sampler = make_shared<Sampler>(suffix_array, 3); -  expected_locations = {0, 1, 4, 5, 6, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); - -  sampler = make_shared<Sampler>(suffix_array, 7); -  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); -} - -} // namespace -} // namespace extractor diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc deleted file mode 100644 index 3305b990..00000000 --- a/extractor/sampler_test_blacklist.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include <gtest/gtest.h> - -#include <memory> - -#include "mocks/mock_suffix_array.h" -#include "mocks/mock_data_array.h" -#include "phrase_location.h" -#include "sampler.h" - -using namespace std; -using namespace ::testing; - -namespace extractor { -namespace { - -class SamplerTestBlacklist : public Test { - protected: -  virtual void SetUp() { -    source_data_array = make_shared<MockDataArray>(); -    for (int i = 0; i < 10; ++i) { -      EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); -    } -    for (int i = -10; i < 0; ++i) { -      EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0)); -    } -    suffix_array = make_shared<MockSuffixArray>(); -    for (int i = -10; i < 10; ++i) { -      EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); -    } -  } - -  shared_ptr<MockSuffixArray> suffix_array; -  shared_ptr<Sampler> sampler; -  shared_ptr<MockDataArray> source_data_array; -}; - -TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) { -  PhraseLocation location(0, 10); -  unordered_set<int> blacklist; -  vector<int> expected_locations; -    -  blacklist.insert(0); -  sampler = make_shared<Sampler>(suffix_array, 1); -  expected_locations = {1}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); -   -  for (int i = 0; i < 9; i++) { -    blacklist.insert(i); -  } -  sampler = make_shared<Sampler>(suffix_array, 1); -  expected_locations = {9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); - -  blacklist.insert(0); -  blacklist.insert(5); -  sampler = make_shared<Sampler>(suffix_array, 2); -  expected_locations = {1, 4}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); - -  blacklist.insert(0); -  blacklist.insert(1); -  blacklist.insert(2); -  blacklist.insert(3); -  sampler = make_shared<Sampler>(suffix_array, 2); -  expected_locations = {4, 5}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); - -  blacklist.insert(0); -  blacklist.insert(3); -  blacklist.insert(7); -  sampler = make_shared<Sampler>(suffix_array, 3); -  expected_locations = {1, 2, 6}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); - -  blacklist.insert(0); -  blacklist.insert(3); -  blacklist.insert(5); -  blacklist.insert(8); -  sampler = make_shared<Sampler>(suffix_array, 4); -  expected_locations = {1, 2, 4, 7}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); -   -  blacklist.insert(0); -  sampler = make_shared<Sampler>(suffix_array, 100); -  expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -  blacklist.clear(); -  -  blacklist.insert(9); -  sampler = make_shared<Sampler>(suffix_array, 100); -  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; -  EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -} - -} // namespace -} // namespace extractor diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc index 0cf4d1f6..4a514b12 100644 --- a/extractor/suffix_array.cc +++ b/extractor/suffix_array.cc @@ -10,7 +10,6 @@  #include "phrase_location.h"  #include "time_util.h" -namespace fs = boost::filesystem;  using namespace std;  using namespace chrono; @@ -188,12 +187,12 @@ shared_ptr<DataArray> SuffixArray::GetData() const {  PhraseLocation SuffixArray::Lookup(int low, int high, const string& word,                                     int offset) const { -  if (!data_array->HasWord(word)) { +  int word_id = data_array->GetWordId(word); +  if (word_id == -1) {      // Return empty phrase location.      return PhraseLocation(0, 0);    } -  int word_id = data_array->GetWordId(word);    if (offset == 0) {      return PhraseLocation(word_start[word_id], word_start[word_id + 1]);    } diff --git a/extractor/suffix_array.h b/extractor/suffix_array.h index 8ee454ec..df80c152 100644 --- a/extractor/suffix_array.h +++ b/extractor/suffix_array.h @@ -5,12 +5,10 @@  #include <string>  #include <vector> -#include <boost/filesystem.hpp>  #include <boost/serialization/serialization.hpp>  #include <boost/serialization/split_member.hpp>  #include <boost/serialization/vector.hpp> -namespace fs = boost::filesystem;  using namespace std;  namespace extractor { diff --git a/extractor/suffix_array_sampler.cc b/extractor/suffix_array_sampler.cc new file mode 100644 index 00000000..4a4ced34 --- /dev/null +++ b/extractor/suffix_array_sampler.cc @@ -0,0 +1,40 @@ +#include "suffix_array_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" +#include "suffix_array.h" + +namespace extractor { + +SuffixArrayRangeSampler::SuffixArrayRangeSampler( +    shared_ptr<SuffixArray> source_suffix_array, int max_samples) : +    BackoffSampler(source_suffix_array->GetData(), max_samples), +    source_suffix_array(source_suffix_array) {} + +SuffixArrayRangeSampler::SuffixArrayRangeSampler() {} + +int SuffixArrayRangeSampler::GetNumSubpatterns(const PhraseLocation&) const { +  return 1; +} + +int SuffixArrayRangeSampler::GetRangeLow( +    const PhraseLocation& location) const { +  return location.sa_low; +} + +int SuffixArrayRangeSampler::GetRangeHigh( +    const PhraseLocation& location) const { +  return location.sa_high; +} + +int SuffixArrayRangeSampler::GetPosition( +    const PhraseLocation&, int position) const { +  return source_suffix_array->GetSuffix(position); +} + +void SuffixArrayRangeSampler::AppendMatching( +    vector<int>& samples, int index, const PhraseLocation&) const { +  samples.push_back(source_suffix_array->GetSuffix(index)); +} + +} // namespace extractor diff --git a/extractor/suffix_array_sampler.h b/extractor/suffix_array_sampler.h new file mode 100644 index 00000000..bb3c2653 --- /dev/null +++ b/extractor/suffix_array_sampler.h @@ -0,0 +1,34 @@ +#ifndef _SUFFIX_ARRAY_SAMPLER_H_ +#define _SUFFIX_ARRAY_SAMPLER_H_ + +#include "backoff_sampler.h" + +namespace extractor { + +class SuffixArray; + +class SuffixArrayRangeSampler : public BackoffSampler { + public: +  SuffixArrayRangeSampler(shared_ptr<SuffixArray> suffix_array, +                          int max_samples); + +  SuffixArrayRangeSampler(); + + private: +  int GetNumSubpatterns(const PhraseLocation& location) const; + +  int GetRangeLow(const PhraseLocation& location) const; + +  int GetRangeHigh(const PhraseLocation& location) const; + +  int GetPosition(const PhraseLocation& location, int index) const; + +  void AppendMatching(vector<int>& samples, int index, +                      const PhraseLocation& location) const; + +  shared_ptr<SuffixArray> source_suffix_array; +}; + +} // namespace extractor + +#endif diff --git a/extractor/suffix_array_sampler_test.cc b/extractor/suffix_array_sampler_test.cc new file mode 100644 index 00000000..4b88c027 --- /dev/null +++ b/extractor/suffix_array_sampler_test.cc @@ -0,0 +1,114 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_data_array.h" +#include "mocks/mock_suffix_array.h" +#include "suffix_array_sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class SuffixArraySamplerTest : public Test { + protected: +  virtual void SetUp() { +    data_array = make_shared<MockDataArray>(); +    for (int i = 0; i < 10; ++i) { +      EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); +    } + +    suffix_array = make_shared<MockSuffixArray>(); +    EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); +    for (int i = 0; i < 10; ++i) { +      EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); +    } +  } + +  shared_ptr<MockDataArray> data_array; +  shared_ptr<MockSuffixArray> suffix_array; +}; + +TEST_F(SuffixArraySamplerTest, TestSample) { +  PhraseLocation location(0, 10); +  unordered_set<int> blacklisted_sentence_ids; + +  SuffixArrayRangeSampler sampler(suffix_array, 1); +  vector<int> expected_locations = {0}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 2); +  expected_locations = {0, 5}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 3); +  expected_locations = {0, 3, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 4); +  expected_locations = {0, 3, 5, 8}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 100); +  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(SuffixArraySamplerTest, TestBackoffSample) { +  PhraseLocation location(0, 10); + +  SuffixArrayRangeSampler sampler(suffix_array, 1); +  unordered_set<int> blacklisted_sentence_ids = {0}; +  vector<int> expected_locations = {1}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8}; +  expected_locations = {9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 2); +  blacklisted_sentence_ids = {0, 5}; +  expected_locations = {1, 4}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {0, 1, 2, 3}; +  expected_locations = {4, 5}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 3); +  blacklisted_sentence_ids = {0, 3, 7}; +  expected_locations = {1, 2, 6}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 4); +  blacklisted_sentence_ids = {0, 3, 5, 8}; +  expected_locations = {1, 2, 4, 7}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  sampler = SuffixArrayRangeSampler(suffix_array, 100); +  blacklisted_sentence_ids = {0}; +  expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); + +  blacklisted_sentence_ids = {9}; +  expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; +  EXPECT_EQ(PhraseLocation(expected_locations, 1), +            sampler.Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index ba0dbcc3..161edbc0 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -21,7 +21,7 @@ class SuffixArrayTest : public Test {    virtual void SetUp() {      data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2};      data_array = make_shared<MockDataArray>(); -    EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); +    EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data));      EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7));      EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13));      suffix_array = SuffixArray(data_array); @@ -55,22 +55,18 @@ TEST_F(SuffixArrayTest, TestLookup) {      EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));    } -  EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true));    EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6));    EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0)); -  EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false)); +  EXPECT_CALL(*data_array, GetWordId("word2")).WillRepeatedly(Return(-1));    EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0)); -  EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true));    EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4));    EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1)); -  EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true));    EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1));    EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2)); -  EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true));    EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2));    EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3)); diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index 1b1ba112..11e29e1e 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -90,13 +90,12 @@ void TranslationTable::IncrementLinksCount(  double TranslationTable::GetTargetGivenSourceScore(      const string& source_word, const string& target_word) { -  if (!source_data_array->HasWord(source_word) || -      !target_data_array->HasWord(target_word)) { +  int source_id = source_data_array->GetWordId(source_word); +  int target_id = target_data_array->GetWordId(target_word); +  if (source_id == -1 || target_id == -1) {      return -1;    } -  int source_id = source_data_array->GetWordId(source_word); -  int target_id = target_data_array->GetWordId(target_word);    auto entry = make_pair(source_id, target_id);    auto it = translation_probabilities.find(entry);    if (it == translation_probabilities.end()) { @@ -107,13 +106,12 @@ double TranslationTable::GetTargetGivenSourceScore(  double TranslationTable::GetSourceGivenTargetScore(      const string& source_word, const string& target_word) { -  if (!source_data_array->HasWord(source_word) || -      !target_data_array->HasWord(target_word)) { +  int source_id = source_data_array->GetWordId(source_word); +  int target_id = target_data_array->GetWordId(target_word); +  if (source_id == -1 || target_id == -1) {      return -1;    } -  int source_id = source_data_array->GetWordId(source_word); -  int target_id = target_data_array->GetWordId(target_word);    auto entry = make_pair(source_id, target_id);    auto it = translation_probabilities.find(entry);    if (it == translation_probabilities.end()) { diff --git a/extractor/translation_table.h b/extractor/translation_table.h index 2a37bab7..97620727 100644 --- a/extractor/translation_table.h +++ b/extractor/translation_table.h @@ -5,14 +5,12 @@  #include <string>  #include <unordered_map> -#include <boost/filesystem.hpp>  #include <boost/functional/hash.hpp>  #include <boost/serialization/serialization.hpp>  #include <boost/serialization/split_member.hpp>  #include <boost/serialization/utility.hpp>  using namespace std; -namespace fs = boost::filesystem;  namespace extractor { diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc index 606777bd..3cfc0011 100644 --- a/extractor/translation_table_test.cc +++ b/extractor/translation_table_test.cc @@ -28,7 +28,7 @@ class TranslationTableTest : public Test {      vector<int> source_sentence_start = {0, 6, 10, 14};      shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>();      EXPECT_CALL(*source_data_array, GetData()) -        .WillRepeatedly(ReturnRef(source_data)); +        .WillRepeatedly(Return(source_data));      EXPECT_CALL(*source_data_array, GetNumSentences())          .WillRepeatedly(Return(3));      for (size_t i = 0; i < source_sentence_start.size(); ++i) { @@ -36,31 +36,25 @@ class TranslationTableTest : public Test {            .WillRepeatedly(Return(source_sentence_start[i]));      }      for (size_t i = 0; i < words.size(); ++i) { -      EXPECT_CALL(*source_data_array, HasWord(words[i])) -          .WillRepeatedly(Return(true));        EXPECT_CALL(*source_data_array, GetWordId(words[i]))            .WillRepeatedly(Return(i + 2));      } -    EXPECT_CALL(*source_data_array, HasWord("d")) -        .WillRepeatedly(Return(false)); +    EXPECT_CALL(*source_data_array, GetWordId("d")).WillRepeatedly(Return(-1));      vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0};      vector<int> target_sentence_start = {0, 7, 10, 13};      shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>();      EXPECT_CALL(*target_data_array, GetData()) -        .WillRepeatedly(ReturnRef(target_data)); +        .WillRepeatedly(Return(target_data));      for (size_t i = 0; i < target_sentence_start.size(); ++i) {        EXPECT_CALL(*target_data_array, GetSentenceStart(i))            .WillRepeatedly(Return(target_sentence_start[i]));      }      for (size_t i = 0; i < words.size(); ++i) { -      EXPECT_CALL(*target_data_array, HasWord(words[i])) -          .WillRepeatedly(Return(true));        EXPECT_CALL(*target_data_array, GetWordId(words[i]))            .WillRepeatedly(Return(i + 2));      } -    EXPECT_CALL(*target_data_array, HasWord("d")) -        .WillRepeatedly(Return(false)); +    EXPECT_CALL(*target_data_array, GetWordId("d")).WillRepeatedly(Return(-1));      vector<pair<int, int>> links1 = {        make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc index 15795d1e..c9c2d6f4 100644 --- a/extractor/vocabulary.cc +++ b/extractor/vocabulary.cc @@ -8,12 +8,13 @@ int Vocabulary::GetTerminalIndex(const string& word) {    int word_id = -1;    #pragma omp critical (vocabulary)    { -    if (!dictionary.count(word)) { +    auto it = dictionary.find(word); +    if (it != dictionary.end()) { +      word_id = it->second; +    } else {        word_id = words.size();        dictionary[word] = word_id;        words.push_back(word); -    } else { -      word_id = dictionary[word];      }    }    return word_id; @@ -34,4 +35,8 @@ string Vocabulary::GetTerminalValue(int symbol) {    return word;  } +bool Vocabulary::operator==(const Vocabulary& other) const { +  return words == other.words && dictionary == other.dictionary; +} +  } // namespace extractor diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h index c8fd9411..db092e99 100644 --- a/extractor/vocabulary.h +++ b/extractor/vocabulary.h @@ -5,6 +5,10 @@  #include <unordered_map>  #include <vector> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/string.hpp> +#include <boost/serialization/vector.hpp> +  using namespace std;  namespace extractor { @@ -14,7 +18,7 @@ namespace extractor {   *   * This strucure contains words located in the frequent collocations and words   * encountered during the grammar extraction time. This dictionary is - * considerably smaller than the dictionaries in the data arrays (and so is the + * considerably smaller than the dictionaries in the data arays (and so is the   * query time). Note that this is the single data structure that changes state   * and needs to have thread safe read/write operations.   * @@ -38,7 +42,24 @@ class Vocabulary {    // Returns the word corresponding to the given word id.    virtual string GetTerminalValue(int symbol); +  bool operator==(const Vocabulary& vocabulary) const; +   private: +  friend class boost::serialization::access; + +  template<class Archive> void save(Archive& ar, unsigned int) const { +    ar << words; +  } + +  template<class Archive> void load(Archive& ar, unsigned int) { +    ar >> words; +    for (size_t i = 0; i < words.size(); ++i) { +      dictionary[words[i]] = i; +    } +  } + +  BOOST_SERIALIZATION_SPLIT_MEMBER(); +    unordered_map<string, int> dictionary;    vector<string> words;  }; diff --git a/extractor/vocabulary_test.cc b/extractor/vocabulary_test.cc new file mode 100644 index 00000000..cf5e3e36 --- /dev/null +++ b/extractor/vocabulary_test.cc @@ -0,0 +1,45 @@ +#include <gtest/gtest.h> + +#include <sstream> +#include <string> +#include <vector> + +#include <boost/archive/text_iarchive.hpp> +#include <boost/archive/text_oarchive.hpp> + +#include "vocabulary.h" + +using namespace std; +using namespace ::testing; +namespace ar = boost::archive; + +namespace extractor { +namespace { + +TEST(VocabularyTest, TestIndexes) { +  Vocabulary vocabulary; +  EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero")); +  EXPECT_EQ("zero", vocabulary.GetTerminalValue(0)); + +  EXPECT_EQ(1, vocabulary.GetTerminalIndex("one")); +  EXPECT_EQ("one", vocabulary.GetTerminalValue(1)); +} + +TEST(VocabularyTest, TestSerialization) { +  Vocabulary vocabulary; +  EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero")); +  EXPECT_EQ("zero", vocabulary.GetTerminalValue(0)); + +  stringstream stream(ios_base::out | ios_base::in); +  ar::text_oarchive output_stream(stream, ar::no_header); +  output_stream << vocabulary; + +  Vocabulary vocabulary_copy; +  ar::text_iarchive input_stream(stream, ar::no_header); +  input_stream >> vocabulary_copy; + +  EXPECT_EQ(vocabulary, vocabulary_copy); +} + +} // namespace +} // namespace extractor | 
