summaryrefslogtreecommitdiff
path: root/extractor/rule_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'extractor/rule_factory.cc')
-rw-r--r--extractor/rule_factory.cc70
1 files changed, 38 insertions, 32 deletions
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 374a0db1..4101fcfa 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -6,6 +6,7 @@
#include <vector>
#include "grammar.h"
+#include "fast_intersector.h"
#include "intersector.h"
#include "matchings_finder.h"
#include "matching_comparator.h"
@@ -15,10 +16,11 @@
#include "sampler.h"
#include "scorer.h"
#include "suffix_array.h"
+#include "time_util.h"
#include "vocabulary.h"
using namespace std;
-using namespace std::chrono;
+using namespace chrono;
typedef high_resolution_clock Clock;
@@ -48,6 +50,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
int max_nonterminals,
int max_rule_symbols,
int max_samples,
+ bool use_fast_intersect,
bool use_baeza_yates,
bool require_tight_phrases) :
vocabulary(vocabulary),
@@ -56,12 +59,15 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
max_rule_span(max_rule_span),
max_nonterminals(max_nonterminals),
max_chunks(max_nonterminals + 1),
- max_rule_symbols(max_rule_symbols) {
+ max_rule_symbols(max_rule_symbols),
+ use_fast_intersect(use_fast_intersect) {
matchings_finder = make_shared<MatchingsFinder>(source_suffix_array);
shared_ptr<MatchingComparator> comparator =
make_shared<MatchingComparator>(min_gap_size, max_rule_span);
intersector = make_shared<Intersector>(vocabulary, precomputation,
source_suffix_array, comparator, use_baeza_yates);
+ fast_intersector = make_shared<FastIntersector>(source_suffix_array,
+ precomputation, vocabulary, max_rule_span, min_gap_size);
phrase_builder = make_shared<PhraseBuilder>(vocabulary);
rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(),
target_data_array, alignment, phrase_builder, scorer, vocabulary,
@@ -73,6 +79,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
HieroCachingRuleFactory::HieroCachingRuleFactory(
shared_ptr<MatchingsFinder> finder,
shared_ptr<Intersector> intersector,
+ shared_ptr<FastIntersector> fast_intersector,
shared_ptr<PhraseBuilder> phrase_builder,
shared_ptr<RuleExtractor> rule_extractor,
shared_ptr<Vocabulary> vocabulary,
@@ -82,9 +89,11 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
int max_rule_span,
int max_nonterminals,
int max_chunks,
- int max_rule_symbols) :
+ int max_rule_symbols,
+ bool use_fast_intersect) :
matchings_finder(finder),
intersector(intersector),
+ fast_intersector(fast_intersector),
phrase_builder(phrase_builder),
rule_extractor(rule_extractor),
vocabulary(vocabulary),
@@ -94,15 +103,14 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
max_rule_span(max_rule_span),
max_nonterminals(max_nonterminals),
max_chunks(max_chunks),
- max_rule_symbols(max_rule_symbols) {}
+ max_rule_symbols(max_rule_symbols),
+ use_fast_intersect(use_fast_intersect) {}
HieroCachingRuleFactory::HieroCachingRuleFactory() {}
HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
- intersector->binary_merge_time = 0;
- intersector->linear_merge_time = 0;
intersector->sort_time = 0;
Clock::time_point start_time = Clock::now();
double total_extract_time = 0;
@@ -155,25 +163,28 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
} else {
PhraseLocation phrase_location;
if (next_phrase.Arity() > 0) {
- Clock::time_point intersect_start_time = Clock::now();
- phrase_location = intersector->Intersect(
- node->phrase,
- node->matchings,
- next_suffix_link->phrase,
- next_suffix_link->matchings,
- next_phrase);
- Clock::time_point intersect_stop_time = Clock::now();
- total_intersect_time += duration_cast<milliseconds>(
- intersect_stop_time - intersect_start_time).count();
+ Clock::time_point intersect_start = Clock::now();
+ if (use_fast_intersect) {
+ phrase_location = fast_intersector->Intersect(
+ node->matchings, next_suffix_link->matchings, next_phrase);
+ } else {
+ phrase_location = intersector->Intersect(
+ node->phrase,
+ node->matchings,
+ next_suffix_link->phrase,
+ next_suffix_link->matchings,
+ next_phrase);
+ }
+ Clock::time_point intersect_stop = Clock::now();
+ total_intersect_time += GetDuration(intersect_start, intersect_stop);
} else {
- Clock::time_point lookup_start_time = Clock::now();
+ Clock::time_point lookup_start = Clock::now();
phrase_location = matchings_finder->Find(
node->matchings,
vocabulary->GetTerminalValue(word_id),
state.phrase.size());
- Clock::time_point lookup_stop_time = Clock::now();
- total_lookup_time += duration_cast<milliseconds>(
- lookup_stop_time - lookup_start_time).count();
+ Clock::time_point lookup_stop = Clock::now();
+ total_lookup_time += GetDuration(lookup_start, lookup_stop);
}
if (phrase_location.IsEmpty()) {
@@ -189,16 +200,15 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
AddTrailingNonterminal(phrase, next_phrase, next_node,
state.starts_with_x);
- Clock::time_point extract_start_time = Clock::now();
+ Clock::time_point extract_start = Clock::now();
if (!state.starts_with_x) {
PhraseLocation sample = sampler->Sample(next_node->matchings);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
}
- Clock::time_point extract_stop_time = Clock::now();
- total_extract_time += duration_cast<milliseconds>(
- extract_stop_time - extract_start_time).count();
+ Clock::time_point extract_stop = Clock::now();
+ total_extract_time += GetDuration(extract_start, extract_stop);
} else {
next_node = node->GetChild(word_id);
}
@@ -211,15 +221,11 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
}
Clock::time_point stop_time = Clock::now();
- milliseconds ms = duration_cast<milliseconds>(stop_time - start_time);
cerr << "Total time for rule lookup, extraction, and scoring = "
- << ms.count() / 1000.0 << endl;
- cerr << "Extract time = " << total_extract_time / 1000.0 << endl;
- cerr << "Intersect time = " << total_intersect_time / 1000.0 << endl;
- cerr << "Sort time = " << intersector->sort_time / 1000.0 << endl;
- cerr << "Linear merge time = " << intersector->linear_merge_time / 1000.0 << endl;
- cerr << "Binary merge time = " << intersector->binary_merge_time / 1000.0 << endl;
- // cerr << "Lookup time = " << total_lookup_time / 1000.0 << endl;
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+ cerr << "Extract time = " << total_extract_time << " seconds" << endl;
+ cerr << "Intersect time = " << total_intersect_time << " seconds" << endl;
+ cerr << "Lookup time = " << total_lookup_time << " seconds" << endl;
return Grammar(rules, scorer->GetFeatureNames());
}