summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2015-01-29 19:50:27 +0100
committerPatrick Simianer <p@simianer.de>2015-01-29 19:50:27 +0100
commit3b371a950f9ae04c9072f3df9b21dafa475916fa (patch)
tree25fabc78bfab76ce746967c76aea5bc85f8c69d9
parent41eb8edaf5965b8efbe0ace199905927452e895d (diff)
dtrain: output_pairs, print_ranking
-rw-r--r--training/dtrain/dtrain.cc27
-rw-r--r--training/dtrain/score.cc2
2 files changed, 26 insertions, 3 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 737326f8..0b648d95 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -41,8 +41,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate")
("batch", po::value<bool>()->zero_tokens(), "do batch optimization")
- ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times")
+ ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times")
("check", po::value<bool>()->zero_tokens(), "produce list of loss differentials")
+ ("print_ranking", po::value<bool>()->zero_tokens(), "output kbest with model score and metric")
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -71,7 +72,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
return false;
}
if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "XYX" &&
- (*cfg)["pair_sampling"].as<string>() != "PRO") {
+ (*cfg)["pair_sampling"].as<string>() != "PRO" && (*cfg)["pair_sampling"].as<string>() != "output_pairs") {
cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;
return false;
}
@@ -114,6 +115,8 @@ main(int argc, char** argv)
if (cfg.count("rescale")) rescale = true;
bool keep = false;
if (cfg.count("keep")) keep = true;
+ bool print_ranking = false;
+ if (cfg.count("print_ranking")) print_ranking = true;
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
@@ -356,6 +359,14 @@ main(int argc, char** argv)
// get (scored) samples
vector<ScoredHyp>* samples = observer->GetSamples();
+ if (print_ranking) {
+ for (auto s: *samples) {
+ cout << ii << " ||| ";
+ printWordIDVec(s.w, cout);
+ cout << " ||| " << s.model << " ||| " << s.score << endl;
+ }
+ }
+
if (verbose) {
cerr << "--- refs for " << ii << ": ";
for (auto r: refs_as_ids_buf[ii]) {
@@ -389,6 +400,8 @@ main(int argc, char** argv)
partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo);
if (pair_sampling == "PRO")
PROsampling(samples, pairs, pair_threshold, max_pairs);
+ if (pair_sampling == "output_pairs")
+ all_pairs(samples, pairs, pair_threshold, max_pairs, false);
int cur_npairs = pairs.size();
npairs += cur_npairs;
@@ -397,6 +410,15 @@ main(int argc, char** argv)
if (check) repeat = 2;
vector<float> losses; // for check
+ if (pair_sampling == "output_pairs") {
+ for (auto p: pairs) {
+ cout << p.first.model << " ||| " << p.first.score << " ||| " << p.first.f << endl;
+ cout << p.second.model << " ||| " << p.second.score << " ||| " << p.second.f << endl;
+ cout << endl;
+ }
+ continue;
+ }
+
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
score_t model_diff = it->first.model - it->second.model;
@@ -547,6 +569,7 @@ main(int argc, char** argv)
if (t == 0) in_sz = ii; // remember size of input (# lines)
+ if (print_ranking) cout << "---" << endl;
if (batch) {
lambdas.plus_eq_v_times_s(batch_updates, eta);
diff --git a/training/dtrain/score.cc b/training/dtrain/score.cc
index d81eafcb..8a28771f 100644
--- a/training/dtrain/score.cc
+++ b/training/dtrain/score.cc
@@ -36,7 +36,7 @@ RefLen(vector<vector<WordID> > refs)
{
size_t ref_len = 0;
for (auto r: refs)
- ref_len = max(ref_len, r.size());
+ ref_len = max(ref_len, r.size()); // FIXME
return ref_len;
}