diff options
author | Patrick Simianer <p@simianer.de> | 2015-01-29 19:50:27 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2015-01-29 19:50:27 +0100 |
commit | 3b371a950f9ae04c9072f3df9b21dafa475916fa (patch) | |
tree | 25fabc78bfab76ce746967c76aea5bc85f8c69d9 /training | |
parent | 41eb8edaf5965b8efbe0ace199905927452e895d (diff) |
dtrain: output_pairs, print_ranking
Diffstat (limited to 'training')
-rw-r--r-- | training/dtrain/dtrain.cc | 27 | ||||
-rw-r--r-- | training/dtrain/score.cc | 2 |
2 files changed, 26 insertions, 3 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 737326f8..0b648d95 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -41,8 +41,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate") ("batch", po::value<bool>()->zero_tokens(), "do batch optimization") - ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times") + ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times") ("check", po::value<bool>()->zero_tokens(), "produce list of loss differentials") + ("print_ranking", po::value<bool>()->zero_tokens(), "output kbest with model score and metric") ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -71,7 +72,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) return false; } if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "XYX" && - (*cfg)["pair_sampling"].as<string>() != "PRO") { + (*cfg)["pair_sampling"].as<string>() != "PRO" && (*cfg)["pair_sampling"].as<string>() != "output_pairs") { cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl; return false; } @@ -114,6 +115,8 @@ main(int argc, char** argv) if (cfg.count("rescale")) rescale = true; bool keep = false; if (cfg.count("keep")) keep = true; + bool print_ranking = false; + if (cfg.count("print_ranking")) print_ranking = true; const unsigned k = cfg["k"].as<unsigned>(); const unsigned N = cfg["N"].as<unsigned>(); @@ -356,6 +359,14 @@ main(int argc, char** argv) // get (scored) samples vector<ScoredHyp>* samples = observer->GetSamples(); + if (print_ranking) { + for (auto s: *samples) { + cout << ii << " ||| "; + printWordIDVec(s.w, cout); + cout << " ||| " << s.model << " ||| " << s.score << endl; + } + } + if (verbose) { cerr << "--- refs for " << ii << ": "; for (auto r: refs_as_ids_buf[ii]) { @@ -389,6 +400,8 @@ main(int argc, char** argv) partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo); if (pair_sampling == "PRO") PROsampling(samples, pairs, pair_threshold, max_pairs); + if (pair_sampling == "output_pairs") + all_pairs(samples, pairs, pair_threshold, max_pairs, false); int cur_npairs = pairs.size(); npairs += cur_npairs; @@ -397,6 +410,15 @@ main(int argc, char** argv) if (check) repeat = 2; vector<float> losses; // for check + if (pair_sampling == "output_pairs") { + for (auto p: pairs) { + cout << p.first.model << " ||| " << p.first.score << " ||| " << p.first.f << endl; + cout << p.second.model << " ||| " << p.second.score << " ||| " << p.second.f << endl; + cout << endl; + } + continue; + } + for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { score_t model_diff = it->first.model - it->second.model; @@ -547,6 +569,7 @@ main(int argc, char** argv) if (t == 0) in_sz = ii; // remember size of input (# lines) + if (print_ranking) cout << "---" << endl; if (batch) { lambdas.plus_eq_v_times_s(batch_updates, eta); diff --git a/training/dtrain/score.cc b/training/dtrain/score.cc index d81eafcb..8a28771f 100644 --- a/training/dtrain/score.cc +++ b/training/dtrain/score.cc @@ -36,7 +36,7 @@ RefLen(vector<vector<WordID> > refs) { size_t ref_len = 0; for (auto r: refs) - ref_len = max(ref_len, r.size()); + ref_len = max(ref_len, r.size()); // FIXME return ref_len; } |