summaryrefslogtreecommitdiff
path: root/training/mira/kbest_mira.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/mira/kbest_mira.cc')
-rw-r--r--training/mira/kbest_mira.cc41
1 files changed, 27 insertions, 14 deletions
diff --git a/training/mira/kbest_mira.cc b/training/mira/kbest_mira.cc
index 8b7993dd..bcb261c9 100644
--- a/training/mira/kbest_mira.cc
+++ b/training/mira/kbest_mira.cc
@@ -8,9 +8,11 @@
#include <boost/program_options.hpp>
#include <boost/program_options/variables_map.hpp>
+#include "stringlib.h"
#include "hg_sampler.h"
#include "sentence_metadata.h"
-#include "scorer.h"
+#include "ns.h"
+#include "ns_docscorer.h"
#include "verbose.h"
#include "viterbi.h"
#include "hg.h"
@@ -91,8 +93,9 @@ struct GoodBadOracle {
};
struct TrainingObserver : public DecoderObserver {
- TrainingObserver(const int k, const DocScorer& d, bool sf, vector<GoodBadOracle>* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {}
- const DocScorer& ds;
+ TrainingObserver(const int k, const DocumentScorer& d, const EvaluationMetric& m, bool sf, vector<GoodBadOracle>* o) : ds(d), metric(m), oracles(*o), kbest_size(k), sample_forest(sf) {}
+ const DocumentScorer& ds;
+ const EvaluationMetric& metric;
vector<GoodBadOracle>& oracles;
std::tr1::shared_ptr<HypothesisInfo> cur_best;
const int kbest_size;
@@ -121,13 +124,16 @@ struct TrainingObserver : public DecoderObserver {
if (sample_forest) {
vector<WordID> cur_prediction;
ViterbiESentence(forest, &cur_prediction);
- float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore();
+ SufficientStats sstats;
+ ds[sent_id]->Evaluate(cur_prediction, &sstats);
+ float sentscore = metric.ComputeScore(sstats);
cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore);
vector<HypergraphSampler::Hypothesis> samples;
HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples);
for (unsigned i = 0; i < samples.size(); ++i) {
- sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore();
+ ds[sent_id]->Evaluate(samples[i].words, &sstats);
+ float sentscore = metric.ComputeScore(sstats);
if (invert_score) sentscore *= -1.0;
if (!cur_good || sentscore > cur_good->mt_metric)
cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore);
@@ -136,11 +142,13 @@ struct TrainingObserver : public DecoderObserver {
}
} else {
KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, kbest_size);
+ SufficientStats sstats;
for (int i = 0; i < kbest_size; ++i) {
const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore();
+ ds[sent_id]->Evaluate(d->yield, &sstats);
+ float sentscore = metric.ComputeScore(sstats);
if (invert_score) sentscore *= -1.0;
// cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl;
if (i == 0)
@@ -192,15 +200,20 @@ int main(int argc, char** argv) {
}
vector<string> corpus;
ReadTrainingCorpus(conf["source"].as<string>(), &corpus);
- const string metric_name = conf["mt_metric"].as<string>();
- ScoreType type = ScoreTypeFromString(metric_name);
- if (type == TER) {
- invert_score = true;
- } else {
- invert_score = false;
+
+ string metric_name = UppercaseString(conf["evaluation_metric"].as<string>());
+ if (metric_name == "COMBI") {
+ cerr << "WARNING: 'combi' metric is no longer supported, switching to 'COMB:TER=-0.5;IBM_BLEU=0.5'\n";
+ metric_name = "COMB:TER=-0.5;IBM_BLEU=0.5";
+ } else if (metric_name == "BLEU") {
+ cerr << "WARNING: 'BLEU' is ambiguous, assuming 'IBM_BLEU'\n";
+ metric_name = "IBM_BLEU";
}
- DocScorer ds(type, conf["reference"].as<vector<string> >(), "");
+ EvaluationMetric* metric = EvaluationMetric::Instance(metric_name);
+ DocumentScorer ds(metric, conf["reference"].as<vector<string> >());
cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl;
+ invert_score = metric->IsErrorMetric();
+
if (ds.size() != corpus.size()) {
cerr << "Mismatched number of references (" << ds.size() << ") and sources (" << corpus.size() << ")\n";
return 1;
@@ -221,7 +234,7 @@ int main(int argc, char** argv) {
assert(corpus.size() > 0);
vector<GoodBadOracle> oracles(corpus.size());
- TrainingObserver observer(conf["k_best_size"].as<int>(), ds, sample_forest, &oracles);
+ TrainingObserver observer(conf["k_best_size"].as<int>(), ds, *metric, sample_forest, &oracles);
int cur_sent = 0;
int lcount = 0;
int normalizer = 0;