diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index fcb46db2..2bb4ec98 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -6,7 +6,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) { po::options_description ini("Configuration File Options"); ini.add_options() - ("input", po::value<string>()->default_value("-"), "input file") + ("input", po::value<string>()->default_value("-"), "input file (src)") + ("refs,r", po::value<string>(), "references") ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") ("decoder_config", po::value<string>(), "configuration file for cdec") @@ -33,8 +34,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") ("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near") ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") - ("refs,r", po::value<string>(), "references in local mode") - ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); + ("noup", po::value<bool>()->zero_tokens(), "do not update weights") + ("pair_stats", po::value<bool>()->zero_tokens(), "stats about correctly ranked/misranked pairs even if loss_margin=0 and gamma=0"); po::options_description cl("Command Line Options"); cl.add_options() ("config,c", po::value<string>(), "dtrain config file") @@ -124,6 +125,10 @@ main(int argc, char** argv) vector<string> print_weights; if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); + bool pair_stats = false; + if (cfg.count("pair_stats")) pair_stats = true; + bool faster_perceptron = false; + if (gamma==0 && loss_margin==0 && !pair_stats) faster_perceptron = true; // setup decoder register_feature_functions(); @@ -346,25 +351,26 @@ main(int argc, char** argv) // get pairs vector<pair<ScoredHyp,ScoredHyp> > pairs; if (pair_sampling == "all") - all_pairs(samples, pairs, pair_threshold, max_pairs); + all_pairs(samples, pairs, pair_threshold, max_pairs, faster_perceptron); if (pair_sampling == "XYX") - partXYX(samples, pairs, pair_threshold, max_pairs, hi_lo); + partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo); if (pair_sampling == "PRO") PROsampling(samples, pairs, pair_threshold, max_pairs); npairs += pairs.size(); for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { -#ifdef DTRAIN_FASTER_PERCEPTRON - bool rank_error = true; // pair sampling already did this for us - rank_errors++; - score_t margin = std::numeric_limits<float>::max(); -#else - bool rank_error = it->first.model <= it->second.model; + bool rank_error; + score_t margin; + if (faster_perceptron) { // we only have considering misranked pairs + rank_error = true; // pair sampling already did this for us + margin = std::numeric_limits<float>::max(); + } else { + rank_error = it->first.model <= it->second.model; + margin = fabs(fabs(it->first.model) - fabs(it->second.model)); + if (!rank_error && margin < loss_margin) margin_violations++; + } if (rank_error) rank_errors++; - score_t margin = fabs(fabs(it->first.model) - fabs(it->second.model)); - if (!rank_error && margin < loss_margin) margin_violations++; -#endif if (scale_bleu_diff) eta = it->first.score - it->second.score; if (rank_error || margin < loss_margin) { SparseVector<weight_t> diff_vec = it->first.f - it->second.f; @@ -458,10 +464,8 @@ main(int argc, char** argv) cerr << _np << npairs/(float)in_sz << endl; cerr << " avg # rank err: "; cerr << rank_errors/(float)in_sz << endl; -#ifndef DTRAIN_FASTER_PERCEPTRON cerr << " avg # margin viol: "; cerr << margin_violations/(float)in_sz << endl; -#endif cerr << " non0 feature count: " << nonz << endl; cerr << " avg list sz: " << list_sz/(float)in_sz << endl; cerr << " avg f count: " << f_count/(float)list_sz << endl; |