diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/dtrain/dtrain.cc | 39 |
1 files changed, 22 insertions, 17 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 18286668..b317c365 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -246,7 +246,7 @@ main(int argc, char** argv) cerr << setw(25) << "k " << k << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; - cerr << setw(25) << "scorer '" << scorer_str << "'" << endl; + cerr << setw(26) << "scorer '" << scorer_str << "'" << endl; if (scorer_str == "approx_bleu") cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; @@ -459,35 +459,40 @@ main(int argc, char** argv) } // l1 regularization + // please note that this penalizes _all_ weights + // (contrary to only the ones changed by the last update) + // after a _sentence_ (not after each example/pair) if (l1naive) { - for (unsigned d = 0; d < lambdas.size(); d++) { - weight_t v = lambdas.get(d); - lambdas.set_value(d, v - sign(v) * l1_reg); + FastSparseVector<weight_t>::iterator it = lambdas.begin(); + for (; it != lambdas.end(); ++it) { + it->second -= sign(it->second) * l1_reg; } } else if (l1clip) { - for (unsigned d = 0; d < lambdas.size(); d++) { - if (lambdas.nonzero(d)) { - weight_t v = lambdas.get(d); + FastSparseVector<weight_t>::iterator it = lambdas.begin(); + for (; it != lambdas.end(); ++it) { + if (it->second != 0) { + weight_t v = it->second; if (v > 0) { - lambdas.set_value(d, max(0., v - l1_reg)); + it->second = max(0., v - l1_reg); } else { - lambdas.set_value(d, min(0., v + l1_reg)); + it->second = min(0., v + l1_reg); } } } } else if (l1cumul) { weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input - for (unsigned d = 0; d < lambdas.size(); d++) { - if (lambdas.nonzero(d)) { - weight_t v = lambdas.get(d); - weight_t penalty = 0; + FastSparseVector<weight_t>::iterator it = lambdas.begin(); + for (; it != lambdas.end(); ++it) { + if (it->second != 0) { + weight_t v = it->second; + weight_t penalized = 0.; if (v > 0) { - penalty = max(0., v-(acc_penalty + cumulative_penalties.get(d))); + penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first))); } else { - penalty = min(0., v+(acc_penalty - cumulative_penalties.get(d))); + penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first))); } - lambdas.set_value(d, penalty); - cumulative_penalties.set_value(d, cumulative_penalties.get(d)+penalty); + it->second = penalized; + cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized); } } } |