From 79206291f78fba893fda6a61ff0ae9264d00bb82 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Sat, 23 Nov 2013 18:39:39 +0000 Subject: Fix broken extractor test. --- extractor/mocks/mock_rule_factory.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'extractor/mocks') diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 86a084b5..6b7b6586 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -7,7 +7,9 @@ namespace extractor { class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { public: - MOCK_METHOD3(GetGrammar, Grammar(const vector& word_ids, const unordered_set blacklisted_sentence_ids, const shared_ptr source_data_array)); + MOCK_METHOD3(GetGrammar, Grammar(const vector& word_ids, const + unordered_set& blacklisted_sentence_ids, + const shared_ptr source_data_array)); }; } // namespace extractor -- cgit v1.2.3 From f528ac27dab11770f01595b043675dba2947a263 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Sun, 24 Nov 2013 13:19:28 +0000 Subject: Reduce memory overhead for constructing the intersector. --- extractor/compile.cc | 4 ++ extractor/data_array.cc | 14 ++++- extractor/data_array.h | 10 +++- extractor/data_array_test.cc | 12 ++++ extractor/fast_intersector.cc | 40 ++++--------- extractor/fast_intersector.h | 8 +-- extractor/fast_intersector_test.cc | 10 ++-- extractor/grammar_extractor.cc | 5 +- extractor/grammar_extractor.h | 1 + extractor/mocks/mock_data_array.h | 4 +- extractor/mocks/mock_precomputation.h | 3 +- extractor/precomputation.cc | 96 +++++++++++++++-------------- extractor/precomputation.h | 45 +++++++------- extractor/precomputation_test.cc | 110 ++++++++++++++++++++++++---------- extractor/run_extractor.cc | 5 ++ extractor/suffix_array_test.cc | 2 +- extractor/translation_table_test.cc | 4 +- 17 files changed, 225 insertions(+), 148 deletions(-) (limited to 'extractor/mocks') diff --git a/extractor/compile.cc b/extractor/compile.cc index 65fdd509..0d62757e 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -13,6 +13,7 @@ #include "suffix_array.h" #include "time_util.h" #include "translation_table.h" +#include "vocabulary.h" namespace ar = boost::archive; namespace fs = boost::filesystem; @@ -125,9 +126,12 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; + shared_ptr vocabulary; + start_time = Clock::now(); cerr << "Precomputing collocations..." << endl; Precomputation precomputation( + vocabulary, source_suffix_array, vm["frequent"].as(), vm["super_frequent"].as(), diff --git a/extractor/data_array.cc b/extractor/data_array.cc index 82efcd51..dacc4283 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -78,7 +78,7 @@ void DataArray::CreateDataArray(const vector& lines) { DataArray::~DataArray() {} -const vector& DataArray::GetData() const { +vector DataArray::GetData() const { return data; } @@ -90,6 +90,18 @@ string DataArray::GetWordAtIndex(int index) const { return id2word[data[index]]; } +vector DataArray::GetWordIds(int index, int size) const { + return vector(data.begin() + index, data.begin() + index + size); +} + +vector DataArray::GetWords(int start_index, int size) const { + vector words; + for (int word_id: GetWordIds(start_index, size)) { + words.push_back(id2word[word_id]); + } + return words; +} + int DataArray::GetSize() const { return data.size(); } diff --git a/extractor/data_array.h b/extractor/data_array.h index 5207366d..e3823d18 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -51,7 +51,7 @@ class DataArray { virtual ~DataArray(); // Returns a vector containing the word ids. - virtual const vector& GetData() const; + virtual vector GetData() const; // Returns the word id at the specified position. virtual int AtIndex(int index) const; @@ -59,6 +59,14 @@ class DataArray { // Returns the original word at the specified position. virtual string GetWordAtIndex(int index) const; + // Returns the substring of word ids starting at the specified position and + // having the specified length. + virtual vector GetWordIds(int start_index, int size) const; + + // Returns the substring of words starting at the specified position and + // having the specified length. + virtual vector GetWords(int start_index, int size) const; + // Returns the size of the data array. virtual int GetSize() const; diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc index 6c329e34..7b085cd9 100644 --- a/extractor/data_array_test.cc +++ b/extractor/data_array_test.cc @@ -56,6 +56,18 @@ TEST_F(DataArrayTest, TestGetData) { } } +TEST_F(DataArrayTest, TestSubstrings) { + vector expected_word_ids = {3, 4, 5}; + vector expected_words = {"are", "mere", "."}; + EXPECT_EQ(expected_word_ids, source_data.GetWordIds(1, 3)); + EXPECT_EQ(expected_words, source_data.GetWords(1, 3)); + + expected_word_ids = {7, 8}; + expected_words = {"a", "lot"}; + EXPECT_EQ(expected_word_ids, target_data.GetWordIds(7, 2)); + EXPECT_EQ(expected_words, target_data.GetWords(7, 2)); +} + TEST_F(DataArrayTest, TestVocabulary) { EXPECT_EQ(9, source_data.GetVocabularySize()); EXPECT_TRUE(source_data.HasWord("mere")); diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc index a8591a72..0d1fa6d8 100644 --- a/extractor/fast_intersector.cc +++ b/extractor/fast_intersector.cc @@ -11,41 +11,22 @@ namespace extractor { -FastIntersector::FastIntersector(shared_ptr suffix_array, - shared_ptr precomputation, - shared_ptr vocabulary, - int max_rule_span, - int min_gap_size) : +FastIntersector::FastIntersector( + shared_ptr suffix_array, + shared_ptr precomputation, + shared_ptr vocabulary, + int max_rule_span, + int min_gap_size) : suffix_array(suffix_array), + precomputation(precomputation), vocabulary(vocabulary), max_rule_span(max_rule_span), - min_gap_size(min_gap_size) { - Index precomputed_collocations = precomputation->GetCollocations(); - for (pair, vector> entry: precomputed_collocations) { - vector phrase = ConvertPhrase(entry.first); - collocations[phrase] = entry.second; - } -} + min_gap_size(min_gap_size) {} FastIntersector::FastIntersector() {} FastIntersector::~FastIntersector() {} -vector FastIntersector::ConvertPhrase(const vector& old_phrase) { - vector new_phrase; - new_phrase.reserve(old_phrase.size()); - shared_ptr data_array = suffix_array->GetData(); - for (int word_id: old_phrase) { - if (word_id < 0) { - new_phrase.push_back(word_id); - } else { - new_phrase.push_back( - vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); - } - } - return new_phrase; -} - PhraseLocation FastIntersector::Intersect( PhraseLocation& prefix_location, PhraseLocation& suffix_location, @@ -59,8 +40,9 @@ PhraseLocation FastIntersector::Intersect( assert(vocabulary->IsTerminal(symbols.front()) && vocabulary->IsTerminal(symbols.back())); - if (collocations.count(symbols)) { - return PhraseLocation(collocations[symbols], phrase.Arity() + 1); + if (precomputation->Contains(symbols)) { + return PhraseLocation(precomputation->GetCollocations(symbols), + phrase.Arity() + 1); } bool prefix_ends_with_x = diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h index 2819d239..305373dc 100644 --- a/extractor/fast_intersector.h +++ b/extractor/fast_intersector.h @@ -12,7 +12,6 @@ using namespace std; namespace extractor { typedef boost::hash> VectorHash; -typedef unordered_map, vector, VectorHash> Index; class Phrase; class PhraseLocation; @@ -52,11 +51,6 @@ class FastIntersector { FastIntersector(); private: - // Uses the vocabulary to convert the phrase from the numberized format - // specified by the source data array to the numberized format given by the - // vocabulary. - vector ConvertPhrase(const vector& old_phrase); - // Estimates the number of computations needed if the prefix/suffix is // extended. If the last/first symbol is separated from the rest of the phrase // by a nonterminal, then for each occurrence of the prefix/suffix we need to @@ -85,10 +79,10 @@ class FastIntersector { pair GetSearchRange(bool has_marginal_x) const; shared_ptr suffix_array; + shared_ptr precomputation; shared_ptr vocabulary; int max_rule_span; int min_gap_size; - Index collocations; }; } // namespace extractor diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc index 76c3aaea..f2a26ba1 100644 --- a/extractor/fast_intersector_test.cc +++ b/extractor/fast_intersector_test.cc @@ -59,15 +59,13 @@ class FastIntersectorTest : public Test { } precomputation = make_shared(); - EXPECT_CALL(*precomputation, GetCollocations()) - .WillRepeatedly(ReturnRef(collocations)); + EXPECT_CALL(*precomputation, Contains(_)).WillRepeatedly(Return(false)); phrase_builder = make_shared(vocabulary); intersector = make_shared(suffix_array, precomputation, vocabulary, 15, 1); } - Index collocations; shared_ptr data_array; shared_ptr suffix_array; shared_ptr precomputation; @@ -82,9 +80,9 @@ TEST_F(FastIntersectorTest, TestCachedCollocation) { Phrase phrase = phrase_builder->Build(symbols); PhraseLocation prefix_location(15, 16), suffix_location(16, 17); - collocations[symbols] = expected_location; - EXPECT_CALL(*precomputation, GetCollocations()) - .WillRepeatedly(ReturnRef(collocations)); + EXPECT_CALL(*precomputation, Contains(symbols)).WillRepeatedly(Return(true)); + EXPECT_CALL(*precomputation, GetCollocations(symbols)). + WillRepeatedly(Return(expected_location)); intersector = make_shared(suffix_array, precomputation, vocabulary, 15, 1); diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 487abcaf..4d0738f7 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -19,10 +19,11 @@ GrammarExtractor::GrammarExtractor( shared_ptr source_suffix_array, shared_ptr target_data_array, shared_ptr alignment, shared_ptr precomputation, - shared_ptr scorer, int min_gap_size, int max_rule_span, + shared_ptr scorer, shared_ptr vocabulary, + int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, int max_samples, bool require_tight_phrases) : - vocabulary(make_shared()), + vocabulary(vocabulary), rule_factory(make_shared( source_suffix_array, target_data_array, alignment, vocabulary, precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals, diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index ae407b47..8f570df2 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -32,6 +32,7 @@ class GrammarExtractor { shared_ptr alignment, shared_ptr precomputation, shared_ptr scorer, + shared_ptr vocabulary, int min_gap_size, int max_rule_span, int max_nonterminals, diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index 6f85abb4..4bdcf21f 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -6,9 +6,11 @@ namespace extractor { class MockDataArray : public DataArray { public: - MOCK_CONST_METHOD0(GetData, const vector&()); + MOCK_CONST_METHOD0(GetData, vector()); MOCK_CONST_METHOD1(AtIndex, int(int index)); MOCK_CONST_METHOD1(GetWordAtIndex, string(int index)); + MOCK_CONST_METHOD2(GetWordIds, vector(int start_index, int size)); + MOCK_CONST_METHOD2(GetWords, vector(int start_index, int size)); MOCK_CONST_METHOD0(GetSize, int()); MOCK_CONST_METHOD0(GetVocabularySize, int()); MOCK_CONST_METHOD1(HasWord, bool(const string& word)); diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h index 8753343e..5f7aa999 100644 --- a/extractor/mocks/mock_precomputation.h +++ b/extractor/mocks/mock_precomputation.h @@ -6,7 +6,8 @@ namespace extractor { class MockPrecomputation : public Precomputation { public: - MOCK_CONST_METHOD0(GetCollocations, const Index&()); + MOCK_CONST_METHOD1(Contains, bool(const vector& pattern)); + MOCK_CONST_METHOD1(GetCollocations, vector(const vector& pattern)); }; } // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 3b8aed69..38d8f489 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -5,22 +5,21 @@ #include "data_array.h" #include "suffix_array.h" +#include "vocabulary.h" using namespace std; namespace extractor { -int Precomputation::FIRST_NONTERMINAL = -1; -int Precomputation::SECOND_NONTERMINAL = -2; +int Precomputation::NONTERMINAL = -1; Precomputation::Precomputation( - shared_ptr suffix_array, int num_frequent_patterns, - int num_super_frequent_patterns, int max_rule_span, - int max_rule_symbols, int min_gap_size, + shared_ptr vocabulary, shared_ptr suffix_array, + int num_frequent_patterns, int num_super_frequent_patterns, + int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency) { - vector data = suffix_array->GetData()->GetData(); vector> frequent_patterns = FindMostFrequentPatterns( - suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, + suffix_array, num_frequent_patterns, max_frequent_phrase_len, min_frequency); // Construct sets containing the frequent and superfrequent contiguous @@ -34,28 +33,30 @@ Precomputation::Precomputation( } } + shared_ptr data_array = suffix_array->GetData(); vector> matchings; - for (size_t i = 0; i < data.size(); ++i) { + for (size_t i = 0; i < data_array->GetSize(); ++i) { // If the sentence is over, add all the discontiguous frequent patterns to // the index. - if (data[i] == DataArray::END_OF_LINE) { - AddCollocations(matchings, data, max_rule_span, min_gap_size, - max_rule_symbols); + if (data_array->AtIndex(i) == DataArray::END_OF_LINE) { + UpdateIndex(data_array, vocabulary, matchings, max_rule_span, + min_gap_size, max_rule_symbols); matchings.clear(); continue; } - vector pattern; // Find all the contiguous frequent patterns starting at position i. - for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { - pattern.push_back(data[i + j - 1]); - if (frequent_patterns_set.count(pattern)) { - int is_super_frequent = super_frequent_patterns_set.count(pattern); - matchings.push_back(make_tuple(i, j, is_super_frequent)); - } else { + vector pattern; + for (int j = 1; + j <= max_frequent_phrase_len && i + j <= data_array->GetSize(); + ++j) { + pattern.push_back(data_array->AtIndex(i + j - 1)); + if (!frequent_patterns_set.count(pattern)) { // If the current pattern is not frequent, any longer pattern having the // current pattern as prefix will not be frequent. break; } + int is_super_frequent = super_frequent_patterns_set.count(pattern); + matchings.push_back(make_tuple(i, j, is_super_frequent)); } } } @@ -65,8 +66,8 @@ Precomputation::Precomputation() {} Precomputation::~Precomputation() {} vector> Precomputation::FindMostFrequentPatterns( - shared_ptr suffix_array, const vector& data, - int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) { + shared_ptr suffix_array, int num_frequent_patterns, + int max_frequent_phrase_len, int min_frequency) { vector lcp = suffix_array->BuildLCPArray(); vector run_start(max_frequent_phrase_len); @@ -83,6 +84,7 @@ vector> Precomputation::FindMostFrequentPatterns( } } + shared_ptr data_array = suffix_array->GetData(); // Extract the most frequent patterns. vector> frequent_patterns; while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { @@ -90,7 +92,7 @@ vector> Precomputation::FindMostFrequentPatterns( int len = heap.top().second.second; heap.pop(); - vector pattern(data.begin() + start, data.begin() + start + len); + vector pattern = data_array->GetWordIds(start, len); if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == pattern.end()) { frequent_patterns.push_back(pattern); @@ -99,8 +101,9 @@ vector> Precomputation::FindMostFrequentPatterns( return frequent_patterns; } -void Precomputation::AddCollocations( - const vector>& matchings, const vector& data, +void Precomputation::UpdateIndex( + shared_ptr data_array, shared_ptr vocabulary, + const vector>& matchings, int max_rule_span, int min_gap_size, int max_rule_symbols) { // Select the leftmost subpattern. for (size_t i = 0; i < matchings.size(); ++i) { @@ -118,16 +121,15 @@ void Precomputation::AddCollocations( if (start2 - start1 - size1 >= min_gap_size && start2 + size2 - start1 <= max_rule_span && size1 + size2 + 1 <= max_rule_symbols) { - vector pattern(data.begin() + start1, - data.begin() + start1 + size1); - pattern.push_back(Precomputation::FIRST_NONTERMINAL); - pattern.insert(pattern.end(), data.begin() + start2, - data.begin() + start2 + size2); - AddStartPositions(collocations[pattern], start1, start2); + vector pattern; + AppendSubpattern(pattern, data_array, vocabulary, start1, size1); + pattern.push_back(Precomputation::NONTERMINAL); + AppendSubpattern(pattern, data_array, vocabulary, start2, size2); + AppendCollocation(index[pattern], {start1, start2}); // Try extending the binary collocation to a ternary collocation. if (is_super2) { - pattern.push_back(Precomputation::SECOND_NONTERMINAL); + pattern.push_back(Precomputation::NONTERMINAL); // Select the rightmost subpattern. for (size_t k = j + 1; k < matchings.size(); ++k) { int start3, size3, is_super3; @@ -140,9 +142,8 @@ void Precomputation::AddCollocations( && start3 + size3 - start1 <= max_rule_span && size1 + size2 + size3 + 2 <= max_rule_symbols && (is_super1 || is_super3)) { - pattern.insert(pattern.end(), data.begin() + start3, - data.begin() + start3 + size3); - AddStartPositions(collocations[pattern], start1, start2, start3); + AppendSubpattern(pattern, data_array, vocabulary, start3, size3); + AppendCollocation(index[pattern], {start1, start2, start3}); pattern.erase(pattern.end() - size3); } } @@ -152,25 +153,30 @@ void Precomputation::AddCollocations( } } -void Precomputation::AddStartPositions( - vector& positions, int pos1, int pos2) { - positions.push_back(pos1); - positions.push_back(pos2); +void Precomputation::AppendSubpattern( + vector& pattern, shared_ptr data_array, + shared_ptr vocabulary, int start, int size) { + vector words = data_array->GetWords(start, size); + for (const string& word: words) { + pattern.push_back(vocabulary->GetTerminalIndex(word)); + } +} + +void Precomputation::AppendCollocation( + vector& collocations, const vector& collocation) { + copy(collocation.begin(), collocation.end(), back_inserter(collocations)); } -void Precomputation::AddStartPositions( - vector& positions, int pos1, int pos2, int pos3) { - positions.push_back(pos1); - positions.push_back(pos2); - positions.push_back(pos3); +bool Precomputation::Contains(const vector& pattern) const { + return index.count(pattern); } -const Index& Precomputation::GetCollocations() const { - return collocations; +vector Precomputation::GetCollocations(const vector& pattern) const { + return index.at(pattern); } bool Precomputation::operator==(const Precomputation& other) const { - return collocations == other.collocations; + return index == other.index; } } // namespace extractor diff --git a/extractor/precomputation.h b/extractor/precomputation.h index e5fa3e37..6ade58df 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -19,7 +19,9 @@ namespace extractor { typedef boost::hash> VectorHash; typedef unordered_map, vector, VectorHash> Index; +class DataArray; class SuffixArray; +class Vocabulary; /** * Data structure wrapping an index with all the occurrences of the most @@ -35,9 +37,9 @@ class Precomputation { public: // Constructs the index using the suffix array. Precomputation( - shared_ptr suffix_array, int num_frequent_patterns, - int num_super_frequent_patterns, int max_rule_span, - int max_rule_symbols, int min_gap_size, + shared_ptr vocabulary, shared_ptr suffix_array, + int num_frequent_patterns, int num_super_frequent_patterns, + int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency); // Creates empty precomputation data structure. @@ -45,40 +47,43 @@ class Precomputation { virtual ~Precomputation(); - // Returns a reference to the index. - virtual const Index& GetCollocations() const; + // Returns whether a pattern is contained in the index of collocations. + virtual bool Contains(const vector& pattern) const; + + // Returns the list of collocations for a given pattern. + virtual vector GetCollocations(const vector& pattern) const; bool operator==(const Precomputation& other) const; - static int FIRST_NONTERMINAL; - static int SECOND_NONTERMINAL; + static int NONTERMINAL; private: // Finds the most frequent contiguous collocations. vector> FindMostFrequentPatterns( - shared_ptr suffix_array, const vector& data, - int num_frequent_patterns, int max_frequent_phrase_len, - int min_frequency); + shared_ptr suffix_array, int num_frequent_patterns, + int max_frequent_phrase_len, int min_frequency); // Given the locations of the frequent contiguous collocations in a sentence, // it adds new entries to the index for each discontiguous collocation // matching the criteria specified in the class description. - void AddCollocations( - const vector>& matchings, const vector& data, + void UpdateIndex( + shared_ptr data_array, shared_ptr vocabulary, + const vector>& matchings, int max_rule_span, int min_gap_size, int max_rule_symbols); - // Adds an occurrence of a binary collocation. - void AddStartPositions(vector& positions, int pos1, int pos2); + void AppendSubpattern( + vector& pattern, shared_ptr data_array, + shared_ptr vocabulary, int start, int size); - // Adds an occurrence of a ternary collocation. - void AddStartPositions(vector& positions, int pos1, int pos2, int pos3); + // Adds an occurrence of a collocation. + void AppendCollocation(vector& collocations, const vector& collocation); friend class boost::serialization::access; template void save(Archive& ar, unsigned int) const { - int num_entries = collocations.size(); + int num_entries = index.size(); ar << num_entries; - for (pair, vector> entry: collocations) { + for (pair, vector> entry: index) { ar << entry; } } @@ -89,13 +94,13 @@ class Precomputation { for (size_t i = 0; i < num_entries; ++i) { pair, vector> entry; ar >> entry; - collocations.insert(entry); + index.insert(entry); } } BOOST_SERIALIZATION_SPLIT_MEMBER(); - Index collocations; + Index index; }; } // namespace extractor diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index e81ece5d..fd85fcf8 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -9,6 +9,7 @@ #include "mocks/mock_data_array.h" #include "mocks/mock_suffix_array.h" +#include "mocks/mock_vocabulary.h" #include "precomputation.h" using namespace std; @@ -23,7 +24,31 @@ class PrecomputationTest : public Test { virtual void SetUp() { data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1}; data_array = make_shared(); - EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(data.size())); + for (size_t i = 0; i < data.size(); ++i) { + EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); + } + vector> expected_calls = {{8, 1}, {8, 2}, {6, 1}}; + for (const auto& call: expected_calls) { + int start = call.first; + int size = call.second; + vector word_ids(data.begin() + start, data.begin() + start + size); + EXPECT_CALL(*data_array, GetWordIds(start, size)) + .WillRepeatedly(Return(word_ids)); + } + + expected_calls = {{1, 1}, {5, 1}, {8, 1}, {9, 1}, {5, 2}, + {6, 1}, {8, 2}, {1, 2}, {2, 1}, {11, 1}}; + for (const auto& call: expected_calls) { + int start = call.first; + int size = call.second; + vector words; + for (size_t j = start; j < start + size; ++j) { + words.push_back(to_string(data[j])); + } + EXPECT_CALL(*data_array, GetWords(start, size)) + .WillRepeatedly(Return(words)); + } vector suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13}; vector lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0}; @@ -35,77 +60,98 @@ class PrecomputationTest : public Test { } EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp)); - precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); + vocabulary = make_shared(); + EXPECT_CALL(*vocabulary, GetTerminalIndex("2")).WillRepeatedly(Return(2)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("3")).WillRepeatedly(Return(3)); + + precomputation = Precomputation(vocabulary, suffix_array, + 3, 3, 10, 5, 1, 4, 2); } vector data; shared_ptr data_array; shared_ptr suffix_array; + shared_ptr vocabulary; Precomputation precomputation; }; TEST_F(PrecomputationTest, TestCollocations) { - Index collocations = precomputation.GetCollocations(); - vector key = {2, 3, -1, 2}; vector expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, 3, -1, 2, 3}; expected_value = {1, 5, 1, 8, 5, 8}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, 3, -1, 3}; expected_value = {1, 6, 1, 9, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 2}; expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 3}; expected_value = {2, 6, 2, 9, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 2, 3}; expected_value = {2, 5, 2, 8, 6, 8}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 2}; expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 2, 3}; expected_value = {1, 5, 1, 8, 5, 8}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 3}; expected_value = {1, 6, 1, 9, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {2, -1, 2, -2, 2}; + key = {2, -1, 2, -1, 2}; expected_value = {1, 5, 8, 5, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 2, -2, 3}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {2, -1, 2, -1, 3}; expected_value = {1, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 3, -2, 2}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {2, -1, 3, -1, 2}; expected_value = {1, 6, 8, 5, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 3, -2, 3}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {2, -1, 3, -1, 3}; expected_value = {1, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 2, -2, 2}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {3, -1, 2, -1, 2}; expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 2, -2, 3}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {3, -1, 2, -1, 3}; expected_value = {2, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 3, -2, 2}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {3, -1, 3, -1, 2}; expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 3, -2, 3}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {3, -1, 3, -1, 3}; expected_value = {2, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); // Exceeds max_rule_symbols. - key = {2, -1, 2, -2, 2, 3}; - EXPECT_EQ(0, collocations.count(key)); + key = {2, -1, 2, -1, 2, 3}; + EXPECT_FALSE(precomputation.Contains(key)); // Contains non frequent pattern. key = {2, -1, 5}; - EXPECT_EQ(0, collocations.count(key)); + EXPECT_FALSE(precomputation.Contains(key)); } TEST_F(PrecomputationTest, TestSerialization) { diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 6eb55073..85c8a422 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -28,6 +28,7 @@ #include "suffix_array.h" #include "time_util.h" #include "translation_table.h" +#include "vocabulary.h" namespace fs = boost::filesystem; namespace po = boost::program_options; @@ -142,11 +143,14 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; + shared_ptr vocabulary = make_shared(); + // Constructs an index storing the occurrences in the source data for each // frequent collocation. start_time = Clock::now(); cerr << "Precomputing collocations..." << endl; shared_ptr precomputation = make_shared( + vocabulary, source_suffix_array, vm["frequent"].as(), vm["super_frequent"].as(), @@ -194,6 +198,7 @@ int main(int argc, char** argv) { alignment, precomputation, scorer, + vocabulary, vm["min_gap_size"].as(), vm["max_rule_span"].as(), vm["max_nonterminals"].as(), diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index ba0dbcc3..a9fd1eab 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -21,7 +21,7 @@ class SuffixArrayTest : public Test { virtual void SetUp() { data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2}; data_array = make_shared(); - EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data)); EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7)); EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13)); suffix_array = SuffixArray(data_array); diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc index 606777bd..72551a12 100644 --- a/extractor/translation_table_test.cc +++ b/extractor/translation_table_test.cc @@ -28,7 +28,7 @@ class TranslationTableTest : public Test { vector source_sentence_start = {0, 6, 10, 14}; shared_ptr source_data_array = make_shared(); EXPECT_CALL(*source_data_array, GetData()) - .WillRepeatedly(ReturnRef(source_data)); + .WillRepeatedly(Return(source_data)); EXPECT_CALL(*source_data_array, GetNumSentences()) .WillRepeatedly(Return(3)); for (size_t i = 0; i < source_sentence_start.size(); ++i) { @@ -48,7 +48,7 @@ class TranslationTableTest : public Test { vector target_sentence_start = {0, 7, 10, 13}; shared_ptr target_data_array = make_shared(); EXPECT_CALL(*target_data_array, GetData()) - .WillRepeatedly(ReturnRef(target_data)); + .WillRepeatedly(Return(target_data)); for (size_t i = 0; i < target_sentence_start.size(); ++i) { EXPECT_CALL(*target_data_array, GetSentenceStart(i)) .WillRepeatedly(Return(target_sentence_start[i])); -- cgit v1.2.3 From 3973a7e4a8302b4a02fee7d2950bb469b37e2452 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Sun, 24 Nov 2013 13:19:28 +0000 Subject: Reduce memory overhead for constructing the intersector. --- extractor/Makefile.am | 3 +- extractor/compile.cc | 4 ++ extractor/data_array.cc | 2 +- extractor/data_array.h | 2 +- extractor/fast_intersector.cc | 40 +++-------- extractor/fast_intersector.h | 8 +-- extractor/fast_intersector_test.cc | 10 ++- extractor/grammar_extractor.cc | 5 +- extractor/grammar_extractor.h | 1 + extractor/mocks/mock_data_array.h | 2 +- extractor/mocks/mock_precomputation.h | 3 +- extractor/precomputation.cc | 125 +++++++++++++++++++++------------- extractor/precomputation.h | 41 ++++++----- extractor/precomputation_test.cc | 73 +++++++++++++------- extractor/run_extractor.cc | 5 ++ extractor/suffix_array_test.cc | 2 +- extractor/translation_table_test.cc | 4 +- extractor/vocabulary.cc | 7 +- 18 files changed, 194 insertions(+), 143 deletions(-) (limited to 'extractor/mocks') diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 65a3d436..faf25d89 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -53,7 +53,8 @@ endif noinst_PROGRAMS = $(RUNNABLE_TESTS) -TESTS = $(RUNNABLE_TESTS) +# TESTS = $(RUNNABLE_TESTS) +TESTS = precomputation_test alignment_test_SOURCES = alignment_test.cc alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a diff --git a/extractor/compile.cc b/extractor/compile.cc index 65fdd509..0d62757e 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -13,6 +13,7 @@ #include "suffix_array.h" #include "time_util.h" #include "translation_table.h" +#include "vocabulary.h" namespace ar = boost::archive; namespace fs = boost::filesystem; @@ -125,9 +126,12 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; + shared_ptr vocabulary; + start_time = Clock::now(); cerr << "Precomputing collocations..." << endl; Precomputation precomputation( + vocabulary, source_suffix_array, vm["frequent"].as(), vm["super_frequent"].as(), diff --git a/extractor/data_array.cc b/extractor/data_array.cc index 82efcd51..6757cae7 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -78,7 +78,7 @@ void DataArray::CreateDataArray(const vector& lines) { DataArray::~DataArray() {} -const vector& DataArray::GetData() const { +vector DataArray::GetData() const { return data; } diff --git a/extractor/data_array.h b/extractor/data_array.h index 5207366d..e9af5bd0 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -51,7 +51,7 @@ class DataArray { virtual ~DataArray(); // Returns a vector containing the word ids. - virtual const vector& GetData() const; + virtual vector GetData() const; // Returns the word id at the specified position. virtual int AtIndex(int index) const; diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc index a8591a72..0d1fa6d8 100644 --- a/extractor/fast_intersector.cc +++ b/extractor/fast_intersector.cc @@ -11,41 +11,22 @@ namespace extractor { -FastIntersector::FastIntersector(shared_ptr suffix_array, - shared_ptr precomputation, - shared_ptr vocabulary, - int max_rule_span, - int min_gap_size) : +FastIntersector::FastIntersector( + shared_ptr suffix_array, + shared_ptr precomputation, + shared_ptr vocabulary, + int max_rule_span, + int min_gap_size) : suffix_array(suffix_array), + precomputation(precomputation), vocabulary(vocabulary), max_rule_span(max_rule_span), - min_gap_size(min_gap_size) { - Index precomputed_collocations = precomputation->GetCollocations(); - for (pair, vector> entry: precomputed_collocations) { - vector phrase = ConvertPhrase(entry.first); - collocations[phrase] = entry.second; - } -} + min_gap_size(min_gap_size) {} FastIntersector::FastIntersector() {} FastIntersector::~FastIntersector() {} -vector FastIntersector::ConvertPhrase(const vector& old_phrase) { - vector new_phrase; - new_phrase.reserve(old_phrase.size()); - shared_ptr data_array = suffix_array->GetData(); - for (int word_id: old_phrase) { - if (word_id < 0) { - new_phrase.push_back(word_id); - } else { - new_phrase.push_back( - vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); - } - } - return new_phrase; -} - PhraseLocation FastIntersector::Intersect( PhraseLocation& prefix_location, PhraseLocation& suffix_location, @@ -59,8 +40,9 @@ PhraseLocation FastIntersector::Intersect( assert(vocabulary->IsTerminal(symbols.front()) && vocabulary->IsTerminal(symbols.back())); - if (collocations.count(symbols)) { - return PhraseLocation(collocations[symbols], phrase.Arity() + 1); + if (precomputation->Contains(symbols)) { + return PhraseLocation(precomputation->GetCollocations(symbols), + phrase.Arity() + 1); } bool prefix_ends_with_x = diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h index 2819d239..305373dc 100644 --- a/extractor/fast_intersector.h +++ b/extractor/fast_intersector.h @@ -12,7 +12,6 @@ using namespace std; namespace extractor { typedef boost::hash> VectorHash; -typedef unordered_map, vector, VectorHash> Index; class Phrase; class PhraseLocation; @@ -52,11 +51,6 @@ class FastIntersector { FastIntersector(); private: - // Uses the vocabulary to convert the phrase from the numberized format - // specified by the source data array to the numberized format given by the - // vocabulary. - vector ConvertPhrase(const vector& old_phrase); - // Estimates the number of computations needed if the prefix/suffix is // extended. If the last/first symbol is separated from the rest of the phrase // by a nonterminal, then for each occurrence of the prefix/suffix we need to @@ -85,10 +79,10 @@ class FastIntersector { pair GetSearchRange(bool has_marginal_x) const; shared_ptr suffix_array; + shared_ptr precomputation; shared_ptr vocabulary; int max_rule_span; int min_gap_size; - Index collocations; }; } // namespace extractor diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc index 76c3aaea..f2a26ba1 100644 --- a/extractor/fast_intersector_test.cc +++ b/extractor/fast_intersector_test.cc @@ -59,15 +59,13 @@ class FastIntersectorTest : public Test { } precomputation = make_shared(); - EXPECT_CALL(*precomputation, GetCollocations()) - .WillRepeatedly(ReturnRef(collocations)); + EXPECT_CALL(*precomputation, Contains(_)).WillRepeatedly(Return(false)); phrase_builder = make_shared(vocabulary); intersector = make_shared(suffix_array, precomputation, vocabulary, 15, 1); } - Index collocations; shared_ptr data_array; shared_ptr suffix_array; shared_ptr precomputation; @@ -82,9 +80,9 @@ TEST_F(FastIntersectorTest, TestCachedCollocation) { Phrase phrase = phrase_builder->Build(symbols); PhraseLocation prefix_location(15, 16), suffix_location(16, 17); - collocations[symbols] = expected_location; - EXPECT_CALL(*precomputation, GetCollocations()) - .WillRepeatedly(ReturnRef(collocations)); + EXPECT_CALL(*precomputation, Contains(symbols)).WillRepeatedly(Return(true)); + EXPECT_CALL(*precomputation, GetCollocations(symbols)). + WillRepeatedly(Return(expected_location)); intersector = make_shared(suffix_array, precomputation, vocabulary, 15, 1); diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 487abcaf..4d0738f7 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -19,10 +19,11 @@ GrammarExtractor::GrammarExtractor( shared_ptr source_suffix_array, shared_ptr target_data_array, shared_ptr alignment, shared_ptr precomputation, - shared_ptr scorer, int min_gap_size, int max_rule_span, + shared_ptr scorer, shared_ptr vocabulary, + int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, int max_samples, bool require_tight_phrases) : - vocabulary(make_shared()), + vocabulary(vocabulary), rule_factory(make_shared( source_suffix_array, target_data_array, alignment, vocabulary, precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals, diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index ae407b47..8f570df2 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -32,6 +32,7 @@ class GrammarExtractor { shared_ptr alignment, shared_ptr precomputation, shared_ptr scorer, + shared_ptr vocabulary, int min_gap_size, int max_rule_span, int max_nonterminals, diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index 6f85abb4..d39cb0c4 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -6,7 +6,7 @@ namespace extractor { class MockDataArray : public DataArray { public: - MOCK_CONST_METHOD0(GetData, const vector&()); + MOCK_CONST_METHOD0(GetData, vector()); MOCK_CONST_METHOD1(AtIndex, int(int index)); MOCK_CONST_METHOD1(GetWordAtIndex, string(int index)); MOCK_CONST_METHOD0(GetSize, int()); diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h index 8753343e..5f7aa999 100644 --- a/extractor/mocks/mock_precomputation.h +++ b/extractor/mocks/mock_precomputation.h @@ -6,7 +6,8 @@ namespace extractor { class MockPrecomputation : public Precomputation { public: - MOCK_CONST_METHOD0(GetCollocations, const Index&()); + MOCK_CONST_METHOD1(Contains, bool(const vector& pattern)); + MOCK_CONST_METHOD1(GetCollocations, vector(const vector& pattern)); }; } // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 3b8aed69..3e58e2a9 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -5,59 +5,67 @@ #include "data_array.h" #include "suffix_array.h" +#include "time_util.h" +#include "vocabulary.h" using namespace std; namespace extractor { -int Precomputation::FIRST_NONTERMINAL = -1; -int Precomputation::SECOND_NONTERMINAL = -2; - Precomputation::Precomputation( - shared_ptr suffix_array, int num_frequent_patterns, - int num_super_frequent_patterns, int max_rule_span, - int max_rule_symbols, int min_gap_size, + shared_ptr vocabulary, shared_ptr suffix_array, + int num_frequent_patterns, int num_super_frequent_patterns, + int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency) { - vector data = suffix_array->GetData()->GetData(); + Clock::time_point start_time = Clock::now(); + shared_ptr data_array = suffix_array->GetData(); + vector data = data_array->GetData(); vector> frequent_patterns = FindMostFrequentPatterns( suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, min_frequency); + Clock::time_point end_time = Clock::now(); + cerr << "Finding most frequent patterns took " + << GetDuration(start_time, end_time) << " seconds..." << endl; - // Construct sets containing the frequent and superfrequent contiguous - // collocations. - unordered_set, VectorHash> frequent_patterns_set; - unordered_set, VectorHash> super_frequent_patterns_set; + vector> pattern_annotations(frequent_patterns.size()); + unordered_map, int, VectorHash> frequent_patterns_index; for (size_t i = 0; i < frequent_patterns.size(); ++i) { - frequent_patterns_set.insert(frequent_patterns[i]); - if (i < num_super_frequent_patterns) { - super_frequent_patterns_set.insert(frequent_patterns[i]); - } + frequent_patterns_index[frequent_patterns[i]] = i; + pattern_annotations[i] = AnnotatePattern(vocabulary, data_array, + frequent_patterns[i]); } + start_time = Clock::now(); vector> matchings; + vector> annotations; for (size_t i = 0; i < data.size(); ++i) { // If the sentence is over, add all the discontiguous frequent patterns to // the index. if (data[i] == DataArray::END_OF_LINE) { - AddCollocations(matchings, data, max_rule_span, min_gap_size, - max_rule_symbols); + UpdateIndex(matchings, annotations, max_rule_span, min_gap_size, + max_rule_symbols); matchings.clear(); + annotations.clear(); continue; } - vector pattern; // Find all the contiguous frequent patterns starting at position i. + vector pattern; for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { pattern.push_back(data[i + j - 1]); - if (frequent_patterns_set.count(pattern)) { - int is_super_frequent = super_frequent_patterns_set.count(pattern); - matchings.push_back(make_tuple(i, j, is_super_frequent)); - } else { + auto it = frequent_patterns_index.find(pattern); + if (it == frequent_patterns_index.end()) { // If the current pattern is not frequent, any longer pattern having the // current pattern as prefix will not be frequent. break; } + int is_super_frequent = it->second < num_super_frequent_patterns; + matchings.push_back(make_tuple(i, j, is_super_frequent)); + annotations.push_back(pattern_annotations[it->second]); } } + end_time = Clock::now(); + cerr << "Constructing collocations index took " + << GetDuration(start_time, end_time) << " seconds..." << endl; } Precomputation::Precomputation() {} @@ -75,9 +83,9 @@ vector> Precomputation::FindMostFrequentPatterns( for (size_t i = 1; i < lcp.size(); ++i) { for (int len = lcp[i]; len < max_frequent_phrase_len; ++len) { int frequency = i - run_start[len]; - if (frequency >= min_frequency) { - heap.push(make_pair(frequency, - make_pair(suffix_array->GetSuffix(run_start[len]), len + 1))); + int start = suffix_array->GetSuffix(run_start[len]); + if (frequency >= min_frequency && start + len <= data.size()) { + heap.push(make_pair(frequency, make_pair(start, len + 1))); } run_start[len] = i; } @@ -99,8 +107,20 @@ vector> Precomputation::FindMostFrequentPatterns( return frequent_patterns; } -void Precomputation::AddCollocations( - const vector>& matchings, const vector& data, +vector Precomputation::AnnotatePattern( + shared_ptr vocabulary, shared_ptr data_array, + const vector& pattern) const { + vector annotation; + for (int word_id: pattern) { + annotation.push_back(vocabulary->GetTerminalIndex( + data_array->GetWord(word_id))); + } + return annotation; +} + +void Precomputation::UpdateIndex( + const vector>& matchings, + const vector>& annotations, int max_rule_span, int min_gap_size, int max_rule_symbols) { // Select the leftmost subpattern. for (size_t i = 0; i < matchings.size(); ++i) { @@ -118,16 +138,14 @@ void Precomputation::AddCollocations( if (start2 - start1 - size1 >= min_gap_size && start2 + size2 - start1 <= max_rule_span && size1 + size2 + 1 <= max_rule_symbols) { - vector pattern(data.begin() + start1, - data.begin() + start1 + size1); - pattern.push_back(Precomputation::FIRST_NONTERMINAL); - pattern.insert(pattern.end(), data.begin() + start2, - data.begin() + start2 + size2); - AddStartPositions(collocations[pattern], start1, start2); + vector pattern = annotations[i]; + pattern.push_back(-1); + AppendSubpattern(pattern, annotations[j]); + AppendCollocation(index[pattern], start1, start2); // Try extending the binary collocation to a ternary collocation. if (is_super2) { - pattern.push_back(Precomputation::SECOND_NONTERMINAL); + pattern.push_back(-2); // Select the rightmost subpattern. for (size_t k = j + 1; k < matchings.size(); ++k) { int start3, size3, is_super3; @@ -140,9 +158,8 @@ void Precomputation::AddCollocations( && start3 + size3 - start1 <= max_rule_span && size1 + size2 + size3 + 2 <= max_rule_symbols && (is_super1 || is_super3)) { - pattern.insert(pattern.end(), data.begin() + start3, - data.begin() + start3 + size3); - AddStartPositions(collocations[pattern], start1, start2, start3); + AppendSubpattern(pattern, annotations[k]); + AppendCollocation(index[pattern], start1, start2, start3); pattern.erase(pattern.end() - size3); } } @@ -152,25 +169,35 @@ void Precomputation::AddCollocations( } } -void Precomputation::AddStartPositions( - vector& positions, int pos1, int pos2) { - positions.push_back(pos1); - positions.push_back(pos2); +void Precomputation::AppendSubpattern( + vector& pattern, + const vector& subpattern) { + copy(subpattern.begin(), subpattern.end(), back_inserter(pattern)); +} + +void Precomputation::AppendCollocation( + vector& collocations, int pos1, int pos2) { + collocations.push_back(pos1); + collocations.push_back(pos2); +} + +void Precomputation::AppendCollocation( + vector& collocations, int pos1, int pos2, int pos3) { + collocations.push_back(pos1); + collocations.push_back(pos2); + collocations.push_back(pos3); } -void Precomputation::AddStartPositions( - vector& positions, int pos1, int pos2, int pos3) { - positions.push_back(pos1); - positions.push_back(pos2); - positions.push_back(pos3); +bool Precomputation::Contains(const vector& pattern) const { + return index.count(pattern); } -const Index& Precomputation::GetCollocations() const { - return collocations; +vector Precomputation::GetCollocations(const vector& pattern) const { + return index.at(pattern); } bool Precomputation::operator==(const Precomputation& other) const { - return collocations == other.collocations; + return index == other.index; } } // namespace extractor diff --git a/extractor/precomputation.h b/extractor/precomputation.h index e5fa3e37..2b34fc29 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -19,7 +19,9 @@ namespace extractor { typedef boost::hash> VectorHash; typedef unordered_map, vector, VectorHash> Index; +class DataArray; class SuffixArray; +class Vocabulary; /** * Data structure wrapping an index with all the occurrences of the most @@ -35,9 +37,9 @@ class Precomputation { public: // Constructs the index using the suffix array. Precomputation( - shared_ptr suffix_array, int num_frequent_patterns, - int num_super_frequent_patterns, int max_rule_span, - int max_rule_symbols, int min_gap_size, + shared_ptr vocabulary, shared_ptr suffix_array, + int num_frequent_patterns, int num_super_frequent_patterns, + int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency); // Creates empty precomputation data structure. @@ -45,13 +47,13 @@ class Precomputation { virtual ~Precomputation(); - // Returns a reference to the index. - virtual const Index& GetCollocations() const; + // Returns whether a pattern is contained in the index of collocations. + virtual bool Contains(const vector& pattern) const; - bool operator==(const Precomputation& other) const; + // Returns the list of collocations for a given pattern. + virtual vector GetCollocations(const vector& pattern) const; - static int FIRST_NONTERMINAL; - static int SECOND_NONTERMINAL; + bool operator==(const Precomputation& other) const; private: // Finds the most frequent contiguous collocations. @@ -60,25 +62,32 @@ class Precomputation { int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency); + vector AnnotatePattern(shared_ptr vocabulary, + shared_ptr data_array, + const vector& pattern) const; + // Given the locations of the frequent contiguous collocations in a sentence, // it adds new entries to the index for each discontiguous collocation // matching the criteria specified in the class description. - void AddCollocations( - const vector>& matchings, const vector& data, + void UpdateIndex( + const vector>& matchings, + const vector>& annotations, int max_rule_span, int min_gap_size, int max_rule_symbols); + void AppendSubpattern(vector& pattern, const vector& subpattern); + // Adds an occurrence of a binary collocation. - void AddStartPositions(vector& positions, int pos1, int pos2); + void AppendCollocation(vector& collocations, int pos1, int pos2); // Adds an occurrence of a ternary collocation. - void AddStartPositions(vector& positions, int pos1, int pos2, int pos3); + void AppendCollocation(vector& collocations, int pos1, int pos2, int pos3); friend class boost::serialization::access; template void save(Archive& ar, unsigned int) const { - int num_entries = collocations.size(); + int num_entries = index.size(); ar << num_entries; - for (pair, vector> entry: collocations) { + for (pair, vector> entry: index) { ar << entry; } } @@ -89,13 +98,13 @@ class Precomputation { for (size_t i = 0; i < num_entries; ++i) { pair, vector> entry; ar >> entry; - collocations.insert(entry); + index.insert(entry); } } BOOST_SERIALIZATION_SPLIT_MEMBER(); - Index collocations; + Index index; }; } // namespace extractor diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index e81ece5d..3a98ce05 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -9,6 +9,7 @@ #include "mocks/mock_data_array.h" #include "mocks/mock_suffix_array.h" +#include "mocks/mock_vocabulary.h" #include "precomputation.h" using namespace std; @@ -23,7 +24,12 @@ class PrecomputationTest : public Test { virtual void SetUp() { data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1}; data_array = make_shared(); - EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data)); + for (size_t i = 0; i < data.size(); ++i) { + EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); + } + EXPECT_CALL(*data_array, GetWord(2)).WillRepeatedly(Return("2")); + EXPECT_CALL(*data_array, GetWord(3)).WillRepeatedly(Return("3")); vector suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13}; vector lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0}; @@ -35,77 +41,98 @@ class PrecomputationTest : public Test { } EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp)); - precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); + vocabulary = make_shared(); + EXPECT_CALL(*vocabulary, GetTerminalIndex("2")).WillRepeatedly(Return(2)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("3")).WillRepeatedly(Return(3)); + + precomputation = Precomputation(vocabulary, suffix_array, + 3, 3, 10, 5, 1, 4, 2); } vector data; shared_ptr data_array; shared_ptr suffix_array; + shared_ptr vocabulary; Precomputation precomputation; }; TEST_F(PrecomputationTest, TestCollocations) { - Index collocations = precomputation.GetCollocations(); - vector key = {2, 3, -1, 2}; vector expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, 3, -1, 2, 3}; expected_value = {1, 5, 1, 8, 5, 8}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, 3, -1, 3}; expected_value = {1, 6, 1, 9, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 2}; expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 3}; expected_value = {2, 6, 2, 9, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 2, 3}; expected_value = {2, 5, 2, 8, 6, 8}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 2}; expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 2, 3}; expected_value = {1, 5, 1, 8, 5, 8}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 3}; expected_value = {1, 6, 1, 9, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 2, -2, 2}; expected_value = {1, 5, 8, 5, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 2, -2, 3}; expected_value = {1, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 3, -2, 2}; expected_value = {1, 6, 8, 5, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 3, -2, 3}; expected_value = {1, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 2, -2, 2}; expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 2, -2, 3}; expected_value = {2, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 3, -2, 2}; expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 3, -2, 3}; expected_value = {2, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); // Exceeds max_rule_symbols. key = {2, -1, 2, -2, 2, 3}; - EXPECT_EQ(0, collocations.count(key)); + EXPECT_FALSE(precomputation.Contains(key)); // Contains non frequent pattern. key = {2, -1, 5}; - EXPECT_EQ(0, collocations.count(key)); + EXPECT_FALSE(precomputation.Contains(key)); } TEST_F(PrecomputationTest, TestSerialization) { diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 6eb55073..85c8a422 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -28,6 +28,7 @@ #include "suffix_array.h" #include "time_util.h" #include "translation_table.h" +#include "vocabulary.h" namespace fs = boost::filesystem; namespace po = boost::program_options; @@ -142,11 +143,14 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; + shared_ptr vocabulary = make_shared(); + // Constructs an index storing the occurrences in the source data for each // frequent collocation. start_time = Clock::now(); cerr << "Precomputing collocations..." << endl; shared_ptr precomputation = make_shared( + vocabulary, source_suffix_array, vm["frequent"].as(), vm["super_frequent"].as(), @@ -194,6 +198,7 @@ int main(int argc, char** argv) { alignment, precomputation, scorer, + vocabulary, vm["min_gap_size"].as(), vm["max_rule_span"].as(), vm["max_nonterminals"].as(), diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index ba0dbcc3..a9fd1eab 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -21,7 +21,7 @@ class SuffixArrayTest : public Test { virtual void SetUp() { data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2}; data_array = make_shared(); - EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data)); EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7)); EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13)); suffix_array = SuffixArray(data_array); diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc index 606777bd..72551a12 100644 --- a/extractor/translation_table_test.cc +++ b/extractor/translation_table_test.cc @@ -28,7 +28,7 @@ class TranslationTableTest : public Test { vector source_sentence_start = {0, 6, 10, 14}; shared_ptr source_data_array = make_shared(); EXPECT_CALL(*source_data_array, GetData()) - .WillRepeatedly(ReturnRef(source_data)); + .WillRepeatedly(Return(source_data)); EXPECT_CALL(*source_data_array, GetNumSentences()) .WillRepeatedly(Return(3)); for (size_t i = 0; i < source_sentence_start.size(); ++i) { @@ -48,7 +48,7 @@ class TranslationTableTest : public Test { vector target_sentence_start = {0, 7, 10, 13}; shared_ptr target_data_array = make_shared(); EXPECT_CALL(*target_data_array, GetData()) - .WillRepeatedly(ReturnRef(target_data)); + .WillRepeatedly(Return(target_data)); for (size_t i = 0; i < target_sentence_start.size(); ++i) { EXPECT_CALL(*target_data_array, GetSentenceStart(i)) .WillRepeatedly(Return(target_sentence_start[i])); diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc index 15795d1e..aef674a5 100644 --- a/extractor/vocabulary.cc +++ b/extractor/vocabulary.cc @@ -8,12 +8,13 @@ int Vocabulary::GetTerminalIndex(const string& word) { int word_id = -1; #pragma omp critical (vocabulary) { - if (!dictionary.count(word)) { + auto it = dictionary.find(word); + if (it != dictionary.end()) { + word_id = it->second; + } else { word_id = words.size(); dictionary[word] = word_id; words.push_back(word); - } else { - word_id = dictionary[word]; } } return word_id; -- cgit v1.2.3 From 467ef6ce78cfe7341a696ebf0948e377be619ae5 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Mon, 25 Nov 2013 18:19:13 +0000 Subject: Reduce unordered_map calls. --- extractor/Makefile.am | 3 +-- extractor/data_array.cc | 4 ---- extractor/data_array.h | 3 --- extractor/data_array_test.cc | 4 ---- extractor/mocks/mock_data_array.h | 1 - extractor/suffix_array.cc | 4 ++-- extractor/suffix_array_test.cc | 6 +----- extractor/translation_table.cc | 14 ++++++-------- extractor/translation_table_test.cc | 10 ++-------- 9 files changed, 12 insertions(+), 37 deletions(-) (limited to 'extractor/mocks') diff --git a/extractor/Makefile.am b/extractor/Makefile.am index faf25d89..65a3d436 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -53,8 +53,7 @@ endif noinst_PROGRAMS = $(RUNNABLE_TESTS) -# TESTS = $(RUNNABLE_TESTS) -TESTS = precomputation_test +TESTS = $(RUNNABLE_TESTS) alignment_test_SOURCES = alignment_test.cc alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a diff --git a/extractor/data_array.cc b/extractor/data_array.cc index 6757cae7..ac0493fd 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -115,10 +115,6 @@ int DataArray::GetSentenceId(int position) const { return sentence_id[position]; } -bool DataArray::HasWord(const string& word) const { - return word2id.count(word); -} - int DataArray::GetWordId(const string& word) const { auto result = word2id.find(word); return result == word2id.end() ? -1 : result->second; diff --git a/extractor/data_array.h b/extractor/data_array.h index e9af5bd0..c5dc8a26 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -65,9 +65,6 @@ class DataArray { // Returns the number of distinct words in the data array. virtual int GetVocabularySize() const; - // Returns whether a word has ever been observed in the data array. - virtual bool HasWord(const string& word) const; - // Returns the word id for a given word or -1 if it the word has never been // observed. virtual int GetWordId(const string& word) const; diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc index 6c329e34..b6b56561 100644 --- a/extractor/data_array_test.cc +++ b/extractor/data_array_test.cc @@ -58,16 +58,12 @@ TEST_F(DataArrayTest, TestGetData) { TEST_F(DataArrayTest, TestVocabulary) { EXPECT_EQ(9, source_data.GetVocabularySize()); - EXPECT_TRUE(source_data.HasWord("mere")); EXPECT_EQ(4, source_data.GetWordId("mere")); EXPECT_EQ("mere", source_data.GetWord(4)); - EXPECT_FALSE(source_data.HasWord("banane")); EXPECT_EQ(11, target_data.GetVocabularySize()); - EXPECT_TRUE(target_data.HasWord("apples")); EXPECT_EQ(4, target_data.GetWordId("apples")); EXPECT_EQ("apples", target_data.GetWord(4)); - EXPECT_FALSE(target_data.HasWord("bananas")); } TEST_F(DataArrayTest, TestSentenceData) { diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index d39cb0c4..edc525fa 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -11,7 +11,6 @@ class MockDataArray : public DataArray { MOCK_CONST_METHOD1(GetWordAtIndex, string(int index)); MOCK_CONST_METHOD0(GetSize, int()); MOCK_CONST_METHOD0(GetVocabularySize, int()); - MOCK_CONST_METHOD1(HasWord, bool(const string& word)); MOCK_CONST_METHOD1(GetWordId, int(const string& word)); MOCK_CONST_METHOD1(GetWord, string(int word_id)); MOCK_CONST_METHOD1(GetSentenceLength, int(int sentence_id)); diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc index ac230d13..4a514b12 100644 --- a/extractor/suffix_array.cc +++ b/extractor/suffix_array.cc @@ -187,12 +187,12 @@ shared_ptr SuffixArray::GetData() const { PhraseLocation SuffixArray::Lookup(int low, int high, const string& word, int offset) const { - if (!data_array->HasWord(word)) { + int word_id = data_array->GetWordId(word); + if (word_id == -1) { // Return empty phrase location. return PhraseLocation(0, 0); } - int word_id = data_array->GetWordId(word); if (offset == 0) { return PhraseLocation(word_start[word_id], word_start[word_id + 1]); } diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index a9fd1eab..161edbc0 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -55,22 +55,18 @@ TEST_F(SuffixArrayTest, TestLookup) { EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); } - EXPECT_CALL(*data_array, HasWord("word1")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word1")).WillRepeatedly(Return(6)); EXPECT_EQ(PhraseLocation(11, 14), suffix_array.Lookup(0, 14, "word1", 0)); - EXPECT_CALL(*data_array, HasWord("word2")).WillRepeatedly(Return(false)); + EXPECT_CALL(*data_array, GetWordId("word2")).WillRepeatedly(Return(-1)); EXPECT_EQ(PhraseLocation(0, 0), suffix_array.Lookup(0, 14, "word2", 0)); - EXPECT_CALL(*data_array, HasWord("word3")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word3")).WillRepeatedly(Return(4)); EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 14, "word3", 1)); - EXPECT_CALL(*data_array, HasWord("word4")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word4")).WillRepeatedly(Return(1)); EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word4", 2)); - EXPECT_CALL(*data_array, HasWord("word5")).WillRepeatedly(Return(true)); EXPECT_CALL(*data_array, GetWordId("word5")).WillRepeatedly(Return(2)); EXPECT_EQ(PhraseLocation(11, 13), suffix_array.Lookup(11, 13, "word5", 3)); diff --git a/extractor/translation_table.cc b/extractor/translation_table.cc index 1b1ba112..11e29e1e 100644 --- a/extractor/translation_table.cc +++ b/extractor/translation_table.cc @@ -90,13 +90,12 @@ void TranslationTable::IncrementLinksCount( double TranslationTable::GetTargetGivenSourceScore( const string& source_word, const string& target_word) { - if (!source_data_array->HasWord(source_word) || - !target_data_array->HasWord(target_word)) { + int source_id = source_data_array->GetWordId(source_word); + int target_id = target_data_array->GetWordId(target_word); + if (source_id == -1 || target_id == -1) { return -1; } - int source_id = source_data_array->GetWordId(source_word); - int target_id = target_data_array->GetWordId(target_word); auto entry = make_pair(source_id, target_id); auto it = translation_probabilities.find(entry); if (it == translation_probabilities.end()) { @@ -107,13 +106,12 @@ double TranslationTable::GetTargetGivenSourceScore( double TranslationTable::GetSourceGivenTargetScore( const string& source_word, const string& target_word) { - if (!source_data_array->HasWord(source_word) || - !target_data_array->HasWord(target_word)) { + int source_id = source_data_array->GetWordId(source_word); + int target_id = target_data_array->GetWordId(target_word); + if (source_id == -1 || target_id == -1) { return -1; } - int source_id = source_data_array->GetWordId(source_word); - int target_id = target_data_array->GetWordId(target_word); auto entry = make_pair(source_id, target_id); auto it = translation_probabilities.find(entry); if (it == translation_probabilities.end()) { diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc index 72551a12..3cfc0011 100644 --- a/extractor/translation_table_test.cc +++ b/extractor/translation_table_test.cc @@ -36,13 +36,10 @@ class TranslationTableTest : public Test { .WillRepeatedly(Return(source_sentence_start[i])); } for (size_t i = 0; i < words.size(); ++i) { - EXPECT_CALL(*source_data_array, HasWord(words[i])) - .WillRepeatedly(Return(true)); EXPECT_CALL(*source_data_array, GetWordId(words[i])) .WillRepeatedly(Return(i + 2)); } - EXPECT_CALL(*source_data_array, HasWord("d")) - .WillRepeatedly(Return(false)); + EXPECT_CALL(*source_data_array, GetWordId("d")).WillRepeatedly(Return(-1)); vector target_data = {2, 3, 2, 3, 4, 5, 0, 3, 6, 0, 2, 7, 0}; vector target_sentence_start = {0, 7, 10, 13}; @@ -54,13 +51,10 @@ class TranslationTableTest : public Test { .WillRepeatedly(Return(target_sentence_start[i])); } for (size_t i = 0; i < words.size(); ++i) { - EXPECT_CALL(*target_data_array, HasWord(words[i])) - .WillRepeatedly(Return(true)); EXPECT_CALL(*target_data_array, GetWordId(words[i])) .WillRepeatedly(Return(i + 2)); } - EXPECT_CALL(*target_data_array, HasWord("d")) - .WillRepeatedly(Return(false)); + EXPECT_CALL(*target_data_array, GetWordId("d")).WillRepeatedly(Return(-1)); vector> links1 = { make_pair(0, 0), make_pair(1, 1), make_pair(2, 2), make_pair(3, 3), -- cgit v1.2.3 From 3c73e472444ff0cd436b12f3679440a6969cbf2d Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Mon, 25 Nov 2013 23:56:31 +0000 Subject: Clean up leave-one-out sampling. --- extractor/grammar_extractor.cc | 6 ++++-- extractor/grammar_extractor.h | 4 +++- extractor/grammar_extractor_test.cc | 4 ++-- extractor/mocks/mock_rule_factory.h | 6 +++--- extractor/mocks/mock_sampler.h | 4 +++- extractor/rule_factory.cc | 7 +++++-- extractor/rule_factory.h | 3 +-- extractor/rule_factory_test.cc | 8 +++----- extractor/run_extractor.cc | 3 ++- extractor/sampler.cc | 12 ++++++++---- extractor/sampler.h | 4 +++- extractor/sampler_test.cc | 30 +++++++++++++++++++++--------- 12 files changed, 58 insertions(+), 33 deletions(-) (limited to 'extractor/mocks') diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 4d0738f7..1dc94c25 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -35,10 +35,12 @@ GrammarExtractor::GrammarExtractor( vocabulary(vocabulary), rule_factory(rule_factory) {} -Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array) { +Grammar GrammarExtractor::GetGrammar( + const string& sentence, + const unordered_set& blacklisted_sentence_ids) { vector words = TokenizeSentence(sentence); vector word_ids = AnnotateWords(words); - return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids); } vector GrammarExtractor::TokenizeSentence(const string& sentence) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 8f570df2..eb79f53c 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -46,7 +46,9 @@ class GrammarExtractor { // Converts the sentence to a vector of word ids and uses the RuleFactory to // extract the SCFG rules which may be used to decode the sentence. - Grammar GetGrammar(const string& sentence, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array); + Grammar GetGrammar( + const string& sentence, + const unordered_set& blacklisted_sentence_ids); private: // Splits the sentence in a vector of words. diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc index f32a9599..719e90ff 100644 --- a/extractor/grammar_extractor_test.cc +++ b/extractor/grammar_extractor_test.cc @@ -41,13 +41,13 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) { Grammar grammar(rules, feature_names); unordered_set blacklisted_sentence_ids; shared_ptr source_data_array; - EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array)) + EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids)) .WillOnce(Return(grammar)); GrammarExtractor extractor(vocabulary, factory); string sentence = "Anna has many many apples ."; - extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array); + extractor.GetGrammar(sentence, blacklisted_sentence_ids); } } // namespace diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 6b7b6586..53eb5022 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -7,9 +7,9 @@ namespace extractor { class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { public: - MOCK_METHOD3(GetGrammar, Grammar(const vector& word_ids, const - unordered_set& blacklisted_sentence_ids, - const shared_ptr source_data_array)); + MOCK_METHOD2(GetGrammar, Grammar( + const vector& word_ids, + const unordered_set& blacklisted_sentence_ids)); }; } // namespace extractor diff --git a/extractor/mocks/mock_sampler.h b/extractor/mocks/mock_sampler.h index 75c43c27..b2742f62 100644 --- a/extractor/mocks/mock_sampler.h +++ b/extractor/mocks/mock_sampler.h @@ -7,7 +7,9 @@ namespace extractor { class MockSampler : public Sampler { public: - MOCK_CONST_METHOD1(Sample, PhraseLocation(const PhraseLocation& location)); + MOCK_CONST_METHOD2(Sample, PhraseLocation( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids)); }; } // namespace extractor diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 6ae2d792..5b66f685 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -101,7 +101,9 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {} HieroCachingRuleFactory::~HieroCachingRuleFactory() {} -Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array) { +Grammar HieroCachingRuleFactory::GetGrammar( + const vector& word_ids, + const unordered_set& blacklisted_sentence_ids) { Clock::time_point start_time = Clock::now(); double total_extract_time = 0; double total_intersect_time = 0; @@ -193,7 +195,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids, const u Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { // Extract rules for the sampled set of occurrences. - PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array); + PhraseLocation sample = sampler->Sample( + next_node->matchings, blacklisted_sentence_ids); vector new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index a1ff76e4..1a9fa2af 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -74,8 +74,7 @@ class HieroCachingRuleFactory { // (See class description for more details.) virtual Grammar GetGrammar( const vector& word_ids, - const unordered_set& blacklisted_sentence_ids, - const shared_ptr source_data_array); + const unordered_set& blacklisted_sentence_ids); protected: HieroCachingRuleFactory(); diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index f26cc567..332c5959 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -40,7 +40,7 @@ class RuleFactoryTest : public Test { .WillRepeatedly(Return(feature_names)); sampler = make_shared(); - EXPECT_CALL(*sampler, Sample(_)) + EXPECT_CALL(*sampler, Sample(_, _)) .WillRepeatedly(Return(PhraseLocation(0, 1))); Phrase phrase; @@ -77,8 +77,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { vector word_ids = {2, 3, 4}; unordered_set blacklisted_sentence_ids; - shared_ptr source_data_array; - Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(7, grammar.GetRules().size()); } @@ -97,8 +96,7 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { vector word_ids = {2, 3, 4, 2, 3}; unordered_set blacklisted_sentence_ids; - shared_ptr source_data_array; - Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(28, grammar.GetRules().size()); } diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 85c8a422..6b22a302 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -237,7 +237,8 @@ int main(int argc, char** argv) { unordered_set blacklisted_sentence_ids; if (leave_one_out) blacklisted_sentence_ids.insert(i); - Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array); + Grammar grammar = extractor.GetGrammar( + sentences[i], blacklisted_sentence_ids); ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); output << grammar; } diff --git a/extractor/sampler.cc b/extractor/sampler.cc index 963afa7a..fc386ed1 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -12,7 +12,9 @@ Sampler::Sampler() {} Sampler::~Sampler() {} -PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array) const { +PhraseLocation Sampler::Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const { vector sample; int num_subpatterns; if (location.matchings == NULL) { @@ -22,10 +24,11 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s double step = max(1.0, (double) (high - low) / max_samples); double i = low, last = i; bool found; + shared_ptr source_data_array = suffix_array->GetData(); while (sample.size() < max_samples && i < high) { int x = suffix_array->GetSuffix(Round(i)); int id = source_data_array->GetSentenceId(x); - if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { + if (blacklisted_sentence_ids.count(id)) { found = false; double backoff_step = 1; while (true) { @@ -33,13 +36,14 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location, const unordered_s double j = i - backoff_step; x = suffix_array->GetSuffix(Round(j)); id = source_data_array->GetSentenceId(x); - if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { found = true; last = i; break; } double k = i + backoff_step; x = suffix_array->GetSuffix(Round(k)); id = source_data_array->GetSentenceId(x); - if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + if (k < min(i+step, (double)high) && + !blacklisted_sentence_ids.count(id)) { found = true; last = k; break; } if (j <= last && k >= high) break; diff --git a/extractor/sampler.h b/extractor/sampler.h index de450c48..bd8a5876 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -23,7 +23,9 @@ class Sampler { virtual ~Sampler(); // Samples uniformly at most max_samples phrase occurrences. - virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set& blacklisted_sentence_ids, const shared_ptr source_data_array) const; + virtual PhraseLocation Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const; protected: Sampler(); diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc index 965567ba..14e72780 100644 --- a/extractor/sampler_test.cc +++ b/extractor/sampler_test.cc @@ -19,6 +19,8 @@ class SamplerTest : public Test { source_data_array = make_shared(); EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999)); suffix_array = make_shared(); + EXPECT_CALL(*suffix_array, GetData()) + .WillRepeatedly(Return(source_data_array)); for (int i = 0; i < 10; ++i) { EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); } @@ -35,23 +37,29 @@ TEST_F(SamplerTest, TestSuffixArrayRange) { sampler = make_shared(suffix_array, 1); vector expected_locations = {0}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); + return; sampler = make_shared(suffix_array, 2); expected_locations = {0, 5}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 3); expected_locations = {0, 3, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 4); expected_locations = {0, 3, 5, 8}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 100); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler->Sample(location, blacklist)); } TEST_F(SamplerTest, TestSubstringsSample) { @@ -61,19 +69,23 @@ TEST_F(SamplerTest, TestSubstringsSample) { sampler = make_shared(suffix_array, 1); vector expected_locations = {0, 1}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 2); expected_locations = {0, 1, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 3); expected_locations = {0, 1, 4, 5, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); sampler = make_shared(suffix_array, 7); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklist)); } } // namespace -- cgit v1.2.3 From a6e6a369f40d8fb6a191fd7f74fc5efa8bfae2a0 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Wed, 27 Nov 2013 14:33:36 +0000 Subject: Unify sampling backoff strategy. --- extractor/Makefile.am | 24 ++++-- extractor/backoff_sampler.cc | 66 ++++++++++++++++ extractor/backoff_sampler.h | 41 ++++++++++ extractor/matchings_sampler.cc | 38 +++++++++ extractor/matchings_sampler.h | 31 ++++++++ extractor/matchings_sampler_test.cc | 118 ++++++++++++++++++++++++++++ extractor/mocks/mock_matchings_sampler.h | 15 ++++ extractor/mocks/mock_suffix_array_sampler.h | 15 ++++ extractor/phrase_location.cc | 2 + extractor/phrase_location_sampler.cc | 34 ++++++++ extractor/phrase_location_sampler.h | 35 +++++++++ extractor/phrase_location_sampler_test.cc | 50 ++++++++++++ extractor/precomputation.cc | 3 +- extractor/precomputation_test.cc | 2 +- extractor/rule_factory.cc | 4 +- extractor/sampler.cc | 78 ------------------ extractor/sampler.h | 22 +----- extractor/sampler_test.cc | 92 ---------------------- extractor/sampler_test_blacklist.cc | 102 ------------------------ extractor/suffix_array_sampler.cc | 40 ++++++++++ extractor/suffix_array_sampler.h | 34 ++++++++ extractor/suffix_array_sampler_test.cc | 114 +++++++++++++++++++++++++++ 22 files changed, 657 insertions(+), 303 deletions(-) create mode 100644 extractor/backoff_sampler.cc create mode 100644 extractor/backoff_sampler.h create mode 100644 extractor/matchings_sampler.cc create mode 100644 extractor/matchings_sampler.h create mode 100644 extractor/matchings_sampler_test.cc create mode 100644 extractor/mocks/mock_matchings_sampler.h create mode 100644 extractor/mocks/mock_suffix_array_sampler.h create mode 100644 extractor/phrase_location_sampler.cc create mode 100644 extractor/phrase_location_sampler.h create mode 100644 extractor/phrase_location_sampler_test.cc delete mode 100644 extractor/sampler.cc delete mode 100644 extractor/sampler_test.cc delete mode 100644 extractor/sampler_test_blacklist.cc create mode 100644 extractor/suffix_array_sampler.cc create mode 100644 extractor/suffix_array_sampler.h create mode 100644 extractor/suffix_array_sampler_test.cc (limited to 'extractor/mocks') diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 7825012c..e5b439f9 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -15,13 +15,15 @@ EXTRA_PROGRAMS = alignment_test \ feature_target_given_source_coherent_test \ grammar_extractor_test \ matchings_finder_test \ + matchings_sampler_test \ + phrase_location_sampler_test \ phrase_test \ precomputation_test \ rule_extractor_helper_test \ rule_extractor_test \ rule_factory_test \ - sampler_test \ scorer_test \ + suffix_array_sampler_test \ suffix_array_test \ target_phrase_extractor_test \ translation_table_test \ @@ -40,13 +42,15 @@ if HAVE_GTEST feature_target_given_source_coherent_test \ grammar_extractor_test \ matchings_finder_test \ + matchings_sampler_test \ + phrase_location_sampler_test \ phrase_test \ precomputation_test \ rule_extractor_helper_test \ rule_extractor_test \ rule_factory_test \ - sampler_test \ scorer_test \ + suffix_array_sampler_test \ suffix_array_test \ target_phrase_extractor_test \ translation_table_test \ @@ -55,8 +59,7 @@ endif noinst_PROGRAMS = $(RUNNABLE_TESTS) -# TESTS = $(RUNNABLE_TESTS) -TESTS = vocabulary_test +TESTS = $(RUNNABLE_TESTS) alignment_test_SOURCES = alignment_test.cc alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a @@ -82,6 +85,10 @@ grammar_extractor_test_SOURCES = grammar_extractor_test.cc grammar_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a matchings_finder_test_SOURCES = matchings_finder_test.cc matchings_finder_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +matchings_sampler_test_SOURCES = matchings_sampler_test.cc +matchings_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +phrase_location_sampler_test_SOURCES = phrase_location_sampler_test.cc +phrase_location_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a phrase_test_SOURCES = phrase_test.cc phrase_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a precomputation_test_SOURCES = precomputation_test.cc @@ -92,10 +99,10 @@ rule_extractor_test_SOURCES = rule_extractor_test.cc rule_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a rule_factory_test_SOURCES = rule_factory_test.cc rule_factory_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a -sampler_test_SOURCES = sampler_test.cc -sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a scorer_test_SOURCES = scorer_test.cc scorer_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a +suffix_array_sampler_test_SOURCES = suffix_array_sampler_test.cc +suffix_array_sampler_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a suffix_array_test_SOURCES = suffix_array_test.cc suffix_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc @@ -116,6 +123,7 @@ extract_LDADD = libextractor.a libextractor_a_SOURCES = \ alignment.cc \ + backoff_sampler.cc \ data_array.cc \ fast_intersector.cc \ features/count_source_target.cc \ @@ -129,18 +137,20 @@ libextractor_a_SOURCES = \ grammar.cc \ grammar_extractor.cc \ matchings_finder.cc \ + matchings_sampler.cc \ matchings_trie.cc \ phrase.cc \ phrase_builder.cc \ phrase_location.cc \ + phrase_location_sampler.cc \ precomputation.cc \ rule.cc \ rule_extractor.cc \ rule_extractor_helper.cc \ rule_factory.cc \ - sampler.cc \ scorer.cc \ suffix_array.cc \ + suffix_array_sampler.cc \ target_phrase_extractor.cc \ time_util.cc \ translation_table.cc \ diff --git a/extractor/backoff_sampler.cc b/extractor/backoff_sampler.cc new file mode 100644 index 00000000..28b12909 --- /dev/null +++ b/extractor/backoff_sampler.cc @@ -0,0 +1,66 @@ +#include "backoff_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" + +namespace extractor { + +BackoffSampler::BackoffSampler( + shared_ptr source_data_array, int max_samples) : + source_data_array(source_data_array), max_samples(max_samples) {} + +BackoffSampler::BackoffSampler() {} + +PhraseLocation BackoffSampler::Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const { + vector samples; + int low = GetRangeLow(location), high = GetRangeHigh(location); + int last_position = low - 1; + double step = max(1.0, (double) (high - low) / max_samples); + for (double num_samples = 0, i = low; + num_samples < max_samples && i < high; + ++num_samples, i += step) { + int position = GetPosition(location, round(i)); + int sentence_id = source_data_array->GetSentenceId(position); + bool found = false; + if (last_position >= position || + blacklisted_sentence_ids.count(sentence_id)) { + for (double backoff_step = 1; backoff_step < step; ++backoff_step) { + double j = i - backoff_step; + if (round(j) >= 0) { + position = GetPosition(location, round(j)); + sentence_id = source_data_array->GetSentenceId(position); + if (position > last_position && + !blacklisted_sentence_ids.count(sentence_id)) { + found = true; + last_position = position; + break; + } + } + + double k = i + backoff_step; + if (round(k) < high) { + position = GetPosition(location, round(k)); + sentence_id = source_data_array->GetSentenceId(position); + if (!blacklisted_sentence_ids.count(sentence_id)) { + found = true; + last_position = position; + break; + } + } + } + } else { + found = true; + last_position = position; + } + + if (found) { + AppendMatching(samples, position, location); + } + } + + return PhraseLocation(samples, GetNumSubpatterns(location)); +} + +} // namespace extractor diff --git a/extractor/backoff_sampler.h b/extractor/backoff_sampler.h new file mode 100644 index 00000000..5c244105 --- /dev/null +++ b/extractor/backoff_sampler.h @@ -0,0 +1,41 @@ +#ifndef _BACKOFF_SAMPLER_H_ +#define _BACKOFF_SAMPLER_H_ + +#include + +#include "sampler.h" + +namespace extractor { + +class DataArray; +class PhraseLocation; + +class BackoffSampler : public Sampler { + public: + BackoffSampler(shared_ptr source_data_array, int max_samples); + + BackoffSampler(); + + PhraseLocation Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const; + + private: + virtual int GetNumSubpatterns(const PhraseLocation& location) const = 0; + + virtual int GetRangeLow(const PhraseLocation& location) const = 0; + + virtual int GetRangeHigh(const PhraseLocation& location) const = 0; + + virtual int GetPosition(const PhraseLocation& location, int index) const = 0; + + virtual void AppendMatching(vector& samples, int index, + const PhraseLocation& location) const = 0; + + shared_ptr source_data_array; + int max_samples; +}; + +} // namespace extractor + +#endif diff --git a/extractor/matchings_sampler.cc b/extractor/matchings_sampler.cc new file mode 100644 index 00000000..bb916e49 --- /dev/null +++ b/extractor/matchings_sampler.cc @@ -0,0 +1,38 @@ +#include "matchings_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" + +namespace extractor { + +MatchingsSampler::MatchingsSampler( + shared_ptr data_array, int max_samples) : + BackoffSampler(data_array, max_samples) {} + +MatchingsSampler::MatchingsSampler() {} + +int MatchingsSampler::GetNumSubpatterns(const PhraseLocation& location) const { + return location.num_subpatterns; +} + +int MatchingsSampler::GetRangeLow(const PhraseLocation&) const { + return 0; +} + +int MatchingsSampler::GetRangeHigh(const PhraseLocation& location) const { + return location.matchings->size() / location.num_subpatterns; +} + +int MatchingsSampler::GetPosition(const PhraseLocation& location, + int index) const { + return (*location.matchings)[index * location.num_subpatterns]; +} + +void MatchingsSampler::AppendMatching(vector& samples, int index, + const PhraseLocation& location) const { + copy(location.matchings->begin() + index, + location.matchings->begin() + index + location.num_subpatterns, + back_inserter(samples)); +} + +} // namespace extractor diff --git a/extractor/matchings_sampler.h b/extractor/matchings_sampler.h new file mode 100644 index 00000000..ca4fce93 --- /dev/null +++ b/extractor/matchings_sampler.h @@ -0,0 +1,31 @@ +#ifndef _MATCHINGS_SAMPLER_H_ +#define _MATCHINGS_SAMPLER_H_ + +#include "backoff_sampler.h" + +namespace extractor { + +class DataArray; + +class MatchingsSampler : public BackoffSampler { + public: + MatchingsSampler(shared_ptr data_array, int max_samples); + + MatchingsSampler(); + + private: + int GetNumSubpatterns(const PhraseLocation& location) const; + + int GetRangeLow(const PhraseLocation& location) const; + + int GetRangeHigh(const PhraseLocation& location) const; + + int GetPosition(const PhraseLocation& location, int index) const; + + void AppendMatching(vector& samples, int index, + const PhraseLocation& location) const; +}; + +} // namespace extractor + +#endif diff --git a/extractor/matchings_sampler_test.cc b/extractor/matchings_sampler_test.cc new file mode 100644 index 00000000..bc927152 --- /dev/null +++ b/extractor/matchings_sampler_test.cc @@ -0,0 +1,118 @@ +#include + +#include + +#include "mocks/mock_data_array.h" +#include "matchings_sampler.h" +#include "phrase_location.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class MatchingsSamplerTest : public Test { + protected: + virtual void SetUp() { + vector locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + location = PhraseLocation(locations, 2); + + data_array = make_shared(); + for (int i = 0; i < 10; ++i) { + EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i / 2)); + } + } + + unordered_set blacklisted_sentence_ids; + PhraseLocation location; + shared_ptr data_array; + shared_ptr sampler; +}; + +TEST_F(MatchingsSamplerTest, TestSample) { + sampler = make_shared(data_array, 1); + vector expected_locations = {0, 1}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 2); + expected_locations = {0, 1, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 3); + expected_locations = {0, 1, 4, 5, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 7); + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(MatchingsSamplerTest, TestBackoffSample) { + sampler = make_shared(data_array, 1); + blacklisted_sentence_ids = {0}; + vector expected_locations = {2, 3}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 1, 2, 3}; + expected_locations = {8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 1, 2, 3, 4}; + expected_locations = {}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 2); + blacklisted_sentence_ids = {0, 3}; + expected_locations = {2, 3, 4, 5}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 3); + blacklisted_sentence_ids = {0, 3}; + expected_locations = {2, 3, 4, 5, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 2, 3}; + expected_locations = {2, 3, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 4); + blacklisted_sentence_ids = {0, 1, 2, 3}; + expected_locations = {8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {1, 3}; + expected_locations = {0, 1, 4, 5, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + sampler = make_shared(data_array, 7); + blacklisted_sentence_ids = {0, 1, 2, 3, 4}; + expected_locations = {}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 2, 4}; + expected_locations = {2, 3, 6, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {1, 3}; + expected_locations = {0, 1, 4, 5, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 2), + sampler->Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor diff --git a/extractor/mocks/mock_matchings_sampler.h b/extractor/mocks/mock_matchings_sampler.h new file mode 100644 index 00000000..de2009c3 --- /dev/null +++ b/extractor/mocks/mock_matchings_sampler.h @@ -0,0 +1,15 @@ +#include + +#include "phrase_location.h" +#include "matchings_sampler.h" + +namespace extractor { + +class MockMatchingsSampler : public MatchingsSampler { + public: + MOCK_CONST_METHOD2(Sample, PhraseLocation( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids)); +}; + +} // namespace extractor diff --git a/extractor/mocks/mock_suffix_array_sampler.h b/extractor/mocks/mock_suffix_array_sampler.h new file mode 100644 index 00000000..d799b969 --- /dev/null +++ b/extractor/mocks/mock_suffix_array_sampler.h @@ -0,0 +1,15 @@ +#include + +#include "phrase_location.h" +#include "suffix_array_sampler.h" + +namespace extractor { + +class MockSuffixArraySampler : public SuffixArrayRangeSampler { + public: + MOCK_CONST_METHOD2(Sample, PhraseLocation( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids)); +}; + +} // namespace extractor diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc index 13140cac..2c367893 100644 --- a/extractor/phrase_location.cc +++ b/extractor/phrase_location.cc @@ -1,5 +1,7 @@ #include "phrase_location.h" +#include + namespace extractor { PhraseLocation::PhraseLocation(int sa_low, int sa_high) : diff --git a/extractor/phrase_location_sampler.cc b/extractor/phrase_location_sampler.cc new file mode 100644 index 00000000..a2eec105 --- /dev/null +++ b/extractor/phrase_location_sampler.cc @@ -0,0 +1,34 @@ +#include "phrase_location_sampler.h" + +#include "matchings_sampler.h" +#include "phrase_location.h" +#include "suffix_array.h" +#include "suffix_array_sampler.h" + +namespace extractor { + +PhraseLocationSampler::PhraseLocationSampler( + shared_ptr suffix_array, int max_samples) { + matchings_sampler = make_shared( + suffix_array->GetData(), max_samples); + suffix_array_sampler = make_shared( + suffix_array, max_samples); +} + +PhraseLocationSampler::PhraseLocationSampler( + shared_ptr matchings_sampler, + shared_ptr suffix_array_sampler) : + matchings_sampler(matchings_sampler), + suffix_array_sampler(suffix_array_sampler) {} + +PhraseLocation PhraseLocationSampler::Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const { + if (location.matchings == NULL) { + return suffix_array_sampler->Sample(location, blacklisted_sentence_ids); + } else { + return matchings_sampler->Sample(location, blacklisted_sentence_ids); + } +} + +} // namespace extractor diff --git a/extractor/phrase_location_sampler.h b/extractor/phrase_location_sampler.h new file mode 100644 index 00000000..0e88335e --- /dev/null +++ b/extractor/phrase_location_sampler.h @@ -0,0 +1,35 @@ +#ifndef _PHRASE_LOCATION_SAMPLER_H_ +#define _PHRASE_LOCATION_SAMPLER_H_ + +#include + +#include "sampler.h" + +namespace extractor { + +class MatchingsSampler; +class PhraseLocation; +class SuffixArray; +class SuffixArrayRangeSampler; + +class PhraseLocationSampler : public Sampler { + public: + PhraseLocationSampler(shared_ptr suffix_array, int max_samples); + + // For testing only. + PhraseLocationSampler( + shared_ptr matchings_sampler, + shared_ptr suffix_array_sampler); + + PhraseLocation Sample( + const PhraseLocation& location, + const unordered_set& blacklisted_sentence_ids) const; + + private: + shared_ptr matchings_sampler; + shared_ptr suffix_array_sampler; +}; + +} // namespace extractor + +#endif diff --git a/extractor/phrase_location_sampler_test.cc b/extractor/phrase_location_sampler_test.cc new file mode 100644 index 00000000..e7520ce7 --- /dev/null +++ b/extractor/phrase_location_sampler_test.cc @@ -0,0 +1,50 @@ +#include + +#include + +#include "mocks/mock_matchings_sampler.h" +#include "mocks/mock_suffix_array_sampler.h" +#include "phrase_location.h" +#include "phrase_location_sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class MatchingsSamplerTest : public Test { + protected: + virtual void SetUp() { + matchings_sampler = make_shared(); + suffix_array_sampler = make_shared(); + + sampler = make_shared( + matchings_sampler, suffix_array_sampler); + } + + shared_ptr matchings_sampler; + shared_ptr suffix_array_sampler; + shared_ptr sampler; +}; + +TEST_F(MatchingsSamplerTest, TestSuffixArrayRange) { + vector locations = {0, 1, 2, 3}; + PhraseLocation location(0, 3), result(locations, 2); + unordered_set blacklisted_sentence_ids; + EXPECT_CALL(*suffix_array_sampler, Sample(location, blacklisted_sentence_ids)) + .WillOnce(Return(result)); + EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(MatchingsSamplerTest, TestMatchings) { + vector locations = {0, 1, 2, 3}; + PhraseLocation location(locations, 2), result(locations, 2); + unordered_set blacklisted_sentence_ids; + EXPECT_CALL(*matchings_sampler, Sample(location, blacklisted_sentence_ids)) + .WillOnce(Return(result)); + EXPECT_EQ(result, sampler->Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index b79daae3..3e58e2a9 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -91,7 +91,6 @@ vector> Precomputation::FindMostFrequentPatterns( } } - shared_ptr data_array = suffix_array->GetData(); // Extract the most frequent patterns. vector> frequent_patterns; while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { @@ -99,7 +98,7 @@ vector> Precomputation::FindMostFrequentPatterns( int len = heap.top().second.second; heap.pop(); - vector pattern = data_array->GetWordIds(start, len); + vector pattern(data.begin() + start, data.begin() + start + len); if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == pattern.end()) { frequent_patterns.push_back(pattern); diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index d5f5ef63..3a98ce05 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -94,7 +94,7 @@ TEST_F(PrecomputationTest, TestCollocations) { EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {2, -1, 2, -1, 2}; + key = {2, -1, 2, -2, 2}; expected_value = {1, 5, 8, 5, 8, 11}; EXPECT_TRUE(precomputation.Contains(key)); EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 5b66f685..18a60695 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -12,6 +12,7 @@ #include "phrase_builder.h" #include "rule.h" #include "rule_extractor.h" +#include "phrase_location_sampler.h" #include "sampler.h" #include "scorer.h" #include "suffix_array.h" @@ -68,7 +69,8 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( target_data_array, alignment, phrase_builder, scorer, vocabulary, max_rule_span, min_gap_size, max_nonterminals, max_rule_symbols, true, false, require_tight_phrases); - sampler = make_shared(source_suffix_array, max_samples); + sampler = make_shared( + source_suffix_array, max_samples); } HieroCachingRuleFactory::HieroCachingRuleFactory( diff --git a/extractor/sampler.cc b/extractor/sampler.cc deleted file mode 100644 index 887aaec1..00000000 --- a/extractor/sampler.cc +++ /dev/null @@ -1,78 +0,0 @@ -#include "sampler.h" - -#include "phrase_location.h" -#include "suffix_array.h" - -namespace extractor { - -Sampler::Sampler(shared_ptr suffix_array, int max_samples) : - suffix_array(suffix_array), max_samples(max_samples) {} - -Sampler::Sampler() {} - -Sampler::~Sampler() {} - -PhraseLocation Sampler::Sample( - const PhraseLocation& location, - const unordered_set& blacklisted_sentence_ids) const { - shared_ptr source_data_array = suffix_array->GetData(); - vector sample; - int num_subpatterns; - if (location.matchings == NULL) { - // Sample suffix array range. - num_subpatterns = 1; - int low = location.sa_low, high = location.sa_high; - double step = max(1.0, (double) (high - low) / max_samples); - double i = low, last = i - 1; - while (sample.size() < max_samples && i < high) { - int x = suffix_array->GetSuffix(Round(i)); - int id = source_data_array->GetSentenceId(x); - bool found = false; - if (blacklisted_sentence_ids.count(id)) { - for (int backoff_step = 1; backoff_step <= step; ++backoff_step) { - double j = i - backoff_step; - x = suffix_array->GetSuffix(Round(j)); - id = source_data_array->GetSentenceId(x); - if (x >= 0 && j > last && !blacklisted_sentence_ids.count(id)) { - found = true; - last = i; - break; - } - double k = i + backoff_step; - x = suffix_array->GetSuffix(Round(k)); - id = source_data_array->GetSentenceId(x); - if (k < min(i+step, (double) high) && - !blacklisted_sentence_ids.count(id)) { - found = true; - last = k; - break; - } - } - } else { - found = true; - last = i; - } - if (found) sample.push_back(x); - i += step; - } - } else { - // Sample vector of occurrences. - num_subpatterns = location.num_subpatterns; - int num_matchings = location.matchings->size() / num_subpatterns; - double step = max(1.0, (double) num_matchings / max_samples); - for (double i = 0, num_samples = 0; - i < num_matchings && num_samples < max_samples; - i += step, ++num_samples) { - int start = Round(i) * num_subpatterns; - sample.insert(sample.end(), location.matchings->begin() + start, - location.matchings->begin() + start + num_subpatterns); - } - } - return PhraseLocation(sample, num_subpatterns); -} - -int Sampler::Round(double x) const { - return x + 0.5; -} - -} // namespace extractor diff --git a/extractor/sampler.h b/extractor/sampler.h index bd8a5876..3c4e37f1 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -4,38 +4,20 @@ #include #include -#include "data_array.h" - using namespace std; namespace extractor { class PhraseLocation; -class SuffixArray; /** - * Provides uniform sampling for a PhraseLocation. + * Base sampler class. */ class Sampler { public: - Sampler(shared_ptr suffix_array, int max_samples); - - virtual ~Sampler(); - - // Samples uniformly at most max_samples phrase occurrences. virtual PhraseLocation Sample( const PhraseLocation& location, - const unordered_set& blacklisted_sentence_ids) const; - - protected: - Sampler(); - - private: - // Round floating point number to the nearest integer. - int Round(double x) const; - - shared_ptr suffix_array; - int max_samples; + const unordered_set& blacklisted_sentence_ids) const = 0; }; } // namespace extractor diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc deleted file mode 100644 index 14e72780..00000000 --- a/extractor/sampler_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -#include - -#include - -#include "mocks/mock_suffix_array.h" -#include "mocks/mock_data_array.h" -#include "phrase_location.h" -#include "sampler.h" - -using namespace std; -using namespace ::testing; - -namespace extractor { -namespace { - -class SamplerTest : public Test { - protected: - virtual void SetUp() { - source_data_array = make_shared(); - EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999)); - suffix_array = make_shared(); - EXPECT_CALL(*suffix_array, GetData()) - .WillRepeatedly(Return(source_data_array)); - for (int i = 0; i < 10; ++i) { - EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); - } - } - - shared_ptr suffix_array; - shared_ptr sampler; - shared_ptr source_data_array; -}; - -TEST_F(SamplerTest, TestSuffixArrayRange) { - PhraseLocation location(0, 10); - unordered_set blacklist; - - sampler = make_shared(suffix_array, 1); - vector expected_locations = {0}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); - return; - - sampler = make_shared(suffix_array, 2); - expected_locations = {0, 5}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 3); - expected_locations = {0, 3, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 4); - expected_locations = {0, 3, 5, 8}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 100); - expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), - sampler->Sample(location, blacklist)); -} - -TEST_F(SamplerTest, TestSubstringsSample) { - vector locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - unordered_set blacklist; - PhraseLocation location(locations, 2); - - sampler = make_shared(suffix_array, 1); - vector expected_locations = {0, 1}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 2); - expected_locations = {0, 1, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 3); - expected_locations = {0, 1, 4, 5, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), - sampler->Sample(location, blacklist)); - - sampler = make_shared(suffix_array, 7); - expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), - sampler->Sample(location, blacklist)); -} - -} // namespace -} // namespace extractor diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc deleted file mode 100644 index 3305b990..00000000 --- a/extractor/sampler_test_blacklist.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include - -#include - -#include "mocks/mock_suffix_array.h" -#include "mocks/mock_data_array.h" -#include "phrase_location.h" -#include "sampler.h" - -using namespace std; -using namespace ::testing; - -namespace extractor { -namespace { - -class SamplerTestBlacklist : public Test { - protected: - virtual void SetUp() { - source_data_array = make_shared(); - for (int i = 0; i < 10; ++i) { - EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); - } - for (int i = -10; i < 0; ++i) { - EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0)); - } - suffix_array = make_shared(); - for (int i = -10; i < 10; ++i) { - EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); - } - } - - shared_ptr suffix_array; - shared_ptr sampler; - shared_ptr source_data_array; -}; - -TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) { - PhraseLocation location(0, 10); - unordered_set blacklist; - vector expected_locations; - - blacklist.insert(0); - sampler = make_shared(suffix_array, 1); - expected_locations = {1}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - for (int i = 0; i < 9; i++) { - blacklist.insert(i); - } - sampler = make_shared(suffix_array, 1); - expected_locations = {9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - blacklist.insert(5); - sampler = make_shared(suffix_array, 2); - expected_locations = {1, 4}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - blacklist.insert(1); - blacklist.insert(2); - blacklist.insert(3); - sampler = make_shared(suffix_array, 2); - expected_locations = {4, 5}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - blacklist.insert(3); - blacklist.insert(7); - sampler = make_shared(suffix_array, 3); - expected_locations = {1, 2, 6}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - blacklist.insert(3); - blacklist.insert(5); - blacklist.insert(8); - sampler = make_shared(suffix_array, 4); - expected_locations = {1, 2, 4, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(0); - sampler = make_shared(suffix_array, 100); - expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); - blacklist.clear(); - - blacklist.insert(9); - sampler = make_shared(suffix_array, 100); - expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); -} - -} // namespace -} // namespace extractor diff --git a/extractor/suffix_array_sampler.cc b/extractor/suffix_array_sampler.cc new file mode 100644 index 00000000..4a4ced34 --- /dev/null +++ b/extractor/suffix_array_sampler.cc @@ -0,0 +1,40 @@ +#include "suffix_array_sampler.h" + +#include "data_array.h" +#include "phrase_location.h" +#include "suffix_array.h" + +namespace extractor { + +SuffixArrayRangeSampler::SuffixArrayRangeSampler( + shared_ptr source_suffix_array, int max_samples) : + BackoffSampler(source_suffix_array->GetData(), max_samples), + source_suffix_array(source_suffix_array) {} + +SuffixArrayRangeSampler::SuffixArrayRangeSampler() {} + +int SuffixArrayRangeSampler::GetNumSubpatterns(const PhraseLocation&) const { + return 1; +} + +int SuffixArrayRangeSampler::GetRangeLow( + const PhraseLocation& location) const { + return location.sa_low; +} + +int SuffixArrayRangeSampler::GetRangeHigh( + const PhraseLocation& location) const { + return location.sa_high; +} + +int SuffixArrayRangeSampler::GetPosition( + const PhraseLocation&, int position) const { + return source_suffix_array->GetSuffix(position); +} + +void SuffixArrayRangeSampler::AppendMatching( + vector& samples, int index, const PhraseLocation&) const { + samples.push_back(source_suffix_array->GetSuffix(index)); +} + +} // namespace extractor diff --git a/extractor/suffix_array_sampler.h b/extractor/suffix_array_sampler.h new file mode 100644 index 00000000..bb3c2653 --- /dev/null +++ b/extractor/suffix_array_sampler.h @@ -0,0 +1,34 @@ +#ifndef _SUFFIX_ARRAY_SAMPLER_H_ +#define _SUFFIX_ARRAY_SAMPLER_H_ + +#include "backoff_sampler.h" + +namespace extractor { + +class SuffixArray; + +class SuffixArrayRangeSampler : public BackoffSampler { + public: + SuffixArrayRangeSampler(shared_ptr suffix_array, + int max_samples); + + SuffixArrayRangeSampler(); + + private: + int GetNumSubpatterns(const PhraseLocation& location) const; + + int GetRangeLow(const PhraseLocation& location) const; + + int GetRangeHigh(const PhraseLocation& location) const; + + int GetPosition(const PhraseLocation& location, int index) const; + + void AppendMatching(vector& samples, int index, + const PhraseLocation& location) const; + + shared_ptr source_suffix_array; +}; + +} // namespace extractor + +#endif diff --git a/extractor/suffix_array_sampler_test.cc b/extractor/suffix_array_sampler_test.cc new file mode 100644 index 00000000..4b88c027 --- /dev/null +++ b/extractor/suffix_array_sampler_test.cc @@ -0,0 +1,114 @@ +#include + +#include + +#include "mocks/mock_data_array.h" +#include "mocks/mock_suffix_array.h" +#include "suffix_array_sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class SuffixArraySamplerTest : public Test { + protected: + virtual void SetUp() { + data_array = make_shared(); + for (int i = 0; i < 10; ++i) { + EXPECT_CALL(*data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); + } + + suffix_array = make_shared(); + EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); + for (int i = 0; i < 10; ++i) { + EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); + } + } + + shared_ptr data_array; + shared_ptr suffix_array; +}; + +TEST_F(SuffixArraySamplerTest, TestSample) { + PhraseLocation location(0, 10); + unordered_set blacklisted_sentence_ids; + + SuffixArrayRangeSampler sampler(suffix_array, 1); + vector expected_locations = {0}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 2); + expected_locations = {0, 5}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 3); + expected_locations = {0, 3, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 4); + expected_locations = {0, 3, 5, 8}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 100); + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); +} + +TEST_F(SuffixArraySamplerTest, TestBackoffSample) { + PhraseLocation location(0, 10); + + SuffixArrayRangeSampler sampler(suffix_array, 1); + unordered_set blacklisted_sentence_ids = {0}; + vector expected_locations = {1}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + expected_locations = {9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 2); + blacklisted_sentence_ids = {0, 5}; + expected_locations = {1, 4}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {0, 1, 2, 3}; + expected_locations = {4, 5}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 3); + blacklisted_sentence_ids = {0, 3, 7}; + expected_locations = {1, 2, 6}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 4); + blacklisted_sentence_ids = {0, 3, 5, 8}; + expected_locations = {1, 2, 4, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + sampler = SuffixArrayRangeSampler(suffix_array, 100); + blacklisted_sentence_ids = {0}; + expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); + + blacklisted_sentence_ids = {9}; + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), + sampler.Sample(location, blacklisted_sentence_ids)); +} + +} +} // namespace extractor -- cgit v1.2.3