diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 107 |
1 files changed, 64 insertions, 43 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 0853173f..3d3aa2d3 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -6,32 +6,33 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) { po::options_description ini("Configuration File Options"); ini.add_options() - ("input", po::value<string>()->default_value("-"), "input file") - ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") - ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") - ("decoder_config", po::value<string>(), "configuration file for cdec") - ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: kbest, forest") - ("k", po::value<unsigned>()->default_value(100), "how many translations to sample") - ("filter", po::value<string>()->default_value("unique"), "filter kbest list: no, unique") - ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, rand, 108010") - ("N", po::value<unsigned>()->default_value(3), "N for Ngrams (BLEU)") - ("epochs", po::value<unsigned>()->default_value(2), "# of iterations T (per shard)") - ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_*, smooth_*, approx_*") - ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences") - ("print_weights", po::value<string>(), "weights to print on each iteration") - ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id") - ("learning_rate", po::value<weight_t>()->default_value(0.0005), "learning rate") - ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)") - ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") - ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") - ("keep_w", po::value<bool>()->zero_tokens(), "protocol weights for each iteration") - ("unit_weight_vector", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input") - ("l1_reg", po::value<string>()->default_value("no"), "apply l1 regularization as in Tsuroka et al 2010") - ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") + ("input", po::value<string>()->default_value("-"), "input file") + ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") + ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") + ("decoder_config", po::value<string>(), "configuration file for cdec") + ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: kbest, forest") + ("k", po::value<unsigned>()->default_value(100), "how many translations to sample") + ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: no, uniq") + ("pair_sampling", po::value<string>()->default_value("all"), "how to sample pairs: all, 5050, 108010, PRO") + ("N", po::value<unsigned>()->default_value(3), "N for Ngrams (BLEU)") + ("epochs", po::value<unsigned>()->default_value(2), "# of iterations T (per shard)") + ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_*, smooth_*, approx_*") + ("learning_rate", po::value<weight_t>()->default_value(0.0005), "learning rate") + ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)") + ("select_weights", po::value<string>()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") + ("unit_wv", po::value<bool>()->zero_tokens(), "Rescale weight vector after each input") + ("l1_reg", po::value<string>()->default_value("no"), "apply l1 regularization as in Tsuroka et al 2010") + ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") + ("update_ok", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") + ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences") + ("keep_w", po::value<bool>()->zero_tokens(), "keep weights files for each iteration") + ("print_weights", po::value<string>(), "weights to print on each iteration") + ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id") + ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") #ifdef DTRAIN_LOCAL - ("refs,r", po::value<string>(), "references in local mode") + ("refs,r", po::value<string>(), "references in local mode") #endif - ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); + ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() ("config,c", po::value<string>(), "dtrain config file") @@ -63,13 +64,14 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; return false; } - if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "unique" + if ((*cfg)["sample_from"].as<string>() == "kbest" && (*cfg)["filter"].as<string>() != "uniq" && (*cfg)["filter"].as<string>() != "no") { - cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'unique' or 'no'." << endl; + cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'no'." << endl; return false; } if ((*cfg)["pair_sampling"].as<string>() != "all" - && (*cfg)["pair_sampling"].as<string>() != "rand" && (*cfg)["pair_sampling"].as<string>() != "108010") { + && (*cfg)["pair_sampling"].as<string>() != "5050" && (*cfg)["pair_sampling"].as<string>() != "108010" + && (*cfg)["pair_sampling"].as<string>() != "PRO") { cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "', use 'all' or 'rand'." << endl; return false; } @@ -101,11 +103,14 @@ main(int argc, char** argv) task_id = cfg["hstreaming"].as<string>(); cerr.precision(17); } - bool unit_weight_vector = false; - if (cfg.count("unit_weight_vector")) unit_weight_vector = true; + bool unit_wv = false; + if (cfg.count("unit_wv")) unit_wv = true; HSReporter rep(task_id); bool keep_w = false; if (cfg.count("keep_w")) keep_w = true; + bool update_ok = false; + if (cfg.count("update_ok")) + update_ok = true; const unsigned k = cfg["k"].as<unsigned>(); const unsigned N = cfg["N"].as<unsigned>(); @@ -118,7 +123,7 @@ main(int argc, char** argv) vector<string> print_weights; if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); - + // setup decoder register_feature_functions(); SetSilent(true); @@ -187,7 +192,7 @@ main(int argc, char** argv) vector<vector<WordID> > ref_ids_buf; // references as WordID vecs // where temp files go string tmp_path = cfg["tmp"].as<string>(); - vector<string> w_tmp_files; // used for protocol_w + vector<string> w_tmp_files; // used for keep_w #ifdef DTRAIN_LOCAL string refs_fn = cfg["refs"].as<string>(); ReadFile refs(refs_fn); @@ -226,6 +231,12 @@ main(int argc, char** argv) cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; + if (cfg.count("l1_reg")) + cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl; + if (update_ok) + cerr << setw(25) << "up ok " << update_ok << endl; + if (unit_wv) + cerr << setw(25) << "unit weight vec " << unit_wv << endl; if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl; } @@ -320,7 +331,7 @@ main(int argc, char** argv) // get buffered grammar string grammar_str; while (true) { - string rule; + string rule; getline(grammar_buf_in, rule); if (boost::starts_with(rule, DTRAIN_GRAMMAR_DELIM)) break; grammar_str += rule + "\n"; @@ -372,13 +383,15 @@ main(int argc, char** argv) if (!noup) { vector<pair<ScoredHyp,ScoredHyp> > pairs; if (pair_sampling == "all") - sample_all_pairs(samples, pairs); - if (pair_sampling == "rand") - sample_rand_pairs(samples, pairs, &rng); + all_pairs(samples, pairs); + if (pair_sampling == "5050") + rand_pairs_5050(samples, pairs, &rng); if (pair_sampling == "108010") - sample108010(samples, pairs); + multpart108010(samples, pairs); + if (pair_sampling == "PRO") + PROsampling(samples, pairs); npairs += pairs.size(); - + for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { score_t rank_error = it->second.score - it->first.score; @@ -388,6 +401,11 @@ main(int argc, char** argv) SparseVector<weight_t> diff_vec = it->second.f - it->first.f; lambdas.plus_eq_v_times_s(diff_vec, eta); rank_errors++; + } else { + if (update_ok) { + SparseVector<weight_t> diff_vec = it->first.f - it->second.f; + lambdas.plus_eq_v_times_s(diff_vec, eta); + } } if (it->first.model - it->second.model < 1) margin_violations++; } else { @@ -404,6 +422,8 @@ main(int argc, char** argv) } } + //////// + // TEST THIS // reset cumulative_penalties after 1 iter? // do this only once per INPUT (not per pair) if (l1naive) { @@ -439,8 +459,9 @@ main(int argc, char** argv) } } } + //////// - if (unit_weight_vector && sample_from == "forest") lambdas /= lambdas.l2norm(); + if (unit_wv && sample_from == "forest") lambdas /= lambdas.l2norm(); ++ii; @@ -501,11 +522,11 @@ main(int argc, char** argv) } if (hstreaming) { - rep.update_counter("Score 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(score_avg*_SCALE)); - rep.update_counter("Model 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(model_avg*_SCALE)); - rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), (unsigned)((npairs/(weight_t)in_sz)*_SCALE)); - rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*_SCALE)); - rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*_SCALE)); + rep.update_counter("Score 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(score_avg*DTRAIN_SCALE)); + rep.update_counter("Model 1best avg #"+boost::lexical_cast<string>(t+1), (unsigned)(model_avg*DTRAIN_SCALE)); + rep.update_counter("Pairs avg #"+boost::lexical_cast<string>(t+1), (unsigned)((npairs/(weight_t)in_sz)*DTRAIN_SCALE)); + rep.update_counter("Rank errors avg #"+boost::lexical_cast<string>(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*DTRAIN_SCALE)); + rep.update_counter("Margin violations avg #"+boost::lexical_cast<string>(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*DTRAIN_SCALE)); unsigned nonz = (unsigned)lambdas.size_nonzero(); rep.update_counter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz); rep.update_gcounter("Non zero feature count #"+boost::lexical_cast<string>(t+1), nonz); |