summaryrefslogtreecommitdiff
path: root/dtrain
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain')
-rw-r--r--dtrain/dtrain.cc12
1 files changed, 6 insertions, 6 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index e07b9307..9db8516c 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -25,12 +25,12 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_")
("learning_rate", po::value<weight_t>()->default_value(0.0001), "learning rate")
("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)")
- ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
+ ("select_weights", po::value<string>()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)")
("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input")
("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
("funny", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
- ("average", po::value<bool>()->zero_tokens(), "output weight vector is average of all iterations")
+ ("fselect", po::value<weight_t>()->default_value(-1), "select top x percent of features after each epoch")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -81,7 +81,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
return false;
}
if ((*cfg)["select_weights"].as<string>() != "last" && (*cfg)["select_weights"].as<string>() != "best" &&
- (*cfg)["select_weights"].as<string>() != "VOID") {
+ (*cfg)["select_weights"].as<string>() != "avg" && (*cfg)["select_weights"].as<string>() != "VOID") {
cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl;
return false;
}
@@ -116,9 +116,6 @@ main(int argc, char** argv)
bool funny = false;
if (cfg.count("funny"))
funny = true;
- bool average = false;
- if (cfg.count("average"))
- average = true;
const unsigned k = cfg["k"].as<unsigned>();
const unsigned N = cfg["N"].as<unsigned>();
@@ -129,6 +126,9 @@ main(int argc, char** argv)
const string pair_sampling = cfg["pair_sampling"].as<string>();
const score_t pair_threshold = cfg["pair_threshold"].as<score_t>();
const string select_weights = cfg["select_weights"].as<string>();
+ bool average = false;
+ if (select_weights == "avg")
+ average = true;
vector<string> print_weights;
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));