diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/mpi_online_optimize.cc | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index d662e8bd..509fbf15 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -4,10 +4,9 @@ #include <vector> #include <cassert> #include <cmath> +#include <tr1/memory> -#include <mpi.h> #include <boost/mpi.hpp> -#include <boost/shared_ptr.hpp> #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> @@ -24,8 +23,8 @@ #include "sparse_vector.h" #include "sampler.h" + using namespace std; -using boost::shared_ptr; namespace po = boost::program_options; void SanityCheck(const vector<double>& w) { @@ -57,13 +56,14 @@ void ShowLargestFeatures(const vector<double>& w) { cerr << endl; } -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("input_weights,w",po::value<string>(),"Input feature weights file") ("training_data,t",po::value<string>(),"Training data corpus") ("decoder_config,c",po::value<string>(),"Decoder configuration file") ("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file") + ("maximum_iteration,i", po::value<unsigned>(), "Maximum number of iterations") ("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(5), "Number of training instances evaluated per processor in each minibatch") ("freeze_feature_set,Z", "The feature set specified in the initial weights file is frozen throughout the duration of training") ("optimization_method,m", po::value<string>()->default_value("sgd"), "Optimization method (sgd)") @@ -89,9 +89,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; - MPI::Finalize(); - exit(1); + return false; } + return true; } void ReadTrainingCorpus(const string& fname, vector<string>* c) { @@ -220,7 +220,8 @@ int main(int argc, char** argv) { std::tr1::shared_ptr<MT19937> rng; po::variables_map conf; - InitCommandLine(argc, argv, &conf); + if (!InitCommandLine(argc, argv, &conf)) + return 1; // load initial weights Weights weights; @@ -292,6 +293,10 @@ int main(int argc, char** argv) { observer.Reset(); decoder.SetWeights(lambdas); if (rank == 0) { + if (conf.count("maximum_iteration")) { + if (iter == conf["maximum_iteration"].as<unsigned>()) + converged = true; + } SanityCheck(lambdas); ShowLargestFeatures(lambdas); string fname = "weights.cur.gz"; |