diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-09-13 17:36:23 +0100 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-09-13 17:36:23 +0100 |
commit | 251da4347ea356f799e6c227ac8cf541c0cef2f2 (patch) | |
tree | 407e647e34aa89049754d83e9e1eb2cddff05de8 /pro-train/mr_pro_reduce.cc | |
parent | 75bff8e374f3cdcf3dc141f8b7b37858d0611234 (diff) |
get rid of bad Weights class so it no longer keeps a copy of a vector inside it
Diffstat (limited to 'pro-train/mr_pro_reduce.cc')
-rw-r--r-- | pro-train/mr_pro_reduce.cc | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 9b422f33..9caaa1d1 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -194,7 +194,7 @@ int main(int argc, char** argv) { InitCommandLine(argc, argv, &conf); string line; vector<pair<bool, SparseVector<double> > > training, testing; - SparseVector<double> old_weights; + SparseVector<weight_t> old_weights; const bool tune_regularizer = conf.count("tune_regularizer"); if (tune_regularizer && !conf.count("testset")) { cerr << "--tune_regularizer requires --testset to be set\n"; @@ -210,9 +210,9 @@ int main(int argc, char** argv) { const double psi = conf["interpolation"].as<double>(); if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } if (conf.count("weights")) { - Weights w; - w.InitFromFile(conf["weights"].as<string>()); - w.InitSparseVector(&old_weights); + vector<weight_t> dt; + Weights::InitFromFile(conf["weights"].as<string>(), &dt); + Weights::InitSparseVector(dt, &old_weights); } ReadCorpus(&cin, &training); if (conf.count("testset")) { @@ -220,8 +220,8 @@ int main(int argc, char** argv) { ReadCorpus(rf.stream(), &testing); } cerr << "Number of features: " << FD::NumFeats() << endl; - vector<double> x(FD::NumFeats(), 0.0); // x[0] is bias - for (SparseVector<double>::const_iterator it = old_weights.begin(); + vector<weight_t> x(FD::NumFeats(), 0.0); // x[0] is bias + for (SparseVector<weight_t>::const_iterator it = old_weights.begin(); it != old_weights.end(); ++it) x[it->first] = it->second; double tppl = 0.0; @@ -257,7 +257,6 @@ int main(int argc, char** argv) { sigsq = sp[best_i].first; tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as<unsigned>(), &x); } - Weights w; if (conf.count("weights")) { for (int i = 1; i < x.size(); ++i) x[i] = (x[i] * psi) + old_weights.get(i) * (1.0 - psi); @@ -271,7 +270,6 @@ int main(int argc, char** argv) { cout << "# " << sp[i].first << "\t" << sp[i].second << "\t" << smoothed[i] << endl; } } - w.InitFromVector(x); - w.WriteToFile("-"); + Weights::WriteToFile("-", x); return 0; } |