From 72c9dedc8124977712462c6babbc0c1b0375f813 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Wed, 15 May 2013 13:25:36 +0200 Subject: fixed l1 reg --- training/dtrain/dtrain.cc | 46 +++++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 19 deletions(-) (limited to 'training') diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 83e4e440..e1d5a2d4 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -366,6 +366,9 @@ main(int argc, char** argv) PROsampling(samples, pairs, pair_threshold, max_pairs); npairs += pairs.size(); + SparseVector lambdas_copy; + if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas; + for (vector >::iterator it = pairs.begin(); it != pairs.end(); it++) { bool rank_error; @@ -389,23 +392,26 @@ main(int argc, char** argv) } // l1 regularization - // please note that this penalizes _all_ weights - // (contrary to only the ones changed by the last update) - // after a _sentence_ (not after each example/pair) + // please note that this regularizations happen + // after a _sentence_ -- not after each example/pair! if (l1naive) { FastSparseVector::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { - it->second -= sign(it->second) * l1_reg; + if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + it->second -= sign(it->second) * l1_reg; + } } } else if (l1clip) { FastSparseVector::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { - if (it->second != 0) { - weight_t v = it->second; - if (v > 0) { - it->second = max(0., v - l1_reg); - } else { - it->second = min(0., v + l1_reg); + if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + if (it->second != 0) { + weight_t v = it->second; + if (v > 0) { + it->second = max(0., v - l1_reg); + } else { + it->second = min(0., v + l1_reg); + } } } } @@ -413,16 +419,18 @@ main(int argc, char** argv) weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input FastSparseVector::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { - if (it->second != 0) { - weight_t v = it->second; - weight_t penalized = 0.; - if (v > 0) { - penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first))); - } else { - penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first))); + if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + if (it->second != 0) { + weight_t v = it->second; + weight_t penalized = 0.; + if (v > 0) { + penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first))); + } else { + penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first))); + } + it->second = penalized; + cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized); } - it->second = penalized; - cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized); } } } -- cgit v1.2.3