summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-09-14 13:12:01 +0100
committerChris Dyer <cdyer@cs.cmu.edu>2011-09-14 13:12:01 +0100
commit9ba06c6f1a7e751da245219da291e329efa2b7e5 (patch)
tree1754cfd98fff05762f70b98cec470493e05c7f30
parent70e9873251215b887ede2cc03855bfccf725b593 (diff)
fix pro train bug causing it not to optimize when there is no held-out test set
-rw-r--r--pro-train/mr_pro_reduce.cc30
1 files changed, 15 insertions, 15 deletions
diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc
index 239649c1..e71347ba 100644
--- a/pro-train/mr_pro_reduce.cc
+++ b/pro-train/mr_pro_reduce.cc
@@ -194,7 +194,6 @@ int main(int argc, char** argv) {
InitCommandLine(argc, argv, &conf);
string line;
vector<pair<bool, SparseVector<weight_t> > > training, testing;
- SparseVector<weight_t> old_weights;
const bool tune_regularizer = conf.count("tune_regularizer");
if (tune_regularizer && !conf.count("testset")) {
cerr << "--tune_regularizer requires --testset to be set\n";
@@ -202,28 +201,28 @@ int main(int argc, char** argv) {
}
const double min_reg = conf["min_reg"].as<double>();
const double max_reg = conf["max_reg"].as<double>();
- double sigsq = conf["sigma_squared"].as<double>();
+ double sigsq = conf["sigma_squared"].as<double>(); // will be overridden if parameter is tuned
assert(sigsq > 0.0);
assert(min_reg > 0.0);
assert(max_reg > 0.0);
assert(max_reg > min_reg);
const double psi = conf["interpolation"].as<double>();
if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; }
- if (conf.count("weights")) {
- vector<weight_t> dt;
- Weights::InitFromFile(conf["weights"].as<string>(), &dt);
- Weights::InitSparseVector(dt, &old_weights);
- }
ReadCorpus(&cin, &training);
if (conf.count("testset")) {
ReadFile rf(conf["testset"].as<string>());
ReadCorpus(rf.stream(), &testing);
}
cerr << "Number of features: " << FD::NumFeats() << endl;
- vector<weight_t> x(FD::NumFeats(), 0.0); // x[0] is bias
- for (SparseVector<weight_t>::const_iterator it = old_weights.begin();
- it != old_weights.end(); ++it)
- x[it->first] = it->second;
+
+ vector<weight_t> x, prev_x; // x[0] is bias
+ if (conf.count("weights")) {
+ Weights::InitFromFile(conf["weights"].as<string>(), &x);
+ prev_x = x;
+ }
+ cerr << " Number of features: " << x.size() << endl;
+ cerr << "Number of training examples: " << training.size() << endl;
+ cerr << "Number of testing examples: " << testing.size() << endl;
double tppl = 0.0;
vector<pair<double,double> > sp;
vector<double> smoothed;
@@ -255,11 +254,12 @@ int main(int argc, char** argv) {
}
}
sigsq = sp[best_i].first;
- tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as<unsigned>(), &x);
- }
+ } // tune regularizer
+ tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as<unsigned>(), &x);
if (conf.count("weights")) {
- for (int i = 1; i < x.size(); ++i)
- x[i] = (x[i] * psi) + old_weights.get(i) * (1.0 - psi);
+ for (int i = 1; i < x.size(); ++i) {
+ x[i] = (x[i] * psi) + prev_x[i] * (1.0 - psi);
+ }
}
cout.precision(15);
cout << "# sigma^2=" << sigsq << "\theld out perplexity=";