summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/dtrain/dtrain.cc39
1 files changed, 22 insertions, 17 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 18286668..b317c365 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -246,7 +246,7 @@ main(int argc, char** argv)
cerr << setw(25) << "k " << k << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
- cerr << setw(25) << "scorer '" << scorer_str << "'" << endl;
+ cerr << setw(26) << "scorer '" << scorer_str << "'" << endl;
if (scorer_str == "approx_bleu")
cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl;
cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
@@ -459,35 +459,40 @@ main(int argc, char** argv)
}
// l1 regularization
+ // please note that this penalizes _all_ weights
+ // (contrary to only the ones changed by the last update)
+ // after a _sentence_ (not after each example/pair)
if (l1naive) {
- for (unsigned d = 0; d < lambdas.size(); d++) {
- weight_t v = lambdas.get(d);
- lambdas.set_value(d, v - sign(v) * l1_reg);
+ FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ for (; it != lambdas.end(); ++it) {
+ it->second -= sign(it->second) * l1_reg;
}
} else if (l1clip) {
- for (unsigned d = 0; d < lambdas.size(); d++) {
- if (lambdas.nonzero(d)) {
- weight_t v = lambdas.get(d);
+ FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ for (; it != lambdas.end(); ++it) {
+ if (it->second != 0) {
+ weight_t v = it->second;
if (v > 0) {
- lambdas.set_value(d, max(0., v - l1_reg));
+ it->second = max(0., v - l1_reg);
} else {
- lambdas.set_value(d, min(0., v + l1_reg));
+ it->second = min(0., v + l1_reg);
}
}
}
} else if (l1cumul) {
weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
- for (unsigned d = 0; d < lambdas.size(); d++) {
- if (lambdas.nonzero(d)) {
- weight_t v = lambdas.get(d);
- weight_t penalty = 0;
+ FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ for (; it != lambdas.end(); ++it) {
+ if (it->second != 0) {
+ weight_t v = it->second;
+ weight_t penalized = 0.;
if (v > 0) {
- penalty = max(0., v-(acc_penalty + cumulative_penalties.get(d)));
+ penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
} else {
- penalty = min(0., v+(acc_penalty - cumulative_penalties.get(d)));
+ penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
}
- lambdas.set_value(d, penalty);
- cumulative_penalties.set_value(d, cumulative_penalties.get(d)+penalty);
+ it->second = penalized;
+ cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
}
}
}