summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-03-07 22:49:46 +0000
committerPaul Baltescu <pauldb89@gmail.com>2013-03-07 22:49:46 +0000
commit092b7cf020680e949d6956ec6ef2cf012faccd86 (patch)
tree4bc074572925a10b63928639be244a60f153f7ac /extractor
parentd7271db305bd1aeaf9c3d9ac1043546fec22a402 (diff)
Parallelized grammar extraction.
Diffstat (limited to 'extractor')
-rw-r--r--extractor/Makefile.am3
-rw-r--r--extractor/matchings_trie.cc11
-rw-r--r--extractor/matchings_trie.h7
-rw-r--r--extractor/rule_extractor_helper.cc2
-rw-r--r--extractor/rule_factory.cc6
-rw-r--r--extractor/rule_factory.h1
-rw-r--r--extractor/run_extractor.cc41
-rw-r--r--extractor/vocabulary.cc27
-rw-r--r--extractor/vocabulary.h2
9 files changed, 60 insertions, 40 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index 721df18b..d8239b7d 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -145,4 +145,5 @@ libextractor_a_SOURCES = \
translation_table.cc \
vocabulary.cc
-AM_CPPFLAGS = -W -Wall -Wno-sign-compare -std=c++0x $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS)
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare -std=c++0x -fopenmp $(GTEST_CPPFLAGS) $(GMOCK_CPPFLAGS)
+AM_LDFLAGS = -fopenmp
diff --git a/extractor/matchings_trie.cc b/extractor/matchings_trie.cc
index c7b98765..7fb7a529 100644
--- a/extractor/matchings_trie.cc
+++ b/extractor/matchings_trie.cc
@@ -2,19 +2,22 @@
namespace extractor {
-void MatchingsTrie::Reset() {
- ResetTree(root);
+MatchingsTrie::MatchingsTrie() {
root = make_shared<TrieNode>();
}
+MatchingsTrie::~MatchingsTrie() {
+ DeleteTree(root);
+}
+
shared_ptr<TrieNode> MatchingsTrie::GetRoot() const {
return root;
}
-void MatchingsTrie::ResetTree(shared_ptr<TrieNode> root) {
+void MatchingsTrie::DeleteTree(shared_ptr<TrieNode> root) {
if (root != NULL) {
for (auto child: root->children) {
- ResetTree(child.second);
+ DeleteTree(child.second);
}
if (root->suffix_link != NULL) {
root->suffix_link.reset();
diff --git a/extractor/matchings_trie.h b/extractor/matchings_trie.h
index a54671d2..f3dcc075 100644
--- a/extractor/matchings_trie.h
+++ b/extractor/matchings_trie.h
@@ -37,11 +37,14 @@ struct TrieNode {
class MatchingsTrie {
public:
- void Reset();
+ MatchingsTrie();
+
+ virtual ~MatchingsTrie();
+
shared_ptr<TrieNode> GetRoot() const;
private:
- void ResetTree(shared_ptr<TrieNode> root);
+ void DeleteTree(shared_ptr<TrieNode> root);
shared_ptr<TrieNode> root;
};
diff --git a/extractor/rule_extractor_helper.cc b/extractor/rule_extractor_helper.cc
index d9ed6a7e..81b522f0 100644
--- a/extractor/rule_extractor_helper.cc
+++ b/extractor/rule_extractor_helper.cc
@@ -117,7 +117,7 @@ bool RuleExtractorHelper::FindFixPoint(
source_high, target_phrase_low, target_phrase_high);
if (target_phrase_low == -1) {
- // TODO(pauldb): Low priority corner case inherited from Adam's code:
+ // Note: Low priority corner case inherited from Adam's code:
// If w is unaligned, but we don't require aligned terminals, returning an
// error here prevents the extraction of the allowed rule
// X -> X_1 w X_2 / X_1 X_2
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 51f85c30..a5505ced 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -105,8 +105,8 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
double total_extract_time = 0;
double total_intersect_time = 0;
double total_lookup_time = 0;
- // Clear cache for every new sentence.
- trie.Reset();
+
+ MatchingsTrie trie;
shared_ptr<TrieNode> root = trie.GetRoot();
int first_x = vocabulary->GetNonterminalIndex(1);
@@ -200,8 +200,6 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
}
}
- cerr << "Vocabulary size = " << vocabulary->Size() << endl;
-
Clock::time_point stop_time = Clock::now();
cerr << "Total time for rule lookup, extraction, and scoring = "
<< GetDuration(start_time, stop_time) << " seconds" << endl;
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index 0de04e40..d8dc2ccc 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -81,7 +81,6 @@ class HieroCachingRuleFactory {
shared_ptr<MatchingsFinder> matchings_finder;
shared_ptr<FastIntersector> fast_intersector;
- MatchingsTrie trie;
shared_ptr<PhraseBuilder> phrase_builder;
shared_ptr<RuleExtractor> rule_extractor;
shared_ptr<Vocabulary> vocabulary;
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index ae3a875e..eb5600fe 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -5,6 +5,7 @@
#include <string>
#include <vector>
+#include <omp.h>
#include <boost/filesystem.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
@@ -35,6 +36,10 @@ using namespace extractor;
using namespace features;
int main(int argc, char** argv) {
+ int num_threads_default = 1;
+ #pragma omp parallel
+ num_threads_default = omp_get_num_threads();
+
po::options_description desc("Command line options");
desc.add_options()
("help,h", "Show available options")
@@ -43,13 +48,15 @@ int main(int argc, char** argv) {
("bitext,b", po::value<string>(), "Parallel text (source ||| target)")
("alignment,a", po::value<string>()->required(), "Bitext word alignment")
("grammars,g", po::value<string>()->required(), "Grammars output path")
+ ("threads,t", po::value<int>()->default_value(num_threads_default),
+ "Number of parallel extractors")
("frequent", po::value<int>()->default_value(100),
"Number of precomputed frequent patterns")
("super_frequent", po::value<int>()->default_value(10),
"Number of precomputed super frequent patterns")
("max_rule_span", po::value<int>()->default_value(15),
"Maximum rule span")
- ("max_rule_symbols,l", po::value<int>()->default_value(5),
+ ("max_rule_symbols", po::value<int>()->default_value(5),
"Maximum number of symbols (terminals + nontermals) in a rule")
("min_gap_size", po::value<int>()->default_value(1), "Minimum gap size")
("max_phrase_len", po::value<int>()->default_value(4),
@@ -155,7 +162,6 @@ int main(int argc, char** argv) {
};
shared_ptr<Scorer> scorer = make_shared<Scorer>(features);
- // TODO(pauldb): Add parallelization.
GrammarExtractor extractor(
source_suffix_array,
target_data_array,
@@ -172,30 +178,39 @@ int main(int argc, char** argv) {
// Release extra memory used by the initial precomputation.
precomputation.reset();
- int grammar_id = 0;
fs::path grammar_path = vm["grammars"].as<string>();
if (!fs::is_directory(grammar_path)) {
fs::create_directory(grammar_path);
}
- string sentence, delimiter = "|||";
+ string sentence;
+ vector<string> sentences;
while (getline(cin, sentence)) {
+ sentences.push_back(sentence);
+ }
+
+ #pragma omp parallel for schedule(dynamic) \
+ num_threads(vm["threads"].as<int>()) ordered
+ for (size_t i = 0; i < sentences.size(); ++i) {
+ string delimiter = "|||";
string suffix = "";
- int position = sentence.find(delimiter);
- if (position != sentence.npos) {
- suffix = sentence.substr(position);
- sentence = sentence.substr(0, position);
+ int position = sentences[i].find(delimiter);
+ if (position != sentences[i].npos) {
+ suffix = sentences[i].substr(position);
+ sentences[i] = sentences[i].substr(0, position);
}
- Grammar grammar = extractor.GetGrammar(sentence);
- string file_name = "grammar." + to_string(grammar_id);
+ Grammar grammar = extractor.GetGrammar(sentences[i]);
+ string file_name = "grammar." + to_string(i);
fs::path grammar_file = grammar_path / file_name;
ofstream output(grammar_file.c_str());
output << grammar;
- cout << "<seg grammar=\"" << grammar_file << "\" id=\"" << grammar_id
- << "\"> " << sentence << " </seg> " << suffix << endl;
- ++grammar_id;
+ #pragma omp critical (stdout_write)
+ {
+ cout << "<seg grammar=\"" << grammar_file << "\" id=\"" << i << "\"> "
+ << sentences[i] << " </seg> " << suffix << endl;
+ }
}
Clock::time_point extraction_stop_time = Clock::now();
cerr << "Overall extraction step took "
diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc
index 57f564d9..15795d1e 100644
--- a/extractor/vocabulary.cc
+++ b/extractor/vocabulary.cc
@@ -5,14 +5,18 @@ namespace extractor {
Vocabulary::~Vocabulary() {}
int Vocabulary::GetTerminalIndex(const string& word) {
- if (!dictionary.count(word)) {
- int word_id = words.size();
- dictionary[word] = word_id;
- words.push_back(word);
- return word_id;
+ int word_id = -1;
+ #pragma omp critical (vocabulary)
+ {
+ if (!dictionary.count(word)) {
+ word_id = words.size();
+ dictionary[word] = word_id;
+ words.push_back(word);
+ } else {
+ word_id = dictionary[word];
+ }
}
-
- return dictionary[word];
+ return word_id;
}
int Vocabulary::GetNonterminalIndex(int position) {
@@ -24,11 +28,10 @@ bool Vocabulary::IsTerminal(int symbol) {
}
string Vocabulary::GetTerminalValue(int symbol) {
- return words[symbol];
-}
-
-int Vocabulary::Size() {
- return words.size();
+ string word;
+ #pragma omp critical (vocabulary)
+ word = words[symbol];
+ return word;
}
} // namespace extractor
diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h
index dcc2a8fa..03c7dc66 100644
--- a/extractor/vocabulary.h
+++ b/extractor/vocabulary.h
@@ -21,8 +21,6 @@ class Vocabulary {
virtual string GetTerminalValue(int symbol);
- int Size();
-
private:
unordered_map<string, int> dictionary;
vector<string> words;