summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
Diffstat (limited to 'extractor')
-rw-r--r--extractor/Makefile.am7
-rw-r--r--extractor/data_array.cc3
-rw-r--r--extractor/fast_intersector.cc191
-rw-r--r--extractor/fast_intersector.h65
-rw-r--r--extractor/fast_intersector_test.cc146
-rw-r--r--extractor/grammar_extractor.cc4
-rw-r--r--extractor/grammar_extractor.h1
-rw-r--r--extractor/intersector.cc18
-rw-r--r--extractor/linear_merger.cc10
-rw-r--r--extractor/linear_merger.h4
-rw-r--r--extractor/mocks/mock_fast_intersector.h11
-rw-r--r--extractor/phrase_location.cc12
-rw-r--r--extractor/phrase_location.h4
-rw-r--r--extractor/rule_factory.cc70
-rw-r--r--extractor/rule_factory.h8
-rw-r--r--extractor/rule_factory_test.cc54
-rw-r--r--extractor/run_extractor.cc44
-rw-r--r--extractor/suffix_array.cc15
-rw-r--r--extractor/time_util.cc6
-rw-r--r--extractor/time_util.h14
20 files changed, 610 insertions, 77 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index c82fc1ae..8f76dea5 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -4,6 +4,7 @@ noinst_PROGRAMS = \
alignment_test \
binary_search_merger_test \
data_array_test \
+ fast_intersector_test \
feature_count_source_target_test \
feature_is_source_singleton_test \
feature_is_source_target_singleton_test \
@@ -32,6 +33,7 @@ noinst_PROGRAMS = \
TESTS = alignment_test \
binary_search_merger_test \
data_array_test \
+ fast_intersector_test \
feature_count_source_target_test \
feature_is_source_singleton_test \
feature_is_source_target_singleton_test \
@@ -63,6 +65,8 @@ binary_search_merger_test_SOURCES = binary_search_merger_test.cc
binary_search_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
data_array_test_SOURCES = data_array_test.cc
data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
+fast_intersector_test_SOURCES = fast_intersector_test.cc
+fast_intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
feature_count_source_target_test_SOURCES = features/count_source_target_test.cc
feature_count_source_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
feature_is_source_singleton_test_SOURCES = features/is_source_singleton_test.cc
@@ -125,12 +129,14 @@ libcompile_a_SOURCES = \
phrase_location.cc \
precomputation.cc \
suffix_array.cc \
+ time_util.cc \
translation_table.cc
libextractor_a_SOURCES = \
alignment.cc \
binary_search_merger.cc \
data_array.cc \
+ fast_intersector.cc \
features/count_source_target.cc \
features/feature.cc \
features/is_source_singleton.cc \
@@ -159,6 +165,7 @@ libextractor_a_SOURCES = \
scorer.cc \
suffix_array.cc \
target_phrase_extractor.cc \
+ time_util.cc \
translation_table.cc \
veb.cc \
veb_bitset.cc \
diff --git a/extractor/data_array.cc b/extractor/data_array.cc
index 1097caf3..cd430c69 100644
--- a/extractor/data_array.cc
+++ b/extractor/data_array.cc
@@ -147,7 +147,8 @@ bool DataArray::HasWord(const string& word) const {
}
int DataArray::GetWordId(const string& word) const {
- return word2id.find(word)->second;
+ auto result = word2id.find(word);
+ return result == word2id.end() ? -1 : result->second;
}
string DataArray::GetWord(int word_id) const {
diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc
new file mode 100644
index 00000000..8c7a7af8
--- /dev/null
+++ b/extractor/fast_intersector.cc
@@ -0,0 +1,191 @@
+#include "fast_intersector.h"
+
+#include <cassert>
+
+#include "data_array.h"
+#include "phrase.h"
+#include "phrase_location.h"
+#include "precomputation.h"
+#include "suffix_array.h"
+#include "vocabulary.h"
+
+FastIntersector::FastIntersector(shared_ptr<SuffixArray> suffix_array,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ int min_gap_size) :
+ suffix_array(suffix_array),
+ vocabulary(vocabulary),
+ max_rule_span(max_rule_span),
+ min_gap_size(min_gap_size) {
+ Index precomputed_collocations = precomputation->GetCollocations();
+ for (pair<vector<int>, vector<int> > entry: precomputed_collocations) {
+ vector<int> phrase = ConvertPhrase(entry.first);
+ collocations[phrase] = entry.second;
+ }
+}
+
+FastIntersector::FastIntersector() {}
+
+FastIntersector::~FastIntersector() {}
+
+vector<int> FastIntersector::ConvertPhrase(const vector<int>& old_phrase) {
+ vector<int> new_phrase;
+ new_phrase.reserve(old_phrase.size());
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ int num_nonterminals = 0;
+ for (int word_id: old_phrase) {
+ // TODO(pauldb): Remove overhead for relabelling the nonterminals here.
+ if (word_id == Precomputation::NON_TERMINAL) {
+ ++num_nonterminals;
+ new_phrase.push_back(vocabulary->GetNonterminalIndex(num_nonterminals));
+ } else {
+ new_phrase.push_back(
+ vocabulary->GetTerminalIndex(data_array->GetWord(word_id)));
+ }
+ }
+ return new_phrase;
+}
+
+PhraseLocation FastIntersector::Intersect(
+ PhraseLocation& prefix_location,
+ PhraseLocation& suffix_location,
+ const Phrase& phrase) {
+ vector<int> symbols = phrase.Get();
+
+ // We should never attempt to do an intersect query for a pattern starting or
+ // ending with a non terminal. The RuleFactory should handle these cases,
+ // initializing the matchings list with the one for the pattern without the
+ // starting or ending terminal.
+ assert(vocabulary->IsTerminal(symbols.front())
+ && vocabulary->IsTerminal(symbols.back()));
+
+ if (collocations.count(symbols)) {
+ return PhraseLocation(collocations[symbols], phrase.Arity() + 1);
+ }
+
+ bool prefix_ends_with_x =
+ !vocabulary->IsTerminal(symbols[symbols.size() - 2]);
+ bool suffix_starts_with_x = !vocabulary->IsTerminal(symbols[1]);
+ if (EstimateNumOperations(prefix_location, prefix_ends_with_x) <=
+ EstimateNumOperations(suffix_location, suffix_starts_with_x)) {
+ return ExtendPrefixPhraseLocation(prefix_location, phrase,
+ prefix_ends_with_x, symbols.back());
+ } else {
+ return ExtendSuffixPhraseLocation(suffix_location, phrase,
+ suffix_starts_with_x, symbols.front());
+ }
+}
+
+int FastIntersector::EstimateNumOperations(
+ const PhraseLocation& phrase_location, bool has_margin_x) const {
+ int num_locations = phrase_location.GetSize();
+ return has_margin_x ? num_locations * max_rule_span : num_locations;
+}
+
+PhraseLocation FastIntersector::ExtendPrefixPhraseLocation(
+ PhraseLocation& prefix_location, const Phrase& phrase,
+ bool prefix_ends_with_x, int next_symbol) const {
+ ExtendPhraseLocation(prefix_location);
+ vector<int> positions = *prefix_location.matchings;
+ int num_subpatterns = prefix_location.num_subpatterns;
+
+ vector<int> new_positions;
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ int data_array_symbol = data_array->GetWordId(
+ vocabulary->GetTerminalValue(next_symbol));
+ if (data_array_symbol == -1) {
+ return PhraseLocation(new_positions, num_subpatterns);
+ }
+
+ pair<int, int> range = GetSearchRange(prefix_ends_with_x);
+ for (size_t i = 0; i < positions.size(); i += num_subpatterns) {
+ int sent_id = data_array->GetSentenceId(positions[i]);
+ int sent_end = data_array->GetSentenceStart(sent_id + 1) - 1;
+ int pattern_end = positions[i + num_subpatterns - 1] + range.first;
+ if (prefix_ends_with_x) {
+ pattern_end += phrase.GetChunkLen(phrase.Arity() - 1) - 1;
+ } else {
+ pattern_end += phrase.GetChunkLen(phrase.Arity()) - 2;
+ }
+ for (int j = range.first; j < range.second; ++j) {
+ if (pattern_end >= sent_end ||
+ pattern_end - positions[i] >= max_rule_span) {
+ break;
+ }
+
+ if (data_array->AtIndex(pattern_end) == data_array_symbol) {
+ new_positions.insert(new_positions.end(), positions.begin() + i,
+ positions.begin() + i + num_subpatterns);
+ if (prefix_ends_with_x) {
+ new_positions.push_back(pattern_end);
+ }
+ }
+ ++pattern_end;
+ }
+ }
+
+ return PhraseLocation(new_positions, phrase.Arity() + 1);
+}
+
+PhraseLocation FastIntersector::ExtendSuffixPhraseLocation(
+ PhraseLocation& suffix_location, const Phrase& phrase,
+ bool suffix_starts_with_x, int prev_symbol) const {
+ ExtendPhraseLocation(suffix_location);
+ vector<int> positions = *suffix_location.matchings;
+ int num_subpatterns = suffix_location.num_subpatterns;
+
+ vector<int> new_positions;
+ shared_ptr<DataArray> data_array = suffix_array->GetData();
+ int data_array_symbol = data_array->GetWordId(
+ vocabulary->GetTerminalValue(prev_symbol));
+ if (data_array_symbol == -1) {
+ return PhraseLocation(new_positions, num_subpatterns);
+ }
+
+ pair<int, int> range = GetSearchRange(suffix_starts_with_x);
+ for (size_t i = 0; i < positions.size(); i += num_subpatterns) {
+ int sent_id = data_array->GetSentenceId(positions[i]);
+ int sent_start = data_array->GetSentenceStart(sent_id);
+ int pattern_start = positions[i] - range.first;
+ int pattern_end = positions[i + num_subpatterns - 1] +
+ phrase.GetChunkLen(phrase.Arity()) - 1;
+ for (int j = range.first; j < range.second; ++j) {
+ if (pattern_start < sent_start ||
+ pattern_end - pattern_start >= max_rule_span) {
+ break;
+ }
+
+ if (data_array->AtIndex(pattern_start) == data_array_symbol) {
+ new_positions.push_back(pattern_start);
+ new_positions.insert(new_positions.end(),
+ positions.begin() + i + !suffix_starts_with_x,
+ positions.begin() + i + num_subpatterns);
+ }
+ --pattern_start;
+ }
+ }
+
+ return PhraseLocation(new_positions, phrase.Arity() + 1);
+}
+
+void FastIntersector::ExtendPhraseLocation(PhraseLocation& location) const {
+ if (location.matchings != NULL) {
+ return;
+ }
+
+ location.num_subpatterns = 1;
+ location.matchings = make_shared<vector<int> >();
+ for (int i = location.sa_low; i < location.sa_high; ++i) {
+ location.matchings->push_back(suffix_array->GetSuffix(i));
+ }
+ location.sa_low = location.sa_high = 0;
+}
+
+pair<int, int> FastIntersector::GetSearchRange(bool has_marginal_x) const {
+ if (has_marginal_x) {
+ return make_pair(min_gap_size + 1, max_rule_span);
+ } else {
+ return make_pair(1, 2);
+ }
+}
diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h
new file mode 100644
index 00000000..785e428e
--- /dev/null
+++ b/extractor/fast_intersector.h
@@ -0,0 +1,65 @@
+#ifndef _FAST_INTERSECTOR_H_
+#define _FAST_INTERSECTOR_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include <boost/functional/hash.hpp>
+
+using namespace std;
+
+typedef boost::hash<vector<int> > VectorHash;
+typedef unordered_map<vector<int>, vector<int>, VectorHash> Index;
+
+class Phrase;
+class PhraseLocation;
+class Precomputation;
+class SuffixArray;
+class Vocabulary;
+
+class FastIntersector {
+ public:
+ FastIntersector(shared_ptr<SuffixArray> suffix_array,
+ shared_ptr<Precomputation> precomputation,
+ shared_ptr<Vocabulary> vocabulary,
+ int max_rule_span,
+ int min_gap_size);
+
+ virtual ~FastIntersector();
+
+ virtual PhraseLocation Intersect(PhraseLocation& prefix_location,
+ PhraseLocation& suffix_location,
+ const Phrase& phrase);
+
+ protected:
+ FastIntersector();
+
+ private:
+ vector<int> ConvertPhrase(const vector<int>& old_phrase);
+
+ int EstimateNumOperations(const PhraseLocation& phrase_location,
+ bool has_margin_x) const;
+
+ PhraseLocation ExtendPrefixPhraseLocation(PhraseLocation& prefix_location,
+ const Phrase& phrase,
+ bool prefix_ends_with_x,
+ int next_symbol) const;
+
+ PhraseLocation ExtendSuffixPhraseLocation(PhraseLocation& suffix_location,
+ const Phrase& phrase,
+ bool suffix_starts_with_x,
+ int prev_symbol) const;
+
+ void ExtendPhraseLocation(PhraseLocation& location) const;
+
+ pair<int, int> GetSearchRange(bool has_marginal_x) const;
+
+ shared_ptr<SuffixArray> suffix_array;
+ shared_ptr<Vocabulary> vocabulary;
+ int max_rule_span;
+ int min_gap_size;
+ Index collocations;
+};
+
+#endif
diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc
new file mode 100644
index 00000000..0d6ef367
--- /dev/null
+++ b/extractor/fast_intersector_test.cc
@@ -0,0 +1,146 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "fast_intersector.h"
+#include "mocks/mock_data_array.h"
+#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_precomputation.h"
+#include "mocks/mock_vocabulary.h"
+#include "phrase.h"
+#include "phrase_location.h"
+#include "phrase_builder.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace {
+
+class FastIntersectorTest : public Test {
+ protected:
+ virtual void SetUp() {
+ vector<string> words = {"EOL", "it", "makes", "him", "and", "mars", ",",
+ "sets", "on", "takes", "off", "."};
+ vocabulary = make_shared<MockVocabulary>();
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*vocabulary, GetTerminalIndex(words[i]))
+ .WillRepeatedly(Return(i));
+ EXPECT_CALL(*vocabulary, GetTerminalValue(i))
+ .WillRepeatedly(Return(words[i]));
+ }
+
+ vector<int> data = {1, 2, 3, 4, 1, 5, 3, 6, 1,
+ 7, 3, 8, 4, 1, 9, 3, 10, 11, 0};
+ data_array = make_shared<MockDataArray>();
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i]));
+ EXPECT_CALL(*data_array, GetSentenceId(i))
+ .WillRepeatedly(Return(0));
+ }
+ EXPECT_CALL(*data_array, GetSentenceStart(0))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*data_array, GetSentenceStart(1))
+ .WillRepeatedly(Return(19));
+ for (size_t i = 0; i < words.size(); ++i) {
+ EXPECT_CALL(*data_array, GetWordId(words[i]))
+ .WillRepeatedly(Return(i));
+ EXPECT_CALL(*data_array, GetWord(i))
+ .WillRepeatedly(Return(words[i]));
+ }
+
+ vector<int> suffixes = {18, 0, 4, 8, 13, 1, 2, 6, 10, 15, 3, 12, 5, 7, 9,
+ 11, 14, 16, 17};
+ suffix_array = make_shared<MockSuffixArray>();
+ EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array));
+ for (size_t i = 0; i < suffixes.size(); ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i)).
+ WillRepeatedly(Return(suffixes[i]));
+ }
+
+ precomputation = make_shared<MockPrecomputation>();
+ EXPECT_CALL(*precomputation, GetCollocations())
+ .WillRepeatedly(ReturnRef(collocations));
+
+ phrase_builder = make_shared<PhraseBuilder>(vocabulary);
+ intersector = make_shared<FastIntersector>(suffix_array, precomputation,
+ vocabulary, 15, 1);
+ }
+
+ Index collocations;
+ shared_ptr<MockDataArray> data_array;
+ shared_ptr<MockSuffixArray> suffix_array;
+ shared_ptr<MockPrecomputation> precomputation;
+ shared_ptr<MockVocabulary> vocabulary;
+ shared_ptr<FastIntersector> intersector;
+ shared_ptr<PhraseBuilder> phrase_builder;
+};
+
+TEST_F(FastIntersectorTest, TestCachedCollocation) {
+ vector<int> symbols = {8, -1, 9};
+ vector<int> expected_location = {11};
+ Phrase phrase = phrase_builder->Build(symbols);
+ PhraseLocation prefix_location(15, 16), suffix_location(16, 17);
+
+ collocations[symbols] = expected_location;
+ EXPECT_CALL(*precomputation, GetCollocations())
+ .WillRepeatedly(ReturnRef(collocations));
+ intersector = make_shared<FastIntersector>(suffix_array, precomputation,
+ vocabulary, 15, 1);
+
+ PhraseLocation result = intersector->Intersect(
+ prefix_location, suffix_location, phrase);
+
+ EXPECT_EQ(PhraseLocation(expected_location, 2), result);
+ EXPECT_EQ(PhraseLocation(15, 16), prefix_location);
+ EXPECT_EQ(PhraseLocation(16, 17), suffix_location);
+}
+
+TEST_F(FastIntersectorTest, TestIntersectaXbXcExtendSuffix) {
+ vector<int> symbols = {1, -1, 3, -1, 1};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> prefix_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10,
+ 8, 15, 3, 15};
+ vector<int> suffix_locs = {2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13};
+ PhraseLocation prefix_location(prefix_locs, 2);
+ PhraseLocation suffix_location(suffix_locs, 2);
+
+ vector<int> expected_locs = {0, 2, 4, 0, 2, 8, 0, 2, 13, 4, 6, 8, 0, 6, 8,
+ 4, 6, 13, 0, 6, 13, 8, 10, 13, 4, 10, 13,
+ 0, 10, 13};
+ PhraseLocation result = intersector->Intersect(
+ prefix_location, suffix_location, phrase);
+ EXPECT_EQ(PhraseLocation(expected_locs, 3), result);
+}
+
+/*
+TEST_F(FastIntersectorTest, TestIntersectaXbExtendPrefix) {
+ vector<int> symbols = {1, -1, 3};
+ Phrase phrase = phrase_builder->Build(symbols);
+ PhraseLocation prefix_location(1, 5), suffix_location(6, 10);
+
+ vector<int> expected_prefix_locs = {0, 4, 8, 13};
+ vector<int> expected_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10,
+ 8, 15, 13, 15};
+ PhraseLocation result = intersector->Intersect(
+ prefix_location, suffix_location, phrase);
+ EXPECT_EQ(PhraseLocation(expected_locs, 2), result);
+ EXPECT_EQ(PhraseLocation(expected_prefix_locs, 1), prefix_location);
+}
+
+TEST_F(FastIntersectorTest, TestIntersectCheckEstimates) {
+ // The suffix matches in fewer positions, but because it starts with an X
+ // it requires more operations and we prefer extending the prefix.
+ vector<int> symbols = {1, -1, 4, 1};
+ Phrase phrase = phrase_builder->Build(symbols);
+ vector<int> prefix_locs = {0, 3, 0, 12, 4, 12, 8, 12};
+ PhraseLocation prefix_location(prefix_locs, 2), suffix_location(10, 12);
+
+ vector<int> expected_locs = {0, 3, 0, 12, 4, 12, 8, 12};
+ PhraseLocation result = intersector->Intersect(
+ prefix_location, suffix_location, phrase);
+ EXPECT_EQ(PhraseLocation(expected_locs, 2), result);
+ EXPECT_EQ(PhraseLocation(10, 12), suffix_location);
+}
+*/
+
+} // namespace
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 2f008026..a03e805f 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -16,12 +16,12 @@ GrammarExtractor::GrammarExtractor(
shared_ptr<Alignment> alignment, shared_ptr<Precomputation> precomputation,
shared_ptr<Scorer> scorer, int min_gap_size, int max_rule_span,
int max_nonterminals, int max_rule_symbols, int max_samples,
- bool use_baeza_yates, bool require_tight_phrases) :
+ bool use_fast_intersect, bool use_baeza_yates, bool require_tight_phrases) :
vocabulary(make_shared<Vocabulary>()),
rule_factory(make_shared<HieroCachingRuleFactory>(
source_suffix_array, target_data_array, alignment, vocabulary,
precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals,
- max_rule_symbols, max_samples, use_baeza_yates,
+ max_rule_symbols, max_samples, use_fast_intersect, use_baeza_yates,
require_tight_phrases)) {}
GrammarExtractor::GrammarExtractor(
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index 5f87faa7..a8f2090d 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -29,6 +29,7 @@ class GrammarExtractor {
int max_nonterminals,
int max_rule_symbols,
int max_samples,
+ bool use_fast_intersect,
bool use_baeza_yates,
bool require_tight_phrases);
diff --git a/extractor/intersector.cc b/extractor/intersector.cc
index cf42f630..39a7648d 100644
--- a/extractor/intersector.cc
+++ b/extractor/intersector.cc
@@ -1,7 +1,5 @@
#include "intersector.h"
-#include <chrono>
-
#include "data_array.h"
#include "matching_comparator.h"
#include "phrase.h"
@@ -11,10 +9,6 @@
#include "veb.h"
#include "vocabulary.h"
-using namespace std::chrono;
-
-typedef high_resolution_clock Clock;
-
Intersector::Intersector(shared_ptr<Vocabulary> vocabulary,
shared_ptr<Precomputation> precomputation,
shared_ptr<SuffixArray> suffix_array,
@@ -92,9 +86,6 @@ PhraseLocation Intersector::Intersect(
const Phrase& prefix, PhraseLocation& prefix_location,
const Phrase& suffix, PhraseLocation& suffix_location,
const Phrase& phrase) {
- if (linear_merge_time == 0) {
- linear_merger->linear_merge_time = 0;
- }
vector<int> symbols = phrase.Get();
// We should never attempt to do an intersect query for a pattern starting or
@@ -116,21 +107,15 @@ PhraseLocation Intersector::Intersect(
int prefix_subpatterns = prefix_location.num_subpatterns;
int suffix_subpatterns = suffix_location.num_subpatterns;
if (use_baeza_yates) {
- double prev_linear_merge_time = linear_merger->linear_merge_time;
- Clock::time_point start = Clock::now();
binary_search_merger->Merge(locations, phrase, suffix,
prefix_matchings->begin(), prefix_matchings->end(),
suffix_matchings->begin(), suffix_matchings->end(),
prefix_subpatterns, suffix_subpatterns);
- Clock::time_point stop = Clock::now();
- binary_merge_time += duration_cast<milliseconds>(stop - start).count() -
- (linear_merger->linear_merge_time - prev_linear_merge_time);
} else {
linear_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(),
prefix_matchings->end(), suffix_matchings->begin(),
suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns);
}
- linear_merge_time = linear_merger->linear_merge_time;
return PhraseLocation(locations, phrase.Arity() + 1);
}
@@ -141,7 +126,6 @@ void Intersector::ExtendPhraseLocation(
return;
}
- Clock::time_point sort_start = Clock::now();
phrase_location.num_subpatterns = 1;
phrase_location.sa_low = phrase_location.sa_high = 0;
@@ -167,6 +151,4 @@ void Intersector::ExtendPhraseLocation(
}
phrase_location.matchings = make_shared<vector<int> >(matchings);
- Clock::time_point sort_stop = Clock::now();
- sort_time += duration_cast<milliseconds>(sort_stop - sort_start).count();
}
diff --git a/extractor/linear_merger.cc b/extractor/linear_merger.cc
index 7233f945..e7a32788 100644
--- a/extractor/linear_merger.cc
+++ b/extractor/linear_merger.cc
@@ -1,6 +1,5 @@
#include "linear_merger.h"
-#include <chrono>
#include <cmath>
#include "data_array.h"
@@ -10,10 +9,6 @@
#include "phrase_location.h"
#include "vocabulary.h"
-using namespace std::chrono;
-
-typedef high_resolution_clock Clock;
-
LinearMerger::LinearMerger(shared_ptr<Vocabulary> vocabulary,
shared_ptr<DataArray> data_array,
shared_ptr<MatchingComparator> comparator) :
@@ -28,8 +23,6 @@ void LinearMerger::Merge(
vector<int>::iterator prefix_start, vector<int>::iterator prefix_end,
vector<int>::iterator suffix_start, vector<int>::iterator suffix_end,
int prefix_subpatterns, int suffix_subpatterns) {
- Clock::time_point start = Clock::now();
-
int last_chunk_len = suffix.GetChunkLen(suffix.Arity());
bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0));
@@ -69,7 +62,4 @@ void LinearMerger::Merge(
prefix_start += prefix_subpatterns;
}
}
-
- Clock::time_point stop = Clock::now();
- linear_merge_time += duration_cast<milliseconds>(stop - start).count();
}
diff --git a/extractor/linear_merger.h b/extractor/linear_merger.h
index 25692b15..c3c7111e 100644
--- a/extractor/linear_merger.h
+++ b/extractor/linear_merger.h
@@ -33,10 +33,6 @@ class LinearMerger {
shared_ptr<Vocabulary> vocabulary;
shared_ptr<DataArray> data_array;
shared_ptr<MatchingComparator> comparator;
-
- // TODO(pauldb): Remove this eventually.
- public:
- double linear_merge_time;
};
#endif
diff --git a/extractor/mocks/mock_fast_intersector.h b/extractor/mocks/mock_fast_intersector.h
new file mode 100644
index 00000000..201386f2
--- /dev/null
+++ b/extractor/mocks/mock_fast_intersector.h
@@ -0,0 +1,11 @@
+#include <gmock/gmock.h>
+
+#include "../fast_intersector.h"
+#include "../phrase.h"
+#include "../phrase_location.h"
+
+class MockFastIntersector : public FastIntersector {
+ public:
+ MOCK_METHOD3(Intersect, PhraseLocation(PhraseLocation&, PhraseLocation&,
+ const Phrase&));
+};
diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc
index 62f1e714..b0bfed80 100644
--- a/extractor/phrase_location.cc
+++ b/extractor/phrase_location.cc
@@ -5,15 +5,19 @@ PhraseLocation::PhraseLocation(int sa_low, int sa_high) :
PhraseLocation::PhraseLocation(const vector<int>& matchings,
int num_subpatterns) :
- sa_high(0), sa_low(0),
+ sa_low(0), sa_high(0),
matchings(make_shared<vector<int> >(matchings)),
num_subpatterns(num_subpatterns) {}
-bool PhraseLocation::IsEmpty() {
+bool PhraseLocation::IsEmpty() const {
+ return GetSize() == 0;
+}
+
+int PhraseLocation::GetSize() const {
if (num_subpatterns > 0) {
- return matchings->size() == 0;
+ return matchings->size();
} else {
- return sa_low >= sa_high;
+ return sa_high - sa_low;
}
}
diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h
index e04d8628..a0eb36c8 100644
--- a/extractor/phrase_location.h
+++ b/extractor/phrase_location.h
@@ -11,7 +11,9 @@ struct PhraseLocation {
PhraseLocation(const vector<int>& matchings, int num_subpatterns);
- bool IsEmpty();
+ bool IsEmpty() const;
+
+ int GetSize() const;
friend bool operator==(const PhraseLocation& a, const PhraseLocation& b);
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 374a0db1..4101fcfa 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -6,6 +6,7 @@
#include <vector>
#include "grammar.h"
+#include "fast_intersector.h"
#include "intersector.h"
#include "matchings_finder.h"
#include "matching_comparator.h"
@@ -15,10 +16,11 @@
#include "sampler.h"
#include "scorer.h"
#include "suffix_array.h"
+#include "time_util.h"
#include "vocabulary.h"
using namespace std;
-using namespace std::chrono;
+using namespace chrono;
typedef high_resolution_clock Clock;
@@ -48,6 +50,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
int max_nonterminals,
int max_rule_symbols,
int max_samples,
+ bool use_fast_intersect,
bool use_baeza_yates,
bool require_tight_phrases) :
vocabulary(vocabulary),
@@ -56,12 +59,15 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
max_rule_span(max_rule_span),
max_nonterminals(max_nonterminals),
max_chunks(max_nonterminals + 1),
- max_rule_symbols(max_rule_symbols) {
+ max_rule_symbols(max_rule_symbols),
+ use_fast_intersect(use_fast_intersect) {
matchings_finder = make_shared<MatchingsFinder>(source_suffix_array);
shared_ptr<MatchingComparator> comparator =
make_shared<MatchingComparator>(min_gap_size, max_rule_span);
intersector = make_shared<Intersector>(vocabulary, precomputation,
source_suffix_array, comparator, use_baeza_yates);
+ fast_intersector = make_shared<FastIntersector>(source_suffix_array,
+ precomputation, vocabulary, max_rule_span, min_gap_size);
phrase_builder = make_shared<PhraseBuilder>(vocabulary);
rule_extractor = make_shared<RuleExtractor>(source_suffix_array->GetData(),
target_data_array, alignment, phrase_builder, scorer, vocabulary,
@@ -73,6 +79,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
HieroCachingRuleFactory::HieroCachingRuleFactory(
shared_ptr<MatchingsFinder> finder,
shared_ptr<Intersector> intersector,
+ shared_ptr<FastIntersector> fast_intersector,
shared_ptr<PhraseBuilder> phrase_builder,
shared_ptr<RuleExtractor> rule_extractor,
shared_ptr<Vocabulary> vocabulary,
@@ -82,9 +89,11 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
int max_rule_span,
int max_nonterminals,
int max_chunks,
- int max_rule_symbols) :
+ int max_rule_symbols,
+ bool use_fast_intersect) :
matchings_finder(finder),
intersector(intersector),
+ fast_intersector(fast_intersector),
phrase_builder(phrase_builder),
rule_extractor(rule_extractor),
vocabulary(vocabulary),
@@ -94,15 +103,14 @@ HieroCachingRuleFactory::HieroCachingRuleFactory(
max_rule_span(max_rule_span),
max_nonterminals(max_nonterminals),
max_chunks(max_chunks),
- max_rule_symbols(max_rule_symbols) {}
+ max_rule_symbols(max_rule_symbols),
+ use_fast_intersect(use_fast_intersect) {}
HieroCachingRuleFactory::HieroCachingRuleFactory() {}
HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
- intersector->binary_merge_time = 0;
- intersector->linear_merge_time = 0;
intersector->sort_time = 0;
Clock::time_point start_time = Clock::now();
double total_extract_time = 0;
@@ -155,25 +163,28 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
} else {
PhraseLocation phrase_location;
if (next_phrase.Arity() > 0) {
- Clock::time_point intersect_start_time = Clock::now();
- phrase_location = intersector->Intersect(
- node->phrase,
- node->matchings,
- next_suffix_link->phrase,
- next_suffix_link->matchings,
- next_phrase);
- Clock::time_point intersect_stop_time = Clock::now();
- total_intersect_time += duration_cast<milliseconds>(
- intersect_stop_time - intersect_start_time).count();
+ Clock::time_point intersect_start = Clock::now();
+ if (use_fast_intersect) {
+ phrase_location = fast_intersector->Intersect(
+ node->matchings, next_suffix_link->matchings, next_phrase);
+ } else {
+ phrase_location = intersector->Intersect(
+ node->phrase,
+ node->matchings,
+ next_suffix_link->phrase,
+ next_suffix_link->matchings,
+ next_phrase);
+ }
+ Clock::time_point intersect_stop = Clock::now();
+ total_intersect_time += GetDuration(intersect_start, intersect_stop);
} else {
- Clock::time_point lookup_start_time = Clock::now();
+ Clock::time_point lookup_start = Clock::now();
phrase_location = matchings_finder->Find(
node->matchings,
vocabulary->GetTerminalValue(word_id),
state.phrase.size());
- Clock::time_point lookup_stop_time = Clock::now();
- total_lookup_time += duration_cast<milliseconds>(
- lookup_stop_time - lookup_start_time).count();
+ Clock::time_point lookup_stop = Clock::now();
+ total_lookup_time += GetDuration(lookup_start, lookup_stop);
}
if (phrase_location.IsEmpty()) {
@@ -189,16 +200,15 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
AddTrailingNonterminal(phrase, next_phrase, next_node,
state.starts_with_x);
- Clock::time_point extract_start_time = Clock::now();
+ Clock::time_point extract_start = Clock::now();
if (!state.starts_with_x) {
PhraseLocation sample = sampler->Sample(next_node->matchings);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
}
- Clock::time_point extract_stop_time = Clock::now();
- total_extract_time += duration_cast<milliseconds>(
- extract_stop_time - extract_start_time).count();
+ Clock::time_point extract_stop = Clock::now();
+ total_extract_time += GetDuration(extract_start, extract_stop);
} else {
next_node = node->GetChild(word_id);
}
@@ -211,15 +221,11 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
}
Clock::time_point stop_time = Clock::now();
- milliseconds ms = duration_cast<milliseconds>(stop_time - start_time);
cerr << "Total time for rule lookup, extraction, and scoring = "
- << ms.count() / 1000.0 << endl;
- cerr << "Extract time = " << total_extract_time / 1000.0 << endl;
- cerr << "Intersect time = " << total_intersect_time / 1000.0 << endl;
- cerr << "Sort time = " << intersector->sort_time / 1000.0 << endl;
- cerr << "Linear merge time = " << intersector->linear_merge_time / 1000.0 << endl;
- cerr << "Binary merge time = " << intersector->binary_merge_time / 1000.0 << endl;
- // cerr << "Lookup time = " << total_lookup_time / 1000.0 << endl;
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+ cerr << "Extract time = " << total_extract_time << " seconds" << endl;
+ cerr << "Intersect time = " << total_intersect_time << " seconds" << endl;
+ cerr << "Lookup time = " << total_lookup_time << " seconds" << endl;
return Grammar(rules, scorer->GetFeatureNames());
}
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index cf344667..a39386a8 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -13,6 +13,7 @@ class Alignment;
class DataArray;
class Grammar;
class MatchingsFinder;
+class FastIntersector;
class Intersector;
class Precomputation;
class Rule;
@@ -37,6 +38,7 @@ class HieroCachingRuleFactory {
int max_nonterminals,
int max_rule_symbols,
int max_samples,
+ bool use_fast_intersect,
bool use_beaza_yates,
bool require_tight_phrases);
@@ -44,6 +46,7 @@ class HieroCachingRuleFactory {
HieroCachingRuleFactory(
shared_ptr<MatchingsFinder> finder,
shared_ptr<Intersector> intersector,
+ shared_ptr<FastIntersector> fast_intersector,
shared_ptr<PhraseBuilder> phrase_builder,
shared_ptr<RuleExtractor> rule_extractor,
shared_ptr<Vocabulary> vocabulary,
@@ -53,7 +56,8 @@ class HieroCachingRuleFactory {
int max_rule_span,
int max_nonterminals,
int max_chunks,
- int max_rule_symbols);
+ int max_rule_symbols,
+ bool use_fast_intersect);
virtual ~HieroCachingRuleFactory();
@@ -80,6 +84,7 @@ class HieroCachingRuleFactory {
shared_ptr<MatchingsFinder> matchings_finder;
shared_ptr<Intersector> intersector;
+ shared_ptr<FastIntersector> fast_intersector;
MatchingsTrie trie;
shared_ptr<PhraseBuilder> phrase_builder;
shared_ptr<RuleExtractor> rule_extractor;
@@ -91,6 +96,7 @@ class HieroCachingRuleFactory {
int max_nonterminals;
int max_chunks;
int max_rule_symbols;
+ bool use_fast_intersect;
};
#endif
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
index d6fbab74..d329382a 100644
--- a/extractor/rule_factory_test.cc
+++ b/extractor/rule_factory_test.cc
@@ -5,6 +5,7 @@
#include <vector>
#include "grammar.h"
+#include "mocks/mock_fast_intersector.h"
#include "mocks/mock_intersector.h"
#include "mocks/mock_matchings_finder.h"
#include "mocks/mock_rule_extractor.h"
@@ -25,6 +26,7 @@ class RuleFactoryTest : public Test {
virtual void SetUp() {
finder = make_shared<MockMatchingsFinder>();
intersector = make_shared<MockIntersector>();
+ fast_intersector = make_shared<MockFastIntersector>();
vocabulary = make_shared<MockVocabulary>();
EXPECT_CALL(*vocabulary, GetTerminalValue(2)).WillRepeatedly(Return("a"));
@@ -49,14 +51,12 @@ class RuleFactoryTest : public Test {
extractor = make_shared<MockRuleExtractor>();
EXPECT_CALL(*extractor, ExtractRules(_, _))
.WillRepeatedly(Return(rules));
-
- factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
- phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5);
}
vector<string> feature_names;
shared_ptr<MockMatchingsFinder> finder;
shared_ptr<MockIntersector> intersector;
+ shared_ptr<MockFastIntersector> fast_intersector;
shared_ptr<MockVocabulary> vocabulary;
shared_ptr<PhraseBuilder> phrase_builder;
shared_ptr<MockScorer> scorer;
@@ -66,6 +66,10 @@ class RuleFactoryTest : public Test {
};
TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
+ factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
+ fast_intersector, phrase_builder, extractor, vocabulary, sampler,
+ scorer, 1, 10, 2, 3, 5, false);
+
EXPECT_CALL(*finder, Find(_, _, _))
.Times(6)
.WillRepeatedly(Return(PhraseLocation(0, 1)));
@@ -73,14 +77,37 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
EXPECT_CALL(*intersector, Intersect(_, _, _, _, _))
.Times(1)
.WillRepeatedly(Return(PhraseLocation(0, 1)));
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0);
vector<int> word_ids = {2, 3, 4};
Grammar grammar = factory->GetGrammar(word_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(7, grammar.GetRules().size());
+
+ // Test for fast intersector.
+ factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
+ fast_intersector, phrase_builder, extractor, vocabulary, sampler,
+ scorer, 1, 10, 2, 3, 5, true);
+
+ EXPECT_CALL(*finder, Find(_, _, _))
+ .Times(6)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _))
+ .Times(1)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+ EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0);
+
+ grammar = factory->GetGrammar(word_ids);
+ EXPECT_EQ(feature_names, grammar.GetFeatureNames());
+ EXPECT_EQ(7, grammar.GetRules().size());
}
TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
+ factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
+ fast_intersector, phrase_builder, extractor, vocabulary, sampler,
+ scorer, 1, 10, 2, 3, 5, false);
+
EXPECT_CALL(*finder, Find(_, _, _))
.Times(12)
.WillRepeatedly(Return(PhraseLocation(0, 1)));
@@ -89,10 +116,31 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
.Times(16)
.WillRepeatedly(Return(PhraseLocation(0, 1)));
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0);
+
vector<int> word_ids = {2, 3, 4, 2, 3};
Grammar grammar = factory->GetGrammar(word_ids);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(28, grammar.GetRules().size());
+
+ // Test for fast intersector.
+ factory = make_shared<HieroCachingRuleFactory>(finder, intersector,
+ fast_intersector, phrase_builder, extractor, vocabulary, sampler,
+ scorer, 1, 10, 2, 3, 5, true);
+
+ EXPECT_CALL(*finder, Find(_, _, _))
+ .Times(12)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*fast_intersector, Intersect(_, _, _))
+ .Times(16)
+ .WillRepeatedly(Return(PhraseLocation(0, 1)));
+
+ EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0);
+
+ grammar = factory->GetGrammar(word_ids);
+ EXPECT_EQ(feature_names, grammar.GetFeatureNames());
+ EXPECT_EQ(28, grammar.GetRules().size());
}
} // namespace
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index ed30e6fe..38f10a5f 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -1,3 +1,4 @@
+#include <chrono>
#include <fstream>
#include <iostream>
#include <string>
@@ -23,6 +24,7 @@
#include "rule.h"
#include "scorer.h"
#include "suffix_array.h"
+#include "time_util.h"
#include "translation_table.h"
namespace fs = boost::filesystem;
@@ -56,6 +58,9 @@ int main(int argc, char** argv) {
"Minimum number of occurences for a pharse to be considered frequent")
("max_samples", po::value<int>()->default_value(300),
"Maximum number of samples")
+ ("fast_intersect", po::value<bool>()->default_value(false),
+ "Enable fast intersect")
+ // TODO(pauldb): Check if this works when set to false.
("tight_phrases", po::value<bool>()->default_value(true),
"False if phrases may be loose (better, but slower)")
("baeza_yates", po::value<bool>()->default_value(true),
@@ -80,6 +85,9 @@ int main(int argc, char** argv) {
return 1;
}
+ Clock::time_point preprocess_start_time = Clock::now();
+ cerr << "Reading source and target data..." << endl;
+ Clock::time_point start_time = Clock::now();
shared_ptr<DataArray> source_data_array, target_data_array;
if (vm.count("bitext")) {
source_data_array = make_shared<DataArray>(
@@ -90,13 +98,28 @@ int main(int argc, char** argv) {
source_data_array = make_shared<DataArray>(vm["source"].as<string>());
target_data_array = make_shared<DataArray>(vm["target"].as<string>());
}
+ Clock::time_point stop_time = Clock::now();
+ cerr << "Reading data took " << GetDuration(start_time, stop_time)
+ << " seconds" << endl;
+
+ cerr << "Creating source suffix array..." << endl;
+ start_time = Clock::now();
shared_ptr<SuffixArray> source_suffix_array =
make_shared<SuffixArray>(source_data_array);
+ stop_time = Clock::now();
+ cerr << "Creating suffix array took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
-
+ cerr << "Reading alignment..." << endl;
+ start_time = Clock::now();
shared_ptr<Alignment> alignment =
make_shared<Alignment>(vm["alignment"].as<string>());
+ stop_time = Clock::now();
+ cerr << "Reading alignment took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+ cerr << "Precomputating collocations..." << endl;
+ start_time = Clock::now();
shared_ptr<Precomputation> precomputation = make_shared<Precomputation>(
source_suffix_array,
vm["frequent"].as<int>(),
@@ -106,10 +129,24 @@ int main(int argc, char** argv) {
vm["min_gap_size"].as<int>(),
vm["max_phrase_len"].as<int>(),
vm["min_frequency"].as<int>());
+ stop_time = Clock::now();
+ cerr << "Precomputing collocations took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+ cerr << "Precomputing conditional probabilities..." << endl;
+ start_time = Clock::now();
shared_ptr<TranslationTable> table = make_shared<TranslationTable>(
source_data_array, target_data_array, alignment);
+ stop_time = Clock::now();
+ cerr << "Precomputing conditional probabilities took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
+ Clock::time_point preprocess_stop_time = Clock::now();
+ cerr << "Overall preprocessing step took "
+ << GetDuration(preprocess_start_time, preprocess_stop_time)
+ << " seconds" << endl;
+ Clock::time_point extraction_start_time = Clock::now();
vector<shared_ptr<Feature> > features = {
make_shared<TargetGivenSourceCoherent>(),
make_shared<SampleSourceCount>(),
@@ -133,6 +170,7 @@ int main(int argc, char** argv) {
vm["max_nonterminals"].as<int>(),
vm["max_rule_symbols"].as<int>(),
vm["max_samples"].as<int>(),
+ vm["fast_intersect"].as<bool>(),
vm["baeza_yates"].as<bool>(),
vm["tight_phrases"].as<bool>());
@@ -161,6 +199,10 @@ int main(int argc, char** argv) {
<< "\"> " << sentence << " </seg> " << suffix << endl;
++grammar_id;
}
+ Clock::time_point extraction_stop_time = Clock::now();
+ cerr << "Overall extraction step took "
+ << GetDuration(extraction_start_time, extraction_stop_time)
+ << " seconds" << endl;
return 0;
}
diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc
index 9815996f..23c458a4 100644
--- a/extractor/suffix_array.cc
+++ b/extractor/suffix_array.cc
@@ -1,14 +1,17 @@
#include "suffix_array.h"
+#include <chrono>
#include <iostream>
#include <string>
#include <vector>
#include "data_array.h"
#include "phrase_location.h"
+#include "time_util.h"
namespace fs = boost::filesystem;
using namespace std;
+using namespace chrono;
SuffixArray::SuffixArray(shared_ptr<DataArray> data_array) :
data_array(data_array) {
@@ -39,6 +42,7 @@ void SuffixArray::BuildSuffixArray() {
}
PrefixDoublingSort(groups);
+ cerr << "\tFinalizing sort..." << endl;
for (size_t i = 0; i < groups.size(); ++i) {
suffix_array[groups[i]] = i;
@@ -46,6 +50,7 @@ void SuffixArray::BuildSuffixArray() {
}
void SuffixArray::InitialBucketSort(vector<int>& groups) {
+ Clock::time_point start_time = Clock::now();
for (size_t i = 0; i < groups.size(); ++i) {
++word_start[groups[i]];
}
@@ -62,6 +67,9 @@ void SuffixArray::InitialBucketSort(vector<int>& groups) {
for (size_t i = 0; i < suffix_array.size(); ++i) {
groups[i] = word_start[groups[i] + 1] - 1;
}
+ Clock::time_point stop_time = Clock::now();
+ cerr << "\tBucket sort took " << GetDuration(start_time, stop_time)
+ << " seconds" << endl;
}
void SuffixArray::PrefixDoublingSort(vector<int>& groups) {
@@ -127,6 +135,9 @@ void SuffixArray::TernaryQuicksort(int left, int right, int step,
}
vector<int> SuffixArray::BuildLCPArray() const {
+ Clock::time_point start_time = Clock::now();
+ cerr << "Constructing LCP array..." << endl;
+
vector<int> lcp(suffix_array.size());
vector<int> rank(suffix_array.size());
const vector<int>& data = data_array->GetData();
@@ -153,6 +164,10 @@ vector<int> SuffixArray::BuildLCPArray() const {
}
}
+ Clock::time_point stop_time = Clock::now();
+ cerr << "Constructing LCP took "
+ << GetDuration(start_time, stop_time) << " seconds" << endl;
+
return lcp;
}
diff --git a/extractor/time_util.cc b/extractor/time_util.cc
new file mode 100644
index 00000000..88395f77
--- /dev/null
+++ b/extractor/time_util.cc
@@ -0,0 +1,6 @@
+#include "time_util.h"
+
+double GetDuration(const Clock::time_point& start_time,
+ const Clock::time_point& stop_time) {
+ return duration_cast<milliseconds>(stop_time - start_time).count() / 1000.0;
+}
diff --git a/extractor/time_util.h b/extractor/time_util.h
new file mode 100644
index 00000000..6f7eda70
--- /dev/null
+++ b/extractor/time_util.h
@@ -0,0 +1,14 @@
+#ifndef _TIME_UTIL_H_
+#define _TIME_UTIL_H_
+
+#include <chrono>
+
+using namespace std;
+using namespace chrono;
+
+typedef high_resolution_clock Clock;
+
+double GetDuration(const Clock::time_point& start_time,
+ const Clock::time_point& stop_time);
+
+#endif