diff options
Diffstat (limited to 'dtrain')
| -rw-r--r-- | dtrain/dtrain.cc | 51 | ||||
| -rw-r--r-- | dtrain/dtrain.h | 6 | ||||
| -rw-r--r-- | dtrain/test/example/dtrain.ini | 6 | 
3 files changed, 60 insertions, 3 deletions
| diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index ca5f0c5e..448e639c 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -26,6 +26,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)      ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")      ("keep_w",         po::value<bool>()->zero_tokens(),                              "protocol weights for each iteration")      ("unit_weight_vector", po::value<bool>()->zero_tokens(),                       "Rescale weight vector after each input") +    ("l1_reg",         po::value<string>()->default_value("no"),         "apply l1 regularization as in Tsuroka et al 2010") +    ("l1_reg_strength", po::value<weight_t>(),                                                 "l1 regularization strength")  #ifdef DTRAIN_LOCAL      ("refs,r",         po::value<string>(),                                                      "references in local mode")  #endif @@ -155,13 +157,25 @@ main(int argc, char** argv)    // init weights    vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); -  SparseVector<weight_t> lambdas; +  SparseVector<weight_t> lambdas, cumulative_penalties;    if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights);    Weights::InitSparseVector(dense_weights, &lambdas);    // meta params for perceptron, SVM    weight_t eta = cfg["learning_rate"].as<weight_t>();    weight_t gamma = cfg["gamma"].as<weight_t>(); +  // l1 regularization +  bool l1naive = false; +  bool l1clip = false; +  bool l1cumul = false; +  weight_t l1_reg = 0; +  if (cfg["l1_reg"].as<string>() != "no") { +    string s = cfg["l1_reg"].as<string>(); +    if (s == "naive") l1naive = true; +    else if (s == "clip") l1clip = true; +    else if (s == "cumul") l1cumul = true; +    l1_reg = cfg["l1_reg_strength"].as<weight_t>(); +  }    // output    string output_fn = cfg["output"].as<string>(); @@ -388,6 +402,41 @@ main(int argc, char** argv)            lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));          }        } + +      // reset cumulative_penalties after 1 iter?  +      // do this only once per INPUT (not per pair) +      if (l1naive) { +        for (unsigned d = 0; d < lambdas.size(); d++) { +          weight_t v = lambdas.get(d); +          lambdas.set_value(d, v - sign(v) * l1_reg); +        } +      } else if (l1clip) { +        for (unsigned d = 0; d < lambdas.size(); d++) { +          if (lambdas.nonzero(d)) { +            weight_t v = lambdas.get(d); +            if (v > 0) { +              lambdas.set_value(d, max(0., v - l1_reg)); +            } else { +              lambdas.set_value(d, min(0., v + l1_reg)); +            } +          } +        } +      } else if (l1cumul) { +        weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input  +        for (unsigned d = 0; d < lambdas.size(); d++) { +          if (lambdas.nonzero(d)) { +            weight_t v = lambdas.get(d); +            weight_t penalty = 0; +            if (v > 0) { +              penalty = max(0., v-(acc_penalty + cumulative_penalties.get(d))); +            } else { +              penalty = min(0., v+(acc_penalty - cumulative_penalties.get(d))); +            } +            lambdas.set_value(d, penalty); +            cumulative_penalties.set_value(d, cumulative_penalties.get(d)+penalty); +          } +        } +      }      }      if (unit_weight_vector && sample_from == "forest") lambdas /= lambdas.l2norm(); diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h index 84f3f1f5..cfc3f460 100644 --- a/dtrain/dtrain.h +++ b/dtrain/dtrain.h @@ -85,5 +85,11 @@ inline void printWordIDVec(vector<WordID>& v)    }  } +template<typename T> +inline T sign(T z) { +  if (z == 0) return 0; +  return z < 0 ? -1 : +1; +} +  #endif diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini index 95eeb8e5..900878a5 100644 --- a/dtrain/test/example/dtrain.ini +++ b/dtrain/test/example/dtrain.ini @@ -6,11 +6,13 @@ epochs=20  input=test/example/nc-1k-tabs.gz  scorer=stupid_bleu  output=weights.gz -stop_after=10 +stop_after=100  sample_from=forest  pair_sampling=108010  select_weights=VOID  print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough  tmp=/tmp -unit_weight_vector=true +#unit_weight_vector=  keep_w=true +#l1_reg=clip +#l1_reg_strength=0.00001 | 
