diff options
Diffstat (limited to 'extractor/precomputation.cc')
-rw-r--r-- | extractor/precomputation.cc | 96 |
1 files changed, 51 insertions, 45 deletions
diff --git a/extractor/precomputation.cc b/extractor/precomputation.cc index 3b8aed69..38d8f489 100644 --- a/extractor/precomputation.cc +++ b/extractor/precomputation.cc @@ -5,22 +5,21 @@ #include "data_array.h" #include "suffix_array.h" +#include "vocabulary.h" using namespace std; namespace extractor { -int Precomputation::FIRST_NONTERMINAL = -1; -int Precomputation::SECOND_NONTERMINAL = -2; +int Precomputation::NONTERMINAL = -1; Precomputation::Precomputation( - shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, - int num_super_frequent_patterns, int max_rule_span, - int max_rule_symbols, int min_gap_size, + shared_ptr<Vocabulary> vocabulary, shared_ptr<SuffixArray> suffix_array, + int num_frequent_patterns, int num_super_frequent_patterns, + int max_rule_span, int max_rule_symbols, int min_gap_size, int max_frequent_phrase_len, int min_frequency) { - vector<int> data = suffix_array->GetData()->GetData(); vector<vector<int>> frequent_patterns = FindMostFrequentPatterns( - suffix_array, data, num_frequent_patterns, max_frequent_phrase_len, + suffix_array, num_frequent_patterns, max_frequent_phrase_len, min_frequency); // Construct sets containing the frequent and superfrequent contiguous @@ -34,28 +33,30 @@ Precomputation::Precomputation( } } + shared_ptr<DataArray> data_array = suffix_array->GetData(); vector<tuple<int, int, int>> matchings; - for (size_t i = 0; i < data.size(); ++i) { + for (size_t i = 0; i < data_array->GetSize(); ++i) { // If the sentence is over, add all the discontiguous frequent patterns to // the index. - if (data[i] == DataArray::END_OF_LINE) { - AddCollocations(matchings, data, max_rule_span, min_gap_size, - max_rule_symbols); + if (data_array->AtIndex(i) == DataArray::END_OF_LINE) { + UpdateIndex(data_array, vocabulary, matchings, max_rule_span, + min_gap_size, max_rule_symbols); matchings.clear(); continue; } - vector<int> pattern; // Find all the contiguous frequent patterns starting at position i. - for (int j = 1; j <= max_frequent_phrase_len && i + j <= data.size(); ++j) { - pattern.push_back(data[i + j - 1]); - if (frequent_patterns_set.count(pattern)) { - int is_super_frequent = super_frequent_patterns_set.count(pattern); - matchings.push_back(make_tuple(i, j, is_super_frequent)); - } else { + vector<int> pattern; + for (int j = 1; + j <= max_frequent_phrase_len && i + j <= data_array->GetSize(); + ++j) { + pattern.push_back(data_array->AtIndex(i + j - 1)); + if (!frequent_patterns_set.count(pattern)) { // If the current pattern is not frequent, any longer pattern having the // current pattern as prefix will not be frequent. break; } + int is_super_frequent = super_frequent_patterns_set.count(pattern); + matchings.push_back(make_tuple(i, j, is_super_frequent)); } } } @@ -65,8 +66,8 @@ Precomputation::Precomputation() {} Precomputation::~Precomputation() {} vector<vector<int>> Precomputation::FindMostFrequentPatterns( - shared_ptr<SuffixArray> suffix_array, const vector<int>& data, - int num_frequent_patterns, int max_frequent_phrase_len, int min_frequency) { + shared_ptr<SuffixArray> suffix_array, int num_frequent_patterns, + int max_frequent_phrase_len, int min_frequency) { vector<int> lcp = suffix_array->BuildLCPArray(); vector<int> run_start(max_frequent_phrase_len); @@ -83,6 +84,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( } } + shared_ptr<DataArray> data_array = suffix_array->GetData(); // Extract the most frequent patterns. vector<vector<int>> frequent_patterns; while (frequent_patterns.size() < num_frequent_patterns && !heap.empty()) { @@ -90,7 +92,7 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( int len = heap.top().second.second; heap.pop(); - vector<int> pattern(data.begin() + start, data.begin() + start + len); + vector<int> pattern = data_array->GetWordIds(start, len); if (find(pattern.begin(), pattern.end(), DataArray::END_OF_LINE) == pattern.end()) { frequent_patterns.push_back(pattern); @@ -99,8 +101,9 @@ vector<vector<int>> Precomputation::FindMostFrequentPatterns( return frequent_patterns; } -void Precomputation::AddCollocations( - const vector<tuple<int, int, int>>& matchings, const vector<int>& data, +void Precomputation::UpdateIndex( + shared_ptr<DataArray> data_array, shared_ptr<Vocabulary> vocabulary, + const vector<tuple<int, int, int>>& matchings, int max_rule_span, int min_gap_size, int max_rule_symbols) { // Select the leftmost subpattern. for (size_t i = 0; i < matchings.size(); ++i) { @@ -118,16 +121,15 @@ void Precomputation::AddCollocations( if (start2 - start1 - size1 >= min_gap_size && start2 + size2 - start1 <= max_rule_span && size1 + size2 + 1 <= max_rule_symbols) { - vector<int> pattern(data.begin() + start1, - data.begin() + start1 + size1); - pattern.push_back(Precomputation::FIRST_NONTERMINAL); - pattern.insert(pattern.end(), data.begin() + start2, - data.begin() + start2 + size2); - AddStartPositions(collocations[pattern], start1, start2); + vector<int> pattern; + AppendSubpattern(pattern, data_array, vocabulary, start1, size1); + pattern.push_back(Precomputation::NONTERMINAL); + AppendSubpattern(pattern, data_array, vocabulary, start2, size2); + AppendCollocation(index[pattern], {start1, start2}); // Try extending the binary collocation to a ternary collocation. if (is_super2) { - pattern.push_back(Precomputation::SECOND_NONTERMINAL); + pattern.push_back(Precomputation::NONTERMINAL); // Select the rightmost subpattern. for (size_t k = j + 1; k < matchings.size(); ++k) { int start3, size3, is_super3; @@ -140,9 +142,8 @@ void Precomputation::AddCollocations( && start3 + size3 - start1 <= max_rule_span && size1 + size2 + size3 + 2 <= max_rule_symbols && (is_super1 || is_super3)) { - pattern.insert(pattern.end(), data.begin() + start3, - data.begin() + start3 + size3); - AddStartPositions(collocations[pattern], start1, start2, start3); + AppendSubpattern(pattern, data_array, vocabulary, start3, size3); + AppendCollocation(index[pattern], {start1, start2, start3}); pattern.erase(pattern.end() - size3); } } @@ -152,25 +153,30 @@ void Precomputation::AddCollocations( } } -void Precomputation::AddStartPositions( - vector<int>& positions, int pos1, int pos2) { - positions.push_back(pos1); - positions.push_back(pos2); +void Precomputation::AppendSubpattern( + vector<int>& pattern, shared_ptr<DataArray> data_array, + shared_ptr<Vocabulary> vocabulary, int start, int size) { + vector<string> words = data_array->GetWords(start, size); + for (const string& word: words) { + pattern.push_back(vocabulary->GetTerminalIndex(word)); + } +} + +void Precomputation::AppendCollocation( + vector<int>& collocations, const vector<int>& collocation) { + copy(collocation.begin(), collocation.end(), back_inserter(collocations)); } -void Precomputation::AddStartPositions( - vector<int>& positions, int pos1, int pos2, int pos3) { - positions.push_back(pos1); - positions.push_back(pos2); - positions.push_back(pos3); +bool Precomputation::Contains(const vector<int>& pattern) const { + return index.count(pattern); } -const Index& Precomputation::GetCollocations() const { - return collocations; +vector<int> Precomputation::GetCollocations(const vector<int>& pattern) const { + return index.at(pattern); } bool Precomputation::operator==(const Precomputation& other) const { - return collocations == other.collocations; + return index == other.index; } } // namespace extractor |