summaryrefslogtreecommitdiff
path: root/training/dtrain/dtrain.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r--training/dtrain/dtrain.cc35
1 files changed, 31 insertions, 4 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index e5cfd50a..97df530b 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -10,9 +10,9 @@ main(int argc, char** argv)
{
// get configuration
po::variables_map conf;
- if (!dtrain_init(argc, argv, &conf))
- exit(1); // something is wrong
+ dtrain_init(argc, argv, &conf);
const size_t k = conf["k"].as<size_t>();
+ const string score_name = conf["score"].as<string>();
const size_t N = conf["N"].as<size_t>();
const size_t T = conf["iterations"].as<size_t>();
const weight_t eta = conf["learning_rate"].as<weight_t>();
@@ -25,12 +25,28 @@ main(int argc, char** argv)
boost::split(print_weights, conf["print_weights"].as<string>(),
boost::is_any_of(" "));
- // setup decoder
+ // setup decoder and scorer
register_feature_functions();
SetSilent(true);
ReadFile f(conf["decoder_conf"].as<string>());
Decoder decoder(f.stream());
- ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N));
+ Scorer* scorer;
+ if (score_name == "nakov") {
+ scorer = static_cast<PerSentenceBleuScorer*>(new PerSentenceBleuScorer(N));
+ } else if (score_name == "papineni") {
+ scorer = static_cast<BleuScorer*>(new BleuScorer(N));
+ } else if (score_name == "lin") {
+ scorer = static_cast<OriginalPerSentenceBleuScorer*>\
+ (new OriginalPerSentenceBleuScorer(N));
+ } else if (score_name == "liang") {
+ scorer = static_cast<SmoothPerSentenceBleuScorer*>\
+ (new SmoothPerSentenceBleuScorer(N));
+ } else if (score_name == "chiang") {
+ scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N));
+ } else {
+ assert(false);
+ }
+ ScoredKbest* observer = new ScoredKbest(k, scorer);
// weights
vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
@@ -52,6 +68,7 @@ main(int argc, char** argv)
// output configuration
cerr << "dtrain" << endl << "Parameters:" << endl;
cerr << setw(25) << "k " << k << endl;
+ cerr << setw(25) << "score " << "'" << score_name << "'" << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
cerr << setw(25) << "learning rate " << eta << endl;
@@ -149,6 +166,16 @@ main(int argc, char** argv)
lambdas_copy = lambdas;
lambdas.plus_eq_v_times_s(updates, eta);
+ // update context for approx. BLEU
+ if (score_name == "chiang") {
+ for (auto it: *samples) {
+ if (it.rank == 0) {
+ scorer->UpdateContext(it.w, buf_ngs[i], buf_ls[i], 0.9);
+ break;
+ }
+ }
+ }
+
// l1 regularization
// NB: regularization is done after each sentence,
// not after every single pair!