From a94a2016915d91d01a102c56b86a54e5fe6e647a Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Fri, 27 Apr 2012 01:54:47 +0200 Subject: fix approx. BLEU of (Chiang et al. '08) --- dtrain/dtrain.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'dtrain/dtrain.cc') diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index ea5b8835..3dee10f2 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -32,7 +32,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("l1_reg_strength", po::value(), "l1 regularization strength") ("inc_correct", po::value()->zero_tokens(), "include correctly ranked pairs into updates") ("fselect", po::value()->default_value(-1), "TODO select top x percent of features after each epoch") - ("approx_bleu_scale", po::value()->default_value(0.9), "scaling for approx. BLEU") + ("approx_bleu_d", po::value()->default_value(0.9), "discount for approx. BLEU") #ifdef DTRAIN_LOCAL ("refs,r", po::value(), "references in local mode") #endif @@ -136,6 +136,7 @@ main(int argc, char** argv) const score_t pair_threshold = cfg["pair_threshold"].as(); const string select_weights = cfg["select_weights"].as(); const float hi_lo = cfg["hi_lo"].as(); + const score_t approx_bleu_d = cfg["approx_bleu_d"].as(); bool average = false; if (select_weights == "avg") average = true; @@ -161,7 +162,7 @@ main(int argc, char** argv) } else if (scorer_str == "smooth_bleu") { scorer = dynamic_cast(new SmoothBleuScorer); } else if (scorer_str == "approx_bleu") { - scorer = dynamic_cast(new ApproxBleuScorer(N)); + scorer = dynamic_cast(new ApproxBleuScorer(N, approx_bleu_d)); } else { cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; exit(1); @@ -235,6 +236,8 @@ main(int argc, char** argv) cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; cerr << setw(25) << "scorer '" << scorer_str << "'" << endl; + if (scorer_str == "approx_bleu") + cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; if (sample_from == "kbest") cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; @@ -242,7 +245,7 @@ main(int argc, char** argv) cerr << setw(25) << "gamma " << gamma << endl; cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; if (pair_sampling == "XYX") - cerr << setw(25) << "hi lo " << "'" << hi_lo << "'" << endl; + cerr << setw(25) << "hi lo " << hi_lo << endl; cerr << setw(25) << "pair threshold " << pair_threshold << endl; cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; if (cfg.count("l1_reg")) @@ -261,7 +264,7 @@ main(int argc, char** argv) cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as() << "'" << endl; if (cfg.count("stop-after")) cerr << setw(25) << "stop_after " << stop_after << endl; - if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl; + if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl; } -- cgit v1.2.3