diff options
| -rw-r--r-- | training/mpi_online_optimize.cc | 18 | 
1 files changed, 14 insertions, 4 deletions
diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc index 95b462bb..62821aa3 100644 --- a/training/mpi_online_optimize.cc +++ b/training/mpi_online_optimize.cc @@ -85,7 +85,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    }    po::notify(*conf); -  if (conf->count("help") || !conf->count("input_weights") || !conf->count("training_data") || !conf->count("decoder_config")) { +  if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) {      cerr << dcmdline_options << endl;      MPI::Finalize();      exit(1); @@ -201,14 +201,14 @@ int main(int argc, char** argv) {    cerr << "MPI: I am " << rank << '/' << size << endl;    register_feature_functions();    MT19937* rng = NULL; -  if (rank == 0) rng = new MT19937;    po::variables_map conf;    InitCommandLine(argc, argv, &conf);    // load initial weights    Weights weights; -  weights.InitFromFile(conf["input_weights"].as<string>()); +  if (conf.count("input_weights")) +    weights.InitFromFile(conf["input_weights"].as<string>());    // freeze feature set    const bool freeze_feature_set = conf.count("freeze_feature_set"); @@ -228,6 +228,7 @@ int main(int argc, char** argv) {    std::tr1::shared_ptr<OnlineOptimizer> o;    std::tr1::shared_ptr<LearningRateSchedule> lr; +  vector<int> order;    if (rank == 0) {      // TODO config      lr.reset(new ExponentialDecayLearningRate(corpus.size(), conf["eta_0"].as<double>())); @@ -239,12 +240,21 @@ int main(int argc, char** argv) {      } else {        assert(!"fail");      } + +    // randomize corpus +    rng = new MT19937; +    order.resize(corpus.size()); +    for (unsigned i = 0; i < order.size(); ++i) order[i]=i; +    Shuffle(&order, rng);    }    double objective = 0;    vector<double> lambdas;    weights.InitVector(&lambdas);    bool converged = false; - +  const unsigned size_per_proc = conf["minibatch_size_per_proc"].as<unsigned>(); +  for (int i = 0; i < size_per_proc; ++i) +    cerr << "i=" << i << ": " << order[i] << endl; +  abort();    TrainingObserver observer;    while (!converged) {      observer.Reset();  | 
