diff options
| -rw-r--r-- | training/cllh_filter_grammar.cc | 94 | 
1 files changed, 67 insertions, 27 deletions
| diff --git a/training/cllh_filter_grammar.cc b/training/cllh_filter_grammar.cc index b5a4c35d..90fe9fba 100644 --- a/training/cllh_filter_grammar.cc +++ b/training/cllh_filter_grammar.cc @@ -1,6 +1,8 @@  #include <iostream>  #include <vector>  #include <cassert> +#include <unistd.h>   // fork +#include <sys/wait.h> // waitpid  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> @@ -19,7 +21,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    po::options_description opts("Configuration options");    opts.add_options()          ("training_data,t",po::value<string>(),"Training data corpus") -        ("decoder_config,c",po::value<string>(),"Decoder configuration file"); +        ("decoder_config,c",po::value<string>(),"Decoder configuration file") +        ("ncpus,n",po::value<unsigned>()->default_value(1),"Number of CPUs to use");    po::options_description clo("Command line options");    clo.add_options()          ("config", po::value<string>(), "Configuration file") @@ -41,14 +44,19 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    }  } -void ReadTrainingCorpus(const string& fname, vector<string>* c) { +void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c, vector<int>* ids) {    ReadFile rf(fname);    istream& in = *rf.stream();    string line; +  int lc = 0;    while(in) {      getline(in, line);      if (!in) break; -    c->push_back(line); +    if (lc % size == rank) { +      c->push_back(line); +      ids->push_back(lc); +    } +    ++lc;    }  } @@ -68,11 +76,10 @@ struct TrainingObserver : public DecoderObserver {    virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) {      assert(state == 1);      for (int i = 0; i < hg->edges_.size(); ++i) { -      const TRulePtr& rule = hg->edges_[i].rule_; +      const TRule* rule = hg->edges_[i].rule_.get();        if (rule->lhs_ == s_lhs || rule->lhs_ == goal_lhs)  // fragile hack to filter out glue rules          continue; -      if (rule->prev_i == -1) -        used.insert(rule); +      used.insert(rule);      }      state = 2;    } @@ -90,7 +97,7 @@ struct TrainingObserver : public DecoderObserver {      }    } -  set<TRulePtr> used; +  set<const TRule*> used;    const int s_lhs;    const int goal_lhs; @@ -99,39 +106,72 @@ struct TrainingObserver : public DecoderObserver {    int state;  }; +void work(const string& fname, int rank, int size, Decoder* decoder) { +  cerr << "Worker " << rank << '/' << size << " starting.\n"; +  vector<string> corpus; +  vector<int> ids; +  ReadTrainingCorpus(fname, rank, size, &corpus, &ids); +  assert(corpus.size() > 0); +  cerr << "  " << rank << '/' << size << ": has " << corpus.size() << " sentences to process\n"; +  ostringstream oc; oc << "corpus." << rank << "_of_" << size; +  WriteFile foc(oc.str()); +  ostringstream og; og << "grammar." << rank << "_of_" << size << ".gz"; +  WriteFile fog(og.str()); + +  set<const TRule*> all_used; +  TrainingObserver observer; +  for (int i = 0; i < corpus.size(); ++i) { +    int ex_num = ids[i]; +    decoder->SetId(ex_num); +    decoder->Decode(corpus[ex_num], &observer); +    if (observer.failed) { +      (*foc.stream()) << "*** id=" << ex_num << " is unreachable\n"; +    } else { +      (*foc.stream()) << corpus[ex_num] << endl; +      for (set<const TRule*>::iterator it = observer.used.begin(); it != observer.used.end(); ++it) { +        if (all_used.insert(*it).second) +          (*fog.stream()) << **it << endl; +      } +    } +  } +} +  int main(int argc, char** argv) { -  SetSilent(true);  // turn off verbose decoder output    register_feature_functions();    po::variables_map conf;    InitCommandLine(argc, argv, &conf); - +  const string fname = conf["training_data"].as<string>(); +  const unsigned ncpus = conf["ncpus"].as<unsigned>(); +  assert(ncpus > 0);    ReadFile ini_rf(conf["decoder_config"].as<string>());    Decoder decoder(ini_rf.stream());    if (decoder.GetConf()["input"].as<string>() != "-") {      cerr << "cdec.ini must not set an input file\n";      abort();    } - -  vector<string> corpus; -  ReadTrainingCorpus(conf["training_data"].as<string>(), &corpus); -  assert(corpus.size() > 0); - -  TrainingObserver observer; -  for (int i = 0; i < corpus.size(); ++i) { -    int ex_num = i; -    decoder.SetId(ex_num); -    decoder.Decode(corpus[ex_num], &observer); -    if (observer.failed) { -      cerr << "*** id=" << ex_num << " is unreachable\n"; -      observer.used.clear(); +  SetSilent(true);  // turn off verbose decoder output +  cerr << "Forking " << ncpus << " time(s)\n"; +  vector<pid_t> children; +  for (int i = 0; i < ncpus; ++i) { +    pid_t pid = fork(); +    if (pid < 0) { +      cerr << "Fork failed!\n"; +      exit(1); +    } +    if (pid > 0) { +      children.push_back(pid);      } else { -      cerr << corpus[ex_num] << endl; -      for (set<TRulePtr>::iterator it = observer.used.begin(); it != observer.used.end(); ++it) { -        cout << **it << endl; -        (*it)->prev_i = 0; -      } +      work(fname, i, ncpus, &decoder); +      cerr << "  " << i << "/" << ncpus << " finished.\n"; +      _exit(0);      }    } +  for (int i = 0; i < children.size(); ++i) { +    int status; +    int w = waitpid(children[i], &status, 0); +    if (w < 0) { cerr << "Error while waiting for children!"; return 1; } +    cerr << "Child " << i << ": status=" << status << " sig?=" << WIFSIGNALED(status) << " sig=" << WTERMSIG(status) << endl; +  }    return 0;  } | 
