summaryrefslogtreecommitdiff
path: root/mteval
diff options
context:
space:
mode:
Diffstat (limited to 'mteval')
-rw-r--r--mteval/Jamfile2
-rw-r--r--mteval/Makefile.am2
-rw-r--r--mteval/ns.cc3
-rw-r--r--mteval/ns_cer.cc55
-rw-r--r--mteval/ns_cer.h23
5 files changed, 83 insertions, 2 deletions
diff --git a/mteval/Jamfile b/mteval/Jamfile
index 6260caea..3ed2c2cc 100644
--- a/mteval/Jamfile
+++ b/mteval/Jamfile
@@ -1,6 +1,6 @@
import testing ;
-lib mteval : ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc ..//utils : <include>. : : <include>. <library>..//z ;
+lib mteval : ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc ns_cer.cc ..//utils : <include>. : : <include>. <library>..//z ;
exe fast_score : fast_score.cc mteval ..//utils ..//boost_program_options ;
exe mbr_kbest : mbr_kbest.cc mteval ..//utils ..//boost_program_options ;
alias programs : fast_score mbr_kbest ;
diff --git a/mteval/Makefile.am b/mteval/Makefile.am
index 8d844e24..22550c99 100644
--- a/mteval/Makefile.am
+++ b/mteval/Makefile.am
@@ -8,7 +8,7 @@ TESTS = scorer_test
noinst_LIBRARIES = libmteval.a
-libmteval_a_SOURCES = ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc
+libmteval_a_SOURCES = ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc ns_cer.cc
fast_score_SOURCES = fast_score.cc
fast_score_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a -lz
diff --git a/mteval/ns.cc b/mteval/ns.cc
index 8d354677..33952da7 100644
--- a/mteval/ns.cc
+++ b/mteval/ns.cc
@@ -2,6 +2,7 @@
#include "ns_ter.h"
#include "ns_ext.h"
#include "ns_comb.h"
+#include "ns_cer.h"
#include <cstdio>
#include <cassert>
@@ -254,6 +255,8 @@ EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) {
m = new ExternalMetric("METEOR", "java -Xmx1536m -jar /Users/cdyer/software/meteor/meteor-1.3.jar - - -mira -lower -t tune -l en");
} else if (metric_id.find("COMB:") == 0) {
m = new CombinationMetric(metric_id);
+ } else if (metric_id == "CER") {
+ m = new CERMetric;
} else {
cerr << "Implement please: " << metric_id << endl;
abort();
diff --git a/mteval/ns_cer.cc b/mteval/ns_cer.cc
new file mode 100644
index 00000000..a843d471
--- /dev/null
+++ b/mteval/ns_cer.cc
@@ -0,0 +1,55 @@
+#include "ns_cer.h"
+#include "tdict.h"
+
+static const unsigned kNUMFIELDS = 2;
+static const unsigned kEDITDISTANCE = 0;
+static const unsigned kCHARCOUNT = 1;
+
+bool CERMetric::IsErrorMetric() const {
+ return true;
+}
+
+unsigned CERMetric::SufficientStatisticsVectorSize() const {
+ return 2;
+}
+
+unsigned CERMetric::EditDistance(const std::string& hyp,
+ const std::string& ref) const {
+ const unsigned m = hyp.size(), n = ref.size();
+ std::vector<unsigned> edit((m + 1) * 2);
+ for(unsigned i = 0; i < n + 1; i++) {
+ for(unsigned j = 0; j < m + 1; j++) {
+ if(i == 0)
+ edit[j] = j;
+ else if(j == 0)
+ edit[(i%2)*(m+1)] = i;
+ else
+ edit[(i%2)*(m+1) + j] = std::min(std::min(edit[(i%2)*(m+1) + j-1] + 1,
+ edit[((i-1)%2)*(m+1) + j] + 1),
+ edit[((i-1)%2)*(m+1) + (j-1)]
+ + (hyp[j-1] == ref[i-1] ? 0 : 1));
+
+ }
+ }
+ return edit[(n%2)*(m+1) + m];
+}
+
+void CERMetric::ComputeSufficientStatistics(const std::vector<WordID>& hyp,
+ const std::vector<std::vector<WordID> >& refs,
+ SufficientStats* out) const {
+ out->fields.resize(kNUMFIELDS);
+ std::string hyp_str(TD::GetString(hyp));
+ float best_score = hyp_str.size();
+ for (size_t i = 0; i < refs.size(); ++i) {
+ std::string ref_str(TD::GetString(refs[i]));
+ float score = EditDistance(hyp_str, ref_str);
+ if (score < best_score) {
+ out->fields[kEDITDISTANCE] = score;
+ out->fields[kCHARCOUNT] = ref_str.size();
+ best_score = score;
+ }
+ }
+}
+float CERMetric::ComputeScore(const SufficientStats& stats) const {
+ return stats.fields[kEDITDISTANCE] / stats.fields[kCHARCOUNT];
+}
diff --git a/mteval/ns_cer.h b/mteval/ns_cer.h
new file mode 100644
index 00000000..9d211181
--- /dev/null
+++ b/mteval/ns_cer.h
@@ -0,0 +1,23 @@
+#ifndef _NS_CER_H_
+#define _NS_CER_H_
+
+#include "ns.h"
+
+class CERMetric : public EvaluationMetric {
+ friend class EvaluationMetric;
+ private:
+ unsigned EditDistance(const std::string& hyp,
+ const std::string& ref) const;
+ protected:
+ CERMetric() : EvaluationMetric("CER") {}
+
+ public:
+ virtual bool IsErrorMetric() const;
+ virtual unsigned SufficientStatisticsVectorSize() const;
+ virtual void ComputeSufficientStatistics(const std::vector<WordID>& hyp,
+ const std::vector<std::vector<WordID> >& refs,
+ SufficientStats* out) const;
+ virtual float ComputeScore(const SufficientStats& stats) const;
+};
+
+#endif