From b5aaa40ef7c9d216ae905b76f50e68a8f94656c9 Mon Sep 17 00:00:00 2001 From: armatthews Date: Thu, 9 Apr 2015 00:46:45 -0400 Subject: added missed files --- mteval/wer.cc | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ mteval/wer.h | 16 ++++++++ 2 files changed, 133 insertions(+) create mode 100644 mteval/wer.cc create mode 100644 mteval/wer.h diff --git a/mteval/wer.cc b/mteval/wer.cc new file mode 100644 index 00000000..c806b3be --- /dev/null +++ b/mteval/wer.cc @@ -0,0 +1,117 @@ +#include "wer.h" + +#include +#include +#include +#include +#include +#ifndef HAVE_OLD_CPP +# include +#else +# include +namespace std { using std::tr1::unordered_map; } +#endif +#include +#include +#include +#include +#include "tdict.h" +#include "levenshtein.h" + +using namespace std; + +class WERScore : public ScoreBase { + friend class WERScorer; + + public: + static const unsigned kEDITDISTANCE = 0; + static const unsigned kCHARCOUNT = 1; + static const unsigned kDUMMY_LAST_ENTRY = 2; + + WERScore() : stats(0,kDUMMY_LAST_ENTRY) {} + float ComputePartialScore() const { return 0.0;} + float ComputeScore() const { + return static_cast(stats[kEDITDISTANCE]) / static_cast(stats[kCHARCOUNT]); + } + void ScoreDetails(string* details) const; + void PlusPartialEquals(const Score& rhs, int oracle_e_cover, int oracle_f_cover, int src_len){} + void PlusEquals(const Score& delta, const float scale) { + if (scale==1) + stats += static_cast(delta).stats; + if (scale==-1) + stats -= static_cast(delta).stats; + throw std::runtime_error("WERScore::PlusEquals with scale != +-1"); + } + void PlusEquals(const Score& delta) { + stats += static_cast(delta).stats; + } + + ScoreP GetZero() const { + return ScoreP(new WERScore); + } + ScoreP GetOne() const { + return ScoreP(new WERScore); + } + void Subtract(const Score& rhs, Score* res) const { + static_cast(res)->stats = stats - static_cast(rhs).stats; + } + void Encode(std::string* out) const { + ostringstream os; + os << stats[kEDITDISTANCE] << ' ' + << stats[kCHARCOUNT]; + *out = os.str(); + } + bool IsAdditiveIdentity() const { + for (int i = 0; i < kDUMMY_LAST_ENTRY; ++i) + if (stats[i] != 0) return false; + return true; + } + private: + valarray stats; +}; + +ScoreP WERScorer::ScoreFromString(const std::string& data) { + istringstream is(data); + WERScore* r = new WERScore; + is >> r->stats[WERScore::kEDITDISTANCE] + >> r->stats[WERScore::kCHARCOUNT]; + return ScoreP(r); +} + +void WERScore::ScoreDetails(std::string* details) const { + char buf[200]; + sprintf(buf, "WER = %.2f, edits=%d, len=%d", + ComputeScore() * 100.0f, + stats[kEDITDISTANCE], + stats[kCHARCOUNT]); + *details = buf; +} + +WERScorer::~WERScorer() {} +WERScorer::WERScorer(const vector >& refs) {} + +ScoreP WERScorer::ScoreCCandidate(const vector& hyp) const { + return ScoreP(); +} + +float WERScorer::Calculate(const std::vector& hyp, const Sentence& ref, int& edits, int& char_count) const { + edits = cdec::LevenshteinDistance(hyp, ref); + char_count = ref.size(); + return static_cast(edits) / static_cast(char_count); +} + +ScoreP WERScorer::ScoreCandidate(const std::vector& hyp) const { + float best_score = numeric_limits::max(); + WERScore* res = new WERScore; + for (int i = 0; i < refs.size(); ++i) { + int edits, char_count; + const vector& ref = refs[i]; + float score = Calculate(hyp, ref, edits, char_count); + if (score < best_score) { + res->stats[WERScore::kEDITDISTANCE] = edits; + res->stats[WERScore::kCHARCOUNT] = char_count; + best_score = score; + } + } + return ScoreP(res); +} diff --git a/mteval/wer.h b/mteval/wer.h new file mode 100644 index 00000000..3a799f58 --- /dev/null +++ b/mteval/wer.h @@ -0,0 +1,16 @@ +#ifndef WER_H_ +#define WER_H_ + +#include "scorer.h" + +class WERScorer : public SentenceScorer { + public: + WERScorer(const std::vector >& references); + ~WERScorer(); + ScoreP ScoreCandidate(const std::vector& hyp) const; + ScoreP ScoreCCandidate(const std::vector& hyp) const; + static ScoreP ScoreFromString(const std::string& data); + float Calculate(const std::vector& hyp, const Sentence& ref, int& edits, int& char_count) const; +}; + +#endif -- cgit v1.2.3