diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-11-02 19:13:43 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-11-02 19:13:43 +0000 |
commit | 1336aecfe930546f8836ffe65dd5ff78434084eb (patch) | |
tree | 7d8c9396dadb58d06e72c30f5ca874389b22dea7 | |
parent | cd7562fde01771d461350cf91b383021754ea27b (diff) |
mpi batch optimize without mpi
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@705 ec762483-ff6d-05da-a07a-a48fb63a330f
-rw-r--r-- | training/Makefile.am | 10 | ||||
-rw-r--r-- | training/mpi_batch_optimize.cc | 22 |
2 files changed, 27 insertions, 5 deletions
diff --git a/training/Makefile.am b/training/Makefile.am index 89d4a4c9..ec33e267 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -9,7 +9,8 @@ bin_PROGRAMS = \ plftools \ collapse_weights \ cllh_filter_grammar \ - mpi_online_optimize + mpi_online_optimize \ + mpi_batch_optimize noinst_PROGRAMS = \ lbfgs_test \ @@ -20,13 +21,12 @@ TESTS = lbfgs_test optimize_test mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz -if MPI -bin_PROGRAMS += mpi_batch_optimize \ - compute_cllh - mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz +if MPI +bin_PROGRAMS += compute_cllh + compute_cllh_SOURCES = compute_cllh.cc compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz endif diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index f1ee9fb4..8f45aef1 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -4,7 +4,10 @@ #include <cassert> #include <cmath> +#ifdef HAVE_MPI #include <mpi.h> +#endif + #include <boost/shared_ptr.hpp> #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> @@ -84,12 +87,16 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; +#ifdef HAVE_MPI MPI::Finalize(); +#endif exit(1); } if (conf->count("training_data") && conf->count("sharded_input")) { cerr << "Cannot specify both --training_data and --sharded_input\n"; +#ifdef HAVE_MPI MPI::Finalize(); +#endif exit(1); } } @@ -206,9 +213,14 @@ void StoreConfig(const vector<string>& cfg, istringstream* o) { } int main(int argc, char** argv) { +#ifdef HAVE_MPI MPI::Init(argc, argv); const int size = MPI::COMM_WORLD.Get_size(); const int rank = MPI::COMM_WORLD.Get_rank(); +#else + const int size = 1; + const int rank = 0; +#endif SetSilent(true); // turn off verbose decoder output register_feature_functions(); @@ -220,7 +232,9 @@ int main(int argc, char** argv) { shard_dir = conf["sharded_input"].as<string>(); if (!DirectoryExists(shard_dir)) { if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl; +#ifdef HAVE_MPI MPI::Finalize(); +#endif return 1; } if (rank == 0) @@ -258,7 +272,9 @@ int main(int argc, char** argv) { Decoder* decoder = new Decoder(&ini); if (decoder->GetConf()["input"].as<string>() != "-") { cerr << "cdec.ini must not set an input file\n"; +#ifdef HAVE_MPI MPI::COMM_WORLD.Abort(1); +#endif } if (rank == 0) cerr << "Done loading grammar!\n"; @@ -326,10 +342,12 @@ int main(int argc, char** argv) { observer.SetLocalGradientAndObjective(&gradient, &objective); double to = 0; +#ifdef HAVE_MPI MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0); MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0); swap(gradient, rcv_grad); objective = to; +#endif if (rank == 0) { // run optimizer only on rank=0 node if (gaussian_prior) { @@ -376,11 +394,15 @@ int main(int argc, char** argv) { weights.WriteToFile(fname, true, &svv); } // rank == 0 int cint = converged; +#ifdef HAVE_MPI MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0); MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0); MPI::COMM_WORLD.Barrier(); +#endif converged = cint; } +#ifdef HAVE_MPI MPI::Finalize(); +#endif return 0; } |