diff options
Diffstat (limited to 'pro-train/mr_pro_reduce.cc')
-rw-r--r-- | pro-train/mr_pro_reduce.cc | 57 |
1 files changed, 33 insertions, 24 deletions
diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index 2b9c5ce7..e1a7db8a 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -24,7 +24,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("weights,w", po::value<string>(), "Weights from previous iteration (used as initialization and interpolation") ("interpolation,p",po::value<double>()->default_value(0.9), "Output weights are p*w + (1-p)*w_prev") ("memory_buffers,m",po::value<unsigned>()->default_value(200), "Number of memory buffers (LBFGS)") - ("sigma_squared,s",po::value<double>()->default_value(0.5), "Sigma squared for Gaussian prior") + ("sigma_squared,s",po::value<double>()->default_value(1.0), "Sigma squared for Gaussian prior") ("help,h", "Help"); po::options_description dcmdline_options; dcmdline_options.add(opts); @@ -35,6 +35,31 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } } +void ParseSparseVector(string& line, size_t cur, SparseVector<double>* out) { + SparseVector<double>& x = *out; + size_t last_start = cur; + size_t last_comma = string::npos; + while(cur <= line.size()) { + if (line[cur] == ' ' || cur == line.size()) { + if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { + cerr << "[ERROR] " << line << endl << " position = " << cur << endl; + exit(1); + } + const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); + if (cur < line.size()) line[cur] = 0; + const double val = strtod(&line[last_comma + 1], NULL); + x.set_value(fid, val); + + last_comma = string::npos; + last_start = cur+1; + } else { + if (line[cur] == '=') + last_comma = cur; + } + ++cur; + } +} + int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); @@ -60,28 +85,7 @@ int main(int argc, char** argv) { assert(ks == 1); const bool y = line[0] == '1'; SparseVector<double> x; - size_t last_start = ks + 1; - size_t last_comma = string::npos; - size_t cur = last_start; - while(cur <= line.size()) { - if (line[cur] == ' ' || cur == line.size()) { - if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { - cerr << "[ERROR] " << line << endl << " position = " << cur << endl; - exit(1); - } - const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); - if (cur < line.size()) line[cur] = 0; - const double val = strtod(&line[last_comma + 1], NULL); - x.set_value(fid, val); - - last_comma = string::npos; - last_start = cur+1; - } else { - if (line[cur] == '=') - last_comma = cur; - } - ++cur; - } + ParseSparseVector(line, ks + 1, &x); training.push_back(make_pair(y, x)); } if (flag) cerr << endl; @@ -95,6 +99,7 @@ int main(int argc, char** argv) { SparseVector<double> g; bool converged = false; LBFGSOptimizer opt(FD::NumFeats(), conf["memory_buffers"].as<unsigned>()); + double ppl = 0; while(!converged) { double cll = 0; double dbias = 0; @@ -114,14 +119,18 @@ int main(int argc, char** argv) { lp_false*=-1; if (training[i].first) { // true label cll -= lp_true; + ppl += lp_true / log(2); g -= training[i].second * exp(lp_false); dbias -= exp(lp_false); } else { // false label cll -= lp_false; + ppl += lp_false / log(2); g += training[i].second * exp(lp_true); dbias += exp(lp_true); } } + ppl /= training.size(); + ppl = pow(2.0, - ppl); vg.clear(); g.init_vector(&vg); vg[0] = dbias; @@ -139,7 +148,7 @@ int main(int argc, char** argv) { double reg = 0; #endif cll += reg; - cerr << cll << " (REG=" << reg << ")\t"; + cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t"; bool failed = false; try { opt.Optimize(cll, vg, &x); |