summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc42
1 files changed, 2 insertions, 40 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 18addcb0..69630206 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -21,9 +21,7 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
("epochs", po::value<unsigned>()->default_value(10), "# of iterations T (per shard)")
("k", po::value<unsigned>()->default_value(100), "how many translations to sample")
("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'")
- ("pair_sampling", po::value<string>()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'")
("hi_lo", po::value<float>()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5")
- ("pair_threshold", po::value<score_t>()->default_value(0.), "bleu [0,1] threshold to filter pairs")
("N", po::value<unsigned>()->default_value(4), "N for Ngrams (BLEU)")
("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_, lc_")
("learning_rate", po::value<weight_t>()->default_value(1.0), "learning rate")
@@ -34,7 +32,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
("fselect", po::value<weight_t>()->default_value(-1), "select top x percent (or by threshold) of features after each epoch NOT IMPLEMENTED") // TODO
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
- ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate")
("batch", po::value<bool>()->zero_tokens(), "do batch optimization")
("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times")
@@ -56,14 +53,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
cerr << cl << endl;
return false;
}
- if ((*conf)["pair_sampling"].as<string>() != "all" && (*conf)["pair_sampling"].as<string>() != "XYX" &&
- (*conf)["pair_sampling"].as<string>() != "PRO" && (*conf)["pair_sampling"].as<string>() != "output_pairs") {
- cerr << "Wrong 'pair_sampling' param: '" << (*conf)["pair_sampling"].as<string>() << "'." << endl;
- return false;
- }
- if (conf->count("hi_lo") && (*conf)["pair_sampling"].as<string>() != "XYX") {
- cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl;
- }
if ((*conf)["hi_lo"].as<float>() > 0.5 || (*conf)["hi_lo"].as<float>() < 0.01) {
cerr << "hi_lo must lie in [0.01, 0.5]" << endl;
return false;
@@ -72,10 +61,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
cerr << "No training data given." << endl;
return false;
}
- if ((*conf)["pair_threshold"].as<score_t>() < 0) {
- cerr << "The threshold must be >= 0!" << endl;
- return false;
- }
if ((*conf)["select_weights"].as<string>() != "last" && (*conf)["select_weights"].as<string>() != "best" &&
(*conf)["select_weights"].as<string>() != "avg" && (*conf)["select_weights"].as<string>() != "VOID") {
cerr << "Wrong 'select_weights' param: '" << (*conf)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl;
@@ -106,12 +91,9 @@ main(int argc, char** argv)
const unsigned N = conf["N"].as<unsigned>();
const unsigned T = conf["epochs"].as<unsigned>();
const unsigned stop_after = conf["stop_after"].as<unsigned>();
- const string pair_sampling = conf["pair_sampling"].as<string>();
- const score_t pair_threshold = conf["pair_threshold"].as<score_t>();
const string select_weights = conf["select_weights"].as<string>();
const string output_ranking = conf["output_ranking"].as<string>();
const float hi_lo = conf["hi_lo"].as<float>();
- const unsigned max_pairs = conf["max_pairs"].as<unsigned>();
int repeat = conf["repeat"].as<unsigned>();
weight_t loss_margin = conf["loss_margin"].as<weight_t>();
bool batch = false;
@@ -192,17 +174,13 @@ main(int argc, char** argv)
cerr << setw(25) << "gamma " << gamma << endl;
cerr << setw(25) << "loss margin " << loss_margin << endl;
cerr << setw(25) << "faster perceptron " << faster_perceptron << endl;
- cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;
- if (pair_sampling == "XYX")
- cerr << setw(25) << "hi lo " << hi_lo << endl;
- cerr << setw(25) << "pair threshold " << pair_threshold << endl;
+ cerr << setw(25) << "hi lo " << hi_lo << endl;
cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;
if (conf.count("l1_reg"))
cerr << setw(25) << "l1 reg " << l1_reg << " '" << conf["l1_reg"].as<string>() << "'" << endl;
if (rescale)
cerr << setw(25) << "rescale " << rescale << endl;
cerr << setw(25) << "pclr " << pclr << endl;
- cerr << setw(25) << "max pairs " << max_pairs << endl;
cerr << setw(25) << "repeat " << repeat << endl;
cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
@@ -335,28 +313,12 @@ main(int argc, char** argv)
if (!noup) {
// get pairs
vector<pair<ScoredHyp,ScoredHyp> > pairs;
- if (pair_sampling == "all")
- all_pairs(samples, pairs, pair_threshold, max_pairs, faster_perceptron);
- if (pair_sampling == "XYX")
- partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo);
- if (pair_sampling == "PRO")
- PROsampling(samples, pairs, pair_threshold, max_pairs);
- if (pair_sampling == "output_pairs")
- all_pairs(samples, pairs, pair_threshold, max_pairs, false);
+ MakePairs(samples, pairs, faster_perceptron, hi_lo);
int cur_npairs = pairs.size();
npairs += cur_npairs;
score_t kbest_loss_first = 0.0, kbest_loss_last = 0.0;
- if (pair_sampling == "output_pairs") {
- for (auto p: pairs) {
- cout << p.first.model << " ||| " << p.first.score << " ||| " << p.first.f << endl;
- cout << p.second.model << " ||| " << p.second.score << " ||| " << p.second.f << endl;
- cout << endl;
- }
- continue;
- }
-
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
if (rescale) {