diff options
Diffstat (limited to 'mteval/ns.cc')
-rw-r--r-- | mteval/ns.cc | 42 |
1 files changed, 35 insertions, 7 deletions
diff --git a/mteval/ns.cc b/mteval/ns.cc index 3af7cc63..b64d4798 100644 --- a/mteval/ns.cc +++ b/mteval/ns.cc @@ -3,6 +3,7 @@ #include "ns_ext.h" #include "ns_comb.h" #include "ns_cer.h" +#include "ns_ssk.h" #include <cstdio> #include <cassert> @@ -12,12 +13,15 @@ #include <sstream> #include "tdict.h" +#include "filelib.h" #include "stringlib.h" using namespace std; map<string, EvaluationMetric*> EvaluationMetric::instances_; +extern const char* meteor_jar_path; + SegmentEvaluator::~SegmentEvaluator() {} EvaluationMetric::~EvaluationMetric() {} @@ -57,7 +61,7 @@ string EvaluationMetric::DetailedScore(const SufficientStats& stats) const { return os.str(); } -enum BleuType { IBM, Koehn, NIST }; +enum BleuType { IBM, Koehn, NIST, QCRI }; template <unsigned int N = 4u, BleuType BrevityType = IBM> struct BleuSegmentEvaluator : public SegmentEvaluator { BleuSegmentEvaluator(const vector<vector<WordID> >& refs, const EvaluationMetric* em) : evaluation_metric(em) { @@ -87,7 +91,7 @@ struct BleuSegmentEvaluator : public SegmentEvaluator { float& ref_len = out->fields[2*N + 1]; hyp_len = hyp.size(); ref_len = lengths_[0]; - if (lengths_.size() > 1 && BrevityType == IBM) { + if (lengths_.size() > 1 && (BrevityType == IBM || BrevityType == QCRI)) { float bestd = 2000000; float hl = hyp.size(); float bl = -1; @@ -182,7 +186,7 @@ struct BleuSegmentEvaluator : public SegmentEvaluator { template <unsigned int N = 4u, BleuType BrevityType = IBM> struct BleuMetric : public EvaluationMetric { - BleuMetric() : EvaluationMetric(BrevityType == IBM ? "IBM_BLEU" : (BrevityType == Koehn ? "KOEHN_BLEU" : "NIST_BLEU")) {} + BleuMetric() : EvaluationMetric(BrevityType == IBM ? "IBM_BLEU" : (BrevityType == Koehn ? "KOEHN_BLEU" : (BrevityType == NIST ? "NIST_BLEU" : "QCRI_BLEU"))) {} unsigned SufficientStatisticsVectorSize() const { return N*2 + 2; } boost::shared_ptr<SegmentEvaluator> CreateSegmentEvaluator(const vector<vector<WordID> >& refs) const { return boost::shared_ptr<SegmentEvaluator>(new BleuSegmentEvaluator<N,BrevityType>(refs, this)); @@ -190,26 +194,37 @@ struct BleuMetric : public EvaluationMetric { float ComputeBreakdown(const SufficientStats& stats, float* bp, vector<float>* out) const { if (out) { out->clear(); } float log_bleu = 0; + float log_bleu_adj = 0; // for QCRI int count = 0; + float alpha = BrevityType == QCRI ? 1 : 0.01; for (int i = 0; i < N; ++i) { if (stats.fields[i+N] > 0) { float cor_count = stats.fields[i]; // correct_ngram_hit_counts[i]; // smooth bleu - if (!cor_count) { cor_count = 0.01; } + if (!cor_count) { cor_count = alpha; } float lprec = log(cor_count) - log(stats.fields[i+N]); // log(hyp_ngram_counts[i]); if (out) out->push_back(exp(lprec)); log_bleu += lprec; + if (BrevityType == QCRI) + log_bleu_adj += log(alpha) - log(stats.fields[i+N] + alpha); ++count; } } log_bleu /= count; + log_bleu_adj /= count; float lbp = 0.0; const float& hyp_len = stats.fields[2*N]; const float& ref_len = stats.fields[2*N + 1]; - if (hyp_len < ref_len) - lbp = (hyp_len - ref_len) / hyp_len; + if (hyp_len < ref_len) { + if (BrevityType == QCRI) + lbp = (hyp_len - ref_len - alpha) / hyp_len; + else + lbp = (hyp_len - ref_len) / hyp_len; + } log_bleu += lbp; if (bp) *bp = exp(lbp); + if (BrevityType == QCRI) + return exp(log_bleu) - exp(lbp + log_bleu_adj); return exp(log_bleu); } string DetailedScore(const SufficientStats& stats) const { @@ -249,10 +264,23 @@ EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) { m = new BleuMetric<4, NIST>; } else if (metric_id == "KOEHN_BLEU") { m = new BleuMetric<4, Koehn>; + } else if (metric_id == "QCRI_BLEU") { + m = new BleuMetric<4, QCRI>; + } else if (metric_id == "SSK") { + m = new SSKMetric; } else if (metric_id == "TER") { m = new TERMetric; } else if (metric_id == "METEOR") { - m = new ExternalMetric("METEOR", "java -Xmx1536m -jar /cab0/tools/meteor-1.3/meteor-1.3.jar - - -mira -lower -t tune -l en"); +#if HAVE_METEOR + if (!FileExists(meteor_jar_path)) { + cerr << meteor_jar_path << " not found!\n"; + abort(); + } + m = new ExternalMetric("METEOR", string("java -Xmx1536m -jar ") + meteor_jar_path + " - - -mira -lower -t tune -l en"); +#else + cerr << "cdec was not built with the --with-meteor option." << endl; + abort(); +#endif } else if (metric_id.find("COMB:") == 0) { m = new CombinationMetric(metric_id); } else if (metric_id == "CER") { |