summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc25
1 files changed, 20 insertions, 5 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 581c985a..e07b9307 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -30,6 +30,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
("funny", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
+ ("average", po::value<bool>()->zero_tokens(), "output weight vector is average of all iterations")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -115,6 +116,9 @@ main(int argc, char** argv)
bool funny = false;
if (cfg.count("funny"))
funny = true;
+ bool average = false;
+ if (cfg.count("average"))
+ average = true;
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
@@ -167,7 +171,7 @@ main(int argc, char** argv)
// init weights
vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
- SparseVector<weight_t> lambdas, cumulative_penalties;
+ SparseVector<weight_t> lambdas, cumulative_penalties, w_average;
if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights);
Weights::InitSparseVector(dense_weights, &lambdas);
@@ -480,6 +484,8 @@ if (false) {
} // input loop
+ if (average) w_average += lambdas;
+
if (scorer_str == "approx_bleu") scorer->Reset();
if (t == 0) {
@@ -572,20 +578,29 @@ if (false) {
} // outer loop
+ if (average) w_average /= (weight_t)T;
+
#ifndef DTRAIN_LOCAL
unlink(grammar_buf_fn.c_str());
#endif
if (!noup) {
if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl;
- if (select_weights == "last") { // last
+ if (select_weights == "last" || average) { // last, average
WriteFile of(output_fn); // works with '-'
ostream& o = *of.stream();
o.precision(17);
o << _np;
- for (SparseVector<weight_t>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
- if (it->second == 0) continue;
- o << FD::Convert(it->first) << '\t' << it->second << endl;
+ if (average) {
+ for (SparseVector<weight_t>::const_iterator it = w_average.begin(); it != w_average.end(); ++it) {
+ if (it->second == 0) continue;
+ o << FD::Convert(it->first) << '\t' << it->second << endl;
+ }
+ } else {
+ for (SparseVector<weight_t>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
+ if (it->second == 0) continue;
+ o << FD::Convert(it->first) << '\t' << it->second << endl;
+ }
}
} else if (select_weights == "VOID") { // do nothing with the weights
} else { // best