summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc166
1 files changed, 113 insertions, 53 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index b39fff3e..e563f541 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -13,45 +13,65 @@ main(int argc, char** argv)
if (!dtrain_init(argc, argv, &conf))
return 1;
const size_t k = conf["k"].as<size_t>();
+ const bool unique_kbest = conf["unique_kbest"].as<bool>();
+ const bool forest_sample = conf["forest_sample"].as<bool>();
const string score_name = conf["score"].as<string>();
+ const weight_t nakov_fix = conf["nakov_fix"].as<weight_t>();
+ const weight_t chiang_decay = conf["chiang_decay"].as<weight_t>();
const size_t N = conf["N"].as<size_t>();
const size_t T = conf["iterations"].as<size_t>();
const weight_t eta = conf["learning_rate"].as<weight_t>();
const weight_t margin = conf["margin"].as<weight_t>();
+ const weight_t cut = conf["cut"].as<weight_t>();
+ const bool adjust_cut = conf["adjust"].as<bool>();
+ const bool all_pairs = cut==0;
const bool average = conf["average"].as<bool>();
- const bool structured = conf["struct"].as<bool>();
+ const bool pro = conf["pro_sampling"].as<bool>();
+ const bool structured = conf["structured"].as<bool>();
+ const weight_t threshold = conf["threshold"].as<weight_t>();
+ const size_t max_up = conf["max_pairs"].as<size_t>();
const weight_t l1_reg = conf["l1_reg"].as<weight_t>();
const bool keep = conf["keep"].as<bool>();
const bool noup = conf["disable_learning"].as<bool>();
const string output_fn = conf["output"].as<string>();
- const string output_data_which = conf["output_data"].as<string>();
- const bool output_data = output_data_which!="";
vector<string> print_weights;
boost::split(print_weights, conf["print_weights"].as<string>(),
boost::is_any_of(" "));
+ const string output_updates_fn = conf["output_updates"].as<string>();
+ const bool output_updates = output_updates_fn!="";
+ const string output_raw_fn = conf["output_raw"].as<string>();
+ const bool output_raw = output_raw_fn!="";
- // setup decoder and scorer
+ // setup decoder
register_feature_functions();
SetSilent(true);
ReadFile f(conf["decoder_conf"].as<string>());
Decoder decoder(f.stream());
+
+ // setup scorer & observer
Scorer* scorer;
if (score_name == "nakov") {
- scorer = static_cast<PerSentenceBleuScorer*>(new PerSentenceBleuScorer(N));
+ scorer = static_cast<NakovBleuScorer*>(new NakovBleuScorer(N, nakov_fix));
} else if (score_name == "papineni") {
- scorer = static_cast<BleuScorer*>(new BleuScorer(N));
+ scorer = static_cast<PapineniBleuScorer*>(new PapineniBleuScorer(N));
} else if (score_name == "lin") {
- scorer = static_cast<OriginalPerSentenceBleuScorer*>\
- (new OriginalPerSentenceBleuScorer(N));
+ scorer = static_cast<LinBleuScorer*>(new LinBleuScorer(N));
} else if (score_name == "liang") {
- scorer = static_cast<SmoothPerSentenceBleuScorer*>\
- (new SmoothPerSentenceBleuScorer(N));
+ scorer = static_cast<LiangBleuScorer*>(new LiangBleuScorer(N));
} else if (score_name == "chiang") {
- scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N));
+ scorer = static_cast<ChiangBleuScorer*>(new ChiangBleuScorer(N));
+ } else if (score_name == "sum") {
+ scorer = static_cast<SumBleuScorer*>(new SumBleuScorer(N));
} else {
assert(false);
}
- ScoredKbest* observer = new ScoredKbest(k, scorer);
+ HypSampler* observer;
+ if (forest_sample)
+ observer = new KSampler(k, scorer);
+ else if (unique_kbest)
+ observer = new KBestSampler(k, scorer);
+ else
+ observer = new KBestNoFilterSampler(k, scorer);
// weights
vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
@@ -65,22 +85,46 @@ main(int argc, char** argv)
string input_fn = conf["bitext"].as<string>();
ReadFile input(input_fn);
vector<string> buf; // decoder only accepts strings as input
- vector<vector<Ngrams> > buf_ngs; // compute ngrams and lengths of references
- vector<vector<size_t> > buf_ls; // just once
+ vector<vector<Ngrams> > buffered_ngrams; // compute ngrams and lengths of references
+ vector<vector<size_t> > buffered_lengths; // (just once)
size_t input_sz = 0;
- cerr << _p4;
+ cerr << setprecision(4);
// output configuration
cerr << "Parameters:" << endl;
cerr << setw(25) << "bitext " << "'" << input_fn << "'" << endl;
cerr << setw(25) << "k " << k << endl;
+ if (unique_kbest && !forest_sample)
+ cerr << setw(25) << "unique k-best " << unique_kbest << endl;
+ if (forest_sample)
+ cerr << setw(25) << "forest " << forest_sample << endl;
+ if (all_pairs)
+ cerr << setw(25) << "all pairs " << all_pairs << endl;
+ else if (pro)
+ cerr << setw(25) << "PRO " << pro << endl;
cerr << setw(25) << "score " << "'" << score_name << "'" << endl;
+ if (score_name == "nakov")
+ cerr << setw(25) << "nakov fix " << nakov_fix << endl;
+ if (score_name == "chiang")
+ cerr << setw(25) << "chiang decay " << chiang_decay << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
cerr << setw(25) << "learning rate " << eta << endl;
cerr << setw(25) << "margin " << margin << endl;
+ if (!structured) {
+ cerr << setw(25) << "cut " << cut << endl;
+ cerr << setw(25) << "adjust " << adjust_cut << endl;
+ } else {
+ cerr << setw(25) << "struct. obj " << structured << endl;
+ }
+ if (threshold > 0)
+ cerr << setw(25) << "threshold " << threshold << endl;
+ if (max_up != numeric_limits<size_t>::max())
+ cerr << setw(25) << "max up. " << max_up << endl;
+ if (noup)
+ cerr << setw(25) << "no up. " << noup << endl;
cerr << setw(25) << "average " << average << endl;
- cerr << setw(25) << "l1 reg " << l1_reg << endl;
+ cerr << setw(25) << "l1 reg. " << l1_reg << endl;
cerr << setw(25) << "decoder conf " << "'"
<< conf["decoder_conf"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
@@ -89,6 +133,8 @@ main(int argc, char** argv)
cerr << setw(25) << "weights in " << "'"
<< conf["input_weights"].as<string>() << "'" << endl;
}
+ if (noup)
+ cerr << setw(25) << "no updates!" << endl;
cerr << "(1 dot per processed input)" << endl;
// meta
@@ -96,6 +142,13 @@ main(int argc, char** argv)
size_t best_iteration = 0;
time_t total_time = 0.;
+ // output
+ WriteFile raw_out;
+ if (output_raw) raw_out.Init(output_raw_fn);
+ WriteFile updates_out;
+ if (output_updates) updates_out.Init(output_raw_fn);
+
+
for (size_t t = 0; t < T; t++) // T iterations
{
@@ -120,16 +173,16 @@ main(int argc, char** argv)
boost::algorithm::split_regex(parts, in, boost::regex(" \\|\\|\\| "));
buf.push_back(parts[0]);
parts.erase(parts.begin());
- buf_ngs.push_back({});
- buf_ls.push_back({});
+ buffered_ngrams.push_back({});
+ buffered_lengths.push_back({});
for (auto s: parts) {
vector<WordID> r;
vector<string> toks;
boost::split(toks, s, boost::is_any_of(" "));
for (auto tok: toks)
r.push_back(TD::Convert(tok));
- buf_ngs.back().emplace_back(MakeNgrams(r, N));
- buf_ls.back().push_back(r.size());
+ buffered_ngrams.back().emplace_back(ngrams(r, N));
+ buffered_lengths.back().push_back(r.size());
}
}
} else {
@@ -155,50 +208,54 @@ main(int argc, char** argv)
// decode
if (t > 0 || i > 0)
lambdas.init_vector(&decoder_weights);
- observer->SetReference(buf_ngs[i], buf_ls[i]);
+ observer->reference_ngrams = &buffered_ngrams[i];
+ observer->reference_lengths = &buffered_lengths[i];
decoder.Decode(buf[i], observer);
- vector<ScoredHyp>* samples = observer->GetSamples();
-
- // stats for 1best
- gold_sum += samples->front().gold;
- model_sum += samples->front().model;
- feature_count += observer->GetFeatureCount();
- list_sz += observer->GetSize();
-
- if (output_data) {
- if (output_data_which == "kbest") {
- OutputKbest(samples);
- } else if (output_data_which == "default") {
- OutputMultipartitePairs(samples, margin);
- } else if (output_data_which == "all") {
- OutputAllPairs(samples);
- }
- }
+ vector<Hyp>* sample = &(observer->sample);
+
+ // stats for 1-best
+ gold_sum += sample->front().gold;
+ model_sum += sample->front().model;
+ feature_count += observer->feature_count;
+ list_sz += observer->effective_size;
+
+ if (output_raw)
+ output_sample(sample);
- // get pairs and update
+ // update model
if (!noup) {
SparseVector<weight_t> updates;
if (structured)
- num_up += CollectUpdatesStruct(samples, updates);
+ num_up += update_structured(sample, updates, margin/*,
+ output_updates, updates_out.get()*/); // FIXME
+ else if (all_pairs)
+ num_up += updates_all(sample, updates, max_up, threshold/*,
+ output_updates, updates_out.get()*/); // FIXME
+ else if (pro)
+ num_up += updates_pro(sample, updates, cut, max_up, threshold/*,
+ output_updates, updates_out.get()*/); // FIXME
else
- num_up += CollectUpdates(samples, updates, margin);
+ num_up += updates_multipartite(sample, updates, cut, margin,
+ max_up, threshold, adjust_cut/*,
+ output_updates, updates_out.get()*/); // FIXME
SparseVector<weight_t> lambdas_copy;
if (l1_reg)
lambdas_copy = lambdas;
lambdas.plus_eq_v_times_s(updates, eta);
- // update context for approx. BLEU
+ // update context for Chiang's approx. BLEU
if (score_name == "chiang") {
- for (auto it: *samples) {
+ for (auto it: *sample) {
if (it.rank == 0) {
- scorer->UpdateContext(it.w, buf_ngs[i], buf_ls[i], 0.9);
+ scorer->update_context(it.w, buffered_ngrams[i],
+ buffered_lengths[i], chiang_decay);
break;
}
}
}
- // l1 regularization
+ // \ell_1 regularization
// NB: regularization is done after each sentence,
// not after every single pair!
if (l1_reg) {
@@ -234,19 +291,22 @@ main(int argc, char** argv)
// stats
weight_t gold_avg = gold_sum/(weight_t)input_sz;
- cerr << _p << "WEIGHTS" << endl;
+ cerr << setiosflags(ios::showpos) << "WEIGHTS" << endl;
for (auto name: print_weights)
- cerr << setw(18) << name << " = " << lambdas.get(FD::Convert(name)) << endl;
+ cerr << setw(18) << name << " = "
+ << lambdas.get(FD::Convert(name)) << endl;
cerr << " ---" << endl;
- cerr << _np << " 1best avg score: " << gold_avg*100;
- cerr << _p << " (" << (gold_avg-gold_prev)*100 << ")" << endl;
+ cerr << resetiosflags(ios::showpos)
+ << " 1best avg score: " << gold_avg*100;
+ cerr << setiosflags(ios::showpos) << " ("
+ << (gold_avg-gold_prev)*100 << ")" << endl;
cerr << " 1best avg model score: "
<< model_sum/(weight_t)input_sz << endl;
cerr << " avg # updates: ";
- cerr << _np << num_up/(float)input_sz << endl;
- cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl;
- cerr << " avg f count: " << feature_count/(float)list_sz << endl;
- cerr << " avg list sz: " << list_sz/(float)input_sz << endl;
+ cerr << resetiosflags(ios::showpos) << num_up/(float)input_sz << endl;
+ cerr << " non-0 feature count: " << lambdas.num_nonzero() << endl;
+ cerr << " avg f count: " << feature_count/(float)list_sz << endl;
+ cerr << " avg list sz: " << list_sz/(float)input_sz << endl;
if (gold_avg > best) {
best = gold_avg;