From 2b9bc7c2f2a4a3993196a0a3f89f4007143ad351 Mon Sep 17 00:00:00 2001 From: redpony Date: Fri, 15 Oct 2010 20:30:49 +0000 Subject: few fixes git-svn-id: https://ws10smt.googlecode.com/svn/trunk@676 ec762483-ff6d-05da-a07a-a48fb63a330f --- training/mpi_online_optimize.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'training') diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 4c08b181..96f3bcfb 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -220,7 +220,6 @@ bool LoadAgenda(const string& file, vector >* a) { x.first = line.substr(0,d); x.second = atoi(line.substr(d+1).c_str()); a->push_back(x); - cerr << "X: " << x.second << " - " << x.first << "'\n"; if (!FileExists(x.first)) { cerr << "Can't find file " << x.first << endl; return false; @@ -261,7 +260,11 @@ int main(int argc, char** argv) { return 1; } + size_t total_corpus_size = 0; + reduce(world, corpus.size(), total_corpus_size, std::plus(), 0); + if (rank == 0) { + cerr << "Total corpus size: " << total_corpus_size << endl; const unsigned batch_size = size_per_proc * size; // TODO config lr.reset(new ExponentialDecayLearningRate(batch_size, conf["eta_0"].as())); @@ -269,7 +272,7 @@ int main(int argc, char** argv) { const string omethod = conf["optimization_method"].as(); if (omethod == "sgd") { const double C = conf["regularization_strength"].as(); - o.reset(new CumulativeL1OnlineOptimizer(lr, corpus.size(), C)); + o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C)); } else { assert(!"fail"); } @@ -302,7 +305,8 @@ int main(int argc, char** argv) { ReadFile ini_rf(cur_config); Decoder decoder(ini_rf.stream()); - o->ResetEpoch(); // resets the learning rate-- TODO is this good? + if (rank == 0) + o->ResetEpoch(); // resets the learning rate-- TODO is this good? int iter = -1; bool converged = false; @@ -324,7 +328,7 @@ int main(int argc, char** argv) { } if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; } ostringstream vv; - vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size * size_per_proc / static_cast(corpus.size())) << " eta=" << lr->eta(titer); + vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())) << " eta=" << lr->eta(titer); const string svv = vv.str(); cerr << svv << endl; weights.WriteToFile(fname, true, &svv); @@ -343,7 +347,6 @@ int main(int argc, char** argv) { if (rank == 0) { g /= (size_per_proc * size); o->UpdateWeights(g, FD::NumFeats(), &x); - cerr << "XX: " << x << endl; } broadcast(world, x, 0); broadcast(world, converged, 0); -- cgit v1.2.3