summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc18
1 files changed, 11 insertions, 7 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 2bb4ec98..149f87d4 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -34,8 +34,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair")
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
- ("noup", po::value<bool>()->zero_tokens(), "do not update weights")
- ("pair_stats", po::value<bool>()->zero_tokens(), "stats about correctly ranked/misranked pairs even if loss_margin=0 and gamma=0");
+ ("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
("config,c", po::value<string>(), "dtrain config file")
@@ -125,10 +124,7 @@ main(int argc, char** argv)
vector<string> print_weights;
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
- bool pair_stats = false;
- if (cfg.count("pair_stats")) pair_stats = true;
- bool faster_perceptron = false;
- if (gamma==0 && loss_margin==0 && !pair_stats) faster_perceptron = true;
+
// setup decoder
register_feature_functions();
@@ -185,6 +181,11 @@ main(int argc, char** argv)
weight_t eta = cfg["learning_rate"].as<weight_t>();
weight_t gamma = cfg["gamma"].as<weight_t>();
+ // faster perceptron: consider only misranked pairs, see
+ // DO NOT ENABLE WITH SVM (gamma > 0) OR loss_margin!
+ bool faster_perceptron = false;
+ if (gamma==0 && loss_margin==0) faster_perceptron = true;
+
// l1 regularization
bool l1naive = false;
bool l1clip = false;
@@ -232,6 +233,7 @@ main(int argc, char** argv)
else cerr << setw(25) << "learning rate " << "bleu diff" << endl;
cerr << setw(25) << "gamma " << gamma << endl;
cerr << setw(25) << "loss margin " << loss_margin << endl;
+ cerr << setw(25) << "faster perceptron " << faster_perceptron << endl;
cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;
if (pair_sampling == "XYX")
cerr << setw(25) << "hi lo " << hi_lo << endl;
@@ -461,7 +463,9 @@ main(int argc, char** argv)
cerr << _np << " 1best avg model score: " << model_avg;
cerr << _p << " (" << model_diff << ")" << endl;
cerr << " avg # pairs: ";
- cerr << _np << npairs/(float)in_sz << endl;
+ cerr << _np << npairs/(float)in_sz;
+ if (faster_perceptron) cerr << " (meaningless)";
+ cerr << endl;
cerr << " avg # rank err: ";
cerr << rank_errors/(float)in_sz << endl;
cerr << " avg # margin viol: ";