diff options
| author | Paul Baltescu <pauldb89@gmail.com> | 2013-03-07 22:49:46 +0000 | 
|---|---|---|
| committer | Paul Baltescu <pauldb89@gmail.com> | 2013-03-07 22:49:46 +0000 | 
| commit | 092b7cf020680e949d6956ec6ef2cf012faccd86 (patch) | |
| tree | 4bc074572925a10b63928639be244a60f153f7ac | |
| parent | d7271db305bd1aeaf9c3d9ac1043546fec22a402 (diff) | |
Parallelized grammar extraction.
| -rw-r--r-- | configure.ac | 1 | ||||
| -rw-r--r-- | extractor/Makefile.am | 3 | ||||
| -rw-r--r-- | extractor/matchings_trie.cc | 11 | ||||
| -rw-r--r-- | extractor/matchings_trie.h | 7 | ||||
| -rw-r--r-- | extractor/rule_extractor_helper.cc | 2 | ||||
| -rw-r--r-- | extractor/rule_factory.cc | 6 | ||||
| -rw-r--r-- | extractor/rule_factory.h | 1 | ||||
| -rw-r--r-- | extractor/run_extractor.cc | 41 | ||||
| -rw-r--r-- | extractor/vocabulary.cc | 27 | ||||
| -rw-r--r-- | extractor/vocabulary.h | 2 | 
10 files changed, 61 insertions, 40 deletions
diff --git a/configure.ac b/configure.ac index 66ab7778..59224e0c 100644 --- a/configure.ac +++ b/configure.ac @@ -11,6 +11,7 @@ esac  AC_PROG_CC  AC_PROG_CXX  AC_LANG_CPLUSPLUS +AC_OPENMP  BOOST_REQUIRE([1.44])  BOOST_FILESYSTEM  BOOST_PROGRAM_OPTIONS diff --git a/extractor/Makefile.am b/extractor/Makefile.am index 721df18b..d8239b7d 100644 --- a/extractor/Makefile.am +++ b/extractor/Makefile.am @@ -145,4 +145,5 @@ libextractor_a_SOURCES = \    translation_table.cc \    vocabulary.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -std=c++0x $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS) +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -std=c++0x -fopenmp $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS) +AM_LDFLAGS = -fopenmp diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc index c7b98765..7fb7a529 100644 --- a/extractor/matchings_trie.cc +++ b/extractor/matchings_trie.cc @@ -2,19 +2,22 @@  namespace extractor { -void MatchingsTrie::Reset() { -  ResetTree(root); +MatchingsTrie::MatchingsTrie() {    root = make_shared<TrieNode>();  } +MatchingsTrie::~MatchingsTrie() { +  DeleteTree(root); +} +  shared_ptr<TrieNode> MatchingsTrie::GetRoot() const {    return root;  } -void MatchingsTrie::ResetTree(shared_ptr<TrieNode> root) { +void MatchingsTrie::DeleteTree(shared_ptr<TrieNode> root) {    if (root != NULL) {      for (auto child: root->children) { -      ResetTree(child.second); +      DeleteTree(child.second);      }      if (root->suffix_link != NULL) {        root->suffix_link.reset(); diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h index a54671d2..f3dcc075 100644 --- a/extractor/matchings_trie.h +++ b/extractor/matchings_trie.h @@ -37,11 +37,14 @@ struct TrieNode {  class MatchingsTrie {   public: -  void Reset(); +  MatchingsTrie(); + +  virtual ~MatchingsTrie(); +    shared_ptr<TrieNode> GetRoot() const;   private: -  void ResetTree(shared_ptr<TrieNode> root); +  void DeleteTree(shared_ptr<TrieNode> root);    shared_ptr<TrieNode> root;  }; diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc index d9ed6a7e..81b522f0 100644 --- a/extractor/rule_extractor_helper.cc +++ b/extractor/rule_extractor_helper.cc @@ -117,7 +117,7 @@ bool RuleExtractorHelper::FindFixPoint(                   source_high, target_phrase_low, target_phrase_high);    if (target_phrase_low == -1) { -    // TODO(pauldb): Low priority corner case inherited from Adam's code: +    // Note: Low priority corner case inherited from Adam's code:      // If w is unaligned, but we don't require aligned terminals, returning an      // error here prevents the extraction of the allowed rule      // X -> X_1 w X_2 / X_1 X_2 diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 51f85c30..a5505ced 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -105,8 +105,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {    double total_extract_time = 0;    double total_intersect_time = 0;    double total_lookup_time = 0; -  // Clear cache for every new sentence. -  trie.Reset(); + +  MatchingsTrie trie;    shared_ptr<TrieNode> root = trie.GetRoot();    int first_x = vocabulary->GetNonterminalIndex(1); @@ -200,8 +200,6 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {      }    } -  cerr << "Vocabulary size = " << vocabulary->Size() << endl; -    Clock::time_point stop_time = Clock::now();    cerr << "Total time for rule lookup, extraction, and scoring = "         << GetDuration(start_time, stop_time) << " seconds" << endl; diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index 0de04e40..d8dc2ccc 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -81,7 +81,6 @@ class HieroCachingRuleFactory {    shared_ptr<MatchingsFinder> matchings_finder;    shared_ptr<FastIntersector> fast_intersector; -  MatchingsTrie trie;    shared_ptr<PhraseBuilder> phrase_builder;    shared_ptr<RuleExtractor> rule_extractor;    shared_ptr<Vocabulary> vocabulary; diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index ae3a875e..eb5600fe 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -5,6 +5,7 @@  #include <string>  #include <vector> +#include <omp.h>  #include <boost/filesystem.hpp>  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> @@ -35,6 +36,10 @@ using namespace extractor;  using namespace features;  int main(int argc, char** argv) { +  int num_threads_default = 1; +  #pragma omp parallel +  num_threads_default = omp_get_num_threads(); +    po::options_description desc("Command line options");    desc.add_options()      ("help,h", "Show available options") @@ -43,13 +48,15 @@ int main(int argc, char** argv) {      ("bitext,b", po::value<string>(), "Parallel text (source ||| target)")      ("alignment,a", po::value<string>()->required(), "Bitext word alignment")      ("grammars,g", po::value<string>()->required(), "Grammars output path") +    ("threads,t", po::value<int>()->default_value(num_threads_default), +        "Number of parallel extractors")      ("frequent", po::value<int>()->default_value(100),          "Number of precomputed frequent patterns")      ("super_frequent", po::value<int>()->default_value(10),          "Number of precomputed super frequent patterns")      ("max_rule_span", po::value<int>()->default_value(15),          "Maximum rule span") -    ("max_rule_symbols,l", po::value<int>()->default_value(5), +    ("max_rule_symbols", po::value<int>()->default_value(5),          "Maximum number of symbols (terminals + nontermals) in a rule")      ("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size")      ("max_phrase_len", po::value<int>()->default_value(4), @@ -155,7 +162,6 @@ int main(int argc, char** argv) {    };    shared_ptr<Scorer> scorer = make_shared<Scorer>(features); -  // TODO(pauldb): Add parallelization.    GrammarExtractor extractor(        source_suffix_array,        target_data_array, @@ -172,30 +178,39 @@ int main(int argc, char** argv) {    // Release extra memory used by the initial precomputation.    precomputation.reset(); -  int grammar_id = 0;    fs::path grammar_path = vm["grammars"].as<string>();    if (!fs::is_directory(grammar_path)) {      fs::create_directory(grammar_path);    } -  string sentence, delimiter = "|||"; +  string sentence; +  vector<string> sentences;    while (getline(cin, sentence)) { +    sentences.push_back(sentence); +  } + +  #pragma omp parallel for schedule(dynamic) \ +      num_threads(vm["threads"].as<int>()) ordered +  for (size_t i = 0; i < sentences.size(); ++i) { +    string delimiter = "|||";      string suffix = ""; -    int position = sentence.find(delimiter); -    if (position != sentence.npos) { -      suffix = sentence.substr(position); -      sentence = sentence.substr(0, position); +    int position = sentences[i].find(delimiter); +    if (position != sentences[i].npos) { +      suffix = sentences[i].substr(position); +      sentences[i] = sentences[i].substr(0, position);      } -    Grammar grammar = extractor.GetGrammar(sentence); -    string file_name = "grammar." + to_string(grammar_id); +    Grammar grammar = extractor.GetGrammar(sentences[i]); +    string file_name = "grammar." + to_string(i);      fs::path grammar_file = grammar_path / file_name;      ofstream output(grammar_file.c_str());      output << grammar; -    cout << "<seg grammar=\"" << grammar_file << "\" id=\"" << grammar_id -         << "\"> " << sentence << " </seg> " << suffix << endl; -    ++grammar_id; +    #pragma omp critical (stdout_write) +    { +      cout << "<seg grammar=\"" << grammar_file << "\" id=\"" << i << "\"> " +           << sentences[i] << " </seg> " << suffix << endl; +    }    }    Clock::time_point extraction_stop_time = Clock::now();    cerr << "Overall extraction step took " diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc index 57f564d9..15795d1e 100644 --- a/extractor/vocabulary.cc +++ b/extractor/vocabulary.cc @@ -5,14 +5,18 @@ namespace extractor {  Vocabulary::~Vocabulary() {}  int Vocabulary::GetTerminalIndex(const string& word) { -  if (!dictionary.count(word)) { -    int word_id = words.size(); -    dictionary[word] = word_id; -    words.push_back(word); -    return word_id; +  int word_id = -1; +  #pragma omp critical (vocabulary) +  { +    if (!dictionary.count(word)) { +      word_id = words.size(); +      dictionary[word] = word_id; +      words.push_back(word); +    } else { +      word_id = dictionary[word]; +    }    } - -  return dictionary[word]; +  return word_id;  }  int Vocabulary::GetNonterminalIndex(int position) { @@ -24,11 +28,10 @@ bool Vocabulary::IsTerminal(int symbol) {  }  string Vocabulary::GetTerminalValue(int symbol) { -  return words[symbol]; -} - -int Vocabulary::Size() { -  return words.size(); +  string word; +  #pragma omp critical (vocabulary) +  word = words[symbol]; +  return word;  }  } // namespace extractor diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h index dcc2a8fa..03c7dc66 100644 --- a/extractor/vocabulary.h +++ b/extractor/vocabulary.h @@ -21,8 +21,6 @@ class Vocabulary {    virtual string GetTerminalValue(int symbol); -  int Size(); -   private:    unordered_map<string, int> dictionary;    vector<string> words;  | 
