From 9ba06c6f1a7e751da245219da291e329efa2b7e5 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 14 Sep 2011 13:12:01 +0100 Subject: fix pro train bug causing it not to optimize when there is no held-out test set --- pro-train/mr_pro_reduce.cc | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) (limited to 'pro-train') diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 239649c1..e71347ba 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -194,7 +194,6 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); string line; vector > > training, testing; - SparseVector old_weights; const bool tune_regularizer = conf.count("tune_regularizer"); if (tune_regularizer && !conf.count("testset")) { cerr << "--tune_regularizer requires --testset to be set\n"; @@ -202,28 +201,28 @@ int main(int argc, char** argv) { } const double min_reg = conf["min_reg"].as(); const double max_reg = conf["max_reg"].as(); - double sigsq = conf["sigma_squared"].as(); + double sigsq = conf["sigma_squared"].as(); // will be overridden if parameter is tuned assert(sigsq > 0.0); assert(min_reg > 0.0); assert(max_reg > 0.0); assert(max_reg > min_reg); const double psi = conf["interpolation"].as(); if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } - if (conf.count("weights")) { - vector dt; - Weights::InitFromFile(conf["weights"].as(), &dt); - Weights::InitSparseVector(dt, &old_weights); - } ReadCorpus(&cin, &training); if (conf.count("testset")) { ReadFile rf(conf["testset"].as()); ReadCorpus(rf.stream(), &testing); } cerr << "Number of features: " << FD::NumFeats() << endl; - vector x(FD::NumFeats(), 0.0); // x[0] is bias - for (SparseVector::const_iterator it = old_weights.begin(); - it != old_weights.end(); ++it) - x[it->first] = it->second; + + vector x, prev_x; // x[0] is bias + if (conf.count("weights")) { + Weights::InitFromFile(conf["weights"].as(), &x); + prev_x = x; + } + cerr << " Number of features: " << x.size() << endl; + cerr << "Number of training examples: " << training.size() << endl; + cerr << "Number of testing examples: " << testing.size() << endl; double tppl = 0.0; vector > sp; vector smoothed; @@ -255,11 +254,12 @@ int main(int argc, char** argv) { } } sigsq = sp[best_i].first; - tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); - } + } // tune regularizer + tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as(), &x); if (conf.count("weights")) { - for (int i = 1; i < x.size(); ++i) - x[i] = (x[i] * psi) + old_weights.get(i) * (1.0 - psi); + for (int i = 1; i < x.size(); ++i) { + x[i] = (x[i] * psi) + prev_x[i] * (1.0 - psi); + } } cout.precision(15); cout << "# sigma^2=" << sigsq << "\theld out perplexity="; -- cgit v1.2.3