diff options
author | Patrick Simianer <p@simianer.de> | 2015-02-26 13:26:37 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2015-02-26 13:26:37 +0100 |
commit | 4223261682388944fe1b1cf31b9d51d88f9ad53b (patch) | |
tree | daf072c310d60b0386587bde5e554312f193b3b2 /training/dtrain/dtrain.h | |
parent | 2a37a7ad1b21ab54701de3b5b44dc4ea55a75307 (diff) |
refactoring
Diffstat (limited to 'training/dtrain/dtrain.h')
-rw-r--r-- | training/dtrain/dtrain.h | 57 |
1 files changed, 42 insertions, 15 deletions
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 2b466930..728b0698 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -1,9 +1,6 @@ #ifndef _DTRAIN_H_ #define _DTRAIN_H_ -#define DTRAIN_DOTS 10 // after how many inputs to display a '.' -#define DTRAIN_SCALE 100000 - #include <iomanip> #include <climits> #include <string.h> @@ -25,20 +22,17 @@ namespace po = boost::program_options; namespace dtrain { -typedef double score_t; - struct ScoredHyp { - vector<WordID> w; + vector<WordID> w; SparseVector<weight_t> f; - score_t model, score; - unsigned rank; + weight_t model, gold; + size_t rank; }; inline void RegisterAndConvert(const vector<string>& strs, vector<WordID>& ids) { - vector<string>::const_iterator it; for (auto s: strs) ids.push_back(TD::Convert(s)); } @@ -46,7 +40,7 @@ RegisterAndConvert(const vector<string>& strs, vector<WordID>& ids) inline void PrintWordIDVec(vector<WordID>& v, ostream& os=cerr) { - for (unsigned i = 0; i < v.size(); i++) { + for (size_t i = 0; i < v.size(); i++) { os << TD::Convert(v[i]); if (i < v.size()-1) os << " "; } @@ -57,12 +51,45 @@ inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); } inline ostream& _p2(ostream& out) { return out << setprecision(2); } inline ostream& _p5(ostream& out) { return out << setprecision(5); } -template<typename T> -inline T -sign(T z) +bool +dtrain_init(int argc, char** argv, po::variables_map* conf) { - if (z == 0) return 0; - return z < 0 ? -1 : +1; + po::options_description ini("Configuration File Options"); + ini.add_options() + ("bitext,b", po::value<string>(), "bitext") + ("decoder_config,C", po::value<string>(), "configuration file for decoder") + ("iterations,T", po::value<size_t>()->default_value(10), "number of iterations T (per shard)") + ("k", po::value<size_t>()->default_value(100), "size of kbest list") + ("learning_rate,l", po::value<weight_t>()->default_value(1.0), "learning rate") + ("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength") + ("error_margin,m", po::value<weight_t>()->default_value(0.), "margin for margin perceptron") + ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") + ("input_weights,w", po::value<string>(), "input weights file") + ("average,a", po::value<bool>()->default_value(false), "output average weights") + ("keep,K", po::value<bool>()->default_value(false), "output a weight file per iteration") + ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") + ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), + "list of weights to print after each iteration"); + po::options_description cl("Command Line Options"); + cl.add_options() + ("config,c", po::value<string>(), "dtrain config file"); + cl.add(ini); + po::store(parse_command_line(argc, argv, cl), *conf); + if (conf->count("config")) { + ifstream f((*conf)["config"].as<string>().c_str()); + po::store(po::parse_config_file(f, ini), *conf); + } + po::notify(*conf); + if (!conf->count("decoder_config")) { + cerr << "Missing decoder configuration." << endl; + return false; + } + if (!conf->count("bitext")) { + cerr << "No training data given." << endl; + return false; + } + + return true; } } // namespace |