summaryrefslogtreecommitdiff
path: root/mteval/ns.cc
diff options
context:
space:
mode:
Diffstat (limited to 'mteval/ns.cc')
-rw-r--r--mteval/ns.cc26
1 files changed, 15 insertions, 11 deletions
diff --git a/mteval/ns.cc b/mteval/ns.cc
index 2c8bd806..1d37c436 100644
--- a/mteval/ns.cc
+++ b/mteval/ns.cc
@@ -65,38 +65,41 @@ string EvaluationMetric::DetailedScore(const SufficientStats& stats) const {
}
enum BleuType { IBM, Koehn, NIST, QCRI };
-template <unsigned int N = 4u, BleuType BrevityType = IBM>
+template <unsigned int N = 4u, BleuType BrevityType = IBM, bool CharBased = false>
struct BleuSegmentEvaluator : public SegmentEvaluator {
BleuSegmentEvaluator(const vector<vector<WordID> >& refs, const EvaluationMetric* em) : evaluation_metric(em) {
- assert(refs.size() > 0);
+ const vector<vector<WordID> >& local_refs = (CharBased ? Characterize(refs) : refs);
+
+ assert(local_refs.size() > 0);
float tot = 0;
int smallest = 9999999;
- for (vector<vector<WordID> >::const_iterator ci = refs.begin();
- ci != refs.end(); ++ci) {
+ for (vector<vector<WordID> >::const_iterator ci = local_refs.begin();
+ ci != local_refs.end(); ++ci) {
lengths_.push_back(ci->size());
tot += lengths_.back();
if (lengths_.back() < smallest) smallest = lengths_.back();
CountRef(*ci);
}
if (BrevityType == Koehn)
- lengths_[0] = tot / refs.size();
+ lengths_[0] = tot / local_refs.size();
if (BrevityType == NIST)
lengths_[0] = smallest;
}
void Evaluate(const vector<WordID>& hyp, SufficientStats* out) const {
+ const vector<WordID>& local_hyp = (CharBased ? Characterize(hyp) : hyp);
out->fields.resize(N + N + 2);
out->id_ = evaluation_metric->MetricId();
for (unsigned i = 0; i < N+N+2; ++i) out->fields[i] = 0;
- ComputeNgramStats(hyp, &out->fields[0], &out->fields[N], true);
+ ComputeNgramStats(local_hyp, &out->fields[0], &out->fields[N], true);
float& hyp_len = out->fields[2*N];
float& ref_len = out->fields[2*N + 1];
- hyp_len = hyp.size();
+ hyp_len = local_hyp.size();
ref_len = lengths_[0];
if (lengths_.size() > 1 && (BrevityType == IBM || BrevityType == QCRI)) {
float bestd = 2000000;
- float hl = hyp.size();
+ float hl = local_hyp.size();
float bl = -1;
for (vector<float>::const_iterator ci = lengths_.begin(); ci != lengths_.end(); ++ci) {
if (fabs(*ci - hl) < bestd) {
@@ -187,12 +190,12 @@ struct BleuSegmentEvaluator : public SegmentEvaluator {
mutable NGramCountMap ngrams_;
};
-template <unsigned int N = 4u, BleuType BrevityType = IBM>
+template <unsigned int N = 4u, BleuType BrevityType = IBM, bool CharBased = false>
struct BleuMetric : public EvaluationMetric {
BleuMetric() : EvaluationMetric(BrevityType == IBM ? "IBM_BLEU" : (BrevityType == Koehn ? "KOEHN_BLEU" : (BrevityType == NIST ? "NIST_BLEU" : "QCRI_BLEU"))) {}
unsigned SufficientStatisticsVectorSize() const { return N*2 + 2; }
boost::shared_ptr<SegmentEvaluator> CreateSegmentEvaluator(const vector<vector<WordID> >& refs) const {
- return boost::shared_ptr<SegmentEvaluator>(new BleuSegmentEvaluator<N,BrevityType>(refs, this));
+ return boost::shared_ptr<SegmentEvaluator>(new BleuSegmentEvaluator<N,BrevityType, CharBased>(refs, this));
}
float ComputeBreakdown(const SufficientStats& stats, float* bp, vector<float>* out) const {
if (out) { out->clear(); }
@@ -290,6 +293,8 @@ EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) {
m = new CERMetric;
} else if (metric_id == "WER") {
m = new WERMetric;
+ } else if (metric_id == "CBLEU") {
+ return new BleuMetric<5, IBM, true>;
} else {
cerr << "Implement please: " << metric_id << endl;
abort();
@@ -322,4 +327,3 @@ void SufficientStats::Encode(string* out) const {
os << ' ' << fields[i];
*out = os.str();
}
-