diff options
| -rw-r--r-- | training/Makefile.am | 10 | ||||
| -rw-r--r-- | training/mpi_batch_optimize.cc | 22 | 
2 files changed, 27 insertions, 5 deletions
| diff --git a/training/Makefile.am b/training/Makefile.am index 89d4a4c9..ec33e267 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -9,7 +9,8 @@ bin_PROGRAMS = \    plftools \    collapse_weights \    cllh_filter_grammar \ -  mpi_online_optimize +  mpi_online_optimize \ +  mpi_batch_optimize  noinst_PROGRAMS = \    lbfgs_test \ @@ -20,13 +21,12 @@ TESTS = lbfgs_test optimize_test  mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc  mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz -if MPI -bin_PROGRAMS += mpi_batch_optimize \ -             compute_cllh -  mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc  mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz +if MPI +bin_PROGRAMS += compute_cllh +  compute_cllh_SOURCES = compute_cllh.cc  compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz  endif diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index f1ee9fb4..8f45aef1 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -4,7 +4,10 @@  #include <cassert>  #include <cmath> +#ifdef HAVE_MPI  #include <mpi.h> +#endif +  #include <boost/shared_ptr.hpp>  #include <boost/program_options.hpp>  #include <boost/program_options/variables_map.hpp> @@ -84,12 +87,16 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) {      cerr << dcmdline_options << endl; +#ifdef HAVE_MPI      MPI::Finalize(); +#endif      exit(1);    }    if (conf->count("training_data") && conf->count("sharded_input")) {      cerr << "Cannot specify both --training_data and --sharded_input\n"; +#ifdef HAVE_MPI      MPI::Finalize(); +#endif      exit(1);    }  } @@ -206,9 +213,14 @@ void StoreConfig(const vector<string>& cfg, istringstream* o) {  }  int main(int argc, char** argv) { +#ifdef HAVE_MPI    MPI::Init(argc, argv);    const int size = MPI::COMM_WORLD.Get_size();     const int rank = MPI::COMM_WORLD.Get_rank(); +#else +  const int size = 1; +  const int rank = 0; +#endif    SetSilent(true);  // turn off verbose decoder output    register_feature_functions(); @@ -220,7 +232,9 @@ int main(int argc, char** argv) {      shard_dir = conf["sharded_input"].as<string>();      if (!DirectoryExists(shard_dir)) {        if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl; +#ifdef HAVE_MPI        MPI::Finalize(); +#endif        return 1;      }      if (rank == 0) @@ -258,7 +272,9 @@ int main(int argc, char** argv) {    Decoder* decoder = new Decoder(&ini);    if (decoder->GetConf()["input"].as<string>() != "-") {      cerr << "cdec.ini must not set an input file\n"; +#ifdef HAVE_MPI      MPI::COMM_WORLD.Abort(1); +#endif    }    if (rank == 0) cerr << "Done loading grammar!\n"; @@ -326,10 +342,12 @@ int main(int argc, char** argv) {      observer.SetLocalGradientAndObjective(&gradient, &objective);      double to = 0; +#ifdef HAVE_MPI      MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0);      MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0);      swap(gradient, rcv_grad);      objective = to; +#endif      if (rank == 0) {  // run optimizer only on rank=0 node        if (gaussian_prior) { @@ -376,11 +394,15 @@ int main(int argc, char** argv) {        weights.WriteToFile(fname, true, &svv);      }  // rank == 0      int cint = converged; +#ifdef HAVE_MPI      MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0);      MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0);      MPI::COMM_WORLD.Barrier(); +#endif      converged = cint;    } +#ifdef HAVE_MPI    MPI::Finalize();  +#endif    return 0;  } | 
