diff options
| author | Patrick Simianer <p@simianer.de> | 2012-03-13 09:24:47 +0100 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2012-03-13 09:24:47 +0100 | 
| commit | ef6085e558e26c8819f1735425761103021b6470 (patch) | |
| tree | 5cf70e4c48c64d838e1326b5a505c8c4061bff4a /mteval/mbr_kbest.cc | |
| parent | 10a232656a0c882b3b955d2bcfac138ce11e8a2e (diff) | |
| parent | dfbc278c1057555fda9312291c8024049e00b7d8 (diff) | |
merge with upstream
Diffstat (limited to 'mteval/mbr_kbest.cc')
| -rw-r--r-- | mteval/mbr_kbest.cc | 24 | 
1 files changed, 14 insertions, 10 deletions
| diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc index 64a6a8bf..2bd31566 100644 --- a/mteval/mbr_kbest.cc +++ b/mteval/mbr_kbest.cc @@ -5,7 +5,7 @@  #include "prob.h"  #include "tdict.h" -#include "scorer.h" +#include "ns.h"  #include "filelib.h"  #include "stringlib.h" @@ -17,7 +17,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {    po::options_description opts("Configuration options");    opts.add_options()          ("scale,a",po::value<double>()->default_value(1.0), "Posterior scaling factor (alpha)") -        ("loss_function,l",po::value<string>()->default_value("bleu"), "Loss function") +        ("evaluation_metric,m",po::value<string>()->default_value("ibm_bleu"), "Evaluation metric")          ("input,i",po::value<string>()->default_value("-"), "File to read k-best lists from")          ("output_list,L", "Show reranked list as output")          ("help,h", "Help"); @@ -75,13 +75,15 @@ bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, pro  int main(int argc, char** argv) {    po::variables_map conf;    InitCommandLine(argc, argv, &conf); -  const string metric = conf["loss_function"].as<string>(); +  const string smetric = conf["evaluation_metric"].as<string>(); +  EvaluationMetric* metric = EvaluationMetric::Instance(smetric); + +  const bool is_loss = (UppercaseString(smetric) == "TER");    const bool output_list = conf.count("output_list") > 0;    const string file = conf["input"].as<string>();    const double mbr_scale = conf["scale"].as<double>();    cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl; -  ScoreType type = ScoreTypeFromString(metric);    vector<pair<vector<WordID>, prob_t> > list;    ReadFile rf(file);    string sent_id; @@ -99,15 +101,17 @@ int main(int argc, char** argv) {      vector<double> mbr_scores(output_list ? list.size() : 0);      double mbr_loss = numeric_limits<double>::max();      for (int i = 0 ; i < list.size(); ++i) { -      vector<vector<WordID> > refs(1, list[i].first); -      //cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl; -      ScorerP scorer = SentenceScorer::CreateSentenceScorer(type, refs); +      const vector<vector<WordID> > refs(1, list[i].first); +      boost::shared_ptr<SegmentEvaluator> segeval = metric-> +          CreateSegmentEvaluator(refs); +        double wl_acc = 0;        for (int j = 0; j < list.size(); ++j) {          if (i != j) { -          ScoreP s = scorer->ScoreCandidate(list[j].first); -          double loss = 1.0 - s->ComputeScore(); -          if (type == TER || type == AER) loss = 1.0 - loss; +          SufficientStats ss; +          segeval->Evaluate(list[j].first, &ss); +          double loss = 1.0 - metric->ComputeScore(ss); +          if (is_loss) loss = 1.0 - loss;            double weighted_loss = loss * (joints[j] / marginal).as_float();            wl_acc += weighted_loss;            if ((!output_list) && wl_acc > mbr_loss) break; | 
