diff options
Diffstat (limited to 'pro-train/mr_pro_reduce.cc')
-rw-r--r-- | pro-train/mr_pro_reduce.cc | 82 |
1 files changed, 47 insertions, 35 deletions
diff --git a/pro-train/mr_pro_reduce.cc b/pro-train/mr_pro_reduce.cc index aff410a0..98cddba2 100644 --- a/pro-train/mr_pro_reduce.cc +++ b/pro-train/mr_pro_reduce.cc @@ -23,11 +23,12 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("weights,w", po::value<string>(), "Weights from previous iteration (used as initialization and interpolation") - ("interpolation,p",po::value<double>()->default_value(0.9), "Output weights are p*w + (1-p)*w_prev") + ("regularize_to_weights,y",po::value<double>()->default_value(0.0), "Differences in learned weights to previous weights are penalized with an l2 penalty with this strength; 0.0 = no effect") + ("interpolate_with_weights,p",po::value<double>()->default_value(1.0), "Output weights are p*w + (1-p)*w_prev; 1.0 = no effect") ("memory_buffers,m",po::value<unsigned>()->default_value(200), "Number of memory buffers (LBFGS)") - ("sigma_squared,s",po::value<double>()->default_value(0.1), "Sigma squared for Gaussian prior") - ("min_reg,r",po::value<double>()->default_value(1e-8), "When tuning (-T) regularization strength, minimum regularization strenght") - ("max_reg,R",po::value<double>()->default_value(10.0), "When tuning (-T) regularization strength, maximum regularization strenght") + ("regularization_strength,C",po::value<double>()->default_value(1.0), "l2 regularization strength") + ("min_reg,r",po::value<double>()->default_value(0.01), "When tuning (-T) regularization strength, minimum regularization strenght") + ("max_reg,R",po::value<double>()->default_value(1e6), "When tuning (-T) regularization strength, maximum regularization strenght") ("testset,t",po::value<string>(), "Optional held-out test set") ("tune_regularizer,T", "Use the held out test set (-t) to tune the regularization strength") ("help,h", "Help"); @@ -95,6 +96,27 @@ void GradAdd(const SparseVector<weight_t>& v, const double scale, vector<weight_ } } +double ApplyRegularizationTerms(const double C, + const double T, + const vector<weight_t>& weights, + const vector<weight_t>& prev_weights, + vector<weight_t>* g) { + assert(weights.size() == g->size()); + double reg = 0; + for (size_t i = 0; i < weights.size(); ++i) { + const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); + const double& w_i = weights[i]; + double& g_i = (*g)[i]; + reg += C * w_i * w_i; + g_i += 2 * C * w_i; + + const double diff_i = w_i - prev_w_i; + reg += T * diff_i * diff_i; + g_i += 2 * T * diff_i; + } + return reg; +} + double TrainingInference(const vector<weight_t>& x, const vector<pair<bool, SparseVector<weight_t> > >& corpus, vector<weight_t>* g = NULL) { @@ -134,8 +156,10 @@ double TrainingInference(const vector<weight_t>& x, // return held-out log likelihood double LearnParameters(const vector<pair<bool, SparseVector<weight_t> > >& training, const vector<pair<bool, SparseVector<weight_t> > >& testing, - const double sigsq, + const double C, + const double T, const unsigned memory_buffers, + const vector<weight_t>& prev_x, vector<weight_t>* px) { vector<weight_t>& x = *px; vector<weight_t> vg(FD::NumFeats(), 0.0); @@ -157,26 +181,12 @@ double LearnParameters(const vector<pair<bool, SparseVector<weight_t> > >& train } // handle regularizer -#if 1 - double norm = 0; - for (int i = 1; i < x.size(); ++i) { - const double mean_i = 0.0; - const double param = (x[i] - mean_i); - norm += param * param; - vg[i] += param / sigsq; - } - const double reg = norm / (2.0 * sigsq); -#else - double reg = 0; -#endif + double reg = ApplyRegularizationTerms(C, T, x, prev_x, &vg); cll += reg; - cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t TEST_PPL=" << tppl << "\t"; + cerr << cll << " (REG=" << reg << ")\tPPL=" << ppl << "\t TEST_PPL=" << tppl << "\t" << endl; try { - vector<weight_t> old_x = x; - do { - opt.Optimize(cll, vg, &x); - converged = opt.HasConverged(); - } while (!converged && x == old_x); + opt.Optimize(cll, vg, &x); + converged = opt.HasConverged(); } catch (...) { cerr << "Exception caught, assuming convergence is close enough...\n"; converged = true; @@ -201,13 +211,14 @@ int main(int argc, char** argv) { } const double min_reg = conf["min_reg"].as<double>(); const double max_reg = conf["max_reg"].as<double>(); - double sigsq = conf["sigma_squared"].as<double>(); // will be overridden if parameter is tuned - assert(sigsq > 0.0); + double C = conf["regularization_strength"].as<double>(); // will be overridden if parameter is tuned + const double T = conf["regularize_to_weights"].as<double>(); + assert(C > 0.0); assert(min_reg > 0.0); assert(max_reg > 0.0); assert(max_reg > min_reg); - const double psi = conf["interpolation"].as<double>(); - if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; } + const double psi = conf["interpolate_with_weights"].as<double>(); + if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; return 1; } ReadCorpus(&cin, &training); if (conf.count("testset")) { ReadFile rf(conf["testset"].as<string>()); @@ -231,14 +242,15 @@ int main(int argc, char** argv) { vector<pair<double,double> > sp; vector<double> smoothed; if (tune_regularizer) { - sigsq = min_reg; + C = min_reg; const double steps = 18; double sweep_factor = exp((log(max_reg) - log(min_reg)) / steps); cerr << "SWEEP FACTOR: " << sweep_factor << endl; - while(sigsq < max_reg) { - tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as<unsigned>(), &x); - sp.push_back(make_pair(sigsq, tppl)); - sigsq *= sweep_factor; + while(C < max_reg) { + cerr << "C=" << C << "\tT=" <<T << endl; + tppl = LearnParameters(training, testing, C, T, conf["memory_buffers"].as<unsigned>(), prev_x, &x); + sp.push_back(make_pair(C, tppl)); + C *= sweep_factor; } smoothed.resize(sp.size(), 0); smoothed[0] = sp[0].second; @@ -257,16 +269,16 @@ int main(int argc, char** argv) { best_i = i; } } - sigsq = sp[best_i].first; + C = sp[best_i].first; } // tune regularizer - tppl = LearnParameters(training, testing, sigsq, conf["memory_buffers"].as<unsigned>(), &x); + tppl = LearnParameters(training, testing, C, T, conf["memory_buffers"].as<unsigned>(), prev_x, &x); if (conf.count("weights")) { for (int i = 1; i < x.size(); ++i) { x[i] = (x[i] * psi) + prev_x[i] * (1.0 - psi); } } cout.precision(15); - cout << "# sigma^2=" << sigsq << "\theld out perplexity="; + cout << "# C=" << C << "\theld out perplexity="; if (tppl) { cout << tppl << endl; } else { cout << "N/A\n"; } if (sp.size()) { cout << "# Parameter sweep:\n"; |