summaryrefslogtreecommitdiff
path: root/rst_parser/mst_train.cc
diff options
context:
space:
mode:
Diffstat (limited to 'rst_parser/mst_train.cc')
-rw-r--r--rst_parser/mst_train.cc29
1 files changed, 21 insertions, 8 deletions
diff --git a/rst_parser/mst_train.cc b/rst_parser/mst_train.cc
index b5114726..c5cab6ec 100644
--- a/rst_parser/mst_train.cc
+++ b/rst_parser/mst_train.cc
@@ -23,7 +23,9 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
string cfg_file;
opts.add_options()
("training_data,t",po::value<string>()->default_value("-"), "File containing training data (jsent format)")
- ("feature_function,F",po::value<vector<string> >()->composing(), "feature function")
+ ("feature_function,F",po::value<vector<string> >()->composing(), "feature function (multiple permitted)")
+ ("weights,w",po::value<string>(), "Optional starting weights")
+ ("output_every_i_iterations,I",po::value<unsigned>()->default_value(1), "Write weights every I iterations")
("regularization_strength,C",po::value<double>()->default_value(1.0), "Regularization strength")
("correction_buffers,m", po::value<int>()->default_value(10), "LBFGS correction buffers");
po::options_description clo("Command line options");
@@ -161,9 +163,13 @@ int main(int argc, char** argv) {
if (flag) cerr << endl;
//cerr << "EMP: " << empirical << endl; //DE
vector<weight_t> weights(FD::NumFeats(), 0.0);
+ if (conf.count("weights"))
+ Weights::InitFromFile(conf["weights"].as<string>(), &weights);
vector<weight_t> g(FD::NumFeats(), 0.0);
cerr << "features initialized\noptimizing...\n";
boost::shared_ptr<BatchOptimizer> o;
+ int every = corpus.size() / 20;
+ if (!every) ++every;
o.reset(new LBFGSOptimizer(g.size(), conf["correction_buffers"].as<int>()));
int iterations = 1000;
for (int iter = 0; iter < iterations; ++iter) {
@@ -174,11 +180,12 @@ int main(int argc, char** argv) {
double obj = -empirical.dot(weights);
// SparseVector<double> mfm; //DE
for (int i = 0; i < corpus.size(); ++i) {
+ if ((i + 1) % every == 0) cerr << '.' << flush;
const int num_words = corpus[i].ts.words.size();
forests[i].Reweight(weights);
- double lz;
- forests[i].EdgeMarginals(&lz);
- obj -= lz;
+ prob_t z;
+ forests[i].EdgeMarginals(&z);
+ obj -= log(z);
//cerr << " O = " << (-corpus[i].features.dot(weights)) << " D=" << -lz << " OO= " << (-corpus[i].features.dot(weights) - lz) << endl;
//cerr << " ZZ = " << zz << endl;
for (int h = -1; h < num_words; ++h) {
@@ -202,14 +209,20 @@ int main(int argc, char** argv) {
gnorm += g[i]*g[i];
ostringstream ll;
ll << "ITER=" << (iter+1) << "\tOBJ=" << (obj+r) << "\t[F=" << obj << " R=" << r << "]\tGnorm=" << sqrt(gnorm);
- cerr << endl << ll.str() << endl;
+ cerr << ' ' << ll.str().substr(ll.str().find('\t')+1) << endl;
obj += r;
assert(obj >= 0);
o->Optimize(obj, g, &weights);
Weights::ShowLargestFeatures(weights);
- string sl = ll.str();
- Weights::WriteToFile(o->HasConverged() ? "weights.final.gz" : "weights.cur.gz", weights, true, &sl);
- if (o->HasConverged()) { cerr << "CONVERGED\n"; break; }
+ const bool converged = o->HasConverged();
+ const char* ofname = converged ? "weights.final.gz" : "weights.cur.gz";
+ if (converged || ((iter+1) % conf["output_every_i_iterations"].as<unsigned>()) == 0) {
+ cerr << "writing..." << flush;
+ const string sl = ll.str();
+ Weights::WriteToFile(ofname, weights, true, &sl);
+ cerr << "done" << endl;
+ }
+ if (converged) { cerr << "CONVERGED\n"; break; }
}
forests[0].Reweight(weights);
TreeSampler ts(forests[0]);