diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 76 |
1 files changed, 44 insertions, 32 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index cf913765..ea5b8835 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -6,35 +6,37 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) { po::options_description ini("Configuration File Options"); ini.add_options() - ("input", po::value<string>()->default_value("-"), "input file") - ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") - ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") - ("decoder_config", po::value<string>(), "configuration file for cdec") - ("print_weights", po::value<string>(), "weights to print on each iteration") - ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences") - ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") - ("keep", po::value<bool>()->zero_tokens(), "keep weights files for each iteration") - ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id") - ("epochs", po::value<unsigned>()->default_value(10), "# of iterations T (per shard)") - ("k", po::value<unsigned>()->default_value(100), "how many translations to sample") - ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'") - ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") - ("pair_sampling", po::value<string>()->default_value("108010"), "how to sample pairs: 'all', '108010' or 'PRO'") - ("pair_threshold", po::value<score_t>()->default_value(0), "bleu [0,1] threshold to filter pairs") - ("N", po::value<unsigned>()->default_value(4), "N for Ngrams (BLEU)") - ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_") - ("learning_rate", po::value<weight_t>()->default_value(0.0001), "learning rate") - ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)") - ("select_weights", po::value<string>()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") - ("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input") - ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") - ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") - ("inc_correct", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") - ("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent of features after each epoch") + ("input", po::value<string>()->default_value("-"), "input file") + ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") + ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") + ("decoder_config", po::value<string>(), "configuration file for cdec") + ("print_weights", po::value<string>(), "weights to print on each iteration") + ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences") + ("tmp", po::value<string>()->default_value("/tmp"), "temp dir to use") + ("keep", po::value<bool>()->zero_tokens(), "keep weights files for each iteration") + ("hstreaming", po::value<string>(), "run in hadoop streaming mode, arg is a task id") + ("epochs", po::value<unsigned>()->default_value(10), "# of iterations T (per shard)") + ("k", po::value<unsigned>()->default_value(100), "how many translations to sample") + ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'") + ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") + ("pair_sampling", po::value<string>()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'") + ("hi_lo", po::value<float>()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5") + ("pair_threshold", po::value<score_t>()->default_value(0), "bleu [0,1] threshold to filter pairs") + ("N", po::value<unsigned>()->default_value(4), "N for Ngrams (BLEU)") + ("scorer", po::value<string>()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_") + ("learning_rate", po::value<weight_t>()->default_value(0.0001), "learning rate") + ("gamma", po::value<weight_t>()->default_value(0), "gamma for SVM (0 for perceptron)") + ("select_weights", po::value<string>()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") + ("rescale", po::value<bool>()->zero_tokens(), "rescale weight vector after each input") + ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") + ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") + ("inc_correct", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") + ("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent of features after each epoch") + ("approx_bleu_scale", po::value<score_t>()->default_value(0.9), "scaling for approx. BLEU") #ifdef DTRAIN_LOCAL - ("refs,r", po::value<string>(), "references in local mode") + ("refs,r", po::value<string>(), "references in local mode") #endif - ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); + ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() ("config,c", po::value<string>(), "dtrain config file") @@ -71,11 +73,18 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as<string>() << "', use 'uniq' or 'not'." << endl; return false; } - if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "108010" && + if ((*cfg)["pair_sampling"].as<string>() != "all" && (*cfg)["pair_sampling"].as<string>() != "XYX" && (*cfg)["pair_sampling"].as<string>() != "PRO") { cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl; return false; } + if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") { + cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl; + } + if((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) { + cerr << "hi_lo must lie in [0.01, 0.5]" << endl; + return false; + } if ((*cfg)["pair_threshold"].as<score_t>() < 0) { cerr << "The threshold must be >= 0!" << endl; return false; @@ -126,6 +135,7 @@ main(int argc, char** argv) const string pair_sampling = cfg["pair_sampling"].as<string>(); const score_t pair_threshold = cfg["pair_threshold"].as<score_t>(); const string select_weights = cfg["select_weights"].as<string>(); + const float hi_lo = cfg["hi_lo"].as<float>(); bool average = false; if (select_weights == "avg") average = true; @@ -231,6 +241,8 @@ main(int argc, char** argv) cerr << setw(25) << "learning rate " << eta << endl; cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; + if (pair_sampling == "XYX") + cerr << setw(25) << "hi lo " << "'" << hi_lo << "'" << endl; cerr << setw(25) << "pair threshold " << pair_threshold << endl; cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; if (cfg.count("l1_reg")) @@ -400,10 +412,10 @@ main(int argc, char** argv) vector<pair<ScoredHyp,ScoredHyp> > pairs; if (pair_sampling == "all") all_pairs(samples, pairs, pair_threshold); - if (pair_sampling == "108010") - part108010(samples, pairs, pair_threshold); + if (pair_sampling == "XYX") + partXYX(samples, pairs, pair_threshold, hi_lo); if (pair_sampling == "PRO") - PROsampling(samples, pairs); + PROsampling(samples, pairs, pair_threshold); npairs += pairs.size(); pair_count += 2*pairs.size(); @@ -456,7 +468,7 @@ main(int argc, char** argv) } } } else if (l1cumul) { - weight_t acc_penalty = (ii+1) * l1_reg; // Note: ii is the index of the current input + weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input for (unsigned d = 0; d < lambdas.size(); d++) { if (lambdas.nonzero(d)) { weight_t v = lambdas.get(d); |