summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-09-10 18:20:16 +0200
committerPatrick Simianer <p@simianer.de>2013-09-10 18:20:16 +0200
commitc398cef915ea7037c91066b6bfc19d915cac498b (patch)
tree0aad40078383b16dfce6d137fbf008aac8f085be /training
parentee1d45810c869411c6c3b7c6de366393882a2efe (diff)
simple pclr
Diffstat (limited to 'training')
-rw-r--r--training/dtrain/dtrain.cc25
1 files changed, 20 insertions, 5 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 0ee2f124..34c0a54a 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -40,6 +40,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair")
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
+ ("pclr", po::value<bool>()->zero_tokens(), "use a (simple) per-coordinate learning rate")
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -124,6 +125,8 @@ main(int argc, char** argv)
if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max();
bool scale_bleu_diff = false;
if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true;
+ bool pclr = false;
+ if (cfg.count("pclr")) pclr = true;
bool average = false;
if (select_weights == "avg")
average = true;
@@ -131,7 +134,6 @@ main(int argc, char** argv)
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
-
// setup decoder
register_feature_functions();
SetSilent(true);
@@ -249,6 +251,8 @@ main(int argc, char** argv)
cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;
if (rescale)
cerr << setw(25) << "rescale " << rescale << endl;
+ if (pclr)
+ cerr << setw(25) << "pclr " << pclr << endl;
cerr << setw(25) << "max pairs " << max_pairs << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
@@ -261,6 +265,8 @@ main(int argc, char** argv)
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl;
}
+ // pclr
+ SparseVector<weight_t> learning_rates;
for (unsigned t = 0; t < T; t++) // T epochs
{
@@ -385,7 +391,16 @@ main(int argc, char** argv)
if (scale_bleu_diff) eta = it->first.score - it->second.score;
if (rank_error || margin < loss_margin) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
- lambdas.plus_eq_v_times_s(diff_vec, eta);
+ if (pclr) {
+ SparseVector<weight_t>::iterator jt = diff_vec.begin();
+ for (; jt != diff_vec.end(); ++it) {
+ jt->second *= max(0.0000001, eta/(eta+learning_rates[jt->first])); // FIXME
+ learning_rates[jt->first]++;
+ }
+ lambdas += diff_vec;
+ } else {
+ lambdas.plus_eq_v_times_s(diff_vec, eta);
+ }
if (gamma)
lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
}
@@ -395,14 +410,14 @@ main(int argc, char** argv)
// please note that this regularizations happen
// after a _sentence_ -- not after each example/pair!
if (l1naive) {
- FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
it->second -= sign(it->second) * l1_reg;
}
}
} else if (l1clip) {
- FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
if (it->second != 0) {
@@ -417,7 +432,7 @@ main(int argc, char** argv)
}
} else if (l1cumul) {
weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
- FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
if (it->second != 0) {