diff options
Diffstat (limited to 'training/mira/kbest_mira.cc')
-rw-r--r-- | training/mira/kbest_mira.cc | 41 |
1 files changed, 27 insertions, 14 deletions
diff --git a/training/mira/kbest_mira.cc b/training/mira/kbest_mira.cc index 8b7993dd..bcb261c9 100644 --- a/training/mira/kbest_mira.cc +++ b/training/mira/kbest_mira.cc @@ -8,9 +8,11 @@ #include <boost/program_options.hpp> #include <boost/program_options/variables_map.hpp> +#include "stringlib.h" #include "hg_sampler.h" #include "sentence_metadata.h" -#include "scorer.h" +#include "ns.h" +#include "ns_docscorer.h" #include "verbose.h" #include "viterbi.h" #include "hg.h" @@ -91,8 +93,9 @@ struct GoodBadOracle { }; struct TrainingObserver : public DecoderObserver { - TrainingObserver(const int k, const DocScorer& d, bool sf, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {} - const DocScorer& ds; + TrainingObserver(const int k, const DocumentScorer& d, const EvaluationMetric& m, bool sf, vector<GoodBadOracle>* o) : ds(d), metric(m), oracles(*o), kbest_size(k), sample_forest(sf) {} + const DocumentScorer& ds; + const EvaluationMetric& metric; vector<GoodBadOracle>& oracles; std::tr1::shared_ptr<HypothesisInfo> cur_best; const int kbest_size; @@ -121,13 +124,16 @@ struct TrainingObserver : public DecoderObserver { if (sample_forest) { vector<WordID> cur_prediction; ViterbiESentence(forest, &cur_prediction); - float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore(); + SufficientStats sstats; + ds[sent_id]->Evaluate(cur_prediction, &sstats); + float sentscore = metric.ComputeScore(sstats); cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore); vector<HypergraphSampler::Hypothesis> samples; HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples); for (unsigned i = 0; i < samples.size(); ++i) { - sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore(); + ds[sent_id]->Evaluate(samples[i].words, &sstats); + float sentscore = metric.ComputeScore(sstats); if (invert_score) sentscore *= -1.0; if (!cur_good || sentscore > cur_good->mt_metric) cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore); @@ -136,11 +142,13 @@ struct TrainingObserver : public DecoderObserver { } } else { KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size); + SufficientStats sstats; for (int i = 0; i < kbest_size; ++i) { const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); + ds[sent_id]->Evaluate(d->yield, &sstats); + float sentscore = metric.ComputeScore(sstats); if (invert_score) sentscore *= -1.0; // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; if (i == 0) @@ -192,15 +200,20 @@ int main(int argc, char** argv) { } vector<string> corpus; ReadTrainingCorpus(conf["source"].as<string>(), &corpus); - const string metric_name = conf["mt_metric"].as<string>(); - ScoreType type = ScoreTypeFromString(metric_name); - if (type == TER) { - invert_score = true; - } else { - invert_score = false; + + string metric_name = UppercaseString(conf["evaluation_metric"].as<string>()); + if (metric_name == "COMBI") { + cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n"; + metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5"; + } else if (metric_name == "BLEU") { + cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n"; + metric_name = "IBM_BLEU"; } - DocScorer ds(type, conf["reference"].as<vector<string> >(), ""); + EvaluationMetric* metric = EvaluationMetric::Instance(metric_name); + DocumentScorer ds(metric, conf["reference"].as<vector<string> >()); cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl; + invert_score = metric->IsErrorMetric(); + if (ds.size() != corpus.size()) { cerr << "Mismatched number of references (" << ds.size() << ") and sources (" << corpus.size() << ")\n"; return 1; @@ -221,7 +234,7 @@ int main(int argc, char** argv) { assert(corpus.size() > 0); vector<GoodBadOracle> oracles(corpus.size()); - TrainingObserver observer(conf["k_best_size"].as<int>(), ds, sample_forest, &oracles); + TrainingObserver observer(conf["k_best_size"].as<int>(), ds, *metric, sample_forest, &oracles); int cur_sent = 0; int lcount = 0; int normalizer = 0; |