summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-09-10 19:54:40 +0200
committerPatrick Simianer <p@simianer.de>2013-09-10 19:54:40 +0200
commitc171ea9c37bf170b91946e0f5d22e7fd0d2c5825 (patch)
tree8e5711328440d3bcaf87189eb34cebef365910b6 /training/dtrain/dtrain.cc
parent4d90fcf6a24be10aa35c7783e6f2623e9fedf9d6 (diff)
do pclr after sentences..
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc36
1 files changed, 26 insertions, 10 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 34c0a54a..2d090666 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -372,7 +372,8 @@ main(int argc, char** argv)
PROsampling(samples, pairs, pair_threshold, max_pairs);
npairs += pairs.size();
- SparseVector<weight_t> lambdas_copy;
+ SparseVector<weight_t> lambdas_copy; // for l1 regularization
+ SparseVector<weight_t> sum_up; // for pclr
if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
@@ -392,20 +393,24 @@ main(int argc, char** argv)
if (rank_error || margin < loss_margin) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
if (pclr) {
- SparseVector<weight_t>::iterator jt = diff_vec.begin();
- for (; jt != diff_vec.end(); ++it) {
- jt->second *= max(0.0000001, eta/(eta+learning_rates[jt->first])); // FIXME
- learning_rates[jt->first]++;
- }
- lambdas += diff_vec;
- } else {
- lambdas.plus_eq_v_times_s(diff_vec, eta);
- }
+ sum_up += diff_vec;
+ } else {
+ lambdas.plus_eq_v_times_s(diff_vec, eta);
+ }
if (gamma)
lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
}
}
+ // per-coordinate learning rate
+ if (pclr) {
+ SparseVector<weight_t>::iterator it = sum_up.begin();
+ for (; it != lambdas.end(); ++it) {
+ lambdas[it->first] += it->second * max(0.00000001, eta/(eta+learning_rates[it->first]));
+ learning_rates[it->first]++;
+ }
+ }
+
// l1 regularization
// please note that this regularizations happen
// after a _sentence_ -- not after each example/pair!
@@ -413,6 +418,8 @@ main(int argc, char** argv)
SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ it->second *= max(0.0000001, eta/(eta+learning_rates[it->first])); // FIXME
+ learning_rates[it->first]++;
it->second -= sign(it->second) * l1_reg;
}
}
@@ -530,6 +537,15 @@ main(int argc, char** argv)
Weights::WriteToFile(w_fn, dense_weights, true);
}
+ WriteFile of("-");
+ ostream& o = *of.stream();
+ o << "<<<<<<<<<<<<<<<<<<<<<<<<\n";
+ for (SparseVector<weight_t>::iterator it = learning_rates.begin(); it != learning_rates.end(); ++it) {
+ if (it->second == 0) continue;
+ o << FD::Convert(it->first) << '\t' << it->second << endl;
+ }
+ o << ">>>>>>>>>>>>>>>>>>>>>>>>>\n";
+
} // outer loop
if (average) w_average /= (weight_t)T;