summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPatrick Simianer <simianer@cl.uni-heidelberg.de>2013-10-08 13:57:45 +0200
committerPatrick Simianer <simianer@cl.uni-heidelberg.de>2013-10-08 13:57:45 +0200
commit8fae8c224fc7a8f8a858ed9a022992d020057f65 (patch)
tree295c876d6fa8c8cc6725ec6cd2a9f164fb78e232 /training/dtrain/dtrain.cc
parent61ab9164cce065ee3cb4fc6c72d0b7246874d18c (diff)
dtrain: added pclr variants and new expected-output; fixed bug in soft syntax features
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc31
1 files changed, 18 insertions, 13 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 9d60a903..38a9b69a 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -40,7 +40,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair")
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
- ("pclr", po::value<bool>()->zero_tokens(), "use a (simple) per-coordinate learning rate")
+ ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate")
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -125,8 +125,7 @@ main(int argc, char** argv)
if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max();
bool scale_bleu_diff = false;
if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true;
- bool pclr = false;
- if (cfg.count("pclr")) pclr = true;
+ const string pclr = cfg["pclr"].as<string>();
bool average = false;
if (select_weights == "avg")
average = true;
@@ -190,7 +189,6 @@ main(int argc, char** argv)
weight_t gamma = cfg["gamma"].as<weight_t>();
// faster perceptron: consider only misranked pairs, see
- // DO NOT ENABLE WITH SVM (gamma > 0) OR loss_margin!
bool faster_perceptron = false;
if (gamma==0 && loss_margin==0) faster_perceptron = true;
@@ -251,8 +249,7 @@ main(int argc, char** argv)
cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;
if (rescale)
cerr << setw(25) << "rescale " << rescale << endl;
- if (pclr)
- cerr << setw(25) << "pclr " << pclr << endl;
+ cerr << setw(25) << "pclr " << pclr << endl;
cerr << setw(25) << "max pairs " << max_pairs << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
@@ -392,22 +389,30 @@ main(int argc, char** argv)
if (scale_bleu_diff) eta = it->first.score - it->second.score;
if (rank_error || margin < loss_margin) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
- if (pclr) {
+ if (pclr != "no") {
sum_up += diff_vec;
} else {
lambdas.plus_eq_v_times_s(diff_vec, eta);
+ if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); // FIXME
}
- if (gamma)
- lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
}
}
// per-coordinate learning rate
- if (pclr) {
+ if (pclr != "no") {
SparseVector<weight_t>::iterator it = sum_up.begin();
- for (; it != lambdas.end(); ++it) {
- learning_rates[it->first]++;
- lambdas[it->first] += it->second / learning_rates[it->first]; //* max(0.00000001, eta/(eta+learning_rates[it->first]));
+ for (; it != sum_up.end(); ++it) {
+ if (pclr == "simple") {
+ lambdas[it->first] += it->second / max(1.0, learning_rates[it->first]);
+ learning_rates[it->first]++;
+ } else if (pclr == "adagrad") {
+ if (learning_rates[it->first] == 0) {
+ lambdas[it->first] += it->second * eta;
+ } else {
+ lambdas[it->first] += it->second * eta * learning_rates[it->first];
+ }
+ learning_rates[it->first] += pow(it->second, 2.0);
+ }
}
}