From 6802ac200ef614b4935d597ed4cfc3857c1f6c06 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 20 Nov 2012 13:56:08 -0500 Subject: fixes for 2011 optimizer --- training/crf/mpi_online_optimize.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/crf/mpi_online_optimize.cc b/training/crf/mpi_online_optimize.cc index d6968848..9e1ae34c 100644 --- a/training/crf/mpi_online_optimize.cc +++ b/training/crf/mpi_online_optimize.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -41,6 +42,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("training_agenda,a",po::value(), "Text file listing a series of configuration files and the number of iterations to train using each configuration successively") ("minibatch_size_per_proc,s", po::value()->default_value(5), "Number of training instances evaluated per processor in each minibatch") ("optimization_method,m", po::value()->default_value("sgd"), "Optimization method (sgd)") + ("max_walltime", po::value(), "Maximum walltime to run (in minutes)") ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") ("eta_0,e", po::value()->default_value(0.2), "Initial learning rate for SGD (eta_0)") ("L1,1","Use L1 regularization") @@ -304,6 +306,9 @@ int main(int argc, char** argv) { int write_weights_every_ith = 100; // TODO configure int titer = -1; + unsigned timeout = 0; + if (conf.count("max_walltime")) timeout = 60 * conf["max_walltime"].as(); + const time_t start_time = time(NULL); for (int ai = 0; ai < agenda.size(); ++ai) { const string& cur_config = agenda[ai].first; const unsigned max_iteration = agenda[ai].second; @@ -336,9 +341,14 @@ int main(int argc, char** argv) { ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz"; fname = o.str(); } + const time_t cur_time = time(NULL); + if (timeout) { + if ((cur_time - start_time) > timeout) converged = true; + } if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; } ostringstream vv; - vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())) << " eta=" << lr->eta(titer); + double minutes = (cur_time - start_time) / 60.0; + vv << "total walltime=" << minutes << "min iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())) << " eta=" << lr->eta(titer); const string svv = vv.str(); cerr << svv << endl; Weights::WriteToFile(fname, lambdas, true, &svv); -- cgit v1.2.3