From 1db0933e17387525b6f36c4e37f9ae1ae2bfceb6 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Tue, 29 Nov 2011 21:37:33 +0100 Subject: epoch averaging --- dtrain/README.md | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ dtrain/dtrain.cc | 25 ++++++++++++++++++----- 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/dtrain/README.md b/dtrain/README.md index d78dc100..32dafade 100644 --- a/dtrain/README.md +++ b/dtrain/README.md @@ -336,3 +336,65 @@ ioh: 4 when does overfitting begin? +--- +Variables + k 100..1500 higher better + N 3/4 + learning rate + reg/gamma + epochs -> best on devtest (10..30) (select_weights) + scorer -> approx_bleu correlates ok (stupid bleu, bleu, smooth bleu) + sample from -> kbest | forest + filter -> no uniq (kbest) + pair sampling -> all 5050 108010 PRO alld + update_ok -> update towards correctly ranked + features + 6x tm + 2x lm + wp + Glue + rule ids + rule ngrams + rule shape + span features + + +PRO + k = 1500 + N = 4 + learning rate = 0.0005 + gamma = 0 + epochs = 30 + scorer = stupid bleu (Bleu+1) + sample from = kbest + filter = no + pair sampling = PRO + update_ok + features = base + +cur: + shard_sz 500 1k 3k + PRO with forest sampling + PRO w/o update_ok + tune learning rate + all with discard (not only top 50) + filter kbest uniq? + + -> repeat most on Tset, lXlX stuff + -> PRO approx bleu + -> tune gamma + -> best pair sampling method + -> reduce k? + => scorer => approx_bleu (test w PRO) + -> PRO on training set + -> PRO more features + -> discard + 108010 + + + +-- +forest vs kbest count vocab? +108010 select discard +approx bleu + + diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 581c985a..e07b9307 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -30,6 +30,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("l1_reg", po::value()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") ("l1_reg_strength", po::value(), "l1 regularization strength") ("funny", po::value()->zero_tokens(), "include correctly ranked pairs into updates") + ("average", po::value()->zero_tokens(), "output weight vector is average of all iterations") #ifdef DTRAIN_LOCAL ("refs,r", po::value(), "references in local mode") #endif @@ -115,6 +116,9 @@ main(int argc, char** argv) bool funny = false; if (cfg.count("funny")) funny = true; + bool average = false; + if (cfg.count("average")) + average = true; const unsigned k = cfg["k"].as(); const unsigned N = cfg["N"].as(); @@ -167,7 +171,7 @@ main(int argc, char** argv) // init weights vector& dense_weights = decoder.CurrentWeightVector(); - SparseVector lambdas, cumulative_penalties; + SparseVector lambdas, cumulative_penalties, w_average; if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as(), &dense_weights); Weights::InitSparseVector(dense_weights, &lambdas); @@ -480,6 +484,8 @@ if (false) { } // input loop + if (average) w_average += lambdas; + if (scorer_str == "approx_bleu") scorer->Reset(); if (t == 0) { @@ -572,20 +578,29 @@ if (false) { } // outer loop + if (average) w_average /= (weight_t)T; + #ifndef DTRAIN_LOCAL unlink(grammar_buf_fn.c_str()); #endif if (!noup) { if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl; - if (select_weights == "last") { // last + if (select_weights == "last" || average) { // last, average WriteFile of(output_fn); // works with '-' ostream& o = *of.stream(); o.precision(17); o << _np; - for (SparseVector::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { - if (it->second == 0) continue; - o << FD::Convert(it->first) << '\t' << it->second << endl; + if (average) { + for (SparseVector::const_iterator it = w_average.begin(); it != w_average.end(); ++it) { + if (it->second == 0) continue; + o << FD::Convert(it->first) << '\t' << it->second << endl; + } + } else { + for (SparseVector::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { + if (it->second == 0) continue; + o << FD::Convert(it->first) << '\t' << it->second << endl; + } } } else if (select_weights == "VOID") { // do nothing with the weights } else { // best -- cgit v1.2.3