diff options
Diffstat (limited to 'training/dtrain')
| -rw-r--r-- | training/dtrain/dtrain.cc | 18 | 
1 files changed, 9 insertions, 9 deletions
| diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 0b648d95..b180bc82 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -43,7 +43,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)      ("batch",             po::value<bool>()->zero_tokens(),                                               "do batch optimization")      ("repeat",            po::value<unsigned>()->default_value(1),     "repeat optimization over kbest list this number of times")      ("check",             po::value<bool>()->zero_tokens(),                                  "produce list of loss differentials") -    ("print_ranking",     po::value<bool>()->zero_tokens(),                            "output kbest with model score and metric") +    ("output_ranking",    po::value<string>()->default_value(""),                      "Output kbests with model scores and metric per iteration to this folder.")      ("noup",              po::value<bool>()->zero_tokens(),                                               "do not update weights");    po::options_description cl("Command Line Options");    cl.add_options() @@ -115,8 +115,6 @@ main(int argc, char** argv)    if (cfg.count("rescale")) rescale = true;    bool keep = false;    if (cfg.count("keep")) keep = true; -  bool print_ranking = false; -  if (cfg.count("print_ranking")) print_ranking = true;    const unsigned k = cfg["k"].as<unsigned>();    const unsigned N = cfg["N"].as<unsigned>(); @@ -127,6 +125,7 @@ main(int argc, char** argv)    const string pair_sampling = cfg["pair_sampling"].as<string>();    const score_t pair_threshold = cfg["pair_threshold"].as<score_t>();    const string select_weights = cfg["select_weights"].as<string>(); +  const string output_ranking = cfg["output_ranking"].as<string>();    const float hi_lo = cfg["hi_lo"].as<float>();    const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>();    const unsigned max_pairs = cfg["max_pairs"].as<unsigned>(); @@ -359,12 +358,15 @@ main(int argc, char** argv)      // get (scored) samples      vector<ScoredHyp>* samples = observer->GetSamples(); -    if (print_ranking) { +    if (output_ranking != "") { +      WriteFile of(output_ranking+"/"+to_string(t)+"."+to_string(ii)+".list"); // works with '-' +      stringstream ss;        for (auto s: *samples) { -        cout << ii << " ||| "; -        printWordIDVec(s.w, cout); -        cout << " ||| " << s.model << " ||| " << s.score << endl; +        ss << ii << " ||| "; +        printWordIDVec(s.w, ss); +        ss << " ||| " << s.model << " ||| " << s.score << endl;        } +      of.get() << ss.str();      }      if (verbose) { @@ -569,8 +571,6 @@ main(int argc, char** argv)    if (t == 0) in_sz = ii; // remember size of input (# lines) -  if (print_ranking) cout << "---" << endl; -    if (batch) {      lambdas.plus_eq_v_times_s(batch_updates, eta);      if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); | 
