summaryrefslogtreecommitdiff
path: root/training/mpi_batch_optimize.cc
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-02 19:13:43 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-02 19:13:43 +0000
commit1336aecfe930546f8836ffe65dd5ff78434084eb (patch)
tree7d8c9396dadb58d06e72c30f5ca874389b22dea7 /training/mpi_batch_optimize.cc
parentcd7562fde01771d461350cf91b383021754ea27b (diff)
mpi batch optimize without mpi
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@705 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/mpi_batch_optimize.cc')
-rw-r--r--training/mpi_batch_optimize.cc22
1 files changed, 22 insertions, 0 deletions
diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc
index f1ee9fb4..8f45aef1 100644
--- a/training/mpi_batch_optimize.cc
+++ b/training/mpi_batch_optimize.cc
@@ -4,7 +4,10 @@
#include <cassert>
#include <cmath>
+#ifdef HAVE_MPI
#include <mpi.h>
+#endif
+
#include <boost/shared_ptr.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
@@ -84,12 +87,16 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) {
cerr << dcmdline_options << endl;
+#ifdef HAVE_MPI
MPI::Finalize();
+#endif
exit(1);
}
if (conf->count("training_data") && conf->count("sharded_input")) {
cerr << "Cannot specify both --training_data and --sharded_input\n";
+#ifdef HAVE_MPI
MPI::Finalize();
+#endif
exit(1);
}
}
@@ -206,9 +213,14 @@ void StoreConfig(const vector<string>& cfg, istringstream* o) {
}
int main(int argc, char** argv) {
+#ifdef HAVE_MPI
MPI::Init(argc, argv);
const int size = MPI::COMM_WORLD.Get_size();
const int rank = MPI::COMM_WORLD.Get_rank();
+#else
+ const int size = 1;
+ const int rank = 0;
+#endif
SetSilent(true); // turn off verbose decoder output
register_feature_functions();
@@ -220,7 +232,9 @@ int main(int argc, char** argv) {
shard_dir = conf["sharded_input"].as<string>();
if (!DirectoryExists(shard_dir)) {
if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl;
+#ifdef HAVE_MPI
MPI::Finalize();
+#endif
return 1;
}
if (rank == 0)
@@ -258,7 +272,9 @@ int main(int argc, char** argv) {
Decoder* decoder = new Decoder(&ini);
if (decoder->GetConf()["input"].as<string>() != "-") {
cerr << "cdec.ini must not set an input file\n";
+#ifdef HAVE_MPI
MPI::COMM_WORLD.Abort(1);
+#endif
}
if (rank == 0) cerr << "Done loading grammar!\n";
@@ -326,10 +342,12 @@ int main(int argc, char** argv) {
observer.SetLocalGradientAndObjective(&gradient, &objective);
double to = 0;
+#ifdef HAVE_MPI
MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0);
MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0);
swap(gradient, rcv_grad);
objective = to;
+#endif
if (rank == 0) { // run optimizer only on rank=0 node
if (gaussian_prior) {
@@ -376,11 +394,15 @@ int main(int argc, char** argv) {
weights.WriteToFile(fname, true, &svv);
} // rank == 0
int cint = converged;
+#ifdef HAVE_MPI
MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0);
MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0);
MPI::COMM_WORLD.Barrier();
+#endif
converged = cint;
}
+#ifdef HAVE_MPI
MPI::Finalize();
+#endif
return 0;
}