summaryrefslogtreecommitdiff
path: root/pro-train/mr_pro_reduce.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-07-12 22:34:34 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-07-12 22:34:34 -0400
commitc87835f5f94b3aa954682133c40117b3f8e26585 (patch)
treedfe5e8cffbdbf3b911d7ef6fc9d7eb8508a28d89 /pro-train/mr_pro_reduce.cc
parent9ab32f74dd821f08cb5863faf88d40ca60301688 (diff)
debugged pro trainer
Diffstat (limited to 'pro-train/mr_pro_reduce.cc')
-rw-r--r--pro-train/mr_pro_reduce.cc57
1 files changed, 33 insertions, 24 deletions
diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc
index 2b9c5ce7..e1a7db8a 100644
--- a/pro-train/mr_pro_reduce.cc
+++ b/pro-train/mr_pro_reduce.cc
@@ -24,7 +24,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("weights,w", po::value<string>(), "Weights from previous iteration (used as initialization and interpolation")
("interpolation,p",po::value<double>()->default_value(0.9), "Output weights are p*w + (1-p)*w_prev")
("memory_buffers,m",po::value<unsigned>()->default_value(200), "Number of memory buffers (LBFGS)")
- ("sigma_squared,s",po::value<double>()->default_value(0.5), "Sigma squared for Gaussian prior")
+ ("sigma_squared,s",po::value<double>()->default_value(1.0), "Sigma squared for Gaussian prior")
("help,h", "Help");
po::options_description dcmdline_options;
dcmdline_options.add(opts);
@@ -35,6 +35,31 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
}
+void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) {
+ SparseVector<double>& x = *out;
+ size_t last_start = cur;
+ size_t last_comma = string::npos;
+ while(cur <= line.size()) {
+ if (line[cur] == ' ' || cur == line.size()) {
+ if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) {
+ cerr << "[ERROR] " << line << endl << " position = " << cur << endl;
+ exit(1);
+ }
+ const int fid = FD::Convert(line.substr(last_start, last_comma - last_start));
+ if (cur < line.size()) line[cur] = 0;
+ const double val = strtod(&line[last_comma + 1], NULL);
+ x.set_value(fid, val);
+
+ last_comma = string::npos;
+ last_start = cur+1;
+ } else {
+ if (line[cur] == '=')
+ last_comma = cur;
+ }
+ ++cur;
+ }
+}
+
int main(int argc, char** argv) {
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
@@ -60,28 +85,7 @@ int main(int argc, char** argv) {
assert(ks == 1);
const bool y = line[0] == '1';
SparseVector<double> x;
- size_t last_start = ks + 1;
- size_t last_comma = string::npos;
- size_t cur = last_start;
- while(cur <= line.size()) {
- if (line[cur] == ' ' || cur == line.size()) {
- if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) {
- cerr << "[ERROR] " << line << endl << " position = " << cur << endl;
- exit(1);
- }
- const int fid = FD::Convert(line.substr(last_start, last_comma - last_start));
- if (cur < line.size()) line[cur] = 0;
- const double val = strtod(&line[last_comma + 1], NULL);
- x.set_value(fid, val);
-
- last_comma = string::npos;
- last_start = cur+1;
- } else {
- if (line[cur] == '=')
- last_comma = cur;
- }
- ++cur;
- }
+ ParseSparseVector(line, ks + 1, &x);
training.push_back(make_pair(y, x));
}
if (flag) cerr << endl;
@@ -95,6 +99,7 @@ int main(int argc, char** argv) {
SparseVector<double> g;
bool converged = false;
LBFGSOptimizer opt(FD::NumFeats(), conf["memory_buffers"].as<unsigned>());
+ double ppl = 0;
while(!converged) {
double cll = 0;
double dbias = 0;
@@ -114,14 +119,18 @@ int main(int argc, char** argv) {
lp_false*=-1;
if (training[i].first) { // true label
cll -= lp_true;
+ ppl += lp_true / log(2);
g -= training[i].second * exp(lp_false);
dbias -= exp(lp_false);
} else { // false label
cll -= lp_false;
+ ppl += lp_false / log(2);
g += training[i].second * exp(lp_true);
dbias += exp(lp_true);
}
}
+ ppl /= training.size();
+ ppl = pow(2.0, - ppl);
vg.clear();
g.init_vector(&vg);
vg[0] = dbias;
@@ -139,7 +148,7 @@ int main(int argc, char** argv) {
double reg = 0;
#endif
cll += reg;
- cerr << cll << " (REG=" << reg << ")\t";
+ cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t";
bool failed = false;
try {
opt.Optimize(cll, vg, &x);