diff options
Diffstat (limited to 'extractor/rule_factory.cc')
-rw-r--r-- | extractor/rule_factory.cc | 84 |
1 files changed, 72 insertions, 12 deletions
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index c22f9b48..374a0db1 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -1,6 +1,6 @@ #include "rule_factory.h" -#include <cassert> +#include <chrono> #include <memory> #include <queue> #include <vector> @@ -18,7 +18,9 @@ #include "vocabulary.h" using namespace std; -using namespace tr1; +using namespace std::chrono; + +typedef high_resolution_clock Clock; struct State { State(int start, int end, const vector<int>& phrase, @@ -68,8 +70,44 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( sampler = make_shared<Sampler>(source_suffix_array, max_samples); } +HieroCachingRuleFactory::HieroCachingRuleFactory( + shared_ptr<MatchingsFinder> finder, + shared_ptr<Intersector> intersector, + shared_ptr<PhraseBuilder> phrase_builder, + shared_ptr<RuleExtractor> rule_extractor, + shared_ptr<Vocabulary> vocabulary, + shared_ptr<Sampler> sampler, + shared_ptr<Scorer> scorer, + int min_gap_size, + int max_rule_span, + int max_nonterminals, + int max_chunks, + int max_rule_symbols) : + matchings_finder(finder), + intersector(intersector), + phrase_builder(phrase_builder), + rule_extractor(rule_extractor), + vocabulary(vocabulary), + sampler(sampler), + scorer(scorer), + min_gap_size(min_gap_size), + max_rule_span(max_rule_span), + max_nonterminals(max_nonterminals), + max_chunks(max_chunks), + max_rule_symbols(max_rule_symbols) {} + +HieroCachingRuleFactory::HieroCachingRuleFactory() {} + +HieroCachingRuleFactory::~HieroCachingRuleFactory() {} Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { + intersector->binary_merge_time = 0; + intersector->linear_merge_time = 0; + intersector->sort_time = 0; + Clock::time_point start_time = Clock::now(); + double total_extract_time = 0; + double total_intersect_time = 0; + double total_lookup_time = 0; // Clear cache for every new sentence. trie.Reset(); shared_ptr<TrieNode> root = trie.GetRoot(); @@ -107,34 +145,42 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { } if (RequiresLookup(node, word_id)) { - shared_ptr<TrieNode> next_suffix_link = - node->suffix_link->GetChild(word_id); + shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ? + trie.GetRoot() : node->suffix_link->GetChild(word_id); if (state.starts_with_x) { // If the phrase starts with a non terminal, we simply use the matchings // from the suffix link. - next_node = shared_ptr<TrieNode>(new TrieNode( - next_suffix_link, next_phrase, next_suffix_link->matchings)); + next_node = make_shared<TrieNode>( + next_suffix_link, next_phrase, next_suffix_link->matchings); } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { + Clock::time_point intersect_start_time = Clock::now(); phrase_location = intersector->Intersect( node->phrase, node->matchings, next_suffix_link->phrase, next_suffix_link->matchings, next_phrase); + Clock::time_point intersect_stop_time = Clock::now(); + total_intersect_time += duration_cast<milliseconds>( + intersect_stop_time - intersect_start_time).count(); } else { + Clock::time_point lookup_start_time = Clock::now(); phrase_location = matchings_finder->Find( node->matchings, vocabulary->GetTerminalValue(word_id), state.phrase.size()); + Clock::time_point lookup_stop_time = Clock::now(); + total_lookup_time += duration_cast<milliseconds>( + lookup_stop_time - lookup_start_time).count(); } if (phrase_location.IsEmpty()) { continue; } - next_node = shared_ptr<TrieNode>(new TrieNode( - next_suffix_link, next_phrase, phrase_location)); + next_node = make_shared<TrieNode>( + next_suffix_link, next_phrase, phrase_location); } node->AddChild(word_id, next_node); @@ -143,12 +189,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { AddTrailingNonterminal(phrase, next_phrase, next_node, state.starts_with_x); + Clock::time_point extract_start_time = Clock::now(); if (!state.starts_with_x) { PhraseLocation sample = sampler->Sample(next_node->matchings); vector<Rule> new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); } + Clock::time_point extract_stop_time = Clock::now(); + total_extract_time += duration_cast<milliseconds>( + extract_stop_time - extract_start_time).count(); } else { next_node = node->GetChild(word_id); } @@ -160,6 +210,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { } } + Clock::time_point stop_time = Clock::now(); + milliseconds ms = duration_cast<milliseconds>(stop_time - start_time); + cerr << "Total time for rule lookup, extraction, and scoring = " + << ms.count() / 1000.0 << endl; + cerr << "Extract time = " << total_extract_time / 1000.0 << endl; + cerr << "Intersect time = " << total_intersect_time / 1000.0 << endl; + cerr << "Sort time = " << intersector->sort_time / 1000.0 << endl; + cerr << "Linear merge time = " << intersector->linear_merge_time / 1000.0 << endl; + cerr << "Binary merge time = " << intersector->binary_merge_time / 1000.0 << endl; + // cerr << "Lookup time = " << total_lookup_time / 1000.0 << endl; return Grammar(rules, scorer->GetFeatureNames()); } @@ -192,12 +252,12 @@ void HieroCachingRuleFactory::AddTrailingNonterminal( Phrase var_phrase = phrase_builder->Build(symbols); int suffix_var_id = vocabulary->GetNonterminalIndex( - prefix.Arity() + starts_with_x == 0); + prefix.Arity() + (starts_with_x == 0)); shared_ptr<TrieNode> var_suffix_link = prefix_node->suffix_link->GetChild(suffix_var_id); - prefix_node->AddChild(var_id, shared_ptr<TrieNode>(new TrieNode( - var_suffix_link, var_phrase, prefix_node->matchings))); + prefix_node->AddChild(var_id, make_shared<TrieNode>( + var_suffix_link, var_phrase, prefix_node->matchings)); } vector<State> HieroCachingRuleFactory::ExtendState( @@ -216,7 +276,7 @@ vector<State> HieroCachingRuleFactory::ExtendState( new_states.push_back(State(state.start, state.end + 1, symbols, state.subpatterns_start, node, state.starts_with_x)); - int num_subpatterns = phrase.Arity() + state.starts_with_x == 0; + int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0); if (symbols.size() + 1 >= max_rule_symbols || phrase.Arity() >= max_nonterminals || num_subpatterns >= max_chunks) { |