diff options
-rw-r--r-- | training/dtrain/dtrain.cc | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 0b648d95..b180bc82 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -43,7 +43,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("batch", po::value<bool>()->zero_tokens(), "do batch optimization") ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times") ("check", po::value<bool>()->zero_tokens(), "produce list of loss differentials") - ("print_ranking", po::value<bool>()->zero_tokens(), "output kbest with model score and metric") + ("output_ranking", po::value<string>()->default_value(""), "Output kbests with model scores and metric per iteration to this folder.") ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -115,8 +115,6 @@ main(int argc, char** argv) if (cfg.count("rescale")) rescale = true; bool keep = false; if (cfg.count("keep")) keep = true; - bool print_ranking = false; - if (cfg.count("print_ranking")) print_ranking = true; const unsigned k = cfg["k"].as<unsigned>(); const unsigned N = cfg["N"].as<unsigned>(); @@ -127,6 +125,7 @@ main(int argc, char** argv) const string pair_sampling = cfg["pair_sampling"].as<string>(); const score_t pair_threshold = cfg["pair_threshold"].as<score_t>(); const string select_weights = cfg["select_weights"].as<string>(); + const string output_ranking = cfg["output_ranking"].as<string>(); const float hi_lo = cfg["hi_lo"].as<float>(); const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>(); const unsigned max_pairs = cfg["max_pairs"].as<unsigned>(); @@ -359,12 +358,15 @@ main(int argc, char** argv) // get (scored) samples vector<ScoredHyp>* samples = observer->GetSamples(); - if (print_ranking) { + if (output_ranking != "") { + WriteFile of(output_ranking+"/"+to_string(t)+"."+to_string(ii)+".list"); // works with '-' + stringstream ss; for (auto s: *samples) { - cout << ii << " ||| "; - printWordIDVec(s.w, cout); - cout << " ||| " << s.model << " ||| " << s.score << endl; + ss << ii << " ||| "; + printWordIDVec(s.w, ss); + ss << " ||| " << s.model << " ||| " << s.score << endl; } + of.get() << ss.str(); } if (verbose) { @@ -569,8 +571,6 @@ main(int argc, char** argv) if (t == 0) in_sz = ii; // remember size of input (# lines) - if (print_ranking) cout << "---" << endl; - if (batch) { lambdas.plus_eq_v_times_s(batch_updates, eta); if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); |