diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-01-27 13:19:27 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-01-27 13:19:27 -0500 |
commit | 203c3c3357b9ed8cfe44932c2bf5ea19eba6238c (patch) | |
tree | c446f8e8afbe194ef656b33cfc643f83633cf18c /mteval | |
parent | 481a120564fdb73c8c6833e2102acb533683261c (diff) |
migration to new metric api for vest, clean up of unsupported/not functional code
Diffstat (limited to 'mteval')
-rw-r--r-- | mteval/mbr_kbest.cc | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc index 64a6a8bf..b5e4750c 100644 --- a/mteval/mbr_kbest.cc +++ b/mteval/mbr_kbest.cc @@ -5,7 +5,7 @@ #include "prob.h" #include "tdict.h" -#include "scorer.h" +#include "ns.h" #include "filelib.h" #include "stringlib.h" @@ -17,7 +17,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() ("scale,a",po::value<double>()->default_value(1.0), "Posterior scaling factor (alpha)") - ("loss_function,l",po::value<string>()->default_value("bleu"), "Loss function") + ("evaluation_metric,m",po::value<string>()->default_value("ibm_bleu"), "Evaluation metric") ("input,i",po::value<string>()->default_value("-"), "File to read k-best lists from") ("output_list,L", "Show reranked list as output") ("help,h", "Help"); @@ -75,13 +75,14 @@ bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, pro int main(int argc, char** argv) { po::variables_map conf; InitCommandLine(argc, argv, &conf); - const string metric = conf["loss_function"].as<string>(); + const string smetric = conf["evaluation_metric"].as<string>(); + EvaluationMetric* metric = EvaluationMetric::Instance(smetric); + const bool is_loss = (UppercaseString(smetric) == "TER"); const bool output_list = conf.count("output_list") > 0; const string file = conf["input"].as<string>(); const double mbr_scale = conf["scale"].as<double>(); cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl; - ScoreType type = ScoreTypeFromString(metric); vector<pair<vector<WordID>, prob_t> > list; ReadFile rf(file); string sent_id; @@ -99,15 +100,15 @@ int main(int argc, char** argv) { vector<double> mbr_scores(output_list ? list.size() : 0); double mbr_loss = numeric_limits<double>::max(); for (int i = 0 ; i < list.size(); ++i) { - vector<vector<WordID> > refs(1, list[i].first); - //cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl; - ScorerP scorer = SentenceScorer::CreateSentenceScorer(type, refs); + const vector<vector<WordID> > refs(1, list[i].first); + double wl_acc = 0; for (int j = 0; j < list.size(); ++j) { if (i != j) { - ScoreP s = scorer->ScoreCandidate(list[j].first); - double loss = 1.0 - s->ComputeScore(); - if (type == TER || type == AER) loss = 1.0 - loss; + SufficientStats ss; + metric->ComputeSufficientStatistics(list[j].first, refs, &ss); + double loss = 1.0 - metric->ComputeScore(ss); + if (is_loss) loss = 1.0 - loss; double weighted_loss = loss * (joints[j] / marginal).as_float(); wl_acc += weighted_loss; if ((!output_list) && wl_acc > mbr_loss) break; |