diff options
Diffstat (limited to 'training/dtrain')
-rw-r--r-- | training/dtrain/Makefile.am | 2 | ||||
-rw-r--r-- | training/dtrain/dtrain.cc | 49 | ||||
-rw-r--r-- | training/dtrain/dtrain.h | 1 |
3 files changed, 11 insertions, 41 deletions
diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 844c790d..3c072ffc 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = dtrain -dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h +dtrain_SOURCES = dtrain.cc score.cc dtrain.h sample.h pairs.h score.h dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 64fbf80d..67e16d23 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -1,8 +1,7 @@ #include "dtrain.h" #include "score.h" -#include "kbestget.h" -#include "ksampler.h" -#include "pairsampling.h" +#include "sample.h" +#include "pairs.h" using namespace dtrain; @@ -13,15 +12,14 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) po::options_description ini("Configuration File Options"); ini.add_options() ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt ||| tgt ||| ...'") - ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") - ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") - ("decoder_config", po::value<string>(), "configuration file for cdec") + ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") + ("input_weights,w", po::value<string>(), "input weights file (e.g. from previous iteration)") + ("decoder_config,d", po::value<string>(), "configuration file for cdec") ("print_weights", po::value<string>(), "weights to print on each iteration") ("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences") ("keep", po::value<bool>()->zero_tokens(), "keep weights files for each iteration") ("epochs", po::value<unsigned>()->default_value(10), "# of iterations T (per shard)") ("k", po::value<unsigned>()->default_value(100), "how many translations to sample") - ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'") ("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") ("pair_sampling", po::value<string>()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'") ("hi_lo", po::value<float>()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5") @@ -41,9 +39,8 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate") ("batch", po::value<bool>()->zero_tokens(), "do batch optimization") ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times") - ("check", po::value<bool>()->zero_tokens(), "produce list of loss differentials") - ("output_ranking", po::value<string>()->default_value(""), "Output kbests with model scores and metric per iteration to this folder.") - ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); + ("output_ranking", po::value<string>()->default_value(""), "output scored kbests to dir") + ("noup", po::value<bool>()->zero_tokens(), "dont't optimize"); po::options_description cl("Command Line Options"); cl.add_options() ("config,c", po::value<string>(), "dtrain config file") @@ -60,16 +57,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) cerr << cl << endl; return false; } - if ((*conf)["sample_from"].as<string>() != "kbest" - && (*conf)["sample_from"].as<string>() != "forest") { - cerr << "Wrong 'sample_from' param: '" << (*conf)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl; - return false; - } - if ((*conf)["sample_from"].as<string>() == "kbest" && (*conf)["filter"].as<string>() != "uniq" && - (*conf)["filter"].as<string>() != "not") { - cerr << "Wrong 'filter' param: '" << (*conf)["filter"].as<string>() << "', use 'uniq' or 'not'." << endl; - return false; - } if ((*conf)["pair_sampling"].as<string>() != "all" && (*conf)["pair_sampling"].as<string>() != "XYX" && (*conf)["pair_sampling"].as<string>() != "PRO" && (*conf)["pair_sampling"].as<string>() != "output_pairs") { cerr << "Wrong 'pair_sampling' param: '" << (*conf)["pair_sampling"].as<string>() << "'." << endl; @@ -121,7 +108,6 @@ main(int argc, char** argv) const unsigned T = conf["epochs"].as<unsigned>(); const unsigned stop_after = conf["stop_after"].as<unsigned>(); const string filter_type = conf["filter"].as<string>(); - const string sample_from = conf["sample_from"].as<string>(); const string pair_sampling = conf["pair_sampling"].as<string>(); const score_t pair_threshold = conf["pair_threshold"].as<score_t>(); const string select_weights = conf["select_weights"].as<string>(); @@ -130,8 +116,6 @@ main(int argc, char** argv) const score_t approx_bleu_d = conf["approx_bleu_d"].as<score_t>(); const unsigned max_pairs = conf["max_pairs"].as<unsigned>(); int repeat = conf["repeat"].as<unsigned>(); - bool check = false; - if (conf.count("check")) check = true; weight_t loss_margin = conf["loss_margin"].as<weight_t>(); bool batch = false; if (conf.count("batch")) batch = true; @@ -183,10 +167,7 @@ main(int argc, char** argv) // setup decoder observer MT19937 rng; // random number generator, only for forest sampling HypSampler* observer; - if (sample_from == "kbest") - observer = static_cast<KBestGetter*>(new KBestGetter(k, filter_type)); - else - observer = static_cast<KSampler*>(new KSampler(k, &rng)); + observer = static_cast<KBestGetter*>(new KBestGetter(k, filter_type)); observer->SetScorer(scorer); // init weights @@ -244,9 +225,7 @@ main(int argc, char** argv) cerr << setw(26) << "scorer '" << scorer_str << "'" << endl; if (scorer_str == "approx_bleu") cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; - cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; - if (sample_from == "kbest") - cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; + cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; cerr << setw(25) << "learning rate " << eta << endl; cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "loss margin " << loss_margin << endl; @@ -408,9 +387,6 @@ main(int argc, char** argv) score_t kbest_loss_first = 0.0, kbest_loss_last = 0.0; - if (check) repeat = 2; - vector<float> losses; // for check - if (pair_sampling == "output_pairs") { for (auto p: pairs) { cout << p.first.model << " ||| " << p.first.score << " ||| " << p.first.f << endl; @@ -428,7 +404,6 @@ main(int argc, char** argv) } score_t model_diff = it->first.model - it->second.model; score_t loss = max(0.0, -1.0 * model_diff); - losses.push_back(loss); kbest_loss_first += loss; } @@ -439,15 +414,11 @@ main(int argc, char** argv) SparseVector<weight_t> sum_up; // for pclr if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas; - unsigned pair_idx = 0; // for check for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { score_t model_diff = it->first.model - it->second.model; score_t loss = max(0.0, -1.0 * model_diff); - if (check && ki==repeat-1) cout << losses[pair_idx] - loss << endl; - pair_idx++; - if (repeat > 1) { model_diff = lambdas.dot(it->first.f) - lambdas.dot(it->second.f); kbest_loss += loss; @@ -649,8 +620,6 @@ main(int argc, char** argv) Weights::WriteToFile(w_fn, decoder_weights, true); } - if (check) cout << "---" << endl; - } // outer loop if (average) w_average /= (weight_t)T; diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index d7980688..e25c6f24 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -15,6 +15,7 @@ #include "decoder.h" #include "ff_register.h" +#include "sampler.h" #include "sentence_metadata.h" #include "verbose.h" #include "viterbi.h" |