From 54a1c0e2bde259e3acc9c0a8ec8da3c7704e80ca Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Tue, 19 Feb 2013 21:23:48 +0000 Subject: Timing every part of the extractor. --- extractor/Makefile.am | 7 ++ extractor/data_array.cc | 3 +- extractor/fast_intersector.cc | 191 ++++++++++++++++++++++++++++++++ extractor/fast_intersector.h | 65 +++++++++++ extractor/fast_intersector_test.cc | 146 ++++++++++++++++++++++++ extractor/grammar_extractor.cc | 4 +- extractor/grammar_extractor.h | 1 + extractor/intersector.cc | 18 --- extractor/linear_merger.cc | 10 -- extractor/linear_merger.h | 4 - extractor/mocks/mock_fast_intersector.h | 11 ++ extractor/phrase_location.cc | 12 +- extractor/phrase_location.h | 4 +- extractor/rule_factory.cc | 70 ++++++------ extractor/rule_factory.h | 8 +- extractor/rule_factory_test.cc | 54 ++++++++- extractor/run_extractor.cc | 44 +++++++- extractor/suffix_array.cc | 15 +++ extractor/time_util.cc | 6 + extractor/time_util.h | 14 +++ 20 files changed, 610 insertions(+), 77 deletions(-) create mode 100644 extractor/fast_intersector.cc create mode 100644 extractor/fast_intersector.h create mode 100644 extractor/fast_intersector_test.cc create mode 100644 extractor/mocks/mock_fast_intersector.h create mode 100644 extractor/time_util.cc create mode 100644 extractor/time_util.h (limited to 'extractor') diff --git a/extractor/Makefile.am b/extractor/Makefile.am index c82fc1ae..8f76dea5 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -4,6 +4,7 @@ noinst_PROGRAMS = \ alignment_test \ binary_search_merger_test \ data_array_test \ + fast_intersector_test \ feature_count_source_target_test \ feature_is_source_singleton_test \ feature_is_source_target_singleton_test \ @@ -32,6 +33,7 @@ noinst_PROGRAMS = \ TESTS = alignment_test \ binary_search_merger_test \ data_array_test \ + fast_intersector_test \ feature_count_source_target_test \ feature_is_source_singleton_test \ feature_is_source_target_singleton_test \ @@ -63,6 +65,8 @@ binary_search_merger_test_SOURCES = binary_search_merger_test.cc binary_search_merger_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a data_array_test_SOURCES = data_array_test.cc data_array_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a +fast_intersector_test_SOURCES = fast_intersector_test.cc +fast_intersector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a feature_count_source_target_test_SOURCES = features/count_source_target_test.cc feature_count_source_target_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a feature_is_source_singleton_test_SOURCES = features/is_source_singleton_test.cc @@ -125,12 +129,14 @@ libcompile_a_SOURCES = \ phrase_location.cc \ precomputation.cc \ suffix_array.cc \ + time_util.cc \ translation_table.cc libextractor_a_SOURCES = \ alignment.cc \ binary_search_merger.cc \ data_array.cc \ + fast_intersector.cc \ features/count_source_target.cc \ features/feature.cc \ features/is_source_singleton.cc \ @@ -159,6 +165,7 @@ libextractor_a_SOURCES = \ scorer.cc \ suffix_array.cc \ target_phrase_extractor.cc \ + time_util.cc \ translation_table.cc \ veb.cc \ veb_bitset.cc \ diff --git a/extractor/data_array.cc b/extractor/data_array.cc index 1097caf3..cd430c69 100644 --- a/extractor/data_array.cc +++ b/extractor/data_array.cc @@ -147,7 +147,8 @@ bool DataArray::HasWord(const string& word) const { } int DataArray::GetWordId(const string& word) const { - return word2id.find(word)->second; + auto result = word2id.find(word); + return result == word2id.end() ? -1 : result->second; } string DataArray::GetWord(int word_id) const { diff --git a/extractor/fast_intersector.cc b/extractor/fast_intersector.cc new file mode 100644 index 00000000..8c7a7af8 --- /dev/null +++ b/extractor/fast_intersector.cc @@ -0,0 +1,191 @@ +#include "fast_intersector.h" + +#include + +#include "data_array.h" +#include "phrase.h" +#include "phrase_location.h" +#include "precomputation.h" +#include "suffix_array.h" +#include "vocabulary.h" + +FastIntersector::FastIntersector(shared_ptr suffix_array, + shared_ptr precomputation, + shared_ptr vocabulary, + int max_rule_span, + int min_gap_size) : + suffix_array(suffix_array), + vocabulary(vocabulary), + max_rule_span(max_rule_span), + min_gap_size(min_gap_size) { + Index precomputed_collocations = precomputation->GetCollocations(); + for (pair, vector > entry: precomputed_collocations) { + vector phrase = ConvertPhrase(entry.first); + collocations[phrase] = entry.second; + } +} + +FastIntersector::FastIntersector() {} + +FastIntersector::~FastIntersector() {} + +vector FastIntersector::ConvertPhrase(const vector& old_phrase) { + vector new_phrase; + new_phrase.reserve(old_phrase.size()); + shared_ptr data_array = suffix_array->GetData(); + int num_nonterminals = 0; + for (int word_id: old_phrase) { + // TODO(pauldb): Remove overhead for relabelling the nonterminals here. + if (word_id == Precomputation::NON_TERMINAL) { + ++num_nonterminals; + new_phrase.push_back(vocabulary->GetNonterminalIndex(num_nonterminals)); + } else { + new_phrase.push_back( + vocabulary->GetTerminalIndex(data_array->GetWord(word_id))); + } + } + return new_phrase; +} + +PhraseLocation FastIntersector::Intersect( + PhraseLocation& prefix_location, + PhraseLocation& suffix_location, + const Phrase& phrase) { + vector symbols = phrase.Get(); + + // We should never attempt to do an intersect query for a pattern starting or + // ending with a non terminal. The RuleFactory should handle these cases, + // initializing the matchings list with the one for the pattern without the + // starting or ending terminal. + assert(vocabulary->IsTerminal(symbols.front()) + && vocabulary->IsTerminal(symbols.back())); + + if (collocations.count(symbols)) { + return PhraseLocation(collocations[symbols], phrase.Arity() + 1); + } + + bool prefix_ends_with_x = + !vocabulary->IsTerminal(symbols[symbols.size() - 2]); + bool suffix_starts_with_x = !vocabulary->IsTerminal(symbols[1]); + if (EstimateNumOperations(prefix_location, prefix_ends_with_x) <= + EstimateNumOperations(suffix_location, suffix_starts_with_x)) { + return ExtendPrefixPhraseLocation(prefix_location, phrase, + prefix_ends_with_x, symbols.back()); + } else { + return ExtendSuffixPhraseLocation(suffix_location, phrase, + suffix_starts_with_x, symbols.front()); + } +} + +int FastIntersector::EstimateNumOperations( + const PhraseLocation& phrase_location, bool has_margin_x) const { + int num_locations = phrase_location.GetSize(); + return has_margin_x ? num_locations * max_rule_span : num_locations; +} + +PhraseLocation FastIntersector::ExtendPrefixPhraseLocation( + PhraseLocation& prefix_location, const Phrase& phrase, + bool prefix_ends_with_x, int next_symbol) const { + ExtendPhraseLocation(prefix_location); + vector positions = *prefix_location.matchings; + int num_subpatterns = prefix_location.num_subpatterns; + + vector new_positions; + shared_ptr data_array = suffix_array->GetData(); + int data_array_symbol = data_array->GetWordId( + vocabulary->GetTerminalValue(next_symbol)); + if (data_array_symbol == -1) { + return PhraseLocation(new_positions, num_subpatterns); + } + + pair range = GetSearchRange(prefix_ends_with_x); + for (size_t i = 0; i < positions.size(); i += num_subpatterns) { + int sent_id = data_array->GetSentenceId(positions[i]); + int sent_end = data_array->GetSentenceStart(sent_id + 1) - 1; + int pattern_end = positions[i + num_subpatterns - 1] + range.first; + if (prefix_ends_with_x) { + pattern_end += phrase.GetChunkLen(phrase.Arity() - 1) - 1; + } else { + pattern_end += phrase.GetChunkLen(phrase.Arity()) - 2; + } + for (int j = range.first; j < range.second; ++j) { + if (pattern_end >= sent_end || + pattern_end - positions[i] >= max_rule_span) { + break; + } + + if (data_array->AtIndex(pattern_end) == data_array_symbol) { + new_positions.insert(new_positions.end(), positions.begin() + i, + positions.begin() + i + num_subpatterns); + if (prefix_ends_with_x) { + new_positions.push_back(pattern_end); + } + } + ++pattern_end; + } + } + + return PhraseLocation(new_positions, phrase.Arity() + 1); +} + +PhraseLocation FastIntersector::ExtendSuffixPhraseLocation( + PhraseLocation& suffix_location, const Phrase& phrase, + bool suffix_starts_with_x, int prev_symbol) const { + ExtendPhraseLocation(suffix_location); + vector positions = *suffix_location.matchings; + int num_subpatterns = suffix_location.num_subpatterns; + + vector new_positions; + shared_ptr data_array = suffix_array->GetData(); + int data_array_symbol = data_array->GetWordId( + vocabulary->GetTerminalValue(prev_symbol)); + if (data_array_symbol == -1) { + return PhraseLocation(new_positions, num_subpatterns); + } + + pair range = GetSearchRange(suffix_starts_with_x); + for (size_t i = 0; i < positions.size(); i += num_subpatterns) { + int sent_id = data_array->GetSentenceId(positions[i]); + int sent_start = data_array->GetSentenceStart(sent_id); + int pattern_start = positions[i] - range.first; + int pattern_end = positions[i + num_subpatterns - 1] + + phrase.GetChunkLen(phrase.Arity()) - 1; + for (int j = range.first; j < range.second; ++j) { + if (pattern_start < sent_start || + pattern_end - pattern_start >= max_rule_span) { + break; + } + + if (data_array->AtIndex(pattern_start) == data_array_symbol) { + new_positions.push_back(pattern_start); + new_positions.insert(new_positions.end(), + positions.begin() + i + !suffix_starts_with_x, + positions.begin() + i + num_subpatterns); + } + --pattern_start; + } + } + + return PhraseLocation(new_positions, phrase.Arity() + 1); +} + +void FastIntersector::ExtendPhraseLocation(PhraseLocation& location) const { + if (location.matchings != NULL) { + return; + } + + location.num_subpatterns = 1; + location.matchings = make_shared >(); + for (int i = location.sa_low; i < location.sa_high; ++i) { + location.matchings->push_back(suffix_array->GetSuffix(i)); + } + location.sa_low = location.sa_high = 0; +} + +pair FastIntersector::GetSearchRange(bool has_marginal_x) const { + if (has_marginal_x) { + return make_pair(min_gap_size + 1, max_rule_span); + } else { + return make_pair(1, 2); + } +} diff --git a/extractor/fast_intersector.h b/extractor/fast_intersector.h new file mode 100644 index 00000000..785e428e --- /dev/null +++ b/extractor/fast_intersector.h @@ -0,0 +1,65 @@ +#ifndef _FAST_INTERSECTOR_H_ +#define _FAST_INTERSECTOR_H_ + +#include +#include +#include + +#include + +using namespace std; + +typedef boost::hash > VectorHash; +typedef unordered_map, vector, VectorHash> Index; + +class Phrase; +class PhraseLocation; +class Precomputation; +class SuffixArray; +class Vocabulary; + +class FastIntersector { + public: + FastIntersector(shared_ptr suffix_array, + shared_ptr precomputation, + shared_ptr vocabulary, + int max_rule_span, + int min_gap_size); + + virtual ~FastIntersector(); + + virtual PhraseLocation Intersect(PhraseLocation& prefix_location, + PhraseLocation& suffix_location, + const Phrase& phrase); + + protected: + FastIntersector(); + + private: + vector ConvertPhrase(const vector& old_phrase); + + int EstimateNumOperations(const PhraseLocation& phrase_location, + bool has_margin_x) const; + + PhraseLocation ExtendPrefixPhraseLocation(PhraseLocation& prefix_location, + const Phrase& phrase, + bool prefix_ends_with_x, + int next_symbol) const; + + PhraseLocation ExtendSuffixPhraseLocation(PhraseLocation& suffix_location, + const Phrase& phrase, + bool suffix_starts_with_x, + int prev_symbol) const; + + void ExtendPhraseLocation(PhraseLocation& location) const; + + pair GetSearchRange(bool has_marginal_x) const; + + shared_ptr suffix_array; + shared_ptr vocabulary; + int max_rule_span; + int min_gap_size; + Index collocations; +}; + +#endif diff --git a/extractor/fast_intersector_test.cc b/extractor/fast_intersector_test.cc new file mode 100644 index 00000000..0d6ef367 --- /dev/null +++ b/extractor/fast_intersector_test.cc @@ -0,0 +1,146 @@ +#include + +#include + +#include "fast_intersector.h" +#include "mocks/mock_data_array.h" +#include "mocks/mock_suffix_array.h" +#include "mocks/mock_precomputation.h" +#include "mocks/mock_vocabulary.h" +#include "phrase.h" +#include "phrase_location.h" +#include "phrase_builder.h" + +using namespace std; +using namespace ::testing; + +namespace { + +class FastIntersectorTest : public Test { + protected: + virtual void SetUp() { + vector words = {"EOL", "it", "makes", "him", "and", "mars", ",", + "sets", "on", "takes", "off", "."}; + vocabulary = make_shared(); + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*vocabulary, GetTerminalIndex(words[i])) + .WillRepeatedly(Return(i)); + EXPECT_CALL(*vocabulary, GetTerminalValue(i)) + .WillRepeatedly(Return(words[i])); + } + + vector data = {1, 2, 3, 4, 1, 5, 3, 6, 1, + 7, 3, 8, 4, 1, 9, 3, 10, 11, 0}; + data_array = make_shared(); + for (size_t i = 0; i < data.size(); ++i) { + EXPECT_CALL(*data_array, AtIndex(i)).WillRepeatedly(Return(data[i])); + EXPECT_CALL(*data_array, GetSentenceId(i)) + .WillRepeatedly(Return(0)); + } + EXPECT_CALL(*data_array, GetSentenceStart(0)) + .WillRepeatedly(Return(0)); + EXPECT_CALL(*data_array, GetSentenceStart(1)) + .WillRepeatedly(Return(19)); + for (size_t i = 0; i < words.size(); ++i) { + EXPECT_CALL(*data_array, GetWordId(words[i])) + .WillRepeatedly(Return(i)); + EXPECT_CALL(*data_array, GetWord(i)) + .WillRepeatedly(Return(words[i])); + } + + vector suffixes = {18, 0, 4, 8, 13, 1, 2, 6, 10, 15, 3, 12, 5, 7, 9, + 11, 14, 16, 17}; + suffix_array = make_shared(); + EXPECT_CALL(*suffix_array, GetData()).WillRepeatedly(Return(data_array)); + for (size_t i = 0; i < suffixes.size(); ++i) { + EXPECT_CALL(*suffix_array, GetSuffix(i)). + WillRepeatedly(Return(suffixes[i])); + } + + precomputation = make_shared(); + EXPECT_CALL(*precomputation, GetCollocations()) + .WillRepeatedly(ReturnRef(collocations)); + + phrase_builder = make_shared(vocabulary); + intersector = make_shared(suffix_array, precomputation, + vocabulary, 15, 1); + } + + Index collocations; + shared_ptr data_array; + shared_ptr suffix_array; + shared_ptr precomputation; + shared_ptr vocabulary; + shared_ptr intersector; + shared_ptr phrase_builder; +}; + +TEST_F(FastIntersectorTest, TestCachedCollocation) { + vector symbols = {8, -1, 9}; + vector expected_location = {11}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_location(15, 16), suffix_location(16, 17); + + collocations[symbols] = expected_location; + EXPECT_CALL(*precomputation, GetCollocations()) + .WillRepeatedly(ReturnRef(collocations)); + intersector = make_shared(suffix_array, precomputation, + vocabulary, 15, 1); + + PhraseLocation result = intersector->Intersect( + prefix_location, suffix_location, phrase); + + EXPECT_EQ(PhraseLocation(expected_location, 2), result); + EXPECT_EQ(PhraseLocation(15, 16), prefix_location); + EXPECT_EQ(PhraseLocation(16, 17), suffix_location); +} + +TEST_F(FastIntersectorTest, TestIntersectaXbXcExtendSuffix) { + vector symbols = {1, -1, 3, -1, 1}; + Phrase phrase = phrase_builder->Build(symbols); + vector prefix_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, + 8, 15, 3, 15}; + vector suffix_locs = {2, 4, 2, 8, 2, 13, 6, 8, 6, 13, 10, 13}; + PhraseLocation prefix_location(prefix_locs, 2); + PhraseLocation suffix_location(suffix_locs, 2); + + vector expected_locs = {0, 2, 4, 0, 2, 8, 0, 2, 13, 4, 6, 8, 0, 6, 8, + 4, 6, 13, 0, 6, 13, 8, 10, 13, 4, 10, 13, + 0, 10, 13}; + PhraseLocation result = intersector->Intersect( + prefix_location, suffix_location, phrase); + EXPECT_EQ(PhraseLocation(expected_locs, 3), result); +} + +/* +TEST_F(FastIntersectorTest, TestIntersectaXbExtendPrefix) { + vector symbols = {1, -1, 3}; + Phrase phrase = phrase_builder->Build(symbols); + PhraseLocation prefix_location(1, 5), suffix_location(6, 10); + + vector expected_prefix_locs = {0, 4, 8, 13}; + vector expected_locs = {0, 2, 0, 6, 0, 10, 4, 6, 4, 10, 4, 15, 8, 10, + 8, 15, 13, 15}; + PhraseLocation result = intersector->Intersect( + prefix_location, suffix_location, phrase); + EXPECT_EQ(PhraseLocation(expected_locs, 2), result); + EXPECT_EQ(PhraseLocation(expected_prefix_locs, 1), prefix_location); +} + +TEST_F(FastIntersectorTest, TestIntersectCheckEstimates) { + // The suffix matches in fewer positions, but because it starts with an X + // it requires more operations and we prefer extending the prefix. + vector symbols = {1, -1, 4, 1}; + Phrase phrase = phrase_builder->Build(symbols); + vector prefix_locs = {0, 3, 0, 12, 4, 12, 8, 12}; + PhraseLocation prefix_location(prefix_locs, 2), suffix_location(10, 12); + + vector expected_locs = {0, 3, 0, 12, 4, 12, 8, 12}; + PhraseLocation result = intersector->Intersect( + prefix_location, suffix_location, phrase); + EXPECT_EQ(PhraseLocation(expected_locs, 2), result); + EXPECT_EQ(PhraseLocation(10, 12), suffix_location); +} +*/ + +} // namespace diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 2f008026..a03e805f 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -16,12 +16,12 @@ GrammarExtractor::GrammarExtractor( shared_ptr alignment, shared_ptr precomputation, shared_ptr scorer, int min_gap_size, int max_rule_span, int max_nonterminals, int max_rule_symbols, int max_samples, - bool use_baeza_yates, bool require_tight_phrases) : + bool use_fast_intersect, bool use_baeza_yates, bool require_tight_phrases) : vocabulary(make_shared()), rule_factory(make_shared( source_suffix_array, target_data_array, alignment, vocabulary, precomputation, scorer, min_gap_size, max_rule_span, max_nonterminals, - max_rule_symbols, max_samples, use_baeza_yates, + max_rule_symbols, max_samples, use_fast_intersect, use_baeza_yates, require_tight_phrases)) {} GrammarExtractor::GrammarExtractor( diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index 5f87faa7..a8f2090d 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -29,6 +29,7 @@ class GrammarExtractor { int max_nonterminals, int max_rule_symbols, int max_samples, + bool use_fast_intersect, bool use_baeza_yates, bool require_tight_phrases); diff --git a/extractor/intersector.cc b/extractor/intersector.cc index cf42f630..39a7648d 100644 --- a/extractor/intersector.cc +++ b/extractor/intersector.cc @@ -1,7 +1,5 @@ #include "intersector.h" -#include - #include "data_array.h" #include "matching_comparator.h" #include "phrase.h" @@ -11,10 +9,6 @@ #include "veb.h" #include "vocabulary.h" -using namespace std::chrono; - -typedef high_resolution_clock Clock; - Intersector::Intersector(shared_ptr vocabulary, shared_ptr precomputation, shared_ptr suffix_array, @@ -92,9 +86,6 @@ PhraseLocation Intersector::Intersect( const Phrase& prefix, PhraseLocation& prefix_location, const Phrase& suffix, PhraseLocation& suffix_location, const Phrase& phrase) { - if (linear_merge_time == 0) { - linear_merger->linear_merge_time = 0; - } vector symbols = phrase.Get(); // We should never attempt to do an intersect query for a pattern starting or @@ -116,21 +107,15 @@ PhraseLocation Intersector::Intersect( int prefix_subpatterns = prefix_location.num_subpatterns; int suffix_subpatterns = suffix_location.num_subpatterns; if (use_baeza_yates) { - double prev_linear_merge_time = linear_merger->linear_merge_time; - Clock::time_point start = Clock::now(); binary_search_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(), prefix_matchings->end(), suffix_matchings->begin(), suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns); - Clock::time_point stop = Clock::now(); - binary_merge_time += duration_cast(stop - start).count() - - (linear_merger->linear_merge_time - prev_linear_merge_time); } else { linear_merger->Merge(locations, phrase, suffix, prefix_matchings->begin(), prefix_matchings->end(), suffix_matchings->begin(), suffix_matchings->end(), prefix_subpatterns, suffix_subpatterns); } - linear_merge_time = linear_merger->linear_merge_time; return PhraseLocation(locations, phrase.Arity() + 1); } @@ -141,7 +126,6 @@ void Intersector::ExtendPhraseLocation( return; } - Clock::time_point sort_start = Clock::now(); phrase_location.num_subpatterns = 1; phrase_location.sa_low = phrase_location.sa_high = 0; @@ -167,6 +151,4 @@ void Intersector::ExtendPhraseLocation( } phrase_location.matchings = make_shared >(matchings); - Clock::time_point sort_stop = Clock::now(); - sort_time += duration_cast(sort_stop - sort_start).count(); } diff --git a/extractor/linear_merger.cc b/extractor/linear_merger.cc index 7233f945..e7a32788 100644 --- a/extractor/linear_merger.cc +++ b/extractor/linear_merger.cc @@ -1,6 +1,5 @@ #include "linear_merger.h" -#include #include #include "data_array.h" @@ -10,10 +9,6 @@ #include "phrase_location.h" #include "vocabulary.h" -using namespace std::chrono; - -typedef high_resolution_clock Clock; - LinearMerger::LinearMerger(shared_ptr vocabulary, shared_ptr data_array, shared_ptr comparator) : @@ -28,8 +23,6 @@ void LinearMerger::Merge( vector::iterator prefix_start, vector::iterator prefix_end, vector::iterator suffix_start, vector::iterator suffix_end, int prefix_subpatterns, int suffix_subpatterns) { - Clock::time_point start = Clock::now(); - int last_chunk_len = suffix.GetChunkLen(suffix.Arity()); bool offset = !vocabulary->IsTerminal(suffix.GetSymbol(0)); @@ -69,7 +62,4 @@ void LinearMerger::Merge( prefix_start += prefix_subpatterns; } } - - Clock::time_point stop = Clock::now(); - linear_merge_time += duration_cast(stop - start).count(); } diff --git a/extractor/linear_merger.h b/extractor/linear_merger.h index 25692b15..c3c7111e 100644 --- a/extractor/linear_merger.h +++ b/extractor/linear_merger.h @@ -33,10 +33,6 @@ class LinearMerger { shared_ptr vocabulary; shared_ptr data_array; shared_ptr comparator; - - // TODO(pauldb): Remove this eventually. - public: - double linear_merge_time; }; #endif diff --git a/extractor/mocks/mock_fast_intersector.h b/extractor/mocks/mock_fast_intersector.h new file mode 100644 index 00000000..201386f2 --- /dev/null +++ b/extractor/mocks/mock_fast_intersector.h @@ -0,0 +1,11 @@ +#include + +#include "../fast_intersector.h" +#include "../phrase.h" +#include "../phrase_location.h" + +class MockFastIntersector : public FastIntersector { + public: + MOCK_METHOD3(Intersect, PhraseLocation(PhraseLocation&, PhraseLocation&, + const Phrase&)); +}; diff --git a/extractor/phrase_location.cc b/extractor/phrase_location.cc index 62f1e714..b0bfed80 100644 --- a/extractor/phrase_location.cc +++ b/extractor/phrase_location.cc @@ -5,15 +5,19 @@ PhraseLocation::PhraseLocation(int sa_low, int sa_high) : PhraseLocation::PhraseLocation(const vector& matchings, int num_subpatterns) : - sa_high(0), sa_low(0), + sa_low(0), sa_high(0), matchings(make_shared >(matchings)), num_subpatterns(num_subpatterns) {} -bool PhraseLocation::IsEmpty() { +bool PhraseLocation::IsEmpty() const { + return GetSize() == 0; +} + +int PhraseLocation::GetSize() const { if (num_subpatterns > 0) { - return matchings->size() == 0; + return matchings->size(); } else { - return sa_low >= sa_high; + return sa_high - sa_low; } } diff --git a/extractor/phrase_location.h b/extractor/phrase_location.h index e04d8628..a0eb36c8 100644 --- a/extractor/phrase_location.h +++ b/extractor/phrase_location.h @@ -11,7 +11,9 @@ struct PhraseLocation { PhraseLocation(const vector& matchings, int num_subpatterns); - bool IsEmpty(); + bool IsEmpty() const; + + int GetSize() const; friend bool operator==(const PhraseLocation& a, const PhraseLocation& b); diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 374a0db1..4101fcfa 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -6,6 +6,7 @@ #include #include "grammar.h" +#include "fast_intersector.h" #include "intersector.h" #include "matchings_finder.h" #include "matching_comparator.h" @@ -15,10 +16,11 @@ #include "sampler.h" #include "scorer.h" #include "suffix_array.h" +#include "time_util.h" #include "vocabulary.h" using namespace std; -using namespace std::chrono; +using namespace chrono; typedef high_resolution_clock Clock; @@ -48,6 +50,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( int max_nonterminals, int max_rule_symbols, int max_samples, + bool use_fast_intersect, bool use_baeza_yates, bool require_tight_phrases) : vocabulary(vocabulary), @@ -56,12 +59,15 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( max_rule_span(max_rule_span), max_nonterminals(max_nonterminals), max_chunks(max_nonterminals + 1), - max_rule_symbols(max_rule_symbols) { + max_rule_symbols(max_rule_symbols), + use_fast_intersect(use_fast_intersect) { matchings_finder = make_shared(source_suffix_array); shared_ptr comparator = make_shared(min_gap_size, max_rule_span); intersector = make_shared(vocabulary, precomputation, source_suffix_array, comparator, use_baeza_yates); + fast_intersector = make_shared(source_suffix_array, + precomputation, vocabulary, max_rule_span, min_gap_size); phrase_builder = make_shared(vocabulary); rule_extractor = make_shared(source_suffix_array->GetData(), target_data_array, alignment, phrase_builder, scorer, vocabulary, @@ -73,6 +79,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( HieroCachingRuleFactory::HieroCachingRuleFactory( shared_ptr finder, shared_ptr intersector, + shared_ptr fast_intersector, shared_ptr phrase_builder, shared_ptr rule_extractor, shared_ptr vocabulary, @@ -82,9 +89,11 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( int max_rule_span, int max_nonterminals, int max_chunks, - int max_rule_symbols) : + int max_rule_symbols, + bool use_fast_intersect) : matchings_finder(finder), intersector(intersector), + fast_intersector(fast_intersector), phrase_builder(phrase_builder), rule_extractor(rule_extractor), vocabulary(vocabulary), @@ -94,15 +103,14 @@ HieroCachingRuleFactory::HieroCachingRuleFactory( max_rule_span(max_rule_span), max_nonterminals(max_nonterminals), max_chunks(max_chunks), - max_rule_symbols(max_rule_symbols) {} + max_rule_symbols(max_rule_symbols), + use_fast_intersect(use_fast_intersect) {} HieroCachingRuleFactory::HieroCachingRuleFactory() {} HieroCachingRuleFactory::~HieroCachingRuleFactory() {} Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { - intersector->binary_merge_time = 0; - intersector->linear_merge_time = 0; intersector->sort_time = 0; Clock::time_point start_time = Clock::now(); double total_extract_time = 0; @@ -155,25 +163,28 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { } else { PhraseLocation phrase_location; if (next_phrase.Arity() > 0) { - Clock::time_point intersect_start_time = Clock::now(); - phrase_location = intersector->Intersect( - node->phrase, - node->matchings, - next_suffix_link->phrase, - next_suffix_link->matchings, - next_phrase); - Clock::time_point intersect_stop_time = Clock::now(); - total_intersect_time += duration_cast( - intersect_stop_time - intersect_start_time).count(); + Clock::time_point intersect_start = Clock::now(); + if (use_fast_intersect) { + phrase_location = fast_intersector->Intersect( + node->matchings, next_suffix_link->matchings, next_phrase); + } else { + phrase_location = intersector->Intersect( + node->phrase, + node->matchings, + next_suffix_link->phrase, + next_suffix_link->matchings, + next_phrase); + } + Clock::time_point intersect_stop = Clock::now(); + total_intersect_time += GetDuration(intersect_start, intersect_stop); } else { - Clock::time_point lookup_start_time = Clock::now(); + Clock::time_point lookup_start = Clock::now(); phrase_location = matchings_finder->Find( node->matchings, vocabulary->GetTerminalValue(word_id), state.phrase.size()); - Clock::time_point lookup_stop_time = Clock::now(); - total_lookup_time += duration_cast( - lookup_stop_time - lookup_start_time).count(); + Clock::time_point lookup_stop = Clock::now(); + total_lookup_time += GetDuration(lookup_start, lookup_stop); } if (phrase_location.IsEmpty()) { @@ -189,16 +200,15 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { AddTrailingNonterminal(phrase, next_phrase, next_node, state.starts_with_x); - Clock::time_point extract_start_time = Clock::now(); + Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { PhraseLocation sample = sampler->Sample(next_node->matchings); vector new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); } - Clock::time_point extract_stop_time = Clock::now(); - total_extract_time += duration_cast( - extract_stop_time - extract_start_time).count(); + Clock::time_point extract_stop = Clock::now(); + total_extract_time += GetDuration(extract_start, extract_stop); } else { next_node = node->GetChild(word_id); } @@ -211,15 +221,11 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { } Clock::time_point stop_time = Clock::now(); - milliseconds ms = duration_cast(stop_time - start_time); cerr << "Total time for rule lookup, extraction, and scoring = " - << ms.count() / 1000.0 << endl; - cerr << "Extract time = " << total_extract_time / 1000.0 << endl; - cerr << "Intersect time = " << total_intersect_time / 1000.0 << endl; - cerr << "Sort time = " << intersector->sort_time / 1000.0 << endl; - cerr << "Linear merge time = " << intersector->linear_merge_time / 1000.0 << endl; - cerr << "Binary merge time = " << intersector->binary_merge_time / 1000.0 << endl; - // cerr << "Lookup time = " << total_lookup_time / 1000.0 << endl; + << GetDuration(start_time, stop_time) << " seconds" << endl; + cerr << "Extract time = " << total_extract_time << " seconds" << endl; + cerr << "Intersect time = " << total_intersect_time << " seconds" << endl; + cerr << "Lookup time = " << total_lookup_time << " seconds" << endl; return Grammar(rules, scorer->GetFeatureNames()); } diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index cf344667..a39386a8 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -13,6 +13,7 @@ class Alignment; class DataArray; class Grammar; class MatchingsFinder; +class FastIntersector; class Intersector; class Precomputation; class Rule; @@ -37,6 +38,7 @@ class HieroCachingRuleFactory { int max_nonterminals, int max_rule_symbols, int max_samples, + bool use_fast_intersect, bool use_beaza_yates, bool require_tight_phrases); @@ -44,6 +46,7 @@ class HieroCachingRuleFactory { HieroCachingRuleFactory( shared_ptr finder, shared_ptr intersector, + shared_ptr fast_intersector, shared_ptr phrase_builder, shared_ptr rule_extractor, shared_ptr vocabulary, @@ -53,7 +56,8 @@ class HieroCachingRuleFactory { int max_rule_span, int max_nonterminals, int max_chunks, - int max_rule_symbols); + int max_rule_symbols, + bool use_fast_intersect); virtual ~HieroCachingRuleFactory(); @@ -80,6 +84,7 @@ class HieroCachingRuleFactory { shared_ptr matchings_finder; shared_ptr intersector; + shared_ptr fast_intersector; MatchingsTrie trie; shared_ptr phrase_builder; shared_ptr rule_extractor; @@ -91,6 +96,7 @@ class HieroCachingRuleFactory { int max_nonterminals; int max_chunks; int max_rule_symbols; + bool use_fast_intersect; }; #endif diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index d6fbab74..d329382a 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -5,6 +5,7 @@ #include #include "grammar.h" +#include "mocks/mock_fast_intersector.h" #include "mocks/mock_intersector.h" #include "mocks/mock_matchings_finder.h" #include "mocks/mock_rule_extractor.h" @@ -25,6 +26,7 @@ class RuleFactoryTest : public Test { virtual void SetUp() { finder = make_shared(); intersector = make_shared(); + fast_intersector = make_shared(); vocabulary = make_shared(); EXPECT_CALL(*vocabulary, GetTerminalValue(2)).WillRepeatedly(Return("a")); @@ -49,14 +51,12 @@ class RuleFactoryTest : public Test { extractor = make_shared(); EXPECT_CALL(*extractor, ExtractRules(_, _)) .WillRepeatedly(Return(rules)); - - factory = make_shared(finder, intersector, - phrase_builder, extractor, vocabulary, sampler, scorer, 1, 10, 2, 3, 5); } vector feature_names; shared_ptr finder; shared_ptr intersector; + shared_ptr fast_intersector; shared_ptr vocabulary; shared_ptr phrase_builder; shared_ptr scorer; @@ -66,6 +66,10 @@ class RuleFactoryTest : public Test { }; TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { + factory = make_shared(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, false); + EXPECT_CALL(*finder, Find(_, _, _)) .Times(6) .WillRepeatedly(Return(PhraseLocation(0, 1))); @@ -73,14 +77,37 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)) .Times(1) .WillRepeatedly(Return(PhraseLocation(0, 1))); + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0); vector word_ids = {2, 3, 4}; Grammar grammar = factory->GetGrammar(word_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(7, grammar.GetRules().size()); + + // Test for fast intersector. + factory = make_shared(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, true); + + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(6) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)) + .Times(1) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0); + + grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(7, grammar.GetRules().size()); } TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { + factory = make_shared(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, false); + EXPECT_CALL(*finder, Find(_, _, _)) .Times(12) .WillRepeatedly(Return(PhraseLocation(0, 1))); @@ -89,10 +116,31 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { .Times(16) .WillRepeatedly(Return(PhraseLocation(0, 1))); + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)).Times(0); + vector word_ids = {2, 3, 4, 2, 3}; Grammar grammar = factory->GetGrammar(word_ids); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(28, grammar.GetRules().size()); + + // Test for fast intersector. + factory = make_shared(finder, intersector, + fast_intersector, phrase_builder, extractor, vocabulary, sampler, + scorer, 1, 10, 2, 3, 5, true); + + EXPECT_CALL(*finder, Find(_, _, _)) + .Times(12) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*fast_intersector, Intersect(_, _, _)) + .Times(16) + .WillRepeatedly(Return(PhraseLocation(0, 1))); + + EXPECT_CALL(*intersector, Intersect(_, _, _, _, _)).Times(0); + + grammar = factory->GetGrammar(word_ids); + EXPECT_EQ(feature_names, grammar.GetFeatureNames()); + EXPECT_EQ(28, grammar.GetRules().size()); } } // namespace diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index ed30e6fe..38f10a5f 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -23,6 +24,7 @@ #include "rule.h" #include "scorer.h" #include "suffix_array.h" +#include "time_util.h" #include "translation_table.h" namespace fs = boost::filesystem; @@ -56,6 +58,9 @@ int main(int argc, char** argv) { "Minimum number of occurences for a pharse to be considered frequent") ("max_samples", po::value()->default_value(300), "Maximum number of samples") + ("fast_intersect", po::value()->default_value(false), + "Enable fast intersect") + // TODO(pauldb): Check if this works when set to false. ("tight_phrases", po::value()->default_value(true), "False if phrases may be loose (better, but slower)") ("baeza_yates", po::value()->default_value(true), @@ -80,6 +85,9 @@ int main(int argc, char** argv) { return 1; } + Clock::time_point preprocess_start_time = Clock::now(); + cerr << "Reading source and target data..." << endl; + Clock::time_point start_time = Clock::now(); shared_ptr source_data_array, target_data_array; if (vm.count("bitext")) { source_data_array = make_shared( @@ -90,13 +98,28 @@ int main(int argc, char** argv) { source_data_array = make_shared(vm["source"].as()); target_data_array = make_shared(vm["target"].as()); } + Clock::time_point stop_time = Clock::now(); + cerr << "Reading data took " << GetDuration(start_time, stop_time) + << " seconds" << endl; + + cerr << "Creating source suffix array..." << endl; + start_time = Clock::now(); shared_ptr source_suffix_array = make_shared(source_data_array); + stop_time = Clock::now(); + cerr << "Creating suffix array took " + << GetDuration(start_time, stop_time) << " seconds" << endl; - + cerr << "Reading alignment..." << endl; + start_time = Clock::now(); shared_ptr alignment = make_shared(vm["alignment"].as()); + stop_time = Clock::now(); + cerr << "Reading alignment took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + cerr << "Precomputating collocations..." << endl; + start_time = Clock::now(); shared_ptr precomputation = make_shared( source_suffix_array, vm["frequent"].as(), @@ -106,10 +129,24 @@ int main(int argc, char** argv) { vm["min_gap_size"].as(), vm["max_phrase_len"].as(), vm["min_frequency"].as()); + stop_time = Clock::now(); + cerr << "Precomputing collocations took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + cerr << "Precomputing conditional probabilities..." << endl; + start_time = Clock::now(); shared_ptr table = make_shared( source_data_array, target_data_array, alignment); + stop_time = Clock::now(); + cerr << "Precomputing conditional probabilities took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + + Clock::time_point preprocess_stop_time = Clock::now(); + cerr << "Overall preprocessing step took " + << GetDuration(preprocess_start_time, preprocess_stop_time) + << " seconds" << endl; + Clock::time_point extraction_start_time = Clock::now(); vector > features = { make_shared(), make_shared(), @@ -133,6 +170,7 @@ int main(int argc, char** argv) { vm["max_nonterminals"].as(), vm["max_rule_symbols"].as(), vm["max_samples"].as(), + vm["fast_intersect"].as(), vm["baeza_yates"].as(), vm["tight_phrases"].as()); @@ -161,6 +199,10 @@ int main(int argc, char** argv) { << "\"> " << sentence << " " << suffix << endl; ++grammar_id; } + Clock::time_point extraction_stop_time = Clock::now(); + cerr << "Overall extraction step took " + << GetDuration(extraction_start_time, extraction_stop_time) + << " seconds" << endl; return 0; } diff --git a/extractor/suffix_array.cc b/extractor/suffix_array.cc index 9815996f..23c458a4 100644 --- a/extractor/suffix_array.cc +++ b/extractor/suffix_array.cc @@ -1,14 +1,17 @@ #include "suffix_array.h" +#include #include #include #include #include "data_array.h" #include "phrase_location.h" +#include "time_util.h" namespace fs = boost::filesystem; using namespace std; +using namespace chrono; SuffixArray::SuffixArray(shared_ptr data_array) : data_array(data_array) { @@ -39,6 +42,7 @@ void SuffixArray::BuildSuffixArray() { } PrefixDoublingSort(groups); + cerr << "\tFinalizing sort..." << endl; for (size_t i = 0; i < groups.size(); ++i) { suffix_array[groups[i]] = i; @@ -46,6 +50,7 @@ void SuffixArray::BuildSuffixArray() { } void SuffixArray::InitialBucketSort(vector& groups) { + Clock::time_point start_time = Clock::now(); for (size_t i = 0; i < groups.size(); ++i) { ++word_start[groups[i]]; } @@ -62,6 +67,9 @@ void SuffixArray::InitialBucketSort(vector& groups) { for (size_t i = 0; i < suffix_array.size(); ++i) { groups[i] = word_start[groups[i] + 1] - 1; } + Clock::time_point stop_time = Clock::now(); + cerr << "\tBucket sort took " << GetDuration(start_time, stop_time) + << " seconds" << endl; } void SuffixArray::PrefixDoublingSort(vector& groups) { @@ -127,6 +135,9 @@ void SuffixArray::TernaryQuicksort(int left, int right, int step, } vector SuffixArray::BuildLCPArray() const { + Clock::time_point start_time = Clock::now(); + cerr << "Constructing LCP array..." << endl; + vector lcp(suffix_array.size()); vector rank(suffix_array.size()); const vector& data = data_array->GetData(); @@ -153,6 +164,10 @@ vector SuffixArray::BuildLCPArray() const { } } + Clock::time_point stop_time = Clock::now(); + cerr << "Constructing LCP took " + << GetDuration(start_time, stop_time) << " seconds" << endl; + return lcp; } diff --git a/extractor/time_util.cc b/extractor/time_util.cc new file mode 100644 index 00000000..88395f77 --- /dev/null +++ b/extractor/time_util.cc @@ -0,0 +1,6 @@ +#include "time_util.h" + +double GetDuration(const Clock::time_point& start_time, + const Clock::time_point& stop_time) { + return duration_cast(stop_time - start_time).count() / 1000.0; +} diff --git a/extractor/time_util.h b/extractor/time_util.h new file mode 100644 index 00000000..6f7eda70 --- /dev/null +++ b/extractor/time_util.h @@ -0,0 +1,14 @@ +#ifndef _TIME_UTIL_H_ +#define _TIME_UTIL_H_ + +#include + +using namespace std; +using namespace chrono; + +typedef high_resolution_clock Clock; + +double GetDuration(const Clock::time_point& start_time, + const Clock::time_point& stop_time); + +#endif -- cgit v1.2.3