diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-04-14 15:52:13 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-04-14 15:52:13 -0400 |
commit | bfefc1215867094d808424289df39fc590641d83 (patch) | |
tree | c753ee876c9502c0eb9b39286b0d125bc1ccdc49 /training/mpi_batch_optimize.cc | |
parent | bca7c502d92150af9cd7a7fea389eb21e8852ab0 (diff) |
mpi update
Diffstat (limited to 'training/mpi_batch_optimize.cc')
-rw-r--r-- | training/mpi_batch_optimize.cc | 56 |
1 files changed, 28 insertions, 28 deletions
diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index 8f45aef1..11be8bbe 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -5,7 +5,9 @@ #include <cmath> #ifdef HAVE_MPI -#include <mpi.h> +#include <boost/mpi/timer.hpp> +#include <boost/mpi.hpp> +namespace mpi = boost::mpi; #endif #include <boost/shared_ptr.hpp> @@ -57,7 +59,7 @@ void ShowLargestFeatures(const vector<double>& w) { cerr << endl; } -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("input_weights,w",po::value<string>(),"Input feature weights file") @@ -87,18 +89,13 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; -#ifdef HAVE_MPI - MPI::Finalize(); -#endif - exit(1); + return false; } if (conf->count("training_data") && conf->count("sharded_input")) { cerr << "Cannot specify both --training_data and --sharded_input\n"; -#ifdef HAVE_MPI - MPI::Finalize(); -#endif - exit(1); + return false; } + return true; } void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) { @@ -212,11 +209,22 @@ void StoreConfig(const vector<string>& cfg, istringstream* o) { o->str(os.str()); } +template <typename T> +struct VectorPlus : public binary_function<vector<T>, vector<T>, vector<T> > { + vector<T> operator()(const vector<int>& a, const vector<int>& b) const { + assert(a.size() == b.size()); + vector<T> v(a.size()); + transform(a.begin(), a.end(), b.begin(), v.begin(), plus<T>()); + return v; + } +}; + int main(int argc, char** argv) { #ifdef HAVE_MPI - MPI::Init(argc, argv); - const int size = MPI::COMM_WORLD.Get_size(); - const int rank = MPI::COMM_WORLD.Get_rank(); + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); #else const int size = 1; const int rank = 0; @@ -225,16 +233,13 @@ int main(int argc, char** argv) { register_feature_functions(); po::variables_map conf; - InitCommandLine(argc, argv, &conf); + if (!InitCommandLine(argc, argv, &conf)) return 1; string shard_dir; if (conf.count("sharded_input")) { shard_dir = conf["sharded_input"].as<string>(); if (!DirectoryExists(shard_dir)) { if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl; -#ifdef HAVE_MPI - MPI::Finalize(); -#endif return 1; } if (rank == 0) @@ -272,9 +277,7 @@ int main(int argc, char** argv) { Decoder* decoder = new Decoder(&ini); if (decoder->GetConf()["input"].as<string>() != "-") { cerr << "cdec.ini must not set an input file\n"; -#ifdef HAVE_MPI - MPI::COMM_WORLD.Abort(1); -#endif + return 1; } if (rank == 0) cerr << "Done loading grammar!\n"; @@ -343,8 +346,8 @@ int main(int argc, char** argv) { double to = 0; #ifdef HAVE_MPI - MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0); - MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0); + mpi::reduce(world, &gradient[0], &rcv_grad[0], gradient.size(), plus<double>(), 0); + mpi::reduce(world, objective, to, plus<double>(), 0); swap(gradient, rcv_grad); objective = to; #endif @@ -395,14 +398,11 @@ int main(int argc, char** argv) { } // rank == 0 int cint = converged; #ifdef HAVE_MPI - MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0); - MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0); - MPI::COMM_WORLD.Barrier(); + mpi::broadcast(world, lambdas, 0); + mpi::broadcast(world, cint, 0); + world.barrier(); #endif converged = cint; } -#ifdef HAVE_MPI - MPI::Finalize(); -#endif return 0; } |