diff options
Diffstat (limited to 'training/crf/mpi_online_optimize.cc')
-rw-r--r-- | training/crf/mpi_online_optimize.cc | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/training/crf/mpi_online_optimize.cc b/training/crf/mpi_online_optimize.cc index d6968848..9e1ae34c 100644 --- a/training/crf/mpi_online_optimize.cc +++ b/training/crf/mpi_online_optimize.cc @@ -5,6 +5,7 @@ #include <cassert> #include <cmath> #include <tr1/memory> +#include <ctime> #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> @@ -41,6 +42,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("training_agenda,a",po::value<string>(), "Text file listing a series of configuration files and the number of iterations to train using each configuration successively") ("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(5), "Number of training instances evaluated per processor in each minibatch") ("optimization_method,m", po::value<string>()->default_value("sgd"), "Optimization method (sgd)") + ("max_walltime", po::value<unsigned>(), "Maximum walltime to run (in minutes)") ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") ("eta_0,e", po::value<double>()->default_value(0.2), "Initial learning rate for SGD (eta_0)") ("L1,1","Use L1 regularization") @@ -304,6 +306,9 @@ int main(int argc, char** argv) { int write_weights_every_ith = 100; // TODO configure int titer = -1; + unsigned timeout = 0; + if (conf.count("max_walltime")) timeout = 60 * conf["max_walltime"].as<unsigned>(); + const time_t start_time = time(NULL); for (int ai = 0; ai < agenda.size(); ++ai) { const string& cur_config = agenda[ai].first; const unsigned max_iteration = agenda[ai].second; @@ -336,9 +341,14 @@ int main(int argc, char** argv) { ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz"; fname = o.str(); } + const time_t cur_time = time(NULL); + if (timeout) { + if ((cur_time - start_time) > timeout) converged = true; + } if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; } ostringstream vv; - vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())) << " eta=" << lr->eta(titer); + double minutes = (cur_time - start_time) / 60.0; + vv << "total walltime=" << minutes << "min iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())) << " eta=" << lr->eta(titer); const string svv = vv.str(); cerr << svv << endl; Weights::WriteToFile(fname, lambdas, true, &svv); |