diff options
author | Patrick Simianer <p@simianer.de> | 2013-11-12 18:36:03 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2013-11-12 18:36:03 +0100 |
commit | a8ea0a66b798326061bc9f0da153b96b730130f1 (patch) | |
tree | 3bd01a848b10c43be7182add1fc4edb19d73244c /training/dtrain/dtrain.cc | |
parent | decd2c4b1d4fb42a73a3217f347ea8f317e50869 (diff) |
implemented batch tuning
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 81 |
1 files changed, 65 insertions, 16 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index a496f08a..23131810 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -42,6 +42,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near") ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate") + ("batch", po::value<bool>()->zero_tokens(), "do batch optimization") + //("repeat", po::value<int>()->default_value(1), "repeat optimization over kbest list this number of times") + //("test-k-best", po::value<bool>()->zero_tokens(), "check if optimization works (use repeat >= 2)") ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -126,7 +129,12 @@ main(int argc, char** argv) const float hi_lo = cfg["hi_lo"].as<float>(); const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>(); const unsigned max_pairs = cfg["max_pairs"].as<unsigned>(); + //int repeat = cfg["repeat"].as<int>(); + //bool test_k_best = false; + //if (cfg.count("test-k-best")) test_k_best = true; weight_t loss_margin = cfg["loss_margin"].as<weight_t>(); + bool batch = false; + if (cfg.count("batch")) batch = true; if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max(); bool scale_bleu_diff = false; if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true; @@ -184,10 +192,10 @@ main(int argc, char** argv) observer->SetScorer(scorer); // init weights - vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); + vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); SparseVector<weight_t> lambdas, cumulative_penalties, w_average; - if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights); - Weights::InitSparseVector(dense_weights, &lambdas); + if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &decoder_weights); + Weights::InitSparseVector(decoder_weights, &lambdas); // meta params for perceptron, SVM weight_t eta = cfg["learning_rate"].as<weight_t>(); @@ -245,6 +253,7 @@ main(int argc, char** argv) cerr << setw(25) << "k " << k << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; + cerr << setw(25) << "batch " << batch << endl; cerr << setw(26) << "scorer '" << scorer_str << "'" << endl; if (scorer_str == "approx_bleu") cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; @@ -267,6 +276,8 @@ main(int argc, char** argv) cerr << setw(25) << "rescale " << rescale << endl; cerr << setw(25) << "pclr " << pclr << endl; cerr << setw(25) << "max pairs " << max_pairs << endl; + //cerr << setw(25) << "repeat " << repeat << endl; + //cerr << setw(25) << "test k-best " << test_k_best << endl; cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; if (!read_bitext) @@ -281,17 +292,25 @@ main(int argc, char** argv) // pclr SparseVector<weight_t> learning_rates; + // batch + SparseVector<weight_t> batch_updates; + weight_t batch_loss; + + //int did_improve; // FIXME for test-k-best for (unsigned t = 0; t < T; t++) // T epochs { - + time_t start, end; time(&start); score_t score_sum = 0.; score_t model_sum(0); unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0; + batch_loss = 0.; if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; + //did_improve = 0; + while(true) { @@ -337,7 +356,7 @@ main(int argc, char** argv) if (next || stop) break; // weights - lambdas.init_vector(&dense_weights); + lambdas.init_vector(&decoder_weights); // getting input vector<WordID> ref_ids; // reference as vector<WordID> @@ -392,33 +411,51 @@ main(int argc, char** argv) partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo); if (pair_sampling == "PRO") PROsampling(samples, pairs, pair_threshold, max_pairs); - npairs += pairs.size(); + int cur_npairs = pairs.size(); + npairs += cur_npairs; + + weight_t kbest_loss_first, kbest_loss_last = 0.0; +//for (int q=0; q < repeat; q++) { // repeat + + weight_t kbest_loss = 0.0; // test-k-best SparseVector<weight_t> lambdas_copy; // for l1 regularization SparseVector<weight_t> sum_up; // for pclr if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas; for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { - bool rank_error; + + /*if (repeat > 1) { + double x = max(0.0, -1.0 * (lambdas.dot(it->first.f) - lambdas.dot(it->second.f))); + kbest_loss += x; + }*/ + + score_t model_diff = it->first.model - it->second.model; + bool rank_error = false; score_t margin; if (faster_perceptron) { // we only have considering misranked pairs rank_error = true; // pair sampling already did this for us margin = std::numeric_limits<float>::max(); } else { - rank_error = it->first.model <= it->second.model; - margin = fabs(it->first.model - it->second.model); + rank_error = model_diff<=0.0; + margin = fabs(model_diff); if (!rank_error && margin < loss_margin) margin_violations++; } if (rank_error) rank_errors++; if (scale_bleu_diff) eta = it->first.score - it->second.score; if (rank_error || margin < loss_margin) { SparseVector<weight_t> diff_vec = it->first.f - it->second.f; + if (batch) { + batch_loss += max(0., -1.0*model_diff); + batch_updates += diff_vec; + continue; + } if (pclr != "no") { sum_up += diff_vec; } else { lambdas.plus_eq_v_times_s(diff_vec, eta); - if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); // FIXME + if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./cur_npairs)); } } } @@ -487,6 +524,11 @@ main(int argc, char** argv) } } + //if (q==0) { kbest_loss_first = kbest_loss; } + //if (q==repeat-1) { kbest_loss_last = kbest_loss; } +//}//repeat +//if((kbest_loss_first - kbest_loss_last) > 0) did_improve++; + } if (rescale) lambdas /= lambdas.l2norm(); @@ -495,14 +537,20 @@ main(int argc, char** argv) } // input loop - if (average) w_average += lambdas; + if (t == 0) in_sz = ii; // remember size of input (# lines) - if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); + //if (repeat > 1) cout << "did improve? " << did_improve << " out of " << in_sz << endl; - if (t == 0) { - in_sz = ii; // remember size of input (# lines) + if (batch) { + lambdas.plus_eq_v_times_s(batch_updates, eta); + if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); + batch_updates.clear(); } + if (average) w_average += lambdas; + + if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); + // print some stats score_t score_avg = score_sum/(score_t)in_sz; score_t model_avg = model_sum/(score_t)in_sz; @@ -534,6 +582,7 @@ main(int argc, char** argv) cerr << endl; cerr << " avg # rank err: "; cerr << rank_errors/(float)in_sz << endl; + if (batch) cerr << " batch loss: " << batch_loss << endl; cerr << " avg # margin viol: "; cerr << margin_violations/(float)in_sz << endl; cerr << " non0 feature count: " << nonz << endl; @@ -562,9 +611,9 @@ main(int argc, char** argv) // write weights to file if (select_weights == "best" || keep) { - lambdas.init_vector(&dense_weights); + lambdas.init_vector(&decoder_weights); string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz"; - Weights::WriteToFile(w_fn, dense_weights, true); + Weights::WriteToFile(w_fn, decoder_weights, true); } } // outer loop |