From b021b58297de50a874165d26e1c5c808192bbe18 Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Wed, 30 Nov 2011 15:49:59 +0100 Subject: spare a arg --- dtrain/dtrain.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index e07b9307..9db8516c 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -25,12 +25,12 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("scorer", po::value()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_") ("learning_rate", po::value()->default_value(0.0001), "learning rate") ("gamma", po::value()->default_value(0), "gamma for SVM (0 for perceptron)") - ("select_weights", po::value()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") + ("select_weights", po::value()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") ("rescale", po::value()->zero_tokens(), "rescale weight vector after each input") ("l1_reg", po::value()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") ("l1_reg_strength", po::value(), "l1 regularization strength") ("funny", po::value()->zero_tokens(), "include correctly ranked pairs into updates") - ("average", po::value()->zero_tokens(), "output weight vector is average of all iterations") + ("fselect", po::value()->default_value(-1), "select top x percent of features after each epoch") #ifdef DTRAIN_LOCAL ("refs,r", po::value(), "references in local mode") #endif @@ -81,7 +81,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) return false; } if ((*cfg)["select_weights"].as() != "last" && (*cfg)["select_weights"].as() != "best" && - (*cfg)["select_weights"].as() != "VOID") { + (*cfg)["select_weights"].as() != "avg" && (*cfg)["select_weights"].as() != "VOID") { cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as() << "', use 'last' or 'best'." << endl; return false; } @@ -116,9 +116,6 @@ main(int argc, char** argv) bool funny = false; if (cfg.count("funny")) funny = true; - bool average = false; - if (cfg.count("average")) - average = true; const unsigned k = cfg["k"].as(); const unsigned N = cfg["N"].as(); @@ -129,6 +126,9 @@ main(int argc, char** argv) const string pair_sampling = cfg["pair_sampling"].as(); const score_t pair_threshold = cfg["pair_threshold"].as(); const string select_weights = cfg["select_weights"].as(); + bool average = false; + if (select_weights == "avg") + average = true; vector print_weights; if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as(), boost::is_any_of(" ")); -- cgit v1.2.3