summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mteval/mbr_kbest.cc5
1 files changed, 4 insertions, 1 deletions
diff --git a/mteval/mbr_kbest.cc b/mteval/mbr_kbest.cc
index b5e4750c..2bd31566 100644
--- a/mteval/mbr_kbest.cc
+++ b/mteval/mbr_kbest.cc
@@ -77,6 +77,7 @@ int main(int argc, char** argv) {
InitCommandLine(argc, argv, &conf);
const string smetric = conf["evaluation_metric"].as<string>();
EvaluationMetric* metric = EvaluationMetric::Instance(smetric);
+
const bool is_loss = (UppercaseString(smetric) == "TER");
const bool output_list = conf.count("output_list") > 0;
const string file = conf["input"].as<string>();
@@ -101,12 +102,14 @@ int main(int argc, char** argv) {
double mbr_loss = numeric_limits<double>::max();
for (int i = 0 ; i < list.size(); ++i) {
const vector<vector<WordID> > refs(1, list[i].first);
+ boost::shared_ptr<SegmentEvaluator> segeval = metric->
+ CreateSegmentEvaluator(refs);
double wl_acc = 0;
for (int j = 0; j < list.size(); ++j) {
if (i != j) {
SufficientStats ss;
- metric->ComputeSufficientStatistics(list[j].first, refs, &ss);
+ segeval->Evaluate(list[j].first, &ss);
double loss = 1.0 - metric->ComputeScore(ss);
if (is_loss) loss = 1.0 - loss;
double weighted_loss = loss * (joints[j] / marginal).as_float();