From 203c3c3357b9ed8cfe44932c2bf5ea19eba6238c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 27 Jan 2012 13:19:27 -0500 Subject: migration to new metric api for vest, clean up of unsupported/not functional code --- mteval/mbr_kbest.cc | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) (limited to 'mteval') diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc index 64a6a8bf..b5e4750c 100644 --- a/mteval/mbr_kbest.cc +++ b/mteval/mbr_kbest.cc @@ -5,7 +5,7 @@ #include "prob.h" #include "tdict.h" -#include "scorer.h" +#include "ns.h" #include "filelib.h" #include "stringlib.h" @@ -17,7 +17,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("scale,a",po::value()->default_value(1.0), "Posterior scaling factor (alpha)") - ("loss_function,l",po::value()->default_value("bleu"), "Loss function") + ("evaluation_metric,m",po::value()->default_value("ibm_bleu"), "Evaluation metric") ("input,i",po::value()->default_value("-"), "File to read k-best lists from") ("output_list,L", "Show reranked list as output") ("help,h", "Help"); @@ -75,13 +75,14 @@ bool ReadKBestList(istream* in, string* sent_id, vector, pro int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); - const string metric = conf["loss_function"].as(); + const string smetric = conf["evaluation_metric"].as(); + EvaluationMetric* metric = EvaluationMetric::Instance(smetric); + const bool is_loss = (UppercaseString(smetric) == "TER"); const bool output_list = conf.count("output_list") > 0; const string file = conf["input"].as(); const double mbr_scale = conf["scale"].as(); cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl; - ScoreType type = ScoreTypeFromString(metric); vector, prob_t> > list; ReadFile rf(file); string sent_id; @@ -99,15 +100,15 @@ int main(int argc, char** argv) { vector mbr_scores(output_list ? list.size() : 0); double mbr_loss = numeric_limits::max(); for (int i = 0 ; i < list.size(); ++i) { - vector > refs(1, list[i].first); - //cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl; - ScorerP scorer = SentenceScorer::CreateSentenceScorer(type, refs); + const vector > refs(1, list[i].first); + double wl_acc = 0; for (int j = 0; j < list.size(); ++j) { if (i != j) { - ScoreP s = scorer->ScoreCandidate(list[j].first); - double loss = 1.0 - s->ComputeScore(); - if (type == TER || type == AER) loss = 1.0 - loss; + SufficientStats ss; + metric->ComputeSufficientStatistics(list[j].first, refs, &ss); + double loss = 1.0 - metric->ComputeScore(ss); + if (is_loss) loss = 1.0 - loss; double weighted_loss = loss * (joints[j] / marginal).as_float(); wl_acc += weighted_loss; if ((!output_list) && wl_acc > mbr_loss) break; -- cgit v1.2.3