diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 2bb4ec98..149f87d4 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -34,8 +34,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") ("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near") ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") - ("noup", po::value<bool>()->zero_tokens(), "do not update weights") - ("pair_stats", po::value<bool>()->zero_tokens(), "stats about correctly ranked/misranked pairs even if loss_margin=0 and gamma=0"); + ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() ("config,c", po::value<string>(), "dtrain config file") @@ -125,10 +124,7 @@ main(int argc, char** argv) vector<string> print_weights; if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); - bool pair_stats = false; - if (cfg.count("pair_stats")) pair_stats = true; - bool faster_perceptron = false; - if (gamma==0 && loss_margin==0 && !pair_stats) faster_perceptron = true; + // setup decoder register_feature_functions(); @@ -185,6 +181,11 @@ main(int argc, char** argv) weight_t eta = cfg["learning_rate"].as<weight_t>(); weight_t gamma = cfg["gamma"].as<weight_t>(); + // faster perceptron: consider only misranked pairs, see + // DO NOT ENABLE WITH SVM (gamma > 0) OR loss_margin! + bool faster_perceptron = false; + if (gamma==0 && loss_margin==0) faster_perceptron = true; + // l1 regularization bool l1naive = false; bool l1clip = false; @@ -232,6 +233,7 @@ main(int argc, char** argv) else cerr << setw(25) << "learning rate " << "bleu diff" << endl; cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "loss margin " << loss_margin << endl; + cerr << setw(25) << "faster perceptron " << faster_perceptron << endl; cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; if (pair_sampling == "XYX") cerr << setw(25) << "hi lo " << hi_lo << endl; @@ -461,7 +463,9 @@ main(int argc, char** argv) cerr << _np << " 1best avg model score: " << model_avg; cerr << _p << " (" << model_diff << ")" << endl; cerr << " avg # pairs: "; - cerr << _np << npairs/(float)in_sz << endl; + cerr << _np << npairs/(float)in_sz; + if (faster_perceptron) cerr << " (meaningless)"; + cerr << endl; cerr << " avg # rank err: "; cerr << rank_errors/(float)in_sz << endl; cerr << " avg # margin viol: "; |