summaryrefslogtreecommitdiff
path: root/mteval/ns.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2012-03-13 09:24:47 +0100
committerPatrick Simianer <p@simianer.de>2012-03-13 09:24:47 +0100
commitc3a9ea64251605532c7954959662643a6a927bb7 (patch)
treefed6048a5acdaf3834740107771c2bc48f26fd4d /mteval/ns.cc
parent867bca3e5fa0cdd63bf032e5859fb5092d9a4ca1 (diff)
parenta45af4a3704531a8382cd231f6445b3a33b598a3 (diff)
merge with upstream
Diffstat (limited to 'mteval/ns.cc')
-rw-r--r--mteval/ns.cc290
1 files changed, 290 insertions, 0 deletions
diff --git a/mteval/ns.cc b/mteval/ns.cc
new file mode 100644
index 00000000..788f809a
--- /dev/null
+++ b/mteval/ns.cc
@@ -0,0 +1,290 @@
+#include "ns.h"
+#include "ns_ter.h"
+#include "ns_ext.h"
+#include "ns_comb.h"
+
+#include <cstdio>
+#include <cassert>
+#include <cmath>
+#include <cstdlib>
+#include <iostream>
+#include <sstream>
+
+#include "tdict.h"
+#include "stringlib.h"
+
+using namespace std;
+using boost::shared_ptr;
+
+map<string, EvaluationMetric*> EvaluationMetric::instances_;
+
+SegmentEvaluator::~SegmentEvaluator() {}
+EvaluationMetric::~EvaluationMetric() {}
+
+bool EvaluationMetric::IsErrorMetric() const {
+ return false;
+}
+
+struct DefaultSegmentEvaluator : public SegmentEvaluator {
+ DefaultSegmentEvaluator(const vector<vector<WordID> >& refs, const EvaluationMetric* em) : refs_(refs), em_(em) {}
+ void Evaluate(const vector<WordID>& hyp, SufficientStats* out) const {
+ em_->ComputeSufficientStatistics(hyp, refs_, out);
+ out->id_ = em_->MetricId();
+ }
+ const vector<vector<WordID> > refs_;
+ const EvaluationMetric* em_;
+};
+
+shared_ptr<SegmentEvaluator> EvaluationMetric::CreateSegmentEvaluator(const vector<vector<WordID> >& refs) const {
+ return shared_ptr<SegmentEvaluator>(new DefaultSegmentEvaluator(refs, this));
+}
+
+#define MAX_SS_VECTOR_SIZE 50
+unsigned EvaluationMetric::SufficientStatisticsVectorSize() const {
+ return MAX_SS_VECTOR_SIZE;
+}
+
+void EvaluationMetric::ComputeSufficientStatistics(const vector<WordID>&,
+ const vector<vector<WordID> >&,
+ SufficientStats*) const {
+ cerr << "Base class ComputeSufficientStatistics should not be called.\n";
+ abort();
+}
+
+string EvaluationMetric::DetailedScore(const SufficientStats& stats) const {
+ ostringstream os;
+ os << MetricId() << "=" << ComputeScore(stats);
+ return os.str();
+}
+
+enum BleuType { IBM, Koehn, NIST };
+template <unsigned int N = 4u, BleuType BrevityType = IBM>
+struct BleuSegmentEvaluator : public SegmentEvaluator {
+ BleuSegmentEvaluator(const vector<vector<WordID> >& refs, const EvaluationMetric* em) : evaluation_metric(em) {
+ assert(refs.size() > 0);
+ float tot = 0;
+ int smallest = 9999999;
+ for (vector<vector<WordID> >::const_iterator ci = refs.begin();
+ ci != refs.end(); ++ci) {
+ lengths_.push_back(ci->size());
+ tot += lengths_.back();
+ if (lengths_.back() < smallest) smallest = lengths_.back();
+ CountRef(*ci);
+ }
+ if (BrevityType == Koehn)
+ lengths_[0] = tot / refs.size();
+ if (BrevityType == NIST)
+ lengths_[0] = smallest;
+ }
+
+ void Evaluate(const vector<WordID>& hyp, SufficientStats* out) const {
+ out->fields.resize(N + N + 2);
+ out->id_ = evaluation_metric->MetricId();
+ for (unsigned i = 0; i < N+N+2; ++i) out->fields[i] = 0;
+
+ ComputeNgramStats(hyp, &out->fields[0], &out->fields[N], true);
+ float& hyp_len = out->fields[2*N];
+ float& ref_len = out->fields[2*N + 1];
+ hyp_len = hyp.size();
+ ref_len = lengths_[0];
+ if (lengths_.size() > 1 && BrevityType == IBM) {
+ float bestd = 2000000;
+ float hl = hyp.size();
+ float bl = -1;
+ for (vector<float>::const_iterator ci = lengths_.begin(); ci != lengths_.end(); ++ci) {
+ if (fabs(*ci - hl) < bestd) {
+ bestd = fabs(*ci - hl);
+ bl = *ci;
+ }
+ }
+ ref_len = bl;
+ }
+ }
+
+ struct NGramCompare {
+ int operator() (const vector<WordID>& a, const vector<WordID>& b) {
+ const size_t as = a.size();
+ const size_t bs = b.size();
+ const size_t s = (as < bs ? as : bs);
+ for (size_t i = 0; i < s; ++i) {
+ int d = a[i] - b[i];
+ if (d < 0) return true;
+ if (d > 0) return false;
+ }
+ return as < bs;
+ }
+ };
+ typedef map<vector<WordID>, pair<int,int>, NGramCompare> NGramCountMap;
+
+ void CountRef(const vector<WordID>& ref) {
+ NGramCountMap tc;
+ vector<WordID> ngram(N);
+ int s = ref.size();
+ for (int j=0; j<s; ++j) {
+ int remaining = s-j;
+ int k = (N < remaining ? N : remaining);
+ ngram.clear();
+ for (int i=1; i<=k; ++i) {
+ ngram.push_back(ref[j + i - 1]);
+ tc[ngram].first++;
+ }
+ }
+ for (typename NGramCountMap::iterator i = tc.begin(); i != tc.end(); ++i) {
+ pair<int,int>& p = ngrams_[i->first];
+ if (p.first < i->second.first)
+ p = i->second;
+ }
+ }
+
+ void ComputeNgramStats(const vector<WordID>& sent,
+ float* correct, // N elements reserved
+ float* hyp, // N elements reserved
+ bool clip_counts = true) const {
+ // clear clipping stats
+ for (typename NGramCountMap::iterator it = ngrams_.begin(); it != ngrams_.end(); ++it)
+ it->second.second = 0;
+
+ vector<WordID> ngram(N);
+ *correct *= 0;
+ *hyp *= 0;
+ int s = sent.size();
+ for (int j=0; j<s; ++j) {
+ int remaining = s-j;
+ int k = (N < remaining ? N : remaining);
+ ngram.clear();
+ for (int i=1; i<=k; ++i) {
+ ngram.push_back(sent[j + i - 1]);
+ pair<int,int>& p = ngrams_[ngram];
+ if(clip_counts){
+ if (p.second < p.first) {
+ ++p.second;
+ correct[i-1]++;
+ }
+ } else {
+ ++p.second;
+ correct[i-1]++;
+ }
+ // if the 1 gram isn't found, don't try to match don't need to match any 2- 3- .. grams:
+ if (!p.first) {
+ for (; i<=k; ++i)
+ hyp[i-1]++;
+ } else {
+ hyp[i-1]++;
+ }
+ }
+ }
+ }
+
+ const EvaluationMetric* evaluation_metric;
+ vector<float> lengths_;
+ mutable NGramCountMap ngrams_;
+};
+
+template <unsigned int N = 4u, BleuType BrevityType = IBM>
+struct BleuMetric : public EvaluationMetric {
+ BleuMetric() : EvaluationMetric(BrevityType == IBM ? "IBM_BLEU" : (BrevityType == Koehn ? "KOEHN_BLEU" : "NIST_BLEU")) {}
+ unsigned SufficientStatisticsVectorSize() const { return N*2 + 2; }
+ shared_ptr<SegmentEvaluator> CreateSegmentEvaluator(const vector<vector<WordID> >& refs) const {
+ return shared_ptr<SegmentEvaluator>(new BleuSegmentEvaluator<N,BrevityType>(refs, this));
+ }
+ float ComputeBreakdown(const SufficientStats& stats, float* bp, vector<float>* out) const {
+ if (out) { out->clear(); }
+ float log_bleu = 0;
+ int count = 0;
+ for (int i = 0; i < N; ++i) {
+ if (stats.fields[i+N] > 0) {
+ float cor_count = stats.fields[i]; // correct_ngram_hit_counts[i];
+ // smooth bleu
+ if (!cor_count) { cor_count = 0.01; }
+ float lprec = log(cor_count) - log(stats.fields[i+N]); // log(hyp_ngram_counts[i]);
+ if (out) out->push_back(exp(lprec));
+ log_bleu += lprec;
+ ++count;
+ }
+ }
+ log_bleu /= count;
+ float lbp = 0.0;
+ const float& hyp_len = stats.fields[2*N];
+ const float& ref_len = stats.fields[2*N + 1];
+ if (hyp_len < ref_len)
+ lbp = (hyp_len - ref_len) / hyp_len;
+ log_bleu += lbp;
+ if (bp) *bp = exp(lbp);
+ return exp(log_bleu);
+ }
+ string DetailedScore(const SufficientStats& stats) const {
+ char buf[2000];
+ vector<float> precs(N);
+ float bp;
+ float bleu = ComputeBreakdown(stats, &bp, &precs);
+ sprintf(buf, "%s = %.2f, %.1f|%.1f|%.1f|%.1f (brev=%.3f)",
+ MetricId().c_str(),
+ bleu*100.0,
+ precs[0]*100.0,
+ precs[1]*100.0,
+ precs[2]*100.0,
+ precs[3]*100.0,
+ bp);
+ return buf;
+ }
+ float ComputeScore(const SufficientStats& stats) const {
+ return ComputeBreakdown(stats, NULL, NULL);
+ }
+};
+
+EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) {
+ static bool is_first = true;
+ if (is_first) {
+ instances_["NULL"] = NULL;
+ is_first = false;
+ }
+ const string metric_id = UppercaseString(imetric_id);
+
+ map<string, EvaluationMetric*>::iterator it = instances_.find(metric_id);
+ if (it == instances_.end()) {
+ EvaluationMetric* m = NULL;
+ if (metric_id == "IBM_BLEU") {
+ m = new BleuMetric<4, IBM>;
+ } else if (metric_id == "NIST_BLEU") {
+ m = new BleuMetric<4, NIST>;
+ } else if (metric_id == "KOEHN_BLEU") {
+ m = new BleuMetric<4, Koehn>;
+ } else if (metric_id == "TER") {
+ m = new TERMetric;
+ } else if (metric_id == "METEOR") {
+ m = new ExternalMetric("METEOR", "java -Xmx1536m -jar /Users/cdyer/software/meteor/meteor-1.3.jar - - -mira -lower -t tune -l en");
+ } else if (metric_id.find("COMB:") == 0) {
+ m = new CombinationMetric(metric_id);
+ } else {
+ cerr << "Implement please: " << metric_id << endl;
+ abort();
+ }
+ if (m->MetricId() != metric_id) {
+ cerr << "Registry error: " << metric_id << " vs. " << m->MetricId() << endl;
+ abort();
+ }
+ return instances_[metric_id] = m;
+ } else {
+ return it->second;
+ }
+}
+
+SufficientStats::SufficientStats(const string& encoded) {
+ istringstream is(encoded);
+ is >> id_;
+ float val;
+ while(is >> val)
+ fields.push_back(val);
+}
+
+void SufficientStats::Encode(string* out) const {
+ ostringstream os;
+ if (id_.size() > 0)
+ os << id_;
+ else
+ os << "NULL";
+ for (unsigned i = 0; i < fields.size(); ++i)
+ os << ' ' << fields[i];
+ *out = os.str();
+}
+