diff options
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r-- | dtrain/dtrain.cc | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index ea5b8835..3dee10f2 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -32,7 +32,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength") ("inc_correct", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates") ("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent of features after each epoch") - ("approx_bleu_scale", po::value<score_t>()->default_value(0.9), "scaling for approx. BLEU") + ("approx_bleu_d", po::value<score_t>()->default_value(0.9), "discount for approx. BLEU") #ifdef DTRAIN_LOCAL ("refs,r", po::value<string>(), "references in local mode") #endif @@ -136,6 +136,7 @@ main(int argc, char** argv) const score_t pair_threshold = cfg["pair_threshold"].as<score_t>(); const string select_weights = cfg["select_weights"].as<string>(); const float hi_lo = cfg["hi_lo"].as<float>(); + const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>(); bool average = false; if (select_weights == "avg") average = true; @@ -161,7 +162,7 @@ main(int argc, char** argv) } else if (scorer_str == "smooth_bleu") { scorer = dynamic_cast<SmoothBleuScorer*>(new SmoothBleuScorer); } else if (scorer_str == "approx_bleu") { - scorer = dynamic_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N)); + scorer = dynamic_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N, approx_bleu_d)); } else { cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; exit(1); @@ -235,6 +236,8 @@ main(int argc, char** argv) cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; cerr << setw(25) << "scorer '" << scorer_str << "'" << endl; + if (scorer_str == "approx_bleu") + cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; if (sample_from == "kbest") cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; @@ -242,7 +245,7 @@ main(int argc, char** argv) cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; if (pair_sampling == "XYX") - cerr << setw(25) << "hi lo " << "'" << hi_lo << "'" << endl; + cerr << setw(25) << "hi lo " << hi_lo << endl; cerr << setw(25) << "pair threshold " << pair_threshold << endl; cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; if (cfg.count("l1_reg")) @@ -261,7 +264,7 @@ main(int argc, char** argv) cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl; if (cfg.count("stop-after")) cerr << setw(25) << "stop_after " << stop_after << endl; - if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl; + if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl; } |