diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/dtrain/dtrain.cc | 25 | 
1 files changed, 20 insertions, 5 deletions
| diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 0ee2f124..34c0a54a 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -40,6 +40,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)      ("scale_bleu_diff",   po::value<bool>()->zero_tokens(),                      "learning rate <- bleu diff of a misranked pair")      ("loss_margin",       po::value<weight_t>()->default_value(0.),  "update if no error in pref pair but model scores this near")      ("max_pairs",         po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") +    ("pclr",              po::value<bool>()->zero_tokens(),                         "use a (simple) per-coordinate learning rate")      ("noup",              po::value<bool>()->zero_tokens(),                                               "do not update weights");    po::options_description cl("Command Line Options");    cl.add_options() @@ -124,6 +125,8 @@ main(int argc, char** argv)    if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max();    bool scale_bleu_diff = false;    if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true; +  bool pclr = false; +  if (cfg.count("pclr")) pclr = true;    bool average = false;    if (select_weights == "avg")      average = true; @@ -131,7 +134,6 @@ main(int argc, char** argv)    if (cfg.count("print_weights"))      boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); -    // setup decoder    register_feature_functions();    SetSilent(true); @@ -249,6 +251,8 @@ main(int argc, char** argv)        cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;      if (rescale)        cerr << setw(25) << "rescale " << rescale << endl; +    if (pclr) +      cerr << setw(25) << "pclr " << pclr << endl;      cerr << setw(25) << "max pairs " << max_pairs << endl;      cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;      cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; @@ -261,6 +265,8 @@ main(int argc, char** argv)      if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl;    } +  // pclr +  SparseVector<weight_t> learning_rates;    for (unsigned t = 0; t < T; t++) // T epochs    { @@ -385,7 +391,16 @@ main(int argc, char** argv)          if (scale_bleu_diff) eta = it->first.score - it->second.score;          if (rank_error || margin < loss_margin) {            SparseVector<weight_t> diff_vec = it->first.f - it->second.f; -          lambdas.plus_eq_v_times_s(diff_vec, eta); +          if (pclr) { +            SparseVector<weight_t>::iterator jt = diff_vec.begin(); +            for (; jt != diff_vec.end(); ++it) { +              jt->second *= max(0.0000001, eta/(eta+learning_rates[jt->first])); // FIXME +              learning_rates[jt->first]++; +            } +            lambdas += diff_vec; +            } else { +              lambdas.plus_eq_v_times_s(diff_vec, eta); +            }            if (gamma)              lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));          } @@ -395,14 +410,14 @@ main(int argc, char** argv)        // please note that this regularizations happen        // after a _sentence_ -- not after each example/pair!        if (l1naive) { -        FastSparseVector<weight_t>::iterator it = lambdas.begin(); +        SparseVector<weight_t>::iterator it = lambdas.begin();          for (; it != lambdas.end(); ++it) {            if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {              it->second -= sign(it->second) * l1_reg;            }          }        } else if (l1clip) { -        FastSparseVector<weight_t>::iterator it = lambdas.begin(); +        SparseVector<weight_t>::iterator it = lambdas.begin();          for (; it != lambdas.end(); ++it) {            if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {              if (it->second != 0) { @@ -417,7 +432,7 @@ main(int argc, char** argv)          }        } else if (l1cumul) {          weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input -        FastSparseVector<weight_t>::iterator it = lambdas.begin(); +        SparseVector<weight_t>::iterator it = lambdas.begin();          for (; it != lambdas.end(); ++it) {            if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {              if (it->second != 0) { | 
