summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc36
1 files changed, 20 insertions, 16 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index fcb46db2..2bb4ec98 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -6,7 +6,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
{
po::options_description ini("Configuration File Options");
ini.add_options()
- ("input", po::value<string>()->default_value("-"), "input file")
+ ("input", po::value<string>()->default_value("-"), "input file (src)")
+ ("refs,r", po::value<string>(), "references")
("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
("decoder_config", po::value<string>(), "configuration file for cdec")
@@ -33,8 +34,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair")
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
- ("refs,r", po::value<string>(), "references in local mode")
- ("noup", po::value<bool>()->zero_tokens(), "do not update weights");
+ ("noup", po::value<bool>()->zero_tokens(), "do not update weights")
+ ("pair_stats", po::value<bool>()->zero_tokens(), "stats about correctly ranked/misranked pairs even if loss_margin=0 and gamma=0");
po::options_description cl("Command Line Options");
cl.add_options()
("config,c", po::value<string>(), "dtrain config file")
@@ -124,6 +125,10 @@ main(int argc, char** argv)
vector<string> print_weights;
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
+ bool pair_stats = false;
+ if (cfg.count("pair_stats")) pair_stats = true;
+ bool faster_perceptron = false;
+ if (gamma==0 && loss_margin==0 && !pair_stats) faster_perceptron = true;
// setup decoder
register_feature_functions();
@@ -346,25 +351,26 @@ main(int argc, char** argv)
// get pairs
vector<pair<ScoredHyp,ScoredHyp> > pairs;
if (pair_sampling == "all")
- all_pairs(samples, pairs, pair_threshold, max_pairs);
+ all_pairs(samples, pairs, pair_threshold, max_pairs, faster_perceptron);
if (pair_sampling == "XYX")
- partXYX(samples, pairs, pair_threshold, max_pairs, hi_lo);
+ partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo);
if (pair_sampling == "PRO")
PROsampling(samples, pairs, pair_threshold, max_pairs);
npairs += pairs.size();
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
-#ifdef DTRAIN_FASTER_PERCEPTRON
- bool rank_error = true; // pair sampling already did this for us
- rank_errors++;
- score_t margin = std::numeric_limits<float>::max();
-#else
- bool rank_error = it->first.model <= it->second.model;
+ bool rank_error;
+ score_t margin;
+ if (faster_perceptron) { // we only have considering misranked pairs
+ rank_error = true; // pair sampling already did this for us
+ margin = std::numeric_limits<float>::max();
+ } else {
+ rank_error = it->first.model <= it->second.model;
+ margin = fabs(fabs(it->first.model) - fabs(it->second.model));
+ if (!rank_error && margin < loss_margin) margin_violations++;
+ }
if (rank_error) rank_errors++;
- score_t margin = fabs(fabs(it->first.model) - fabs(it->second.model));
- if (!rank_error && margin < loss_margin) margin_violations++;
-#endif
if (scale_bleu_diff) eta = it->first.score - it->second.score;
if (rank_error || margin < loss_margin) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
@@ -458,10 +464,8 @@ main(int argc, char** argv)
cerr << _np << npairs/(float)in_sz << endl;
cerr << " avg # rank err: ";
cerr << rank_errors/(float)in_sz << endl;
-#ifndef DTRAIN_FASTER_PERCEPTRON
cerr << " avg # margin viol: ";
cerr << margin_violations/(float)in_sz << endl;
-#endif
cerr << " non0 feature count: " << nonz << endl;
cerr << " avg list sz: " << list_sz/(float)in_sz << endl;
cerr << " avg f count: " << f_count/(float)list_sz << endl;