summaryrefslogtreecommitdiff
path: root/training/mpi_batch_optimize.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-04-14 22:22:22 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-04-14 22:22:22 -0400
commitfb205d034960c42c37475318d39631ff6e2c7357 (patch)
tree5441a8880ebdbe964c9c2a18521b46629f578a62 /training/mpi_batch_optimize.cc
parentbfefc1215867094d808424289df39fc590641d83 (diff)
mpi optimizations
Diffstat (limited to 'training/mpi_batch_optimize.cc')
-rw-r--r--training/mpi_batch_optimize.cc11
1 files changed, 7 insertions, 4 deletions
diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc
index 11be8bbe..5a6bf301 100644
--- a/training/mpi_batch_optimize.cc
+++ b/training/mpi_batch_optimize.cc
@@ -4,6 +4,7 @@
#include <cassert>
#include <cmath>
+#include "config.h"
#ifdef HAVE_MPI
#include <boost/mpi/timer.hpp>
#include <boost/mpi.hpp>
@@ -333,20 +334,23 @@ int main(int argc, char** argv) {
TrainingObserver observer;
while (!converged) {
observer.Reset();
+#ifdef HAVE_MPI
+ world.barrier();
+#endif
if (rank == 0) {
cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n";
}
decoder->SetWeights(lambdas);
for (int i = 0; i < corpus.size(); ++i)
decoder->Decode(corpus[i], &observer);
-
+ cerr << " process " << rank << '/' << size << " done\n";
fill(gradient.begin(), gradient.end(), 0);
fill(rcv_grad.begin(), rcv_grad.end(), 0);
observer.SetLocalGradientAndObjective(&gradient, &objective);
double to = 0;
#ifdef HAVE_MPI
- mpi::reduce(world, &gradient[0], &rcv_grad[0], gradient.size(), plus<double>(), 0);
+ mpi::reduce(world, &gradient[0], gradient.size(), &rcv_grad[0], plus<double>(), 0);
mpi::reduce(world, objective, to, plus<double>(), 0);
swap(gradient, rcv_grad);
objective = to;
@@ -398,9 +402,8 @@ int main(int argc, char** argv) {
} // rank == 0
int cint = converged;
#ifdef HAVE_MPI
- mpi::broadcast(world, lambdas, 0);
+ mpi::broadcast(world, &lambdas[0], lambdas.size(), 0);
mpi::broadcast(world, cint, 0);
- world.barrier();
#endif
converged = cint;
}