From 092b7cf020680e949d6956ec6ef2cf012faccd86 Mon Sep 17 00:00:00 2001 From: Paul Baltescu Date: Thu, 7 Mar 2013 22:49:46 +0000 Subject: Parallelized grammar extraction. --- configure.ac | 1 + extractor/Makefile.am | 3 ++- extractor/matchings_trie.cc | 11 ++++++---- extractor/matchings_trie.h | 7 +++++-- extractor/rule_extractor_helper.cc | 2 +- extractor/rule_factory.cc | 6 ++---- extractor/rule_factory.h | 1 - extractor/run_extractor.cc | 41 ++++++++++++++++++++++++++------------ extractor/vocabulary.cc | 27 ++++++++++++++----------- extractor/vocabulary.h | 2 -- 10 files changed, 61 insertions(+), 40 deletions(-) diff --git a/configure.ac b/configure.ac index 66ab7778..59224e0c 100644 --- a/configure.ac +++ b/configure.ac @@ -11,6 +11,7 @@ esac AC_PROG_CC AC_PROG_CXX AC_LANG_CPLUSPLUS +AC_OPENMP BOOST_REQUIRE([1.44]) BOOST_FILESYSTEM BOOST_PROGRAM_OPTIONS diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 721df18b..d8239b7d 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -145,4 +145,5 @@ libextractor_a_SOURCES = \ translation_table.cc \ vocabulary.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -std=c++0x $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS) +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -std=c++0x -fopenmp $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS) +AM_LDFLAGS = -fopenmp diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc index c7b98765..7fb7a529 100644 --- a/extractor/matchings_trie.cc +++ b/extractor/matchings_trie.cc @@ -2,19 +2,22 @@ namespace extractor { -void MatchingsTrie::Reset() { - ResetTree(root); +MatchingsTrie::MatchingsTrie() { root = make_shared(); } +MatchingsTrie::~MatchingsTrie() { + DeleteTree(root); +} + shared_ptr MatchingsTrie::GetRoot() const { return root; } -void MatchingsTrie::ResetTree(shared_ptr root) { +void MatchingsTrie::DeleteTree(shared_ptr root) { if (root != NULL) { for (auto child: root->children) { - ResetTree(child.second); + DeleteTree(child.second); } if (root->suffix_link != NULL) { root->suffix_link.reset(); diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h index a54671d2..f3dcc075 100644 --- a/extractor/matchings_trie.h +++ b/extractor/matchings_trie.h @@ -37,11 +37,14 @@ struct TrieNode { class MatchingsTrie { public: - void Reset(); + MatchingsTrie(); + + virtual ~MatchingsTrie(); + shared_ptr GetRoot() const; private: - void ResetTree(shared_ptr root); + void DeleteTree(shared_ptr root); shared_ptr root; }; diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc index d9ed6a7e..81b522f0 100644 --- a/extractor/rule_extractor_helper.cc +++ b/extractor/rule_extractor_helper.cc @@ -117,7 +117,7 @@ bool RuleExtractorHelper::FindFixPoint( source_high, target_phrase_low, target_phrase_high); if (target_phrase_low == -1) { - // TODO(pauldb): Low priority corner case inherited from Adam's code: + // Note: Low priority corner case inherited from Adam's code: // If w is unaligned, but we don't require aligned terminals, returning an // error here prevents the extraction of the allowed rule // X -> X_1 w X_2 / X_1 X_2 diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 51f85c30..a5505ced 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -105,8 +105,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { double total_extract_time = 0; double total_intersect_time = 0; double total_lookup_time = 0; - // Clear cache for every new sentence. - trie.Reset(); + + MatchingsTrie trie; shared_ptr root = trie.GetRoot(); int first_x = vocabulary->GetNonterminalIndex(1); @@ -200,8 +200,6 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector& word_ids) { } } - cerr << "Vocabulary size = " << vocabulary->Size() << endl; - Clock::time_point stop_time = Clock::now(); cerr << "Total time for rule lookup, extraction, and scoring = " << GetDuration(start_time, stop_time) << " seconds" << endl; diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index 0de04e40..d8dc2ccc 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -81,7 +81,6 @@ class HieroCachingRuleFactory { shared_ptr matchings_finder; shared_ptr fast_intersector; - MatchingsTrie trie; shared_ptr phrase_builder; shared_ptr rule_extractor; shared_ptr vocabulary; diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index ae3a875e..eb5600fe 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -35,6 +36,10 @@ using namespace extractor; using namespace features; int main(int argc, char** argv) { + int num_threads_default = 1; + #pragma omp parallel + num_threads_default = omp_get_num_threads(); + po::options_description desc("Command line options"); desc.add_options() ("help,h", "Show available options") @@ -43,13 +48,15 @@ int main(int argc, char** argv) { ("bitext,b", po::value(), "Parallel text (source ||| target)") ("alignment,a", po::value()->required(), "Bitext word alignment") ("grammars,g", po::value()->required(), "Grammars output path") + ("threads,t", po::value()->default_value(num_threads_default), + "Number of parallel extractors") ("frequent", po::value()->default_value(100), "Number of precomputed frequent patterns") ("super_frequent", po::value()->default_value(10), "Number of precomputed super frequent patterns") ("max_rule_span", po::value()->default_value(15), "Maximum rule span") - ("max_rule_symbols,l", po::value()->default_value(5), + ("max_rule_symbols", po::value()->default_value(5), "Maximum number of symbols (terminals + nontermals) in a rule") ("min_gap_size", po::value()->default_value(1), "Minimum gap size") ("max_phrase_len", po::value()->default_value(4), @@ -155,7 +162,6 @@ int main(int argc, char** argv) { }; shared_ptr scorer = make_shared(features); - // TODO(pauldb): Add parallelization. GrammarExtractor extractor( source_suffix_array, target_data_array, @@ -172,30 +178,39 @@ int main(int argc, char** argv) { // Release extra memory used by the initial precomputation. precomputation.reset(); - int grammar_id = 0; fs::path grammar_path = vm["grammars"].as(); if (!fs::is_directory(grammar_path)) { fs::create_directory(grammar_path); } - string sentence, delimiter = "|||"; + string sentence; + vector sentences; while (getline(cin, sentence)) { + sentences.push_back(sentence); + } + + #pragma omp parallel for schedule(dynamic) \ + num_threads(vm["threads"].as()) ordered + for (size_t i = 0; i < sentences.size(); ++i) { + string delimiter = "|||"; string suffix = ""; - int position = sentence.find(delimiter); - if (position != sentence.npos) { - suffix = sentence.substr(position); - sentence = sentence.substr(0, position); + int position = sentences[i].find(delimiter); + if (position != sentences[i].npos) { + suffix = sentences[i].substr(position); + sentences[i] = sentences[i].substr(0, position); } - Grammar grammar = extractor.GetGrammar(sentence); - string file_name = "grammar." + to_string(grammar_id); + Grammar grammar = extractor.GetGrammar(sentences[i]); + string file_name = "grammar." + to_string(i); fs::path grammar_file = grammar_path / file_name; ofstream output(grammar_file.c_str()); output << grammar; - cout << " " << sentence << " " << suffix << endl; - ++grammar_id; + #pragma omp critical (stdout_write) + { + cout << " " + << sentences[i] << " " << suffix << endl; + } } Clock::time_point extraction_stop_time = Clock::now(); cerr << "Overall extraction step took " diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc index 57f564d9..15795d1e 100644 --- a/extractor/vocabulary.cc +++ b/extractor/vocabulary.cc @@ -5,14 +5,18 @@ namespace extractor { Vocabulary::~Vocabulary() {} int Vocabulary::GetTerminalIndex(const string& word) { - if (!dictionary.count(word)) { - int word_id = words.size(); - dictionary[word] = word_id; - words.push_back(word); - return word_id; + int word_id = -1; + #pragma omp critical (vocabulary) + { + if (!dictionary.count(word)) { + word_id = words.size(); + dictionary[word] = word_id; + words.push_back(word); + } else { + word_id = dictionary[word]; + } } - - return dictionary[word]; + return word_id; } int Vocabulary::GetNonterminalIndex(int position) { @@ -24,11 +28,10 @@ bool Vocabulary::IsTerminal(int symbol) { } string Vocabulary::GetTerminalValue(int symbol) { - return words[symbol]; -} - -int Vocabulary::Size() { - return words.size(); + string word; + #pragma omp critical (vocabulary) + word = words[symbol]; + return word; } } // namespace extractor diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h index dcc2a8fa..03c7dc66 100644 --- a/extractor/vocabulary.h +++ b/extractor/vocabulary.h @@ -21,8 +21,6 @@ class Vocabulary { virtual string GetTerminalValue(int symbol); - int Size(); - private: unordered_map dictionary; vector words; -- cgit v1.2.3