diff options
Diffstat (limited to 'extractor')
29 files changed, 544 insertions, 222 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 65a3d436..7825012c 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -1,5 +1,5 @@ -bin_PROGRAMS = compile run_extractor +bin_PROGRAMS = compile run_extractor extract if HAVE_CXX11 @@ -24,7 +24,8 @@ EXTRA_PROGRAMS = alignment_test \ scorer_test \ suffix_array_test \ target_phrase_extractor_test \ - translation_table_test + translation_table_test \ + vocabulary_test if HAVE_GTEST RUNNABLE_TESTS = alignment_test \ @@ -48,12 +49,14 @@ if HAVE_GTEST scorer_test \ suffix_array_test \ target_phrase_extractor_test \ - translation_table_test + translation_table_test \ + vocabulary_test endif noinst_PROGRAMS = $(RUNNABLE_TESTS) -TESTS = $(RUNNABLE_TESTS) +# TESTS = $(RUNNABLE_TESTS) +TESTS = vocabulary_test alignment_test_SOURCES = alignment_test.cc alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a @@ -99,44 +102,17 @@ target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a translation_table_test_SOURCES = translation_table_test.cc translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +vocabulary_test_SOURCES = vocabulary_test.cc +vocabulary_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a -noinst_LIBRARIES = libextractor.a libcompile.a +noinst_LIBRARIES = libextractor.a compile_SOURCES = compile.cc -compile_LDADD = libcompile.a +compile_LDADD = libextractor.a run_extractor_SOURCES = run_extractor.cc run_extractor_LDADD = libextractor.a - -libcompile_a_SOURCES = \ - alignment.cc \ - data_array.cc \ - phrase_location.cc \ - precomputation.cc \ - suffix_array.cc \ - time_util.cc \ - translation_table.cc \ - alignment.h \ - data_array.h \ - fast_intersector.h \ - grammar.h \ - grammar_extractor.h \ - matchings_finder.h \ - matchings_trie.h \ - phrase.h \ - phrase_builder.h \ - phrase_location.h \ - precomputation.h \ - rule.h \ - rule_extractor.h \ - rule_extractor_helper.h \ - rule_factory.h \ - sampler.h \ - scorer.h \ - suffix_array.h \ - target_phrase_extractor.h \ - time_util.h \ - translation_table.h \ - vocabulary.h +extract_SOURCES = extract.cc +extract_LDADD = libextractor.a libextractor_a_SOURCES = \ alignment.cc \ diff --git a/extractor/compile.cc b/extractor/compile.cc index 0d62757e..3ee668ce 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -30,6 +30,8 @@ int main(int argc, char** argv) { ("bitext,b", po::value<string>(), "Parallel text (source ||| target)") ("alignment,a", po::value<string>()->required(), "Bitext word alignment") ("output,o", po::value<string>()->required(), "Output path") + ("config,c", po::value<string>()->required(), + "Path where the config file will be generated") ("frequent", po::value<int>()->default_value(100), "Number of precomputed frequent patterns") ("super_frequent", po::value<int>()->default_value(10), @@ -82,8 +84,12 @@ int main(int argc, char** argv) { target_data_array = make_shared<DataArray>(vm["target"].as<string>()); } + ofstream config_stream(vm["config"].as<string>()); + Clock::time_point start_write = Clock::now(); - ofstream target_fstream((output_dir / fs::path("target.bin")).string()); + string target_path = (output_dir / fs::path("target.bin")).string(); + config_stream << "target = " << target_path << endl; + ofstream target_fstream(target_path); ar::binary_oarchive target_stream(target_fstream); target_stream << *target_data_array; Clock::time_point stop_write = Clock::now(); @@ -100,7 +106,9 @@ int main(int argc, char** argv) { make_shared<SuffixArray>(source_data_array); start_write = Clock::now(); - ofstream source_fstream((output_dir / fs::path("source.bin")).string()); + string source_path = (output_dir / fs::path("source.bin")).string(); + config_stream << "source = " << source_path << endl; + ofstream source_fstream(source_path); ar::binary_oarchive output_stream(source_fstream); output_stream << *source_suffix_array; stop_write = Clock::now(); @@ -116,7 +124,9 @@ int main(int argc, char** argv) { make_shared<Alignment>(vm["alignment"].as<string>()); start_write = Clock::now(); - ofstream alignment_fstream((output_dir / fs::path("alignment.bin")).string()); + string alignment_path = (output_dir / fs::path("alignment.bin")).string(); + config_stream << "alignment = " << alignment_path << endl; + ofstream alignment_fstream(alignment_path); ar::binary_oarchive alignment_stream(alignment_fstream); alignment_stream << *alignment; stop_write = Clock::now(); @@ -126,7 +136,7 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; - shared_ptr<Vocabulary> vocabulary; + shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>(); start_time = Clock::now(); cerr << "Precomputing collocations..." << endl; @@ -142,9 +152,17 @@ int main(int argc, char** argv) { vm["min_frequency"].as<int>()); start_write = Clock::now(); - ofstream precomp_fstream((output_dir / fs::path("precomp.bin")).string()); + string precomputation_path = (output_dir / fs::path("precomp.bin")).string(); + config_stream << "precomputation = " << precomputation_path << endl; + ofstream precomp_fstream(precomputation_path); ar::binary_oarchive precomp_stream(precomp_fstream); precomp_stream << precomputation; + + string vocabulary_path = (output_dir / fs::path("vocab.bin")).string(); + config_stream << "vocabulary = " << vocabulary_path << endl; + ofstream vocab_fstream(vocabulary_path); + ar::binary_oarchive vocab_stream(vocab_fstream); + vocab_stream << *vocabulary; stop_write = Clock::now(); write_duration += GetDuration(start_write, stop_write); @@ -157,7 +175,9 @@ int main(int argc, char** argv) { TranslationTable table(source_data_array, target_data_array, alignment); start_write = Clock::now(); - ofstream table_fstream((output_dir / fs::path("bilex.bin")).string()); + string table_path = (output_dir / fs::path("bilex.bin")).string(); + config_stream << "ttable = " << table_path << endl; + ofstream table_fstream(table_path); ar::binary_oarchive table_stream(table_fstream); table_stream << table; stop_write = Clock::now(); diff --git a/extractor/data_array.cc b/extractor/data_array.cc index dacc4283..9612aa8a 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -127,10 +127,6 @@ int DataArray::GetSentenceId(int position) const { return sentence_id[position]; } -bool DataArray::HasWord(const string& word) const { - return word2id.count(word); -} - int DataArray::GetWordId(const string& word) const { auto result = word2id.find(word); return result == word2id.end() ? -1 : result->second; diff --git a/extractor/data_array.h b/extractor/data_array.h index e3823d18..b96901d1 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -73,9 +73,6 @@ class DataArray { // Returns the number of distinct words in the data array. virtual int GetVocabularySize() const; - // Returns whether a word has ever been observed in the data array. - virtual bool HasWord(const string& word) const; - // Returns the word id for a given word or -1 if it the word has never been // observed. virtual int GetWordId(const string& word) const; diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc index 7b085cd9..99f79d91 100644 --- a/extractor/data_array_test.cc +++ b/extractor/data_array_test.cc @@ -70,16 +70,12 @@ TEST_F(DataArrayTest, TestSubstrings) { TEST_F(DataArrayTest, TestVocabulary) { EXPECT_EQ(9, source_data.GetVocabularySize()); - EXPECT_TRUE(source_data.HasWord("mere")); EXPECT_EQ(4, source_data.GetWordId("mere")); EXPECT_EQ("mere", source_data.GetWord(4)); - EXPECT_FALSE(source_data.HasWord("banane")); EXPECT_EQ(11, target_data.GetVocabularySize()); - EXPECT_TRUE(target_data.HasWord("apples")); EXPECT_EQ(4, target_data.GetWordId("apples")); EXPECT_EQ("apples", target_data.GetWord(4)); - EXPECT_FALSE(target_data.HasWord("bananas")); } TEST_F(DataArrayTest, TestSentenceData) { diff --git a/extractor/extract.cc b/extractor/extract.cc new file mode 100644 index 00000000..387cbe9b --- /dev/null +++ b/extractor/extract.cc @@ -0,0 +1,253 @@ +#include <fstream> +#include <iostream> +#include <memory> +#include <string> +#include <vector> + +#include <boost/archive/binary_iarchive.hpp> +#include <boost/filesystem.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +#include <omp.h> + +#include "alignment.h" +#include "data_array.h" +#include "features/count_source_target.h" +#include "features/feature.h" +#include "features/is_source_singleton.h" +#include "features/is_source_target_singleton.h" +#include "features/max_lex_source_given_target.h" +#include "features/max_lex_target_given_source.h" +#include "features/sample_source_count.h" +#include "features/target_given_source_coherent.h" +#include "grammar.h" +#include "grammar_extractor.h" +#include "precomputation.h" +#include "rule.h" +#include "scorer.h" +#include "suffix_array.h" +#include "time_util.h" +#include "translation_table.h" +#include "vocabulary.h" + +namespace ar = boost::archive; +namespace fs = boost::filesystem; +namespace po = boost::program_options; +using namespace extractor; +using namespace features; +using namespace std; + +// Returns the file path in which a given grammar should be written. +fs::path GetGrammarFilePath(const fs::path& grammar_path, int file_number) { + string file_name = "grammar." + to_string(file_number); + return grammar_path / file_name; +} + +int main(int argc, char** argv) { + po::options_description general_options("General options"); + int max_threads = 1; + #pragma omp parallel + max_threads = omp_get_num_threads(); + string threads_option = "Number of threads used for grammar extraction " + "max(" + to_string(max_threads) + ")"; + general_options.add_options() + ("threads,t", po::value<int>()->required()->default_value(1), + threads_option.c_str()) + ("grammars,g", po::value<string>()->required(), "Grammars output path") + ("max_rule_span", po::value<int>()->default_value(15), + "Maximum rule span") + ("max_rule_symbols", po::value<int>()->default_value(5), + "Maximum number of symbols (terminals + nontermals) in a rule") + ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size") + ("max_nonterminals", po::value<int>()->default_value(2), + "Maximum number of nonterminals in a rule") + ("max_samples", po::value<int>()->default_value(300), + "Maximum number of samples") + ("tight_phrases", po::value<bool>()->default_value(true), + "False if phrases may be loose (better, but slower)") + ("leave_one_out", po::value<bool>()->zero_tokens(), + "do leave-one-out estimation of grammars " + "(e.g. for extracting grammars for the training set"); + + po::options_description cmdline_options("Command line options"); + cmdline_options.add_options() + ("help", "Show available options") + ("config,c", po::value<string>()->required(), "Path to config file"); + cmdline_options.add(general_options); + + po::options_description config_options("Config file options"); + config_options.add_options() + ("target", po::value<string>()->required(), + "Path to target data file in binary format") + ("source", po::value<string>()->required(), + "Path to source suffix array file in binary format") + ("alignment", po::value<string>()->required(), + "Path to alignment file in binary format") + ("precomputation", po::value<string>()->required(), + "Path to precomputation file in binary format") + ("vocabulary", po::value<string>()->required(), + "Path to vocabulary file in binary format") + ("ttable", po::value<string>()->required(), + "Path to translation table in binary format"); + config_options.add(general_options); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, cmdline_options), vm); + if (vm.count("help")) { + po::options_description all_options; + all_options.add(cmdline_options).add(config_options); + cout << all_options << endl; + return 0; + } + + po::notify(vm); + + ifstream config_stream(vm["config"].as<string>()); + po::store(po::parse_config_file(config_stream, config_options), vm); + po::notify(vm); + + int num_threads = vm["threads"].as<int>(); + cerr << "Grammar extraction will use " << num_threads << " threads." << endl; + + Clock::time_point read_start_time = Clock::now(); + + Clock::time_point start_time = Clock::now(); + cerr << "Reading target data in binary format..." << endl; + shared_ptr<DataArray> target_data_array = make_shared<DataArray>(); + ifstream target_fstream(vm["target"].as<string>()); + ar::binary_iarchive target_stream(target_fstream); + target_stream >> *target_data_array; + Clock::time_point end_time = Clock::now(); + cerr << "Reading target data took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading source suffix array in binary format..." << endl; + shared_ptr<SuffixArray> source_suffix_array = make_shared<SuffixArray>(); + ifstream source_fstream(vm["source"].as<string>()); + ar::binary_iarchive source_stream(source_fstream); + source_stream >> *source_suffix_array; + end_time = Clock::now(); + cerr << "Reading source suffix array took " + << GetDuration(start_time, end_time) << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading alignment in binary format..." << endl; + shared_ptr<Alignment> alignment = make_shared<Alignment>(); + ifstream alignment_fstream(vm["alignment"].as<string>()); + ar::binary_iarchive alignment_stream(alignment_fstream); + alignment_stream >> *alignment; + end_time = Clock::now(); + cerr << "Reading alignment took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading precomputation in binary format..." << endl; + shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(); + ifstream precomputation_fstream(vm["precomputation"].as<string>()); + ar::binary_iarchive precomputation_stream(precomputation_fstream); + precomputation_stream >> *precomputation; + end_time = Clock::now(); + cerr << "Reading precomputation took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading vocabulary in binary format..." << endl; + shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>(); + ifstream vocabulary_fstream(vm["vocabulary"].as<string>()); + ar::binary_iarchive vocabulary_stream(vocabulary_fstream); + vocabulary_stream >> *vocabulary; + end_time = Clock::now(); + cerr << "Reading vocabulary took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + start_time = Clock::now(); + cerr << "Reading translation table in binary format..." << endl; + shared_ptr<TranslationTable> table = make_shared<TranslationTable>(); + ifstream ttable_fstream(vm["ttable"].as<string>()); + ar::binary_iarchive ttable_stream(ttable_fstream); + ttable_stream >> *table; + end_time = Clock::now(); + cerr << "Reading translation table took " << GetDuration(start_time, end_time) + << " seconds" << endl; + + Clock::time_point read_end_time = Clock::now(); + cerr << "Total time spent loading data structures into memory: " + << GetDuration(read_start_time, read_end_time) << " seconds" << endl; + + Clock::time_point extraction_start_time = Clock::now(); + // Features used to score each grammar rule. + vector<shared_ptr<Feature>> features = { + make_shared<TargetGivenSourceCoherent>(), + make_shared<SampleSourceCount>(), + make_shared<CountSourceTarget>(), + make_shared<MaxLexSourceGivenTarget>(table), + make_shared<MaxLexTargetGivenSource>(table), + make_shared<IsSourceSingleton>(), + make_shared<IsSourceTargetSingleton>() + }; + shared_ptr<Scorer> scorer = make_shared<Scorer>(features); + + GrammarExtractor extractor( + source_suffix_array, + target_data_array, + alignment, + precomputation, + scorer, + vocabulary, + vm["min_gap_size"].as<int>(), + vm["max_rule_span"].as<int>(), + vm["max_nonterminals"].as<int>(), + vm["max_rule_symbols"].as<int>(), + vm["max_samples"].as<int>(), + vm["tight_phrases"].as<bool>()); + + // Creates the grammars directory if it doesn't exist. + fs::path grammar_path = vm["grammars"].as<string>(); + if (!fs::is_directory(grammar_path)) { + fs::create_directory(grammar_path); + } + + // Reads all sentences for which we extract grammar rules (the paralellization + // is simplified if we read all sentences upfront). + string sentence; + vector<string> sentences; + while (getline(cin, sentence)) { + sentences.push_back(sentence); + } + + // Extracts the grammar for each sentence and saves it to a file. + vector<string> suffixes(sentences.size()); + bool leave_one_out = vm.count("leave_one_out"); + #pragma omp parallel for schedule(dynamic) num_threads(num_threads) + for (size_t i = 0; i < sentences.size(); ++i) { + string suffix; + int position = sentences[i].find("|||"); + if (position != sentences[i].npos) { + suffix = sentences[i].substr(position); + sentences[i] = sentences[i].substr(0, position); + } + suffixes[i] = suffix; + + unordered_set<int> blacklisted_sentence_ids; + if (leave_one_out) { + blacklisted_sentence_ids.insert(i); + } + Grammar grammar = extractor.GetGrammar( + sentences[i], blacklisted_sentence_ids); + ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); + output << grammar; + } + + for (size_t i = 0; i < sentences.size(); ++i) { + cout << "<seg grammar=" << GetGrammarFilePath(grammar_path, i) << " id=\"" + << i << "\"> " << sentences[i] << " </seg> " << suffixes[i] << endl; + } + + Clock::time_point extraction_stop_time = Clock::now(); + cerr << "Overall extraction step took " + << GetDuration(extraction_start_time, extraction_stop_time) + << " seconds" << endl; + + return 0; +} diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 4d0738f7..1dc94c25 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -35,10 +35,12 @@ GrammarExtractor::GrammarExtractor( vocabulary(vocabulary), rule_factory(rule_factory) {} -Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { +Grammar GrammarExtractor::GetGrammar( + const string& sentence, + const unordered_set<int>& blacklisted_sentence_ids) { vector<string> words = TokenizeSentence(sentence); vector<int> word_ids = AnnotateWords(words); - return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids); } vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 8f570df2..0f3069b0 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -15,7 +15,6 @@ class DataArray; class Grammar; class HieroCachingRuleFactory; class Precomputation; -class Rule; class Scorer; class SuffixArray; class Vocabulary; @@ -46,7 +45,9 @@ class GrammarExtractor { // Converts the sentence to a vector of word ids and uses the RuleFactory to // extract the SCFG rules which may be used to decode the sentence. - Grammar GetGrammar(const string& sentence, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array); + Grammar GetGrammar( + const string& sentence, + const unordered_set<int>& blacklisted_sentence_ids); private: // Splits the sentence in a vector of words. diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc index f32a9599..719e90ff 100644 --- a/extractor/grammar_extractor_test.cc +++ b/extractor/grammar_extractor_test.cc @@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) { Grammar grammar(rules, feature_names); unordered_set<int> blacklisted_sentence_ids; shared_ptr<DataArray> source_data_array; - EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array)) + EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids)) .WillOnce(Return(grammar)); GrammarExtractor extractor(vocabulary, factory); string sentence = "Anna has many many apples ."; - extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array); + extractor.GetGrammar(sentence, blacklisted_sentence_ids); } } // namespace diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index 4bdcf21f..98e711d2 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -13,7 +13,6 @@ class MockDataArray : public DataArray { MOCK_CONST_METHOD2(GetWords, vector<string>(int start_index, int size)); MOCK_CONST_METHOD0(GetSize, int()); MOCK_CONST_METHOD0(GetVocabularySize, int()); - MOCK_CONST_METHOD1(HasWord, bool(const string& word)); MOCK_CONST_METHOD1(GetWordId, int(const string& word)); MOCK_CONST_METHOD1(GetWord, string(int word_id)); MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id)); diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 6b7b6586..53eb5022 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -7,9 +7,9 @@ namespace extractor { class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { public: - MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const - unordered_set<int>& blacklisted_sentence_ids, - const shared_ptr<DataArray> source_data_array)); + MOCK_METHOD2(GetGrammar, Grammar( + const vector<int>& word_ids, + const unordered_set<int>& blacklisted_sentence_ids)); }; } // namespace extractor diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h index 75c43c27..b2742f62 100644 --- a/extractor/mocks/mock_sampler.h +++ b/extractor/mocks/mock_sampler.h @@ -7,7 +7,9 @@ namespace extractor { class MockSampler : public Sampler { public: - MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); + MOCK_CONST_METHOD2(Sample, PhraseLocation( + const PhraseLocation& location, + const unordered_set<int>& blacklisted_sentence_ids)); }; } // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 38d8f489..b79daae3 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -5,60 +5,67 @@ #include "data_array.h" #include "suffix_array.h" +#include "time_util.h" #include "vocabulary.h" using namespace std; namespace extractor { -int Precomputation::NONTERMINAL = -1; - Precomputation::Precomputation( shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, int num_super_frequent_patterns, int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency) { + Clock::time_point start_time = Clock::now(); + shared_ptr<DataArray> data_array = suffix_array->GetData(); + vector<int> data = data_array->GetData(); vector<vector<int>> frequent_patterns = FindMostFrequentPatterns( - suffix_array, num_frequent_patterns, max_frequent_phrase_len, + suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, min_frequency); + Clock::time_point end_time = Clock::now(); + cerr << "Finding most frequent patterns took " + << GetDuration(start_time, end_time) << " seconds..." << endl; - // Construct sets containing the frequent and superfrequent contiguous - // collocations. - unordered_set<vector<int>, VectorHash> frequent_patterns_set; - unordered_set<vector<int>, VectorHash> super_frequent_patterns_set; + vector<vector<int>> pattern_annotations(frequent_patterns.size()); + unordered_map<vector<int>, int, VectorHash> frequent_patterns_index; for (size_t i = 0; i < frequent_patterns.size(); ++i) { - frequent_patterns_set.insert(frequent_patterns[i]); - if (i < num_super_frequent_patterns) { - super_frequent_patterns_set.insert(frequent_patterns[i]); - } + frequent_patterns_index[frequent_patterns[i]] = i; + pattern_annotations[i] = AnnotatePattern(vocabulary, data_array, + frequent_patterns[i]); } - shared_ptr<DataArray> data_array = suffix_array->GetData(); + start_time = Clock::now(); vector<tuple<int, int, int>> matchings; - for (size_t i = 0; i < data_array->GetSize(); ++i) { + vector<vector<int>> annotations; + for (size_t i = 0; i < data.size(); ++i) { // If the sentence is over, add all the discontiguous frequent patterns to // the index. - if (data_array->AtIndex(i) == DataArray::END_OF_LINE) { - UpdateIndex(data_array, vocabulary, matchings, max_rule_span, - min_gap_size, max_rule_symbols); + if (data[i] == DataArray::END_OF_LINE) { + UpdateIndex(matchings, annotations, max_rule_span, min_gap_size, + max_rule_symbols); matchings.clear(); + annotations.clear(); continue; } // Find all the contiguous frequent patterns starting at position i. vector<int> pattern; - for (int j = 1; - j <= max_frequent_phrase_len && i + j <= data_array->GetSize(); - ++j) { - pattern.push_back(data_array->AtIndex(i + j - 1)); - if (!frequent_patterns_set.count(pattern)) { + for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { + pattern.push_back(data[i + j - 1]); + auto it = frequent_patterns_index.find(pattern); + if (it == frequent_patterns_index.end()) { // If the current pattern is not frequent, any longer pattern having the // current pattern as prefix will not be frequent. break; } - int is_super_frequent = super_frequent_patterns_set.count(pattern); + int is_super_frequent = it->second < num_super_frequent_patterns; matchings.push_back(make_tuple(i, j, is_super_frequent)); + annotations.push_back(pattern_annotations[it->second]); } } + end_time = Clock::now(); + cerr << "Constructing collocations index took " + << GetDuration(start_time, end_time) << " seconds..." << endl; } Precomputation::Precomputation() {} @@ -66,8 +73,8 @@ Precomputation::Precomputation() {} Precomputation::~Precomputation() {} vector<vector<int>> Precomputation::FindMostFrequentPatterns( - shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, - int max_frequent_phrase_len, int min_frequency) { + shared_ptr<SuffixArray> suffix_array, const vector<int>& data, + int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) { vector<int> lcp = suffix_array->BuildLCPArray(); vector<int> run_start(max_frequent_phrase_len); @@ -76,9 +83,9 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( for (size_t i = 1; i < lcp.size(); ++i) { for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) { int frequency = i - run_start[len]; - if (frequency >= min_frequency) { - heap.push(make_pair(frequency, - make_pair(suffix_array->GetSuffix(run_start[len]), len + 1))); + int start = suffix_array->GetSuffix(run_start[len]); + if (frequency >= min_frequency && start + len <= data.size()) { + heap.push(make_pair(frequency, make_pair(start, len + 1))); } run_start[len] = i; } @@ -101,9 +108,20 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( return frequent_patterns; } +vector<int> Precomputation::AnnotatePattern( + shared_ptr<Vocabulary> vocabulary, shared_ptr<DataArray> data_array, + const vector<int>& pattern) const { + vector<int> annotation; + for (int word_id: pattern) { + annotation.push_back(vocabulary->GetTerminalIndex( + data_array->GetWord(word_id))); + } + return annotation; +} + void Precomputation::UpdateIndex( - shared_ptr<DataArray> data_array, shared_ptr<Vocabulary> vocabulary, const vector<tuple<int, int, int>>& matchings, + const vector<vector<int>>& annotations, int max_rule_span, int min_gap_size, int max_rule_symbols) { // Select the leftmost subpattern. for (size_t i = 0; i < matchings.size(); ++i) { @@ -121,15 +139,14 @@ void Precomputation::UpdateIndex( if (start2 - start1 - size1 >= min_gap_size && start2 + size2 - start1 <= max_rule_span && size1 + size2 + 1 <= max_rule_symbols) { - vector<int> pattern; - AppendSubpattern(pattern, data_array, vocabulary, start1, size1); - pattern.push_back(Precomputation::NONTERMINAL); - AppendSubpattern(pattern, data_array, vocabulary, start2, size2); - AppendCollocation(index[pattern], {start1, start2}); + vector<int> pattern = annotations[i]; + pattern.push_back(-1); + AppendSubpattern(pattern, annotations[j]); + AppendCollocation(index[pattern], start1, start2); // Try extending the binary collocation to a ternary collocation. if (is_super2) { - pattern.push_back(Precomputation::NONTERMINAL); + pattern.push_back(-2); // Select the rightmost subpattern. for (size_t k = j + 1; k < matchings.size(); ++k) { int start3, size3, is_super3; @@ -142,8 +159,8 @@ void Precomputation::UpdateIndex( && start3 + size3 - start1 <= max_rule_span && size1 + size2 + size3 + 2 <= max_rule_symbols && (is_super1 || is_super3)) { - AppendSubpattern(pattern, data_array, vocabulary, start3, size3); - AppendCollocation(index[pattern], {start1, start2, start3}); + AppendSubpattern(pattern, annotations[k]); + AppendCollocation(index[pattern], start1, start2, start3); pattern.erase(pattern.end() - size3); } } @@ -154,17 +171,22 @@ void Precomputation::UpdateIndex( } void Precomputation::AppendSubpattern( - vector<int>& pattern, shared_ptr<DataArray> data_array, - shared_ptr<Vocabulary> vocabulary, int start, int size) { - vector<string> words = data_array->GetWords(start, size); - for (const string& word: words) { - pattern.push_back(vocabulary->GetTerminalIndex(word)); - } + vector<int>& pattern, + const vector<int>& subpattern) { + copy(subpattern.begin(), subpattern.end(), back_inserter(pattern)); +} + +void Precomputation::AppendCollocation( + vector<int>& collocations, int pos1, int pos2) { + collocations.push_back(pos1); + collocations.push_back(pos2); } void Precomputation::AppendCollocation( - vector<int>& collocations, const vector<int>& collocation) { - copy(collocation.begin(), collocation.end(), back_inserter(collocations)); + vector<int>& collocations, int pos1, int pos2, int pos3) { + collocations.push_back(pos1); + collocations.push_back(pos2); + collocations.push_back(pos3); } bool Precomputation::Contains(const vector<int>& pattern) const { diff --git a/extractor/precomputation.h b/extractor/precomputation.h index 6ade58df..2b34fc29 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -55,28 +55,32 @@ class Precomputation { bool operator==(const Precomputation& other) const; - static int NONTERMINAL; - private: // Finds the most frequent contiguous collocations. vector<vector<int>> FindMostFrequentPatterns( - shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, - int max_frequent_phrase_len, int min_frequency); + shared_ptr<SuffixArray> suffix_array, const vector<int>& data, + int num_frequent_patterns, int max_frequent_phrase_len, + int min_frequency); + + vector<int> AnnotatePattern(shared_ptr<Vocabulary> vocabulary, + shared_ptr<DataArray> data_array, + const vector<int>& pattern) const; // Given the locations of the frequent contiguous collocations in a sentence, // it adds new entries to the index for each discontiguous collocation // matching the criteria specified in the class description. void UpdateIndex( - shared_ptr<DataArray> data_array, shared_ptr<Vocabulary> vocabulary, const vector<tuple<int, int, int>>& matchings, + const vector<vector<int>>& annotations, int max_rule_span, int min_gap_size, int max_rule_symbols); - void AppendSubpattern( - vector<int>& pattern, shared_ptr<DataArray> data_array, - shared_ptr<Vocabulary> vocabulary, int start, int size); + void AppendSubpattern(vector<int>& pattern, const vector<int>& subpattern); + + // Adds an occurrence of a binary collocation. + void AppendCollocation(vector<int>& collocations, int pos1, int pos2); - // Adds an occurrence of a collocation. - void AppendCollocation(vector<int>& collocations, const vector<int>& collocation); + // Adds an occurrence of a ternary collocation. + void AppendCollocation(vector<int>& collocations, int pos1, int pos2, int pos3); friend class boost::serialization::access; diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index fd85fcf8..d5f5ef63 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -24,31 +24,12 @@ class PrecomputationTest : public Test { virtual void SetUp() { data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1}; data_array = make_shared<MockDataArray>(); - EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(data.size())); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data)); for (size_t i = 0; i < data.size(); ++i) { EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); } - vector<pair<int, int>> expected_calls = {{8, 1}, {8, 2}, {6, 1}}; - for (const auto& call: expected_calls) { - int start = call.first; - int size = call.second; - vector<int> word_ids(data.begin() + start, data.begin() + start + size); - EXPECT_CALL(*data_array, GetWordIds(start, size)) - .WillRepeatedly(Return(word_ids)); - } - - expected_calls = {{1, 1}, {5, 1}, {8, 1}, {9, 1}, {5, 2}, - {6, 1}, {8, 2}, {1, 2}, {2, 1}, {11, 1}}; - for (const auto& call: expected_calls) { - int start = call.first; - int size = call.second; - vector<string> words; - for (size_t j = start; j < start + size; ++j) { - words.push_back(to_string(data[j])); - } - EXPECT_CALL(*data_array, GetWords(start, size)) - .WillRepeatedly(Return(words)); - } + EXPECT_CALL(*data_array, GetWord(2)).WillRepeatedly(Return("2")); + EXPECT_CALL(*data_array, GetWord(3)).WillRepeatedly(Return("3")); vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13}; vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0}; @@ -117,37 +98,37 @@ TEST_F(PrecomputationTest, TestCollocations) { expected_value = {1, 5, 8, 5, 8, 11}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {2, -1, 2, -1, 3}; + key = {2, -1, 2, -2, 3}; expected_value = {1, 5, 9}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {2, -1, 3, -1, 2}; + key = {2, -1, 3, -2, 2}; expected_value = {1, 6, 8, 5, 9, 11}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {2, -1, 3, -1, 3}; + key = {2, -1, 3, -2, 3}; expected_value = {1, 6, 9}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {3, -1, 2, -1, 2}; + key = {3, -1, 2, -2, 2}; expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {3, -1, 2, -1, 3}; + key = {3, -1, 2, -2, 3}; expected_value = {2, 5, 9}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {3, -1, 3, -1, 2}; + key = {3, -1, 3, -2, 2}; expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {3, -1, 3, -1, 3}; + key = {3, -1, 3, -2, 3}; expected_value = {2, 6, 9}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); // Exceeds max_rule_symbols. - key = {2, -1, 2, -1, 2, 3}; + key = {2, -1, 2, -2, 2, 3}; EXPECT_FALSE(precomputation.Contains(key)); // Contains non frequent pattern. key = {2, -1, 5}; diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 6ae2d792..5b66f685 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -101,7 +101,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {} HieroCachingRuleFactory::~HieroCachingRuleFactory() {} -Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { +Grammar HieroCachingRuleFactory::GetGrammar( + const vector<int>& word_ids, + const unordered_set<int>& blacklisted_sentence_ids) { Clock::time_point start_time = Clock::now(); double total_extract_time = 0; double total_intersect_time = 0; @@ -193,7 +195,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const u Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { // Extract rules for the sampled set of occurrences. - PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array); + PhraseLocation sample = sampler->Sample( + next_node->matchings, blacklisted_sentence_ids); vector<Rule> new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index a1ff76e4..1a9fa2af 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -74,8 +74,7 @@ class HieroCachingRuleFactory { // (See class description for more details.) virtual Grammar GetGrammar( const vector<int>& word_ids, - const unordered_set<int>& blacklisted_sentence_ids, - const shared_ptr<DataArray> source_data_array); + const unordered_set<int>& blacklisted_sentence_ids); protected: HieroCachingRuleFactory(); diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index f26cc567..332c5959 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -40,7 +40,7 @@ class RuleFactoryTest : public Test { .WillRepeatedly(Return(feature_names)); sampler = make_shared<MockSampler>(); - EXPECT_CALL(*sampler, Sample(_)) + EXPECT_CALL(*sampler, Sample(_, _)) .WillRepeatedly(Return(PhraseLocation(0, 1))); Phrase phrase; @@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { vector<int> word_ids = {2, 3, 4}; unordered_set<int> blacklisted_sentence_ids; - shared_ptr<DataArray> source_data_array; - Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(7, grammar.GetRules().size()); } @@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { vector<int> word_ids = {2, 3, 4, 2, 3}; unordered_set<int> blacklisted_sentence_ids; - shared_ptr<DataArray> source_data_array; - Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(28, grammar.GetRules().size()); } diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 85c8a422..f1aa5e35 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -5,10 +5,10 @@ #include <string> #include <vector> -#include <omp.h> #include <boost/filesystem.hpp> #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include <omp.h> #include "alignment.h" #include "data_array.h" @@ -78,7 +78,8 @@ int main(int argc, char** argv) { ("tight_phrases", po::value<bool>()->default_value(true), "False if phrases may be loose (better, but slower)") ("leave_one_out", po::value<bool>()->zero_tokens(), - "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set"); + "do leave-one-out estimation of grammars " + "(e.g. for extracting grammars for the training set"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -99,11 +100,6 @@ int main(int argc, char** argv) { return 1; } - bool leave_one_out = false; - if (vm.count("leave_one_out")) { - leave_one_out = true; - } - int num_threads = vm["threads"].as<int>(); cerr << "Grammar extraction will use " << num_threads << " threads." << endl; @@ -178,8 +174,8 @@ int main(int argc, char** argv) { << GetDuration(preprocess_start_time, preprocess_stop_time) << " seconds" << endl; - // Features used to score each grammar rule. Clock::time_point extraction_start_time = Clock::now(); + // Features used to score each grammar rule. vector<shared_ptr<Feature>> features = { make_shared<TargetGivenSourceCoherent>(), make_shared<SampleSourceCount>(), @@ -206,9 +202,6 @@ int main(int argc, char** argv) { vm["max_samples"].as<int>(), vm["tight_phrases"].as<bool>()); - // Releases extra memory used by the initial precomputation. - precomputation.reset(); - // Creates the grammars directory if it doesn't exist. fs::path grammar_path = vm["grammars"].as<string>(); if (!fs::is_directory(grammar_path)) { @@ -224,6 +217,7 @@ int main(int argc, char** argv) { } // Extracts the grammar for each sentence and saves it to a file. + bool leave_one_out = vm.count("leave_one_out"); vector<string> suffixes(sentences.size()); #pragma omp parallel for schedule(dynamic) num_threads(num_threads) for (size_t i = 0; i < sentences.size(); ++i) { @@ -236,8 +230,11 @@ int main(int argc, char** argv) { suffixes[i] = suffix; unordered_set<int> blacklisted_sentence_ids; - if (leave_one_out) blacklisted_sentence_ids.insert(i); - Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array); + if (leave_one_out) { + blacklisted_sentence_ids.insert(i); + } + Grammar grammar = extractor.GetGrammar( + sentences[i], blacklisted_sentence_ids); ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); output << grammar; } diff --git a/extractor/sampler.cc b/extractor/sampler.cc index 963afa7a..887aaec1 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -12,7 +12,10 @@ Sampler::Sampler() {} Sampler::~Sampler() {} -PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const { +PhraseLocation Sampler::Sample( + const PhraseLocation& location, + const unordered_set<int>& blacklisted_sentence_ids) const { + shared_ptr<DataArray> source_data_array = suffix_array->GetData(); vector<int> sample; int num_subpatterns; if (location.matchings == NULL) { @@ -20,30 +23,30 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s num_subpatterns = 1; int low = location.sa_low, high = location.sa_high; double step = max(1.0, (double) (high - low) / max_samples); - double i = low, last = i; - bool found; + double i = low, last = i - 1; while (sample.size() < max_samples && i < high) { int x = suffix_array->GetSuffix(Round(i)); int id = source_data_array->GetSentenceId(x); - if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { - found = false; - double backoff_step = 1; - while (true) { - if ((double)backoff_step >= step) break; + bool found = false; + if (blacklisted_sentence_ids.count(id)) { + for (int backoff_step = 1; backoff_step <= step; ++backoff_step) { double j = i - backoff_step; x = suffix_array->GetSuffix(Round(j)); id = source_data_array->GetSentenceId(x); - if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { - found = true; last = i; break; + if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { + found = true; + last = i; + break; } double k = i + backoff_step; x = suffix_array->GetSuffix(Round(k)); id = source_data_array->GetSentenceId(x); - if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { - found = true; last = k; break; + if (k < min(i+step, (double) high) && + !blacklisted_sentence_ids.count(id)) { + found = true; + last = k; + break; } - if (j <= last && k >= high) break; - backoff_step++; } } else { found = true; diff --git a/extractor/sampler.h b/extractor/sampler.h index de450c48..bd8a5876 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -23,7 +23,9 @@ class Sampler { virtual ~Sampler(); // Samples uniformly at most max_samples phrase occurrences. - virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int>& blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const; + virtual PhraseLocation Sample( + const PhraseLocation& location, + const unordered_set<int>& blacklisted_sentence_ids) const; protected: Sampler(); diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc index 965567ba..14e72780 100644 --- a/extractor/sampler_test.cc +++ b/extractor/sampler_test.cc @@ -19,6 +19,8 @@ class SamplerTest : public Test { source_data_array = make_shared<MockDataArray>(); EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999)); suffix_array = make_shared<MockSuffixArray>(); + EXPECT_CALL(*suffix_array, GetData()) + .WillRepeatedly(Return(source_data_array)); for (int i = 0; i < 10; ++i) { EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); } @@ -35,23 +37,29 @@ TEST_F(SamplerTest, TestSuffixArrayRange) { sampler = make_shared<Sampler>(suffix_array, 1); vector<int> expected_locations = {0}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); + return; sampler = make_shared<Sampler>(suffix_array, 2); expected_locations = {0, 5}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 3); expected_locations = {0, 3, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 4); expected_locations = {0, 3, 5, 8}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 100); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); } TEST_F(SamplerTest, TestSubstringsSample) { @@ -61,19 +69,23 @@ TEST_F(SamplerTest, TestSubstringsSample) { sampler = make_shared<Sampler>(suffix_array, 1); vector<int> expected_locations = {0, 1}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 2); expected_locations = {0, 1, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 3); expected_locations = {0, 1, 4, 5, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared<Sampler>(suffix_array, 7); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); } } // namespace diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc index ac230d13..4a514b12 100644 --- a/extractor/suffix_array.cc +++ b/extractor/suffix_array.cc @@ -187,12 +187,12 @@ shared_ptr<DataArray> SuffixArray::GetData() const { PhraseLocation SuffixArray::Lookup(int low, int high, const string& word, int offset) const { - if (!data_array->HasWord(word)) { + int word_id = data_array->GetWordId(word); + if (word_id == -1) { // Return empty phrase location. return PhraseLocation(0, 0); } - int word_id = data_array->GetWordId(word); if (offset == 0) { return PhraseLocation(word_start[word_id], word_start[word_id + 1]); } diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index a9fd1eab..161edbc0 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -55,22 +55,18 @@ TEST_F(SuffixArrayTest, TestLookup) { EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); } - EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6)); EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0)); - EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false)); + EXPECT_CALL(*data_array, GetWordId("word2")).WillRepeatedly(Return(-1)); EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0)); - EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4)); EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1)); - EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1)); EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2)); - EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2)); EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3)); diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index 1b1ba112..11e29e1e 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -90,13 +90,12 @@ void TranslationTable::IncrementLinksCount( double TranslationTable::GetTargetGivenSourceScore( const string& source_word, const string& target_word) { - if (!source_data_array->HasWord(source_word) || - !target_data_array->HasWord(target_word)) { + int source_id = source_data_array->GetWordId(source_word); + int target_id = target_data_array->GetWordId(target_word); + if (source_id == -1 || target_id == -1) { return -1; } - int source_id = source_data_array->GetWordId(source_word); - int target_id = target_data_array->GetWordId(target_word); auto entry = make_pair(source_id, target_id); auto it = translation_probabilities.find(entry); if (it == translation_probabilities.end()) { @@ -107,13 +106,12 @@ double TranslationTable::GetTargetGivenSourceScore( double TranslationTable::GetSourceGivenTargetScore( const string& source_word, const string& target_word) { - if (!source_data_array->HasWord(source_word) || - !target_data_array->HasWord(target_word)) { + int source_id = source_data_array->GetWordId(source_word); + int target_id = target_data_array->GetWordId(target_word); + if (source_id == -1 || target_id == -1) { return -1; } - int source_id = source_data_array->GetWordId(source_word); - int target_id = target_data_array->GetWordId(target_word); auto entry = make_pair(source_id, target_id); auto it = translation_probabilities.find(entry); if (it == translation_probabilities.end()) { diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc index 72551a12..3cfc0011 100644 --- a/extractor/translation_table_test.cc +++ b/extractor/translation_table_test.cc @@ -36,13 +36,10 @@ class TranslationTableTest : public Test { .WillRepeatedly(Return(source_sentence_start[i])); } for (size_t i = 0; i < words.size(); ++i) { - EXPECT_CALL(*source_data_array, HasWord(words[i])) - .WillRepeatedly(Return(true)); EXPECT_CALL(*source_data_array, GetWordId(words[i])) .WillRepeatedly(Return(i + 2)); } - EXPECT_CALL(*source_data_array, HasWord("d")) - .WillRepeatedly(Return(false)); + EXPECT_CALL(*source_data_array, GetWordId("d")).WillRepeatedly(Return(-1)); vector<int> target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0}; vector<int> target_sentence_start = {0, 7, 10, 13}; @@ -54,13 +51,10 @@ class TranslationTableTest : public Test { .WillRepeatedly(Return(target_sentence_start[i])); } for (size_t i = 0; i < words.size(); ++i) { - EXPECT_CALL(*target_data_array, HasWord(words[i])) - .WillRepeatedly(Return(true)); EXPECT_CALL(*target_data_array, GetWordId(words[i])) .WillRepeatedly(Return(i + 2)); } - EXPECT_CALL(*target_data_array, HasWord("d")) - .WillRepeatedly(Return(false)); + EXPECT_CALL(*target_data_array, GetWordId("d")).WillRepeatedly(Return(-1)); vector<pair<int, int>> links1 = { make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc index 15795d1e..c9c2d6f4 100644 --- a/extractor/vocabulary.cc +++ b/extractor/vocabulary.cc @@ -8,12 +8,13 @@ int Vocabulary::GetTerminalIndex(const string& word) { int word_id = -1; #pragma omp critical (vocabulary) { - if (!dictionary.count(word)) { + auto it = dictionary.find(word); + if (it != dictionary.end()) { + word_id = it->second; + } else { word_id = words.size(); dictionary[word] = word_id; words.push_back(word); - } else { - word_id = dictionary[word]; } } return word_id; @@ -34,4 +35,8 @@ string Vocabulary::GetTerminalValue(int symbol) { return word; } +bool Vocabulary::operator==(const Vocabulary& other) const { + return words == other.words && dictionary == other.dictionary; +} + } // namespace extractor diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h index c8fd9411..db092e99 100644 --- a/extractor/vocabulary.h +++ b/extractor/vocabulary.h @@ -5,6 +5,10 @@ #include <unordered_map> #include <vector> +#include <boost/serialization/serialization.hpp> +#include <boost/serialization/string.hpp> +#include <boost/serialization/vector.hpp> + using namespace std; namespace extractor { @@ -14,7 +18,7 @@ namespace extractor { * * This strucure contains words located in the frequent collocations and words * encountered during the grammar extraction time. This dictionary is - * considerably smaller than the dictionaries in the data arrays (and so is the + * considerably smaller than the dictionaries in the data arays (and so is the * query time). Note that this is the single data structure that changes state * and needs to have thread safe read/write operations. * @@ -38,7 +42,24 @@ class Vocabulary { // Returns the word corresponding to the given word id. virtual string GetTerminalValue(int symbol); + bool operator==(const Vocabulary& vocabulary) const; + private: + friend class boost::serialization::access; + + template<class Archive> void save(Archive& ar, unsigned int) const { + ar << words; + } + + template<class Archive> void load(Archive& ar, unsigned int) { + ar >> words; + for (size_t i = 0; i < words.size(); ++i) { + dictionary[words[i]] = i; + } + } + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + unordered_map<string, int> dictionary; vector<string> words; }; diff --git a/extractor/vocabulary_test.cc b/extractor/vocabulary_test.cc new file mode 100644 index 00000000..cf5e3e36 --- /dev/null +++ b/extractor/vocabulary_test.cc @@ -0,0 +1,45 @@ +#include <gtest/gtest.h> + +#include <sstream> +#include <string> +#include <vector> + +#include <boost/archive/text_iarchive.hpp> +#include <boost/archive/text_oarchive.hpp> + +#include "vocabulary.h" + +using namespace std; +using namespace ::testing; +namespace ar = boost::archive; + +namespace extractor { +namespace { + +TEST(VocabularyTest, TestIndexes) { + Vocabulary vocabulary; + EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero")); + EXPECT_EQ("zero", vocabulary.GetTerminalValue(0)); + + EXPECT_EQ(1, vocabulary.GetTerminalIndex("one")); + EXPECT_EQ("one", vocabulary.GetTerminalValue(1)); +} + +TEST(VocabularyTest, TestSerialization) { + Vocabulary vocabulary; + EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero")); + EXPECT_EQ("zero", vocabulary.GetTerminalValue(0)); + + stringstream stream(ios_base::out | ios_base::in); + ar::text_oarchive output_stream(stream, ar::no_header); + output_stream << vocabulary; + + Vocabulary vocabulary_copy; + ar::text_iarchive input_stream(stream, ar::no_header); + input_stream >> vocabulary_copy; + + EXPECT_EQ(vocabulary, vocabulary_copy); +} + +} // namespace +} // namespace extractor |