diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 42 |
1 files changed, 2 insertions, 40 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 18addcb0..69630206 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -21,9 +21,7 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) ("epochs", po::value<unsigned>()->default_value(10), "# of iterations T (per shard)") ("k", po::value<unsigned>()->default_value(100), "how many translations to sample") ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") - ("pair_sampling", po::value<string>()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'") ("hi_lo", po::value<float>()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5") - ("pair_threshold", po::value<score_t>()->default_value(0.), "bleu [0,1] threshold to filter pairs") ("N", po::value<unsigned>()->default_value(4), "N for Ngrams (BLEU)") ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_, lc_") ("learning_rate", po::value<weight_t>()->default_value(1.0), "learning rate") @@ -34,7 +32,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") ("fselect", po::value<weight_t>()->default_value(-1), "select top x percent (or by threshold) of features after each epoch NOT IMPLEMENTED") // TODO ("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near") - ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate") ("batch", po::value<bool>()->zero_tokens(), "do batch optimization") ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times") @@ -56,14 +53,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) cerr << cl << endl; return false; } - if ((*conf)["pair_sampling"].as<string>() != "all" && (*conf)["pair_sampling"].as<string>() != "XYX" && - (*conf)["pair_sampling"].as<string>() != "PRO" && (*conf)["pair_sampling"].as<string>() != "output_pairs") { - cerr << "Wrong 'pair_sampling' param: '" << (*conf)["pair_sampling"].as<string>() << "'." << endl; - return false; - } - if (conf->count("hi_lo") && (*conf)["pair_sampling"].as<string>() != "XYX") { - cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl; - } if ((*conf)["hi_lo"].as<float>() > 0.5 || (*conf)["hi_lo"].as<float>() < 0.01) { cerr << "hi_lo must lie in [0.01, 0.5]" << endl; return false; @@ -72,10 +61,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) cerr << "No training data given." << endl; return false; } - if ((*conf)["pair_threshold"].as<score_t>() < 0) { - cerr << "The threshold must be >= 0!" << endl; - return false; - } if ((*conf)["select_weights"].as<string>() != "last" && (*conf)["select_weights"].as<string>() != "best" && (*conf)["select_weights"].as<string>() != "avg" && (*conf)["select_weights"].as<string>() != "VOID") { cerr << "Wrong 'select_weights' param: '" << (*conf)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl; @@ -106,12 +91,9 @@ main(int argc, char** argv) const unsigned N = conf["N"].as<unsigned>(); const unsigned T = conf["epochs"].as<unsigned>(); const unsigned stop_after = conf["stop_after"].as<unsigned>(); - const string pair_sampling = conf["pair_sampling"].as<string>(); - const score_t pair_threshold = conf["pair_threshold"].as<score_t>(); const string select_weights = conf["select_weights"].as<string>(); const string output_ranking = conf["output_ranking"].as<string>(); const float hi_lo = conf["hi_lo"].as<float>(); - const unsigned max_pairs = conf["max_pairs"].as<unsigned>(); int repeat = conf["repeat"].as<unsigned>(); weight_t loss_margin = conf["loss_margin"].as<weight_t>(); bool batch = false; @@ -192,17 +174,13 @@ main(int argc, char** argv) cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "loss margin " << loss_margin << endl; cerr << setw(25) << "faster perceptron " << faster_perceptron << endl; - cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; - if (pair_sampling == "XYX") - cerr << setw(25) << "hi lo " << hi_lo << endl; - cerr << setw(25) << "pair threshold " << pair_threshold << endl; + cerr << setw(25) << "hi lo " << hi_lo << endl; cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; if (conf.count("l1_reg")) cerr << setw(25) << "l1 reg " << l1_reg << " '" << conf["l1_reg"].as<string>() << "'" << endl; if (rescale) cerr << setw(25) << "rescale " << rescale << endl; cerr << setw(25) << "pclr " << pclr << endl; - cerr << setw(25) << "max pairs " << max_pairs << endl; cerr << setw(25) << "repeat " << repeat << endl; cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; @@ -335,28 +313,12 @@ main(int argc, char** argv) if (!noup) { // get pairs vector<pair<ScoredHyp,ScoredHyp> > pairs; - if (pair_sampling == "all") - all_pairs(samples, pairs, pair_threshold, max_pairs, faster_perceptron); - if (pair_sampling == "XYX") - partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo); - if (pair_sampling == "PRO") - PROsampling(samples, pairs, pair_threshold, max_pairs); - if (pair_sampling == "output_pairs") - all_pairs(samples, pairs, pair_threshold, max_pairs, false); + MakePairs(samples, pairs, faster_perceptron, hi_lo); int cur_npairs = pairs.size(); npairs += cur_npairs; score_t kbest_loss_first = 0.0, kbest_loss_last = 0.0; - if (pair_sampling == "output_pairs") { - for (auto p: pairs) { - cout << p.first.model << " ||| " << p.first.score << " ||| " << p.first.f << endl; - cout << p.second.model << " ||| " << p.second.score << " ||| " << p.second.f << endl; - cout << endl; - } - continue; - } - for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { if (rescale) { |