diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 127 |
1 files changed, 68 insertions, 59 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 434ae2d6..581c985a 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -10,25 +10,26 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") ("decoder_config", po::value<string>(), "configuration file for cdec") - ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: kbest, forest") + ("print_weights", po::value<string>(), "weights to print on each iteration") + ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences") + ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") + ("keep", po::value<bool>()->zero_tokens(), "keep weights files for each iteration") + ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id") + ("epochs", po::value<unsigned>()->default_value(10), "# of iterations T (per shard)") ("k", po::value<unsigned>()->default_value(100), "how many translations to sample") - ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: no, uniq") - ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, 5050, 108010, PRO") - ("N", po::value<unsigned>()->default_value(3), "N for Ngrams (BLEU)") - ("epochs", po::value<unsigned>()->default_value(2), "# of iterations T (per shard)") - ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_*, smooth_*, approx_*") - ("learning_rate", po::value<weight_t>()->default_value(0.0005), "learning rate") + ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'") + ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") + ("pair_sampling", po::value<string>()->default_value("108010"), "how to sample pairs: 'all', '108010' or 'PRO'") + ("pair_threshold", po::value<score_t>()->default_value(0), "bleu [0,1] threshold to filter pairs") + ("N", po::value<unsigned>()->default_value(4), "N for Ngrams (BLEU)") + ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_") + ("learning_rate", po::value<weight_t>()->default_value(0.0001), "learning rate") ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)") ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") - ("unit_wv", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input") - ("l1_reg", po::value<string>()->default_value("no"), "apply l1 regularization as in Tsuroka et al 2010") + ("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input") + ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") - ("update_ok", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") - ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences") - ("keep_w", po::value<bool>()->zero_tokens(), "keep weights files for each iteration") - ("print_weights", po::value<string>(), "weights to print on each iteration") - ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id") - ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") + ("funny", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") #ifdef DTRAIN_LOCAL ("refs,r", po::value<string>(), "references in local mode") #endif @@ -64,18 +65,22 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; return false; } - if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "uniq" - && (*cfg)["filter"].as<string>() != "no") { - cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'no'." << endl; + if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "uniq" && + (*cfg)["filter"].as<string>() != "not") { + cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'not'." << endl; return false; } - string s = (*cfg)["pair_sampling"].as<string>(); - if (s != "all" && s != "5050" && s != "108010" && s != "PRO" && s != "alld" && s != "108010d") { + if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "108010" && + (*cfg)["pair_sampling"].as<string>() != "PRO") { cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl; return false; } - if ((*cfg)["select_weights"].as<string>() != "last" - && (*cfg)["select_weights"].as<string>() != "best" && (*cfg)["select_weights"].as<string>() != "VOID") { + if ((*cfg)["pair_threshold"].as<score_t>() < 0) { + cerr << "The threshold must be >= 0!" << endl; + return false; + } + if ((*cfg)["select_weights"].as<string>() != "last" && (*cfg)["select_weights"].as<string>() != "best" && + (*cfg)["select_weights"].as<string>() != "VOID") { cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as<string>() << "', use 'last' or 'best'." << endl; return false; } @@ -102,14 +107,14 @@ main(int argc, char** argv) task_id = cfg["hstreaming"].as<string>(); cerr.precision(17); } - bool unit_wv = false; - if (cfg.count("unit_wv")) unit_wv = true; + bool rescale = false; + if (cfg.count("rescale")) rescale = true; HSReporter rep(task_id); - bool keep_w = false; - if (cfg.count("keep_w")) keep_w = true; - bool update_ok = false; - if (cfg.count("update_ok")) - update_ok = true; + bool keep = false; + if (cfg.count("keep")) keep = true; + bool funny = false; + if (cfg.count("funny")) + funny = true; const unsigned k = cfg["k"].as<unsigned>(); const unsigned N = cfg["N"].as<unsigned>(); @@ -118,6 +123,7 @@ main(int argc, char** argv) const string filter_type = cfg["filter"].as<string>(); const string sample_from = cfg["sample_from"].as<string>(); const string pair_sampling = cfg["pair_sampling"].as<string>(); + const score_t pair_threshold = cfg["pair_threshold"].as<score_t>(); const string select_weights = cfg["select_weights"].as<string>(); vector<string> print_weights; if (cfg.count("print_weights")) @@ -168,12 +174,13 @@ main(int argc, char** argv) // meta params for perceptron, SVM weight_t eta = cfg["learning_rate"].as<weight_t>(); weight_t gamma = cfg["gamma"].as<weight_t>(); + // l1 regularization bool l1naive = false; bool l1clip = false; bool l1cumul = false; weight_t l1_reg = 0; - if (cfg["l1_reg"].as<string>() != "no") { + if (cfg["l1_reg"].as<string>() != "none") { string s = cfg["l1_reg"].as<string>(); if (s == "naive") l1naive = true; else if (s == "clip") l1clip = true; @@ -191,7 +198,7 @@ main(int argc, char** argv) vector<vector<WordID> > ref_ids_buf; // references as WordID vecs // where temp files go string tmp_path = cfg["tmp"].as<string>(); - vector<string> w_tmp_files; // used for keep_w + vector<string> w_tmp_files; // used for keep #ifdef DTRAIN_LOCAL string refs_fn = cfg["refs"].as<string>(); ReadFile refs(refs_fn); @@ -214,28 +221,30 @@ main(int argc, char** argv) cerr << setw(25) << "k " << k << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; - if (cfg.count("stop-after")) - cerr << setw(25) << "stop_after " << stop_after << endl; - if (cfg.count("input_weights")) - cerr << setw(25) << "weights in" << cfg["input_weights"].as<string>() << endl; - cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; -#ifdef DTRAIN_LOCAL - cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; -#endif - cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; + cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; if (sample_from == "kbest") cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; cerr << setw(25) << "learning rate " << eta << endl; cerr << setw(25) << "gamma " << gamma << endl; - cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; + cerr << setw(25) << "pair threshold " << pair_threshold << endl; cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; if (cfg.count("l1_reg")) cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl; - if (update_ok) - cerr << setw(25) << "up ok " << update_ok << endl; - if (unit_wv) - cerr << setw(25) << "unit weight vec " << unit_wv << endl; + if (funny) + cerr << setw(25) << "funny " << funny << endl; + if (rescale) + cerr << setw(25) << "rescale " << rescale << endl; + cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; + cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; +#ifdef DTRAIN_LOCAL + cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; +#endif + cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; + if (cfg.count("input_weights")) + cerr << setw(25) << "weights in" << cfg["input_weights"].as<string>() << endl; + if (cfg.count("stop-after")) + cerr << setw(25) << "stop_after " << stop_after << endl; if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl; } @@ -382,17 +391,11 @@ main(int argc, char** argv) if (!noup) { vector<pair<ScoredHyp,ScoredHyp> > pairs; if (pair_sampling == "all") - all_pairs(samples, pairs); - if (pair_sampling == "5050") - rand_pairs_5050(samples, pairs, &rng); + all_pairs(samples, pairs, pair_threshold); if (pair_sampling == "108010") - multpart108010(samples, pairs); + part108010(samples, pairs, pair_threshold); if (pair_sampling == "PRO") PROsampling(samples, pairs); - if (pair_sampling == "alld") - all_pairs_discard(samples, pairs); - if (pair_sampling == "108010d") - multpart108010_discard(samples, pairs); npairs += pairs.size(); for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); @@ -405,7 +408,7 @@ main(int argc, char** argv) lambdas.plus_eq_v_times_s(diff_vec, eta); rank_errors++; } else { - if (update_ok) { + if (funny) { SparseVector<weight_t> diff_vec = it->first.f - it->second.f; lambdas.plus_eq_v_times_s(diff_vec, eta); } @@ -429,6 +432,7 @@ main(int argc, char** argv) // TEST THIS // reset cumulative_penalties after 1 iter? // do this only once per INPUT (not per pair) +if (false) { if (l1naive) { for (unsigned d = 0; d < lambdas.size(); d++) { weight_t v = lambdas.get(d); @@ -462,9 +466,10 @@ main(int argc, char** argv) } } } +} //////// - if (unit_wv && sample_from == "forest") lambdas /= lambdas.l2norm(); + if (rescale) lambdas /= lambdas.l2norm(); ++ii; @@ -505,6 +510,9 @@ main(int argc, char** argv) score_diff = score_avg; model_diff = model_avg; } + + unsigned nonz; + if (!quiet || hstreaming) nonz = (unsigned)lambdas.size_nonzero(); if (!quiet) { cerr << _p5 << _p << "WEIGHTS" << endl; @@ -522,6 +530,8 @@ main(int argc, char** argv) cerr << rank_errors/(float)in_sz << endl; cerr << " avg #margin viol: "; cerr << margin_violations/float(in_sz) << endl; + cerr << " non0 feature count: "; + cerr << nonz << endl; } if (hstreaming) { @@ -530,7 +540,6 @@ main(int argc, char** argv) rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), (unsigned)((npairs/(weight_t)in_sz)*DTRAIN_SCALE)); rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*DTRAIN_SCALE)); rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*DTRAIN_SCALE)); - unsigned nonz = (unsigned)lambdas.size_nonzero(); rep.update_counter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz); rep.update_gcounter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz); } @@ -555,7 +564,7 @@ main(int argc, char** argv) if (noup) break; // write weights to file - if (select_weights == "best" || keep_w) { + if (select_weights == "best" || keep) { lambdas.init_vector(&dense_weights); string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz"; Weights::WriteToFile(w_fn, dense_weights, true); @@ -589,7 +598,7 @@ main(int argc, char** argv) cout << _np; while(getline(*bestw, o)) cout << o << endl; } - if (!keep_w) { + if (!keep) { for (unsigned i = 0; i < T; i++) { string s = "weights." + boost::lexical_cast<string>(i) + ".gz"; unlink(s.c_str()); @@ -606,7 +615,7 @@ main(int argc, char** argv) cerr << _p2 << "This took " << overall_time/60. << " min." << endl; } - if (keep_w) { + if (keep) { cout << endl << "Weight files per iteration:" << endl; for (unsigned i = 0; i < w_tmp_files.size(); i++) { cout << w_tmp_files[i] << endl; |