diff options
author | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-04-29 14:55:27 +0200 |
---|---|---|
committer | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-04-29 14:55:27 +0200 |
commit | 70585a59a738d0148ed2da90252050f4d86f4a22 (patch) | |
tree | c803b3eb3df5ddc90637f8d33ce7e824c6f0b98b /dtrain/dtrain.cc | |
parent | 810f2bc32c796d270ff2209183ce13e69d4b1627 (diff) |
added params, output
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index e817e7ab..b662cd26 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -21,17 +21,18 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") ("pair_sampling", po::value<string>()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'") ("hi_lo", po::value<float>()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5") - ("pair_threshold", po::value<score_t>()->default_value(0), "bleu [0,1] threshold to filter pairs") + ("pair_threshold", po::value<score_t>()->default_value(0.), "bleu [0,1] threshold to filter pairs") ("N", po::value<unsigned>()->default_value(4), "N for Ngrams (BLEU)") ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_") ("learning_rate", po::value<weight_t>()->default_value(0.0001), "learning rate") - ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)") + ("gamma", po::value<weight_t>()->default_value(0.), "gamma for SVM (0 for perceptron)") ("select_weights", po::value<string>()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") ("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input") ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") - ("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent of features after each epoch") + ("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent (or by threshold) of features after each epoch") ("approx_bleu_d", po::value<score_t>()->default_value(0.9), "discount for approx. BLEU") + ("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") #ifdef DTRAIN_LOCAL ("refs,r", po::value<string>(), "references in local mode") #endif @@ -133,6 +134,8 @@ main(int argc, char** argv) const string select_weights = cfg["select_weights"].as<string>(); const float hi_lo = cfg["hi_lo"].as<float>(); const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>(); + bool scale_bleu_diff = false; + if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true; bool average = false; if (select_weights == "avg") average = true; @@ -236,7 +239,8 @@ main(int argc, char** argv) cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; if (sample_from == "kbest") cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; - cerr << setw(25) << "learning rate " << eta << endl; + if (!scale_bleu_diff) cerr << setw(25) << "learning rate " << eta << endl; + else cerr << setw(25) << "learning rate " << "bleu diff" << endl; cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; if (pair_sampling == "XYX") @@ -255,7 +259,7 @@ main(int argc, char** argv) cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; if (cfg.count("input_weights")) cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl; - if (cfg.count("stop-after")) + if (stop_after > 0) cerr << setw(25) << "stop_after " << stop_after << endl; if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl; } @@ -274,7 +278,7 @@ main(int argc, char** argv) #endif score_t score_sum = 0.; score_t model_sum(0); - unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0; + unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0; if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; while(true) @@ -392,7 +396,7 @@ main(int argc, char** argv) else printWordIDVec(ref_ids); cerr << endl; for (unsigned u = 0; u < samples->size(); u++) { - cerr << _p5 << _np << "[" << u << ". '"; + cerr << _p2 << _np << "[" << u << ". '"; printWordIDVec((*samples)[u].w); cerr << "'" << endl; cerr << "SCORE=" << (*samples)[u].score << ",model="<< (*samples)[u].model << endl; @@ -403,8 +407,12 @@ main(int argc, char** argv) score_sum += (*samples)[0].score; // stats for 1best model_sum += (*samples)[0].model; + f_count += observer->get_f_count(); + list_sz += observer->get_sz(); + // weight updates if (!noup) { + // get pairs vector<pair<ScoredHyp,ScoredHyp> > pairs; if (pair_sampling == "all") all_pairs(samples, pairs, pair_threshold); @@ -420,6 +428,7 @@ main(int argc, char** argv) if (rank_error) rank_errors++; score_t margin = fabs(it->first.model - it->second.model); if (!rank_error && margin < 1) margin_violations++; + if (scale_bleu_diff) eta = it->first.score - it->second.score; if (rank_error || (gamma && margin<1)) { SparseVector<weight_t> diff_vec = it->first.f - it->second.f; lambdas.plus_eq_v_times_s(diff_vec, eta); @@ -512,7 +521,7 @@ main(int argc, char** argv) if (!quiet || hstreaming) nonz = (unsigned)lambdas.size_nonzero(); if (!quiet) { - cerr << _p9 << _p << "WEIGHTS" << endl; + cerr << _p5 << _p << "WEIGHTS" << endl; for (vector<string>::iterator it = print_weights.begin(); it != print_weights.end(); it++) { cerr << setw(18) << *it << " = " << lambdas.get(FD::Convert(*it)) << endl; } @@ -528,6 +537,8 @@ main(int argc, char** argv) cerr << " avg # margin viol: "; cerr << margin_violations/(float)in_sz << endl; cerr << " non0 feature count: " << nonz << endl; + cerr << " avg list sz: " << list_sz/(float)in_sz << endl; + cerr << " avg f count: " << f_count/(float)list_sz << endl; } if (hstreaming) { @@ -617,7 +628,7 @@ main(int argc, char** argv) if (!quiet) { cerr << _p5 << _np << endl << "---" << endl << "Best iteration: "; cerr << best_it+1 << " [SCORE '" << scorer_str << "'=" << max_score << "]." << endl; - cerr << _p2 << "This took " << overall_time/60. << " min." << endl; + cerr << "This took " << overall_time/60. << " min." << endl; } } |