summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc39
1 files changed, 28 insertions, 11 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 9ca048c0..b39fff3e 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -12,17 +12,20 @@ main(int argc, char** argv)
po::variables_map conf;
if (!dtrain_init(argc, argv, &conf))
return 1;
- const size_t k = conf["k"].as<size_t>();
- const string score_name = conf["score"].as<string>();
- const size_t N = conf["N"].as<size_t>();
- const size_t T = conf["iterations"].as<size_t>();
- const weight_t eta = conf["learning_rate"].as<weight_t>();
- const weight_t margin = conf["margin"].as<weight_t>();
- const bool average = conf["average"].as<bool>();
- const bool structured = conf["struct"].as<bool>();
- const weight_t l1_reg = conf["l1_reg"].as<weight_t>();
- const bool keep = conf["keep"].as<bool>();
- const string output_fn = conf["output"].as<string>();
+ const size_t k = conf["k"].as<size_t>();
+ const string score_name = conf["score"].as<string>();
+ const size_t N = conf["N"].as<size_t>();
+ const size_t T = conf["iterations"].as<size_t>();
+ const weight_t eta = conf["learning_rate"].as<weight_t>();
+ const weight_t margin = conf["margin"].as<weight_t>();
+ const bool average = conf["average"].as<bool>();
+ const bool structured = conf["struct"].as<bool>();
+ const weight_t l1_reg = conf["l1_reg"].as<weight_t>();
+ const bool keep = conf["keep"].as<bool>();
+ const bool noup = conf["disable_learning"].as<bool>();
+ const string output_fn = conf["output"].as<string>();
+ const string output_data_which = conf["output_data"].as<string>();
+ const bool output_data = output_data_which!="";
vector<string> print_weights;
boost::split(print_weights, conf["print_weights"].as<string>(),
boost::is_any_of(" "));
@@ -162,7 +165,19 @@ main(int argc, char** argv)
feature_count += observer->GetFeatureCount();
list_sz += observer->GetSize();
+ if (output_data) {
+ if (output_data_which == "kbest") {
+ OutputKbest(samples);
+ } else if (output_data_which == "default") {
+ OutputMultipartitePairs(samples, margin);
+ } else if (output_data_which == "all") {
+ OutputAllPairs(samples);
+ }
+ }
+
// get pairs and update
+ if (!noup) {
+
SparseVector<weight_t> updates;
if (structured)
num_up += CollectUpdatesStruct(samples, updates);
@@ -204,6 +219,8 @@ main(int argc, char** argv)
}
}
+ } // noup
+
i++;
} // input loop