From 77dce5d710bfed0c3e8c03f1f5f1ec5856087ba9 Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Wed, 30 Nov 2011 15:49:59 +0100
Subject: spare a arg
---
dtrain/dtrain.cc | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index e07b9307..9db8516c 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -25,12 +25,12 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("scorer", po::value()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_")
("learning_rate", po::value()->default_value(0.0001), "learning rate")
("gamma", po::value()->default_value(0), "gamma for SVM (0 for perceptron)")
- ("select_weights", po::value()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)")
+ ("select_weights", po::value()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)")
("rescale", po::value()->zero_tokens(), "rescale weight vector after each input")
("l1_reg", po::value()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
("l1_reg_strength", po::value(), "l1 regularization strength")
("funny", po::value()->zero_tokens(), "include correctly ranked pairs into updates")
- ("average", po::value()->zero_tokens(), "output weight vector is average of all iterations")
+ ("fselect", po::value()->default_value(-1), "select top x percent of features after each epoch")
#ifdef DTRAIN_LOCAL
("refs,r", po::value(), "references in local mode")
#endif
@@ -81,7 +81,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
return false;
}
if ((*cfg)["select_weights"].as() != "last" && (*cfg)["select_weights"].as() != "best" &&
- (*cfg)["select_weights"].as() != "VOID") {
+ (*cfg)["select_weights"].as() != "avg" && (*cfg)["select_weights"].as() != "VOID") {
cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as() << "', use 'last' or 'best'." << endl;
return false;
}
@@ -116,9 +116,6 @@ main(int argc, char** argv)
bool funny = false;
if (cfg.count("funny"))
funny = true;
- bool average = false;
- if (cfg.count("average"))
- average = true;
const unsigned k = cfg["k"].as();
const unsigned N = cfg["N"].as();
@@ -129,6 +126,9 @@ main(int argc, char** argv)
const string pair_sampling = cfg["pair_sampling"].as();
const score_t pair_threshold = cfg["pair_threshold"].as();
const string select_weights = cfg["select_weights"].as();
+ bool average = false;
+ if (select_weights == "avg")
+ average = true;
vector print_weights;
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as(), boost::is_any_of(" "));
--
cgit v1.2.3