diff options
author | Paul Baltescu <pauldb89@gmail.com> | 2013-11-23 17:33:47 +0000 |
---|---|---|
committer | Paul Baltescu <pauldb89@gmail.com> | 2013-11-23 17:33:47 +0000 |
commit | 072c4bb1edde483b87b93bc6f4eec36fc8a21008 (patch) | |
tree | 6ceaa6ae1e08df9e523282740b14f4857236297c /training/dtrain/dtrain.cc | |
parent | 7e90b8ea10904f9b83f4e77e14c7396a3e6f7d5d (diff) | |
parent | 9e80389b9763aa4f7f626ec71b561ccf6948d3ad (diff) |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 201 |
1 files changed, 160 insertions, 41 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 0ee2f124..0a27a068 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -12,8 +12,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) { po::options_description ini("Configuration File Options"); ini.add_options() - ("input", po::value<string>()->default_value("-"), "input file (src)") + ("input", po::value<string>(), "input file (src)") ("refs,r", po::value<string>(), "references") + ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt'") ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") ("decoder_config", po::value<string>(), "configuration file for cdec") @@ -40,6 +41,10 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") ("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near") ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") + ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate") + ("batch", po::value<bool>()->zero_tokens(), "do batch optimization") + ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times") + //("test-k-best", po::value<bool>()->zero_tokens(), "check if optimization works (use repeat >= 2)") ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -72,13 +77,17 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl; return false; } - if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") { + if (cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") { cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl; } - if((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) { + if ((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) { cerr << "hi_lo must lie in [0.01, 0.5]" << endl; return false; } + if ((cfg->count("input")>0 || cfg->count("refs")>0) && cfg->count("bitext")>0) { + cerr << "Provide 'input' and 'refs' or 'bitext', not both." << endl; + return false; + } if ((*cfg)["pair_threshold"].as<score_t>() < 0) { cerr << "The threshold must be >= 0!" << endl; return false; @@ -120,10 +129,16 @@ main(int argc, char** argv) const float hi_lo = cfg["hi_lo"].as<float>(); const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>(); const unsigned max_pairs = cfg["max_pairs"].as<unsigned>(); + int repeat = cfg["repeat"].as<unsigned>(); + //bool test_k_best = false; + //if (cfg.count("test-k-best")) test_k_best = true; weight_t loss_margin = cfg["loss_margin"].as<weight_t>(); + bool batch = false; + if (cfg.count("batch")) batch = true; if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max(); bool scale_bleu_diff = false; if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true; + const string pclr = cfg["pclr"].as<string>(); bool average = false; if (select_weights == "avg") average = true; @@ -131,7 +146,6 @@ main(int argc, char** argv) if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); - // setup decoder register_feature_functions(); SetSilent(true); @@ -178,17 +192,16 @@ main(int argc, char** argv) observer->SetScorer(scorer); // init weights - vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); + vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); SparseVector<weight_t> lambdas, cumulative_penalties, w_average; - if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights); - Weights::InitSparseVector(dense_weights, &lambdas); + if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &decoder_weights); + Weights::InitSparseVector(decoder_weights, &lambdas); // meta params for perceptron, SVM weight_t eta = cfg["learning_rate"].as<weight_t>(); weight_t gamma = cfg["gamma"].as<weight_t>(); // faster perceptron: consider only misranked pairs, see - // DO NOT ENABLE WITH SVM (gamma > 0) OR loss_margin! bool faster_perceptron = false; if (gamma==0 && loss_margin==0) faster_perceptron = true; @@ -208,13 +221,24 @@ main(int argc, char** argv) // output string output_fn = cfg["output"].as<string>(); // input - string input_fn = cfg["input"].as<string>(); + bool read_bitext = false; + string input_fn; + if (cfg.count("bitext")) { + read_bitext = true; + input_fn = cfg["bitext"].as<string>(); + } else { + input_fn = cfg["input"].as<string>(); + } ReadFile input(input_fn); // buffer input for t > 0 vector<string> src_str_buf; // source strings (decoder takes only strings) vector<vector<WordID> > ref_ids_buf; // references as WordID vecs - string refs_fn = cfg["refs"].as<string>(); - ReadFile refs(refs_fn); + ReadFile refs; + string refs_fn; + if (!read_bitext) { + refs_fn = cfg["refs"].as<string>(); + refs.Init(refs_fn); + } unsigned in_sz = std::numeric_limits<unsigned>::max(); // input index, input size vector<pair<score_t, score_t> > all_scores; @@ -229,6 +253,7 @@ main(int argc, char** argv) cerr << setw(25) << "k " << k << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; + cerr << setw(25) << "batch " << batch << endl; cerr << setw(26) << "scorer '" << scorer_str << "'" << endl; if (scorer_str == "approx_bleu") cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; @@ -249,10 +274,14 @@ main(int argc, char** argv) cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl; if (rescale) cerr << setw(25) << "rescale " << rescale << endl; + cerr << setw(25) << "pclr " << pclr << endl; cerr << setw(25) << "max pairs " << max_pairs << endl; + cerr << setw(25) << "repeat " << repeat << endl; + //cerr << setw(25) << "test k-best " << test_k_best << endl; cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; - cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; + if (!read_bitext) + cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; if (cfg.count("input_weights")) cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl; @@ -261,6 +290,11 @@ main(int argc, char** argv) if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl; } + // pclr + SparseVector<weight_t> learning_rates; + // batch + SparseVector<weight_t> batch_updates; + score_t batch_loss; for (unsigned t = 0; t < T; t++) // T epochs { @@ -269,16 +303,24 @@ main(int argc, char** argv) time(&start); score_t score_sum = 0.; score_t model_sum(0); - unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0; + unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0, kbest_loss_improve = 0; + batch_loss = 0.; if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; while(true) { string in; + string ref; bool next = false, stop = false; // next iteration or premature stop if (t == 0) { if(!getline(*input, in)) next = true; + if(read_bitext) { + vector<string> strs; + boost::algorithm::split_regex(strs, in, boost::regex(" \\|\\|\\| ")); + in = strs[0]; + ref = strs[1]; + } } else { if (ii == in_sz) next = true; // stop if we reach the end of our input } @@ -310,15 +352,16 @@ main(int argc, char** argv) if (next || stop) break; // weights - lambdas.init_vector(&dense_weights); + lambdas.init_vector(&decoder_weights); // getting input vector<WordID> ref_ids; // reference as vector<WordID> if (t == 0) { - string r_; - getline(*refs, r_); + if (!read_bitext) { + getline(*refs, ref); + } vector<string> ref_tok; - boost::split(ref_tok, r_, boost::is_any_of(" ")); + boost::split(ref_tok, ref, boost::is_any_of(" ")); register_and_convert(ref_tok, ref_ids); ref_ids_buf.push_back(ref_ids); src_str_buf.push_back(in); @@ -348,8 +391,10 @@ main(int argc, char** argv) } } - score_sum += (*samples)[0].score; // stats for 1best - model_sum += (*samples)[0].model; + if (repeat == 1) { + score_sum += (*samples)[0].score; // stats for 1best + model_sum += (*samples)[0].model; + } f_count += observer->get_f_count(); list_sz += observer->get_sz(); @@ -364,30 +409,74 @@ main(int argc, char** argv) partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo); if (pair_sampling == "PRO") PROsampling(samples, pairs, pair_threshold, max_pairs); - npairs += pairs.size(); + int cur_npairs = pairs.size(); + npairs += cur_npairs; + + score_t kbest_loss_first, kbest_loss_last = 0.0; - SparseVector<weight_t> lambdas_copy; + for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); + it != pairs.end(); it++) { + score_t model_diff = it->first.model - it->second.model; + kbest_loss_first += max(0.0, -1.0 * model_diff); + } + + for (int ki=0; ki < repeat; ki++) { + + score_t kbest_loss = 0.0; // test-k-best + SparseVector<weight_t> lambdas_copy; // for l1 regularization + SparseVector<weight_t> sum_up; // for pclr if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas; for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { - bool rank_error; + score_t model_diff = it->first.model - it->second.model; + if (repeat > 1) { + model_diff = lambdas.dot(it->first.f) - lambdas.dot(it->second.f); + kbest_loss += max(0.0, -1.0 * model_diff); + } + bool rank_error = false; score_t margin; if (faster_perceptron) { // we only have considering misranked pairs rank_error = true; // pair sampling already did this for us margin = std::numeric_limits<float>::max(); } else { - rank_error = it->first.model <= it->second.model; - margin = fabs(it->first.model - it->second.model); + rank_error = model_diff<=0.0; + margin = fabs(model_diff); if (!rank_error && margin < loss_margin) margin_violations++; } - if (rank_error) rank_errors++; + if (rank_error && ki==1) rank_errors++; if (scale_bleu_diff) eta = it->first.score - it->second.score; if (rank_error || margin < loss_margin) { SparseVector<weight_t> diff_vec = it->first.f - it->second.f; - lambdas.plus_eq_v_times_s(diff_vec, eta); - if (gamma) - lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); + if (batch) { + batch_loss += max(0., -1.0*model_diff); + batch_updates += diff_vec; + continue; + } + if (pclr != "no") { + sum_up += diff_vec; + } else { + lambdas.plus_eq_v_times_s(diff_vec, eta); + if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./cur_npairs)); + } + } + } + + // per-coordinate learning rate + if (pclr != "no") { + SparseVector<weight_t>::iterator it = sum_up.begin(); + for (; it != sum_up.end(); ++it) { + if (pclr == "simple") { + lambdas[it->first] += it->second / max(1.0, learning_rates[it->first]); + learning_rates[it->first]++; + } else if (pclr == "adagrad") { + if (learning_rates[it->first] == 0) { + lambdas[it->first] += it->second * eta; + } else { + lambdas[it->first] += it->second * eta * learning_rates[it->first]; + } + learning_rates[it->first] += pow(it->second, 2.0); + } } } @@ -395,14 +484,16 @@ main(int argc, char** argv) // please note that this regularizations happen // after a _sentence_ -- not after each example/pair! if (l1naive) { - FastSparseVector<weight_t>::iterator it = lambdas.begin(); + SparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + it->second *= max(0.0000001, eta/(eta+learning_rates[it->first])); // FIXME + learning_rates[it->first]++; it->second -= sign(it->second) * l1_reg; } } } else if (l1clip) { - FastSparseVector<weight_t>::iterator it = lambdas.begin(); + SparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { if (it->second != 0) { @@ -417,7 +508,7 @@ main(int argc, char** argv) } } else if (l1cumul) { weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input - FastSparseVector<weight_t>::iterator it = lambdas.begin(); + SparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { if (it->second != 0) { @@ -435,7 +526,28 @@ main(int argc, char** argv) } } - } + if (ki==repeat-1) { // done + kbest_loss_last = kbest_loss; + if (repeat > 1) { + score_t best_score = -1.; + score_t best_model = -std::numeric_limits<score_t>::max(); + unsigned best_idx; + for (unsigned i=0; i < samples->size(); i++) { + score_t s = lambdas.dot((*samples)[i].f); + if (s > best_model) { + best_idx = i; + best_model = s; + } + } + score_sum += (*samples)[best_idx].score; + model_sum += best_model; + } + } + } // repeat + + if ((kbest_loss_first - kbest_loss_last) >= 0) kbest_loss_improve++; + + } // noup if (rescale) lambdas /= lambdas.l2norm(); @@ -443,14 +555,19 @@ main(int argc, char** argv) } // input loop - if (average) w_average += lambdas; + if (t == 0) in_sz = ii; // remember size of input (# lines) - if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); - if (t == 0) { - in_sz = ii; // remember size of input (# lines) + if (batch) { + lambdas.plus_eq_v_times_s(batch_updates, eta); + if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); + batch_updates.clear(); } + if (average) w_average += lambdas; + + if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); + // print some stats score_t score_avg = score_sum/(score_t)in_sz; score_t model_avg = model_sum/(score_t)in_sz; @@ -477,13 +594,15 @@ main(int argc, char** argv) cerr << _np << " 1best avg model score: " << model_avg; cerr << _p << " (" << model_diff << ")" << endl; cerr << " avg # pairs: "; - cerr << _np << npairs/(float)in_sz; + cerr << _np << npairs/(float)in_sz << endl; + cerr << " avg # rank err: "; + cerr << rank_errors/(float)in_sz; if (faster_perceptron) cerr << " (meaningless)"; cerr << endl; - cerr << " avg # rank err: "; - cerr << rank_errors/(float)in_sz << endl; cerr << " avg # margin viol: "; cerr << margin_violations/(float)in_sz << endl; + if (batch) cerr << " batch loss: " << batch_loss << endl; + cerr << " k-best loss imp: " << ((float)kbest_loss_improve/in_sz)*100 << "%" << endl; cerr << " non0 feature count: " << nonz << endl; cerr << " avg list sz: " << list_sz/(float)in_sz << endl; cerr << " avg f count: " << f_count/(float)list_sz << endl; @@ -510,9 +629,9 @@ main(int argc, char** argv) // write weights to file if (select_weights == "best" || keep) { - lambdas.init_vector(&dense_weights); + lambdas.init_vector(&decoder_weights); string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz"; - Weights::WriteToFile(w_fn, dense_weights, true); + Weights::WriteToFile(w_fn, decoder_weights, true); } } // outer loop |