summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc29
1 files changed, 20 insertions, 9 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index e817e7ab..b662cd26 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -21,17 +21,18 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'")
("pair_sampling", po::value<string>()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'")
("hi_lo", po::value<float>()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5")
- ("pair_threshold", po::value<score_t>()->default_value(0), "bleu [0,1] threshold to filter pairs")
+ ("pair_threshold", po::value<score_t>()->default_value(0.), "bleu [0,1] threshold to filter pairs")
("N", po::value<unsigned>()->default_value(4), "N for Ngrams (BLEU)")
("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_")
("learning_rate", po::value<weight_t>()->default_value(0.0001), "learning rate")
- ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)")
+ ("gamma", po::value<weight_t>()->default_value(0.), "gamma for SVM (0 for perceptron)")
("select_weights", po::value<string>()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)")
("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input")
("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
- ("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent of features after each epoch")
+ ("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent (or by threshold) of features after each epoch")
("approx_bleu_d", po::value<score_t>()->default_value(0.9), "discount for approx. BLEU")
+ ("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -133,6 +134,8 @@ main(int argc, char** argv)
const string select_weights = cfg["select_weights"].as<string>();
const float hi_lo = cfg["hi_lo"].as<float>();
const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>();
+ bool scale_bleu_diff = false;
+ if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true;
bool average = false;
if (select_weights == "avg")
average = true;
@@ -236,7 +239,8 @@ main(int argc, char** argv)
cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
if (sample_from == "kbest")
cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
- cerr << setw(25) << "learning rate " << eta << endl;
+ if (!scale_bleu_diff) cerr << setw(25) << "learning rate " << eta << endl;
+ else cerr << setw(25) << "learning rate " << "bleu diff" << endl;
cerr << setw(25) << "gamma " << gamma << endl;
cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;
if (pair_sampling == "XYX")
@@ -255,7 +259,7 @@ main(int argc, char** argv)
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (cfg.count("input_weights"))
cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl;
- if (cfg.count("stop-after"))
+ if (stop_after > 0)
cerr << setw(25) << "stop_after " << stop_after << endl;
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl;
}
@@ -274,7 +278,7 @@ main(int argc, char** argv)
#endif
score_t score_sum = 0.;
score_t model_sum(0);
- unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0;
+ unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0;
if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
while(true)
@@ -392,7 +396,7 @@ main(int argc, char** argv)
else printWordIDVec(ref_ids);
cerr << endl;
for (unsigned u = 0; u < samples->size(); u++) {
- cerr << _p5 << _np << "[" << u << ". '";
+ cerr << _p2 << _np << "[" << u << ". '";
printWordIDVec((*samples)[u].w);
cerr << "'" << endl;
cerr << "SCORE=" << (*samples)[u].score << ",model="<< (*samples)[u].model << endl;
@@ -403,8 +407,12 @@ main(int argc, char** argv)
score_sum += (*samples)[0].score; // stats for 1best
model_sum += (*samples)[0].model;
+ f_count += observer->get_f_count();
+ list_sz += observer->get_sz();
+
// weight updates
if (!noup) {
+ // get pairs
vector<pair<ScoredHyp,ScoredHyp> > pairs;
if (pair_sampling == "all")
all_pairs(samples, pairs, pair_threshold);
@@ -420,6 +428,7 @@ main(int argc, char** argv)
if (rank_error) rank_errors++;
score_t margin = fabs(it->first.model - it->second.model);
if (!rank_error && margin < 1) margin_violations++;
+ if (scale_bleu_diff) eta = it->first.score - it->second.score;
if (rank_error || (gamma && margin<1)) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
lambdas.plus_eq_v_times_s(diff_vec, eta);
@@ -512,7 +521,7 @@ main(int argc, char** argv)
if (!quiet || hstreaming) nonz = (unsigned)lambdas.size_nonzero();
if (!quiet) {
- cerr << _p9 << _p << "WEIGHTS" << endl;
+ cerr << _p5 << _p << "WEIGHTS" << endl;
for (vector<string>::iterator it = print_weights.begin(); it != print_weights.end(); it++) {
cerr << setw(18) << *it << " = " << lambdas.get(FD::Convert(*it)) << endl;
}
@@ -528,6 +537,8 @@ main(int argc, char** argv)
cerr << " avg # margin viol: ";
cerr << margin_violations/(float)in_sz << endl;
cerr << " non0 feature count: " << nonz << endl;
+ cerr << " avg list sz: " << list_sz/(float)in_sz << endl;
+ cerr << " avg f count: " << f_count/(float)list_sz << endl;
}
if (hstreaming) {
@@ -617,7 +628,7 @@ main(int argc, char** argv)
if (!quiet) {
cerr << _p5 << _np << endl << "---" << endl << "Best iteration: ";
cerr << best_it+1 << " [SCORE '" << scorer_str << "'=" << max_score << "]." << endl;
- cerr << _p2 << "This took " << overall_time/60. << " min." << endl;
+ cerr << "This took " << overall_time/60. << " min." << endl;
}
}