summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.h
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.h')
-rw-r--r--training/dtrain/dtrain.h162
1 files changed, 69 insertions, 93 deletions
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index 07bd9b65..0bbb5c9b 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -1,9 +1,6 @@
#ifndef _DTRAIN_H_
#define _DTRAIN_H_
-#define DTRAIN_DOTS 10 // after how many inputs to display a '.'
-#define DTRAIN_SCALE 100000
-
#include <iomanip>
#include <climits>
#include <string.h>
@@ -25,113 +22,92 @@ namespace po = boost::program_options;
namespace dtrain
{
-
-inline void register_and_convert(const vector<string>& strs, vector<WordID>& ids)
-{
- vector<string>::const_iterator it;
- for (it = strs.begin(); it < strs.end(); it++)
- ids.push_back(TD::Convert(*it));
-}
-
-inline string gettmpf(const string path, const string infix)
-{
- char fn[path.size() + infix.size() + 8];
- strcpy(fn, path.c_str());
- strcat(fn, "/");
- strcat(fn, infix.c_str());
- strcat(fn, "-XXXXXX");
- if (!mkstemp(fn)) {
- cerr << "Cannot make temp file in" << path << " , exiting." << endl;
- exit(1);
- }
- return string(fn);
-}
-
-typedef double score_t;
-
struct ScoredHyp
{
- vector<WordID> w;
- SparseVector<double> f;
- score_t model;
- score_t score;
- unsigned rank;
+ vector<WordID> w;
+ SparseVector<weight_t> f;
+ weight_t model, gold;
+ size_t rank;
};
-struct LocalScorer
+inline void
+PrintWordIDVec(vector<WordID>& v, ostream& os=cerr)
{
- unsigned N_;
- vector<score_t> w_;
-
- virtual score_t
- Score(const vector<WordID>& hyp, const vector<WordID>& ref, const unsigned rank, const unsigned src_len)=0;
-
- virtual void Reset() {} // only for ApproxBleuScorer, LinearBleuScorer
-
- inline void
- Init(unsigned N, vector<score_t> weights)
- {
- assert(N > 0);
- N_ = N;
- if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_);
- else w_ = weights;
- }
-
- inline score_t
- brevity_penalty(const unsigned hyp_len, const unsigned ref_len)
- {
- if (hyp_len > ref_len) return 1;
- return exp(1 - (score_t)ref_len/hyp_len);
+ for (size_t i = 0; i < v.size(); i++) {
+ os << TD::Convert(v[i]);
+ if (i < v.size()-1) os << " ";
}
-};
+}
-struct HypSampler : public DecoderObserver
-{
- LocalScorer* scorer_;
- vector<WordID>* ref_;
- unsigned f_count_, sz_;
- virtual vector<ScoredHyp>* GetSamples()=0;
- inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; }
- inline void SetRef(vector<WordID>& ref) { ref_ = &ref; }
- inline unsigned get_f_count() { return f_count_; }
- inline unsigned get_sz() { return sz_; }
-};
+inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); }
+inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); }
+inline ostream& _p4(ostream& out) { return out << setprecision(4); }
-struct HSReporter
+bool
+dtrain_init(int argc, char** argv, po::variables_map* conf)
{
- string task_id_;
-
- HSReporter(string task_id) : task_id_(task_id) {}
-
- inline void update_counter(string name, unsigned amount) {
- cerr << "reporter:counter:" << task_id_ << "," << name << "," << amount << endl;
+ po::options_description opts("Configuration File Options");
+ opts.add_options()
+ ("bitext,b", po::value<string>(), "bitext")
+ ("decoder_conf,C", po::value<string>(), "configuration file for decoder")
+ ("iterations,T", po::value<size_t>()->default_value(15), "number of iterations T (per shard)")
+ ("k", po::value<size_t>()->default_value(100), "size of kbest list")
+ ("learning_rate,l", po::value<weight_t>()->default_value(0.00001), "learning rate")
+ ("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength")
+ ("margin,m", po::value<weight_t>()->default_value(1.0), "margin for margin perceptron")
+ ("score,s", po::value<string>()->default_value("chiang"), "per-sentence BLEU approx.")
+ ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation")
+ ("input_weights,w", po::value<string>(), "input weights file")
+ ("average,a", po::bool_switch()->default_value(true), "output average weights")
+ ("keep,K", po::bool_switch()->default_value(false), "output a weight file per iteration")
+ ("struct,S", po::bool_switch()->default_value(false), "structured SGD with hope/fear")
+ ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
+ ("disable_learning,X", po::bool_switch()->default_value(false), "disable learning")
+ ("output_data,D", po::value<string>()->default_value(""), "output data to STDOUT; arg. is 'kbest', 'default' or 'all'")
+ ("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"),
+ "list of weights to print after each iteration");
+ po::options_description clopts("Command Line Options");
+ clopts.add_options()
+ ("conf,c", po::value<string>(), "dtrain configuration file")
+ ("help,h", po::bool_switch(), "display options");
+ opts.add(clopts);
+ po::store(parse_command_line(argc, argv, opts), *conf);
+ cerr << "dtrain" << endl << endl;
+ if ((*conf)["help"].as<bool>()) {
+ cerr << opts << endl;
+
+ return false;
}
- inline void update_gcounter(string name, unsigned amount) {
- cerr << "reporter:counter:Global," << name << "," << amount << endl;
+ if (conf->count("conf")) {
+ ifstream f((*conf)["conf"].as<string>().c_str());
+ po::store(po::parse_config_file(f, opts), *conf);
}
-};
+ po::notify(*conf);
+ if (!conf->count("decoder_conf")) {
+ cerr << "Missing decoder configuration." << endl;
+ cerr << opts << endl;
-inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); }
-inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); }
-inline ostream& _p2(ostream& out) { return out << setprecision(2); }
-inline ostream& _p5(ostream& out) { return out << setprecision(5); }
+ return false;
+ }
+ if (!conf->count("bitext")) {
+ cerr << "No input given." << endl;
+ cerr << opts << endl;
-inline void printWordIDVec(vector<WordID>& v, ostream& os=cerr)
-{
- for (unsigned i = 0; i < v.size(); i++) {
- os << TD::Convert(v[i]);
- if (i < v.size()-1) os << " ";
+ return false;
+ }
+ if ((*conf)["output_data"].as<string>() != "") {
+ if ((*conf)["output_data"].as<string>() != "kbest" &&
+ (*conf)["output_data"].as<string>() != "default" &&
+ (*conf)["output_data"].as<string>() != "all") {
+ cerr << "Wrong 'output_data' argument: ";
+ cerr << (*conf)["output_data"].as<string>() << endl;
+ return false;
+ }
}
-}
-template<typename T>
-inline T sign(T z)
-{
- if (z == 0) return 0;
- return z < 0 ? -1 : +1;
+ return true;
}
-
} // namespace
#endif