diff options
Diffstat (limited to 'training/dtrain')
| -rw-r--r-- | training/dtrain/dtrain.cc | 27 | ||||
| -rw-r--r-- | training/dtrain/score.cc | 2 | 
2 files changed, 26 insertions, 3 deletions
| diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 737326f8..0b648d95 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -41,8 +41,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)      ("max_pairs",         po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")      ("pclr",              po::value<string>()->default_value("no"),         "use a (simple|adagrad) per-coordinate learning rate")      ("batch",             po::value<bool>()->zero_tokens(),                                               "do batch optimization") -    ("repeat",            po::value<unsigned>()->default_value(1),          "repeat optimization over kbest list this number of times") +    ("repeat",            po::value<unsigned>()->default_value(1),     "repeat optimization over kbest list this number of times")      ("check",             po::value<bool>()->zero_tokens(),                                  "produce list of loss differentials") +    ("print_ranking",     po::value<bool>()->zero_tokens(),                            "output kbest with model score and metric")      ("noup",              po::value<bool>()->zero_tokens(),                                               "do not update weights");    po::options_description cl("Command Line Options");    cl.add_options() @@ -71,7 +72,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)      return false;    }    if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "XYX" && -        (*cfg)["pair_sampling"].as<string>() != "PRO") { +        (*cfg)["pair_sampling"].as<string>() != "PRO" && (*cfg)["pair_sampling"].as<string>() != "output_pairs") {      cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;      return false;    } @@ -114,6 +115,8 @@ main(int argc, char** argv)    if (cfg.count("rescale")) rescale = true;    bool keep = false;    if (cfg.count("keep")) keep = true; +  bool print_ranking = false; +  if (cfg.count("print_ranking")) print_ranking = true;    const unsigned k = cfg["k"].as<unsigned>();    const unsigned N = cfg["N"].as<unsigned>(); @@ -356,6 +359,14 @@ main(int argc, char** argv)      // get (scored) samples      vector<ScoredHyp>* samples = observer->GetSamples(); +    if (print_ranking) { +      for (auto s: *samples) { +        cout << ii << " ||| "; +        printWordIDVec(s.w, cout); +        cout << " ||| " << s.model << " ||| " << s.score << endl; +      } +    } +      if (verbose) {        cerr << "--- refs for " << ii << ": ";        for (auto r: refs_as_ids_buf[ii]) { @@ -389,6 +400,8 @@ main(int argc, char** argv)          partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo);        if (pair_sampling == "PRO")          PROsampling(samples, pairs, pair_threshold, max_pairs); +      if (pair_sampling == "output_pairs") +        all_pairs(samples, pairs, pair_threshold, max_pairs, false);        int cur_npairs = pairs.size();        npairs += cur_npairs; @@ -397,6 +410,15 @@ main(int argc, char** argv)        if (check) repeat = 2;        vector<float> losses; // for check +      if (pair_sampling == "output_pairs") { +        for (auto p: pairs) { +          cout << p.first.model << " ||| "  << p.first.score << " ||| " <<  p.first.f  << endl; +          cout << p.second.model << " ||| "  << p.second.score << " ||| " <<  p.second.f  << endl; +          cout << endl; +        } +        continue; +      } +        for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();             it != pairs.end(); it++) {          score_t model_diff = it->first.model - it->second.model; @@ -547,6 +569,7 @@ main(int argc, char** argv)    if (t == 0) in_sz = ii; // remember size of input (# lines) +  if (print_ranking) cout << "---" << endl;    if (batch) {      lambdas.plus_eq_v_times_s(batch_updates, eta); diff --git a/training/dtrain/score.cc b/training/dtrain/score.cc index d81eafcb..8a28771f 100644 --- a/training/dtrain/score.cc +++ b/training/dtrain/score.cc @@ -36,7 +36,7 @@ RefLen(vector<vector<WordID> > refs)  {    size_t ref_len = 0;    for (auto r: refs) -    ref_len = max(ref_len, r.size()); +    ref_len = max(ref_len, r.size()); // FIXME    return ref_len;  } | 
