#include "intersector.h" #include #include "data_array.h" #include "matching_comparator.h" #include "phrase.h" #include "phrase_location.h" #include "precomputation.h" #include "suffix_array.h" #include "veb.h" #include "vocabulary.h" using namespace std::chrono; typedef high_resolution_clock Clock; Intersector::Intersector(shared_ptr vocabulary, shared_ptr precomputation, shared_ptr suffix_array, shared_ptr comparator, bool use_baeza_yates) : vocabulary(vocabulary), suffix_array(suffix_array), use_baeza_yates(use_baeza_yates) { shared_ptr data_array = suffix_array->GetData(); linear_merger = make_shared(vocabulary, data_array, comparator); binary_search_merger = make_shared( vocabulary, linear_merger, data_array, comparator); ConvertIndexes(precomputation, data_array); } Intersector::Intersector(shared_ptr vocabulary, shared_ptr precomputation, shared_ptr suffix_array, shared_ptr linear_merger, shared_ptr binary_search_merger, bool use_baeza_yates) : vocabulary(vocabulary), suffix_array(suffix_array), linear_merger(linear_merger), binary_search_merger(binary_search_merger), use_baeza_yates(use_baeza_yates) { ConvertIndexes(precomputation, suffix_array->GetData()); } Intersector::Intersector() {} Intersector::~Intersector() {} void Intersector::ConvertIndexes(shared_ptr precomputation, shared_ptr data_array) { const Index& precomputed_index = precomputation->GetInvertedIndex(); for (pair, vector > entry: precomputed_index) { vector phrase = ConvertPhrase(entry.first, data_array); inverted_index[phrase] = entry.second; phrase.push_back(vocabulary->GetNonterminalIndex(1)); inverted_index[phrase] = entry.second; phrase.pop_back(); phrase.insert(phrase.begin(), vocabulary->GetNonterminalIndex(1)); inverted_index[phrase] = entry.second; } const Index& precomputed_collocations = precomputation->GetCollocations(); for (pair, vector > entry: precomputed_collocations) { vector phrase = ConvertPhrase(entry.first, data_array); collocations[phrase] = entry.second; } } vector Intersector::ConvertPhrase(const vector& old_phrase, shared_ptr data_array) { vector new_phrase; new_phrase.reserve(old_phrase.size()); int arity = 0; for (int word_id: old_phrase) { if (word_id == Precomputation::NON_TERMINAL) { ++arity; new_phrase.push_back(vocabulary->GetNonterminalIndex(arity)); } else { new_phrase.push_back( vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); } } return new_phrase; } PhraseLocation Intersector::Intersect( const Phrase& prefix, PhraseLocation& prefix_location, const Phrase& suffix, PhraseLocation& suffix_location, const Phrase& phrase) { if (linear_merge_time == 0) { linear_merger->linear_merge_time = 0; } vector symbols = phrase.Get(); // We should never attempt to do an intersect query for a pattern starting or // ending with a non terminal. The RuleFactory should handle these cases, // initializing the matchings list with the one for the pattern without the // starting or ending terminal. assert(vocabulary->IsTerminal(symbols.front()) && vocabulary->IsTerminal(symbols.back())); if (collocations.count(symbols)) { return PhraseLocation(collocations[symbols], phrase.Arity() + 1); } vector locations; ExtendPhraseLocation(prefix, prefix_location); ExtendPhraseLocation(suffix, suffix_location); shared_ptr > prefix_matchings = prefix_location.matchings; shared_ptr > suffix_matchings = suffix_location.matchings; int prefix_subpatterns = prefix_location.num_subpatterns; int suffix_subpatterns = suffix_location.num_subpatterns; if (use_baeza_yates) { double prev_linear_merge_time = linear_merger->linear_merge_time; Clock::time_point start = Clock::now(); binary_search_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(), prefix_matchings->end(), suffix_matchings->begin(), suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns); Clock::time_point stop = Clock::now(); binary_merge_time += duration_cast(stop - start).count() - (linear_merger->linear_merge_time - prev_linear_merge_time); } else { linear_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(), prefix_matchings->end(), suffix_matchings->begin(), suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns); } linear_merge_time = linear_merger->linear_merge_time; return PhraseLocation(locations, phrase.Arity() + 1); } void Intersector::ExtendPhraseLocation( const Phrase& phrase, PhraseLocation& phrase_location) { int low = phrase_location.sa_low, high = phrase_location.sa_high; if (phrase_location.matchings != NULL) { return; } Clock::time_point sort_start = Clock::now(); phrase_location.num_subpatterns = 1; phrase_location.sa_low = phrase_location.sa_high = 0; vector symbols = phrase.Get(); if (inverted_index.count(symbols)) { phrase_location.matchings = make_shared >(inverted_index[symbols]); return; } vector matchings; matchings.reserve(high - low + 1); shared_ptr veb = VEB::Create(suffix_array->GetSize()); for (int i = low; i < high; ++i) { veb->Insert(suffix_array->GetSuffix(i)); } int value = veb->GetMinimum(); while (value != -1) { matchings.push_back(value); value = veb->GetSuccessor(value); } phrase_location.matchings = make_shared >(matchings); Clock::time_point sort_stop = Clock::now(); sort_time += duration_cast(sort_stop - sort_start).count(); }