diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
| -rw-r--r-- | training/dtrain/dtrain.cc | 31 | 
1 files changed, 18 insertions, 13 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 9d60a903..38a9b69a 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -40,7 +40,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)      ("scale_bleu_diff",   po::value<bool>()->zero_tokens(),                      "learning rate <- bleu diff of a misranked pair")      ("loss_margin",       po::value<weight_t>()->default_value(0.),  "update if no error in pref pair but model scores this near")      ("max_pairs",         po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") -    ("pclr",              po::value<bool>()->zero_tokens(),                         "use a (simple) per-coordinate learning rate") +    ("pclr",              po::value<string>()->default_value("no"),         "use a (simple|adagrad) per-coordinate learning rate")      ("noup",              po::value<bool>()->zero_tokens(),                                               "do not update weights");    po::options_description cl("Command Line Options");    cl.add_options() @@ -125,8 +125,7 @@ main(int argc, char** argv)    if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max();    bool scale_bleu_diff = false;    if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true; -  bool pclr = false; -  if (cfg.count("pclr")) pclr = true; +  const string pclr = cfg["pclr"].as<string>();    bool average = false;    if (select_weights == "avg")      average = true; @@ -190,7 +189,6 @@ main(int argc, char** argv)    weight_t gamma = cfg["gamma"].as<weight_t>();    // faster perceptron: consider only misranked pairs, see -  // DO NOT ENABLE  WITH SVM (gamma > 0) OR loss_margin!    bool faster_perceptron = false;    if (gamma==0 && loss_margin==0) faster_perceptron = true; @@ -251,8 +249,7 @@ main(int argc, char** argv)        cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;      if (rescale)        cerr << setw(25) << "rescale " << rescale << endl; -    if (pclr) -      cerr << setw(25) << "pclr " << pclr << endl; +    cerr << setw(25) << "pclr " << pclr << endl;      cerr << setw(25) << "max pairs " << max_pairs << endl;      cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;      cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; @@ -392,22 +389,30 @@ main(int argc, char** argv)          if (scale_bleu_diff) eta = it->first.score - it->second.score;          if (rank_error || margin < loss_margin) {            SparseVector<weight_t> diff_vec = it->first.f - it->second.f; -          if (pclr) { +          if (pclr != "no") {              sum_up += diff_vec;            } else {              lambdas.plus_eq_v_times_s(diff_vec, eta); +            if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); // FIXME            } -          if (gamma) -            lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));          }        }        // per-coordinate learning rate -      if (pclr) { +      if (pclr != "no") {          SparseVector<weight_t>::iterator it = sum_up.begin(); -        for (; it != lambdas.end(); ++it) { -          learning_rates[it->first]++; -          lambdas[it->first] += it->second / learning_rates[it->first]; //* max(0.00000001, eta/(eta+learning_rates[it->first])); +        for (; it != sum_up.end(); ++it) { +          if (pclr == "simple") { +           lambdas[it->first] += it->second / max(1.0, learning_rates[it->first]); +           learning_rates[it->first]++; +          } else if (pclr == "adagrad") { +            if (learning_rates[it->first] == 0) { +             lambdas[it->first] +=  it->second * eta; +            } else { +             lambdas[it->first] +=  it->second * eta * learning_rates[it->first]; +            } +            learning_rates[it->first] += pow(it->second, 2.0); +          }          }        }  | 
