diff options
Diffstat (limited to 'extractor/rule_factory.cc')
-rw-r--r-- | extractor/rule_factory.cc | 70 |
1 files changed, 38 insertions, 32 deletions
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 374a0db1..4101fcfa 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -6,6 +6,7 @@ #include <vector> #include "grammar.h" +#include "fast_intersector.h" #include "intersector.h" #include "matchings_finder.h" #include "matching_comparator.h" @@ -15,10 +16,11 @@ #include "sampler.h" #include "scorer.h" #include "suffix_array.h" +#include "time_util.h" #include "vocabulary.h" using namespace std; -using namespace std::chrono; +using namespace chrono; typedef high_resolution_clock Clock; @@ -48,6 +50,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( int max_nonterminals, int max_rule_symbols, int max_samples, + bool use_fast_intersect, bool use_baeza_yates, bool require_tight_phrases) : vocabulary(vocabulary), @@ -56,12 +59,15 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( max_rule_span(max_rule_span), max_nonterminals(max_nonterminals), max_chunks(max_nonterminals + 1), - max_rule_symbols(max_rule_symbols) { + max_rule_symbols(max_rule_symbols), + use_fast_intersect(use_fast_intersect) { matchings_finder = make_shared<MatchingsFinder>(source_suffix_array); shared_ptr<MatchingComparator> comparator = make_shared<MatchingComparator>(min_gap_size, max_rule_span); intersector = make_shared<Intersector>(vocabulary, precomputation, source_suffix_array, comparator, use_baeza_yates); + fast_intersector = make_shared<FastIntersector>(source_suffix_array, + precomputation, vocabulary, max_rule_span, min_gap_size); phrase_builder = make_shared<PhraseBuilder>(vocabulary); rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(), target_data_array, alignment, phrase_builder, scorer, vocabulary, @@ -73,6 +79,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( HieroCachingRuleFactory::HieroCachingRuleFactory( shared_ptr<MatchingsFinder> finder, shared_ptr<Intersector> intersector, + shared_ptr<FastIntersector> fast_intersector, shared_ptr<PhraseBuilder> phrase_builder, shared_ptr<RuleExtractor> rule_extractor, shared_ptr<Vocabulary> vocabulary, @@ -82,9 +89,11 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( int max_rule_span, int max_nonterminals, int max_chunks, - int max_rule_symbols) : + int max_rule_symbols, + bool use_fast_intersect) : matchings_finder(finder), intersector(intersector), + fast_intersector(fast_intersector), phrase_builder(phrase_builder), rule_extractor(rule_extractor), vocabulary(vocabulary), @@ -94,15 +103,14 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( max_rule_span(max_rule_span), max_nonterminals(max_nonterminals), max_chunks(max_chunks), - max_rule_symbols(max_rule_symbols) {} + max_rule_symbols(max_rule_symbols), + use_fast_intersect(use_fast_intersect) {} HieroCachingRuleFactory::HieroCachingRuleFactory() {} HieroCachingRuleFactory::~HieroCachingRuleFactory() {} Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { - intersector->binary_merge_time = 0; - intersector->linear_merge_time = 0; intersector->sort_time = 0; Clock::time_point start_time = Clock::now(); double total_extract_time = 0; @@ -155,25 +163,28 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { - Clock::time_point intersect_start_time = Clock::now(); - phrase_location = intersector->Intersect( - node->phrase, - node->matchings, - next_suffix_link->phrase, - next_suffix_link->matchings, - next_phrase); - Clock::time_point intersect_stop_time = Clock::now(); - total_intersect_time += duration_cast<milliseconds>( - intersect_stop_time - intersect_start_time).count(); + Clock::time_point intersect_start = Clock::now(); + if (use_fast_intersect) { + phrase_location = fast_intersector->Intersect( + node->matchings, next_suffix_link->matchings, next_phrase); + } else { + phrase_location = intersector->Intersect( + node->phrase, + node->matchings, + next_suffix_link->phrase, + next_suffix_link->matchings, + next_phrase); + } + Clock::time_point intersect_stop = Clock::now(); + total_intersect_time += GetDuration(intersect_start, intersect_stop); } else { - Clock::time_point lookup_start_time = Clock::now(); + Clock::time_point lookup_start = Clock::now(); phrase_location = matchings_finder->Find( node->matchings, vocabulary->GetTerminalValue(word_id), state.phrase.size()); - Clock::time_point lookup_stop_time = Clock::now(); - total_lookup_time += duration_cast<milliseconds>( - lookup_stop_time - lookup_start_time).count(); + Clock::time_point lookup_stop = Clock::now(); + total_lookup_time += GetDuration(lookup_start, lookup_stop); } if (phrase_location.IsEmpty()) { @@ -189,16 +200,15 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { AddTrailingNonterminal(phrase, next_phrase, next_node, state.starts_with_x); - Clock::time_point extract_start_time = Clock::now(); + Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { PhraseLocation sample = sampler->Sample(next_node->matchings); vector<Rule> new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); } - Clock::time_point extract_stop_time = Clock::now(); - total_extract_time += duration_cast<milliseconds>( - extract_stop_time - extract_start_time).count(); + Clock::time_point extract_stop = Clock::now(); + total_extract_time += GetDuration(extract_start, extract_stop); } else { next_node = node->GetChild(word_id); } @@ -211,15 +221,11 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { } Clock::time_point stop_time = Clock::now(); - milliseconds ms = duration_cast<milliseconds>(stop_time - start_time); cerr << "Total time for rule lookup, extraction, and scoring = " - << ms.count() / 1000.0 << endl; - cerr << "Extract time = " << total_extract_time / 1000.0 << endl; - cerr << "Intersect time = " << total_intersect_time / 1000.0 << endl; - cerr << "Sort time = " << intersector->sort_time / 1000.0 << endl; - cerr << "Linear merge time = " << intersector->linear_merge_time / 1000.0 << endl; - cerr << "Binary merge time = " << intersector->binary_merge_time / 1000.0 << endl; - // cerr << "Lookup time = " << total_lookup_time / 1000.0 << endl; + << GetDuration(start_time, stop_time) << " seconds" << endl; + cerr << "Extract time = " << total_extract_time << " seconds" << endl; + cerr << "Intersect time = " << total_intersect_time << " seconds" << endl; + cerr << "Lookup time = " << total_lookup_time << " seconds" << endl; return Grammar(rules, scorer->GetFeatureNames()); } |