diff options
| -rw-r--r-- | training/mpi_online_optimize.cc | 13 | 
1 files changed, 8 insertions, 5 deletions
| diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 4c08b181..96f3bcfb 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -220,7 +220,6 @@ bool LoadAgenda(const string& file, vector<pair<string, int> >* a) {      x.first = line.substr(0,d);      x.second = atoi(line.substr(d+1).c_str());      a->push_back(x); -    cerr << "X: " << x.second << " - " << x.first << "'\n";      if (!FileExists(x.first)) {        cerr << "Can't find file " << x.first << endl;        return false; @@ -261,7 +260,11 @@ int main(int argc, char** argv) {      return 1;    } +  size_t total_corpus_size = 0; +  reduce(world, corpus.size(), total_corpus_size, std::plus<size_t>(), 0); +    if (rank == 0) { +    cerr << "Total corpus size: " << total_corpus_size << endl;      const unsigned batch_size = size_per_proc * size;      // TODO config      lr.reset(new ExponentialDecayLearningRate(batch_size, conf["eta_0"].as<double>())); @@ -269,7 +272,7 @@ int main(int argc, char** argv) {      const string omethod = conf["optimization_method"].as<string>();      if (omethod == "sgd") {        const double C = conf["regularization_strength"].as<double>(); -      o.reset(new CumulativeL1OnlineOptimizer(lr, corpus.size(), C)); +      o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C));      } else {        assert(!"fail");      } @@ -302,7 +305,8 @@ int main(int argc, char** argv) {      ReadFile ini_rf(cur_config);      Decoder decoder(ini_rf.stream()); -    o->ResetEpoch(); // resets the learning rate-- TODO is this good? +    if (rank == 0) +      o->ResetEpoch(); // resets the learning rate-- TODO is this good?      int iter = -1;      bool converged = false; @@ -324,7 +328,7 @@ int main(int argc, char** argv) {          }          if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; }          ostringstream vv; -        vv << "total iter=" << titer << " (of current config iter=" << iter << ")  minibatch=" << size_per_proc << " sentences/proc x " << size << " procs.   num_feats=" << x.size() << '/' << FD::NumFeats() << "   passes_thru_data=" << (titer * size * size_per_proc / static_cast<double>(corpus.size())) << "   eta=" << lr->eta(titer); +        vv << "total iter=" << titer << " (of current config iter=" << iter << ")  minibatch=" << size_per_proc << " sentences/proc x " << size << " procs.   num_feats=" << x.size() << '/' << FD::NumFeats() << "   passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())) << "   eta=" << lr->eta(titer);          const string svv = vv.str();          cerr << svv << endl;          weights.WriteToFile(fname, true, &svv); @@ -343,7 +347,6 @@ int main(int argc, char** argv) {        if (rank == 0) {          g /= (size_per_proc * size);          o->UpdateWeights(g, FD::NumFeats(), &x); -        cerr << "XX: " << x << endl;        }        broadcast(world, x, 0);        broadcast(world, converged, 0); | 
