From 251da4347ea356f799e6c227ac8cf541c0cef2f2 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 13 Sep 2011 17:36:23 +0100 Subject: get rid of bad Weights class so it no longer keeps a copy of a vector inside it --- training/mpi_batch_optimize.cc | 127 +++++++++-------------------------------- 1 file changed, 27 insertions(+), 100 deletions(-) (limited to 'training/mpi_batch_optimize.cc') diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc index 39a8af7d..cc5953f6 100644 --- a/training/mpi_batch_optimize.cc +++ b/training/mpi_batch_optimize.cc @@ -31,42 +31,12 @@ using namespace std; using boost::shared_ptr; namespace po = boost::program_options; -void SanityCheck(const vector& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); - } -} - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - -void ShowLargestFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - for (vector::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("input_weights,w",po::value(),"Input feature weights file") ("training_data,t",po::value(),"Training data") ("decoder_config,d",po::value(),"Decoder configuration file") - ("sharded_input,s",po::value(), "Corpus and grammar files are 'sharded' so each processor loads its own input and grammar file. Argument is the directory containing the shards.") ("output_weights,o",po::value()->default_value("-"),"Output feature weights file") ("optimization_method,m", po::value()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") ("correction_buffers,M", po::value()->default_value(10), "Number of gradients for LBFGS to maintain in memory") @@ -88,14 +58,10 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { } po::notify(*conf); - if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) { + if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data")) || !conf->count("decoder_config")) { cerr << dcmdline_options << endl; return false; } - if (conf->count("training_data") && conf->count("sharded_input")) { - cerr << "Cannot specify both --training_data and --sharded_input\n"; - return false; - } return true; } @@ -236,42 +202,9 @@ int main(int argc, char** argv) { po::variables_map conf; if (!InitCommandLine(argc, argv, &conf)) return 1; - string shard_dir; - if (conf.count("sharded_input")) { - shard_dir = conf["sharded_input"].as(); - if (!DirectoryExists(shard_dir)) { - if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl; - return 1; - } - if (rank == 0) - cerr << "Shard directory: " << shard_dir << endl; - } - - // load initial weights - Weights weights; - if (rank == 0) { cerr << "Loading weights...\n"; } - weights.InitFromFile(conf["input_weights"].as()); - if (rank == 0) { cerr << "Done loading weights.\n"; } - - // freeze feature set (should be optional?) - const bool freeze_feature_set = true; - if (freeze_feature_set) FD::Freeze(); - // load cdec.ini and set up decoder vector cdec_ini; ReadConfig(conf["decoder_config"].as(), &cdec_ini); - if (shard_dir.size()) { - if (rank == 0) { - for (int i = 0; i < cdec_ini.size(); ++i) { - if (cdec_ini[i].find("grammar=") == 0) { - cerr << "!!! using sharded input and " << conf["decoder_config"].as() << " contains a grammar specification:\n" << cdec_ini[i] << "\n VERIFY THAT THIS IS CORRECT!\n"; - } - } - } - ostringstream g; - g << "grammar=" << shard_dir << "/grammar." << rank << "_of_" << size << ".gz"; - cdec_ini.push_back(g.str()); - } istringstream ini; StoreConfig(cdec_ini, &ini); if (rank == 0) cerr << "Loading grammar...\n"; @@ -282,22 +215,28 @@ int main(int argc, char** argv) { } if (rank == 0) cerr << "Done loading grammar!\n"; + // load initial weights + if (rank == 0) { cerr << "Loading weights...\n"; } + vector& lambdas = decoder->CurrentWeightVector(); + Weights::InitFromFile(conf["input_weights"].as(), &lambdas); + if (rank == 0) { cerr << "Done loading weights.\n"; } + + // freeze feature set (should be optional?) + const bool freeze_feature_set = true; + if (freeze_feature_set) FD::Freeze(); + const int num_feats = FD::NumFeats(); if (rank == 0) cerr << "Number of features: " << num_feats << endl; + lambdas.resize(num_feats); + const bool gaussian_prior = conf.count("gaussian_prior"); - vector means(num_feats, 0); + vector means(num_feats, 0); if (conf.count("means")) { if (!gaussian_prior) { cerr << "Don't use --means without --gaussian_prior!\n"; exit(1); } - Weights wm; - wm.InitFromFile(conf["means"].as()); - if (num_feats != FD::NumFeats()) { - cerr << "[ERROR] Means file had unexpected features!\n"; - exit(1); - } - wm.InitVector(&means); + Weights::InitFromFile(conf["means"].as(), &means); } shared_ptr o; if (rank == 0) { @@ -309,26 +248,13 @@ int main(int argc, char** argv) { cerr << "Optimizer: " << o->Name() << endl; } double objective = 0; - vector lambdas(num_feats, 0.0); - weights.InitVector(&lambdas); - if (lambdas.size() != num_feats) { - cerr << "Initial weights file did not have all features specified!\n feats=" - << num_feats << "\n weights file=" << lambdas.size() << endl; - lambdas.resize(num_feats, 0.0); - } vector gradient(num_feats, 0.0); - vector rcv_grad(num_feats, 0.0); + vector rcv_grad; + rcv_grad.clear(); bool converged = false; vector corpus; - if (shard_dir.size()) { - ostringstream os; os << shard_dir << "/corpus." << rank << "_of_" << size; - ReadTrainingCorpus(os.str(), 0, 1, &corpus); - cerr << os.str() << " has " << corpus.size() << " training examples. " << endl; - if (corpus.size() > 500) { corpus.resize(500); cerr << " TRUNCATING\n"; } - } else { - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); - } + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); assert(corpus.size() > 0); TrainingObserver observer; @@ -341,19 +267,20 @@ int main(int argc, char** argv) { if (rank == 0) { cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; } - decoder->SetWeights(lambdas); for (int i = 0; i < corpus.size(); ++i) decoder->Decode(corpus[i], &observer); cerr << " process " << rank << '/' << size << " done\n"; fill(gradient.begin(), gradient.end(), 0); - fill(rcv_grad.begin(), rcv_grad.end(), 0); observer.SetLocalGradientAndObjective(&gradient, &objective); double to = 0; #ifdef HAVE_MPI + rcv_grad.resize(num_feats, 0.0); mpi::reduce(world, &gradient[0], gradient.size(), &rcv_grad[0], plus(), 0); - mpi::reduce(world, objective, to, plus(), 0); swap(gradient, rcv_grad); + rcv_grad.clear(); + + mpi::reduce(world, objective, to, plus(), 0); objective = to; #endif @@ -378,7 +305,7 @@ int main(int argc, char** argv) { for (int i = 0; i < gradient.size(); ++i) gnorm += gradient[i] * gradient[i]; cerr << " GNORM=" << sqrt(gnorm) << endl; - vector old = lambdas; + vector old = lambdas; int c = 0; while (old == lambdas) { ++c; @@ -387,9 +314,8 @@ int main(int argc, char** argv) { assert(c < 5); } old.clear(); - SanityCheck(lambdas); - ShowLargestFeatures(lambdas); - weights.InitFromVector(lambdas); + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); converged = o->HasConverged(); if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } @@ -399,7 +325,7 @@ int main(int argc, char** argv) { ostringstream vv; vv << "Objective = " << objective << " (eval count=" << o->EvaluationCount() << ")"; const string svv = vv.str(); - weights.WriteToFile(fname, true, &svv); + Weights::WriteToFile(fname, lambdas, true, &svv); } // rank == 0 int cint = converged; #ifdef HAVE_MPI @@ -411,3 +337,4 @@ int main(int argc, char** argv) { } return 0; } + -- cgit v1.2.3