From 3c73e472444ff0cd436b12f3679440a6969cbf2d Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Mon, 25 Nov 2013 23:56:31 +0000 Subject: Clean up leave-one-out sampling. --- extractor/sampler.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'extractor/sampler.cc') diff --git a/extractor/sampler.cc b/extractor/sampler.cc index 963afa7a..fc386ed1 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -12,7 +12,9 @@ Sampler::Sampler() {} Sampler::~Sampler() {} -PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array) const { +PhraseLocation Sampler::Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const { vector sample; int num_subpatterns; if (location.matchings == NULL) { @@ -22,10 +24,11 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s double step = max(1.0, (double) (high - low) / max_samples); double i = low, last = i; bool found; + shared_ptr source_data_array = suffix_array->GetData(); while (sample.size() < max_samples && i < high) { int x = suffix_array->GetSuffix(Round(i)); int id = source_data_array->GetSentenceId(x); - if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { + if (blacklisted_sentence_ids.count(id)) { found = false; double backoff_step = 1; while (true) { @@ -33,13 +36,14 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s double j = i - backoff_step; x = suffix_array->GetSuffix(Round(j)); id = source_data_array->GetSentenceId(x); - if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { found = true; last = i; break; } double k = i + backoff_step; x = suffix_array->GetSuffix(Round(k)); id = source_data_array->GetSentenceId(x); - if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + if (k < min(i+step, (double)high) && + !blacklisted_sentence_ids.count(id)) { found = true; last = k; break; } if (j <= last && k >= high) break; -- cgit v1.2.3 From bed3e4b867e4132917fa0640956e8ce713f0e451 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Tue, 26 Nov 2013 15:01:14 +0000 Subject: Script for grammar extraction only. --- .gitignore | 1 + extractor/Makefile.am | 40 +------ extractor/extract.cc | 253 ++++++++++++++++++++++++++++++++++++++++++ extractor/grammar_extractor.h | 1 - extractor/run_extractor.cc | 20 ++-- extractor/sampler.cc | 23 ++-- 6 files changed, 278 insertions(+), 60 deletions(-) create mode 100644 extractor/extract.cc (limited to 'extractor/sampler.cc') diff --git a/.gitignore b/.gitignore index 942539cb..f964fa0c 100644 --- a/.gitignore +++ b/.gitignore @@ -72,6 +72,7 @@ extools/score_grammar extools/sg_lexer.cc extractor/*_test extractor/compile +extractor/extract extractor/run_extractor gi/clda/src/clda gi/markov_al/ml diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 64a5a2b5..7825012c 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -1,5 +1,5 @@ -bin_PROGRAMS = compile run_extractor +bin_PROGRAMS = compile run_extractor extract if HAVE_CXX11 @@ -105,44 +105,14 @@ translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $ vocabulary_test_SOURCES = vocabulary_test.cc vocabulary_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a -noinst_LIBRARIES = libextractor.a libcompile.a +noinst_LIBRARIES = libextractor.a compile_SOURCES = compile.cc -compile_LDADD = libcompile.a +compile_LDADD = libextractor.a run_extractor_SOURCES = run_extractor.cc run_extractor_LDADD = libextractor.a - -libcompile_a_SOURCES = \ - alignment.cc \ - data_array.cc \ - phrase_location.cc \ - precomputation.cc \ - suffix_array.cc \ - time_util.cc \ - translation_table.cc \ - vocabulary.cc \ - alignment.h \ - data_array.h \ - fast_intersector.h \ - grammar.h \ - grammar_extractor.h \ - matchings_finder.h \ - matchings_trie.h \ - phrase.h \ - phrase_builder.h \ - phrase_location.h \ - precomputation.h \ - rule.h \ - rule_extractor.h \ - rule_extractor_helper.h \ - rule_factory.h \ - sampler.h \ - scorer.h \ - suffix_array.h \ - target_phrase_extractor.h \ - time_util.h \ - translation_table.h \ - vocabulary.h +extract_SOURCES = extract.cc +extract_LDADD = libextractor.a libextractor_a_SOURCES = \ alignment.cc \ diff --git a/extractor/extract.cc b/extractor/extract.cc new file mode 100644 index 00000000..2d5831fa --- /dev/null +++ b/extractor/extract.cc @@ -0,0 +1,253 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "alignment.h" +#include "data_array.h" +#include "features/count_source_target.h" +#include "features/feature.h" +#include "features/is_source_singleton.h" +#include "features/is_source_target_singleton.h" +#include "features/max_lex_source_given_target.h" +#include "features/max_lex_target_given_source.h" +#include "features/sample_source_count.h" +#include "features/target_given_source_coherent.h" +#include "grammar.h" +#include "grammar_extractor.h" +#include "precomputation.h" +#include "rule.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "translation_table.h" +#include "vocabulary.h" + +namespace ar = boost::archive; +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace extractor; +using namespace features; +using namespace std; + +// Returns the file path in which a given grammar should be written. +fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) { + string file_name = "grammar." + to_string(file_number); + return grammar_path / file_name; +} + +int main(int argc, char** argv) { + po::options_description general_options("General options"); + int max_threads = 1; + #pragma omp parallel + max_threads = omp_get_num_threads(); + string threads_option = "Number of threads used for grammar extraction " + "max(" + to_string(max_threads) + ")"; + general_options.add_options() + ("threads,t", po::value()->required()->default_value(1), + threads_option.c_str()) + ("grammars,g", po::value()->required(), "Grammars output path") + ("max_rule_span", po::value()->default_value(15), + "Maximum rule span") + ("max_rule_symbols", po::value()->default_value(5), + "Maximum number of symbols (terminals + nontermals) in a rule") + ("min_gap_size", po::value()->default_value(1), "Minimum gap size") + ("max_nonterminals", po::value()->default_value(2), + "Maximum number of nonterminals in a rule") + ("max_samples", po::value()->default_value(300), + "Maximum number of samples") + ("tight_phrases", po::value()->default_value(true), + "False if phrases may be loose (better, but slower)") + ("leave_one_out", po::value()->zero_tokens(), + "do leave-one-out estimation of grammars " + "(e.g. for extracting grammars for the training set"); + + po::options_description cmdline_options("Command line options"); + cmdline_options.add_options() + ("help", "Show available options") + ("config", po::value()->required(), "Path to config file"); + cmdline_options.add(general_options); + + po::options_description config_options("Config file options"); + config_options.add_options() + ("target", po::value()->required(), + "Path to target data file in binary format") + ("source", po::value()->required(), + "Path to source suffix array file in binary format") + ("alignment", po::value()->required(), + "Path to alignment file in binary format") + ("precomputation", po::value()->required(), + "Path to precomputation file in binary format") + ("vocabulary", po::value()->required(), + "Path to vocabulary file in binary format") + ("ttable", po::value()->required(), + "Path to translation table in binary format"); + config_options.add(general_options); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, cmdline_options), vm); + if (vm.count("help")) { + po::options_description all_options; + all_options.add(cmdline_options).add(config_options); + cout << all_options << endl; + return 0; + } + + po::notify(vm); + + ifstream config_stream(vm["config"].as()); + po::store(po::parse_config_file(config_stream, config_options), vm); + po::notify(vm); + + int num_threads = vm["threads"].as(); + cerr << "Grammar extraction will use " << num_threads << " threads." << endl; + + Clock::time_point read_start_time = Clock::now(); + + Clock::time_point start_time = Clock::now(); + cerr << "Reading target data in binary format..." << endl; + shared_ptr target_data_array = make_shared(); + ifstream target_fstream(vm["target"].as()); + ar::binary_iarchive target_stream(target_fstream); + target_stream >> *target_data_array; + Clock::time_point end_time = Clock::now(); + cerr << "Reading target data took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading source suffix array in binary format..." << endl; + shared_ptr source_suffix_array = make_shared(); + ifstream source_fstream(vm["source"].as()); + ar::binary_iarchive source_stream(source_fstream); + source_stream >> *source_suffix_array; + end_time = Clock::now(); + cerr << "Reading source suffix array took " + << GetDuration(start_time, end_time) << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading alignment in binary format..." << endl; + shared_ptr alignment = make_shared(); + ifstream alignment_fstream(vm["alignment"].as()); + ar::binary_iarchive alignment_stream(alignment_fstream); + alignment_stream >> *alignment; + end_time = Clock::now(); + cerr << "Reading alignment took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading precomputation in binary format..." << endl; + shared_ptr precomputation = make_shared(); + ifstream precomputation_fstream(vm["precomputation"].as()); + ar::binary_iarchive precomputation_stream(precomputation_fstream); + precomputation_stream >> *precomputation; + end_time = Clock::now(); + cerr << "Reading precomputation took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading vocabulary in binary format..." << endl; + shared_ptr vocabulary = make_shared(); + ifstream vocabulary_fstream(vm["vocabulary"].as()); + ar::binary_iarchive vocabulary_stream(vocabulary_fstream); + vocabulary_stream >> *vocabulary; + end_time = Clock::now(); + cerr << "Reading vocabulary took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading translation table in binary format..." << endl; + shared_ptr table = make_shared(); + ifstream ttable_fstream(vm["ttable"].as()); + ar::binary_iarchive ttable_stream(ttable_fstream); + ttable_stream >> *table; + end_time = Clock::now(); + cerr << "Reading translation table took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + Clock::time_point read_end_time = Clock::now(); + cerr << "Total time spent loading data structures into memory: " + << GetDuration(read_start_time, read_end_time) << " seconds" << endl; + + Clock::time_point extraction_start_time = Clock::now(); + // Features used to score each grammar rule. + vector> features = { + make_shared(), + make_shared(), + make_shared(), + make_shared(table), + make_shared(table), + make_shared(), + make_shared() + }; + shared_ptr scorer = make_shared(features); + + GrammarExtractor extractor( + source_suffix_array, + target_data_array, + alignment, + precomputation, + scorer, + vocabulary, + vm["min_gap_size"].as(), + vm["max_rule_span"].as(), + vm["max_nonterminals"].as(), + vm["max_rule_symbols"].as(), + vm["max_samples"].as(), + vm["tight_phrases"].as()); + + // Creates the grammars directory if it doesn't exist. + fs::path grammar_path = vm["grammars"].as(); + if (!fs::is_directory(grammar_path)) { + fs::create_directory(grammar_path); + } + + // Reads all sentences for which we extract grammar rules (the paralellization + // is simplified if we read all sentences upfront). + string sentence; + vector sentences; + while (getline(cin, sentence)) { + sentences.push_back(sentence); + } + + // Extracts the grammar for each sentence and saves it to a file. + vector suffixes(sentences.size()); + bool leave_one_out = vm.count("leave_one_out"); + #pragma omp parallel for schedule(dynamic) num_threads(num_threads) + for (size_t i = 0; i < sentences.size(); ++i) { + string suffix; + int position = sentences[i].find("|||"); + if (position != sentences[i].npos) { + suffix = sentences[i].substr(position); + sentences[i] = sentences[i].substr(0, position); + } + suffixes[i] = suffix; + + unordered_set blacklisted_sentence_ids; + if (leave_one_out) { + blacklisted_sentence_ids.insert(i); + } + Grammar grammar = extractor.GetGrammar( + sentences[i], blacklisted_sentence_ids); + ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); + // output << grammar; + } + + for (size_t i = 0; i < sentences.size(); ++i) { + cout << " " << sentences[i] << " " << suffixes[i] << endl; + } + + Clock::time_point extraction_stop_time = Clock::now(); + cerr << "Overall extraction step took " + << GetDuration(extraction_start_time, extraction_stop_time) + << " seconds" << endl; + + return 0; +} diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index eb79f53c..0f3069b0 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -15,7 +15,6 @@ class DataArray; class Grammar; class HieroCachingRuleFactory; class Precomputation; -class Rule; class Scorer; class SuffixArray; class Vocabulary; diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 6b22a302..f1aa5e35 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -5,10 +5,10 @@ #include #include -#include #include #include #include +#include #include "alignment.h" #include "data_array.h" @@ -78,7 +78,8 @@ int main(int argc, char** argv) { ("tight_phrases", po::value()->default_value(true), "False if phrases may be loose (better, but slower)") ("leave_one_out", po::value()->zero_tokens(), - "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set"); + "do leave-one-out estimation of grammars " + "(e.g. for extracting grammars for the training set"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -99,11 +100,6 @@ int main(int argc, char** argv) { return 1; } - bool leave_one_out = false; - if (vm.count("leave_one_out")) { - leave_one_out = true; - } - int num_threads = vm["threads"].as(); cerr << "Grammar extraction will use " << num_threads << " threads." << endl; @@ -178,8 +174,8 @@ int main(int argc, char** argv) { << GetDuration(preprocess_start_time, preprocess_stop_time) << " seconds" << endl; - // Features used to score each grammar rule. Clock::time_point extraction_start_time = Clock::now(); + // Features used to score each grammar rule. vector> features = { make_shared(), make_shared(), @@ -206,9 +202,6 @@ int main(int argc, char** argv) { vm["max_samples"].as(), vm["tight_phrases"].as()); - // Releases extra memory used by the initial precomputation. - precomputation.reset(); - // Creates the grammars directory if it doesn't exist. fs::path grammar_path = vm["grammars"].as(); if (!fs::is_directory(grammar_path)) { @@ -224,6 +217,7 @@ int main(int argc, char** argv) { } // Extracts the grammar for each sentence and saves it to a file. + bool leave_one_out = vm.count("leave_one_out"); vector suffixes(sentences.size()); #pragma omp parallel for schedule(dynamic) num_threads(num_threads) for (size_t i = 0; i < sentences.size(); ++i) { @@ -236,7 +230,9 @@ int main(int argc, char** argv) { suffixes[i] = suffix; unordered_set blacklisted_sentence_ids; - if (leave_one_out) blacklisted_sentence_ids.insert(i); + if (leave_one_out) { + blacklisted_sentence_ids.insert(i); + } Grammar grammar = extractor.GetGrammar( sentences[i], blacklisted_sentence_ids); ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); diff --git a/extractor/sampler.cc b/extractor/sampler.cc index fc386ed1..887aaec1 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -15,6 +15,7 @@ Sampler::~Sampler() {} PhraseLocation Sampler::Sample( const PhraseLocation& location, const unordered_set& blacklisted_sentence_ids) const { + shared_ptr source_data_array = suffix_array->GetData(); vector sample; int num_subpatterns; if (location.matchings == NULL) { @@ -22,32 +23,30 @@ PhraseLocation Sampler::Sample( num_subpatterns = 1; int low = location.sa_low, high = location.sa_high; double step = max(1.0, (double) (high - low) / max_samples); - double i = low, last = i; - bool found; - shared_ptr source_data_array = suffix_array->GetData(); + double i = low, last = i - 1; while (sample.size() < max_samples && i < high) { int x = suffix_array->GetSuffix(Round(i)); int id = source_data_array->GetSentenceId(x); + bool found = false; if (blacklisted_sentence_ids.count(id)) { - found = false; - double backoff_step = 1; - while (true) { - if ((double)backoff_step >= step) break; + for (int backoff_step = 1; backoff_step <= step; ++backoff_step) { double j = i - backoff_step; x = suffix_array->GetSuffix(Round(j)); id = source_data_array->GetSentenceId(x); if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { - found = true; last = i; break; + found = true; + last = i; + break; } double k = i + backoff_step; x = suffix_array->GetSuffix(Round(k)); id = source_data_array->GetSentenceId(x); - if (k < min(i+step, (double)high) && + if (k < min(i+step, (double) high) && !blacklisted_sentence_ids.count(id)) { - found = true; last = k; break; + found = true; + last = k; + break; } - if (j <= last && k >= high) break; - backoff_step++; } } else { found = true; -- cgit v1.2.3 From a6e6a369f40d8fb6a191fd7f74fc5efa8bfae2a0 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Wed, 27 Nov 2013 14:33:36 +0000 Subject: Unify sampling backoff strategy. --- extractor/Makefile.am | 24 ++++-- extractor/backoff_sampler.cc | 66 ++++++++++++++++ extractor/backoff_sampler.h | 41 ++++++++++ extractor/matchings_sampler.cc | 38 +++++++++ extractor/matchings_sampler.h | 31 ++++++++ extractor/matchings_sampler_test.cc | 118 ++++++++++++++++++++++++++++ extractor/mocks/mock_matchings_sampler.h | 15 ++++ extractor/mocks/mock_suffix_array_sampler.h | 15 ++++ extractor/phrase_location.cc | 2 + extractor/phrase_location_sampler.cc | 34 ++++++++ extractor/phrase_location_sampler.h | 35 +++++++++ extractor/phrase_location_sampler_test.cc | 50 ++++++++++++ extractor/precomputation.cc | 3 +- extractor/precomputation_test.cc | 2 +- extractor/rule_factory.cc | 4 +- extractor/sampler.cc | 78 ------------------ extractor/sampler.h | 22 +----- extractor/sampler_test.cc | 92 ---------------------- extractor/sampler_test_blacklist.cc | 102 ------------------------ extractor/suffix_array_sampler.cc | 40 ++++++++++ extractor/suffix_array_sampler.h | 34 ++++++++ extractor/suffix_array_sampler_test.cc | 114 +++++++++++++++++++++++++++ 22 files changed, 657 insertions(+), 303 deletions(-) create mode 100644 extractor/backoff_sampler.cc create mode 100644 extractor/backoff_sampler.h create mode 100644 extractor/matchings_sampler.cc create mode 100644 extractor/matchings_sampler.h create mode 100644 extractor/matchings_sampler_test.cc create mode 100644 extractor/mocks/mock_matchings_sampler.h create mode 100644 extractor/mocks/mock_suffix_array_sampler.h create mode 100644 extractor/phrase_location_sampler.cc create mode 100644 extractor/phrase_location_sampler.h create mode 100644 extractor/phrase_location_sampler_test.cc delete mode 100644 extractor/sampler.cc delete mode 100644 extractor/sampler_test.cc delete mode 100644 extractor/sampler_test_blacklist.cc create mode 100644 extractor/suffix_array_sampler.cc create mode 100644 extractor/suffix_array_sampler.h create mode 100644 extractor/suffix_array_sampler_test.cc (limited to 'extractor/sampler.cc') diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 7825012c..e5b439f9 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -15,13 +15,15 @@ EXTRA_PROGRAMS = alignment_test \ feature_target_given_source_coherent_test \ grammar_extractor_test \ matchings_finder_test \ + matchings_sampler_test \ + phrase_location_sampler_test \ phrase_test \ precomputation_test \ rule_extractor_helper_test \ rule_extractor_test \ rule_factory_test \ - sampler_test \ scorer_test \ + suffix_array_sampler_test \ suffix_array_test \ target_phrase_extractor_test \ translation_table_test \ @@ -40,13 +42,15 @@ if HAVE_GTEST feature_target_given_source_coherent_test \ grammar_extractor_test \ matchings_finder_test \ + matchings_sampler_test \ + phrase_location_sampler_test \ phrase_test \ precomputation_test \ rule_extractor_helper_test \ rule_extractor_test \ rule_factory_test \ - sampler_test \ scorer_test \ + suffix_array_sampler_test \ suffix_array_test \ target_phrase_extractor_test \ translation_table_test \ @@ -55,8 +59,7 @@ endif noinst_PROGRAMS = $(RUNNABLE_TESTS) -# TESTS = $(RUNNABLE_TESTS) -TESTS = vocabulary_test +TESTS = $(RUNNABLE_TESTS) alignment_test_SOURCES = alignment_test.cc alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a @@ -82,6 +85,10 @@ grammar_extractor_test_SOURCES = grammar_extractor_test.cc grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a matchings_finder_test_SOURCES = matchings_finder_test.cc matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +matchings_sampler_test_SOURCES = matchings_sampler_test.cc +matchings_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +phrase_location_sampler_test_SOURCES = phrase_location_sampler_test.cc +phrase_location_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a phrase_test_SOURCES = phrase_test.cc phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a precomputation_test_SOURCES = precomputation_test.cc @@ -92,10 +99,10 @@ rule_extractor_test_SOURCES = rule_extractor_test.cc rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a rule_factory_test_SOURCES = rule_factory_test.cc rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a -sampler_test_SOURCES = sampler_test.cc -sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a scorer_test_SOURCES = scorer_test.cc scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +suffix_array_sampler_test_SOURCES = suffix_array_sampler_test.cc +suffix_array_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a suffix_array_test_SOURCES = suffix_array_test.cc suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc @@ -116,6 +123,7 @@ extract_LDADD = libextractor.a libextractor_a_SOURCES = \ alignment.cc \ + backoff_sampler.cc \ data_array.cc \ fast_intersector.cc \ features/count_source_target.cc \ @@ -129,18 +137,20 @@ libextractor_a_SOURCES = \ grammar.cc \ grammar_extractor.cc \ matchings_finder.cc \ + matchings_sampler.cc \ matchings_trie.cc \ phrase.cc \ phrase_builder.cc \ phrase_location.cc \ + phrase_location_sampler.cc \ precomputation.cc \ rule.cc \ rule_extractor.cc \ rule_extractor_helper.cc \ rule_factory.cc \ - sampler.cc \ scorer.cc \ suffix_array.cc \ + suffix_array_sampler.cc \ target_phrase_extractor.cc \ time_util.cc \ translation_table.cc \ diff --git a/extractor/backoff_sampler.cc b/extractor/backoff_sampler.cc new file mode 100644 index 00000000..28b12909 --- /dev/null +++ b/extractor/backoff_sampler.cc @@ -0,0 +1,66 @@ +#include "backoff_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" + +namespace extractor { + +BackoffSampler::BackoffSampler( + shared_ptr source_data_array, int max_samples) : + source_data_array(source_data_array), max_samples(max_samples) {} + +BackoffSampler::BackoffSampler() {} + +PhraseLocation BackoffSampler::Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const { + vector samples; + int low = GetRangeLow(location), high = GetRangeHigh(location); + int last_position = low - 1; + double step = max(1.0, (double) (high - low) / max_samples); + for (double num_samples = 0, i = low; + num_samples < max_samples && i < high; + ++num_samples, i += step) { + int position = GetPosition(location, round(i)); + int sentence_id = source_data_array->GetSentenceId(position); + bool found = false; + if (last_position >= position || + blacklisted_sentence_ids.count(sentence_id)) { + for (double backoff_step = 1; backoff_step < step; ++backoff_step) { + double j = i - backoff_step; + if (round(j) >= 0) { + position = GetPosition(location, round(j)); + sentence_id = source_data_array->GetSentenceId(position); + if (position > last_position && + !blacklisted_sentence_ids.count(sentence_id)) { + found = true; + last_position = position; + break; + } + } + + double k = i + backoff_step; + if (round(k) < high) { + position = GetPosition(location, round(k)); + sentence_id = source_data_array->GetSentenceId(position); + if (!blacklisted_sentence_ids.count(sentence_id)) { + found = true; + last_position = position; + break; + } + } + } + } else { + found = true; + last_position = position; + } + + if (found) { + AppendMatching(samples, position, location); + } + } + + return PhraseLocation(samples, GetNumSubpatterns(location)); +} + +} // namespace extractor diff --git a/extractor/backoff_sampler.h b/extractor/backoff_sampler.h new file mode 100644 index 00000000..5c244105 --- /dev/null +++ b/extractor/backoff_sampler.h @@ -0,0 +1,41 @@ +#ifndef _BACKOFF_SAMPLER_H_ +#define _BACKOFF_SAMPLER_H_ + +#include + +#include "sampler.h" + +namespace extractor { + +class DataArray; +class PhraseLocation; + +class BackoffSampler : public Sampler { + public: + BackoffSampler(shared_ptr source_data_array, int max_samples); + + BackoffSampler(); + + PhraseLocation Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const; + + private: + virtual int GetNumSubpatterns(const PhraseLocation& location) const = 0; + + virtual int GetRangeLow(const PhraseLocation& location) const = 0; + + virtual int GetRangeHigh(const PhraseLocation& location) const = 0; + + virtual int GetPosition(const PhraseLocation& location, int index) const = 0; + + virtual void AppendMatching(vector& samples, int index, + const PhraseLocation& location) const = 0; + + shared_ptr source_data_array; + int max_samples; +}; + +} // namespace extractor + +#endif diff --git a/extractor/matchings_sampler.cc b/extractor/matchings_sampler.cc new file mode 100644 index 00000000..bb916e49 --- /dev/null +++ b/extractor/matchings_sampler.cc @@ -0,0 +1,38 @@ +#include "matchings_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" + +namespace extractor { + +MatchingsSampler::MatchingsSampler( + shared_ptr data_array, int max_samples) : + BackoffSampler(data_array, max_samples) {} + +MatchingsSampler::MatchingsSampler() {} + +int MatchingsSampler::GetNumSubpatterns(const PhraseLocation& location) const { + return location.num_subpatterns; +} + +int MatchingsSampler::GetRangeLow(const PhraseLocation&) const { + return 0; +} + +int MatchingsSampler::GetRangeHigh(const PhraseLocation& location) const { + return location.matchings->size() / location.num_subpatterns; +} + +int MatchingsSampler::GetPosition(const PhraseLocation& location, + int index) const { + return (*location.matchings)[index * location.num_subpatterns]; +} + +void MatchingsSampler::AppendMatching(vector& samples, int index, + const PhraseLocation& location) const { + copy(location.matchings->begin() + index, + location.matchings->begin() + index + location.num_subpatterns, + back_inserter(samples)); +} + +} // namespace extractor diff --git a/extractor/matchings_sampler.h b/extractor/matchings_sampler.h new file mode 100644 index 00000000..ca4fce93 --- /dev/null +++ b/extractor/matchings_sampler.h @@ -0,0 +1,31 @@ +#ifndef _MATCHINGS_SAMPLER_H_ +#define _MATCHINGS_SAMPLER_H_ + +#include "backoff_sampler.h" + +namespace extractor { + +class DataArray; + +class MatchingsSampler : public BackoffSampler { + public: + MatchingsSampler(shared_ptr data_array, int max_samples); + + MatchingsSampler(); + + private: + int GetNumSubpatterns(const PhraseLocation& location) const; + + int GetRangeLow(const PhraseLocation& location) const; + + int GetRangeHigh(const PhraseLocation& location) const; + + int GetPosition(const PhraseLocation& location, int index) const; + + void AppendMatching(vector& samples, int index, + const PhraseLocation& location) const; +}; + +} // namespace extractor + +#endif diff --git a/extractor/matchings_sampler_test.cc b/extractor/matchings_sampler_test.cc new file mode 100644 index 00000000..bc927152 --- /dev/null +++ b/extractor/matchings_sampler_test.cc @@ -0,0 +1,118 @@ +#include + +#include + +#include "mocks/mock_data_array.h" +#include "matchings_sampler.h" +#include "phrase_location.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class MatchingsSamplerTest : public Test { + protected: + virtual void SetUp() { + vector locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + location = PhraseLocation(locations, 2); + + data_array = make_shared(); + for (int i = 0; i < 10; ++i) { + EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i / 2)); + } + } + + unordered_set blacklisted_sentence_ids; + PhraseLocation location; + shared_ptr data_array; + shared_ptr sampler; +}; + +TEST_F(MatchingsSamplerTest, TestSample) { + sampler = make_shared(data_array, 1); + vector expected_locations = {0, 1}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 2); + expected_locations = {0, 1, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 3); + expected_locations = {0, 1, 4, 5, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 7); + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(MatchingsSamplerTest, TestBackoffSample) { + sampler = make_shared(data_array, 1); + blacklisted_sentence_ids = {0}; + vector expected_locations = {2, 3}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 1, 2, 3}; + expected_locations = {8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 1, 2, 3, 4}; + expected_locations = {}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 2); + blacklisted_sentence_ids = {0, 3}; + expected_locations = {2, 3, 4, 5}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 3); + blacklisted_sentence_ids = {0, 3}; + expected_locations = {2, 3, 4, 5, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 2, 3}; + expected_locations = {2, 3, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 4); + blacklisted_sentence_ids = {0, 1, 2, 3}; + expected_locations = {8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {1, 3}; + expected_locations = {0, 1, 4, 5, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 7); + blacklisted_sentence_ids = {0, 1, 2, 3, 4}; + expected_locations = {}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 2, 4}; + expected_locations = {2, 3, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {1, 3}; + expected_locations = {0, 1, 4, 5, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor diff --git a/extractor/mocks/mock_matchings_sampler.h b/extractor/mocks/mock_matchings_sampler.h new file mode 100644 index 00000000..de2009c3 --- /dev/null +++ b/extractor/mocks/mock_matchings_sampler.h @@ -0,0 +1,15 @@ +#include + +#include "phrase_location.h" +#include "matchings_sampler.h" + +namespace extractor { + +class MockMatchingsSampler : public MatchingsSampler { + public: + MOCK_CONST_METHOD2(Sample, PhraseLocation( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_suffix_array_sampler.h b/extractor/mocks/mock_suffix_array_sampler.h new file mode 100644 index 00000000..d799b969 --- /dev/null +++ b/extractor/mocks/mock_suffix_array_sampler.h @@ -0,0 +1,15 @@ +#include + +#include "phrase_location.h" +#include "suffix_array_sampler.h" + +namespace extractor { + +class MockSuffixArraySampler : public SuffixArrayRangeSampler { + public: + MOCK_CONST_METHOD2(Sample, PhraseLocation( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids)); +}; + +} // namespace extractor diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc index 13140cac..2c367893 100644 --- a/extractor/phrase_location.cc +++ b/extractor/phrase_location.cc @@ -1,5 +1,7 @@ #include "phrase_location.h" +#include + namespace extractor { PhraseLocation::PhraseLocation(int sa_low, int sa_high) : diff --git a/extractor/phrase_location_sampler.cc b/extractor/phrase_location_sampler.cc new file mode 100644 index 00000000..a2eec105 --- /dev/null +++ b/extractor/phrase_location_sampler.cc @@ -0,0 +1,34 @@ +#include "phrase_location_sampler.h" + +#include "matchings_sampler.h" +#include "phrase_location.h" +#include "suffix_array.h" +#include "suffix_array_sampler.h" + +namespace extractor { + +PhraseLocationSampler::PhraseLocationSampler( + shared_ptr suffix_array, int max_samples) { + matchings_sampler = make_shared( + suffix_array->GetData(), max_samples); + suffix_array_sampler = make_shared( + suffix_array, max_samples); +} + +PhraseLocationSampler::PhraseLocationSampler( + shared_ptr matchings_sampler, + shared_ptr suffix_array_sampler) : + matchings_sampler(matchings_sampler), + suffix_array_sampler(suffix_array_sampler) {} + +PhraseLocation PhraseLocationSampler::Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const { + if (location.matchings == NULL) { + return suffix_array_sampler->Sample(location, blacklisted_sentence_ids); + } else { + return matchings_sampler->Sample(location, blacklisted_sentence_ids); + } +} + +} // namespace extractor diff --git a/extractor/phrase_location_sampler.h b/extractor/phrase_location_sampler.h new file mode 100644 index 00000000..0e88335e --- /dev/null +++ b/extractor/phrase_location_sampler.h @@ -0,0 +1,35 @@ +#ifndef _PHRASE_LOCATION_SAMPLER_H_ +#define _PHRASE_LOCATION_SAMPLER_H_ + +#include + +#include "sampler.h" + +namespace extractor { + +class MatchingsSampler; +class PhraseLocation; +class SuffixArray; +class SuffixArrayRangeSampler; + +class PhraseLocationSampler : public Sampler { + public: + PhraseLocationSampler(shared_ptr suffix_array, int max_samples); + + // For testing only. + PhraseLocationSampler( + shared_ptr matchings_sampler, + shared_ptr suffix_array_sampler); + + PhraseLocation Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const; + + private: + shared_ptr matchings_sampler; + shared_ptr suffix_array_sampler; +}; + +} // namespace extractor + +#endif diff --git a/extractor/phrase_location_sampler_test.cc b/extractor/phrase_location_sampler_test.cc new file mode 100644 index 00000000..e7520ce7 --- /dev/null +++ b/extractor/phrase_location_sampler_test.cc @@ -0,0 +1,50 @@ +#include + +#include + +#include "mocks/mock_matchings_sampler.h" +#include "mocks/mock_suffix_array_sampler.h" +#include "phrase_location.h" +#include "phrase_location_sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class MatchingsSamplerTest : public Test { + protected: + virtual void SetUp() { + matchings_sampler = make_shared(); + suffix_array_sampler = make_shared(); + + sampler = make_shared( + matchings_sampler, suffix_array_sampler); + } + + shared_ptr matchings_sampler; + shared_ptr suffix_array_sampler; + shared_ptr sampler; +}; + +TEST_F(MatchingsSamplerTest, TestSuffixArrayRange) { + vector locations = {0, 1, 2, 3}; + PhraseLocation location(0, 3), result(locations, 2); + unordered_set blacklisted_sentence_ids; + EXPECT_CALL(*suffix_array_sampler, Sample(location, blacklisted_sentence_ids)) + .WillOnce(Return(result)); + EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(MatchingsSamplerTest, TestMatchings) { + vector locations = {0, 1, 2, 3}; + PhraseLocation location(locations, 2), result(locations, 2); + unordered_set blacklisted_sentence_ids; + EXPECT_CALL(*matchings_sampler, Sample(location, blacklisted_sentence_ids)) + .WillOnce(Return(result)); + EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index b79daae3..3e58e2a9 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -91,7 +91,6 @@ vector> Precomputation::FindMostFrequentPatterns( } } - shared_ptr data_array = suffix_array->GetData(); // Extract the most frequent patterns. vector> frequent_patterns; while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { @@ -99,7 +98,7 @@ vector> Precomputation::FindMostFrequentPatterns( int len = heap.top().second.second; heap.pop(); - vector pattern = data_array->GetWordIds(start, len); + vector pattern(data.begin() + start, data.begin() + start + len); if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == pattern.end()) { frequent_patterns.push_back(pattern); diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index d5f5ef63..3a98ce05 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -94,7 +94,7 @@ TEST_F(PrecomputationTest, TestCollocations) { EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {2, -1, 2, -1, 2}; + key = {2, -1, 2, -2, 2}; expected_value = {1, 5, 8, 5, 8, 11}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 5b66f685..18a60695 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -12,6 +12,7 @@ #include "phrase_builder.h" #include "rule.h" #include "rule_extractor.h" +#include "phrase_location_sampler.h" #include "sampler.h" #include "scorer.h" #include "suffix_array.h" @@ -68,7 +69,8 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( target_data_array, alignment, phrase_builder, scorer, vocabulary, max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true, false, require_tight_phrases); - sampler = make_shared(source_suffix_array, max_samples); + sampler = make_shared( + source_suffix_array, max_samples); } HieroCachingRuleFactory::HieroCachingRuleFactory( diff --git a/extractor/sampler.cc b/extractor/sampler.cc deleted file mode 100644 index 887aaec1..00000000 --- a/extractor/sampler.cc +++ /dev/null @@ -1,78 +0,0 @@ -#include "sampler.h" - -#include "phrase_location.h" -#include "suffix_array.h" - -namespace extractor { - -Sampler::Sampler(shared_ptr suffix_array, int max_samples) : - suffix_array(suffix_array), max_samples(max_samples) {} - -Sampler::Sampler() {} - -Sampler::~Sampler() {} - -PhraseLocation Sampler::Sample( - const PhraseLocation& location, - const unordered_set& blacklisted_sentence_ids) const { - shared_ptr source_data_array = suffix_array->GetData(); - vector sample; - int num_subpatterns; - if (location.matchings == NULL) { - // Sample suffix array range. - num_subpatterns = 1; - int low = location.sa_low, high = location.sa_high; - double step = max(1.0, (double) (high - low) / max_samples); - double i = low, last = i - 1; - while (sample.size() < max_samples && i < high) { - int x = suffix_array->GetSuffix(Round(i)); - int id = source_data_array->GetSentenceId(x); - bool found = false; - if (blacklisted_sentence_ids.count(id)) { - for (int backoff_step = 1; backoff_step <= step; ++backoff_step) { - double j = i - backoff_step; - x = suffix_array->GetSuffix(Round(j)); - id = source_data_array->GetSentenceId(x); - if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { - found = true; - last = i; - break; - } - double k = i + backoff_step; - x = suffix_array->GetSuffix(Round(k)); - id = source_data_array->GetSentenceId(x); - if (k < min(i+step, (double) high) && - !blacklisted_sentence_ids.count(id)) { - found = true; - last = k; - break; - } - } - } else { - found = true; - last = i; - } - if (found) sample.push_back(x); - i += step; - } - } else { - // Sample vector of occurrences. - num_subpatterns = location.num_subpatterns; - int num_matchings = location.matchings->size() / num_subpatterns; - double step = max(1.0, (double) num_matchings / max_samples); - for (double i = 0, num_samples = 0; - i < num_matchings && num_samples < max_samples; - i += step, ++num_samples) { - int start = Round(i) * num_subpatterns; - sample.insert(sample.end(), location.matchings->begin() + start, - location.matchings->begin() + start + num_subpatterns); - } - } - return PhraseLocation(sample, num_subpatterns); -} - -int Sampler::Round(double x) const { - return x + 0.5; -} - -} // namespace extractor diff --git a/extractor/sampler.h b/extractor/sampler.h index bd8a5876..3c4e37f1 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -4,38 +4,20 @@ #include #include -#include "data_array.h" - using namespace std; namespace extractor { class PhraseLocation; -class SuffixArray; /** - * Provides uniform sampling for a PhraseLocation. + * Base sampler class. */ class Sampler { public: - Sampler(shared_ptr suffix_array, int max_samples); - - virtual ~Sampler(); - - // Samples uniformly at most max_samples phrase occurrences. virtual PhraseLocation Sample( const PhraseLocation& location, - const unordered_set& blacklisted_sentence_ids) const; - - protected: - Sampler(); - - private: - // Round floating point number to the nearest integer. - int Round(double x) const; - - shared_ptr suffix_array; - int max_samples; + const unordered_set& blacklisted_sentence_ids) const = 0; }; } // namespace extractor diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc deleted file mode 100644 index 14e72780..00000000 --- a/extractor/sampler_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -#include - -#include - -#include "mocks/mock_suffix_array.h" -#include "mocks/mock_data_array.h" -#include "phrase_location.h" -#include "sampler.h" - -using namespace std; -using namespace ::testing; - -namespace extractor { -namespace { - -class SamplerTest : public Test { - protected: - virtual void SetUp() { - source_data_array = make_shared(); - EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999)); - suffix_array = make_shared(); - EXPECT_CALL(*suffix_array, GetData()) - .WillRepeatedly(Return(source_data_array)); - for (int i = 0; i < 10; ++i) { - EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); - } - } - - shared_ptr suffix_array; - shared_ptr sampler; - shared_ptr source_data_array; -}; - -TEST_F(SamplerTest, TestSuffixArrayRange) { - PhraseLocation location(0, 10); - unordered_set blacklist; - - sampler = make_shared(suffix_array, 1); - vector expected_locations = {0}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); - return; - - sampler = make_shared(suffix_array, 2); - expected_locations = {0, 5}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 3); - expected_locations = {0, 3, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 4); - expected_locations = {0, 3, 5, 8}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 100); - expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); -} - -TEST_F(SamplerTest, TestSubstringsSample) { - vector locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - unordered_set blacklist; - PhraseLocation location(locations, 2); - - sampler = make_shared(suffix_array, 1); - vector expected_locations = {0, 1}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 2); - expected_locations = {0, 1, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 3); - expected_locations = {0, 1, 4, 5, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 7); - expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), - sampler->Sample(location, blacklist)); -} - -} // namespace -} // namespace extractor diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc deleted file mode 100644 index 3305b990..00000000 --- a/extractor/sampler_test_blacklist.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include - -#include - -#include "mocks/mock_suffix_array.h" -#include "mocks/mock_data_array.h" -#include "phrase_location.h" -#include "sampler.h" - -using namespace std; -using namespace ::testing; - -namespace extractor { -namespace { - -class SamplerTestBlacklist : public Test { - protected: - virtual void SetUp() { - source_data_array = make_shared(); - for (int i = 0; i < 10; ++i) { - EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); - } - for (int i = -10; i < 0; ++i) { - EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0)); - } - suffix_array = make_shared(); - for (int i = -10; i < 10; ++i) { - EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); - } - } - - shared_ptr suffix_array; - shared_ptr sampler; - shared_ptr source_data_array; -}; - -TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) { - PhraseLocation location(0, 10); - unordered_set blacklist; - vector expected_locations; - - blacklist.insert(0); - sampler = make_shared(suffix_array, 1); - expected_locations = {1}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - for (int i = 0; i < 9; i++) { - blacklist.insert(i); - } - sampler = make_shared(suffix_array, 1); - expected_locations = {9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - blacklist.insert(5); - sampler = make_shared(suffix_array, 2); - expected_locations = {1, 4}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - blacklist.insert(1); - blacklist.insert(2); - blacklist.insert(3); - sampler = make_shared(suffix_array, 2); - expected_locations = {4, 5}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - blacklist.insert(3); - blacklist.insert(7); - sampler = make_shared(suffix_array, 3); - expected_locations = {1, 2, 6}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - blacklist.insert(3); - blacklist.insert(5); - blacklist.insert(8); - sampler = make_shared(suffix_array, 4); - expected_locations = {1, 2, 4, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - sampler = make_shared(suffix_array, 100); - expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(9); - sampler = make_shared(suffix_array, 100); - expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -} - -} // namespace -} // namespace extractor diff --git a/extractor/suffix_array_sampler.cc b/extractor/suffix_array_sampler.cc new file mode 100644 index 00000000..4a4ced34 --- /dev/null +++ b/extractor/suffix_array_sampler.cc @@ -0,0 +1,40 @@ +#include "suffix_array_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" +#include "suffix_array.h" + +namespace extractor { + +SuffixArrayRangeSampler::SuffixArrayRangeSampler( + shared_ptr source_suffix_array, int max_samples) : + BackoffSampler(source_suffix_array->GetData(), max_samples), + source_suffix_array(source_suffix_array) {} + +SuffixArrayRangeSampler::SuffixArrayRangeSampler() {} + +int SuffixArrayRangeSampler::GetNumSubpatterns(const PhraseLocation&) const { + return 1; +} + +int SuffixArrayRangeSampler::GetRangeLow( + const PhraseLocation& location) const { + return location.sa_low; +} + +int SuffixArrayRangeSampler::GetRangeHigh( + const PhraseLocation& location) const { + return location.sa_high; +} + +int SuffixArrayRangeSampler::GetPosition( + const PhraseLocation&, int position) const { + return source_suffix_array->GetSuffix(position); +} + +void SuffixArrayRangeSampler::AppendMatching( + vector& samples, int index, const PhraseLocation&) const { + samples.push_back(source_suffix_array->GetSuffix(index)); +} + +} // namespace extractor diff --git a/extractor/suffix_array_sampler.h b/extractor/suffix_array_sampler.h new file mode 100644 index 00000000..bb3c2653 --- /dev/null +++ b/extractor/suffix_array_sampler.h @@ -0,0 +1,34 @@ +#ifndef _SUFFIX_ARRAY_SAMPLER_H_ +#define _SUFFIX_ARRAY_SAMPLER_H_ + +#include "backoff_sampler.h" + +namespace extractor { + +class SuffixArray; + +class SuffixArrayRangeSampler : public BackoffSampler { + public: + SuffixArrayRangeSampler(shared_ptr suffix_array, + int max_samples); + + SuffixArrayRangeSampler(); + + private: + int GetNumSubpatterns(const PhraseLocation& location) const; + + int GetRangeLow(const PhraseLocation& location) const; + + int GetRangeHigh(const PhraseLocation& location) const; + + int GetPosition(const PhraseLocation& location, int index) const; + + void AppendMatching(vector& samples, int index, + const PhraseLocation& location) const; + + shared_ptr source_suffix_array; +}; + +} // namespace extractor + +#endif diff --git a/extractor/suffix_array_sampler_test.cc b/extractor/suffix_array_sampler_test.cc new file mode 100644 index 00000000..4b88c027 --- /dev/null +++ b/extractor/suffix_array_sampler_test.cc @@ -0,0 +1,114 @@ +#include + +#include + +#include "mocks/mock_data_array.h" +#include "mocks/mock_suffix_array.h" +#include "suffix_array_sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class SuffixArraySamplerTest : public Test { + protected: + virtual void SetUp() { + data_array = make_shared(); + for (int i = 0; i < 10; ++i) { + EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); + } + + suffix_array = make_shared(); + EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); + for (int i = 0; i < 10; ++i) { + EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); + } + } + + shared_ptr data_array; + shared_ptr suffix_array; +}; + +TEST_F(SuffixArraySamplerTest, TestSample) { + PhraseLocation location(0, 10); + unordered_set blacklisted_sentence_ids; + + SuffixArrayRangeSampler sampler(suffix_array, 1); + vector expected_locations = {0}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 2); + expected_locations = {0, 5}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 3); + expected_locations = {0, 3, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 4); + expected_locations = {0, 3, 5, 8}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 100); + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(SuffixArraySamplerTest, TestBackoffSample) { + PhraseLocation location(0, 10); + + SuffixArrayRangeSampler sampler(suffix_array, 1); + unordered_set blacklisted_sentence_ids = {0}; + vector expected_locations = {1}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + expected_locations = {9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 2); + blacklisted_sentence_ids = {0, 5}; + expected_locations = {1, 4}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 1, 2, 3}; + expected_locations = {4, 5}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 3); + blacklisted_sentence_ids = {0, 3, 7}; + expected_locations = {1, 2, 6}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 4); + blacklisted_sentence_ids = {0, 3, 5, 8}; + expected_locations = {1, 2, 4, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 100); + blacklisted_sentence_ids = {0}; + expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {9}; + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor -- cgit v1.2.3