diff options
Diffstat (limited to 'dtrain')
-rw-r--r-- | dtrain/dtrain.cc | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index e07b9307..9db8516c 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -25,12 +25,12 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_") ("learning_rate", po::value<weight_t>()->default_value(0.0001), "learning rate") ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)") - ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") + ("select_weights", po::value<string>()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") ("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input") ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") ("funny", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") - ("average", po::value<bool>()->zero_tokens(), "output weight vector is average of all iterations") + ("fselect", po::value<weight_t>()->default_value(-1), "select top x percent of features after each epoch") #ifdef DTRAIN_LOCAL ("refs,r", po::value<string>(), "references in local mode") #endif @@ -81,7 +81,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) return false; } if ((*cfg)["select_weights"].as<string>() != "last" && (*cfg)["select_weights"].as<string>() != "best" && - (*cfg)["select_weights"].as<string>() != "VOID") { + (*cfg)["select_weights"].as<string>() != "avg" && (*cfg)["select_weights"].as<string>() != "VOID") { cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl; return false; } @@ -116,9 +116,6 @@ main(int argc, char** argv) bool funny = false; if (cfg.count("funny")) funny = true; - bool average = false; - if (cfg.count("average")) - average = true; const unsigned k = cfg["k"].as<unsigned>(); const unsigned N = cfg["N"].as<unsigned>(); @@ -129,6 +126,9 @@ main(int argc, char** argv) const string pair_sampling = cfg["pair_sampling"].as<string>(); const score_t pair_threshold = cfg["pair_threshold"].as<score_t>(); const string select_weights = cfg["select_weights"].as<string>(); + bool average = false; + if (select_weights == "avg") + average = true; vector<string> print_weights; if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); |