summaryrefslogtreecommitdiff
path: root/training/crf
diff options
context:
space:
mode:
Diffstat (limited to 'training/crf')
-rw-r--r--training/crf/mpi_online_optimize.cc12
1 files changed, 11 insertions, 1 deletions
diff --git a/training/crf/mpi_online_optimize.cc b/training/crf/mpi_online_optimize.cc
index d6968848..9e1ae34c 100644
--- a/training/crf/mpi_online_optimize.cc
+++ b/training/crf/mpi_online_optimize.cc
@@ -5,6 +5,7 @@
#include <cassert>
#include <cmath>
#include <tr1/memory>
+#include <ctime>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
@@ -41,6 +42,7 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("training_agenda,a",po::value<string>(), "Text file listing a series of configuration files and the number of iterations to train using each configuration successively")
("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(5), "Number of training instances evaluated per processor in each minibatch")
("optimization_method,m", po::value<string>()->default_value("sgd"), "Optimization method (sgd)")
+ ("max_walltime", po::value<unsigned>(), "Maximum walltime to run (in minutes)")
("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)")
("eta_0,e", po::value<double>()->default_value(0.2), "Initial learning rate for SGD (eta_0)")
("L1,1","Use L1 regularization")
@@ -304,6 +306,9 @@ int main(int argc, char** argv) {
int write_weights_every_ith = 100; // TODO configure
int titer = -1;
+ unsigned timeout = 0;
+ if (conf.count("max_walltime")) timeout = 60 * conf["max_walltime"].as<unsigned>();
+ const time_t start_time = time(NULL);
for (int ai = 0; ai < agenda.size(); ++ai) {
const string& cur_config = agenda[ai].first;
const unsigned max_iteration = agenda[ai].second;
@@ -336,9 +341,14 @@ int main(int argc, char** argv) {
ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz";
fname = o.str();
}
+ const time_t cur_time = time(NULL);
+ if (timeout) {
+ if ((cur_time - start_time) > timeout) converged = true;
+ }
if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; }
ostringstream vv;
- vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())) << " eta=" << lr->eta(titer);
+ double minutes = (cur_time - start_time) / 60.0;
+ vv << "total walltime=" << minutes << "min iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast<double>(corpus.size())) << " eta=" << lr->eta(titer);
const string svv = vv.str();
cerr << svv << endl;
Weights::WriteToFile(fname, lambdas, true, &svv);