diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 9ca048c0..b39fff3e 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -12,17 +12,20 @@ main(int argc, char** argv) po::variables_map conf; if (!dtrain_init(argc, argv, &conf)) return 1; - const size_t k = conf["k"].as<size_t>(); - const string score_name = conf["score"].as<string>(); - const size_t N = conf["N"].as<size_t>(); - const size_t T = conf["iterations"].as<size_t>(); - const weight_t eta = conf["learning_rate"].as<weight_t>(); - const weight_t margin = conf["margin"].as<weight_t>(); - const bool average = conf["average"].as<bool>(); - const bool structured = conf["struct"].as<bool>(); - const weight_t l1_reg = conf["l1_reg"].as<weight_t>(); - const bool keep = conf["keep"].as<bool>(); - const string output_fn = conf["output"].as<string>(); + const size_t k = conf["k"].as<size_t>(); + const string score_name = conf["score"].as<string>(); + const size_t N = conf["N"].as<size_t>(); + const size_t T = conf["iterations"].as<size_t>(); + const weight_t eta = conf["learning_rate"].as<weight_t>(); + const weight_t margin = conf["margin"].as<weight_t>(); + const bool average = conf["average"].as<bool>(); + const bool structured = conf["struct"].as<bool>(); + const weight_t l1_reg = conf["l1_reg"].as<weight_t>(); + const bool keep = conf["keep"].as<bool>(); + const bool noup = conf["disable_learning"].as<bool>(); + const string output_fn = conf["output"].as<string>(); + const string output_data_which = conf["output_data"].as<string>(); + const bool output_data = output_data_which!=""; vector<string> print_weights; boost::split(print_weights, conf["print_weights"].as<string>(), boost::is_any_of(" ")); @@ -162,7 +165,19 @@ main(int argc, char** argv) feature_count += observer->GetFeatureCount(); list_sz += observer->GetSize(); + if (output_data) { + if (output_data_which == "kbest") { + OutputKbest(samples); + } else if (output_data_which == "default") { + OutputMultipartitePairs(samples, margin); + } else if (output_data_which == "all") { + OutputAllPairs(samples); + } + } + // get pairs and update + if (!noup) { + SparseVector<weight_t> updates; if (structured) num_up += CollectUpdatesStruct(samples, updates); @@ -204,6 +219,8 @@ main(int argc, char** argv) } } + } // noup + i++; } // input loop |