diff options
author | Patrick Simianer <p@simianer.de> | 2011-11-29 21:37:33 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2011-11-29 21:37:33 +0100 |
commit | 1db0933e17387525b6f36c4e37f9ae1ae2bfceb6 (patch) | |
tree | d7bb94d9504bcd159fb50f7dc1efee9cfc1164ba /dtrain/dtrain.cc | |
parent | c7c40bd3585a03f8f0dbb2df7a20aed6960b0922 (diff) |
epoch averaging
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 581c985a..e07b9307 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -30,6 +30,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") ("funny", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") + ("average", po::value<bool>()->zero_tokens(), "output weight vector is average of all iterations") #ifdef DTRAIN_LOCAL ("refs,r", po::value<string>(), "references in local mode") #endif @@ -115,6 +116,9 @@ main(int argc, char** argv) bool funny = false; if (cfg.count("funny")) funny = true; + bool average = false; + if (cfg.count("average")) + average = true; const unsigned k = cfg["k"].as<unsigned>(); const unsigned N = cfg["N"].as<unsigned>(); @@ -167,7 +171,7 @@ main(int argc, char** argv) // init weights vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); - SparseVector<weight_t> lambdas, cumulative_penalties; + SparseVector<weight_t> lambdas, cumulative_penalties, w_average; if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights); Weights::InitSparseVector(dense_weights, &lambdas); @@ -480,6 +484,8 @@ if (false) { } // input loop + if (average) w_average += lambdas; + if (scorer_str == "approx_bleu") scorer->Reset(); if (t == 0) { @@ -572,20 +578,29 @@ if (false) { } // outer loop + if (average) w_average /= (weight_t)T; + #ifndef DTRAIN_LOCAL unlink(grammar_buf_fn.c_str()); #endif if (!noup) { if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl; - if (select_weights == "last") { // last + if (select_weights == "last" || average) { // last, average WriteFile of(output_fn); // works with '-' ostream& o = *of.stream(); o.precision(17); o << _np; - for (SparseVector<weight_t>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { - if (it->second == 0) continue; - o << FD::Convert(it->first) << '\t' << it->second << endl; + if (average) { + for (SparseVector<weight_t>::const_iterator it = w_average.begin(); it != w_average.end(); ++it) { + if (it->second == 0) continue; + o << FD::Convert(it->first) << '\t' << it->second << endl; + } + } else { + for (SparseVector<weight_t>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { + if (it->second == 0) continue; + o << FD::Convert(it->first) << '\t' << it->second << endl; + } } } else if (select_weights == "VOID") { // do nothing with the weights } else { // best |