diff options
author | Patrick Simianer <p@simianer.de> | 2015-09-15 15:32:29 +0200 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2015-09-15 15:32:29 +0200 |
commit | 8d900ca0af90dff71d68a7b596571df3e64c2101 (patch) | |
tree | efc3fddedb3f3680054e475f66a3b5fef97d4bd6 /training/dtrain | |
parent | 0208c988890a72d4a3e80fb3cebf2abd03162050 (diff) |
Diffstat (limited to 'training/dtrain')
-rw-r--r-- | training/dtrain/dtrain.cc | 39 | ||||
-rw-r--r-- | training/dtrain/dtrain.h | 43 | ||||
-rw-r--r-- | training/dtrain/update.h | 50 |
3 files changed, 105 insertions, 27 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 9ca048c0..b39fff3e 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -12,17 +12,20 @@ main(int argc, char** argv) po::variables_map conf; if (!dtrain_init(argc, argv, &conf)) return 1; - const size_t k = conf["k"].as<size_t>(); - const string score_name = conf["score"].as<string>(); - const size_t N = conf["N"].as<size_t>(); - const size_t T = conf["iterations"].as<size_t>(); - const weight_t eta = conf["learning_rate"].as<weight_t>(); - const weight_t margin = conf["margin"].as<weight_t>(); - const bool average = conf["average"].as<bool>(); - const bool structured = conf["struct"].as<bool>(); - const weight_t l1_reg = conf["l1_reg"].as<weight_t>(); - const bool keep = conf["keep"].as<bool>(); - const string output_fn = conf["output"].as<string>(); + const size_t k = conf["k"].as<size_t>(); + const string score_name = conf["score"].as<string>(); + const size_t N = conf["N"].as<size_t>(); + const size_t T = conf["iterations"].as<size_t>(); + const weight_t eta = conf["learning_rate"].as<weight_t>(); + const weight_t margin = conf["margin"].as<weight_t>(); + const bool average = conf["average"].as<bool>(); + const bool structured = conf["struct"].as<bool>(); + const weight_t l1_reg = conf["l1_reg"].as<weight_t>(); + const bool keep = conf["keep"].as<bool>(); + const bool noup = conf["disable_learning"].as<bool>(); + const string output_fn = conf["output"].as<string>(); + const string output_data_which = conf["output_data"].as<string>(); + const bool output_data = output_data_which!=""; vector<string> print_weights; boost::split(print_weights, conf["print_weights"].as<string>(), boost::is_any_of(" ")); @@ -162,7 +165,19 @@ main(int argc, char** argv) feature_count += observer->GetFeatureCount(); list_sz += observer->GetSize(); + if (output_data) { + if (output_data_which == "kbest") { + OutputKbest(samples); + } else if (output_data_which == "default") { + OutputMultipartitePairs(samples, margin); + } else if (output_data_which == "all") { + OutputAllPairs(samples); + } + } + // get pairs and update + if (!noup) { + SparseVector<weight_t> updates; if (structured) num_up += CollectUpdatesStruct(samples, updates); @@ -204,6 +219,8 @@ main(int argc, char** argv) } } + } // noup + i++; } // input loop diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 6633b4f9..0bbb5c9b 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -48,22 +48,24 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration File Options"); opts.add_options() - ("bitext,b", po::value<string>(), "bitext") - ("decoder_conf,C", po::value<string>(), "configuration file for decoder") - ("iterations,T", po::value<size_t>()->default_value(15), "number of iterations T (per shard)") - ("k", po::value<size_t>()->default_value(100), "size of kbest list") - ("learning_rate,l", po::value<weight_t>()->default_value(0.00001), "learning rate") - ("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength") - ("margin,m", po::value<weight_t>()->default_value(1.0), "margin for margin perceptron") - ("score,s", po::value<string>()->default_value("chiang"), "per-sentence BLEU approx.") - ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") - ("input_weights,w", po::value<string>(), "input weights file") - ("average,a", po::bool_switch()->default_value(true), "output average weights") - ("keep,K", po::bool_switch()->default_value(false), "output a weight file per iteration") - ("struct,S", po::bool_switch()->default_value(false), "structured SGD with hope/fear") - ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") - ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), - "list of weights to print after each iteration"); + ("bitext,b", po::value<string>(), "bitext") + ("decoder_conf,C", po::value<string>(), "configuration file for decoder") + ("iterations,T", po::value<size_t>()->default_value(15), "number of iterations T (per shard)") + ("k", po::value<size_t>()->default_value(100), "size of kbest list") + ("learning_rate,l", po::value<weight_t>()->default_value(0.00001), "learning rate") + ("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength") + ("margin,m", po::value<weight_t>()->default_value(1.0), "margin for margin perceptron") + ("score,s", po::value<string>()->default_value("chiang"), "per-sentence BLEU approx.") + ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") + ("input_weights,w", po::value<string>(), "input weights file") + ("average,a", po::bool_switch()->default_value(true), "output average weights") + ("keep,K", po::bool_switch()->default_value(false), "output a weight file per iteration") + ("struct,S", po::bool_switch()->default_value(false), "structured SGD with hope/fear") + ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") + ("disable_learning,X", po::bool_switch()->default_value(false), "disable learning") + ("output_data,D", po::value<string>()->default_value(""), "output data to STDOUT; arg. is 'kbest', 'default' or 'all'") + ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), + "list of weights to print after each iteration"); po::options_description clopts("Command Line Options"); clopts.add_options() ("conf,c", po::value<string>(), "dtrain configuration file") @@ -93,6 +95,15 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) return false; } + if ((*conf)["output_data"].as<string>() != "") { + if ((*conf)["output_data"].as<string>() != "kbest" && + (*conf)["output_data"].as<string>() != "default" && + (*conf)["output_data"].as<string>() != "all") { + cerr << "Wrong 'output_data' argument: "; + cerr << (*conf)["output_data"].as<string>() << endl; + return false; + } + } return true; } diff --git a/training/dtrain/update.h b/training/dtrain/update.h index 6f42e5bd..83dc3186 100644 --- a/training/dtrain/update.h +++ b/training/dtrain/update.h @@ -94,6 +94,56 @@ CollectUpdatesStruct(vector<ScoredHyp>* s, return updates.size(); } +inline void +OutputKbest(vector<ScoredHyp>* s) +{ + sort(s->begin(), s->end(), _cmp); + size_t i = 0; + for (auto k: *s) { + cout << i << "\t" << k.gold << "\t" << k.model << " \t" << k.f << endl; + i++; + } +} + +inline void +OutputMultipartitePairs(vector<ScoredHyp>* s, + weight_t margin=0., + bool all=true) +{ + size_t sz = s->size(); + sort(s->begin(), s->end(), _cmp); + size_t sep = round(sz*0.1); + for (size_t i = 0; i < sep; i++) { + for (size_t j = sep; j < sz; j++) { + if (!all && _good((*s)[i], (*s)[j], margin)) + continue; + cout << (*s)[i].f-(*s)[j].f << endl; + } + } + size_t sep_lo = sz-sep; + for (size_t i = sep; i < sep_lo; i++) { + for (size_t j = sep_lo; j < sz; j++) { + if (!all && _good((*s)[i], (*s)[j], margin)) + continue; + cout << (*s)[i].f-(*s)[j].f << endl; + } + } +} + +inline void +OutputAllPairs(vector<ScoredHyp>* s) +{ + size_t sz = s->size(); + sort(s->begin(), s->end(), _cmp); + for (size_t i = 0; i < sz-1; i++) { + for (size_t j = i+1; j < sz; j++) { + if ((*s)[i].gold == (*s)[j].gold) + continue; + cout << (*s)[i].f-(*s)[j].f << endl; + } + } +} + } // namespace #endif |