summaryrefslogtreecommitdiff
path: root/mteval/ns.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-12-20 15:51:11 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2011-12-20 15:51:11 -0500
commit2eb3bb96c6f780c477585b33273fc0c0d56c80e4 (patch)
treebe8fd2a5df3251ce8fa0a908edc0b40cc2c22e9c /mteval/ns.cc
parent0da1f6de1b33bbff5cb99b1938bb07d050479f10 (diff)
new scorer interface is implemented, but not used
Diffstat (limited to 'mteval/ns.cc')
-rw-r--r--mteval/ns.cc67
1 files changed, 53 insertions, 14 deletions
diff --git a/mteval/ns.cc b/mteval/ns.cc
index 1045a51f..6139757d 100644
--- a/mteval/ns.cc
+++ b/mteval/ns.cc
@@ -1,5 +1,7 @@
#include "ns.h"
#include "ns_ter.h"
+#include "ns_ext.h"
+#include "ns_comb.h"
#include <cassert>
#include <cmath>
@@ -7,6 +9,9 @@
#include <iostream>
#include <sstream>
+#include "tdict.h"
+#include "stringlib.h"
+
using namespace std;
using boost::shared_ptr;
@@ -19,6 +24,7 @@ struct DefaultSegmentEvaluator : public SegmentEvaluator {
DefaultSegmentEvaluator(const vector<vector<WordID> >& refs, const EvaluationMetric* em) : refs_(refs), em_(em) {}
void Evaluate(const vector<WordID>& hyp, SufficientStats* out) const {
em_->ComputeSufficientStatistics(hyp, refs_, out);
+ out->id_ = em_->MetricId();
}
const vector<vector<WordID> > refs_;
const EvaluationMetric* em_;
@@ -28,6 +34,11 @@ shared_ptr<SegmentEvaluator> EvaluationMetric::CreateSegmentEvaluator(const vect
return shared_ptr<SegmentEvaluator>(new DefaultSegmentEvaluator(refs, this));
}
+#define MAX_SS_VECTOR_SIZE 50
+unsigned EvaluationMetric::SufficientStatisticsVectorSize() const {
+ return MAX_SS_VECTOR_SIZE;
+}
+
void EvaluationMetric::ComputeSufficientStatistics(const vector<WordID>&,
const vector<vector<WordID> >&,
SufficientStats*) const {
@@ -35,6 +46,12 @@ void EvaluationMetric::ComputeSufficientStatistics(const vector<WordID>&,
abort();
}
+string EvaluationMetric::DetailedScore(const SufficientStats& stats) const {
+ ostringstream os;
+ os << MetricId() << "=" << ComputeScore(stats);
+ return os.str();
+}
+
enum BleuType { IBM, Koehn, NIST };
template <unsigned int N = 4u, BleuType BrevityType = IBM>
struct BleuSegmentEvaluator : public SegmentEvaluator {
@@ -57,7 +74,7 @@ struct BleuSegmentEvaluator : public SegmentEvaluator {
void Evaluate(const vector<WordID>& hyp, SufficientStats* out) const {
out->fields.resize(N + N + 2);
- out->evaluation_metric = evaluation_metric;
+ out->id_ = evaluation_metric->MetricId();
for (unsigned i = 0; i < N+N+2; ++i) out->fields[i] = 0;
ComputeNgramStats(hyp, &out->fields[0], &out->fields[N], true);
@@ -157,7 +174,12 @@ struct BleuSegmentEvaluator : public SegmentEvaluator {
template <unsigned int N = 4u, BleuType BrevityType = IBM>
struct BleuMetric : public EvaluationMetric {
BleuMetric() : EvaluationMetric("IBM_BLEU") {}
- float ComputeScore(const SufficientStats& stats) const {
+ unsigned SufficientStatisticsVectorSize() const { return N*2 + 2; }
+ shared_ptr<SegmentEvaluator> CreateSegmentEvaluator(const vector<vector<WordID> >& refs) const {
+ return shared_ptr<SegmentEvaluator>(new BleuSegmentEvaluator<N,BrevityType>(refs, this));
+ }
+ float ComputeBreakdown(const SufficientStats& stats, float* bp, vector<float>* out) const {
+ if (out) { out->clear(); }
float log_bleu = 0;
int count = 0;
for (int i = 0; i < N; ++i) {
@@ -166,7 +188,7 @@ struct BleuMetric : public EvaluationMetric {
// smooth bleu
if (!cor_count) { cor_count = 0.01; }
float lprec = log(cor_count) - log(stats.fields[i+N]); // log(hyp_ngram_counts[i]);
- // if (precs) precs->push_back(exp(lprec));
+ if (out) out->push_back(exp(lprec));
log_bleu += lprec;
++count;
}
@@ -178,32 +200,51 @@ struct BleuMetric : public EvaluationMetric {
if (hyp_len < ref_len)
lbp = (hyp_len - ref_len) / hyp_len;
log_bleu += lbp;
- //if (bp) *bp = exp(lbp);
+ if (bp) *bp = exp(lbp);
return exp(log_bleu);
}
- shared_ptr<SegmentEvaluator> CreateSegmentEvaluator(const vector<vector<WordID> >& refs) const {
- return shared_ptr<SegmentEvaluator>(new BleuSegmentEvaluator<N,BrevityType>(refs, this));
+ string DetailedScore(const SufficientStats& stats) const {
+ char buf[2000];
+ vector<float> precs(N);
+ float bp;
+ float bleu = ComputeBreakdown(stats, &bp, &precs);
+ sprintf(buf, "BLEU = %.2f, %.1f|%.1f|%.1f|%.1f (brev=%.3f)",
+ bleu*100.0,
+ precs[0]*100.0,
+ precs[1]*100.0,
+ precs[2]*100.0,
+ precs[3]*100.0,
+ bp);
+ return buf;
+ }
+ float ComputeScore(const SufficientStats& stats) const {
+ return ComputeBreakdown(stats, NULL, NULL);
}
};
-EvaluationMetric* EvaluationMetric::Instance(const string& metric_id) {
+EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) {
static bool is_first = true;
if (is_first) {
instances_["NULL"] = NULL;
is_first = false;
}
+ const string metric_id = UppercaseString(imetric_id);
map<string, EvaluationMetric*>::iterator it = instances_.find(metric_id);
if (it == instances_.end()) {
EvaluationMetric* m = NULL;
- if (metric_id == "IBM_BLEU") {
+ if (metric_id == "IBM_BLEU") {
m = new BleuMetric<4, IBM>;
} else if (metric_id == "NIST_BLEU") {
m = new BleuMetric<4, NIST>;
- } else if (metric_id == "Koehn_BLEU") {
+ } else if (metric_id == "KOEHN_BLEU") {
m = new BleuMetric<4, Koehn>;
} else if (metric_id == "TER") {
m = new TERMetric;
+ } else if (metric_id == "METEOR") {
+ m = new ExternalMetric("METEOR", "java -Xmx1536m -jar /Users/cdyer/software/meteor/meteor-1.3.jar - - -mira -lower -t tune -l en");
+ } else if (metric_id.find("COMB:") == 0) {
+ m = new CombinationMetric(metric_id);
} else {
cerr << "Implement please: " << metric_id << endl;
abort();
@@ -220,9 +261,7 @@ EvaluationMetric* EvaluationMetric::Instance(const string& metric_id) {
SufficientStats::SufficientStats(const string& encoded) {
istringstream is(encoded);
- string type;
- is >> type;
- evaluation_metric = EvaluationMetric::Instance(type);
+ is >> id_;
float val;
while(is >> val)
fields.push_back(val);
@@ -230,8 +269,8 @@ SufficientStats::SufficientStats(const string& encoded) {
void SufficientStats::Encode(string* out) const {
ostringstream os;
- if (evaluation_metric)
- os << evaluation_metric->MetricId();
+ if (id_.size() > 0)
+ os << id_;
else
os << "NULL";
for (unsigned i = 0; i < fields.size(); ++i)