diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/dtrain/dtrain.cc | 46 | 
1 files changed, 27 insertions, 19 deletions
| diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 83e4e440..e1d5a2d4 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -366,6 +366,9 @@ main(int argc, char** argv)          PROsampling(samples, pairs, pair_threshold, max_pairs);        npairs += pairs.size(); +      SparseVector<weight_t> lambdas_copy; +      if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas; +        for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();             it != pairs.end(); it++) {          bool rank_error; @@ -389,23 +392,26 @@ main(int argc, char** argv)        }        // l1 regularization -      // please note that this penalizes _all_ weights -      // (contrary to only the ones changed by the last update) -      // after a _sentence_ (not after each example/pair) +      // please note that this regularizations happen +      // after a _sentence_ -- not after each example/pair!        if (l1naive) {          FastSparseVector<weight_t>::iterator it = lambdas.begin();          for (; it != lambdas.end(); ++it) { -          it->second -= sign(it->second) * l1_reg; +          if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { +            it->second -= sign(it->second) * l1_reg; +          }          }        } else if (l1clip) {          FastSparseVector<weight_t>::iterator it = lambdas.begin();          for (; it != lambdas.end(); ++it) { -          if (it->second != 0) { -            weight_t v = it->second; -            if (v > 0) { -              it->second = max(0., v - l1_reg); -            } else { -              it->second = min(0., v + l1_reg); +          if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { +            if (it->second != 0) { +              weight_t v = it->second; +              if (v > 0) { +                it->second = max(0., v - l1_reg); +              } else { +                it->second = min(0., v + l1_reg); +              }              }            }          } @@ -413,16 +419,18 @@ main(int argc, char** argv)          weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input          FastSparseVector<weight_t>::iterator it = lambdas.begin();          for (; it != lambdas.end(); ++it) { -          if (it->second != 0) { -            weight_t v = it->second; -            weight_t penalized = 0.; -            if (v > 0) { -              penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first))); -            } else { -              penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first))); +          if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { +            if (it->second != 0) { +              weight_t v = it->second; +              weight_t penalized = 0.; +              if (v > 0) { +                penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first))); +              } else { +                penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first))); +              } +              it->second = penalized; +              cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);              } -            it->second = penalized; -            cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);            }          }        } | 
