summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc46
1 files changed, 27 insertions, 19 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 83e4e440..e1d5a2d4 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -366,6 +366,9 @@ main(int argc, char** argv)
PROsampling(samples, pairs, pair_threshold, max_pairs);
npairs += pairs.size();
+ SparseVector<weight_t> lambdas_copy;
+ if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
+
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
bool rank_error;
@@ -389,23 +392,26 @@ main(int argc, char** argv)
}
// l1 regularization
- // please note that this penalizes _all_ weights
- // (contrary to only the ones changed by the last update)
- // after a _sentence_ (not after each example/pair)
+ // please note that this regularizations happen
+ // after a _sentence_ -- not after each example/pair!
if (l1naive) {
FastSparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
- it->second -= sign(it->second) * l1_reg;
+ if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ it->second -= sign(it->second) * l1_reg;
+ }
}
} else if (l1clip) {
FastSparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
- if (it->second != 0) {
- weight_t v = it->second;
- if (v > 0) {
- it->second = max(0., v - l1_reg);
- } else {
- it->second = min(0., v + l1_reg);
+ if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ if (it->second != 0) {
+ weight_t v = it->second;
+ if (v > 0) {
+ it->second = max(0., v - l1_reg);
+ } else {
+ it->second = min(0., v + l1_reg);
+ }
}
}
}
@@ -413,16 +419,18 @@ main(int argc, char** argv)
weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
FastSparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
- if (it->second != 0) {
- weight_t v = it->second;
- weight_t penalized = 0.;
- if (v > 0) {
- penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
- } else {
- penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
+ if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ if (it->second != 0) {
+ weight_t v = it->second;
+ weight_t penalized = 0.;
+ if (v > 0) {
+ penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
+ } else {
+ penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
+ }
+ it->second = penalized;
+ cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
}
- it->second = penalized;
- cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
}
}
}