diff options
Diffstat (limited to 'mteval/ns.cc')
-rw-r--r-- | mteval/ns.cc | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/mteval/ns.cc b/mteval/ns.cc index 2c8bd806..1d37c436 100644 --- a/mteval/ns.cc +++ b/mteval/ns.cc @@ -65,38 +65,41 @@ string EvaluationMetric::DetailedScore(const SufficientStats& stats) const { } enum BleuType { IBM, Koehn, NIST, QCRI }; -template <unsigned int N = 4u, BleuType BrevityType = IBM> +template <unsigned int N = 4u, BleuType BrevityType = IBM, bool CharBased = false> struct BleuSegmentEvaluator : public SegmentEvaluator { BleuSegmentEvaluator(const vector<vector<WordID> >& refs, const EvaluationMetric* em) : evaluation_metric(em) { - assert(refs.size() > 0); + const vector<vector<WordID> >& local_refs = (CharBased ? Characterize(refs) : refs); + + assert(local_refs.size() > 0); float tot = 0; int smallest = 9999999; - for (vector<vector<WordID> >::const_iterator ci = refs.begin(); - ci != refs.end(); ++ci) { + for (vector<vector<WordID> >::const_iterator ci = local_refs.begin(); + ci != local_refs.end(); ++ci) { lengths_.push_back(ci->size()); tot += lengths_.back(); if (lengths_.back() < smallest) smallest = lengths_.back(); CountRef(*ci); } if (BrevityType == Koehn) - lengths_[0] = tot / refs.size(); + lengths_[0] = tot / local_refs.size(); if (BrevityType == NIST) lengths_[0] = smallest; } void Evaluate(const vector<WordID>& hyp, SufficientStats* out) const { + const vector<WordID>& local_hyp = (CharBased ? Characterize(hyp) : hyp); out->fields.resize(N + N + 2); out->id_ = evaluation_metric->MetricId(); for (unsigned i = 0; i < N+N+2; ++i) out->fields[i] = 0; - ComputeNgramStats(hyp, &out->fields[0], &out->fields[N], true); + ComputeNgramStats(local_hyp, &out->fields[0], &out->fields[N], true); float& hyp_len = out->fields[2*N]; float& ref_len = out->fields[2*N + 1]; - hyp_len = hyp.size(); + hyp_len = local_hyp.size(); ref_len = lengths_[0]; if (lengths_.size() > 1 && (BrevityType == IBM || BrevityType == QCRI)) { float bestd = 2000000; - float hl = hyp.size(); + float hl = local_hyp.size(); float bl = -1; for (vector<float>::const_iterator ci = lengths_.begin(); ci != lengths_.end(); ++ci) { if (fabs(*ci - hl) < bestd) { @@ -187,12 +190,12 @@ struct BleuSegmentEvaluator : public SegmentEvaluator { mutable NGramCountMap ngrams_; }; -template <unsigned int N = 4u, BleuType BrevityType = IBM> +template <unsigned int N = 4u, BleuType BrevityType = IBM, bool CharBased = false> struct BleuMetric : public EvaluationMetric { BleuMetric() : EvaluationMetric(BrevityType == IBM ? "IBM_BLEU" : (BrevityType == Koehn ? "KOEHN_BLEU" : (BrevityType == NIST ? "NIST_BLEU" : "QCRI_BLEU"))) {} unsigned SufficientStatisticsVectorSize() const { return N*2 + 2; } boost::shared_ptr<SegmentEvaluator> CreateSegmentEvaluator(const vector<vector<WordID> >& refs) const { - return boost::shared_ptr<SegmentEvaluator>(new BleuSegmentEvaluator<N,BrevityType>(refs, this)); + return boost::shared_ptr<SegmentEvaluator>(new BleuSegmentEvaluator<N,BrevityType, CharBased>(refs, this)); } float ComputeBreakdown(const SufficientStats& stats, float* bp, vector<float>* out) const { if (out) { out->clear(); } @@ -290,6 +293,8 @@ EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) { m = new CERMetric; } else if (metric_id == "WER") { m = new WERMetric; + } else if (metric_id == "CBLEU") { + return new BleuMetric<5, IBM, true>; } else { cerr << "Implement please: " << metric_id << endl; abort(); @@ -322,4 +327,3 @@ void SufficientStats::Encode(string* out) const { os << ' ' << fields[i]; *out = os.str(); } - |