diff options
Diffstat (limited to 'training/mpi_online_optimize.cc')
-rw-r--r-- | training/mpi_online_optimize.cc | 69 |
1 files changed, 24 insertions, 45 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 32033c19..2ef4a2e7 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -31,35 +31,6 @@ namespace mpi = boost::mpi; using namespace std; namespace po = boost::program_options; -void SanityCheck(const vector<double>& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); - } -} - -struct FComp { - const vector<double>& w_; - FComp(const vector<double>& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - -void ShowLargestFeatures(const vector<double>& w) { - vector<int> fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector<int>::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - for (vector<int>::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() @@ -250,10 +221,25 @@ int main(int argc, char** argv) { if (!InitCommandLine(argc, argv, &conf)) return 1; + vector<pair<string, int> > agenda; + if (!LoadAgenda(conf["training_agenda"].as<string>(), &agenda)) + return 1; + if (rank == 0) + cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n"; + + assert(agenda.size() > 0); + + if (1) { // hack to load the feature hash functions -- TODO this should not be in cdec.ini + const string& cur_config = agenda[0].first; + const unsigned max_iteration = agenda[0].second; + ReadFile ini_rf(cur_config); + Decoder decoder(ini_rf.stream()); + } + // load initial weights - Weights weights; + vector<weight_t> init_weights; if (conf.count("input_weights")) - weights.InitFromFile(conf["input_weights"].as<string>()); + Weights::InitFromFile(conf["input_weights"].as<string>(), &init_weights); vector<int> frozen_fids; if (conf.count("frozen_features")) { @@ -310,19 +296,12 @@ int main(int argc, char** argv) { rng.reset(new MT19937); SparseVector<double> x; - weights.InitSparseVector(&x); + Weights::InitSparseVector(init_weights, &x); TrainingObserver observer; int write_weights_every_ith = 100; // TODO configure int titer = -1; - vector<pair<string, int> > agenda; - if (!LoadAgenda(conf["training_agenda"].as<string>(), &agenda)) - return 1; - if (rank == 0) - cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n"; - - vector<double> lambdas; for (int ai = 0; ai < agenda.size(); ++ai) { const string& cur_config = agenda[ai].first; const unsigned max_iteration = agenda[ai].second; @@ -331,6 +310,8 @@ int main(int argc, char** argv) { // load cdec.ini and set up decoder ReadFile ini_rf(cur_config); Decoder decoder(ini_rf.stream()); + vector<weight_t>& lambdas = decoder.CurrentWeightVector(); + if (ai == 0) { lambdas.swap(init_weights); init_weights.clear(); } if (rank == 0) o->ResetEpoch(); // resets the learning rate-- TODO is this good? @@ -341,15 +322,13 @@ int main(int argc, char** argv) { #ifdef HAVE_MPI mpi::timer timer; #endif - weights.InitFromVector(x); - weights.InitVector(&lambdas); + x.init_vector(&lambdas); ++iter; ++titer; observer.Reset(); - decoder.SetWeights(lambdas); if (rank == 0) { converged = (iter == max_iteration); - SanityCheck(lambdas); - ShowLargestFeatures(lambdas); + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); string fname = "weights.cur.gz"; if (iter % write_weights_every_ith == 0) { ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz"; @@ -360,7 +339,7 @@ int main(int argc, char** argv) { vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())) << " eta=" << lr->eta(titer); const string svv = vv.str(); cerr << svv << endl; - weights.WriteToFile(fname, true, &svv); + Weights::WriteToFile(fname, lambdas, true, &svv); } for (int i = 0; i < size_per_proc; ++i) { |