diff options
author | Patrick Simianer <p@simianer.de> | 2013-09-10 19:54:40 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2013-09-10 19:54:40 +0200 |
commit | c171ea9c37bf170b91946e0f5d22e7fd0d2c5825 (patch) | |
tree | 8e5711328440d3bcaf87189eb34cebef365910b6 /training/dtrain/dtrain.cc | |
parent | 4d90fcf6a24be10aa35c7783e6f2623e9fedf9d6 (diff) |
do pclr after sentences..
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 34c0a54a..2d090666 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -372,7 +372,8 @@ main(int argc, char** argv) PROsampling(samples, pairs, pair_threshold, max_pairs); npairs += pairs.size(); - SparseVector<weight_t> lambdas_copy; + SparseVector<weight_t> lambdas_copy; // for l1 regularization + SparseVector<weight_t> sum_up; // for pclr if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas; for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); @@ -392,20 +393,24 @@ main(int argc, char** argv) if (rank_error || margin < loss_margin) { SparseVector<weight_t> diff_vec = it->first.f - it->second.f; if (pclr) { - SparseVector<weight_t>::iterator jt = diff_vec.begin(); - for (; jt != diff_vec.end(); ++it) { - jt->second *= max(0.0000001, eta/(eta+learning_rates[jt->first])); // FIXME - learning_rates[jt->first]++; - } - lambdas += diff_vec; - } else { - lambdas.plus_eq_v_times_s(diff_vec, eta); - } + sum_up += diff_vec; + } else { + lambdas.plus_eq_v_times_s(diff_vec, eta); + } if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); } } + // per-coordinate learning rate + if (pclr) { + SparseVector<weight_t>::iterator it = sum_up.begin(); + for (; it != lambdas.end(); ++it) { + lambdas[it->first] += it->second * max(0.00000001, eta/(eta+learning_rates[it->first])); + learning_rates[it->first]++; + } + } + // l1 regularization // please note that this regularizations happen // after a _sentence_ -- not after each example/pair! @@ -413,6 +418,8 @@ main(int argc, char** argv) SparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + it->second *= max(0.0000001, eta/(eta+learning_rates[it->first])); // FIXME + learning_rates[it->first]++; it->second -= sign(it->second) * l1_reg; } } @@ -530,6 +537,15 @@ main(int argc, char** argv) Weights::WriteToFile(w_fn, dense_weights, true); } + WriteFile of("-"); + ostream& o = *of.stream(); + o << "<<<<<<<<<<<<<<<<<<<<<<<<\n"; + for (SparseVector<weight_t>::iterator it = learning_rates.begin(); it != learning_rates.end(); ++it) { + if (it->second == 0) continue; + o << FD::Convert(it->first) << '\t' << it->second << endl; + } + o << ">>>>>>>>>>>>>>>>>>>>>>>>>\n"; + } // outer loop if (average) w_average /= (weight_t)T; |