diff options
Diffstat (limited to 'training/dtrain/dtrain.cc')
-rw-r--r-- | training/dtrain/dtrain.cc | 35 |
1 files changed, 31 insertions, 4 deletions
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index e5cfd50a..97df530b 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -10,9 +10,9 @@ main(int argc, char** argv) { // get configuration po::variables_map conf; - if (!dtrain_init(argc, argv, &conf)) - exit(1); // something is wrong + dtrain_init(argc, argv, &conf); const size_t k = conf["k"].as<size_t>(); + const string score_name = conf["score"].as<string>(); const size_t N = conf["N"].as<size_t>(); const size_t T = conf["iterations"].as<size_t>(); const weight_t eta = conf["learning_rate"].as<weight_t>(); @@ -25,12 +25,28 @@ main(int argc, char** argv) boost::split(print_weights, conf["print_weights"].as<string>(), boost::is_any_of(" ")); - // setup decoder + // setup decoder and scorer register_feature_functions(); SetSilent(true); ReadFile f(conf["decoder_conf"].as<string>()); Decoder decoder(f.stream()); - ScoredKbest* observer = new ScoredKbest(k, new PerSentenceBleuScorer(N)); + Scorer* scorer; + if (score_name == "nakov") { + scorer = static_cast<PerSentenceBleuScorer*>(new PerSentenceBleuScorer(N)); + } else if (score_name == "papineni") { + scorer = static_cast<BleuScorer*>(new BleuScorer(N)); + } else if (score_name == "lin") { + scorer = static_cast<OriginalPerSentenceBleuScorer*>\ + (new OriginalPerSentenceBleuScorer(N)); + } else if (score_name == "liang") { + scorer = static_cast<SmoothPerSentenceBleuScorer*>\ + (new SmoothPerSentenceBleuScorer(N)); + } else if (score_name == "chiang") { + scorer = static_cast<ApproxBleuScorer*>(new ApproxBleuScorer(N)); + } else { + assert(false); + } + ScoredKbest* observer = new ScoredKbest(k, scorer); // weights vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); @@ -52,6 +68,7 @@ main(int argc, char** argv) // output configuration cerr << "dtrain" << endl << "Parameters:" << endl; cerr << setw(25) << "k " << k << endl; + cerr << setw(25) << "score " << "'" << score_name << "'" << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; cerr << setw(25) << "learning rate " << eta << endl; @@ -149,6 +166,16 @@ main(int argc, char** argv) lambdas_copy = lambdas; lambdas.plus_eq_v_times_s(updates, eta); + // update context for approx. BLEU + if (score_name == "chiang") { + for (auto it: *samples) { + if (it.rank == 0) { + scorer->UpdateContext(it.w, buf_ngs[i], buf_ls[i], 0.9); + break; + } + } + } + // l1 regularization // NB: regularization is done after each sentence, // not after every single pair! |