diff options
Diffstat (limited to 'training/mpi_batch_optimize.cc')
-rw-r--r-- | training/mpi_batch_optimize.cc | 78 |
1 files changed, 69 insertions, 9 deletions
diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index 7953513e..f1ee9fb4 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -1,6 +1,5 @@ #include <sstream> #include <iostream> -#include <fstream> #include <vector> #include <cassert> #include <cmath> @@ -61,6 +60,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("input_weights,w",po::value<string>(),"Input feature weights file") ("training_data,t",po::value<string>(),"Training data") ("decoder_config,d",po::value<string>(),"Decoder configuration file") + ("sharded_input,s",po::value<string>(), "Corpus and grammar files are 'sharded' so each processor loads its own input and grammar file. Argument is the directory containing the shards.") ("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file") ("optimization_method,m", po::value<string>()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") ("correction_buffers,M", po::value<int>()->default_value(10), "Number of gradients for LBFGS to maintain in memory") @@ -82,11 +82,16 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } po::notify(*conf); - if (conf->count("help") || !conf->count("input_weights") || !conf->count("training_data") || !conf->count("decoder_config")) { + if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; MPI::Finalize(); exit(1); } + if (conf->count("training_data") && conf->count("sharded_input")) { + cerr << "Cannot specify both --training_data and --sharded_input\n"; + MPI::Finalize(); + exit(1); + } } void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) { @@ -183,32 +188,79 @@ struct TrainingObserver : public DecoderObserver { int state; }; +void ReadConfig(const string& ini, vector<string>* out) { + ReadFile rf(ini); + istream& in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (!in) continue; + out->push_back(line); + } +} + +void StoreConfig(const vector<string>& cfg, istringstream* o) { + ostringstream os; + for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } + o->str(os.str()); +} + int main(int argc, char** argv) { MPI::Init(argc, argv); const int size = MPI::COMM_WORLD.Get_size(); const int rank = MPI::COMM_WORLD.Get_rank(); SetSilent(true); // turn off verbose decoder output - cerr << "MPI: I am " << rank << '/' << size << endl; register_feature_functions(); po::variables_map conf; InitCommandLine(argc, argv, &conf); + string shard_dir; + if (conf.count("sharded_input")) { + shard_dir = conf["sharded_input"].as<string>(); + if (!DirectoryExists(shard_dir)) { + if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl; + MPI::Finalize(); + return 1; + } + if (rank == 0) + cerr << "Shard directory: " << shard_dir << endl; + } + // load initial weights Weights weights; + if (rank == 0) { cerr << "Loading weights...\n"; } weights.InitFromFile(conf["input_weights"].as<string>()); + if (rank == 0) { cerr << "Done loading weights.\n"; } // freeze feature set (should be optional?) const bool freeze_feature_set = true; if (freeze_feature_set) FD::Freeze(); // load cdec.ini and set up decoder - ReadFile ini_rf(conf["decoder_config"].as<string>()); - Decoder decoder(ini_rf.stream()); - if (decoder.GetConf()["input"].as<string>() != "-") { + vector<string> cdec_ini; + ReadConfig(conf["decoder_config"].as<string>(), &cdec_ini); + if (shard_dir.size()) { + if (rank == 0) { + for (int i = 0; i < cdec_ini.size(); ++i) { + if (cdec_ini[i].find("grammar=") == 0) { + cerr << "!!! using sharded input and " << conf["decoder_config"].as<string>() << " contains a grammar specification:\n" << cdec_ini[i] << "\n VERIFY THAT THIS IS CORRECT!\n"; + } + } + } + ostringstream g; + g << "grammar=" << shard_dir << "/grammar." << rank << "_of_" << size << ".gz"; + cdec_ini.push_back(g.str()); + } + istringstream ini; + StoreConfig(cdec_ini, &ini); + if (rank == 0) cerr << "Loading grammar...\n"; + Decoder* decoder = new Decoder(&ini); + if (decoder->GetConf()["input"].as<string>() != "-") { cerr << "cdec.ini must not set an input file\n"; MPI::COMM_WORLD.Abort(1); } + if (rank == 0) cerr << "Done loading grammar!\n"; const int num_feats = FD::NumFeats(); if (rank == 0) cerr << "Number of features: " << num_feats << endl; @@ -247,8 +299,16 @@ int main(int argc, char** argv) { vector<double> gradient(num_feats, 0.0); vector<double> rcv_grad(num_feats, 0.0); bool converged = false; + vector<string> corpus; - ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus); + if (shard_dir.size()) { + ostringstream os; os << shard_dir << "/corpus." << rank << "_of_" << size; + ReadTrainingCorpus(os.str(), 0, 1, &corpus); + cerr << os.str() << " has " << corpus.size() << " training examples. " << endl; + if (corpus.size() > 500) { corpus.resize(500); cerr << " TRUNCATING\n"; } + } else { + ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus); + } assert(corpus.size() > 0); TrainingObserver observer; @@ -257,9 +317,9 @@ int main(int argc, char** argv) { if (rank == 0) { cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; } - decoder.SetWeights(lambdas); + decoder->SetWeights(lambdas); for (int i = 0; i < corpus.size(); ++i) - decoder.Decode(corpus[i], &observer); + decoder->Decode(corpus[i], &observer); fill(gradient.begin(), gradient.end(), 0); fill(rcv_grad.begin(), rcv_grad.end(), 0); |