summaryrefslogtreecommitdiff
path: root/pro-train/mr_pro_reduce.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-09-13 17:36:23 +0100
committerChris Dyer <cdyer@cs.cmu.edu>2011-09-13 17:36:23 +0100
commit251da4347ea356f799e6c227ac8cf541c0cef2f2 (patch)
tree407e647e34aa89049754d83e9e1eb2cddff05de8 /pro-train/mr_pro_reduce.cc
parent75bff8e374f3cdcf3dc141f8b7b37858d0611234 (diff)
get rid of bad Weights class so it no longer keeps a copy of a vector inside it
Diffstat (limited to 'pro-train/mr_pro_reduce.cc')
-rw-r--r--pro-train/mr_pro_reduce.cc16
1 files changed, 7 insertions, 9 deletions
diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc
index 9b422f33..9caaa1d1 100644
--- a/pro-train/mr_pro_reduce.cc
+++ b/pro-train/mr_pro_reduce.cc
@@ -194,7 +194,7 @@ int main(int argc, char** argv) {
InitCommandLine(argc, argv, &conf);
string line;
vector<pair<bool, SparseVector<double> > > training, testing;
- SparseVector<double> old_weights;
+ SparseVector<weight_t> old_weights;
const bool tune_regularizer = conf.count("tune_regularizer");
if (tune_regularizer && !conf.count("testset")) {
cerr << "--tune_regularizer requires --testset to be set\n";
@@ -210,9 +210,9 @@ int main(int argc, char** argv) {
const double psi = conf["interpolation"].as<double>();
if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; }
if (conf.count("weights")) {
- Weights w;
- w.InitFromFile(conf["weights"].as<string>());
- w.InitSparseVector(&old_weights);
+ vector<weight_t> dt;
+ Weights::InitFromFile(conf["weights"].as<string>(), &dt);
+ Weights::InitSparseVector(dt, &old_weights);
}
ReadCorpus(&cin, &training);
if (conf.count("testset")) {
@@ -220,8 +220,8 @@ int main(int argc, char** argv) {
ReadCorpus(rf.stream(), &testing);
}
cerr << "Number of features: " << FD::NumFeats() << endl;
- vector<double> x(FD::NumFeats(), 0.0); // x[0] is bias
- for (SparseVector<double>::const_iterator it = old_weights.begin();
+ vector<weight_t> x(FD::NumFeats(), 0.0); // x[0] is bias
+ for (SparseVector<weight_t>::const_iterator it = old_weights.begin();
it != old_weights.end(); ++it)
x[it->first] = it->second;
double tppl = 0.0;
@@ -257,7 +257,6 @@ int main(int argc, char** argv) {
sigsq = sp[best_i].first;
tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as<unsigned>(), &x);
}
- Weights w;
if (conf.count("weights")) {
for (int i = 1; i < x.size(); ++i)
x[i] = (x[i] * psi) + old_weights.get(i) * (1.0 - psi);
@@ -271,7 +270,6 @@ int main(int argc, char** argv) {
cout << "# " << sp[i].first << "\t" << sp[i].second << "\t" << smoothed[i] << endl;
}
}
- w.InitFromVector(x);
- w.WriteToFile("-");
+ Weights::WriteToFile("-", x);
return 0;
}