diff options
Diffstat (limited to 'extractor/rule_factory.cc')
| -rw-r--r-- | extractor/rule_factory.cc | 84 | 
1 files changed, 72 insertions, 12 deletions
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index c22f9b48..374a0db1 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -1,6 +1,6 @@  #include "rule_factory.h" -#include <cassert> +#include <chrono>  #include <memory>  #include <queue>  #include <vector> @@ -18,7 +18,9 @@  #include "vocabulary.h"  using namespace std; -using namespace tr1; +using namespace std::chrono; + +typedef high_resolution_clock Clock;  struct State {    State(int start, int end, const vector<int>& phrase, @@ -68,8 +70,44 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(    sampler = make_shared<Sampler>(source_suffix_array, max_samples);  } +HieroCachingRuleFactory::HieroCachingRuleFactory( +    shared_ptr<MatchingsFinder> finder, +    shared_ptr<Intersector> intersector, +    shared_ptr<PhraseBuilder> phrase_builder, +    shared_ptr<RuleExtractor> rule_extractor, +    shared_ptr<Vocabulary> vocabulary, +    shared_ptr<Sampler> sampler, +    shared_ptr<Scorer> scorer, +    int min_gap_size, +    int max_rule_span, +    int max_nonterminals, +    int max_chunks, +    int max_rule_symbols) : +    matchings_finder(finder), +    intersector(intersector), +    phrase_builder(phrase_builder), +    rule_extractor(rule_extractor), +    vocabulary(vocabulary), +    sampler(sampler), +    scorer(scorer), +    min_gap_size(min_gap_size), +    max_rule_span(max_rule_span), +    max_nonterminals(max_nonterminals), +    max_chunks(max_chunks), +    max_rule_symbols(max_rule_symbols) {} + +HieroCachingRuleFactory::HieroCachingRuleFactory() {} + +HieroCachingRuleFactory::~HieroCachingRuleFactory() {}  Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { +  intersector->binary_merge_time = 0; +  intersector->linear_merge_time = 0; +  intersector->sort_time = 0; +  Clock::time_point start_time = Clock::now(); +  double total_extract_time = 0; +  double total_intersect_time = 0; +  double total_lookup_time = 0;    // Clear cache for every new sentence.    trie.Reset();    shared_ptr<TrieNode> root = trie.GetRoot(); @@ -107,34 +145,42 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {      }      if (RequiresLookup(node, word_id)) { -      shared_ptr<TrieNode> next_suffix_link = -          node->suffix_link->GetChild(word_id); +      shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ? +          trie.GetRoot() : node->suffix_link->GetChild(word_id);        if (state.starts_with_x) {          // If the phrase starts with a non terminal, we simply use the matchings          // from the suffix link. -        next_node = shared_ptr<TrieNode>(new TrieNode( -            next_suffix_link, next_phrase, next_suffix_link->matchings)); +        next_node = make_shared<TrieNode>( +            next_suffix_link, next_phrase, next_suffix_link->matchings);        } else {          PhraseLocation phrase_location;          if (next_phrase.Arity() > 0) { +          Clock::time_point intersect_start_time = Clock::now();            phrase_location = intersector->Intersect(                node->phrase,                node->matchings,                next_suffix_link->phrase,                next_suffix_link->matchings,                next_phrase); +          Clock::time_point intersect_stop_time = Clock::now(); +          total_intersect_time += duration_cast<milliseconds>( +              intersect_stop_time - intersect_start_time).count();          } else { +          Clock::time_point lookup_start_time = Clock::now();            phrase_location = matchings_finder->Find(                node->matchings,                vocabulary->GetTerminalValue(word_id),                state.phrase.size()); +          Clock::time_point lookup_stop_time = Clock::now(); +          total_lookup_time += duration_cast<milliseconds>( +              lookup_stop_time - lookup_start_time).count();          }          if (phrase_location.IsEmpty()) {            continue;          } -        next_node = shared_ptr<TrieNode>(new TrieNode( -            next_suffix_link, next_phrase, phrase_location)); +        next_node = make_shared<TrieNode>( +            next_suffix_link, next_phrase, phrase_location);        }        node->AddChild(word_id, next_node); @@ -143,12 +189,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {        AddTrailingNonterminal(phrase, next_phrase, next_node,                               state.starts_with_x); +      Clock::time_point extract_start_time = Clock::now();        if (!state.starts_with_x) {          PhraseLocation sample = sampler->Sample(next_node->matchings);          vector<Rule> new_rules =              rule_extractor->ExtractRules(next_phrase, sample);          rules.insert(rules.end(), new_rules.begin(), new_rules.end());        } +      Clock::time_point extract_stop_time = Clock::now(); +      total_extract_time += duration_cast<milliseconds>( +          extract_stop_time - extract_start_time).count();      } else {        next_node = node->GetChild(word_id);      } @@ -160,6 +210,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {      }    } +  Clock::time_point stop_time = Clock::now(); +  milliseconds ms = duration_cast<milliseconds>(stop_time - start_time); +  cerr << "Total time for rule lookup, extraction, and scoring = " +       << ms.count() / 1000.0 << endl; +  cerr << "Extract time = " << total_extract_time / 1000.0 << endl; +  cerr << "Intersect time = " << total_intersect_time / 1000.0 << endl; +  cerr << "Sort time = " << intersector->sort_time / 1000.0 << endl; +  cerr << "Linear merge time = " << intersector->linear_merge_time / 1000.0 << endl; +  cerr << "Binary merge time = " << intersector->binary_merge_time / 1000.0 << endl; +  // cerr << "Lookup time = " << total_lookup_time / 1000.0 << endl;    return Grammar(rules, scorer->GetFeatureNames());  } @@ -192,12 +252,12 @@ void HieroCachingRuleFactory::AddTrailingNonterminal(    Phrase var_phrase = phrase_builder->Build(symbols);    int suffix_var_id = vocabulary->GetNonterminalIndex( -      prefix.Arity() + starts_with_x == 0); +      prefix.Arity() + (starts_with_x == 0));    shared_ptr<TrieNode> var_suffix_link =        prefix_node->suffix_link->GetChild(suffix_var_id); -  prefix_node->AddChild(var_id, shared_ptr<TrieNode>(new TrieNode( -      var_suffix_link, var_phrase, prefix_node->matchings))); +  prefix_node->AddChild(var_id, make_shared<TrieNode>( +      var_suffix_link, var_phrase, prefix_node->matchings));  }  vector<State> HieroCachingRuleFactory::ExtendState( @@ -216,7 +276,7 @@ vector<State> HieroCachingRuleFactory::ExtendState(    new_states.push_back(State(state.start, state.end + 1, symbols,        state.subpatterns_start, node, state.starts_with_x)); -  int num_subpatterns = phrase.Arity() + state.starts_with_x == 0; +  int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0);    if (symbols.size() + 1 >= max_rule_symbols ||        phrase.Arity() >= max_nonterminals ||        num_subpatterns >= max_chunks) {  | 
