summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc127
1 files changed, 68 insertions, 59 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index 434ae2d6..581c985a 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -10,25 +10,26 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
("decoder_config", po::value<string>(), "configuration file for cdec")
- ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: kbest, forest")
+ ("print_weights", po::value<string>(), "weights to print on each iteration")
+ ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences")
+ ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
+ ("keep", po::value<bool>()->zero_tokens(), "keep weights files for each iteration")
+ ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id")
+ ("epochs", po::value<unsigned>()->default_value(10), "# of iterations T (per shard)")
("k", po::value<unsigned>()->default_value(100), "how many translations to sample")
- ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: no, uniq")
- ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, 5050, 108010, PRO")
- ("N", po::value<unsigned>()->default_value(3), "N for Ngrams (BLEU)")
- ("epochs", po::value<unsigned>()->default_value(2), "# of iterations T (per shard)")
- ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_*, smooth_*, approx_*")
- ("learning_rate", po::value<weight_t>()->default_value(0.0005), "learning rate")
+ ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'")
+ ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'")
+ ("pair_sampling", po::value<string>()->default_value("108010"), "how to sample pairs: 'all', '108010' or 'PRO'")
+ ("pair_threshold", po::value<score_t>()->default_value(0), "bleu [0,1] threshold to filter pairs")
+ ("N", po::value<unsigned>()->default_value(4), "N for Ngrams (BLEU)")
+ ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_")
+ ("learning_rate", po::value<weight_t>()->default_value(0.0001), "learning rate")
("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)")
("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
- ("unit_wv", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input")
- ("l1_reg", po::value<string>()->default_value("no"), "apply l1 regularization as in Tsuroka et al 2010")
+ ("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input")
+ ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
- ("update_ok", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
- ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences")
- ("keep_w", po::value<bool>()->zero_tokens(), "keep weights files for each iteration")
- ("print_weights", po::value<string>(), "weights to print on each iteration")
- ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id")
- ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use")
+ ("funny", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -64,18 +65,22 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
return false;
}
- if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "uniq"
- && (*cfg)["filter"].as<string>() != "no") {
- cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'no'." << endl;
+ if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "uniq" &&
+ (*cfg)["filter"].as<string>() != "not") {
+ cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'not'." << endl;
return false;
}
- string s = (*cfg)["pair_sampling"].as<string>();
- if (s != "all" && s != "5050" && s != "108010" && s != "PRO" && s != "alld" && s != "108010d") {
+ if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "108010" &&
+ (*cfg)["pair_sampling"].as<string>() != "PRO") {
cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;
return false;
}
- if ((*cfg)["select_weights"].as<string>() != "last"
- && (*cfg)["select_weights"].as<string>() != "best" && (*cfg)["select_weights"].as<string>() != "VOID") {
+ if ((*cfg)["pair_threshold"].as<score_t>() < 0) {
+ cerr << "The threshold must be >= 0!" << endl;
+ return false;
+ }
+ if ((*cfg)["select_weights"].as<string>() != "last" && (*cfg)["select_weights"].as<string>() != "best" &&
+ (*cfg)["select_weights"].as<string>() != "VOID") {
cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl;
return false;
}
@@ -102,14 +107,14 @@ main(int argc, char** argv)
task_id = cfg["hstreaming"].as<string>();
cerr.precision(17);
}
- bool unit_wv = false;
- if (cfg.count("unit_wv")) unit_wv = true;
+ bool rescale = false;
+ if (cfg.count("rescale")) rescale = true;
HSReporter rep(task_id);
- bool keep_w = false;
- if (cfg.count("keep_w")) keep_w = true;
- bool update_ok = false;
- if (cfg.count("update_ok"))
- update_ok = true;
+ bool keep = false;
+ if (cfg.count("keep")) keep = true;
+ bool funny = false;
+ if (cfg.count("funny"))
+ funny = true;
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
@@ -118,6 +123,7 @@ main(int argc, char** argv)
const string filter_type = cfg["filter"].as<string>();
const string sample_from = cfg["sample_from"].as<string>();
const string pair_sampling = cfg["pair_sampling"].as<string>();
+ const score_t pair_threshold = cfg["pair_threshold"].as<score_t>();
const string select_weights = cfg["select_weights"].as<string>();
vector<string> print_weights;
if (cfg.count("print_weights"))
@@ -168,12 +174,13 @@ main(int argc, char** argv)
// meta params for perceptron, SVM
weight_t eta = cfg["learning_rate"].as<weight_t>();
weight_t gamma = cfg["gamma"].as<weight_t>();
+
// l1 regularization
bool l1naive = false;
bool l1clip = false;
bool l1cumul = false;
weight_t l1_reg = 0;
- if (cfg["l1_reg"].as<string>() != "no") {
+ if (cfg["l1_reg"].as<string>() != "none") {
string s = cfg["l1_reg"].as<string>();
if (s == "naive") l1naive = true;
else if (s == "clip") l1clip = true;
@@ -191,7 +198,7 @@ main(int argc, char** argv)
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
// where temp files go
string tmp_path = cfg["tmp"].as<string>();
- vector<string> w_tmp_files; // used for keep_w
+ vector<string> w_tmp_files; // used for keep
#ifdef DTRAIN_LOCAL
string refs_fn = cfg["refs"].as<string>();
ReadFile refs(refs_fn);
@@ -214,28 +221,30 @@ main(int argc, char** argv)
cerr << setw(25) << "k " << k << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
- if (cfg.count("stop-after"))
- cerr << setw(25) << "stop_after " << stop_after << endl;
- if (cfg.count("input_weights"))
- cerr << setw(25) << "weights in" << cfg["input_weights"].as<string>() << endl;
- cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
-#ifdef DTRAIN_LOCAL
- cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
-#endif
- cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
+ cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
if (sample_from == "kbest")
cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
cerr << setw(25) << "learning rate " << eta << endl;
cerr << setw(25) << "gamma " << gamma << endl;
- cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;
+ cerr << setw(25) << "pair threshold " << pair_threshold << endl;
cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;
if (cfg.count("l1_reg"))
cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;
- if (update_ok)
- cerr << setw(25) << "up ok " << update_ok << endl;
- if (unit_wv)
- cerr << setw(25) << "unit weight vec " << unit_wv << endl;
+ if (funny)
+ cerr << setw(25) << "funny " << funny << endl;
+ if (rescale)
+ cerr << setw(25) << "rescale " << rescale << endl;
+ cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
+ cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
+#ifdef DTRAIN_LOCAL
+ cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
+#endif
+ cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
+ if (cfg.count("input_weights"))
+ cerr << setw(25) << "weights in" << cfg["input_weights"].as<string>() << endl;
+ if (cfg.count("stop-after"))
+ cerr << setw(25) << "stop_after " << stop_after << endl;
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl;
}
@@ -382,17 +391,11 @@ main(int argc, char** argv)
if (!noup) {
vector<pair<ScoredHyp,ScoredHyp> > pairs;
if (pair_sampling == "all")
- all_pairs(samples, pairs);
- if (pair_sampling == "5050")
- rand_pairs_5050(samples, pairs, &rng);
+ all_pairs(samples, pairs, pair_threshold);
if (pair_sampling == "108010")
- multpart108010(samples, pairs);
+ part108010(samples, pairs, pair_threshold);
if (pair_sampling == "PRO")
PROsampling(samples, pairs);
- if (pair_sampling == "alld")
- all_pairs_discard(samples, pairs);
- if (pair_sampling == "108010d")
- multpart108010_discard(samples, pairs);
npairs += pairs.size();
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
@@ -405,7 +408,7 @@ main(int argc, char** argv)
lambdas.plus_eq_v_times_s(diff_vec, eta);
rank_errors++;
} else {
- if (update_ok) {
+ if (funny) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
lambdas.plus_eq_v_times_s(diff_vec, eta);
}
@@ -429,6 +432,7 @@ main(int argc, char** argv)
// TEST THIS
// reset cumulative_penalties after 1 iter?
// do this only once per INPUT (not per pair)
+if (false) {
if (l1naive) {
for (unsigned d = 0; d < lambdas.size(); d++) {
weight_t v = lambdas.get(d);
@@ -462,9 +466,10 @@ main(int argc, char** argv)
}
}
}
+}
////////
- if (unit_wv && sample_from == "forest") lambdas /= lambdas.l2norm();
+ if (rescale) lambdas /= lambdas.l2norm();
++ii;
@@ -505,6 +510,9 @@ main(int argc, char** argv)
score_diff = score_avg;
model_diff = model_avg;
}
+
+ unsigned nonz;
+ if (!quiet || hstreaming) nonz = (unsigned)lambdas.size_nonzero();
if (!quiet) {
cerr << _p5 << _p << "WEIGHTS" << endl;
@@ -522,6 +530,8 @@ main(int argc, char** argv)
cerr << rank_errors/(float)in_sz << endl;
cerr << " avg #margin viol: ";
cerr << margin_violations/float(in_sz) << endl;
+ cerr << " non0 feature count: ";
+ cerr << nonz << endl;
}
if (hstreaming) {
@@ -530,7 +540,6 @@ main(int argc, char** argv)
rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), (unsigned)((npairs/(weight_t)in_sz)*DTRAIN_SCALE));
rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*DTRAIN_SCALE));
rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*DTRAIN_SCALE));
- unsigned nonz = (unsigned)lambdas.size_nonzero();
rep.update_counter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz);
rep.update_gcounter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz);
}
@@ -555,7 +564,7 @@ main(int argc, char** argv)
if (noup) break;
// write weights to file
- if (select_weights == "best" || keep_w) {
+ if (select_weights == "best" || keep) {
lambdas.init_vector(&dense_weights);
string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz";
Weights::WriteToFile(w_fn, dense_weights, true);
@@ -589,7 +598,7 @@ main(int argc, char** argv)
cout << _np;
while(getline(*bestw, o)) cout << o << endl;
}
- if (!keep_w) {
+ if (!keep) {
for (unsigned i = 0; i < T; i++) {
string s = "weights." + boost::lexical_cast<string>(i) + ".gz";
unlink(s.c_str());
@@ -606,7 +615,7 @@ main(int argc, char** argv)
cerr << _p2 << "This took " << overall_time/60. << " min." << endl;
}
- if (keep_w) {
+ if (keep) {
cout << endl << "Weight files per iteration:" << endl;
for (unsigned i = 0; i < w_tmp_files.size(); i++) {
cout << w_tmp_files[i] << endl;