diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-22 18:28:00 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-09-22 18:28:00 +0000 |
commit | 46c1d1c56535cdaf71e920b12f1cbdcbd1bd9d4f (patch) | |
tree | 7ba40a52456b967200a0f09e9aed939e79af81e4 /training/mpi_online_optimize.cc | |
parent | 716356a5de1eb7d066d511af51b76e1294609d87 (diff) |
few fixes
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@655 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/mpi_online_optimize.cc')
-rw-r--r-- | training/mpi_online_optimize.cc | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 95b462bb..62821aa3 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -85,7 +85,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } po::notify(*conf); - if (conf->count("help") || !conf->count("input_weights") || !conf->count("training_data") || !conf->count("decoder_config")) { + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; MPI::Finalize(); exit(1); @@ -201,14 +201,14 @@ int main(int argc, char** argv) { cerr << "MPI: I am " << rank << '/' << size << endl; register_feature_functions(); MT19937* rng = NULL; - if (rank == 0) rng = new MT19937; po::variables_map conf; InitCommandLine(argc, argv, &conf); // load initial weights Weights weights; - weights.InitFromFile(conf["input_weights"].as<string>()); + if (conf.count("input_weights")) + weights.InitFromFile(conf["input_weights"].as<string>()); // freeze feature set const bool freeze_feature_set = conf.count("freeze_feature_set"); @@ -228,6 +228,7 @@ int main(int argc, char** argv) { std::tr1::shared_ptr<OnlineOptimizer> o; std::tr1::shared_ptr<LearningRateSchedule> lr; + vector<int> order; if (rank == 0) { // TODO config lr.reset(new ExponentialDecayLearningRate(corpus.size(), conf["eta_0"].as<double>())); @@ -239,12 +240,21 @@ int main(int argc, char** argv) { } else { assert(!"fail"); } + + // randomize corpus + rng = new MT19937; + order.resize(corpus.size()); + for (unsigned i = 0; i < order.size(); ++i) order[i]=i; + Shuffle(&order, rng); } double objective = 0; vector<double> lambdas; weights.InitVector(&lambdas); bool converged = false; - + const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>(); + for (int i = 0; i < size_per_proc; ++i) + cerr << "i=" << i << ": " << order[i] << endl; + abort(); TrainingObserver observer; while (!converged) { observer.Reset(); |