summaryrefslogtreecommitdiff
path: root/extractor/rule_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'extractor/rule_factory.cc')
-rw-r--r--extractor/rule_factory.cc84
1 files changed, 72 insertions, 12 deletions
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index c22f9b48..374a0db1 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -1,6 +1,6 @@
#include "rule_factory.h"
-#include <cassert>
+#include <chrono>
#include <memory>
#include <queue>
#include <vector>
@@ -18,7 +18,9 @@
#include "vocabulary.h"
using namespace std;
-using namespace tr1;
+using namespace std::chrono;
+
+typedef high_resolution_clock Clock;
struct State {
State(int start, int end, const vector<int>& phrase,
@@ -68,8 +70,44 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
sampler = make_shared<Sampler>(source_suffix_array, max_samples);
}
+HieroCachingRuleFactory::HieroCachingRuleFactory(
+ shared_ptr<MatchingsFinder> finder,
+ shared_ptr<Intersector> intersector,
+ shared_ptr<PhraseBuilder> phrase_builder,
+ shared_ptr<RuleExtractor> rule_extractor,
+ shared_ptr<Vocabulary> vocabulary,
+ shared_ptr<Sampler> sampler,
+ shared_ptr<Scorer> scorer,
+ int min_gap_size,
+ int max_rule_span,
+ int max_nonterminals,
+ int max_chunks,
+ int max_rule_symbols) :
+ matchings_finder(finder),
+ intersector(intersector),
+ phrase_builder(phrase_builder),
+ rule_extractor(rule_extractor),
+ vocabulary(vocabulary),
+ sampler(sampler),
+ scorer(scorer),
+ min_gap_size(min_gap_size),
+ max_rule_span(max_rule_span),
+ max_nonterminals(max_nonterminals),
+ max_chunks(max_chunks),
+ max_rule_symbols(max_rule_symbols) {}
+
+HieroCachingRuleFactory::HieroCachingRuleFactory() {}
+
+HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
+ intersector->binary_merge_time = 0;
+ intersector->linear_merge_time = 0;
+ intersector->sort_time = 0;
+ Clock::time_point start_time = Clock::now();
+ double total_extract_time = 0;
+ double total_intersect_time = 0;
+ double total_lookup_time = 0;
// Clear cache for every new sentence.
trie.Reset();
shared_ptr<TrieNode> root = trie.GetRoot();
@@ -107,34 +145,42 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
}
if (RequiresLookup(node, word_id)) {
- shared_ptr<TrieNode> next_suffix_link =
- node->suffix_link->GetChild(word_id);
+ shared_ptr<TrieNode> next_suffix_link = node->suffix_link == NULL ?
+ trie.GetRoot() : node->suffix_link->GetChild(word_id);
if (state.starts_with_x) {
// If the phrase starts with a non terminal, we simply use the matchings
// from the suffix link.
- next_node = shared_ptr<TrieNode>(new TrieNode(
- next_suffix_link, next_phrase, next_suffix_link->matchings));
+ next_node = make_shared<TrieNode>(
+ next_suffix_link, next_phrase, next_suffix_link->matchings);
} else {
PhraseLocation phrase_location;
if (next_phrase.Arity() > 0) {
+ Clock::time_point intersect_start_time = Clock::now();
phrase_location = intersector->Intersect(
node->phrase,
node->matchings,
next_suffix_link->phrase,
next_suffix_link->matchings,
next_phrase);
+ Clock::time_point intersect_stop_time = Clock::now();
+ total_intersect_time += duration_cast<milliseconds>(
+ intersect_stop_time - intersect_start_time).count();
} else {
+ Clock::time_point lookup_start_time = Clock::now();
phrase_location = matchings_finder->Find(
node->matchings,
vocabulary->GetTerminalValue(word_id),
state.phrase.size());
+ Clock::time_point lookup_stop_time = Clock::now();
+ total_lookup_time += duration_cast<milliseconds>(
+ lookup_stop_time - lookup_start_time).count();
}
if (phrase_location.IsEmpty()) {
continue;
}
- next_node = shared_ptr<TrieNode>(new TrieNode(
- next_suffix_link, next_phrase, phrase_location));
+ next_node = make_shared<TrieNode>(
+ next_suffix_link, next_phrase, phrase_location);
}
node->AddChild(word_id, next_node);
@@ -143,12 +189,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
AddTrailingNonterminal(phrase, next_phrase, next_node,
state.starts_with_x);
+ Clock::time_point extract_start_time = Clock::now();
if (!state.starts_with_x) {
PhraseLocation sample = sampler->Sample(next_node->matchings);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
}
+ Clock::time_point extract_stop_time = Clock::now();
+ total_extract_time += duration_cast<milliseconds>(
+ extract_stop_time - extract_start_time).count();
} else {
next_node = node->GetChild(word_id);
}
@@ -160,6 +210,16 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
}
}
+ Clock::time_point stop_time = Clock::now();
+ milliseconds ms = duration_cast<milliseconds>(stop_time - start_time);
+ cerr << "Total time for rule lookup, extraction, and scoring = "
+ << ms.count() / 1000.0 << endl;
+ cerr << "Extract time = " << total_extract_time / 1000.0 << endl;
+ cerr << "Intersect time = " << total_intersect_time / 1000.0 << endl;
+ cerr << "Sort time = " << intersector->sort_time / 1000.0 << endl;
+ cerr << "Linear merge time = " << intersector->linear_merge_time / 1000.0 << endl;
+ cerr << "Binary merge time = " << intersector->binary_merge_time / 1000.0 << endl;
+ // cerr << "Lookup time = " << total_lookup_time / 1000.0 << endl;
return Grammar(rules, scorer->GetFeatureNames());
}
@@ -192,12 +252,12 @@ void HieroCachingRuleFactory::AddTrailingNonterminal(
Phrase var_phrase = phrase_builder->Build(symbols);
int suffix_var_id = vocabulary->GetNonterminalIndex(
- prefix.Arity() + starts_with_x == 0);
+ prefix.Arity() + (starts_with_x == 0));
shared_ptr<TrieNode> var_suffix_link =
prefix_node->suffix_link->GetChild(suffix_var_id);
- prefix_node->AddChild(var_id, shared_ptr<TrieNode>(new TrieNode(
- var_suffix_link, var_phrase, prefix_node->matchings)));
+ prefix_node->AddChild(var_id, make_shared<TrieNode>(
+ var_suffix_link, var_phrase, prefix_node->matchings));
}
vector<State> HieroCachingRuleFactory::ExtendState(
@@ -216,7 +276,7 @@ vector<State> HieroCachingRuleFactory::ExtendState(
new_states.push_back(State(state.start, state.end + 1, symbols,
state.subpatterns_start, node, state.starts_with_x));
- int num_subpatterns = phrase.Arity() + state.starts_with_x == 0;
+ int num_subpatterns = phrase.Arity() + (state.starts_with_x == 0);
if (symbols.size() + 1 >= max_rule_symbols ||
phrase.Arity() >= max_nonterminals ||
num_subpatterns >= max_chunks) {