diff options
Diffstat (limited to 'training/dtrain/dtrain.h')
-rw-r--r-- | training/dtrain/dtrain.h | 162 |
1 files changed, 69 insertions, 93 deletions
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 07bd9b65..0bbb5c9b 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -1,9 +1,6 @@ #ifndef _DTRAIN_H_ #define _DTRAIN_H_ -#define DTRAIN_DOTS 10 // after how many inputs to display a '.' -#define DTRAIN_SCALE 100000 - #include <iomanip> #include <climits> #include <string.h> @@ -25,113 +22,92 @@ namespace po = boost::program_options; namespace dtrain { - -inline void register_and_convert(const vector<string>& strs, vector<WordID>& ids) -{ - vector<string>::const_iterator it; - for (it = strs.begin(); it < strs.end(); it++) - ids.push_back(TD::Convert(*it)); -} - -inline string gettmpf(const string path, const string infix) -{ - char fn[path.size() + infix.size() + 8]; - strcpy(fn, path.c_str()); - strcat(fn, "/"); - strcat(fn, infix.c_str()); - strcat(fn, "-XXXXXX"); - if (!mkstemp(fn)) { - cerr << "Cannot make temp file in" << path << " , exiting." << endl; - exit(1); - } - return string(fn); -} - -typedef double score_t; - struct ScoredHyp { - vector<WordID> w; - SparseVector<double> f; - score_t model; - score_t score; - unsigned rank; + vector<WordID> w; + SparseVector<weight_t> f; + weight_t model, gold; + size_t rank; }; -struct LocalScorer +inline void +PrintWordIDVec(vector<WordID>& v, ostream& os=cerr) { - unsigned N_; - vector<score_t> w_; - - virtual score_t - Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned rank, const unsigned src_len)=0; - - virtual void Reset() {} // only for ApproxBleuScorer, LinearBleuScorer - - inline void - Init(unsigned N, vector<score_t> weights) - { - assert(N > 0); - N_ = N; - if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_); - else w_ = weights; - } - - inline score_t - brevity_penalty(const unsigned hyp_len, const unsigned ref_len) - { - if (hyp_len > ref_len) return 1; - return exp(1 - (score_t)ref_len/hyp_len); + for (size_t i = 0; i < v.size(); i++) { + os << TD::Convert(v[i]); + if (i < v.size()-1) os << " "; } -}; +} -struct HypSampler : public DecoderObserver -{ - LocalScorer* scorer_; - vector<WordID>* ref_; - unsigned f_count_, sz_; - virtual vector<ScoredHyp>* GetSamples()=0; - inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } - inline void SetRef(vector<WordID>& ref) { ref_ = &ref; } - inline unsigned get_f_count() { return f_count_; } - inline unsigned get_sz() { return sz_; } -}; +inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); } +inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); } +inline ostream& _p4(ostream& out) { return out << setprecision(4); } -struct HSReporter +bool +dtrain_init(int argc, char** argv, po::variables_map* conf) { - string task_id_; - - HSReporter(string task_id) : task_id_(task_id) {} - - inline void update_counter(string name, unsigned amount) { - cerr << "reporter:counter:" << task_id_ << "," << name << "," << amount << endl; + po::options_description opts("Configuration File Options"); + opts.add_options() + ("bitext,b", po::value<string>(), "bitext") + ("decoder_conf,C", po::value<string>(), "configuration file for decoder") + ("iterations,T", po::value<size_t>()->default_value(15), "number of iterations T (per shard)") + ("k", po::value<size_t>()->default_value(100), "size of kbest list") + ("learning_rate,l", po::value<weight_t>()->default_value(0.00001), "learning rate") + ("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength") + ("margin,m", po::value<weight_t>()->default_value(1.0), "margin for margin perceptron") + ("score,s", po::value<string>()->default_value("chiang"), "per-sentence BLEU approx.") + ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation") + ("input_weights,w", po::value<string>(), "input weights file") + ("average,a", po::bool_switch()->default_value(true), "output average weights") + ("keep,K", po::bool_switch()->default_value(false), "output a weight file per iteration") + ("struct,S", po::bool_switch()->default_value(false), "structured SGD with hope/fear") + ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") + ("disable_learning,X", po::bool_switch()->default_value(false), "disable learning") + ("output_data,D", po::value<string>()->default_value(""), "output data to STDOUT; arg. is 'kbest', 'default' or 'all'") + ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"), + "list of weights to print after each iteration"); + po::options_description clopts("Command Line Options"); + clopts.add_options() + ("conf,c", po::value<string>(), "dtrain configuration file") + ("help,h", po::bool_switch(), "display options"); + opts.add(clopts); + po::store(parse_command_line(argc, argv, opts), *conf); + cerr << "dtrain" << endl << endl; + if ((*conf)["help"].as<bool>()) { + cerr << opts << endl; + + return false; } - inline void update_gcounter(string name, unsigned amount) { - cerr << "reporter:counter:Global," << name << "," << amount << endl; + if (conf->count("conf")) { + ifstream f((*conf)["conf"].as<string>().c_str()); + po::store(po::parse_config_file(f, opts), *conf); } -}; + po::notify(*conf); + if (!conf->count("decoder_conf")) { + cerr << "Missing decoder configuration." << endl; + cerr << opts << endl; -inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); } -inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); } -inline ostream& _p2(ostream& out) { return out << setprecision(2); } -inline ostream& _p5(ostream& out) { return out << setprecision(5); } + return false; + } + if (!conf->count("bitext")) { + cerr << "No input given." << endl; + cerr << opts << endl; -inline void printWordIDVec(vector<WordID>& v, ostream& os=cerr) -{ - for (unsigned i = 0; i < v.size(); i++) { - os << TD::Convert(v[i]); - if (i < v.size()-1) os << " "; + return false; + } + if ((*conf)["output_data"].as<string>() != "") { + if ((*conf)["output_data"].as<string>() != "kbest" && + (*conf)["output_data"].as<string>() != "default" && + (*conf)["output_data"].as<string>() != "all") { + cerr << "Wrong 'output_data' argument: "; + cerr << (*conf)["output_data"].as<string>() << endl; + return false; + } } -} -template<typename T> -inline T sign(T z) -{ - if (z == 0) return 0; - return z < 0 ? -1 : +1; + return true; } - } // namespace #endif |