diff options
author | Patrick Simianer <p@simianer.de> | 2011-11-14 01:20:50 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2011-11-14 01:20:50 +0100 |
commit | a27036119247fd5527ab8222e7df80ec2df31ca2 (patch) | |
tree | 1071cec16c106b1245ca82dc6841ae9f90f49a3c /dtrain | |
parent | 7b79fc9e6e6c9c2bb7f977978e319abe2143bbd9 (diff) |
l1 regularzation
Diffstat (limited to 'dtrain')
-rw-r--r-- | dtrain/dtrain.cc | 51 | ||||
-rw-r--r-- | dtrain/dtrain.h | 6 | ||||
-rw-r--r-- | dtrain/test/example/dtrain.ini | 6 |
3 files changed, 60 insertions, 3 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index ca5f0c5e..448e639c 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -26,6 +26,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") ("keep_w", po::value<bool>()->zero_tokens(), "protocol weights for each iteration") ("unit_weight_vector", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input") + ("l1_reg", po::value<string>()->default_value("no"), "apply l1 regularization as in Tsuroka et al 2010") + ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") #ifdef DTRAIN_LOCAL ("refs,r", po::value<string>(), "references in local mode") #endif @@ -155,13 +157,25 @@ main(int argc, char** argv) // init weights vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); - SparseVector<weight_t> lambdas; + SparseVector<weight_t> lambdas, cumulative_penalties; if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights); Weights::InitSparseVector(dense_weights, &lambdas); // meta params for perceptron, SVM weight_t eta = cfg["learning_rate"].as<weight_t>(); weight_t gamma = cfg["gamma"].as<weight_t>(); + // l1 regularization + bool l1naive = false; + bool l1clip = false; + bool l1cumul = false; + weight_t l1_reg = 0; + if (cfg["l1_reg"].as<string>() != "no") { + string s = cfg["l1_reg"].as<string>(); + if (s == "naive") l1naive = true; + else if (s == "clip") l1clip = true; + else if (s == "cumul") l1cumul = true; + l1_reg = cfg["l1_reg_strength"].as<weight_t>(); + } // output string output_fn = cfg["output"].as<string>(); @@ -388,6 +402,41 @@ main(int argc, char** argv) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); } } + + // reset cumulative_penalties after 1 iter? + // do this only once per INPUT (not per pair) + if (l1naive) { + for (unsigned d = 0; d < lambdas.size(); d++) { + weight_t v = lambdas.get(d); + lambdas.set_value(d, v - sign(v) * l1_reg); + } + } else if (l1clip) { + for (unsigned d = 0; d < lambdas.size(); d++) { + if (lambdas.nonzero(d)) { + weight_t v = lambdas.get(d); + if (v > 0) { + lambdas.set_value(d, max(0., v - l1_reg)); + } else { + lambdas.set_value(d, min(0., v + l1_reg)); + } + } + } + } else if (l1cumul) { + weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input + for (unsigned d = 0; d < lambdas.size(); d++) { + if (lambdas.nonzero(d)) { + weight_t v = lambdas.get(d); + weight_t penalty = 0; + if (v > 0) { + penalty = max(0., v-(acc_penalty + cumulative_penalties.get(d))); + } else { + penalty = min(0., v+(acc_penalty - cumulative_penalties.get(d))); + } + lambdas.set_value(d, penalty); + cumulative_penalties.set_value(d, cumulative_penalties.get(d)+penalty); + } + } + } } if (unit_weight_vector && sample_from == "forest") lambdas /= lambdas.l2norm(); diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h index 84f3f1f5..cfc3f460 100644 --- a/dtrain/dtrain.h +++ b/dtrain/dtrain.h @@ -85,5 +85,11 @@ inline void printWordIDVec(vector<WordID>& v) } } +template<typename T> +inline T sign(T z) { + if (z == 0) return 0; + return z < 0 ? -1 : +1; +} + #endif diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini index 95eeb8e5..900878a5 100644 --- a/dtrain/test/example/dtrain.ini +++ b/dtrain/test/example/dtrain.ini @@ -6,11 +6,13 @@ epochs=20 input=test/example/nc-1k-tabs.gz scorer=stupid_bleu output=weights.gz -stop_after=10 +stop_after=100 sample_from=forest pair_sampling=108010 select_weights=VOID print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough tmp=/tmp -unit_weight_vector=true +#unit_weight_vector= keep_w=true +#l1_reg=clip +#l1_reg_strength=0.00001 |