summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-11-12 18:36:03 +0100
committerPatrick Simianer <p@simianer.de>2013-11-12 18:36:03 +0100
commita8ea0a66b798326061bc9f0da153b96b730130f1 (patch)
tree3bd01a848b10c43be7182add1fc4edb19d73244c /training/dtrain/dtrain.cc
parentdecd2c4b1d4fb42a73a3217f347ea8f317e50869 (diff)
implemented batch tuning
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc81
1 files changed, 65 insertions, 16 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index a496f08a..23131810 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -42,6 +42,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate")
+ ("batch", po::value<bool>()->zero_tokens(), "do batch optimization")
+ //("repeat", po::value<int>()->default_value(1), "repeat optimization over kbest list this number of times")
+ //("test-k-best", po::value<bool>()->zero_tokens(), "check if optimization works (use repeat >= 2)")
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -126,7 +129,12 @@ main(int argc, char** argv)
const float hi_lo = cfg["hi_lo"].as<float>();
const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>();
const unsigned max_pairs = cfg["max_pairs"].as<unsigned>();
+ //int repeat = cfg["repeat"].as<int>();
+ //bool test_k_best = false;
+ //if (cfg.count("test-k-best")) test_k_best = true;
weight_t loss_margin = cfg["loss_margin"].as<weight_t>();
+ bool batch = false;
+ if (cfg.count("batch")) batch = true;
if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max();
bool scale_bleu_diff = false;
if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true;
@@ -184,10 +192,10 @@ main(int argc, char** argv)
observer->SetScorer(scorer);
// init weights
- vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
+ vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
SparseVector<weight_t> lambdas, cumulative_penalties, w_average;
- if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights);
- Weights::InitSparseVector(dense_weights, &lambdas);
+ if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &decoder_weights);
+ Weights::InitSparseVector(decoder_weights, &lambdas);
// meta params for perceptron, SVM
weight_t eta = cfg["learning_rate"].as<weight_t>();
@@ -245,6 +253,7 @@ main(int argc, char** argv)
cerr << setw(25) << "k " << k << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
+ cerr << setw(25) << "batch " << batch << endl;
cerr << setw(26) << "scorer '" << scorer_str << "'" << endl;
if (scorer_str == "approx_bleu")
cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl;
@@ -267,6 +276,8 @@ main(int argc, char** argv)
cerr << setw(25) << "rescale " << rescale << endl;
cerr << setw(25) << "pclr " << pclr << endl;
cerr << setw(25) << "max pairs " << max_pairs << endl;
+ //cerr << setw(25) << "repeat " << repeat << endl;
+ //cerr << setw(25) << "test k-best " << test_k_best << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
if (!read_bitext)
@@ -281,17 +292,25 @@ main(int argc, char** argv)
// pclr
SparseVector<weight_t> learning_rates;
+ // batch
+ SparseVector<weight_t> batch_updates;
+ weight_t batch_loss;
+
+ //int did_improve; // FIXME for test-k-best
for (unsigned t = 0; t < T; t++) // T epochs
{
-
+
time_t start, end;
time(&start);
score_t score_sum = 0.;
score_t model_sum(0);
unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0;
+ batch_loss = 0.;
if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
+ //did_improve = 0;
+
while(true)
{
@@ -337,7 +356,7 @@ main(int argc, char** argv)
if (next || stop) break;
// weights
- lambdas.init_vector(&dense_weights);
+ lambdas.init_vector(&decoder_weights);
// getting input
vector<WordID> ref_ids; // reference as vector<WordID>
@@ -392,33 +411,51 @@ main(int argc, char** argv)
partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo);
if (pair_sampling == "PRO")
PROsampling(samples, pairs, pair_threshold, max_pairs);
- npairs += pairs.size();
+ int cur_npairs = pairs.size();
+ npairs += cur_npairs;
+
+ weight_t kbest_loss_first, kbest_loss_last = 0.0;
+//for (int q=0; q < repeat; q++) { // repeat
+
+ weight_t kbest_loss = 0.0; // test-k-best
SparseVector<weight_t> lambdas_copy; // for l1 regularization
SparseVector<weight_t> sum_up; // for pclr
if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
- bool rank_error;
+
+ /*if (repeat > 1) {
+ double x = max(0.0, -1.0 * (lambdas.dot(it->first.f) - lambdas.dot(it->second.f)));
+ kbest_loss += x;
+ }*/
+
+ score_t model_diff = it->first.model - it->second.model;
+ bool rank_error = false;
score_t margin;
if (faster_perceptron) { // we only have considering misranked pairs
rank_error = true; // pair sampling already did this for us
margin = std::numeric_limits<float>::max();
} else {
- rank_error = it->first.model <= it->second.model;
- margin = fabs(it->first.model - it->second.model);
+ rank_error = model_diff<=0.0;
+ margin = fabs(model_diff);
if (!rank_error && margin < loss_margin) margin_violations++;
}
if (rank_error) rank_errors++;
if (scale_bleu_diff) eta = it->first.score - it->second.score;
if (rank_error || margin < loss_margin) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
+ if (batch) {
+ batch_loss += max(0., -1.0*model_diff);
+ batch_updates += diff_vec;
+ continue;
+ }
if (pclr != "no") {
sum_up += diff_vec;
} else {
lambdas.plus_eq_v_times_s(diff_vec, eta);
- if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); // FIXME
+ if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./cur_npairs));
}
}
}
@@ -487,6 +524,11 @@ main(int argc, char** argv)
}
}
+ //if (q==0) { kbest_loss_first = kbest_loss; }
+ //if (q==repeat-1) { kbest_loss_last = kbest_loss; }
+//}//repeat
+//if((kbest_loss_first - kbest_loss_last) > 0) did_improve++;
+
}
if (rescale) lambdas /= lambdas.l2norm();
@@ -495,14 +537,20 @@ main(int argc, char** argv)
} // input loop
- if (average) w_average += lambdas;
+ if (t == 0) in_sz = ii; // remember size of input (# lines)
- if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset();
+ //if (repeat > 1) cout << "did improve? " << did_improve << " out of " << in_sz << endl;
- if (t == 0) {
- in_sz = ii; // remember size of input (# lines)
+ if (batch) {
+ lambdas.plus_eq_v_times_s(batch_updates, eta);
+ if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
+ batch_updates.clear();
}
+ if (average) w_average += lambdas;
+
+ if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset();
+
// print some stats
score_t score_avg = score_sum/(score_t)in_sz;
score_t model_avg = model_sum/(score_t)in_sz;
@@ -534,6 +582,7 @@ main(int argc, char** argv)
cerr << endl;
cerr << " avg # rank err: ";
cerr << rank_errors/(float)in_sz << endl;
+ if (batch) cerr << " batch loss: " << batch_loss << endl;
cerr << " avg # margin viol: ";
cerr << margin_violations/(float)in_sz << endl;
cerr << " non0 feature count: " << nonz << endl;
@@ -562,9 +611,9 @@ main(int argc, char** argv)
// write weights to file
if (select_weights == "best" || keep) {
- lambdas.init_vector(&dense_weights);
+ lambdas.init_vector(&decoder_weights);
string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz";
- Weights::WriteToFile(w_fn, dense_weights, true);
+ Weights::WriteToFile(w_fn, decoder_weights, true);
}
} // outer loop