summaryrefslogtreecommitdiff
path: root/training/mpi_batch_optimize.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-04-14 15:52:13 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-04-14 15:52:13 -0400
commitbfefc1215867094d808424289df39fc590641d83 (patch)
treec753ee876c9502c0eb9b39286b0d125bc1ccdc49 /training/mpi_batch_optimize.cc
parentbca7c502d92150af9cd7a7fea389eb21e8852ab0 (diff)
mpi update
Diffstat (limited to 'training/mpi_batch_optimize.cc')
-rw-r--r--training/mpi_batch_optimize.cc56
1 files changed, 28 insertions, 28 deletions
diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc
index 8f45aef1..11be8bbe 100644
--- a/training/mpi_batch_optimize.cc
+++ b/training/mpi_batch_optimize.cc
@@ -5,7 +5,9 @@
#include <cmath>
#ifdef HAVE_MPI
-#include <mpi.h>
+#include <boost/mpi/timer.hpp>
+#include <boost/mpi.hpp>
+namespace mpi = boost::mpi;
#endif
#include <boost/shared_ptr.hpp>
@@ -57,7 +59,7 @@ void ShowLargestFeatures(const vector<double>& w) {
cerr << endl;
}
-void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
("input_weights,w",po::value<string>(),"Input feature weights file")
@@ -87,18 +89,13 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) {
cerr << dcmdline_options << endl;
-#ifdef HAVE_MPI
- MPI::Finalize();
-#endif
- exit(1);
+ return false;
}
if (conf->count("training_data") && conf->count("sharded_input")) {
cerr << "Cannot specify both --training_data and --sharded_input\n";
-#ifdef HAVE_MPI
- MPI::Finalize();
-#endif
- exit(1);
+ return false;
}
+ return true;
}
void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) {
@@ -212,11 +209,22 @@ void StoreConfig(const vector<string>& cfg, istringstream* o) {
o->str(os.str());
}
+template <typename T>
+struct VectorPlus : public binary_function<vector<T>, vector<T>, vector<T> > {
+ vector<T> operator()(const vector<int>& a, const vector<int>& b) const {
+ assert(a.size() == b.size());
+ vector<T> v(a.size());
+ transform(a.begin(), a.end(), b.begin(), v.begin(), plus<T>());
+ return v;
+ }
+};
+
int main(int argc, char** argv) {
#ifdef HAVE_MPI
- MPI::Init(argc, argv);
- const int size = MPI::COMM_WORLD.Get_size();
- const int rank = MPI::COMM_WORLD.Get_rank();
+ mpi::environment env(argc, argv);
+ mpi::communicator world;
+ const int size = world.size();
+ const int rank = world.rank();
#else
const int size = 1;
const int rank = 0;
@@ -225,16 +233,13 @@ int main(int argc, char** argv) {
register_feature_functions();
po::variables_map conf;
- InitCommandLine(argc, argv, &conf);
+ if (!InitCommandLine(argc, argv, &conf)) return 1;
string shard_dir;
if (conf.count("sharded_input")) {
shard_dir = conf["sharded_input"].as<string>();
if (!DirectoryExists(shard_dir)) {
if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl;
-#ifdef HAVE_MPI
- MPI::Finalize();
-#endif
return 1;
}
if (rank == 0)
@@ -272,9 +277,7 @@ int main(int argc, char** argv) {
Decoder* decoder = new Decoder(&ini);
if (decoder->GetConf()["input"].as<string>() != "-") {
cerr << "cdec.ini must not set an input file\n";
-#ifdef HAVE_MPI
- MPI::COMM_WORLD.Abort(1);
-#endif
+ return 1;
}
if (rank == 0) cerr << "Done loading grammar!\n";
@@ -343,8 +346,8 @@ int main(int argc, char** argv) {
double to = 0;
#ifdef HAVE_MPI
- MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0);
- MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0);
+ mpi::reduce(world, &gradient[0], &rcv_grad[0], gradient.size(), plus<double>(), 0);
+ mpi::reduce(world, objective, to, plus<double>(), 0);
swap(gradient, rcv_grad);
objective = to;
#endif
@@ -395,14 +398,11 @@ int main(int argc, char** argv) {
} // rank == 0
int cint = converged;
#ifdef HAVE_MPI
- MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0);
- MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0);
- MPI::COMM_WORLD.Barrier();
+ mpi::broadcast(world, lambdas, 0);
+ mpi::broadcast(world, cint, 0);
+ world.barrier();
#endif
converged = cint;
}
-#ifdef HAVE_MPI
- MPI::Finalize();
-#endif
return 0;
}