diff options
Diffstat (limited to 'mteval')
| -rw-r--r-- | mteval/Jamfile | 2 | ||||
| -rw-r--r-- | mteval/Makefile.am | 2 | ||||
| -rw-r--r-- | mteval/ns.cc | 3 | ||||
| -rw-r--r-- | mteval/ns_cer.cc | 55 | ||||
| -rw-r--r-- | mteval/ns_cer.h | 23 | 
5 files changed, 83 insertions, 2 deletions
| diff --git a/mteval/Jamfile b/mteval/Jamfile index 6260caea..3ed2c2cc 100644 --- a/mteval/Jamfile +++ b/mteval/Jamfile @@ -1,6 +1,6 @@  import testing ; -lib mteval : ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc ..//utils : <include>. : : <include>. <library>..//z ; +lib mteval : ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc ns_cer.cc ..//utils : <include>. : : <include>. <library>..//z ;  exe fast_score : fast_score.cc mteval ..//utils ..//boost_program_options ;  exe mbr_kbest : mbr_kbest.cc mteval ..//utils ..//boost_program_options ;  alias programs : fast_score mbr_kbest ; diff --git a/mteval/Makefile.am b/mteval/Makefile.am index 8d844e24..22550c99 100644 --- a/mteval/Makefile.am +++ b/mteval/Makefile.am @@ -8,7 +8,7 @@ TESTS = scorer_test  noinst_LIBRARIES = libmteval.a -libmteval_a_SOURCES = ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc +libmteval_a_SOURCES = ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc ns_cer.cc  fast_score_SOURCES = fast_score.cc  fast_score_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a -lz diff --git a/mteval/ns.cc b/mteval/ns.cc index 8d354677..33952da7 100644 --- a/mteval/ns.cc +++ b/mteval/ns.cc @@ -2,6 +2,7 @@  #include "ns_ter.h"  #include "ns_ext.h"  #include "ns_comb.h" +#include "ns_cer.h"  #include <cstdio>  #include <cassert> @@ -254,6 +255,8 @@ EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) {        m = new ExternalMetric("METEOR", "java -Xmx1536m -jar /Users/cdyer/software/meteor/meteor-1.3.jar - - -mira -lower -t tune -l en");      } else if (metric_id.find("COMB:") == 0) {        m = new CombinationMetric(metric_id); +    } else if (metric_id == "CER") { +      m = new CERMetric;      } else {        cerr << "Implement please: " << metric_id << endl;        abort(); diff --git a/mteval/ns_cer.cc b/mteval/ns_cer.cc new file mode 100644 index 00000000..a843d471 --- /dev/null +++ b/mteval/ns_cer.cc @@ -0,0 +1,55 @@ +#include "ns_cer.h" +#include "tdict.h" + +static const unsigned kNUMFIELDS = 2; +static const unsigned kEDITDISTANCE = 0; +static const unsigned kCHARCOUNT = 1; + +bool CERMetric::IsErrorMetric() const { +  return true; +} + +unsigned CERMetric::SufficientStatisticsVectorSize() const { +  return 2; +} + +unsigned CERMetric::EditDistance(const std::string& hyp, +                                 const std::string& ref) const { +  const unsigned m = hyp.size(), n = ref.size(); +  std::vector<unsigned> edit((m + 1) * 2); +  for(unsigned i = 0; i < n + 1; i++) { +    for(unsigned j = 0; j < m + 1; j++) { +      if(i == 0) +        edit[j] = j; +      else if(j == 0) +        edit[(i%2)*(m+1)] = i; +      else +        edit[(i%2)*(m+1) + j] = std::min(std::min(edit[(i%2)*(m+1) + j-1] + 1, +                                                   edit[((i-1)%2)*(m+1) + j] + 1), +                                                   edit[((i-1)%2)*(m+1) + (j-1)]  +                                                   + (hyp[j-1] == ref[i-1] ? 0 : 1)); +       +    } +  } +  return edit[(n%2)*(m+1) + m]; +} + +void CERMetric::ComputeSufficientStatistics(const std::vector<WordID>& hyp, +                                            const std::vector<std::vector<WordID> >& refs, +                                            SufficientStats* out) const { +  out->fields.resize(kNUMFIELDS); +  std::string hyp_str(TD::GetString(hyp)); +  float best_score = hyp_str.size(); +  for (size_t i = 0; i < refs.size(); ++i) { +    std::string ref_str(TD::GetString(refs[i])); +    float score = EditDistance(hyp_str, ref_str); +    if (score < best_score) { +      out->fields[kEDITDISTANCE] = score; +      out->fields[kCHARCOUNT] = ref_str.size(); +      best_score = score; +    } +  } +} +float CERMetric::ComputeScore(const SufficientStats& stats) const { +  return stats.fields[kEDITDISTANCE] / stats.fields[kCHARCOUNT]; +} diff --git a/mteval/ns_cer.h b/mteval/ns_cer.h new file mode 100644 index 00000000..9d211181 --- /dev/null +++ b/mteval/ns_cer.h @@ -0,0 +1,23 @@ +#ifndef _NS_CER_H_ +#define _NS_CER_H_ + +#include "ns.h" + +class CERMetric : public EvaluationMetric { +  friend class EvaluationMetric; + private: +  unsigned EditDistance(const std::string& hyp, +                        const std::string& ref) const; + protected: +  CERMetric() : EvaluationMetric("CER") {} + + public: +  virtual bool IsErrorMetric() const; +  virtual unsigned SufficientStatisticsVectorSize() const; +  virtual void ComputeSufficientStatistics(const std::vector<WordID>& hyp, +                                           const std::vector<std::vector<WordID> >& refs, +                                           SufficientStats* out) const; +  virtual float ComputeScore(const SufficientStats& stats) const; +}; + +#endif | 
