summaryrefslogtreecommitdiff
path: root/dtrain
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2011-11-14 01:20:50 +0100
committerPatrick Simianer <p@simianer.de>2011-11-14 01:20:50 +0100
commita27036119247fd5527ab8222e7df80ec2df31ca2 (patch)
tree1071cec16c106b1245ca82dc6841ae9f90f49a3c /dtrain
parent7b79fc9e6e6c9c2bb7f977978e319abe2143bbd9 (diff)
l1 regularzation
Diffstat (limited to 'dtrain')
-rw-r--r--dtrain/dtrain.cc51
-rw-r--r--dtrain/dtrain.h6
-rw-r--r--dtrain/test/example/dtrain.ini6
3 files changed, 60 insertions, 3 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index ca5f0c5e..448e639c 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -26,6 +26,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
("keep_w", po::value<bool>()->zero_tokens(), "protocol weights for each iteration")
("unit_weight_vector", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input")
+ ("l1_reg", po::value<string>()->default_value("no"), "apply l1 regularization as in Tsuroka et al 2010")
+ ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -155,13 +157,25 @@ main(int argc, char** argv)
// init weights
vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
- SparseVector<weight_t> lambdas;
+ SparseVector<weight_t> lambdas, cumulative_penalties;
if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights);
Weights::InitSparseVector(dense_weights, &lambdas);
// meta params for perceptron, SVM
weight_t eta = cfg["learning_rate"].as<weight_t>();
weight_t gamma = cfg["gamma"].as<weight_t>();
+ // l1 regularization
+ bool l1naive = false;
+ bool l1clip = false;
+ bool l1cumul = false;
+ weight_t l1_reg = 0;
+ if (cfg["l1_reg"].as<string>() != "no") {
+ string s = cfg["l1_reg"].as<string>();
+ if (s == "naive") l1naive = true;
+ else if (s == "clip") l1clip = true;
+ else if (s == "cumul") l1cumul = true;
+ l1_reg = cfg["l1_reg_strength"].as<weight_t>();
+ }
// output
string output_fn = cfg["output"].as<string>();
@@ -388,6 +402,41 @@ main(int argc, char** argv)
lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
}
}
+
+ // reset cumulative_penalties after 1 iter?
+ // do this only once per INPUT (not per pair)
+ if (l1naive) {
+ for (unsigned d = 0; d < lambdas.size(); d++) {
+ weight_t v = lambdas.get(d);
+ lambdas.set_value(d, v - sign(v) * l1_reg);
+ }
+ } else if (l1clip) {
+ for (unsigned d = 0; d < lambdas.size(); d++) {
+ if (lambdas.nonzero(d)) {
+ weight_t v = lambdas.get(d);
+ if (v > 0) {
+ lambdas.set_value(d, max(0., v - l1_reg));
+ } else {
+ lambdas.set_value(d, min(0., v + l1_reg));
+ }
+ }
+ }
+ } else if (l1cumul) {
+ weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
+ for (unsigned d = 0; d < lambdas.size(); d++) {
+ if (lambdas.nonzero(d)) {
+ weight_t v = lambdas.get(d);
+ weight_t penalty = 0;
+ if (v > 0) {
+ penalty = max(0., v-(acc_penalty + cumulative_penalties.get(d)));
+ } else {
+ penalty = min(0., v+(acc_penalty - cumulative_penalties.get(d)));
+ }
+ lambdas.set_value(d, penalty);
+ cumulative_penalties.set_value(d, cumulative_penalties.get(d)+penalty);
+ }
+ }
+ }
}
if (unit_weight_vector && sample_from == "forest") lambdas /= lambdas.l2norm();
diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h
index 84f3f1f5..cfc3f460 100644
--- a/dtrain/dtrain.h
+++ b/dtrain/dtrain.h
@@ -85,5 +85,11 @@ inline void printWordIDVec(vector<WordID>& v)
}
}
+template<typename T>
+inline T sign(T z) {
+ if (z == 0) return 0;
+ return z < 0 ? -1 : +1;
+}
+
#endif
diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini
index 95eeb8e5..900878a5 100644
--- a/dtrain/test/example/dtrain.ini
+++ b/dtrain/test/example/dtrain.ini
@@ -6,11 +6,13 @@ epochs=20
input=test/example/nc-1k-tabs.gz
scorer=stupid_bleu
output=weights.gz
-stop_after=10
+stop_after=100
sample_from=forest
pair_sampling=108010
select_weights=VOID
print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough
tmp=/tmp
-unit_weight_vector=true
+#unit_weight_vector=
keep_w=true
+#l1_reg=clip
+#l1_reg_strength=0.00001