From 176f311f6b4b2048dd05e0304d66ae5c61a4506e Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Fri, 8 Apr 2016 21:20:29 +0200
Subject: dtrain: adadelta, fix output, max input, batch learning
---
training/dtrain/dtrain.cc | 111 +++++++++++++++++++++++++++++++++++++++-------
1 file changed, 96 insertions(+), 15 deletions(-)
(limited to 'training/dtrain/dtrain.cc')
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 3e9902ab..53e8cd50 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -41,6 +41,13 @@ main(int argc, char** argv)
const bool output_updates = output_updates_fn!="";
const string output_raw_fn = conf["output_raw"].as();
const bool output_raw = output_raw_fn!="";
+ const bool use_adadelta = conf["adadelta"].as();
+ const weight_t adadelta_decay = conf["adadelta_decay"].as();
+ const weight_t adadelta_eta = 0.000001;
+ const string adadelta_input = conf["adadelta_input"].as();
+ const string adadelta_output = conf["adadelta_output"].as();
+ const size_t max_input = conf["stop_after"].as();
+ const bool batch = conf["batch"].as();
// setup decoder
register_feature_functions();
@@ -89,8 +96,8 @@ main(int argc, char** argv)
vector > buffered_lengths; // (just once)
size_t input_sz = 0;
- cerr << setprecision(4);
// output configuration
+ cerr << fixed << setprecision(4);
cerr << "Parameters:" << endl;
cerr << setw(25) << "bitext " << "'" << input_fn << "'" << endl;
cerr << setw(25) << "k " << k << endl;
@@ -109,10 +116,10 @@ main(int argc, char** argv)
cerr << setw(25) << "chiang decay " << chiang_decay << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
- cerr << setw(25) << "learning rate " << eta << endl;
+ cerr << scientific << setw(25) << "learning rate " << eta << endl;
cerr << setw(25) << "margin " << margin << endl;
if (!structured) {
- cerr << setw(25) << "cut " << round(cut*100) << "%" << endl;
+ cerr << fixed << setw(25) << "cut " << round(cut*100) << "%" << endl;
cerr << setw(25) << "adjust " << adjust_cut << endl;
} else {
cerr << setw(25) << "struct. obj " << structured << endl;
@@ -124,7 +131,7 @@ main(int argc, char** argv)
if (noup)
cerr << setw(25) << "no up. " << noup << endl;
cerr << setw(25) << "average " << average << endl;
- cerr << setw(25) << "l1 reg. " << l1_reg << endl;
+ cerr << scientific << setw(25) << "l1 reg. " << l1_reg << endl;
cerr << setw(25) << "decoder conf " << "'"
<< conf["decoder_conf"].as() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
@@ -133,8 +140,17 @@ main(int argc, char** argv)
cerr << setw(25) << "weights in " << "'"
<< conf["input_weights"].as() << "'" << endl;
}
+ cerr << setw(25) << "batch " << batch << endl;
if (noup)
cerr << setw(25) << "no updates!" << endl;
+ if (use_adadelta) {
+ cerr << setw(25) << "adadelta " << use_adadelta << endl;
+ cerr << setw(25) << " decay " << adadelta_decay << endl;
+ if (adadelta_input != "")
+ cerr << setw(25) << "-input " << adadelta_input << endl;
+ if (adadelta_output != "")
+ cerr << setw(25) << "-output " << adadelta_output << endl;
+ }
cerr << "(1 dot per processed input)" << endl;
// meta
@@ -153,10 +169,23 @@ main(int argc, char** argv)
*out_up << setprecision(numeric_limits::digits10+1);
}
+ // adadelta
+ SparseVector gradient_accum, update_accum;
+ if (use_adadelta && adadelta_input!="") {
+ vector grads_tmp;
+ Weights::InitFromFile(adadelta_input+".gradient", &grads_tmp);
+ Weights::InitSparseVector(grads_tmp, &gradient_accum);
+ vector update_tmp;
+ Weights::InitFromFile(adadelta_input+".update", &update_tmp);
+ Weights::InitSparseVector(update_tmp, &update_accum);
+ }
for (size_t t = 0; t < T; t++) // T iterations
{
+ // batch update
+ SparseVector batch_update;
+
time_t start, end;
time(&start);
weight_t gold_sum=0., model_sum=0.;
@@ -194,6 +223,9 @@ main(int argc, char** argv)
next = ieffective_size;
if (output_raw)
- output_sample(sample, *out_raw, i);
+ output_sample(sample, out_raw, i);
// update model
if (!noup) {
@@ -233,21 +265,46 @@ main(int argc, char** argv)
SparseVector updates;
if (structured)
num_up += update_structured(sample, updates, margin,
- output_updates, *out_up, i);
+ out_up, i);
else if (all_pairs)
num_up += updates_all(sample, updates, max_up, threshold,
- output_updates, *out_up, i);
+ out_up, i);
else if (pro)
num_up += updates_pro(sample, updates, cut, max_up, threshold,
- output_updates, *out_up, i);
+ out_up, i);
else
num_up += updates_multipartite(sample, updates, cut, margin,
max_up, threshold, adjust_cut,
- output_updates, *out_up, i);
+ out_up, i);
+
SparseVector lambdas_copy;
if (l1_reg)
lambdas_copy = lambdas;
- lambdas.plus_eq_v_times_s(updates, eta);
+
+ if (use_adadelta) { // adadelta update
+ SparseVector squared;
+ for (auto it: updates)
+ squared[it.first] = pow(it.second, 2.0);
+ gradient_accum *= adadelta_decay;
+ squared *= 1.0-adadelta_decay;
+ gradient_accum += squared;
+ SparseVector u = gradient_accum + update_accum;
+ for (auto it: u)
+ u[it.first] = -1.0*(
+ sqrt(update_accum[it.first]+adadelta_eta)
+ /
+ sqrt(gradient_accum[it.first]+adadelta_eta)
+ ) * updates[it.first];
+ lambdas += u;
+ update_accum *= adadelta_decay;
+ for (auto it: u)
+ u[it.first] = pow(it.second, 2.0);
+ update_accum = update_accum + (u*(1.0-adadelta_decay));
+ } else if (batch) {
+ batch_update += updates;
+ } else { // regular update
+ lambdas.plus_eq_v_times_s(updates, eta);
+ }
// update context for Chiang's approx. BLEU
if (score_name == "chiang") {
@@ -290,23 +347,47 @@ main(int argc, char** argv)
if (t == 0)
input_sz = i; // remember size of input (# lines)
+ // batch
+ if (batch) {
+ batch_update /= (weight_t)num_up;
+ lambdas.plus_eq_v_times_s(batch_update, eta);
+ lambdas.init_vector(&decoder_weights);
+ }
+
// update average
if (average)
w_average += lambdas;
+ if (adadelta_output != "") {
+ WriteFile g(adadelta_output+".gradient.gz");
+ for (auto it: gradient_accum)
+ *g << FD::Convert(it.first) << " " << it.second << endl;
+ WriteFile u(adadelta_output+".update.gz");
+ for (auto it: update_accum)
+ *u << FD::Convert(it.first) << " " << it.second << endl;
+ }
+
// stats
weight_t gold_avg = gold_sum/(weight_t)input_sz;
- cerr << setiosflags(ios::showpos) << "WEIGHTS" << endl;
- for (auto name: print_weights)
+ cerr << setiosflags(ios::showpos) << scientific << "WEIGHTS" << endl;
+ for (auto name: print_weights) {
cerr << setw(18) << name << " = "
- << lambdas.get(FD::Convert(name)) << endl;
+ << lambdas.get(FD::Convert(name));
+ if (use_adadelta) {
+ weight_t rate = -1.0*(sqrt(update_accum[FD::Convert(name)]+adadelta_eta)
+ / sqrt(gradient_accum[FD::Convert(name)]+adadelta_eta));
+ cerr << " {" << rate << "}";
+ }
+ cerr << endl;
+ }
cerr << " ---" << endl;
cerr << resetiosflags(ios::showpos)
<< " 1best avg score: " << gold_avg*100;
- cerr << setiosflags(ios::showpos) << " ("
+ cerr << setiosflags(ios::showpos) << fixed << " ("
<< (gold_avg-gold_prev)*100 << ")" << endl;
- cerr << " 1best avg model score: "
+ cerr << scientific << " 1best avg model score: "
<< model_sum/(weight_t)input_sz << endl;
+ cerr << fixed;
cerr << " avg # updates: ";
cerr << resetiosflags(ios::showpos) << num_up/(float)input_sz << endl;
cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl;
--
cgit v1.2.3