summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/dtrain/Makefile.am2
-rw-r--r--training/dtrain/dtrain.cc49
-rw-r--r--training/dtrain/dtrain.h1
3 files changed, 11 insertions, 41 deletions
diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am
index 844c790d..3c072ffc 100644
--- a/training/dtrain/Makefile.am
+++ b/training/dtrain/Makefile.am
@@ -1,6 +1,6 @@
bin_PROGRAMS = dtrain
-dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h
+dtrain_SOURCES = dtrain.cc score.cc dtrain.h sample.h pairs.h score.h
dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 64fbf80d..67e16d23 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -1,8 +1,7 @@
#include "dtrain.h"
#include "score.h"
-#include "kbestget.h"
-#include "ksampler.h"
-#include "pairsampling.h"
+#include "sample.h"
+#include "pairs.h"
using namespace dtrain;
@@ -13,15 +12,14 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
po::options_description ini("Configuration File Options");
ini.add_options()
("bitext,b", po::value<string>(), "bitext: 'src ||| tgt ||| tgt ||| ...'")
- ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
- ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
- ("decoder_config", po::value<string>(), "configuration file for cdec")
+ ("output,o", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
+ ("input_weights,w", po::value<string>(), "input weights file (e.g. from previous iteration)")
+ ("decoder_config,d", po::value<string>(), "configuration file for cdec")
("print_weights", po::value<string>(), "weights to print on each iteration")
("stop_after", po::value<unsigned>()->default_value(0), "stop after X input sentences")
("keep", po::value<bool>()->zero_tokens(), "keep weights files for each iteration")
("epochs", po::value<unsigned>()->default_value(10), "# of iterations T (per shard)")
("k", po::value<unsigned>()->default_value(100), "how many translations to sample")
- ("sample_from", po::value<string>()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'")
("filter", po::value<string>()->default_value("uniq"), "filter kbest list: 'not', 'uniq'")
("pair_sampling", po::value<string>()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'")
("hi_lo", po::value<float>()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5")
@@ -41,9 +39,8 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate")
("batch", po::value<bool>()->zero_tokens(), "do batch optimization")
("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times")
- ("check", po::value<bool>()->zero_tokens(), "produce list of loss differentials")
- ("output_ranking", po::value<string>()->default_value(""), "Output kbests with model scores and metric per iteration to this folder.")
- ("noup", po::value<bool>()->zero_tokens(), "do not update weights");
+ ("output_ranking", po::value<string>()->default_value(""), "output scored kbests to dir")
+ ("noup", po::value<bool>()->zero_tokens(), "dont't optimize");
po::options_description cl("Command Line Options");
cl.add_options()
("config,c", po::value<string>(), "dtrain config file")
@@ -60,16 +57,6 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
cerr << cl << endl;
return false;
}
- if ((*conf)["sample_from"].as<string>() != "kbest"
- && (*conf)["sample_from"].as<string>() != "forest") {
- cerr << "Wrong 'sample_from' param: '" << (*conf)["sample_from"].as<string>() << "', use 'kbest' or 'forest'." << endl;
- return false;
- }
- if ((*conf)["sample_from"].as<string>() == "kbest" && (*conf)["filter"].as<string>() != "uniq" &&
- (*conf)["filter"].as<string>() != "not") {
- cerr << "Wrong 'filter' param: '" << (*conf)["filter"].as<string>() << "', use 'uniq' or 'not'." << endl;
- return false;
- }
if ((*conf)["pair_sampling"].as<string>() != "all" && (*conf)["pair_sampling"].as<string>() != "XYX" &&
(*conf)["pair_sampling"].as<string>() != "PRO" && (*conf)["pair_sampling"].as<string>() != "output_pairs") {
cerr << "Wrong 'pair_sampling' param: '" << (*conf)["pair_sampling"].as<string>() << "'." << endl;
@@ -121,7 +108,6 @@ main(int argc, char** argv)
const unsigned T = conf["epochs"].as<unsigned>();
const unsigned stop_after = conf["stop_after"].as<unsigned>();
const string filter_type = conf["filter"].as<string>();
- const string sample_from = conf["sample_from"].as<string>();
const string pair_sampling = conf["pair_sampling"].as<string>();
const score_t pair_threshold = conf["pair_threshold"].as<score_t>();
const string select_weights = conf["select_weights"].as<string>();
@@ -130,8 +116,6 @@ main(int argc, char** argv)
const score_t approx_bleu_d = conf["approx_bleu_d"].as<score_t>();
const unsigned max_pairs = conf["max_pairs"].as<unsigned>();
int repeat = conf["repeat"].as<unsigned>();
- bool check = false;
- if (conf.count("check")) check = true;
weight_t loss_margin = conf["loss_margin"].as<weight_t>();
bool batch = false;
if (conf.count("batch")) batch = true;
@@ -183,10 +167,7 @@ main(int argc, char** argv)
// setup decoder observer
MT19937 rng; // random number generator, only for forest sampling
HypSampler* observer;
- if (sample_from == "kbest")
- observer = static_cast<KBestGetter*>(new KBestGetter(k, filter_type));
- else
- observer = static_cast<KSampler*>(new KSampler(k, &rng));
+ observer = static_cast<KBestGetter*>(new KBestGetter(k, filter_type));
observer->SetScorer(scorer);
// init weights
@@ -244,9 +225,7 @@ main(int argc, char** argv)
cerr << setw(26) << "scorer '" << scorer_str << "'" << endl;
if (scorer_str == "approx_bleu")
cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl;
- cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
- if (sample_from == "kbest")
- cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
+ cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
cerr << setw(25) << "learning rate " << eta << endl;
cerr << setw(25) << "gamma " << gamma << endl;
cerr << setw(25) << "loss margin " << loss_margin << endl;
@@ -408,9 +387,6 @@ main(int argc, char** argv)
score_t kbest_loss_first = 0.0, kbest_loss_last = 0.0;
- if (check) repeat = 2;
- vector<float> losses; // for check
-
if (pair_sampling == "output_pairs") {
for (auto p: pairs) {
cout << p.first.model << " ||| " << p.first.score << " ||| " << p.first.f << endl;
@@ -428,7 +404,6 @@ main(int argc, char** argv)
}
score_t model_diff = it->first.model - it->second.model;
score_t loss = max(0.0, -1.0 * model_diff);
- losses.push_back(loss);
kbest_loss_first += loss;
}
@@ -439,15 +414,11 @@ main(int argc, char** argv)
SparseVector<weight_t> sum_up; // for pclr
if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
- unsigned pair_idx = 0; // for check
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
score_t model_diff = it->first.model - it->second.model;
score_t loss = max(0.0, -1.0 * model_diff);
- if (check && ki==repeat-1) cout << losses[pair_idx] - loss << endl;
- pair_idx++;
-
if (repeat > 1) {
model_diff = lambdas.dot(it->first.f) - lambdas.dot(it->second.f);
kbest_loss += loss;
@@ -649,8 +620,6 @@ main(int argc, char** argv)
Weights::WriteToFile(w_fn, decoder_weights, true);
}
- if (check) cout << "---" << endl;
-
} // outer loop
if (average) w_average /= (weight_t)T;
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index d7980688..e25c6f24 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -15,6 +15,7 @@
#include "decoder.h"
#include "ff_register.h"
+#include "sampler.h"
#include "sentence_metadata.h"
#include "verbose.h"
#include "viterbi.h"