diff options
Diffstat (limited to 'rst_parser/mst_train.cc')
-rw-r--r-- | rst_parser/mst_train.cc | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc index b5114726..c5cab6ec 100644 --- a/rst_parser/mst_train.cc +++ b/rst_parser/mst_train.cc @@ -23,7 +23,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { string cfg_file; opts.add_options() ("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)") - ("feature_function,F",po::value<vector<string> >()->composing(), "feature function") + ("feature_function,F",po::value<vector<string> >()->composing(), "feature function (multiple permitted)") + ("weights,w",po::value<string>(), "Optional starting weights") + ("output_every_i_iterations,I",po::value<unsigned>()->default_value(1), "Write weights every I iterations") ("regularization_strength,C",po::value<double>()->default_value(1.0), "Regularization strength") ("correction_buffers,m", po::value<int>()->default_value(10), "LBFGS correction buffers"); po::options_description clo("Command line options"); @@ -161,9 +163,13 @@ int main(int argc, char** argv) { if (flag) cerr << endl; //cerr << "EMP: " << empirical << endl; //DE vector<weight_t> weights(FD::NumFeats(), 0.0); + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as<string>(), &weights); vector<weight_t> g(FD::NumFeats(), 0.0); cerr << "features initialized\noptimizing...\n"; boost::shared_ptr<BatchOptimizer> o; + int every = corpus.size() / 20; + if (!every) ++every; o.reset(new LBFGSOptimizer(g.size(), conf["correction_buffers"].as<int>())); int iterations = 1000; for (int iter = 0; iter < iterations; ++iter) { @@ -174,11 +180,12 @@ int main(int argc, char** argv) { double obj = -empirical.dot(weights); // SparseVector<double> mfm; //DE for (int i = 0; i < corpus.size(); ++i) { + if ((i + 1) % every == 0) cerr << '.' << flush; const int num_words = corpus[i].ts.words.size(); forests[i].Reweight(weights); - double lz; - forests[i].EdgeMarginals(&lz); - obj -= lz; + prob_t z; + forests[i].EdgeMarginals(&z); + obj -= log(z); //cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << " OO= " << (-corpus[i].features.dot(weights) - lz) << endl; //cerr << " ZZ = " << zz << endl; for (int h = -1; h < num_words; ++h) { @@ -202,14 +209,20 @@ int main(int argc, char** argv) { gnorm += g[i]*g[i]; ostringstream ll; ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm); - cerr << endl << ll.str() << endl; + cerr << ' ' << ll.str().substr(ll.str().find('\t')+1) << endl; obj += r; assert(obj >= 0); o->Optimize(obj, g, &weights); Weights::ShowLargestFeatures(weights); - string sl = ll.str(); - Weights::WriteToFile(o->HasConverged() ? "weights.final.gz" : "weights.cur.gz", weights, true, &sl); - if (o->HasConverged()) { cerr << "CONVERGED\n"; break; } + const bool converged = o->HasConverged(); + const char* ofname = converged ? "weights.final.gz" : "weights.cur.gz"; + if (converged || ((iter+1) % conf["output_every_i_iterations"].as<unsigned>()) == 0) { + cerr << "writing..." << flush; + const string sl = ll.str(); + Weights::WriteToFile(ofname, weights, true, &sl); + cerr << "done" << endl; + } + if (converged) { cerr << "CONVERGED\n"; break; } } forests[0].Reweight(weights); TreeSampler ts(forests[0]); |