From 160dbdfa96ae57df82bc33475578904e2cd23317 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Sun, 1 Feb 2015 22:32:40 +0100 Subject: dtrain: simplified pair generation --- training/dtrain/dtrain.cc | 42 +---------------- training/dtrain/pairs.h | 114 ++++++---------------------------------------- 2 files changed, 16 insertions(+), 140 deletions(-) (limited to 'training') diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 18addcb0..69630206 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -21,9 +21,7 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) ("epochs", po::value()->default_value(10), "# of iterations T (per shard)") ("k", po::value()->default_value(100), "how many translations to sample") ("filter", po::value()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") - ("pair_sampling", po::value()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'") ("hi_lo", po::value()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5") - ("pair_threshold", po::value()->default_value(0.), "bleu [0,1] threshold to filter pairs") ("N", po::value()->default_value(4), "N for Ngrams (BLEU)") ("scorer", po::value()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_, lc_") ("learning_rate", po::value()->default_value(1.0), "learning rate") @@ -34,7 +32,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) ("l1_reg_strength", po::value(), "l1 regularization strength") ("fselect", po::value()->default_value(-1), "select top x percent (or by threshold) of features after each epoch NOT IMPLEMENTED") // TODO ("loss_margin", po::value()->default_value(0.), "update if no error in pref pair but model scores this near") - ("max_pairs", po::value()->default_value(std::numeric_limits::max()), "max. # of pairs per Sent.") ("pclr", po::value()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate") ("batch", po::value()->zero_tokens(), "do batch optimization") ("repeat", po::value()->default_value(1), "repeat optimization over kbest list this number of times") @@ -56,14 +53,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) cerr << cl << endl; return false; } - if ((*conf)["pair_sampling"].as() != "all" && (*conf)["pair_sampling"].as() != "XYX" && - (*conf)["pair_sampling"].as() != "PRO" && (*conf)["pair_sampling"].as() != "output_pairs") { - cerr << "Wrong 'pair_sampling' param: '" << (*conf)["pair_sampling"].as() << "'." << endl; - return false; - } - if (conf->count("hi_lo") && (*conf)["pair_sampling"].as() != "XYX") { - cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl; - } if ((*conf)["hi_lo"].as() > 0.5 || (*conf)["hi_lo"].as() < 0.01) { cerr << "hi_lo must lie in [0.01, 0.5]" << endl; return false; @@ -72,10 +61,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) cerr << "No training data given." << endl; return false; } - if ((*conf)["pair_threshold"].as() < 0) { - cerr << "The threshold must be >= 0!" << endl; - return false; - } if ((*conf)["select_weights"].as() != "last" && (*conf)["select_weights"].as() != "best" && (*conf)["select_weights"].as() != "avg" && (*conf)["select_weights"].as() != "VOID") { cerr << "Wrong 'select_weights' param: '" << (*conf)["select_weights"].as() << "', use 'last' or 'best'." << endl; @@ -106,12 +91,9 @@ main(int argc, char** argv) const unsigned N = conf["N"].as(); const unsigned T = conf["epochs"].as(); const unsigned stop_after = conf["stop_after"].as(); - const string pair_sampling = conf["pair_sampling"].as(); - const score_t pair_threshold = conf["pair_threshold"].as(); const string select_weights = conf["select_weights"].as(); const string output_ranking = conf["output_ranking"].as(); const float hi_lo = conf["hi_lo"].as(); - const unsigned max_pairs = conf["max_pairs"].as(); int repeat = conf["repeat"].as(); weight_t loss_margin = conf["loss_margin"].as(); bool batch = false; @@ -192,17 +174,13 @@ main(int argc, char** argv) cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "loss margin " << loss_margin << endl; cerr << setw(25) << "faster perceptron " << faster_perceptron << endl; - cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; - if (pair_sampling == "XYX") - cerr << setw(25) << "hi lo " << hi_lo << endl; - cerr << setw(25) << "pair threshold " << pair_threshold << endl; + cerr << setw(25) << "hi lo " << hi_lo << endl; cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; if (conf.count("l1_reg")) cerr << setw(25) << "l1 reg " << l1_reg << " '" << conf["l1_reg"].as() << "'" << endl; if (rescale) cerr << setw(25) << "rescale " << rescale << endl; cerr << setw(25) << "pclr " << pclr << endl; - cerr << setw(25) << "max pairs " << max_pairs << endl; cerr << setw(25) << "repeat " << repeat << endl; cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; @@ -335,28 +313,12 @@ main(int argc, char** argv) if (!noup) { // get pairs vector > pairs; - if (pair_sampling == "all") - all_pairs(samples, pairs, pair_threshold, max_pairs, faster_perceptron); - if (pair_sampling == "XYX") - partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo); - if (pair_sampling == "PRO") - PROsampling(samples, pairs, pair_threshold, max_pairs); - if (pair_sampling == "output_pairs") - all_pairs(samples, pairs, pair_threshold, max_pairs, false); + MakePairs(samples, pairs, faster_perceptron, hi_lo); int cur_npairs = pairs.size(); npairs += cur_npairs; score_t kbest_loss_first = 0.0, kbest_loss_last = 0.0; - if (pair_sampling == "output_pairs") { - for (auto p: pairs) { - cout << p.first.model << " ||| " << p.first.score << " ||| " << p.first.f << endl; - cout << p.second.model << " ||| " << p.second.score << " ||| " << p.second.f << endl; - cout << endl; - } - continue; - } - for (vector >::iterator it = pairs.begin(); it != pairs.end(); it++) { if (rescale) { diff --git a/training/dtrain/pairs.h b/training/dtrain/pairs.h index fd08be8c..dea0dabc 100644 --- a/training/dtrain/pairs.h +++ b/training/dtrain/pairs.h @@ -1,140 +1,54 @@ -#ifndef _DTRAIN_PAIRSAMPLING_H_ -#define _DTRAIN_PAIRSAMPLING_H_ +#ifndef _DTRAIN_PAIRS_H_ +#define _DTRAIN_PAIRS_H_ namespace dtrain { - bool -accept_pair(score_t a, score_t b, score_t threshold) -{ - if (fabs(a - b) < threshold) return false; - return true; -} - -bool -cmp_hyp_by_score_d(ScoredHyp a, ScoredHyp b) +CmpHypsByScore(ScoredHyp a, ScoredHyp b) { return a.score > b.score; } -inline void -all_pairs(vector* s, vector >& training, score_t threshold, unsigned max, bool misranked_only, float _unused=1) -{ - sort(s->begin(), s->end(), cmp_hyp_by_score_d); - unsigned sz = s->size(); - bool b = false; - unsigned count = 0; - for (unsigned i = 0; i < sz-1; i++) { - for (unsigned j = i+1; j < sz; j++) { - if (misranked_only && !((*s)[i].model <= (*s)[j].model)) continue; - if (threshold > 0) { - if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) - training.push_back(make_pair((*s)[i], (*s)[j])); - } else { - if ((*s)[i].score != (*s)[j].score) - training.push_back(make_pair((*s)[i], (*s)[j])); - } - if (++count == max) { - b = true; - break; - } - } - if (b) break; - } -} - /* * multipartite ranking * sort (descending) by bleu - * compare top X to middle Y and low X + * compare top X (hi) to middle Y (med) and low X (lo) * cmp middle Y to low X */ - inline void -partXYX(vector* s, vector >& training, score_t threshold, unsigned max, bool misranked_only, float hi_lo) +MakePairs(vector* s, + vector >& training, + bool misranked_only, + float hi_lo) { unsigned sz = s->size(); if (sz < 2) return; - sort(s->begin(), s->end(), cmp_hyp_by_score_d); + sort(s->begin(), s->end(), CmpHypsByScore); unsigned sep = round(sz*hi_lo); + // hi vs. med vs. low unsigned sep_hi = sep; if (sz > 4) while (sep_hi < sz && (*s)[sep_hi-1].score == (*s)[sep_hi].score) ++sep_hi; else sep_hi = 1; - bool b = false; - unsigned count = 0; for (unsigned i = 0; i < sep_hi; i++) { for (unsigned j = sep_hi; j < sz; j++) { if (misranked_only && !((*s)[i].model <= (*s)[j].model)) continue; - if (threshold > 0) { - if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) - training.push_back(make_pair((*s)[i], (*s)[j])); - } else { - if ((*s)[i].score != (*s)[j].score) - training.push_back(make_pair((*s)[i], (*s)[j])); - } - if (++count == max) { - b = true; - break; - } + if ((*s)[i].score != (*s)[j].score) + training.push_back(make_pair((*s)[i], (*s)[j])); } - if (b) break; } + // med vs. low unsigned sep_lo = sz-sep; while (sep_lo > 0 && (*s)[sep_lo-1].score == (*s)[sep_lo].score) --sep_lo; for (unsigned i = sep_hi; i < sep_lo; i++) { for (unsigned j = sep_lo; j < sz; j++) { if (misranked_only && !((*s)[i].model <= (*s)[j].model)) continue; - if (threshold > 0) { - if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) - training.push_back(make_pair((*s)[i], (*s)[j])); - } else { - if ((*s)[i].score != (*s)[j].score) - training.push_back(make_pair((*s)[i], (*s)[j])); - } - if (++count == max) return; - } - } -} - -/* - * pair sampling as in - * 'Tuning as Ranking' (Hopkins & May, 2011) - * count = max (5000) - * threshold = 5% BLEU (0.05 for param 3) - * cut = top 10% - */ -bool -_PRO_cmp_pair_by_diff_d(pair a, pair b) -{ - return (fabs(a.first.score - a.second.score)) > (fabs(b.first.score - b.second.score)); -} -inline void -PROsampling(vector* s, vector >& training, score_t threshold, unsigned max, bool _unused=false, float _also_unused=0) -{ - sort(s->begin(), s->end(), cmp_hyp_by_score_d); - unsigned max_count = max, count = 0, sz = s->size(); - bool b = false; - for (unsigned i = 0; i < sz-1; i++) { - for (unsigned j = i+1; j < sz; j++) { - if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) { + if ((*s)[i].score != (*s)[j].score) training.push_back(make_pair((*s)[i], (*s)[j])); - if (++count == max_count) { - b = true; - break; - } - } } - if (b) break; - } - if (training.size() > max/10) { - sort(training.begin(), training.end(), _PRO_cmp_pair_by_diff_d); - training.erase(training.begin()+(max/10), training.end()); } - return; } - } // namespace #endif -- cgit v1.2.3