From 176f311f6b4b2048dd05e0304d66ae5c61a4506e Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 8 Apr 2016 21:20:29 +0200 Subject: dtrain: adadelta, fix output, max input, batch learning --- training/dtrain/dtrain.cc | 111 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 96 insertions(+), 15 deletions(-) (limited to 'training/dtrain/dtrain.cc') diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 3e9902ab..53e8cd50 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -41,6 +41,13 @@ main(int argc, char** argv) const bool output_updates = output_updates_fn!=""; const string output_raw_fn = conf["output_raw"].as(); const bool output_raw = output_raw_fn!=""; + const bool use_adadelta = conf["adadelta"].as(); + const weight_t adadelta_decay = conf["adadelta_decay"].as(); + const weight_t adadelta_eta = 0.000001; + const string adadelta_input = conf["adadelta_input"].as(); + const string adadelta_output = conf["adadelta_output"].as(); + const size_t max_input = conf["stop_after"].as(); + const bool batch = conf["batch"].as(); // setup decoder register_feature_functions(); @@ -89,8 +96,8 @@ main(int argc, char** argv) vector > buffered_lengths; // (just once) size_t input_sz = 0; - cerr << setprecision(4); // output configuration + cerr << fixed << setprecision(4); cerr << "Parameters:" << endl; cerr << setw(25) << "bitext " << "'" << input_fn << "'" << endl; cerr << setw(25) << "k " << k << endl; @@ -109,10 +116,10 @@ main(int argc, char** argv) cerr << setw(25) << "chiang decay " << chiang_decay << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; - cerr << setw(25) << "learning rate " << eta << endl; + cerr << scientific << setw(25) << "learning rate " << eta << endl; cerr << setw(25) << "margin " << margin << endl; if (!structured) { - cerr << setw(25) << "cut " << round(cut*100) << "%" << endl; + cerr << fixed << setw(25) << "cut " << round(cut*100) << "%" << endl; cerr << setw(25) << "adjust " << adjust_cut << endl; } else { cerr << setw(25) << "struct. obj " << structured << endl; @@ -124,7 +131,7 @@ main(int argc, char** argv) if (noup) cerr << setw(25) << "no up. " << noup << endl; cerr << setw(25) << "average " << average << endl; - cerr << setw(25) << "l1 reg. " << l1_reg << endl; + cerr << scientific << setw(25) << "l1 reg. " << l1_reg << endl; cerr << setw(25) << "decoder conf " << "'" << conf["decoder_conf"].as() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; @@ -133,8 +140,17 @@ main(int argc, char** argv) cerr << setw(25) << "weights in " << "'" << conf["input_weights"].as() << "'" << endl; } + cerr << setw(25) << "batch " << batch << endl; if (noup) cerr << setw(25) << "no updates!" << endl; + if (use_adadelta) { + cerr << setw(25) << "adadelta " << use_adadelta << endl; + cerr << setw(25) << " decay " << adadelta_decay << endl; + if (adadelta_input != "") + cerr << setw(25) << "-input " << adadelta_input << endl; + if (adadelta_output != "") + cerr << setw(25) << "-output " << adadelta_output << endl; + } cerr << "(1 dot per processed input)" << endl; // meta @@ -153,10 +169,23 @@ main(int argc, char** argv) *out_up << setprecision(numeric_limits::digits10+1); } + // adadelta + SparseVector gradient_accum, update_accum; + if (use_adadelta && adadelta_input!="") { + vector grads_tmp; + Weights::InitFromFile(adadelta_input+".gradient", &grads_tmp); + Weights::InitSparseVector(grads_tmp, &gradient_accum); + vector update_tmp; + Weights::InitFromFile(adadelta_input+".update", &update_tmp); + Weights::InitSparseVector(update_tmp, &update_accum); + } for (size_t t = 0; t < T; t++) // T iterations { + // batch update + SparseVector batch_update; + time_t start, end; time(&start); weight_t gold_sum=0., model_sum=0.; @@ -194,6 +223,9 @@ main(int argc, char** argv) next = ieffective_size; if (output_raw) - output_sample(sample, *out_raw, i); + output_sample(sample, out_raw, i); // update model if (!noup) { @@ -233,21 +265,46 @@ main(int argc, char** argv) SparseVector updates; if (structured) num_up += update_structured(sample, updates, margin, - output_updates, *out_up, i); + out_up, i); else if (all_pairs) num_up += updates_all(sample, updates, max_up, threshold, - output_updates, *out_up, i); + out_up, i); else if (pro) num_up += updates_pro(sample, updates, cut, max_up, threshold, - output_updates, *out_up, i); + out_up, i); else num_up += updates_multipartite(sample, updates, cut, margin, max_up, threshold, adjust_cut, - output_updates, *out_up, i); + out_up, i); + SparseVector lambdas_copy; if (l1_reg) lambdas_copy = lambdas; - lambdas.plus_eq_v_times_s(updates, eta); + + if (use_adadelta) { // adadelta update + SparseVector squared; + for (auto it: updates) + squared[it.first] = pow(it.second, 2.0); + gradient_accum *= adadelta_decay; + squared *= 1.0-adadelta_decay; + gradient_accum += squared; + SparseVector u = gradient_accum + update_accum; + for (auto it: u) + u[it.first] = -1.0*( + sqrt(update_accum[it.first]+adadelta_eta) + / + sqrt(gradient_accum[it.first]+adadelta_eta) + ) * updates[it.first]; + lambdas += u; + update_accum *= adadelta_decay; + for (auto it: u) + u[it.first] = pow(it.second, 2.0); + update_accum = update_accum + (u*(1.0-adadelta_decay)); + } else if (batch) { + batch_update += updates; + } else { // regular update + lambdas.plus_eq_v_times_s(updates, eta); + } // update context for Chiang's approx. BLEU if (score_name == "chiang") { @@ -290,23 +347,47 @@ main(int argc, char** argv) if (t == 0) input_sz = i; // remember size of input (# lines) + // batch + if (batch) { + batch_update /= (weight_t)num_up; + lambdas.plus_eq_v_times_s(batch_update, eta); + lambdas.init_vector(&decoder_weights); + } + // update average if (average) w_average += lambdas; + if (adadelta_output != "") { + WriteFile g(adadelta_output+".gradient.gz"); + for (auto it: gradient_accum) + *g << FD::Convert(it.first) << " " << it.second << endl; + WriteFile u(adadelta_output+".update.gz"); + for (auto it: update_accum) + *u << FD::Convert(it.first) << " " << it.second << endl; + } + // stats weight_t gold_avg = gold_sum/(weight_t)input_sz; - cerr << setiosflags(ios::showpos) << "WEIGHTS" << endl; - for (auto name: print_weights) + cerr << setiosflags(ios::showpos) << scientific << "WEIGHTS" << endl; + for (auto name: print_weights) { cerr << setw(18) << name << " = " - << lambdas.get(FD::Convert(name)) << endl; + << lambdas.get(FD::Convert(name)); + if (use_adadelta) { + weight_t rate = -1.0*(sqrt(update_accum[FD::Convert(name)]+adadelta_eta) + / sqrt(gradient_accum[FD::Convert(name)]+adadelta_eta)); + cerr << " {" << rate << "}"; + } + cerr << endl; + } cerr << " ---" << endl; cerr << resetiosflags(ios::showpos) << " 1best avg score: " << gold_avg*100; - cerr << setiosflags(ios::showpos) << " (" + cerr << setiosflags(ios::showpos) << fixed << " (" << (gold_avg-gold_prev)*100 << ")" << endl; - cerr << " 1best avg model score: " + cerr << scientific << " 1best avg model score: " << model_sum/(weight_t)input_sz << endl; + cerr << fixed; cerr << " avg # updates: "; cerr << resetiosflags(ios::showpos) << num_up/(float)input_sz << endl; cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl; -- cgit v1.2.3