summaryrefslogtreecommitdiff
path: root/training/mpi_batch_optimize.cc
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-10-15 20:13:01 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-10-15 20:13:01 +0000
commit99fc029872c94d59851237e2f64e66951ceec1d2 (patch)
tree0272a7832ee0e217d80007ce3f61c69dd9ed527d /training/mpi_batch_optimize.cc
parent0c5cbb5bc9ad7226213035ec27f3d20f7a74cd7c (diff)
new multi-epoch online optimizer
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@675 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'training/mpi_batch_optimize.cc')
-rw-r--r--training/mpi_batch_optimize.cc78
1 files changed, 69 insertions, 9 deletions
diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc
index 7953513e..f1ee9fb4 100644
--- a/training/mpi_batch_optimize.cc
+++ b/training/mpi_batch_optimize.cc
@@ -1,6 +1,5 @@
#include <sstream>
#include <iostream>
-#include <fstream>
#include <vector>
#include <cassert>
#include <cmath>
@@ -61,6 +60,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("input_weights,w",po::value<string>(),"Input feature weights file")
("training_data,t",po::value<string>(),"Training data")
("decoder_config,d",po::value<string>(),"Decoder configuration file")
+ ("sharded_input,s",po::value<string>(), "Corpus and grammar files are 'sharded' so each processor loads its own input and grammar file. Argument is the directory containing the shards.")
("output_weights,o",po::value<string>()->default_value("-"),"Output feature weights file")
("optimization_method,m", po::value<string>()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)")
("correction_buffers,M", po::value<int>()->default_value(10), "Number of gradients for LBFGS to maintain in memory")
@@ -82,11 +82,16 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
po::notify(*conf);
- if (conf->count("help") || !conf->count("input_weights") || !conf->count("training_data") || !conf->count("decoder_config")) {
+ if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data") | conf->count("sharded_input")) || !conf->count("decoder_config")) {
cerr << dcmdline_options << endl;
MPI::Finalize();
exit(1);
}
+ if (conf->count("training_data") && conf->count("sharded_input")) {
+ cerr << "Cannot specify both --training_data and --sharded_input\n";
+ MPI::Finalize();
+ exit(1);
+ }
}
void ReadTrainingCorpus(const string& fname, int rank, int size, vector<string>* c) {
@@ -183,32 +188,79 @@ struct TrainingObserver : public DecoderObserver {
int state;
};
+void ReadConfig(const string& ini, vector<string>* out) {
+ ReadFile rf(ini);
+ istream& in = *rf.stream();
+ while(in) {
+ string line;
+ getline(in, line);
+ if (!in) continue;
+ out->push_back(line);
+ }
+}
+
+void StoreConfig(const vector<string>& cfg, istringstream* o) {
+ ostringstream os;
+ for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; }
+ o->str(os.str());
+}
+
int main(int argc, char** argv) {
MPI::Init(argc, argv);
const int size = MPI::COMM_WORLD.Get_size();
const int rank = MPI::COMM_WORLD.Get_rank();
SetSilent(true); // turn off verbose decoder output
- cerr << "MPI: I am " << rank << '/' << size << endl;
register_feature_functions();
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
+ string shard_dir;
+ if (conf.count("sharded_input")) {
+ shard_dir = conf["sharded_input"].as<string>();
+ if (!DirectoryExists(shard_dir)) {
+ if (rank == 0) cerr << "Can't find shard directory: " << shard_dir << endl;
+ MPI::Finalize();
+ return 1;
+ }
+ if (rank == 0)
+ cerr << "Shard directory: " << shard_dir << endl;
+ }
+
// load initial weights
Weights weights;
+ if (rank == 0) { cerr << "Loading weights...\n"; }
weights.InitFromFile(conf["input_weights"].as<string>());
+ if (rank == 0) { cerr << "Done loading weights.\n"; }
// freeze feature set (should be optional?)
const bool freeze_feature_set = true;
if (freeze_feature_set) FD::Freeze();
// load cdec.ini and set up decoder
- ReadFile ini_rf(conf["decoder_config"].as<string>());
- Decoder decoder(ini_rf.stream());
- if (decoder.GetConf()["input"].as<string>() != "-") {
+ vector<string> cdec_ini;
+ ReadConfig(conf["decoder_config"].as<string>(), &cdec_ini);
+ if (shard_dir.size()) {
+ if (rank == 0) {
+ for (int i = 0; i < cdec_ini.size(); ++i) {
+ if (cdec_ini[i].find("grammar=") == 0) {
+ cerr << "!!! using sharded input and " << conf["decoder_config"].as<string>() << " contains a grammar specification:\n" << cdec_ini[i] << "\n VERIFY THAT THIS IS CORRECT!\n";
+ }
+ }
+ }
+ ostringstream g;
+ g << "grammar=" << shard_dir << "/grammar." << rank << "_of_" << size << ".gz";
+ cdec_ini.push_back(g.str());
+ }
+ istringstream ini;
+ StoreConfig(cdec_ini, &ini);
+ if (rank == 0) cerr << "Loading grammar...\n";
+ Decoder* decoder = new Decoder(&ini);
+ if (decoder->GetConf()["input"].as<string>() != "-") {
cerr << "cdec.ini must not set an input file\n";
MPI::COMM_WORLD.Abort(1);
}
+ if (rank == 0) cerr << "Done loading grammar!\n";
const int num_feats = FD::NumFeats();
if (rank == 0) cerr << "Number of features: " << num_feats << endl;
@@ -247,8 +299,16 @@ int main(int argc, char** argv) {
vector<double> gradient(num_feats, 0.0);
vector<double> rcv_grad(num_feats, 0.0);
bool converged = false;
+
vector<string> corpus;
- ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus);
+ if (shard_dir.size()) {
+ ostringstream os; os << shard_dir << "/corpus." << rank << "_of_" << size;
+ ReadTrainingCorpus(os.str(), 0, 1, &corpus);
+ cerr << os.str() << " has " << corpus.size() << " training examples. " << endl;
+ if (corpus.size() > 500) { corpus.resize(500); cerr << " TRUNCATING\n"; }
+ } else {
+ ReadTrainingCorpus(conf["training_data"].as<string>(), rank, size, &corpus);
+ }
assert(corpus.size() > 0);
TrainingObserver observer;
@@ -257,9 +317,9 @@ int main(int argc, char** argv) {
if (rank == 0) {
cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n";
}
- decoder.SetWeights(lambdas);
+ decoder->SetWeights(lambdas);
for (int i = 0; i < corpus.size(); ++i)
- decoder.Decode(corpus[i], &observer);
+ decoder->Decode(corpus[i], &observer);
fill(gradient.begin(), gradient.end(), 0);
fill(rcv_grad.begin(), rcv_grad.end(), 0);