diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 737326f8..0b648d95 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -41,8 +41,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate") ("batch", po::value<bool>()->zero_tokens(), "do batch optimization") - ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times") + ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times") ("check", po::value<bool>()->zero_tokens(), "produce list of loss differentials") + ("print_ranking", po::value<bool>()->zero_tokens(), "output kbest with model score and metric") ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -71,7 +72,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) return false; } if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "XYX" && - (*cfg)["pair_sampling"].as<string>() != "PRO") { + (*cfg)["pair_sampling"].as<string>() != "PRO" && (*cfg)["pair_sampling"].as<string>() != "output_pairs") { cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl; return false; } @@ -114,6 +115,8 @@ main(int argc, char** argv) if (cfg.count("rescale")) rescale = true; bool keep = false; if (cfg.count("keep")) keep = true; + bool print_ranking = false; + if (cfg.count("print_ranking")) print_ranking = true; const unsigned k = cfg["k"].as<unsigned>(); const unsigned N = cfg["N"].as<unsigned>(); @@ -356,6 +359,14 @@ main(int argc, char** argv) // get (scored) samples vector<ScoredHyp>* samples = observer->GetSamples(); + if (print_ranking) { + for (auto s: *samples) { + cout << ii << " ||| "; + printWordIDVec(s.w, cout); + cout << " ||| " << s.model << " ||| " << s.score << endl; + } + } + if (verbose) { cerr << "--- refs for " << ii << ": "; for (auto r: refs_as_ids_buf[ii]) { @@ -389,6 +400,8 @@ main(int argc, char** argv) partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo); if (pair_sampling == "PRO") PROsampling(samples, pairs, pair_threshold, max_pairs); + if (pair_sampling == "output_pairs") + all_pairs(samples, pairs, pair_threshold, max_pairs, false); int cur_npairs = pairs.size(); npairs += cur_npairs; @@ -397,6 +410,15 @@ main(int argc, char** argv) if (check) repeat = 2; vector<float> losses; // for check + if (pair_sampling == "output_pairs") { + for (auto p: pairs) { + cout << p.first.model << " ||| " << p.first.score << " ||| " << p.first.f << endl; + cout << p.second.model << " ||| " << p.second.score << " ||| " << p.second.f << endl; + cout << endl; + } + continue; + } + for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { score_t model_diff = it->first.model - it->second.model; @@ -547,6 +569,7 @@ main(int argc, char** argv) if (t == 0) in_sz = ii; // remember size of input (# lines) + if (print_ranking) cout << "---" << endl; if (batch) { lambdas.plus_eq_v_times_s(batch_updates, eta); |