From a6d8ae2bd3cc2294e17588656e6aa20a96f6fcbc Mon Sep 17 00:00:00 2001
From: Patrick Simianer
Date: Tue, 12 Nov 2013 18:36:03 +0100
Subject: implemented batch tuning
---
training/dtrain/dtrain.cc | 81 ++++++++++++++++++++++------
training/dtrain/examples/standard/dtrain.ini | 4 +-
2 files changed, 67 insertions(+), 18 deletions(-)
(limited to 'training/dtrain')
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index a496f08a..23131810 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -42,6 +42,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("loss_margin", po::value()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value()->default_value(std::numeric_limits::max()), "max. # of pairs per Sent.")
("pclr", po::value()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate")
+ ("batch", po::value()->zero_tokens(), "do batch optimization")
+ //("repeat", po::value()->default_value(1), "repeat optimization over kbest list this number of times")
+ //("test-k-best", po::value()->zero_tokens(), "check if optimization works (use repeat >= 2)")
("noup", po::value()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -126,7 +129,12 @@ main(int argc, char** argv)
const float hi_lo = cfg["hi_lo"].as();
const score_t approx_bleu_d = cfg["approx_bleu_d"].as();
const unsigned max_pairs = cfg["max_pairs"].as();
+ //int repeat = cfg["repeat"].as();
+ //bool test_k_best = false;
+ //if (cfg.count("test-k-best")) test_k_best = true;
weight_t loss_margin = cfg["loss_margin"].as();
+ bool batch = false;
+ if (cfg.count("batch")) batch = true;
if (loss_margin > 9998.) loss_margin = std::numeric_limits::max();
bool scale_bleu_diff = false;
if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true;
@@ -184,10 +192,10 @@ main(int argc, char** argv)
observer->SetScorer(scorer);
// init weights
- vector& dense_weights = decoder.CurrentWeightVector();
+ vector& decoder_weights = decoder.CurrentWeightVector();
SparseVector lambdas, cumulative_penalties, w_average;
- if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as(), &dense_weights);
- Weights::InitSparseVector(dense_weights, &lambdas);
+ if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as(), &decoder_weights);
+ Weights::InitSparseVector(decoder_weights, &lambdas);
// meta params for perceptron, SVM
weight_t eta = cfg["learning_rate"].as();
@@ -245,6 +253,7 @@ main(int argc, char** argv)
cerr << setw(25) << "k " << k << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
+ cerr << setw(25) << "batch " << batch << endl;
cerr << setw(26) << "scorer '" << scorer_str << "'" << endl;
if (scorer_str == "approx_bleu")
cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl;
@@ -267,6 +276,8 @@ main(int argc, char** argv)
cerr << setw(25) << "rescale " << rescale << endl;
cerr << setw(25) << "pclr " << pclr << endl;
cerr << setw(25) << "max pairs " << max_pairs << endl;
+ //cerr << setw(25) << "repeat " << repeat << endl;
+ //cerr << setw(25) << "test k-best " << test_k_best << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
if (!read_bitext)
@@ -281,17 +292,25 @@ main(int argc, char** argv)
// pclr
SparseVector learning_rates;
+ // batch
+ SparseVector batch_updates;
+ weight_t batch_loss;
+
+ //int did_improve; // FIXME for test-k-best
for (unsigned t = 0; t < T; t++) // T epochs
{
-
+
time_t start, end;
time(&start);
score_t score_sum = 0.;
score_t model_sum(0);
unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0;
+ batch_loss = 0.;
if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
+ //did_improve = 0;
+
while(true)
{
@@ -337,7 +356,7 @@ main(int argc, char** argv)
if (next || stop) break;
// weights
- lambdas.init_vector(&dense_weights);
+ lambdas.init_vector(&decoder_weights);
// getting input
vector ref_ids; // reference as vector
@@ -392,33 +411,51 @@ main(int argc, char** argv)
partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo);
if (pair_sampling == "PRO")
PROsampling(samples, pairs, pair_threshold, max_pairs);
- npairs += pairs.size();
+ int cur_npairs = pairs.size();
+ npairs += cur_npairs;
+
+ weight_t kbest_loss_first, kbest_loss_last = 0.0;
+//for (int q=0; q < repeat; q++) { // repeat
+
+ weight_t kbest_loss = 0.0; // test-k-best
SparseVector lambdas_copy; // for l1 regularization
SparseVector sum_up; // for pclr
if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
for (vector >::iterator it = pairs.begin();
it != pairs.end(); it++) {
- bool rank_error;
+
+ /*if (repeat > 1) {
+ double x = max(0.0, -1.0 * (lambdas.dot(it->first.f) - lambdas.dot(it->second.f)));
+ kbest_loss += x;
+ }*/
+
+ score_t model_diff = it->first.model - it->second.model;
+ bool rank_error = false;
score_t margin;
if (faster_perceptron) { // we only have considering misranked pairs
rank_error = true; // pair sampling already did this for us
margin = std::numeric_limits::max();
} else {
- rank_error = it->first.model <= it->second.model;
- margin = fabs(it->first.model - it->second.model);
+ rank_error = model_diff<=0.0;
+ margin = fabs(model_diff);
if (!rank_error && margin < loss_margin) margin_violations++;
}
if (rank_error) rank_errors++;
if (scale_bleu_diff) eta = it->first.score - it->second.score;
if (rank_error || margin < loss_margin) {
SparseVector diff_vec = it->first.f - it->second.f;
+ if (batch) {
+ batch_loss += max(0., -1.0*model_diff);
+ batch_updates += diff_vec;
+ continue;
+ }
if (pclr != "no") {
sum_up += diff_vec;
} else {
lambdas.plus_eq_v_times_s(diff_vec, eta);
- if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); // FIXME
+ if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./cur_npairs));
}
}
}
@@ -487,6 +524,11 @@ main(int argc, char** argv)
}
}
+ //if (q==0) { kbest_loss_first = kbest_loss; }
+ //if (q==repeat-1) { kbest_loss_last = kbest_loss; }
+//}//repeat
+//if((kbest_loss_first - kbest_loss_last) > 0) did_improve++;
+
}
if (rescale) lambdas /= lambdas.l2norm();
@@ -495,14 +537,20 @@ main(int argc, char** argv)
} // input loop
- if (average) w_average += lambdas;
+ if (t == 0) in_sz = ii; // remember size of input (# lines)
- if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset();
+ //if (repeat > 1) cout << "did improve? " << did_improve << " out of " << in_sz << endl;
- if (t == 0) {
- in_sz = ii; // remember size of input (# lines)
+ if (batch) {
+ lambdas.plus_eq_v_times_s(batch_updates, eta);
+ if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
+ batch_updates.clear();
}
+ if (average) w_average += lambdas;
+
+ if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset();
+
// print some stats
score_t score_avg = score_sum/(score_t)in_sz;
score_t model_avg = model_sum/(score_t)in_sz;
@@ -534,6 +582,7 @@ main(int argc, char** argv)
cerr << endl;
cerr << " avg # rank err: ";
cerr << rank_errors/(float)in_sz << endl;
+ if (batch) cerr << " batch loss: " << batch_loss << endl;
cerr << " avg # margin viol: ";
cerr << margin_violations/(float)in_sz << endl;
cerr << " non0 feature count: " << nonz << endl;
@@ -562,9 +611,9 @@ main(int argc, char** argv)
// write weights to file
if (select_weights == "best" || keep) {
- lambdas.init_vector(&dense_weights);
+ lambdas.init_vector(&decoder_weights);
string w_fn = "weights." + boost::lexical_cast(t) + ".gz";
- Weights::WriteToFile(w_fn, dense_weights, true);
+ Weights::WriteToFile(w_fn, decoder_weights, true);
}
} // outer loop
diff --git a/training/dtrain/examples/standard/dtrain.ini b/training/dtrain/examples/standard/dtrain.ini
index 7dbb4ff0..4d096dfb 100644
--- a/training/dtrain/examples/standard/dtrain.ini
+++ b/training/dtrain/examples/standard/dtrain.ini
@@ -11,11 +11,11 @@ print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 Phr
stop_after=10 # stop epoch after 10 inputs
# interesting stuff
-epochs=3 # run over input 3 times
+epochs=100 # run over input 3 times
k=100 # use 100best lists
N=4 # optimize (approx) BLEU4
scorer=fixed_stupid_bleu # use 'stupid' BLEU+1
-learning_rate=1.0 # learning rate, don't care if gamma=0 (perceptron) and loss_margin=0 (not margin perceptron)
+learning_rate=0.0001 # learning rate, don't care if gamma=0 (perceptron) and loss_margin=0 (not margin perceptron)
gamma=0 # use SVM reg
sample_from=kbest # use kbest lists (as opposed to forest)
filter=uniq # only unique entries in kbest (surface form)
--
cgit v1.2.3