diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 61 |
1 files changed, 10 insertions, 51 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 67e16d23..18addcb0 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -30,10 +30,9 @@ dtrain_init(int argc, char** argv, po::variables_map* conf) ("gamma", po::value<weight_t>()->default_value(0.), "gamma for SVM (0 for perceptron)") ("select_weights", po::value<string>()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") ("rescale", po::value<bool>()->zero_tokens(), "(re)scale data and weight vector to unit length") - ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") + ("l1_reg", po::value<string>()->default_value("none"), "apply l1 regularization with clipping as in 'Tsuroka et al' (2010)") ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") ("fselect", po::value<weight_t>()->default_value(-1), "select top x percent (or by threshold) of features after each epoch NOT IMPLEMENTED") // TODO - ("approx_bleu_d", po::value<score_t>()->default_value(0.9), "discount for approx. BLEU") ("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near") ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate") @@ -107,13 +106,11 @@ main(int argc, char** argv) const unsigned N = conf["N"].as<unsigned>(); const unsigned T = conf["epochs"].as<unsigned>(); const unsigned stop_after = conf["stop_after"].as<unsigned>(); - const string filter_type = conf["filter"].as<string>(); const string pair_sampling = conf["pair_sampling"].as<string>(); const score_t pair_threshold = conf["pair_threshold"].as<score_t>(); const string select_weights = conf["select_weights"].as<string>(); const string output_ranking = conf["output_ranking"].as<string>(); const float hi_lo = conf["hi_lo"].as<float>(); - const score_t approx_bleu_d = conf["approx_bleu_d"].as<score_t>(); const unsigned max_pairs = conf["max_pairs"].as<unsigned>(); int repeat = conf["repeat"].as<unsigned>(); weight_t loss_margin = conf["loss_margin"].as<weight_t>(); @@ -136,39 +133,8 @@ main(int argc, char** argv) cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl; Decoder decoder(ini_rf.stream()); - // scoring metric/scorer - string scorer_str = conf["scorer"].as<string>(); - LocalScorer* scorer; - if (scorer_str == "bleu") { - scorer = static_cast<BleuScorer*>(new BleuScorer); - } else if (scorer_str == "stupid_bleu") { - scorer = static_cast<StupidBleuScorer*>(new StupidBleuScorer); - } else if (scorer_str == "fixed_stupid_bleu") { - scorer = static_cast<FixedStupidBleuScorer*>(new FixedStupidBleuScorer); - } else if (scorer_str == "smooth_bleu") { - scorer = static_cast<SmoothBleuScorer*>(new SmoothBleuScorer); - } else if (scorer_str == "sum_bleu") { - scorer = static_cast<SumBleuScorer*>(new SumBleuScorer); - } else if (scorer_str == "sumexp_bleu") { - scorer = static_cast<SumExpBleuScorer*>(new SumExpBleuScorer); - } else if (scorer_str == "sumwhatever_bleu") { - scorer = static_cast<SumWhateverBleuScorer*>(new SumWhateverBleuScorer); - } else if (scorer_str == "approx_bleu") { - scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N, approx_bleu_d)); - } else if (scorer_str == "lc_bleu") { - scorer = static_cast<LinearBleuScorer*>(new LinearBleuScorer(N)); - } else { - cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; - exit(1); - } - vector<score_t> bleu_weights; - scorer->Init(N, bleu_weights); - // setup decoder observer - MT19937 rng; // random number generator, only for forest sampling - HypSampler* observer; - observer = static_cast<KBestGetter*>(new KBestGetter(k, filter_type)); - observer->SetScorer(scorer); + ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N)); // init weights vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); @@ -222,10 +188,6 @@ main(int argc, char** argv) cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; cerr << setw(25) << "batch " << batch << endl; - cerr << setw(26) << "scorer '" << scorer_str << "'" << endl; - if (scorer_str == "approx_bleu") - cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; - cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; cerr << setw(25) << "learning rate " << eta << endl; cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "loss margin " << loss_margin << endl; @@ -242,7 +204,6 @@ main(int argc, char** argv) cerr << setw(25) << "pclr " << pclr << endl; cerr << setw(25) << "max pairs " << max_pairs << endl; cerr << setw(25) << "repeat " << repeat << endl; - //cerr << setw(25) << "test k-best " << test_k_best << endl; cerr << setw(25) << "cdec conf " << "'" << conf["decoder_config"].as<string>() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; @@ -321,13 +282,13 @@ main(int argc, char** argv) vector<WordID> cur_ref; vector<string> tok; boost::split(tok, r, boost::is_any_of(" ")); - register_and_convert(tok, cur_ref); + RegisterAndConvert(tok, cur_ref); cur_refs.push_back(cur_ref); } refs_as_ids_buf.push_back(cur_refs); src_str_buf.push_back(in); } - observer->SetRef(refs_as_ids_buf[ii]); + observer->SetReference(refs_as_ids_buf[ii]); if (t == 0) decoder.Decode(in, observer); else @@ -341,7 +302,7 @@ main(int argc, char** argv) stringstream ss; for (auto s: *samples) { ss << ii << " ||| "; - printWordIDVec(s.w, ss); + PrintWordIDVec(s.w, ss); ss << " ||| " << s.model << " ||| " << s.score << endl; } of.get() << ss.str(); @@ -350,12 +311,12 @@ main(int argc, char** argv) if (verbose) { cerr << "--- refs for " << ii << ": "; for (auto r: refs_as_ids_buf[ii]) { - printWordIDVec(r); + PrintWordIDVec(r); cerr << endl; } for (unsigned u = 0; u < samples->size(); u++) { cerr << _p2 << _np << "[" << u << ". '"; - printWordIDVec((*samples)[u].w); + PrintWordIDVec((*samples)[u].w); cerr << "'" << endl; cerr << "SCORE=" << (*samples)[u].score << ",model="<< (*samples)[u].model << endl; cerr << "F{" << (*samples)[u].f << "} ]" << endl << endl; @@ -367,8 +328,8 @@ main(int argc, char** argv) model_sum += (*samples)[0].model; } - f_count += observer->get_f_count(); - list_sz += observer->get_sz(); + f_count += observer->GetFeatureCount(); + list_sz += observer->GetSize(); // weight updates if (!noup) { @@ -552,8 +513,6 @@ main(int argc, char** argv) if (average) w_average += lambdas; - if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); - // print some stats score_t score_avg = score_sum/(score_t)in_sz; score_t model_avg = model_sum/(score_t)in_sz; @@ -665,7 +624,7 @@ main(int argc, char** argv) if (!quiet) { cerr << _p5 << _np << endl << "---" << endl << "Best iteration: "; - cerr << best_it+1 << " [SCORE '" << scorer_str << "'=" << max_score << "]." << endl; + cerr << best_it+1 << " [SCORE = " << max_score << "]." << endl; cerr << "This took " << overall_time/60. << " min." << endl; } } |