summaryrefslogtreecommitdiff
path: root/mteval
diff options
context:
space:
mode:
Diffstat (limited to 'mteval')
-rw-r--r--mteval/Makefile.am3
-rw-r--r--mteval/levenshtein.h29
-rw-r--r--mteval/ns.cc3
-rw-r--r--mteval/ns_cer.cc26
-rw-r--r--mteval/ns_cer.h3
-rw-r--r--mteval/ns_wer.cc35
-rw-r--r--mteval/ns_wer.h20
7 files changed, 94 insertions, 25 deletions
diff --git a/mteval/Makefile.am b/mteval/Makefile.am
index c833eb01..aac3e6b5 100644
--- a/mteval/Makefile.am
+++ b/mteval/Makefile.am
@@ -14,6 +14,7 @@ libmteval_a_SOURCES = \
aer_scorer.h \
comb_scorer.h \
external_scorer.h \
+ levenshtein.h \
ns.h \
ns_cer.h \
ns_comb.h \
@@ -21,6 +22,7 @@ libmteval_a_SOURCES = \
ns_ext.h \
ns_ssk.h \
ns_ter.h \
+ ns_wer.h \
scorer.h \
ter.h \
aer_scorer.cc \
@@ -34,6 +36,7 @@ libmteval_a_SOURCES = \
ns_ext.cc \
ns_ssk.cc \
ns_ter.cc \
+ ns_wer.cc \
scorer.cc \
ter.cc
diff --git a/mteval/levenshtein.h b/mteval/levenshtein.h
new file mode 100644
index 00000000..13a97047
--- /dev/null
+++ b/mteval/levenshtein.h
@@ -0,0 +1,29 @@
+#ifndef _LEVENSHTEIN_H_
+#define _LEVENSHTEIN_H_
+
+namespace cdec {
+
+template <typename V>
+inline unsigned LevenshteinDistance(const V& a, const V& b) {
+ const unsigned m = a.size(), n = b.size();
+ std::vector<unsigned> edit((m + 1) * 2);
+ for (unsigned i = 0; i <= n; i++) {
+ for (unsigned j = 0; j <= m; j++) {
+ if (i == 0)
+ edit[j] = j;
+ else if (j == 0)
+ edit[(i % 2) * (m + 1)] = i;
+ else
+ edit[(i % 2) * (m + 1) + j] = std::min(std::min(
+ edit[(i % 2) * (m + 1) + j - 1] + 1,
+ edit[((i - 1) % 2) * (m + 1) + j] + 1),
+ edit[((i - 1) % 2) * (m + 1) + (j - 1)]
+ + (a[j - 1] == b[i - 1] ? 0 : 1));
+ }
+ }
+ return edit[(n % 2) * (m + 1) + m];
+}
+
+}
+
+#endif
diff --git a/mteval/ns.cc b/mteval/ns.cc
index c1ea238b..075e0121 100644
--- a/mteval/ns.cc
+++ b/mteval/ns.cc
@@ -3,6 +3,7 @@
#include "ns_ext.h"
#include "ns_comb.h"
#include "ns_cer.h"
+#include "ns_wer.h"
#include "ns_ssk.h"
#include <cstdio>
@@ -285,6 +286,8 @@ EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) {
m = new CombinationMetric(metric_id);
} else if (metric_id == "CER") {
m = new CERMetric;
+ } else if (metric_id == "WER") {
+ m = new WERMetric;
} else {
cerr << "Implement please: " << metric_id << endl;
abort();
diff --git a/mteval/ns_cer.cc b/mteval/ns_cer.cc
index a843d471..da6683b1 100644
--- a/mteval/ns_cer.cc
+++ b/mteval/ns_cer.cc
@@ -1,5 +1,6 @@
#include "ns_cer.h"
#include "tdict.h"
+#include "levenshtein.h"
static const unsigned kNUMFIELDS = 2;
static const unsigned kEDITDISTANCE = 0;
@@ -13,27 +14,6 @@ unsigned CERMetric::SufficientStatisticsVectorSize() const {
return 2;
}
-unsigned CERMetric::EditDistance(const std::string& hyp,
- const std::string& ref) const {
- const unsigned m = hyp.size(), n = ref.size();
- std::vector<unsigned> edit((m + 1) * 2);
- for(unsigned i = 0; i < n + 1; i++) {
- for(unsigned j = 0; j < m + 1; j++) {
- if(i == 0)
- edit[j] = j;
- else if(j == 0)
- edit[(i%2)*(m+1)] = i;
- else
- edit[(i%2)*(m+1) + j] = std::min(std::min(edit[(i%2)*(m+1) + j-1] + 1,
- edit[((i-1)%2)*(m+1) + j] + 1),
- edit[((i-1)%2)*(m+1) + (j-1)]
- + (hyp[j-1] == ref[i-1] ? 0 : 1));
-
- }
- }
- return edit[(n%2)*(m+1) + m];
-}
-
void CERMetric::ComputeSufficientStatistics(const std::vector<WordID>& hyp,
const std::vector<std::vector<WordID> >& refs,
SufficientStats* out) const {
@@ -42,7 +22,7 @@ void CERMetric::ComputeSufficientStatistics(const std::vector<WordID>& hyp,
float best_score = hyp_str.size();
for (size_t i = 0; i < refs.size(); ++i) {
std::string ref_str(TD::GetString(refs[i]));
- float score = EditDistance(hyp_str, ref_str);
+ float score = cdec::LevenshteinDistance(hyp_str, ref_str);
if (score < best_score) {
out->fields[kEDITDISTANCE] = score;
out->fields[kCHARCOUNT] = ref_str.size();
@@ -50,6 +30,8 @@ void CERMetric::ComputeSufficientStatistics(const std::vector<WordID>& hyp,
}
}
}
+
float CERMetric::ComputeScore(const SufficientStats& stats) const {
return stats.fields[kEDITDISTANCE] / stats.fields[kCHARCOUNT];
}
+
diff --git a/mteval/ns_cer.h b/mteval/ns_cer.h
index 9d211181..cb2b4b4a 100644
--- a/mteval/ns_cer.h
+++ b/mteval/ns_cer.h
@@ -5,9 +5,6 @@
class CERMetric : public EvaluationMetric {
friend class EvaluationMetric;
- private:
- unsigned EditDistance(const std::string& hyp,
- const std::string& ref) const;
protected:
CERMetric() : EvaluationMetric("CER") {}
diff --git a/mteval/ns_wer.cc b/mteval/ns_wer.cc
new file mode 100644
index 00000000..f9b2bbbb
--- /dev/null
+++ b/mteval/ns_wer.cc
@@ -0,0 +1,35 @@
+#include "ns_wer.h"
+#include "tdict.h"
+#include "levenshtein.h"
+
+static const unsigned kNUMFIELDS = 2;
+static const unsigned kEDITDISTANCE = 0;
+static const unsigned kCHARCOUNT = 1;
+
+bool WERMetric::IsErrorMetric() const {
+ return true;
+}
+
+unsigned WERMetric::SufficientStatisticsVectorSize() const {
+ return 2;
+}
+
+void WERMetric::ComputeSufficientStatistics(const std::vector<WordID>& hyp,
+ const std::vector<std::vector<WordID> >& refs,
+ SufficientStats* out) const {
+ out->fields.resize(kNUMFIELDS);
+ float best_score = hyp.size();
+ for (size_t i = 0; i < refs.size(); ++i) {
+ float score = cdec::LevenshteinDistance(hyp, refs[i]);
+ if (score < best_score) {
+ out->fields[kEDITDISTANCE] = score;
+ out->fields[kCHARCOUNT] = refs[i].size();
+ best_score = score;
+ }
+ }
+}
+
+float WERMetric::ComputeScore(const SufficientStats& stats) const {
+ return stats.fields[kEDITDISTANCE] / stats.fields[kCHARCOUNT];
+}
+
diff --git a/mteval/ns_wer.h b/mteval/ns_wer.h
new file mode 100644
index 00000000..24c85d83
--- /dev/null
+++ b/mteval/ns_wer.h
@@ -0,0 +1,20 @@
+#ifndef _NS_WER_H_
+#define _NS_WER_H_
+
+#include "ns.h"
+
+class WERMetric : public EvaluationMetric {
+ friend class EvaluationMetric;
+ protected:
+ WERMetric() : EvaluationMetric("WER") {}
+
+ public:
+ virtual bool IsErrorMetric() const;
+ virtual unsigned SufficientStatisticsVectorSize() const;
+ virtual void ComputeSufficientStatistics(const std::vector<WordID>& hyp,
+ const std::vector<std::vector<WordID> >& refs,
+ SufficientStats* out) const;
+ virtual float ComputeScore(const SufficientStats& stats) const;
+};
+
+#endif