From 72c9dedc8124977712462c6babbc0c1b0375f813 Mon Sep 17 00:00:00 2001
From: Patrick Simianer <simianer@cl.uni-heidelberg.de>
Date: Wed, 15 May 2013 13:25:36 +0200
Subject: fixed l1 reg

---
 training/dtrain/dtrain.cc | 46 +++++++++++++++++++++++++++-------------------
 1 file changed, 27 insertions(+), 19 deletions(-)

diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 83e4e440..e1d5a2d4 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -366,6 +366,9 @@ main(int argc, char** argv)
         PROsampling(samples, pairs, pair_threshold, max_pairs);
       npairs += pairs.size();
 
+      SparseVector<weight_t> lambdas_copy;
+      if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
+
       for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
            it != pairs.end(); it++) {
         bool rank_error;
@@ -389,23 +392,26 @@ main(int argc, char** argv)
       }
 
       // l1 regularization
-      // please note that this penalizes _all_ weights
-      // (contrary to only the ones changed by the last update)
-      // after a _sentence_ (not after each example/pair)
+      // please note that this regularizations happen
+      // after a _sentence_ -- not after each example/pair!
       if (l1naive) {
         FastSparseVector<weight_t>::iterator it = lambdas.begin();
         for (; it != lambdas.end(); ++it) {
-          it->second -= sign(it->second) * l1_reg;
+          if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+            it->second -= sign(it->second) * l1_reg;
+          }
         }
       } else if (l1clip) {
         FastSparseVector<weight_t>::iterator it = lambdas.begin();
         for (; it != lambdas.end(); ++it) {
-          if (it->second != 0) {
-            weight_t v = it->second;
-            if (v > 0) {
-              it->second = max(0., v - l1_reg);
-            } else {
-              it->second = min(0., v + l1_reg);
+          if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+            if (it->second != 0) {
+              weight_t v = it->second;
+              if (v > 0) {
+                it->second = max(0., v - l1_reg);
+              } else {
+                it->second = min(0., v + l1_reg);
+              }
             }
           }
         }
@@ -413,16 +419,18 @@ main(int argc, char** argv)
         weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
         FastSparseVector<weight_t>::iterator it = lambdas.begin();
         for (; it != lambdas.end(); ++it) {
-          if (it->second != 0) {
-            weight_t v = it->second;
-            weight_t penalized = 0.;
-            if (v > 0) {
-              penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
-            } else {
-              penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
+          if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+            if (it->second != 0) {
+              weight_t v = it->second;
+              weight_t penalized = 0.;
+              if (v > 0) {
+                penalized = max(0., v-(acc_penalty + cumulative_penalties.get(it->first)));
+              } else {
+                penalized = min(0., v+(acc_penalty - cumulative_penalties.get(it->first)));
+              }
+              it->second = penalized;
+              cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
             }
-            it->second = penalized;
-            cumulative_penalties.set_value(it->first, cumulative_penalties.get(it->first)+penalized);
           }
         }
       }
-- 
cgit v1.2.3