summaryrefslogtreecommitdiff
path: root/dtrain
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain')
-rw-r--r--dtrain/README.md62
-rw-r--r--dtrain/dtrain.cc25
2 files changed, 82 insertions, 5 deletions
diff --git a/dtrain/README.md b/dtrain/README.md
index d78dc100..32dafade 100644
--- a/dtrain/README.md
+++ b/dtrain/README.md
@@ -336,3 +336,65 @@ ioh: 4
when does overfitting begin?
+---
+Variables
+ k 100..1500 higher better
+ N 3/4
+ learning rate
+ reg/gamma
+ epochs -> best on devtest (10..30) (select_weights)
+ scorer -> approx_bleu correlates ok (stupid bleu, bleu, smooth bleu)
+ sample from -> kbest | forest
+ filter -> no uniq (kbest)
+ pair sampling -> all 5050 108010 PRO alld
+ update_ok -> update towards correctly ranked
+ features
+ 6x tm
+ 2x lm
+ wp
+ Glue
+ rule ids
+ rule ngrams
+ rule shape
+ span features
+
+
+PRO
+ k = 1500
+ N = 4
+ learning rate = 0.0005
+ gamma = 0
+ epochs = 30
+ scorer = stupid bleu (Bleu+1)
+ sample from = kbest
+ filter = no
+ pair sampling = PRO
+ update_ok
+ features = base
+
+cur:
+ shard_sz 500 1k 3k
+ PRO with forest sampling
+ PRO w/o update_ok
+ tune learning rate
+ all with discard (not only top 50)
+ filter kbest uniq?
+
+ -> repeat most on Tset, lXlX stuff
+ -> PRO approx bleu
+ -> tune gamma
+ -> best pair sampling method
+ -> reduce k?
+ => scorer => approx_bleu (test w PRO)
+ -> PRO on training set
+ -> PRO more features
+ -> discard + 108010
+
+
+
+--
+forest vs kbest count vocab?
+108010 select discard
+approx bleu
+
+
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 581c985a..e07b9307 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -30,6 +30,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
("funny", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
+ ("average", po::value<bool>()->zero_tokens(), "output weight vector is average of all iterations")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -115,6 +116,9 @@ main(int argc, char** argv)
bool funny = false;
if (cfg.count("funny"))
funny = true;
+ bool average = false;
+ if (cfg.count("average"))
+ average = true;
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
@@ -167,7 +171,7 @@ main(int argc, char** argv)
// init weights
vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
- SparseVector<weight_t> lambdas, cumulative_penalties;
+ SparseVector<weight_t> lambdas, cumulative_penalties, w_average;
if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights);
Weights::InitSparseVector(dense_weights, &lambdas);
@@ -480,6 +484,8 @@ if (false) {
} // input loop
+ if (average) w_average += lambdas;
+
if (scorer_str == "approx_bleu") scorer->Reset();
if (t == 0) {
@@ -572,20 +578,29 @@ if (false) {
} // outer loop
+ if (average) w_average /= (weight_t)T;
+
#ifndef DTRAIN_LOCAL
unlink(grammar_buf_fn.c_str());
#endif
if (!noup) {
if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl;
- if (select_weights == "last") { // last
+ if (select_weights == "last" || average) { // last, average
WriteFile of(output_fn); // works with '-'
ostream& o = *of.stream();
o.precision(17);
o << _np;
- for (SparseVector<weight_t>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
- if (it->second == 0) continue;
- o << FD::Convert(it->first) << '\t' << it->second << endl;
+ if (average) {
+ for (SparseVector<weight_t>::const_iterator it = w_average.begin(); it != w_average.end(); ++it) {
+ if (it->second == 0) continue;
+ o << FD::Convert(it->first) << '\t' << it->second << endl;
+ }
+ } else {
+ for (SparseVector<weight_t>::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) {
+ if (it->second == 0) continue;
+ o << FD::Convert(it->first) << '\t' << it->second << endl;
+ }
}
} else if (select_weights == "VOID") { // do nothing with the weights
} else { // best