summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-11-23 17:33:47 +0000
committerPaul Baltescu <pauldb89@gmail.com>2013-11-23 17:33:47 +0000
commitcc6313b23cac25eb05976b6cf64f96faf1ed4163 (patch)
tree3dc28060ad25b43773e875bea7388ab1cefcd927 /training/dtrain/dtrain.cc
parent7990c750829af93f0a1e0fc14534582f52ee9e8c (diff)
parentf2fb69b10a897e8beb4e6e6d6cbb4327096235ef (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc201
1 files changed, 160 insertions, 41 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 0ee2f124..0a27a068 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -12,8 +12,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
{
po::options_description ini("Configuration File Options");
ini.add_options()
- ("input", po::value<string>()->default_value("-"), "input file (src)")
+ ("input", po::value<string>(), "input file (src)")
("refs,r", po::value<string>(), "references")
+ ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt'")
("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
("decoder_config", po::value<string>(), "configuration file for cdec")
@@ -40,6 +41,10 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair")
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
+ ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate")
+ ("batch", po::value<bool>()->zero_tokens(), "do batch optimization")
+ ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times")
+ //("test-k-best", po::value<bool>()->zero_tokens(), "check if optimization works (use repeat >= 2)")
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -72,13 +77,17 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;
return false;
}
- if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") {
+ if (cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") {
cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl;
}
- if((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) {
+ if ((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) {
cerr << "hi_lo must lie in [0.01, 0.5]" << endl;
return false;
}
+ if ((cfg->count("input")>0 || cfg->count("refs")>0) && cfg->count("bitext")>0) {
+ cerr << "Provide 'input' and 'refs' or 'bitext', not both." << endl;
+ return false;
+ }
if ((*cfg)["pair_threshold"].as<score_t>() < 0) {
cerr << "The threshold must be >= 0!" << endl;
return false;
@@ -120,10 +129,16 @@ main(int argc, char** argv)
const float hi_lo = cfg["hi_lo"].as<float>();
const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>();
const unsigned max_pairs = cfg["max_pairs"].as<unsigned>();
+ int repeat = cfg["repeat"].as<unsigned>();
+ //bool test_k_best = false;
+ //if (cfg.count("test-k-best")) test_k_best = true;
weight_t loss_margin = cfg["loss_margin"].as<weight_t>();
+ bool batch = false;
+ if (cfg.count("batch")) batch = true;
if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max();
bool scale_bleu_diff = false;
if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true;
+ const string pclr = cfg["pclr"].as<string>();
bool average = false;
if (select_weights == "avg")
average = true;
@@ -131,7 +146,6 @@ main(int argc, char** argv)
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
-
// setup decoder
register_feature_functions();
SetSilent(true);
@@ -178,17 +192,16 @@ main(int argc, char** argv)
observer->SetScorer(scorer);
// init weights
- vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
+ vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
SparseVector<weight_t> lambdas, cumulative_penalties, w_average;
- if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights);
- Weights::InitSparseVector(dense_weights, &lambdas);
+ if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &decoder_weights);
+ Weights::InitSparseVector(decoder_weights, &lambdas);
// meta params for perceptron, SVM
weight_t eta = cfg["learning_rate"].as<weight_t>();
weight_t gamma = cfg["gamma"].as<weight_t>();
// faster perceptron: consider only misranked pairs, see
- // DO NOT ENABLE WITH SVM (gamma > 0) OR loss_margin!
bool faster_perceptron = false;
if (gamma==0 && loss_margin==0) faster_perceptron = true;
@@ -208,13 +221,24 @@ main(int argc, char** argv)
// output
string output_fn = cfg["output"].as<string>();
// input
- string input_fn = cfg["input"].as<string>();
+ bool read_bitext = false;
+ string input_fn;
+ if (cfg.count("bitext")) {
+ read_bitext = true;
+ input_fn = cfg["bitext"].as<string>();
+ } else {
+ input_fn = cfg["input"].as<string>();
+ }
ReadFile input(input_fn);
// buffer input for t > 0
vector<string> src_str_buf; // source strings (decoder takes only strings)
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
- string refs_fn = cfg["refs"].as<string>();
- ReadFile refs(refs_fn);
+ ReadFile refs;
+ string refs_fn;
+ if (!read_bitext) {
+ refs_fn = cfg["refs"].as<string>();
+ refs.Init(refs_fn);
+ }
unsigned in_sz = std::numeric_limits<unsigned>::max(); // input index, input size
vector<pair<score_t, score_t> > all_scores;
@@ -229,6 +253,7 @@ main(int argc, char** argv)
cerr << setw(25) << "k " << k << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
+ cerr << setw(25) << "batch " << batch << endl;
cerr << setw(26) << "scorer '" << scorer_str << "'" << endl;
if (scorer_str == "approx_bleu")
cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl;
@@ -249,10 +274,14 @@ main(int argc, char** argv)
cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;
if (rescale)
cerr << setw(25) << "rescale " << rescale << endl;
+ cerr << setw(25) << "pclr " << pclr << endl;
cerr << setw(25) << "max pairs " << max_pairs << endl;
+ cerr << setw(25) << "repeat " << repeat << endl;
+ //cerr << setw(25) << "test k-best " << test_k_best << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
- cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
+ if (!read_bitext)
+ cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (cfg.count("input_weights"))
cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl;
@@ -261,6 +290,11 @@ main(int argc, char** argv)
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl;
}
+ // pclr
+ SparseVector<weight_t> learning_rates;
+ // batch
+ SparseVector<weight_t> batch_updates;
+ score_t batch_loss;
for (unsigned t = 0; t < T; t++) // T epochs
{
@@ -269,16 +303,24 @@ main(int argc, char** argv)
time(&start);
score_t score_sum = 0.;
score_t model_sum(0);
- unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0;
+ unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0, kbest_loss_improve = 0;
+ batch_loss = 0.;
if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
while(true)
{
string in;
+ string ref;
bool next = false, stop = false; // next iteration or premature stop
if (t == 0) {
if(!getline(*input, in)) next = true;
+ if(read_bitext) {
+ vector<string> strs;
+ boost::algorithm::split_regex(strs, in, boost::regex(" \\|\\|\\| "));
+ in = strs[0];
+ ref = strs[1];
+ }
} else {
if (ii == in_sz) next = true; // stop if we reach the end of our input
}
@@ -310,15 +352,16 @@ main(int argc, char** argv)
if (next || stop) break;
// weights
- lambdas.init_vector(&dense_weights);
+ lambdas.init_vector(&decoder_weights);
// getting input
vector<WordID> ref_ids; // reference as vector<WordID>
if (t == 0) {
- string r_;
- getline(*refs, r_);
+ if (!read_bitext) {
+ getline(*refs, ref);
+ }
vector<string> ref_tok;
- boost::split(ref_tok, r_, boost::is_any_of(" "));
+ boost::split(ref_tok, ref, boost::is_any_of(" "));
register_and_convert(ref_tok, ref_ids);
ref_ids_buf.push_back(ref_ids);
src_str_buf.push_back(in);
@@ -348,8 +391,10 @@ main(int argc, char** argv)
}
}
- score_sum += (*samples)[0].score; // stats for 1best
- model_sum += (*samples)[0].model;
+ if (repeat == 1) {
+ score_sum += (*samples)[0].score; // stats for 1best
+ model_sum += (*samples)[0].model;
+ }
f_count += observer->get_f_count();
list_sz += observer->get_sz();
@@ -364,30 +409,74 @@ main(int argc, char** argv)
partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo);
if (pair_sampling == "PRO")
PROsampling(samples, pairs, pair_threshold, max_pairs);
- npairs += pairs.size();
+ int cur_npairs = pairs.size();
+ npairs += cur_npairs;
+
+ score_t kbest_loss_first, kbest_loss_last = 0.0;
- SparseVector<weight_t> lambdas_copy;
+ for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
+ it != pairs.end(); it++) {
+ score_t model_diff = it->first.model - it->second.model;
+ kbest_loss_first += max(0.0, -1.0 * model_diff);
+ }
+
+ for (int ki=0; ki < repeat; ki++) {
+
+ score_t kbest_loss = 0.0; // test-k-best
+ SparseVector<weight_t> lambdas_copy; // for l1 regularization
+ SparseVector<weight_t> sum_up; // for pclr
if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
- bool rank_error;
+ score_t model_diff = it->first.model - it->second.model;
+ if (repeat > 1) {
+ model_diff = lambdas.dot(it->first.f) - lambdas.dot(it->second.f);
+ kbest_loss += max(0.0, -1.0 * model_diff);
+ }
+ bool rank_error = false;
score_t margin;
if (faster_perceptron) { // we only have considering misranked pairs
rank_error = true; // pair sampling already did this for us
margin = std::numeric_limits<float>::max();
} else {
- rank_error = it->first.model <= it->second.model;
- margin = fabs(it->first.model - it->second.model);
+ rank_error = model_diff<=0.0;
+ margin = fabs(model_diff);
if (!rank_error && margin < loss_margin) margin_violations++;
}
- if (rank_error) rank_errors++;
+ if (rank_error && ki==1) rank_errors++;
if (scale_bleu_diff) eta = it->first.score - it->second.score;
if (rank_error || margin < loss_margin) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
- lambdas.plus_eq_v_times_s(diff_vec, eta);
- if (gamma)
- lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
+ if (batch) {
+ batch_loss += max(0., -1.0*model_diff);
+ batch_updates += diff_vec;
+ continue;
+ }
+ if (pclr != "no") {
+ sum_up += diff_vec;
+ } else {
+ lambdas.plus_eq_v_times_s(diff_vec, eta);
+ if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./cur_npairs));
+ }
+ }
+ }
+
+ // per-coordinate learning rate
+ if (pclr != "no") {
+ SparseVector<weight_t>::iterator it = sum_up.begin();
+ for (; it != sum_up.end(); ++it) {
+ if (pclr == "simple") {
+ lambdas[it->first] += it->second / max(1.0, learning_rates[it->first]);
+ learning_rates[it->first]++;
+ } else if (pclr == "adagrad") {
+ if (learning_rates[it->first] == 0) {
+ lambdas[it->first] += it->second * eta;
+ } else {
+ lambdas[it->first] += it->second * eta * learning_rates[it->first];
+ }
+ learning_rates[it->first] += pow(it->second, 2.0);
+ }
}
}
@@ -395,14 +484,16 @@ main(int argc, char** argv)
// please note that this regularizations happen
// after a _sentence_ -- not after each example/pair!
if (l1naive) {
- FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ it->second *= max(0.0000001, eta/(eta+learning_rates[it->first])); // FIXME
+ learning_rates[it->first]++;
it->second -= sign(it->second) * l1_reg;
}
}
} else if (l1clip) {
- FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
if (it->second != 0) {
@@ -417,7 +508,7 @@ main(int argc, char** argv)
}
} else if (l1cumul) {
weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
- FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
if (it->second != 0) {
@@ -435,7 +526,28 @@ main(int argc, char** argv)
}
}
- }
+ if (ki==repeat-1) { // done
+ kbest_loss_last = kbest_loss;
+ if (repeat > 1) {
+ score_t best_score = -1.;
+ score_t best_model = -std::numeric_limits<score_t>::max();
+ unsigned best_idx;
+ for (unsigned i=0; i < samples->size(); i++) {
+ score_t s = lambdas.dot((*samples)[i].f);
+ if (s > best_model) {
+ best_idx = i;
+ best_model = s;
+ }
+ }
+ score_sum += (*samples)[best_idx].score;
+ model_sum += best_model;
+ }
+ }
+ } // repeat
+
+ if ((kbest_loss_first - kbest_loss_last) >= 0) kbest_loss_improve++;
+
+ } // noup
if (rescale) lambdas /= lambdas.l2norm();
@@ -443,14 +555,19 @@ main(int argc, char** argv)
} // input loop
- if (average) w_average += lambdas;
+ if (t == 0) in_sz = ii; // remember size of input (# lines)
- if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset();
- if (t == 0) {
- in_sz = ii; // remember size of input (# lines)
+ if (batch) {
+ lambdas.plus_eq_v_times_s(batch_updates, eta);
+ if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
+ batch_updates.clear();
}
+ if (average) w_average += lambdas;
+
+ if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset();
+
// print some stats
score_t score_avg = score_sum/(score_t)in_sz;
score_t model_avg = model_sum/(score_t)in_sz;
@@ -477,13 +594,15 @@ main(int argc, char** argv)
cerr << _np << " 1best avg model score: " << model_avg;
cerr << _p << " (" << model_diff << ")" << endl;
cerr << " avg # pairs: ";
- cerr << _np << npairs/(float)in_sz;
+ cerr << _np << npairs/(float)in_sz << endl;
+ cerr << " avg # rank err: ";
+ cerr << rank_errors/(float)in_sz;
if (faster_perceptron) cerr << " (meaningless)";
cerr << endl;
- cerr << " avg # rank err: ";
- cerr << rank_errors/(float)in_sz << endl;
cerr << " avg # margin viol: ";
cerr << margin_violations/(float)in_sz << endl;
+ if (batch) cerr << " batch loss: " << batch_loss << endl;
+ cerr << " k-best loss imp: " << ((float)kbest_loss_improve/in_sz)*100 << "%" << endl;
cerr << " non0 feature count: " << nonz << endl;
cerr << " avg list sz: " << list_sz/(float)in_sz << endl;
cerr << " avg f count: " << f_count/(float)list_sz << endl;
@@ -510,9 +629,9 @@ main(int argc, char** argv)
// write weights to file
if (select_weights == "best" || keep) {
- lambdas.init_vector(&dense_weights);
+ lambdas.init_vector(&decoder_weights);
string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz";
- Weights::WriteToFile(w_fn, dense_weights, true);
+ Weights::WriteToFile(w_fn, decoder_weights, true);
}
} // outer loop