diff options
author | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2013-05-15 13:25:36 +0200 |
---|---|---|
committer | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2013-05-15 13:25:36 +0200 |
commit | 72c9dedc8124977712462c6babbc0c1b0375f813 (patch) | |
tree | c7ec2d8171786aa18580eb8498f44a969bb0a664 /training/dtrain | |
parent | 0ce66778da6079506896739e9d97dc7dff83cd72 (diff) |
fixed l1 reg
Diffstat (limited to 'training/dtrain')
-rw-r--r-- | training/dtrain/dtrain.cc | 46 |
1 files changed, 27 insertions, 19 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 83e4e440..e1d5a2d4 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -366,6 +366,9 @@ main(int argc, char** argv) PROsampling(samples, pairs, pair_threshold, max_pairs); npairs += pairs.size(); + SparseVector<weight_t> lambdas_copy; + if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas; + for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { bool rank_error; @@ -389,23 +392,26 @@ main(int argc, char** argv) } // l1 regularization - // please note that this penalizes _all_ weights - // (contrary to only the ones changed by the last update) - // after a _sentence_ (not after each example/pair) + // please note that this regularizations happen + // after a _sentence_ -- not after each example/pair! if (l1naive) { FastSparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { - it->second -= sign(it->second) * l1_reg; + if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + it->second -= sign(it->second) * l1_reg; + } } } else if (l1clip) { FastSparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { - if (it->second != 0) { - weight_t v = it->second; - if (v > 0) { - it->second = max(0., v - l1_reg); - } else { - it->second = min(0., v + l1_reg); + if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + if (it->second != 0) { + weight_t v = it->second; + if (v > 0) { + it->second = max(0., v - l1_reg); + } else { + it->second = min(0., v + l1_reg); + } } } } @@ -413,16 +419,18 @@ main(int argc, char** argv) weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input FastSparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { - if (it->second != 0) { - weight_t v = it->second; - weight_t penalized = 0.; - if (v > 0) { - penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first))); - } else { - penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first))); + if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + if (it->second != 0) { + weight_t v = it->second; + weight_t penalized = 0.; + if (v > 0) { + penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first))); + } else { + penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first))); + } + it->second = penalized; + cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized); } - it->second = penalized; - cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized); } } } |