summaryrefslogtreecommitdiff
path: root/training/dtrain
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain')
-rw-r--r--training/dtrain/README.md18
-rw-r--r--training/dtrain/dtrain.cc166
-rw-r--r--training/dtrain/dtrain.h110
3 files changed, 197 insertions, 97 deletions
diff --git a/training/dtrain/README.md b/training/dtrain/README.md
index 73a6a5a5..dc473568 100644
--- a/training/dtrain/README.md
+++ b/training/dtrain/README.md
@@ -16,6 +16,24 @@ Running
-------
Download runnable examples for all use cases from [1] and extract here.
+TODO
+----
+ * "stop_after" stop after X inputs
+ * "select_weights" average, best, last
+ * "rescale" rescale weight vector
+ * implement SVM objective?
+ * other variants of l1 regularization?
+ * l2 regularization?
+ * l1/l2 regularization?
+ * scale updates by bleu difference
+ * AdaGrad, per-coordinate learning rates
+ * batch update
+ * "repeat" iterate over k-best lists
+ * show k-best loss improvement
+ * "quiet"
+ * "verbose"
+ * fix output
+
Legal
-----
Copyright (c) 2012-2015 by Patrick Simianer <p@simianer.de>
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index b39fff3e..e563f541 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -13,45 +13,65 @@ main(int argc, char** argv)
if (!dtrain_init(argc, argv, &conf))
return 1;
const size_t k = conf["k"].as<size_t>();
+ const bool unique_kbest = conf["unique_kbest"].as<bool>();
+ const bool forest_sample = conf["forest_sample"].as<bool>();
const string score_name = conf["score"].as<string>();
+ const weight_t nakov_fix = conf["nakov_fix"].as<weight_t>();
+ const weight_t chiang_decay = conf["chiang_decay"].as<weight_t>();
const size_t N = conf["N"].as<size_t>();
const size_t T = conf["iterations"].as<size_t>();
const weight_t eta = conf["learning_rate"].as<weight_t>();
const weight_t margin = conf["margin"].as<weight_t>();
+ const weight_t cut = conf["cut"].as<weight_t>();
+ const bool adjust_cut = conf["adjust"].as<bool>();
+ const bool all_pairs = cut==0;
const bool average = conf["average"].as<bool>();
- const bool structured = conf["struct"].as<bool>();
+ const bool pro = conf["pro_sampling"].as<bool>();
+ const bool structured = conf["structured"].as<bool>();
+ const weight_t threshold = conf["threshold"].as<weight_t>();
+ const size_t max_up = conf["max_pairs"].as<size_t>();
const weight_t l1_reg = conf["l1_reg"].as<weight_t>();
const bool keep = conf["keep"].as<bool>();
const bool noup = conf["disable_learning"].as<bool>();
const string output_fn = conf["output"].as<string>();
- const string output_data_which = conf["output_data"].as<string>();
- const bool output_data = output_data_which!="";
vector<string> print_weights;
boost::split(print_weights, conf["print_weights"].as<string>(),
boost::is_any_of(" "));
+ const string output_updates_fn = conf["output_updates"].as<string>();
+ const bool output_updates = output_updates_fn!="";
+ const string output_raw_fn = conf["output_raw"].as<string>();
+ const bool output_raw = output_raw_fn!="";
- // setup decoder and scorer
+ // setup decoder
register_feature_functions();
SetSilent(true);
ReadFile f(conf["decoder_conf"].as<string>());
Decoder decoder(f.stream());
+
+ // setup scorer & observer
Scorer* scorer;
if (score_name == "nakov") {
- scorer = static_cast<PerSentenceBleuScorer*>(new PerSentenceBleuScorer(N));
+ scorer = static_cast<NakovBleuScorer*>(new NakovBleuScorer(N, nakov_fix));
} else if (score_name == "papineni") {
- scorer = static_cast<BleuScorer*>(new BleuScorer(N));
+ scorer = static_cast<PapineniBleuScorer*>(new PapineniBleuScorer(N));
} else if (score_name == "lin") {
- scorer = static_cast<OriginalPerSentenceBleuScorer*>\
- (new OriginalPerSentenceBleuScorer(N));
+ scorer = static_cast<LinBleuScorer*>(new LinBleuScorer(N));
} else if (score_name == "liang") {
- scorer = static_cast<SmoothPerSentenceBleuScorer*>\
- (new SmoothPerSentenceBleuScorer(N));
+ scorer = static_cast<LiangBleuScorer*>(new LiangBleuScorer(N));
} else if (score_name == "chiang") {
- scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N));
+ scorer = static_cast<ChiangBleuScorer*>(new ChiangBleuScorer(N));
+ } else if (score_name == "sum") {
+ scorer = static_cast<SumBleuScorer*>(new SumBleuScorer(N));
} else {
assert(false);
}
- ScoredKbest* observer = new ScoredKbest(k, scorer);
+ HypSampler* observer;
+ if (forest_sample)
+ observer = new KSampler(k, scorer);
+ else if (unique_kbest)
+ observer = new KBestSampler(k, scorer);
+ else
+ observer = new KBestNoFilterSampler(k, scorer);
// weights
vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
@@ -65,22 +85,46 @@ main(int argc, char** argv)
string input_fn = conf["bitext"].as<string>();
ReadFile input(input_fn);
vector<string> buf; // decoder only accepts strings as input
- vector<vector<Ngrams> > buf_ngs; // compute ngrams and lengths of references
- vector<vector<size_t> > buf_ls; // just once
+ vector<vector<Ngrams> > buffered_ngrams; // compute ngrams and lengths of references
+ vector<vector<size_t> > buffered_lengths; // (just once)
size_t input_sz = 0;
- cerr << _p4;
+ cerr << setprecision(4);
// output configuration
cerr << "Parameters:" << endl;
cerr << setw(25) << "bitext " << "'" << input_fn << "'" << endl;
cerr << setw(25) << "k " << k << endl;
+ if (unique_kbest && !forest_sample)
+ cerr << setw(25) << "unique k-best " << unique_kbest << endl;
+ if (forest_sample)
+ cerr << setw(25) << "forest " << forest_sample << endl;
+ if (all_pairs)
+ cerr << setw(25) << "all pairs " << all_pairs << endl;
+ else if (pro)
+ cerr << setw(25) << "PRO " << pro << endl;
cerr << setw(25) << "score " << "'" << score_name << "'" << endl;
+ if (score_name == "nakov")
+ cerr << setw(25) << "nakov fix " << nakov_fix << endl;
+ if (score_name == "chiang")
+ cerr << setw(25) << "chiang decay " << chiang_decay << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
cerr << setw(25) << "learning rate " << eta << endl;
cerr << setw(25) << "margin " << margin << endl;
+ if (!structured) {
+ cerr << setw(25) << "cut " << cut << endl;
+ cerr << setw(25) << "adjust " << adjust_cut << endl;
+ } else {
+ cerr << setw(25) << "struct. obj " << structured << endl;
+ }
+ if (threshold > 0)
+ cerr << setw(25) << "threshold " << threshold << endl;
+ if (max_up != numeric_limits<size_t>::max())
+ cerr << setw(25) << "max up. " << max_up << endl;
+ if (noup)
+ cerr << setw(25) << "no up. " << noup << endl;
cerr << setw(25) << "average " << average << endl;
- cerr << setw(25) << "l1 reg " << l1_reg << endl;
+ cerr << setw(25) << "l1 reg. " << l1_reg << endl;
cerr << setw(25) << "decoder conf " << "'"
<< conf["decoder_conf"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
@@ -89,6 +133,8 @@ main(int argc, char** argv)
cerr << setw(25) << "weights in " << "'"
<< conf["input_weights"].as<string>() << "'" << endl;
}
+ if (noup)
+ cerr << setw(25) << "no updates!" << endl;
cerr << "(1 dot per processed input)" << endl;
// meta
@@ -96,6 +142,13 @@ main(int argc, char** argv)
size_t best_iteration = 0;
time_t total_time = 0.;
+ // output
+ WriteFile raw_out;
+ if (output_raw) raw_out.Init(output_raw_fn);
+ WriteFile updates_out;
+ if (output_updates) updates_out.Init(output_raw_fn);
+
+
for (size_t t = 0; t < T; t++) // T iterations
{
@@ -120,16 +173,16 @@ main(int argc, char** argv)
boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
buf.push_back(parts[0]);
parts.erase(parts.begin());
- buf_ngs.push_back({});
- buf_ls.push_back({});
+ buffered_ngrams.push_back({});
+ buffered_lengths.push_back({});
for (auto s: parts) {
vector<WordID> r;
vector<string> toks;
boost::split(toks, s, boost::is_any_of(" "));
for (auto tok: toks)
r.push_back(TD::Convert(tok));
- buf_ngs.back().emplace_back(MakeNgrams(r, N));
- buf_ls.back().push_back(r.size());
+ buffered_ngrams.back().emplace_back(ngrams(r, N));
+ buffered_lengths.back().push_back(r.size());
}
}
} else {
@@ -155,50 +208,54 @@ main(int argc, char** argv)
// decode
if (t > 0 || i > 0)
lambdas.init_vector(&decoder_weights);
- observer->SetReference(buf_ngs[i], buf_ls[i]);
+ observer->reference_ngrams = &buffered_ngrams[i];
+ observer->reference_lengths = &buffered_lengths[i];
decoder.Decode(buf[i], observer);
- vector<ScoredHyp>* samples = observer->GetSamples();
-
- // stats for 1best
- gold_sum += samples->front().gold;
- model_sum += samples->front().model;
- feature_count += observer->GetFeatureCount();
- list_sz += observer->GetSize();
-
- if (output_data) {
- if (output_data_which == "kbest") {
- OutputKbest(samples);
- } else if (output_data_which == "default") {
- OutputMultipartitePairs(samples, margin);
- } else if (output_data_which == "all") {
- OutputAllPairs(samples);
- }
- }
+ vector<Hyp>* sample = &(observer->sample);
+
+ // stats for 1-best
+ gold_sum += sample->front().gold;
+ model_sum += sample->front().model;
+ feature_count += observer->feature_count;
+ list_sz += observer->effective_size;
+
+ if (output_raw)
+ output_sample(sample);
- // get pairs and update
+ // update model
if (!noup) {
SparseVector<weight_t> updates;
if (structured)
- num_up += CollectUpdatesStruct(samples, updates);
+ num_up += update_structured(sample, updates, margin/*,
+ output_updates, updates_out.get()*/); // FIXME
+ else if (all_pairs)
+ num_up += updates_all(sample, updates, max_up, threshold/*,
+ output_updates, updates_out.get()*/); // FIXME
+ else if (pro)
+ num_up += updates_pro(sample, updates, cut, max_up, threshold/*,
+ output_updates, updates_out.get()*/); // FIXME
else
- num_up += CollectUpdates(samples, updates, margin);
+ num_up += updates_multipartite(sample, updates, cut, margin,
+ max_up, threshold, adjust_cut/*,
+ output_updates, updates_out.get()*/); // FIXME
SparseVector<weight_t> lambdas_copy;
if (l1_reg)
lambdas_copy = lambdas;
lambdas.plus_eq_v_times_s(updates, eta);
- // update context for approx. BLEU
+ // update context for Chiang's approx. BLEU
if (score_name == "chiang") {
- for (auto it: *samples) {
+ for (auto it: *sample) {
if (it.rank == 0) {
- scorer->UpdateContext(it.w, buf_ngs[i], buf_ls[i], 0.9);
+ scorer->update_context(it.w, buffered_ngrams[i],
+ buffered_lengths[i], chiang_decay);
break;
}
}
}
- // l1 regularization
+ // \ell_1 regularization
// NB: regularization is done after each sentence,
// not after every single pair!
if (l1_reg) {
@@ -234,19 +291,22 @@ main(int argc, char** argv)
// stats
weight_t gold_avg = gold_sum/(weight_t)input_sz;
- cerr << _p << "WEIGHTS" << endl;
+ cerr << setiosflags(ios::showpos) << "WEIGHTS" << endl;
for (auto name: print_weights)
- cerr << setw(18) << name << " = " << lambdas.get(FD::Convert(name)) << endl;
+ cerr << setw(18) << name << " = "
+ << lambdas.get(FD::Convert(name)) << endl;
cerr << " ---" << endl;
- cerr << _np << " 1best avg score: " << gold_avg*100;
- cerr << _p << " (" << (gold_avg-gold_prev)*100 << ")" << endl;
+ cerr << resetiosflags(ios::showpos)
+ << " 1best avg score: " << gold_avg*100;
+ cerr << setiosflags(ios::showpos) << " ("
+ << (gold_avg-gold_prev)*100 << ")" << endl;
cerr << " 1best avg model score: "
<< model_sum/(weight_t)input_sz << endl;
cerr << " avg # updates: ";
- cerr << _np << num_up/(float)input_sz << endl;
- cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl;
- cerr << " avg f count: " << feature_count/(float)list_sz << endl;
- cerr << " avg list sz: " << list_sz/(float)input_sz << endl;
+ cerr << resetiosflags(ios::showpos) << num_up/(float)input_sz << endl;
+ cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl;
+ cerr << " avg f count: " << feature_count/(float)list_sz << endl;
+ cerr << " avg list sz: " << list_sz/(float)input_sz << endl;
if (gold_avg > best) {
best = gold_avg;
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index 0bbb5c9b..18a7dbdc 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -22,59 +22,90 @@ namespace po = boost::program_options;
namespace dtrain
{
-struct ScoredHyp
+struct Hyp
{
+ Hyp() {}
+ Hyp(vector<WordID> w, SparseVector<weight_t> f, weight_t model, weight_t gold,
+ size_t rank) : w(w), f(f), model(model), gold(gold), rank(rank) {}
+
vector<WordID> w;
SparseVector<weight_t> f;
weight_t model, gold;
size_t rank;
};
-inline void
-PrintWordIDVec(vector<WordID>& v, ostream& os=cerr)
-{
- for (size_t i = 0; i < v.size(); i++) {
- os << TD::Convert(v[i]);
- if (i < v.size()-1) os << " ";
- }
-}
-
-inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); }
-inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); }
-inline ostream& _p4(ostream& out) { return out << setprecision(4); }
-
bool
-dtrain_init(int argc, char** argv, po::variables_map* conf)
+dtrain_init(int argc,
+ char** argv,
+ po::variables_map* conf)
{
po::options_description opts("Configuration File Options");
opts.add_options()
- ("bitext,b", po::value<string>(), "bitext")
- ("decoder_conf,C", po::value<string>(), "configuration file for decoder")
- ("iterations,T", po::value<size_t>()->default_value(15), "number of iterations T (per shard)")
- ("k", po::value<size_t>()->default_value(100), "size of kbest list")
- ("learning_rate,l", po::value<weight_t>()->default_value(0.00001), "learning rate")
- ("l1_reg,r", po::value<weight_t>()->default_value(0.), "l1 regularization strength")
- ("margin,m", po::value<weight_t>()->default_value(1.0), "margin for margin perceptron")
- ("score,s", po::value<string>()->default_value("chiang"), "per-sentence BLEU approx.")
- ("N", po::value<size_t>()->default_value(4), "N for BLEU approximation")
- ("input_weights,w", po::value<string>(), "input weights file")
- ("average,a", po::bool_switch()->default_value(true), "output average weights")
- ("keep,K", po::bool_switch()->default_value(false), "output a weight file per iteration")
- ("struct,S", po::bool_switch()->default_value(false), "structured SGD with hope/fear")
- ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
- ("disable_learning,X", po::bool_switch()->default_value(false), "disable learning")
- ("output_data,D", po::value<string>()->default_value(""), "output data to STDOUT; arg. is 'kbest', 'default' or 'all'")
+ ("bitext,b", po::value<string>(),
+ "bitext, source and references in a single file [e ||| f]")
+ ("decoder_conf,C", po::value<string>(),
+ "decoder configuration file")
+ ("iterations,T", po::value<size_t>()->default_value(15),
+ "number of iterations T")
+ ("k", po::value<size_t>()->default_value(100),
+ "sample size per input (e.g. size of k-best lists)")
+ ("unique_kbest", po::bool_switch()->default_value(true),
+ "unique k-best lists")
+ ("forest_sample", po::bool_switch()->default_value(false),
+ "sample k hyptheses from forest instead of using k-best list")
+ ("learning_rate,l", po::value<weight_t>()->default_value(0.00001),
+ "learning rate [only meaningful if margin>0 or input weights are given]")
+ ("l1_reg,r", po::value<weight_t>()->default_value(0.),
+ "l1 regularization strength [see Tsuruoka, Tsujii and Ananiadou (2009)]")
+ ("margin,m", po::value<weight_t>()->default_value(1.0),
+ "margin for margin perceptron [set =0 for standard perceptron]")
+ ("cut,u", po::value<weight_t>()->default_value(0.1),
+ "use top/bottom 10% (default) of k-best as 'good' and 'bad' for \
+pair sampling, 0 to use all pairs TODO")
+ ("adjust,A", po::bool_switch()->default_value(false),
+ "adjust cut for optimal pos. in k-best to cut")
+ ("score,s", po::value<string>()->default_value("chiang"),
+ "per-sentence BLEU (approx.)")
+ ("nakov_fix", po::value<weight_t>()->default_value(1.0),
+ "add to reference length [see score.h]")
+ ("chiang_decay", po::value<weight_t>()->default_value(0.9),
+ "decaying factor for Chiang's approx. BLEU")
+ ("N", po::value<size_t>()->default_value(4),
+ "N for BLEU approximation")
+ ("input_weights,w", po::value<string>(),
+ "weights to initialize model")
+ ("average,a", po::bool_switch()->default_value(true),
+ "output average weights")
+ ("keep,K", po::bool_switch()->default_value(false),
+ "output a weight file per iteration [as weights.T.gz]")
+ ("structured,S", po::bool_switch()->default_value(false),
+ "structured prediction objective [hope/fear] w/ SGD")
+ ("pro_sampling", po::bool_switch()->default_value(false),
+ "updates from pairs selected as shown in Fig.4 of (Hopkins and May, 2011) [Gamma=max_pairs (default 5000), Xi=cut (default 50); threshold default 0.05]")
+ ("threshold", po::value<weight_t>()->default_value(0.),
+ "(min.) threshold in terms of gold score for pair selection")
+ ("max_pairs",
+ po::value<size_t>()->default_value(numeric_limits<size_t>::max()),
+ "max. number of updates/pairs")
+ ("output,o", po::value<string>()->default_value("-"),
+ "output weights file, '-' for STDOUT")
+ ("disable_learning,X", po::bool_switch()->default_value(false),
+ "fix model")
+ ("output_updates,U", po::value<string>()->default_value(""),
+ "output updates (diff. vectors) [to filename]")
+ ("output_raw,R", po::value<string>()->default_value(""),
+ "output raw data (e.g. k-best lists) [to filename]")
("print_weights,P", po::value<string>()->default_value("EgivenFCoherent SampleCountF CountEF MaxLexFgivenE MaxLexEgivenF IsSingletonF IsSingletonFE Glue WordPenalty PassThrough LanguageModel LanguageModel_OOV"),
- "list of weights to print after each iteration");
+ "list of weights to print after each iteration");
po::options_description clopts("Command Line Options");
clopts.add_options()
("conf,c", po::value<string>(), "dtrain configuration file")
("help,h", po::bool_switch(), "display options");
opts.add(clopts);
po::store(parse_command_line(argc, argv, opts), *conf);
- cerr << "dtrain" << endl << endl;
+ cerr << "*dtrain*" << endl << endl;
if ((*conf)["help"].as<bool>()) {
- cerr << opts << endl;
+ cerr << setprecision(3) << opts << endl;
return false;
}
@@ -90,20 +121,11 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
return false;
}
if (!conf->count("bitext")) {
- cerr << "No input given." << endl;
+ cerr << "No input bitext." << endl;
cerr << opts << endl;
return false;
}
- if ((*conf)["output_data"].as<string>() != "") {
- if ((*conf)["output_data"].as<string>() != "kbest" &&
- (*conf)["output_data"].as<string>() != "default" &&
- (*conf)["output_data"].as<string>() != "all") {
- cerr << "Wrong 'output_data' argument: ";
- cerr << (*conf)["output_data"].as<string>() << endl;
- return false;
- }
- }
return true;
}