diff options
-rw-r--r-- | training/cllh_filter_grammar.cc | 94 |
1 files changed, 67 insertions, 27 deletions
diff --git a/training/cllh_filter_grammar.cc b/training/cllh_filter_grammar.cc index b5a4c35d..90fe9fba 100644 --- a/training/cllh_filter_grammar.cc +++ b/training/cllh_filter_grammar.cc @@ -1,6 +1,8 @@ #include <iostream> #include <vector> #include <cassert> +#include <unistd.h> // fork +#include <sys/wait.h> // waitpid #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> @@ -19,7 +21,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("training_data,t",po::value<string>(),"Training data corpus") - ("decoder_config,c",po::value<string>(),"Decoder configuration file"); + ("decoder_config,c",po::value<string>(),"Decoder configuration file") + ("ncpus,n",po::value<unsigned>()->default_value(1),"Number of CPUs to use"); po::options_description clo("Command line options"); clo.add_options() ("config", po::value<string>(), "Configuration file") @@ -41,14 +44,19 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } -void ReadTrainingCorpus(const string& fname, vector<string>* c) { +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c, vector<int>* ids) { ReadFile rf(fname); istream& in = *rf.stream(); string line; + int lc = 0; while(in) { getline(in, line); if (!in) break; - c->push_back(line); + if (lc % size == rank) { + c->push_back(line); + ids->push_back(lc); + } + ++lc; } } @@ -68,11 +76,10 @@ struct TrainingObserver : public DecoderObserver { virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { assert(state == 1); for (int i = 0; i < hg->edges_.size(); ++i) { - const TRulePtr& rule = hg->edges_[i].rule_; + const TRule* rule = hg->edges_[i].rule_.get(); if (rule->lhs_ == s_lhs || rule->lhs_ == goal_lhs) // fragile hack to filter out glue rules continue; - if (rule->prev_i == -1) - used.insert(rule); + used.insert(rule); } state = 2; } @@ -90,7 +97,7 @@ struct TrainingObserver : public DecoderObserver { } } - set<TRulePtr> used; + set<const TRule*> used; const int s_lhs; const int goal_lhs; @@ -99,39 +106,72 @@ struct TrainingObserver : public DecoderObserver { int state; }; +void work(const string& fname, int rank, int size, Decoder* decoder) { + cerr << "Worker " << rank << '/' << size << " starting.\n"; + vector<string> corpus; + vector<int> ids; + ReadTrainingCorpus(fname, rank, size, &corpus, &ids); + assert(corpus.size() > 0); + cerr << " " << rank << '/' << size << ": has " << corpus.size() << " sentences to process\n"; + ostringstream oc; oc << "corpus." << rank << "_of_" << size; + WriteFile foc(oc.str()); + ostringstream og; og << "grammar." << rank << "_of_" << size << ".gz"; + WriteFile fog(og.str()); + + set<const TRule*> all_used; + TrainingObserver observer; + for (int i = 0; i < corpus.size(); ++i) { + int ex_num = ids[i]; + decoder->SetId(ex_num); + decoder->Decode(corpus[ex_num], &observer); + if (observer.failed) { + (*foc.stream()) << "*** id=" << ex_num << " is unreachable\n"; + } else { + (*foc.stream()) << corpus[ex_num] << endl; + for (set<const TRule*>::iterator it = observer.used.begin(); it != observer.used.end(); ++it) { + if (all_used.insert(*it).second) + (*fog.stream()) << **it << endl; + } + } + } +} + int main(int argc, char** argv) { - SetSilent(true); // turn off verbose decoder output register_feature_functions(); po::variables_map conf; InitCommandLine(argc, argv, &conf); - + const string fname = conf["training_data"].as<string>(); + const unsigned ncpus = conf["ncpus"].as<unsigned>(); + assert(ncpus > 0); ReadFile ini_rf(conf["decoder_config"].as<string>()); Decoder decoder(ini_rf.stream()); if (decoder.GetConf()["input"].as<string>() != "-") { cerr << "cdec.ini must not set an input file\n"; abort(); } - - vector<string> corpus; - ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus); - assert(corpus.size() > 0); - - TrainingObserver observer; - for (int i = 0; i < corpus.size(); ++i) { - int ex_num = i; - decoder.SetId(ex_num); - decoder.Decode(corpus[ex_num], &observer); - if (observer.failed) { - cerr << "*** id=" << ex_num << " is unreachable\n"; - observer.used.clear(); + SetSilent(true); // turn off verbose decoder output + cerr << "Forking " << ncpus << " time(s)\n"; + vector<pid_t> children; + for (int i = 0; i < ncpus; ++i) { + pid_t pid = fork(); + if (pid < 0) { + cerr << "Fork failed!\n"; + exit(1); + } + if (pid > 0) { + children.push_back(pid); } else { - cerr << corpus[ex_num] << endl; - for (set<TRulePtr>::iterator it = observer.used.begin(); it != observer.used.end(); ++it) { - cout << **it << endl; - (*it)->prev_i = 0; - } + work(fname, i, ncpus, &decoder); + cerr << " " << i << "/" << ncpus << " finished.\n"; + _exit(0); } } + for (int i = 0; i < children.size(); ++i) { + int status; + int w = waitpid(children[i], &status, 0); + if (w < 0) { cerr << "Error while waiting for children!"; return 1; } + cerr << "Child " << i << ": status=" << status << " sig?=" << WIFSIGNALED(status) << " sig=" << WTERMSIG(status) << endl; + } return 0; } |