summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/dtrain/dtrain.cc18
1 files changed, 9 insertions, 9 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 0b648d95..b180bc82 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -43,7 +43,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("batch", po::value<bool>()->zero_tokens(), "do batch optimization")
("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times")
("check", po::value<bool>()->zero_tokens(), "produce list of loss differentials")
- ("print_ranking", po::value<bool>()->zero_tokens(), "output kbest with model score and metric")
+ ("output_ranking", po::value<string>()->default_value(""), "Output kbests with model scores and metric per iteration to this folder.")
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -115,8 +115,6 @@ main(int argc, char** argv)
if (cfg.count("rescale")) rescale = true;
bool keep = false;
if (cfg.count("keep")) keep = true;
- bool print_ranking = false;
- if (cfg.count("print_ranking")) print_ranking = true;
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
@@ -127,6 +125,7 @@ main(int argc, char** argv)
const string pair_sampling = cfg["pair_sampling"].as<string>();
const score_t pair_threshold = cfg["pair_threshold"].as<score_t>();
const string select_weights = cfg["select_weights"].as<string>();
+ const string output_ranking = cfg["output_ranking"].as<string>();
const float hi_lo = cfg["hi_lo"].as<float>();
const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>();
const unsigned max_pairs = cfg["max_pairs"].as<unsigned>();
@@ -359,12 +358,15 @@ main(int argc, char** argv)
// get (scored) samples
vector<ScoredHyp>* samples = observer->GetSamples();
- if (print_ranking) {
+ if (output_ranking != "") {
+ WriteFile of(output_ranking+"/"+to_string(t)+"."+to_string(ii)+".list"); // works with '-'
+ stringstream ss;
for (auto s: *samples) {
- cout << ii << " ||| ";
- printWordIDVec(s.w, cout);
- cout << " ||| " << s.model << " ||| " << s.score << endl;
+ ss << ii << " ||| ";
+ printWordIDVec(s.w, ss);
+ ss << " ||| " << s.model << " ||| " << s.score << endl;
}
+ of.get() << ss.str();
}
if (verbose) {
@@ -569,8 +571,6 @@ main(int argc, char** argv)
if (t == 0) in_sz = ii; // remember size of input (# lines)
- if (print_ranking) cout << "---" << endl;
-
if (batch) {
lambdas.plus_eq_v_times_s(batch_updates, eta);
if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));