summaryrefslogtreecommitdiff
path: root/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dtrain/dtrain.cc')
-rw-r--r--dtrain/dtrain.cc11
1 files changed, 7 insertions, 4 deletions
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index ea5b8835..3dee10f2 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -32,7 +32,7 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("l1_reg_strength", po::value<weight_t>(), "l1 regularization strength")
("inc_correct", po::value<bool>()->zero_tokens(), "include correctly ranked pairs into updates")
("fselect", po::value<weight_t>()->default_value(-1), "TODO select top x percent of features after each epoch")
- ("approx_bleu_scale", po::value<score_t>()->default_value(0.9), "scaling for approx. BLEU")
+ ("approx_bleu_d", po::value<score_t>()->default_value(0.9), "discount for approx. BLEU")
#ifdef DTRAIN_LOCAL
("refs,r", po::value<string>(), "references in local mode")
#endif
@@ -136,6 +136,7 @@ main(int argc, char** argv)
const score_t pair_threshold = cfg["pair_threshold"].as<score_t>();
const string select_weights = cfg["select_weights"].as<string>();
const float hi_lo = cfg["hi_lo"].as<float>();
+ const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>();
bool average = false;
if (select_weights == "avg")
average = true;
@@ -161,7 +162,7 @@ main(int argc, char** argv)
} else if (scorer_str == "smooth_bleu") {
scorer = dynamic_cast<SmoothBleuScorer*>(new SmoothBleuScorer);
} else if (scorer_str == "approx_bleu") {
- scorer = dynamic_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N));
+ scorer = dynamic_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N, approx_bleu_d));
} else {
cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl;
exit(1);
@@ -235,6 +236,8 @@ main(int argc, char** argv)
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
cerr << setw(25) << "scorer '" << scorer_str << "'" << endl;
+ if (scorer_str == "approx_bleu")
+ cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl;
cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl;
if (sample_from == "kbest")
cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl;
@@ -242,7 +245,7 @@ main(int argc, char** argv)
cerr << setw(25) << "gamma " << gamma << endl;
cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl;
if (pair_sampling == "XYX")
- cerr << setw(25) << "hi lo " << "'" << hi_lo << "'" << endl;
+ cerr << setw(25) << "hi lo " << hi_lo << endl;
cerr << setw(25) << "pair threshold " << pair_threshold << endl;
cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl;
if (cfg.count("l1_reg"))
@@ -261,7 +264,7 @@ main(int argc, char** argv)
cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl;
if (cfg.count("stop-after"))
cerr << setw(25) << "stop_after " << stop_after << endl;
- if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " lines of input)" << endl;
+ if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl;
}