diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 51 |
1 files changed, 50 insertions, 1 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index ca5f0c5e..448e639c 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -26,6 +26,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") ("keep_w", po::value<bool>()->zero_tokens(), "protocol weights for each iteration") ("unit_weight_vector", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input") + ("l1_reg", po::value<string>()->default_value("no"), "apply l1 regularization as in Tsuroka et al 2010") + ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") #ifdef DTRAIN_LOCAL ("refs,r", po::value<string>(), "references in local mode") #endif @@ -155,13 +157,25 @@ main(int argc, char** argv) // init weights vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); - SparseVector<weight_t> lambdas; + SparseVector<weight_t> lambdas, cumulative_penalties; if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights); Weights::InitSparseVector(dense_weights, &lambdas); // meta params for perceptron, SVM weight_t eta = cfg["learning_rate"].as<weight_t>(); weight_t gamma = cfg["gamma"].as<weight_t>(); + // l1 regularization + bool l1naive = false; + bool l1clip = false; + bool l1cumul = false; + weight_t l1_reg = 0; + if (cfg["l1_reg"].as<string>() != "no") { + string s = cfg["l1_reg"].as<string>(); + if (s == "naive") l1naive = true; + else if (s == "clip") l1clip = true; + else if (s == "cumul") l1cumul = true; + l1_reg = cfg["l1_reg_strength"].as<weight_t>(); + } // output string output_fn = cfg["output"].as<string>(); @@ -388,6 +402,41 @@ main(int argc, char** argv) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); } } + + // reset cumulative_penalties after 1 iter? + // do this only once per INPUT (not per pair) + if (l1naive) { + for (unsigned d = 0; d < lambdas.size(); d++) { + weight_t v = lambdas.get(d); + lambdas.set_value(d, v - sign(v) * l1_reg); + } + } else if (l1clip) { + for (unsigned d = 0; d < lambdas.size(); d++) { + if (lambdas.nonzero(d)) { + weight_t v = lambdas.get(d); + if (v > 0) { + lambdas.set_value(d, max(0., v - l1_reg)); + } else { + lambdas.set_value(d, min(0., v + l1_reg)); + } + } + } + } else if (l1cumul) { + weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input + for (unsigned d = 0; d < lambdas.size(); d++) { + if (lambdas.nonzero(d)) { + weight_t v = lambdas.get(d); + weight_t penalty = 0; + if (v > 0) { + penalty = max(0., v-(acc_penalty + cumulative_penalties.get(d))); + } else { + penalty = min(0., v+(acc_penalty - cumulative_penalties.get(d))); + } + lambdas.set_value(d, penalty); + cumulative_penalties.set_value(d, cumulative_penalties.get(d)+penalty); + } + } + } } if (unit_weight_vector && sample_from == "forest") lambdas /= lambdas.l2norm(); |