summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-02 19:13:43 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-02 19:13:43 +0000
commit0c83a47353733c85814871a9a656cafdc70e6800 (patch)
tree6f2f453f81eb33a63537c04e0f192869614a6d42 /training
parent8ab20672df0eb71e5d1e3c6b84adaa1f4ddc2b74 (diff)
mpi batch optimize without mpi
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@705 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training')
-rw-r--r--training/Makefile.am10
-rw-r--r--training/mpi_batch_optimize.cc22
2 files changed, 27 insertions, 5 deletions
diff --git a/training/Makefile.am b/training/Makefile.am
index 89d4a4c9..ec33e267 100644
--- a/training/Makefile.am
+++ b/training/Makefile.am
@@ -9,7 +9,8 @@ bin_PROGRAMS = \
plftools \
collapse_weights \
cllh_filter_grammar \
- mpi_online_optimize
+ mpi_online_optimize \
+ mpi_batch_optimize
noinst_PROGRAMS = \
lbfgs_test \
@@ -20,13 +21,12 @@ TESTS = lbfgs_test optimize_test
mpi_online_optimize_SOURCES = mpi_online_optimize.cc online_optimizer.cc
mpi_online_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz
-if MPI
-bin_PROGRAMS += mpi_batch_optimize \
- compute_cllh
-
mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc optimize.cc
mpi_batch_optimize_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz
+if MPI
+bin_PROGRAMS += compute_cllh
+
compute_cllh_SOURCES = compute_cllh.cc
compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz
endif
diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc
index f1ee9fb4..8f45aef1 100644
--- a/training/mpi_batch_optimize.cc
+++ b/training/mpi_batch_optimize.cc
@@ -4,7 +4,10 @@
#include <cassert>
#include <cmath>
+#ifdef HAVE_MPI
#include <mpi.h>
+#endif
+
#include <boost/shared_ptr.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
@@ -84,12 +87,16 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) {
cerr << dcmdline_options << endl;
+#ifdef HAVE_MPI
MPI::Finalize();
+#endif
exit(1);
}
if (conf->count("training_data") && conf->count("sharded_input")) {
cerr << "Cannot specify both --training_data and --sharded_input\n";
+#ifdef HAVE_MPI
MPI::Finalize();
+#endif
exit(1);
}
}
@@ -206,9 +213,14 @@ void StoreConfig(const vector<string>& cfg, istringstream* o) {
}
int main(int argc, char** argv) {
+#ifdef HAVE_MPI
MPI::Init(argc, argv);
const int size = MPI::COMM_WORLD.Get_size();
const int rank = MPI::COMM_WORLD.Get_rank();
+#else
+ const int size = 1;
+ const int rank = 0;
+#endif
SetSilent(true); // turn off verbose decoder output
register_feature_functions();
@@ -220,7 +232,9 @@ int main(int argc, char** argv) {
shard_dir = conf["sharded_input"].as<string>();
if (!DirectoryExists(shard_dir)) {
if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl;
+#ifdef HAVE_MPI
MPI::Finalize();
+#endif
return 1;
}
if (rank == 0)
@@ -258,7 +272,9 @@ int main(int argc, char** argv) {
Decoder* decoder = new Decoder(&ini);
if (decoder->GetConf()["input"].as<string>() != "-") {
cerr << "cdec.ini must not set an input file\n";
+#ifdef HAVE_MPI
MPI::COMM_WORLD.Abort(1);
+#endif
}
if (rank == 0) cerr << "Done loading grammar!\n";
@@ -326,10 +342,12 @@ int main(int argc, char** argv) {
observer.SetLocalGradientAndObjective(&gradient, &objective);
double to = 0;
+#ifdef HAVE_MPI
MPI::COMM_WORLD.Reduce(const_cast<double*>(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0);
MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0);
swap(gradient, rcv_grad);
objective = to;
+#endif
if (rank == 0) { // run optimizer only on rank=0 node
if (gaussian_prior) {
@@ -376,11 +394,15 @@ int main(int argc, char** argv) {
weights.WriteToFile(fname, true, &svv);
} // rank == 0
int cint = converged;
+#ifdef HAVE_MPI
MPI::COMM_WORLD.Bcast(const_cast<double*>(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0);
MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0);
MPI::COMM_WORLD.Barrier();
+#endif
converged = cint;
}
+#ifdef HAVE_MPI
MPI::Finalize();
+#endif
return 0;
}