From 4d90fcf6a24be10aa35c7783e6f2623e9fedf9d6 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Tue, 10 Sep 2013 18:20:16 +0200 Subject: simple pclr --- training/dtrain/dtrain.cc | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 0ee2f124..34c0a54a 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -40,6 +40,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("scale_bleu_diff", po::value()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") ("loss_margin", po::value()->default_value(0.), "update if no error in pref pair but model scores this near") ("max_pairs", po::value()->default_value(std::numeric_limits::max()), "max. # of pairs per Sent.") + ("pclr", po::value()->zero_tokens(), "use a (simple) per-coordinate learning rate") ("noup", po::value()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -124,6 +125,8 @@ main(int argc, char** argv) if (loss_margin > 9998.) loss_margin = std::numeric_limits::max(); bool scale_bleu_diff = false; if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true; + bool pclr = false; + if (cfg.count("pclr")) pclr = true; bool average = false; if (select_weights == "avg") average = true; @@ -131,7 +134,6 @@ main(int argc, char** argv) if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as(), boost::is_any_of(" ")); - // setup decoder register_feature_functions(); SetSilent(true); @@ -249,6 +251,8 @@ main(int argc, char** argv) cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as() << "'" << endl; if (rescale) cerr << setw(25) << "rescale " << rescale << endl; + if (pclr) + cerr << setw(25) << "pclr " << pclr << endl; cerr << setw(25) << "max pairs " << max_pairs << endl; cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; @@ -261,6 +265,8 @@ main(int argc, char** argv) if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl; } + // pclr + SparseVector learning_rates; for (unsigned t = 0; t < T; t++) // T epochs { @@ -385,7 +391,16 @@ main(int argc, char** argv) if (scale_bleu_diff) eta = it->first.score - it->second.score; if (rank_error || margin < loss_margin) { SparseVector diff_vec = it->first.f - it->second.f; - lambdas.plus_eq_v_times_s(diff_vec, eta); + if (pclr) { + SparseVector::iterator jt = diff_vec.begin(); + for (; jt != diff_vec.end(); ++it) { + jt->second *= max(0.0000001, eta/(eta+learning_rates[jt->first])); // FIXME + learning_rates[jt->first]++; + } + lambdas += diff_vec; + } else { + lambdas.plus_eq_v_times_s(diff_vec, eta); + } if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); } @@ -395,14 +410,14 @@ main(int argc, char** argv) // please note that this regularizations happen // after a _sentence_ -- not after each example/pair! if (l1naive) { - FastSparseVector::iterator it = lambdas.begin(); + SparseVector::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { it->second -= sign(it->second) * l1_reg; } } } else if (l1clip) { - FastSparseVector::iterator it = lambdas.begin(); + SparseVector::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { if (it->second != 0) { @@ -417,7 +432,7 @@ main(int argc, char** argv) } } else if (l1cumul) { weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input - FastSparseVector::iterator it = lambdas.begin(); + SparseVector::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { if (it->second != 0) { -- cgit v1.2.3