From 0c83a47353733c85814871a9a656cafdc70e6800 Mon Sep 17 00:00:00 2001 From: redpony Date: Tue, 2 Nov 2010 19:13:43 +0000 Subject: mpi batch optimize without mpi git-svn-id: https://ws10smt.googlecode.com/svn/trunk@705 ec762483-ff6d-05da-a07a-a48fb63a330f --- training/mpi_batch_optimize.cc | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'training/mpi_batch_optimize.cc') diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index f1ee9fb4..8f45aef1 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -4,7 +4,10 @@ #include #include +#ifdef HAVE_MPI #include +#endif + #include #include #include @@ -84,12 +87,16 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; +#ifdef HAVE_MPI MPI::Finalize(); +#endif exit(1); } if (conf->count("training_data") && conf->count("sharded_input")) { cerr << "Cannot specify both --training_data and --sharded_input\n"; +#ifdef HAVE_MPI MPI::Finalize(); +#endif exit(1); } } @@ -206,9 +213,14 @@ void StoreConfig(const vector& cfg, istringstream* o) { } int main(int argc, char** argv) { +#ifdef HAVE_MPI MPI::Init(argc, argv); const int size = MPI::COMM_WORLD.Get_size(); const int rank = MPI::COMM_WORLD.Get_rank(); +#else + const int size = 1; + const int rank = 0; +#endif SetSilent(true); // turn off verbose decoder output register_feature_functions(); @@ -220,7 +232,9 @@ int main(int argc, char** argv) { shard_dir = conf["sharded_input"].as(); if (!DirectoryExists(shard_dir)) { if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl; +#ifdef HAVE_MPI MPI::Finalize(); +#endif return 1; } if (rank == 0) @@ -258,7 +272,9 @@ int main(int argc, char** argv) { Decoder* decoder = new Decoder(&ini); if (decoder->GetConf()["input"].as() != "-") { cerr << "cdec.ini must not set an input file\n"; +#ifdef HAVE_MPI MPI::COMM_WORLD.Abort(1); +#endif } if (rank == 0) cerr << "Done loading grammar!\n"; @@ -326,10 +342,12 @@ int main(int argc, char** argv) { observer.SetLocalGradientAndObjective(&gradient, &objective); double to = 0; +#ifdef HAVE_MPI MPI::COMM_WORLD.Reduce(const_cast(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0); MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0); swap(gradient, rcv_grad); objective = to; +#endif if (rank == 0) { // run optimizer only on rank=0 node if (gaussian_prior) { @@ -376,11 +394,15 @@ int main(int argc, char** argv) { weights.WriteToFile(fname, true, &svv); } // rank == 0 int cint = converged; +#ifdef HAVE_MPI MPI::COMM_WORLD.Bcast(const_cast(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0); MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0); MPI::COMM_WORLD.Barrier(); +#endif converged = cint; } +#ifdef HAVE_MPI MPI::Finalize(); +#endif return 0; } -- cgit v1.2.3