diff options
-rw-r--r-- | extractor/compile.cc | 4 | ||||
-rw-r--r-- | extractor/data_array.cc | 14 | ||||
-rw-r--r-- | extractor/data_array.h | 10 | ||||
-rw-r--r-- | extractor/data_array_test.cc | 12 | ||||
-rw-r--r-- | extractor/fast_intersector.cc | 40 | ||||
-rw-r--r-- | extractor/fast_intersector.h | 8 | ||||
-rw-r--r-- | extractor/fast_intersector_test.cc | 10 | ||||
-rw-r--r-- | extractor/grammar_extractor.cc | 5 | ||||
-rw-r--r-- | extractor/grammar_extractor.h | 1 | ||||
-rw-r--r-- | extractor/mocks/mock_data_array.h | 4 | ||||
-rw-r--r-- | extractor/mocks/mock_precomputation.h | 3 | ||||
-rw-r--r-- | extractor/precomputation.cc | 96 | ||||
-rw-r--r-- | extractor/precomputation.h | 45 | ||||
-rw-r--r-- | extractor/precomputation_test.cc | 110 | ||||
-rw-r--r-- | extractor/run_extractor.cc | 5 | ||||
-rw-r--r-- | extractor/suffix_array_test.cc | 2 | ||||
-rw-r--r-- | extractor/translation_table_test.cc | 4 |
17 files changed, 225 insertions, 148 deletions
diff --git a/extractor/compile.cc b/extractor/compile.cc index 65fdd509..0d62757e 100644 --- a/extractor/compile.cc +++ b/extractor/compile.cc @@ -13,6 +13,7 @@ #include "suffix_array.h" #include "time_util.h" #include "translation_table.h" +#include "vocabulary.h" namespace ar = boost::archive; namespace fs = boost::filesystem; @@ -125,9 +126,12 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; + shared_ptr<Vocabulary> vocabulary; + start_time = Clock::now(); cerr << "Precomputing collocations..." << endl; Precomputation precomputation( + vocabulary, source_suffix_array, vm["frequent"].as<int>(), vm["super_frequent"].as<int>(), diff --git a/extractor/data_array.cc b/extractor/data_array.cc index 82efcd51..dacc4283 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -78,7 +78,7 @@ void DataArray::CreateDataArray(const vector<string>& lines) { DataArray::~DataArray() {} -const vector<int>& DataArray::GetData() const { +vector<int> DataArray::GetData() const { return data; } @@ -90,6 +90,18 @@ string DataArray::GetWordAtIndex(int index) const { return id2word[data[index]]; } +vector<int> DataArray::GetWordIds(int index, int size) const { + return vector<int>(data.begin() + index, data.begin() + index + size); +} + +vector<string> DataArray::GetWords(int start_index, int size) const { + vector<string> words; + for (int word_id: GetWordIds(start_index, size)) { + words.push_back(id2word[word_id]); + } + return words; +} + int DataArray::GetSize() const { return data.size(); } diff --git a/extractor/data_array.h b/extractor/data_array.h index 5207366d..e3823d18 100644 --- a/extractor/data_array.h +++ b/extractor/data_array.h @@ -51,7 +51,7 @@ class DataArray { virtual ~DataArray(); // Returns a vector containing the word ids. - virtual const vector<int>& GetData() const; + virtual vector<int> GetData() const; // Returns the word id at the specified position. virtual int AtIndex(int index) const; @@ -59,6 +59,14 @@ class DataArray { // Returns the original word at the specified position. virtual string GetWordAtIndex(int index) const; + // Returns the substring of word ids starting at the specified position and + // having the specified length. + virtual vector<int> GetWordIds(int start_index, int size) const; + + // Returns the substring of words starting at the specified position and + // having the specified length. + virtual vector<string> GetWords(int start_index, int size) const; + // Returns the size of the data array. virtual int GetSize() const; diff --git a/extractor/data_array_test.cc b/extractor/data_array_test.cc index 6c329e34..7b085cd9 100644 --- a/extractor/data_array_test.cc +++ b/extractor/data_array_test.cc @@ -56,6 +56,18 @@ TEST_F(DataArrayTest, TestGetData) { } } +TEST_F(DataArrayTest, TestSubstrings) { + vector<int> expected_word_ids = {3, 4, 5}; + vector<string> expected_words = {"are", "mere", "."}; + EXPECT_EQ(expected_word_ids, source_data.GetWordIds(1, 3)); + EXPECT_EQ(expected_words, source_data.GetWords(1, 3)); + + expected_word_ids = {7, 8}; + expected_words = {"a", "lot"}; + EXPECT_EQ(expected_word_ids, target_data.GetWordIds(7, 2)); + EXPECT_EQ(expected_words, target_data.GetWords(7, 2)); +} + TEST_F(DataArrayTest, TestVocabulary) { EXPECT_EQ(9, source_data.GetVocabularySize()); EXPECT_TRUE(source_data.HasWord("mere")); diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc index a8591a72..0d1fa6d8 100644 --- a/extractor/fast_intersector.cc +++ b/extractor/fast_intersector.cc @@ -11,41 +11,22 @@ namespace extractor { -FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array, - shared_ptr<Precomputation> precomputation, - shared_ptr<Vocabulary> vocabulary, - int max_rule_span, - int min_gap_size) : +FastIntersector::FastIntersector( + shared_ptr<SuffixArray> suffix_array, + shared_ptr<Precomputation> precomputation, + shared_ptr<Vocabulary> vocabulary, + int max_rule_span, + int min_gap_size) : suffix_array(suffix_array), + precomputation(precomputation), vocabulary(vocabulary), max_rule_span(max_rule_span), - min_gap_size(min_gap_size) { - Index precomputed_collocations = precomputation->GetCollocations(); - for (pair<vector<int>, vector<int>> entry: precomputed_collocations) { - vector<int> phrase = ConvertPhrase(entry.first); - collocations[phrase] = entry.second; - } -} + min_gap_size(min_gap_size) {} FastIntersector::FastIntersector() {} FastIntersector::~FastIntersector() {} -vector<int> FastIntersector::ConvertPhrase(const vector<int>& old_phrase) { - vector<int> new_phrase; - new_phrase.reserve(old_phrase.size()); - shared_ptr<DataArray> data_array = suffix_array->GetData(); - for (int word_id: old_phrase) { - if (word_id < 0) { - new_phrase.push_back(word_id); - } else { - new_phrase.push_back( - vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); - } - } - return new_phrase; -} - PhraseLocation FastIntersector::Intersect( PhraseLocation& prefix_location, PhraseLocation& suffix_location, @@ -59,8 +40,9 @@ PhraseLocation FastIntersector::Intersect( assert(vocabulary->IsTerminal(symbols.front()) && vocabulary->IsTerminal(symbols.back())); - if (collocations.count(symbols)) { - return PhraseLocation(collocations[symbols], phrase.Arity() + 1); + if (precomputation->Contains(symbols)) { + return PhraseLocation(precomputation->GetCollocations(symbols), + phrase.Arity() + 1); } bool prefix_ends_with_x = diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h index 2819d239..305373dc 100644 --- a/extractor/fast_intersector.h +++ b/extractor/fast_intersector.h @@ -12,7 +12,6 @@ using namespace std; namespace extractor { typedef boost::hash<vector<int>> VectorHash; -typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; class Phrase; class PhraseLocation; @@ -52,11 +51,6 @@ class FastIntersector { FastIntersector(); private: - // Uses the vocabulary to convert the phrase from the numberized format - // specified by the source data array to the numberized format given by the - // vocabulary. - vector<int> ConvertPhrase(const vector<int>& old_phrase); - // Estimates the number of computations needed if the prefix/suffix is // extended. If the last/first symbol is separated from the rest of the phrase // by a nonterminal, then for each occurrence of the prefix/suffix we need to @@ -85,10 +79,10 @@ class FastIntersector { pair<int, int> GetSearchRange(bool has_marginal_x) const; shared_ptr<SuffixArray> suffix_array; + shared_ptr<Precomputation> precomputation; shared_ptr<Vocabulary> vocabulary; int max_rule_span; int min_gap_size; - Index collocations; }; } // namespace extractor diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc index 76c3aaea..f2a26ba1 100644 --- a/extractor/fast_intersector_test.cc +++ b/extractor/fast_intersector_test.cc @@ -59,15 +59,13 @@ class FastIntersectorTest : public Test { } precomputation = make_shared<MockPrecomputation>(); - EXPECT_CALL(*precomputation, GetCollocations()) - .WillRepeatedly(ReturnRef(collocations)); + EXPECT_CALL(*precomputation, Contains(_)).WillRepeatedly(Return(false)); phrase_builder = make_shared<PhraseBuilder>(vocabulary); intersector = make_shared<FastIntersector>(suffix_array, precomputation, vocabulary, 15, 1); } - Index collocations; shared_ptr<MockDataArray> data_array; shared_ptr<MockSuffixArray> suffix_array; shared_ptr<MockPrecomputation> precomputation; @@ -82,9 +80,9 @@ TEST_F(FastIntersectorTest, TestCachedCollocation) { Phrase phrase = phrase_builder->Build(symbols); PhraseLocation prefix_location(15, 16), suffix_location(16, 17); - collocations[symbols] = expected_location; - EXPECT_CALL(*precomputation, GetCollocations()) - .WillRepeatedly(ReturnRef(collocations)); + EXPECT_CALL(*precomputation, Contains(symbols)).WillRepeatedly(Return(true)); + EXPECT_CALL(*precomputation, GetCollocations(symbols)). + WillRepeatedly(Return(expected_location)); intersector = make_shared<FastIntersector>(suffix_array, precomputation, vocabulary, 15, 1); diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 487abcaf..4d0738f7 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -19,10 +19,11 @@ GrammarExtractor::GrammarExtractor( shared_ptr<SuffixArray> source_suffix_array, shared_ptr<DataArray> target_data_array, shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation, - shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span, + shared_ptr<Scorer> scorer, shared_ptr<Vocabulary> vocabulary, + int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, int max_samples, bool require_tight_phrases) : - vocabulary(make_shared<Vocabulary>()), + vocabulary(vocabulary), rule_factory(make_shared<HieroCachingRuleFactory>( source_suffix_array, target_data_array, alignment, vocabulary, precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals, diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index ae407b47..8f570df2 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -32,6 +32,7 @@ class GrammarExtractor { shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation, shared_ptr<Scorer> scorer, + shared_ptr<Vocabulary> vocabulary, int min_gap_size, int max_rule_span, int max_nonterminals, diff --git a/extractor/mocks/mock_data_array.h b/extractor/mocks/mock_data_array.h index 6f85abb4..4bdcf21f 100644 --- a/extractor/mocks/mock_data_array.h +++ b/extractor/mocks/mock_data_array.h @@ -6,9 +6,11 @@ namespace extractor { class MockDataArray : public DataArray { public: - MOCK_CONST_METHOD0(GetData, const vector<int>&()); + MOCK_CONST_METHOD0(GetData, vector<int>()); MOCK_CONST_METHOD1(AtIndex, int(int index)); MOCK_CONST_METHOD1(GetWordAtIndex, string(int index)); + MOCK_CONST_METHOD2(GetWordIds, vector<int>(int start_index, int size)); + MOCK_CONST_METHOD2(GetWords, vector<string>(int start_index, int size)); MOCK_CONST_METHOD0(GetSize, int()); MOCK_CONST_METHOD0(GetVocabularySize, int()); MOCK_CONST_METHOD1(HasWord, bool(const string& word)); diff --git a/extractor/mocks/mock_precomputation.h b/extractor/mocks/mock_precomputation.h index 8753343e..5f7aa999 100644 --- a/extractor/mocks/mock_precomputation.h +++ b/extractor/mocks/mock_precomputation.h @@ -6,7 +6,8 @@ namespace extractor { class MockPrecomputation : public Precomputation { public: - MOCK_CONST_METHOD0(GetCollocations, const Index&()); + MOCK_CONST_METHOD1(Contains, bool(const vector<int>& pattern)); + MOCK_CONST_METHOD1(GetCollocations, vector<int>(const vector<int>& pattern)); }; } // namespace extractor diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 3b8aed69..38d8f489 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -5,22 +5,21 @@ #include "data_array.h" #include "suffix_array.h" +#include "vocabulary.h" using namespace std; namespace extractor { -int Precomputation::FIRST_NONTERMINAL = -1; -int Precomputation::SECOND_NONTERMINAL = -2; +int Precomputation::NONTERMINAL = -1; Precomputation::Precomputation( - shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, - int num_super_frequent_patterns, int max_rule_span, - int max_rule_symbols, int min_gap_size, + shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array, + int num_frequent_patterns, int num_super_frequent_patterns, + int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency) { - vector<int> data = suffix_array->GetData()->GetData(); vector<vector<int>> frequent_patterns = FindMostFrequentPatterns( - suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, + suffix_array, num_frequent_patterns, max_frequent_phrase_len, min_frequency); // Construct sets containing the frequent and superfrequent contiguous @@ -34,28 +33,30 @@ Precomputation::Precomputation( } } + shared_ptr<DataArray> data_array = suffix_array->GetData(); vector<tuple<int, int, int>> matchings; - for (size_t i = 0; i < data.size(); ++i) { + for (size_t i = 0; i < data_array->GetSize(); ++i) { // If the sentence is over, add all the discontiguous frequent patterns to // the index. - if (data[i] == DataArray::END_OF_LINE) { - AddCollocations(matchings, data, max_rule_span, min_gap_size, - max_rule_symbols); + if (data_array->AtIndex(i) == DataArray::END_OF_LINE) { + UpdateIndex(data_array, vocabulary, matchings, max_rule_span, + min_gap_size, max_rule_symbols); matchings.clear(); continue; } - vector<int> pattern; // Find all the contiguous frequent patterns starting at position i. - for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { - pattern.push_back(data[i + j - 1]); - if (frequent_patterns_set.count(pattern)) { - int is_super_frequent = super_frequent_patterns_set.count(pattern); - matchings.push_back(make_tuple(i, j, is_super_frequent)); - } else { + vector<int> pattern; + for (int j = 1; + j <= max_frequent_phrase_len && i + j <= data_array->GetSize(); + ++j) { + pattern.push_back(data_array->AtIndex(i + j - 1)); + if (!frequent_patterns_set.count(pattern)) { // If the current pattern is not frequent, any longer pattern having the // current pattern as prefix will not be frequent. break; } + int is_super_frequent = super_frequent_patterns_set.count(pattern); + matchings.push_back(make_tuple(i, j, is_super_frequent)); } } } @@ -65,8 +66,8 @@ Precomputation::Precomputation() {} Precomputation::~Precomputation() {} vector<vector<int>> Precomputation::FindMostFrequentPatterns( - shared_ptr<SuffixArray> suffix_array, const vector<int>& data, - int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) { + shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, + int max_frequent_phrase_len, int min_frequency) { vector<int> lcp = suffix_array->BuildLCPArray(); vector<int> run_start(max_frequent_phrase_len); @@ -83,6 +84,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( } } + shared_ptr<DataArray> data_array = suffix_array->GetData(); // Extract the most frequent patterns. vector<vector<int>> frequent_patterns; while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { @@ -90,7 +92,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( int len = heap.top().second.second; heap.pop(); - vector<int> pattern(data.begin() + start, data.begin() + start + len); + vector<int> pattern = data_array->GetWordIds(start, len); if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == pattern.end()) { frequent_patterns.push_back(pattern); @@ -99,8 +101,9 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( return frequent_patterns; } -void Precomputation::AddCollocations( - const vector<tuple<int, int, int>>& matchings, const vector<int>& data, +void Precomputation::UpdateIndex( + shared_ptr<DataArray> data_array, shared_ptr<Vocabulary> vocabulary, + const vector<tuple<int, int, int>>& matchings, int max_rule_span, int min_gap_size, int max_rule_symbols) { // Select the leftmost subpattern. for (size_t i = 0; i < matchings.size(); ++i) { @@ -118,16 +121,15 @@ void Precomputation::AddCollocations( if (start2 - start1 - size1 >= min_gap_size && start2 + size2 - start1 <= max_rule_span && size1 + size2 + 1 <= max_rule_symbols) { - vector<int> pattern(data.begin() + start1, - data.begin() + start1 + size1); - pattern.push_back(Precomputation::FIRST_NONTERMINAL); - pattern.insert(pattern.end(), data.begin() + start2, - data.begin() + start2 + size2); - AddStartPositions(collocations[pattern], start1, start2); + vector<int> pattern; + AppendSubpattern(pattern, data_array, vocabulary, start1, size1); + pattern.push_back(Precomputation::NONTERMINAL); + AppendSubpattern(pattern, data_array, vocabulary, start2, size2); + AppendCollocation(index[pattern], {start1, start2}); // Try extending the binary collocation to a ternary collocation. if (is_super2) { - pattern.push_back(Precomputation::SECOND_NONTERMINAL); + pattern.push_back(Precomputation::NONTERMINAL); // Select the rightmost subpattern. for (size_t k = j + 1; k < matchings.size(); ++k) { int start3, size3, is_super3; @@ -140,9 +142,8 @@ void Precomputation::AddCollocations( && start3 + size3 - start1 <= max_rule_span && size1 + size2 + size3 + 2 <= max_rule_symbols && (is_super1 || is_super3)) { - pattern.insert(pattern.end(), data.begin() + start3, - data.begin() + start3 + size3); - AddStartPositions(collocations[pattern], start1, start2, start3); + AppendSubpattern(pattern, data_array, vocabulary, start3, size3); + AppendCollocation(index[pattern], {start1, start2, start3}); pattern.erase(pattern.end() - size3); } } @@ -152,25 +153,30 @@ void Precomputation::AddCollocations( } } -void Precomputation::AddStartPositions( - vector<int>& positions, int pos1, int pos2) { - positions.push_back(pos1); - positions.push_back(pos2); +void Precomputation::AppendSubpattern( + vector<int>& pattern, shared_ptr<DataArray> data_array, + shared_ptr<Vocabulary> vocabulary, int start, int size) { + vector<string> words = data_array->GetWords(start, size); + for (const string& word: words) { + pattern.push_back(vocabulary->GetTerminalIndex(word)); + } +} + +void Precomputation::AppendCollocation( + vector<int>& collocations, const vector<int>& collocation) { + copy(collocation.begin(), collocation.end(), back_inserter(collocations)); } -void Precomputation::AddStartPositions( - vector<int>& positions, int pos1, int pos2, int pos3) { - positions.push_back(pos1); - positions.push_back(pos2); - positions.push_back(pos3); +bool Precomputation::Contains(const vector<int>& pattern) const { + return index.count(pattern); } -const Index& Precomputation::GetCollocations() const { - return collocations; +vector<int> Precomputation::GetCollocations(const vector<int>& pattern) const { + return index.at(pattern); } bool Precomputation::operator==(const Precomputation& other) const { - return collocations == other.collocations; + return index == other.index; } } // namespace extractor diff --git a/extractor/precomputation.h b/extractor/precomputation.h index e5fa3e37..6ade58df 100644 --- a/extractor/precomputation.h +++ b/extractor/precomputation.h @@ -19,7 +19,9 @@ namespace extractor { typedef boost::hash<vector<int>> VectorHash; typedef unordered_map<vector<int>, vector<int>, VectorHash> Index; +class DataArray; class SuffixArray; +class Vocabulary; /** * Data structure wrapping an index with all the occurrences of the most @@ -35,9 +37,9 @@ class Precomputation { public: // Constructs the index using the suffix array. Precomputation( - shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, - int num_super_frequent_patterns, int max_rule_span, - int max_rule_symbols, int min_gap_size, + shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array, + int num_frequent_patterns, int num_super_frequent_patterns, + int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency); // Creates empty precomputation data structure. @@ -45,40 +47,43 @@ class Precomputation { virtual ~Precomputation(); - // Returns a reference to the index. - virtual const Index& GetCollocations() const; + // Returns whether a pattern is contained in the index of collocations. + virtual bool Contains(const vector<int>& pattern) const; + + // Returns the list of collocations for a given pattern. + virtual vector<int> GetCollocations(const vector<int>& pattern) const; bool operator==(const Precomputation& other) const; - static int FIRST_NONTERMINAL; - static int SECOND_NONTERMINAL; + static int NONTERMINAL; private: // Finds the most frequent contiguous collocations. vector<vector<int>> FindMostFrequentPatterns( - shared_ptr<SuffixArray> suffix_array, const vector<int>& data, - int num_frequent_patterns, int max_frequent_phrase_len, - int min_frequency); + shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, + int max_frequent_phrase_len, int min_frequency); // Given the locations of the frequent contiguous collocations in a sentence, // it adds new entries to the index for each discontiguous collocation // matching the criteria specified in the class description. - void AddCollocations( - const vector<std::tuple<int, int, int>>& matchings, const vector<int>& data, + void UpdateIndex( + shared_ptr<DataArray> data_array, shared_ptr<Vocabulary> vocabulary, + const vector<tuple<int, int, int>>& matchings, int max_rule_span, int min_gap_size, int max_rule_symbols); - // Adds an occurrence of a binary collocation. - void AddStartPositions(vector<int>& positions, int pos1, int pos2); + void AppendSubpattern( + vector<int>& pattern, shared_ptr<DataArray> data_array, + shared_ptr<Vocabulary> vocabulary, int start, int size); - // Adds an occurrence of a ternary collocation. - void AddStartPositions(vector<int>& positions, int pos1, int pos2, int pos3); + // Adds an occurrence of a collocation. + void AppendCollocation(vector<int>& collocations, const vector<int>& collocation); friend class boost::serialization::access; template<class Archive> void save(Archive& ar, unsigned int) const { - int num_entries = collocations.size(); + int num_entries = index.size(); ar << num_entries; - for (pair<vector<int>, vector<int>> entry: collocations) { + for (pair<vector<int>, vector<int>> entry: index) { ar << entry; } } @@ -89,13 +94,13 @@ class Precomputation { for (size_t i = 0; i < num_entries; ++i) { pair<vector<int>, vector<int>> entry; ar >> entry; - collocations.insert(entry); + index.insert(entry); } } BOOST_SERIALIZATION_SPLIT_MEMBER(); - Index collocations; + Index index; }; } // namespace extractor diff --git a/extractor/precomputation_test.cc b/extractor/precomputation_test.cc index e81ece5d..fd85fcf8 100644 --- a/extractor/precomputation_test.cc +++ b/extractor/precomputation_test.cc @@ -9,6 +9,7 @@ #include "mocks/mock_data_array.h" #include "mocks/mock_suffix_array.h" +#include "mocks/mock_vocabulary.h" #include "precomputation.h" using namespace std; @@ -23,7 +24,31 @@ class PrecomputationTest : public Test { virtual void SetUp() { data = {4, 2, 3, 5, 7, 2, 3, 5, 2, 3, 4, 2, 1}; data_array = make_shared<MockDataArray>(); - EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(data.size())); + for (size_t i = 0; i < data.size(); ++i) { + EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); + } + vector<pair<int, int>> expected_calls = {{8, 1}, {8, 2}, {6, 1}}; + for (const auto& call: expected_calls) { + int start = call.first; + int size = call.second; + vector<int> word_ids(data.begin() + start, data.begin() + start + size); + EXPECT_CALL(*data_array, GetWordIds(start, size)) + .WillRepeatedly(Return(word_ids)); + } + + expected_calls = {{1, 1}, {5, 1}, {8, 1}, {9, 1}, {5, 2}, + {6, 1}, {8, 2}, {1, 2}, {2, 1}, {11, 1}}; + for (const auto& call: expected_calls) { + int start = call.first; + int size = call.second; + vector<string> words; + for (size_t j = start; j < start + size; ++j) { + words.push_back(to_string(data[j])); + } + EXPECT_CALL(*data_array, GetWords(start, size)) + .WillRepeatedly(Return(words)); + } vector<int> suffixes{12, 8, 5, 1, 9, 6, 2, 0, 10, 7, 3, 4, 13}; vector<int> lcp{-1, 0, 2, 3, 1, 0, 1, 2, 0, 2, 0, 1, 0, 0}; @@ -35,77 +60,98 @@ class PrecomputationTest : public Test { } EXPECT_CALL(*suffix_array, BuildLCPArray()).WillRepeatedly(Return(lcp)); - precomputation = Precomputation(suffix_array, 3, 3, 10, 5, 1, 4, 2); + vocabulary = make_shared<MockVocabulary>(); + EXPECT_CALL(*vocabulary, GetTerminalIndex("2")).WillRepeatedly(Return(2)); + EXPECT_CALL(*vocabulary, GetTerminalIndex("3")).WillRepeatedly(Return(3)); + + precomputation = Precomputation(vocabulary, suffix_array, + 3, 3, 10, 5, 1, 4, 2); } vector<int> data; shared_ptr<MockDataArray> data_array; shared_ptr<MockSuffixArray> suffix_array; + shared_ptr<MockVocabulary> vocabulary; Precomputation precomputation; }; TEST_F(PrecomputationTest, TestCollocations) { - Index collocations = precomputation.GetCollocations(); - vector<int> key = {2, 3, -1, 2}; vector<int> expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, 3, -1, 2, 3}; expected_value = {1, 5, 1, 8, 5, 8}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, 3, -1, 3}; expected_value = {1, 6, 1, 9, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 2}; expected_value = {2, 5, 2, 8, 2, 11, 6, 8, 6, 11, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 3}; expected_value = {2, 6, 2, 9, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {3, -1, 2, 3}; expected_value = {2, 5, 2, 8, 6, 8}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 2}; expected_value = {1, 5, 1, 8, 5, 8, 5, 11, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 2, 3}; expected_value = {1, 5, 1, 8, 5, 8}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); key = {2, -1, 3}; expected_value = {1, 6, 1, 9, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); - key = {2, -1, 2, -2, 2}; + key = {2, -1, 2, -1, 2}; expected_value = {1, 5, 8, 5, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 2, -2, 3}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {2, -1, 2, -1, 3}; expected_value = {1, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 3, -2, 2}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {2, -1, 3, -1, 2}; expected_value = {1, 6, 8, 5, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {2, -1, 3, -2, 3}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {2, -1, 3, -1, 3}; expected_value = {1, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 2, -2, 2}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {3, -1, 2, -1, 2}; expected_value = {2, 5, 8, 2, 5, 11, 2, 8, 11, 6, 8, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 2, -2, 3}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {3, -1, 2, -1, 3}; expected_value = {2, 5, 9}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 3, -2, 2}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {3, -1, 3, -1, 2}; expected_value = {2, 6, 8, 2, 6, 11, 2, 9, 11, 6, 9, 11}; - EXPECT_EQ(expected_value, collocations[key]); - key = {3, -1, 3, -2, 3}; + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); + key = {3, -1, 3, -1, 3}; expected_value = {2, 6, 9}; - EXPECT_EQ(expected_value, collocations[key]); + EXPECT_TRUE(precomputation.Contains(key)); + EXPECT_EQ(expected_value, precomputation.GetCollocations(key)); // Exceeds max_rule_symbols. - key = {2, -1, 2, -2, 2, 3}; - EXPECT_EQ(0, collocations.count(key)); + key = {2, -1, 2, -1, 2, 3}; + EXPECT_FALSE(precomputation.Contains(key)); // Contains non frequent pattern. key = {2, -1, 5}; - EXPECT_EQ(0, collocations.count(key)); + EXPECT_FALSE(precomputation.Contains(key)); } TEST_F(PrecomputationTest, TestSerialization) { diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 6eb55073..85c8a422 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -28,6 +28,7 @@ #include "suffix_array.h" #include "time_util.h" #include "translation_table.h" +#include "vocabulary.h" namespace fs = boost::filesystem; namespace po = boost::program_options; @@ -142,11 +143,14 @@ int main(int argc, char** argv) { cerr << "Reading alignment took " << GetDuration(start_time, stop_time) << " seconds" << endl; + shared_ptr<Vocabulary> vocabulary = make_shared<Vocabulary>(); + // Constructs an index storing the occurrences in the source data for each // frequent collocation. start_time = Clock::now(); cerr << "Precomputing collocations..." << endl; shared_ptr<Precomputation> precomputation = make_shared<Precomputation>( + vocabulary, source_suffix_array, vm["frequent"].as<int>(), vm["super_frequent"].as<int>(), @@ -194,6 +198,7 @@ int main(int argc, char** argv) { alignment, precomputation, scorer, + vocabulary, vm["min_gap_size"].as<int>(), vm["max_rule_span"].as<int>(), vm["max_nonterminals"].as<int>(), diff --git a/extractor/suffix_array_test.cc b/extractor/suffix_array_test.cc index ba0dbcc3..a9fd1eab 100644 --- a/extractor/suffix_array_test.cc +++ b/extractor/suffix_array_test.cc @@ -21,7 +21,7 @@ class SuffixArrayTest : public Test { virtual void SetUp() { data = {6, 4, 1, 2, 4, 5, 3, 4, 6, 6, 4, 1, 2}; data_array = make_shared<MockDataArray>(); - EXPECT_CALL(*data_array, GetData()).WillRepeatedly(ReturnRef(data)); + EXPECT_CALL(*data_array, GetData()).WillRepeatedly(Return(data)); EXPECT_CALL(*data_array, GetVocabularySize()).WillRepeatedly(Return(7)); EXPECT_CALL(*data_array, GetSize()).WillRepeatedly(Return(13)); suffix_array = SuffixArray(data_array); diff --git a/extractor/translation_table_test.cc b/extractor/translation_table_test.cc index 606777bd..72551a12 100644 --- a/extractor/translation_table_test.cc +++ b/extractor/translation_table_test.cc @@ -28,7 +28,7 @@ class TranslationTableTest : public Test { vector<int> source_sentence_start = {0, 6, 10, 14}; shared_ptr<MockDataArray> source_data_array = make_shared<MockDataArray>(); EXPECT_CALL(*source_data_array, GetData()) - .WillRepeatedly(ReturnRef(source_data)); + .WillRepeatedly(Return(source_data)); EXPECT_CALL(*source_data_array, GetNumSentences()) .WillRepeatedly(Return(3)); for (size_t i = 0; i < source_sentence_start.size(); ++i) { @@ -48,7 +48,7 @@ class TranslationTableTest : public Test { vector<int> target_sentence_start = {0, 7, 10, 13}; shared_ptr<MockDataArray> target_data_array = make_shared<MockDataArray>(); EXPECT_CALL(*target_data_array, GetData()) - .WillRepeatedly(ReturnRef(target_data)); + .WillRepeatedly(Return(target_data)); for (size_t i = 0; i < target_sentence_start.size(); ++i) { EXPECT_CALL(*target_data_array, GetSentenceStart(i)) .WillRepeatedly(Return(target_sentence_start[i])); |