summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc61
1 files changed, 10 insertions, 51 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 67e16d23..18addcb0 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -30,10 +30,9 @@ dtrain_init(int argc, char** argv, po::variables_map* conf)
("gamma", po::value<weight_t>()->default_value(0.), "gamma for SVM (0 for perceptron)")
("select_weights", po::value<string>()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)")
("rescale", po::value<bool>()->zero_tokens(), "(re)scale data and weight vector to unit length")
- ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)")
+ ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization with clipping as in 'Tsuroka et al' (2010)")
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
("fselect", po::value<weight_t>()->default_value(-1), "select top x percent (or by threshold) of features after each epoch NOT IMPLEMENTED") // TODO
- ("approx_bleu_d", po::value<score_t>()->default_value(0.9), "discount for approx. BLEU")
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate")
@@ -107,13 +106,11 @@ main(int argc, char** argv)
const unsigned N = conf["N"].as<unsigned>();
const unsigned T = conf["epochs"].as<unsigned>();
const unsigned stop_after = conf["stop_after"].as<unsigned>();
- const string filter_type = conf["filter"].as<string>();
const string pair_sampling = conf["pair_sampling"].as<string>();
const score_t pair_threshold = conf["pair_threshold"].as<score_t>();
const string select_weights = conf["select_weights"].as<string>();
const string output_ranking = conf["output_ranking"].as<string>();
const float hi_lo = conf["hi_lo"].as<float>();
- const score_t approx_bleu_d = conf["approx_bleu_d"].as<score_t>();
const unsigned max_pairs = conf["max_pairs"].as<unsigned>();
int repeat = conf["repeat"].as<unsigned>();
weight_t loss_margin = conf["loss_margin"].as<weight_t>();
@@ -136,39 +133,8 @@ main(int argc, char** argv)
cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl;
Decoder decoder(ini_rf.stream());
- // scoring metric/scorer
- string scorer_str = conf["scorer"].as<string>();
- LocalScorer* scorer;
- if (scorer_str == "bleu") {
- scorer = static_cast<BleuScorer*>(new BleuScorer);
- } else if (scorer_str == "stupid_bleu") {
- scorer = static_cast<StupidBleuScorer*>(new StupidBleuScorer);
- } else if (scorer_str == "fixed_stupid_bleu") {
- scorer = static_cast<FixedStupidBleuScorer*>(new FixedStupidBleuScorer);
- } else if (scorer_str == "smooth_bleu") {
- scorer = static_cast<SmoothBleuScorer*>(new SmoothBleuScorer);
- } else if (scorer_str == "sum_bleu") {
- scorer = static_cast<SumBleuScorer*>(new SumBleuScorer);
- } else if (scorer_str == "sumexp_bleu") {
- scorer = static_cast<SumExpBleuScorer*>(new SumExpBleuScorer);
- } else if (scorer_str == "sumwhatever_bleu") {
- scorer = static_cast<SumWhateverBleuScorer*>(new SumWhateverBleuScorer);
- } else if (scorer_str == "approx_bleu") {
- scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N, approx_bleu_d));
- } else if (scorer_str == "lc_bleu") {
- scorer = static_cast<LinearBleuScorer*>(new LinearBleuScorer(N));
- } else {
- cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl;
- exit(1);
- }
- vector<score_t> bleu_weights;
- scorer->Init(N, bleu_weights);
-
// setup decoder observer
- MT19937 rng; // random number generator, only for forest sampling
- HypSampler* observer;
- observer = static_cast<KBestGetter*>(new KBestGetter(k, filter_type));
- observer->SetScorer(scorer);
+ ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N));
// init weights
vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
@@ -222,10 +188,6 @@ main(int argc, char** argv)
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
cerr << setw(25) << "batch " << batch << endl;
- cerr << setw(26) << "scorer '" << scorer_str << "'" << endl;
- if (scorer_str == "approx_bleu")
- cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl;
- cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
cerr << setw(25) << "learning rate " << eta << endl;
cerr << setw(25) << "gamma " << gamma << endl;
cerr << setw(25) << "loss margin " << loss_margin << endl;
@@ -242,7 +204,6 @@ main(int argc, char** argv)
cerr << setw(25) << "pclr " << pclr << endl;
cerr << setw(25) << "max pairs " << max_pairs << endl;
cerr << setw(25) << "repeat " << repeat << endl;
- //cerr << setw(25) << "test k-best " << test_k_best << endl;
cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
@@ -321,13 +282,13 @@ main(int argc, char** argv)
vector<WordID> cur_ref;
vector<string> tok;
boost::split(tok, r, boost::is_any_of(" "));
- register_and_convert(tok, cur_ref);
+ RegisterAndConvert(tok, cur_ref);
cur_refs.push_back(cur_ref);
}
refs_as_ids_buf.push_back(cur_refs);
src_str_buf.push_back(in);
}
- observer->SetRef(refs_as_ids_buf[ii]);
+ observer->SetReference(refs_as_ids_buf[ii]);
if (t == 0)
decoder.Decode(in, observer);
else
@@ -341,7 +302,7 @@ main(int argc, char** argv)
stringstream ss;
for (auto s: *samples) {
ss << ii << " ||| ";
- printWordIDVec(s.w, ss);
+ PrintWordIDVec(s.w, ss);
ss << " ||| " << s.model << " ||| " << s.score << endl;
}
of.get() << ss.str();
@@ -350,12 +311,12 @@ main(int argc, char** argv)
if (verbose) {
cerr << "--- refs for " << ii << ": ";
for (auto r: refs_as_ids_buf[ii]) {
- printWordIDVec(r);
+ PrintWordIDVec(r);
cerr << endl;
}
for (unsigned u = 0; u < samples->size(); u++) {
cerr << _p2 << _np << "[" << u << ". '";
- printWordIDVec((*samples)[u].w);
+ PrintWordIDVec((*samples)[u].w);
cerr << "'" << endl;
cerr << "SCORE=" << (*samples)[u].score << ",model="<< (*samples)[u].model << endl;
cerr << "F{" << (*samples)[u].f << "} ]" << endl << endl;
@@ -367,8 +328,8 @@ main(int argc, char** argv)
model_sum += (*samples)[0].model;
}
- f_count += observer->get_f_count();
- list_sz += observer->get_sz();
+ f_count += observer->GetFeatureCount();
+ list_sz += observer->GetSize();
// weight updates
if (!noup) {
@@ -552,8 +513,6 @@ main(int argc, char** argv)
if (average) w_average += lambdas;
- if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset();
-
// print some stats
score_t score_avg = score_sum/(score_t)in_sz;
score_t model_avg = model_sum/(score_t)in_sz;
@@ -665,7 +624,7 @@ main(int argc, char** argv)
if (!quiet) {
cerr << _p5 << _np << endl << "---" << endl << "Best iteration: ";
- cerr << best_it+1 << " [SCORE '" << scorer_str << "'=" << max_score << "]." << endl;
+ cerr << best_it+1 << " [SCORE = " << max_score << "]." << endl;
cerr << "This took " << overall_time/60. << " min." << endl;
}
}