summaryrefslogtreecommitdiff
path: root/mteval/mbr_kbest.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2012-03-13 09:24:47 +0100
committerPatrick Simianer <p@simianer.de>2012-03-13 09:24:47 +0100
commitc3a9ea64251605532c7954959662643a6a927bb7 (patch)
treefed6048a5acdaf3834740107771c2bc48f26fd4d /mteval/mbr_kbest.cc
parent867bca3e5fa0cdd63bf032e5859fb5092d9a4ca1 (diff)
parenta45af4a3704531a8382cd231f6445b3a33b598a3 (diff)
merge with upstream
Diffstat (limited to 'mteval/mbr_kbest.cc')
-rw-r--r--mteval/mbr_kbest.cc24
1 files changed, 14 insertions, 10 deletions
diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc
index 64a6a8bf..2bd31566 100644
--- a/mteval/mbr_kbest.cc
+++ b/mteval/mbr_kbest.cc
@@ -5,7 +5,7 @@
#include "prob.h"
#include "tdict.h"
-#include "scorer.h"
+#include "ns.h"
#include "filelib.h"
#include "stringlib.h"
@@ -17,7 +17,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
("scale,a",po::value<double>()->default_value(1.0), "Posterior scaling factor (alpha)")
- ("loss_function,l",po::value<string>()->default_value("bleu"), "Loss function")
+ ("evaluation_metric,m",po::value<string>()->default_value("ibm_bleu"), "Evaluation metric")
("input,i",po::value<string>()->default_value("-"), "File to read k-best lists from")
("output_list,L", "Show reranked list as output")
("help,h", "Help");
@@ -75,13 +75,15 @@ bool ReadKBestList(istream* in, string* sent_id, vector<pair<vector<WordID>, pro
int main(int argc, char** argv) {
po::variables_map conf;
InitCommandLine(argc, argv, &conf);
- const string metric = conf["loss_function"].as<string>();
+ const string smetric = conf["evaluation_metric"].as<string>();
+ EvaluationMetric* metric = EvaluationMetric::Instance(smetric);
+
+ const bool is_loss = (UppercaseString(smetric) == "TER");
const bool output_list = conf.count("output_list") > 0;
const string file = conf["input"].as<string>();
const double mbr_scale = conf["scale"].as<double>();
cerr << "Posterior scaling factor (alpha) = " << mbr_scale << endl;
- ScoreType type = ScoreTypeFromString(metric);
vector<pair<vector<WordID>, prob_t> > list;
ReadFile rf(file);
string sent_id;
@@ -99,15 +101,17 @@ int main(int argc, char** argv) {
vector<double> mbr_scores(output_list ? list.size() : 0);
double mbr_loss = numeric_limits<double>::max();
for (int i = 0 ; i < list.size(); ++i) {
- vector<vector<WordID> > refs(1, list[i].first);
- //cerr << i << ": " << list[i].second <<"\t" << TD::GetString(list[i].first) << endl;
- ScorerP scorer = SentenceScorer::CreateSentenceScorer(type, refs);
+ const vector<vector<WordID> > refs(1, list[i].first);
+ boost::shared_ptr<SegmentEvaluator> segeval = metric->
+ CreateSegmentEvaluator(refs);
+
double wl_acc = 0;
for (int j = 0; j < list.size(); ++j) {
if (i != j) {
- ScoreP s = scorer->ScoreCandidate(list[j].first);
- double loss = 1.0 - s->ComputeScore();
- if (type == TER || type == AER) loss = 1.0 - loss;
+ SufficientStats ss;
+ segeval->Evaluate(list[j].first, &ss);
+ double loss = 1.0 - metric->ComputeScore(ss);
+ if (is_loss) loss = 1.0 - loss;
double weighted_loss = loss * (joints[j] / marginal).as_float();
wl_acc += weighted_loss;
if ((!output_list) && wl_acc > mbr_loss) break;