summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/mpi_online_optimize.cc19
1 files changed, 12 insertions, 7 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc
index d662e8bd..509fbf15 100644
--- a/training/mpi_online_optimize.cc
+++ b/training/mpi_online_optimize.cc
@@ -4,10 +4,9 @@
#include <vector>
#include <cassert>
#include <cmath>
+#include <tr1/memory>
-#include <mpi.h>
#include <boost/mpi.hpp>
-#include <boost/shared_ptr.hpp>
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
@@ -24,8 +23,8 @@
#include "sparse_vector.h"
#include "sampler.h"
+
using namespace std;
-using boost::shared_ptr;
namespace po = boost::program_options;
void SanityCheck(const vector<double>& w) {
@@ -57,13 +56,14 @@ void ShowLargestFeatures(const vector<double>& w) {
cerr << endl;
}
-void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+bool InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
("input_weights,w",po::value<string>(),"Input feature weights file")
("training_data,t",po::value<string>(),"Training data corpus")
("decoder_config,c",po::value<string>(),"Decoder configuration file")
("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file")
+ ("maximum_iteration,i", po::value<unsigned>(), "Maximum number of iterations")
("minibatch_size_per_proc,s", po::value<unsigned>()->default_value(5), "Number of training instances evaluated per processor in each minibatch")
("freeze_feature_set,Z", "The feature set specified in the initial weights file is frozen throughout the duration of training")
("optimization_method,m", po::value<string>()->default_value("sgd"), "Optimization method (sgd)")
@@ -89,9 +89,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) {
cerr << dcmdline_options << endl;
- MPI::Finalize();
- exit(1);
+ return false;
}
+ return true;
}
void ReadTrainingCorpus(const string& fname, vector<string>* c) {
@@ -220,7 +220,8 @@ int main(int argc, char** argv) {
std::tr1::shared_ptr<MT19937> rng;
po::variables_map conf;
- InitCommandLine(argc, argv, &conf);
+ if (!InitCommandLine(argc, argv, &conf))
+ return 1;
// load initial weights
Weights weights;
@@ -292,6 +293,10 @@ int main(int argc, char** argv) {
observer.Reset();
decoder.SetWeights(lambdas);
if (rank == 0) {
+ if (conf.count("maximum_iteration")) {
+ if (iter == conf["maximum_iteration"].as<unsigned>())
+ converged = true;
+ }
SanityCheck(lambdas);
ShowLargestFeatures(lambdas);
string fname = "weights.cur.gz";