summaryrefslogtreecommitdiff
path: root/training/mpi_flex_optimize.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/mpi_flex_optimize.cc')
-rw-r--r--training/mpi_flex_optimize.cc13
1 files changed, 7 insertions, 6 deletions
diff --git a/training/mpi_flex_optimize.cc b/training/mpi_flex_optimize.cc
index 00746532..a9197208 100644
--- a/training/mpi_flex_optimize.cc
+++ b/training/mpi_flex_optimize.cc
@@ -205,7 +205,7 @@ int main(int argc, char** argv) {
const int size = 1;
const int rank = 0;
#endif
- if (size > 0) SetSilent(true); // turn off verbose decoder output
+ if (size > 1) SetSilent(true); // turn off verbose decoder output
register_feature_functions();
MT19937* rng = NULL;
@@ -272,6 +272,7 @@ int main(int argc, char** argv) {
int iter = -1;
bool converged = false;
+ vector<double> gg;
while (!converged) {
#ifdef HAVE_MPI
mpi::timer timer;
@@ -343,7 +344,7 @@ int main(int argc, char** argv) {
double obj = 0;
#ifdef HAVE_MPI
- // TODO obj
+ reduce(world, local_obj, obj, std::plus<double>(), 0);
reduce(world, local_grad, g, std::plus<SparseVector<double> >(), 0);
#else
obj = local_obj;
@@ -354,13 +355,14 @@ int main(int argc, char** argv) {
// g /= (size_per_proc * size);
if (!o)
o.reset(new LBFGSOptimizer(FD::NumFeats(), lbfgs_memory_buffers));
- vector<double> gg(FD::NumFeats());
+ gg.clear();
+ gg.resize(FD::NumFeats());
if (gg.size() != cur_weights.size()) { cur_weights.resize(gg.size()); }
for (SparseVector<double>::const_iterator it = g.begin(); it != g.end(); ++it)
if (it->first) { gg[it->first] = it->second; }
g.clear();
double r = ApplyRegularizationTerms(regularization_strength,
- time_series_strength * (iter == 0 ? 0.0 : 1.0),
+ time_series_strength, // * (iter == 0 ? 0.0 : 1.0),
cur_weights,
prev_weights,
&gg);
@@ -375,10 +377,9 @@ int main(int argc, char** argv) {
o->Optimize(obj, gg, &cur_weights);
}
#ifdef HAVE_MPI
- // broadcast(world, x, 0);
+ broadcast(world, cur_weights, 0);
broadcast(world, converged, 0);
world.barrier();
- if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; }
#endif
}
prev_weights = cur_weights;