summaryrefslogtreecommitdiff
path: root/pro-train
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-07-13 16:25:05 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-07-13 16:25:05 -0400
commit34fdc73e613bbc30d59d7bd36c5db31a94a7ac68 (patch)
tree8c32bf116211f553fb484595787c302bd2b61924 /pro-train
parenta037f52a87f7d5711b5521047e7fb3fcd756c647 (diff)
faster code, optional held-out test set
Diffstat (limited to 'pro-train')
-rw-r--r--pro-train/mr_pro_reduce.cc140
1 files changed, 89 insertions, 51 deletions
diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc
index 5382e1a5..491ceb3a 100644
--- a/pro-train/mr_pro_reduce.cc
+++ b/pro-train/mr_pro_reduce.cc
@@ -7,6 +7,7 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "filelib.h"
#include "weights.h"
#include "sparse_vector.h"
#include "optimize.h"
@@ -25,6 +26,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("interpolation,p",po::value<double>()->default_value(0.9), "Output weights are p*w + (1-p)*w_prev")
("memory_buffers,m",po::value<unsigned>()->default_value(200), "Number of memory buffers (LBFGS)")
("sigma_squared,s",po::value<double>()->default_value(1.0), "Sigma squared for Gaussian prior")
+ ("testset,t",po::value<string>(), "Optional held-out test set to tune regularizer")
("help,h", "Help");
po::options_description dcmdline_options;
dcmdline_options.add(opts);
@@ -60,13 +62,79 @@ void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) {
}
}
+void ReadCorpus(istream* pin, vector<pair<bool, SparseVector<double> > >* corpus) {
+ istream& in = *pin;
+ corpus->clear();
+ bool flag = false;
+ int lc = 0;
+ string line;
+ SparseVector<double> x;
+ while(getline(in, line)) {
+ ++lc;
+ if (lc % 1000 == 0) { cerr << '.'; flag = true; }
+ if (lc % 40000 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
+ if (line.empty()) continue;
+ const size_t ks = line.find("\t");
+ assert(string::npos != ks);
+ assert(ks == 1);
+ const bool y = line[0] == '1';
+ x.clear();
+ ParseSparseVector(line, ks + 1, &x);
+ corpus->push_back(make_pair(y, x));
+ }
+ if (flag) cerr << endl;
+}
+
+void GradAdd(const SparseVector<double>& v, const double scale, vector<double>* acc) {
+ for (SparseVector<double>::const_iterator it = v.begin();
+ it != v.end(); ++it) {
+ (*acc)[it->first] += it->second * scale;
+ }
+}
+
+double TrainingInference(const vector<double>& x,
+ const vector<pair<bool, SparseVector<double> > >& corpus,
+ vector<double>* g = NULL) {
+ if (g) fill(g->begin(), g->end(), 0.0);
+
+ double cll = 0;
+ for (int i = 0; i < corpus.size(); ++i) {
+ const double dotprod = corpus[i].second.dot(x) + x[0]; // x[0] is bias
+ double lp_false = dotprod;
+ double lp_true = -dotprod;
+ if (0 < lp_true) {
+ lp_true += log1p(exp(-lp_true));
+ lp_false = log1p(exp(lp_false));
+ } else {
+ lp_true = log1p(exp(lp_true));
+ lp_false += log1p(exp(-lp_false));
+ }
+ lp_true*=-1;
+ lp_false*=-1;
+ if (corpus[i].first) { // true label
+ cll -= lp_true;
+ if (g) {
+ // g -= corpus[i].second * exp(lp_false);
+ GradAdd(corpus[i].second, -exp(lp_false), g);
+ (*g)[0] -= exp(lp_false); // bias
+ }
+ } else { // false label
+ cll -= lp_false;
+ if (g) {
+ // g += corpus[i].second * exp(lp_true);
+ GradAdd(corpus[i].second, exp(lp_true), g);
+ (*g)[0] += exp(lp_true); // bias
+ }
+ }
+ }
+ return cll;
+}
+
int main(int argc, char** argv) {
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
string line;
- vector<pair<bool, SparseVector<double> > > training;
- int lc = 0;
- bool flag = false;
+ vector<pair<bool, SparseVector<double> > > training, testing;
SparseVector<double> old_weights;
const double psi = conf["interpolation"].as<double>();
if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; }
@@ -75,20 +143,11 @@ int main(int argc, char** argv) {
w.InitFromFile(conf["weights"].as<string>());
w.InitSparseVector(&old_weights);
}
- while(getline(cin, line)) {
- ++lc;
- if (lc % 1000 == 0) { cerr << '.'; flag = true; }
- if (lc % 40000 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
- if (line.empty()) continue;
- const size_t ks = line.find("\t");
- assert(string::npos != ks);
- assert(ks == 1);
- const bool y = line[0] == '1';
- SparseVector<double> x;
- ParseSparseVector(line, ks + 1, &x);
- training.push_back(make_pair(y, x));
+ ReadCorpus(&cin, &training);
+ if (conf.count("testset")) {
+ ReadFile rf(conf["testset"].as<string>());
+ ReadCorpus(rf.stream(), &testing);
}
- if (flag) cerr << endl;
cerr << "Number of features: " << FD::NumFeats() << endl;
vector<double> x(FD::NumFeats(), 0.0); // x[0] is bias
@@ -96,44 +155,23 @@ int main(int argc, char** argv) {
it != old_weights.end(); ++it)
x[it->first] = it->second;
vector<double> vg(FD::NumFeats(), 0.0);
- SparseVector<double> g;
bool converged = false;
LBFGSOptimizer opt(FD::NumFeats(), conf["memory_buffers"].as<unsigned>());
- double ppl = 0;
while(!converged) {
- double cll = 0;
- double dbias = 0;
- g.clear();
- for (int i = 0; i < training.size(); ++i) {
- const double dotprod = training[i].second.dot(x) + x[0]; // x[0] is bias
- double lp_false = dotprod;
- double lp_true = -dotprod;
- if (0 < lp_true) {
- lp_true += log1p(exp(-lp_true));
- lp_false = log1p(exp(lp_false));
- } else {
- lp_true = log1p(exp(lp_true));
- lp_false += log1p(exp(-lp_false));
- }
- lp_true*=-1;
- lp_false*=-1;
- if (training[i].first) { // true label
- cll -= lp_true;
- ppl += lp_true / log(2);
- g -= training[i].second * exp(lp_false);
- dbias -= exp(lp_false);
- } else { // false label
- cll -= lp_false;
- ppl += lp_false / log(2);
- g += training[i].second * exp(lp_true);
- dbias += exp(lp_true);
- }
- }
+ double cll = TrainingInference(x, training, &vg);
+ double ppl = cll / log(2);
ppl /= training.size();
- ppl = pow(2.0, - ppl);
- vg.clear();
- g.init_vector(&vg);
- vg[0] = dbias;
+ ppl = pow(2.0, ppl);
+ double tppl = 0.0;
+
+ // evaluate optional held-out test set
+ if (testing.size()) {
+ tppl = TrainingInference(x, testing) / log(2);
+ tppl /= testing.size();
+ tppl = pow(2.0, tppl);
+ }
+
+ // handle regularizer
#if 1
const double sigsq = conf["sigma_squared"].as<double>();
double norm = 0;
@@ -148,7 +186,7 @@ int main(int argc, char** argv) {
double reg = 0;
#endif
cll += reg;
- cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t";
+ cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t TEST_PPL=" << tppl << "\t";
try {
vector<double> old_x = x;
do {