diff options
Diffstat (limited to 'python/src/mteval.pxi')
-rw-r--r-- | python/src/mteval.pxi | 74 |
1 files changed, 56 insertions, 18 deletions
diff --git a/python/src/mteval.pxi b/python/src/mteval.pxi index d90eb9a6..cd1c3c81 100644 --- a/python/src/mteval.pxi +++ b/python/src/mteval.pxi @@ -1,15 +1,5 @@ cimport mteval -cdef char* as_str(sentence, error_msg='Cannot convert type %s to str'): - cdef bytes ret - if isinstance(sentence, unicode): - ret = sentence.encode('utf8') - elif isinstance(sentence, str): - ret = sentence - else: - raise TypeError(error_msg % type(sentence)) - return ret - cdef SufficientStats as_stats(x, y): if isinstance(x, SufficientStats): return x @@ -20,7 +10,7 @@ cdef SufficientStats as_stats(x, y): return stats cdef class Candidate: - cdef mteval.Candidate* candidate + cdef mteval.const_Candidate* candidate cdef public float score property words: @@ -29,7 +19,7 @@ cdef class Candidate: property fmap: def __get__(self): - cdef SparseVector fmap = SparseVector() + cdef SparseVector fmap = SparseVector.__new__(SparseVector) fmap.vector = new FastSparseVector[weight_t](self.candidate.fmap) return fmap @@ -53,7 +43,12 @@ cdef class SufficientStats: def __iter__(self): for i in range(len(self)): - yield self.stats[0][i] + yield self[i] + + def __getitem__(self, int index): + if not 0 <= index < len(self): + raise IndexError('sufficient stats vector index out of range') + return self.stats[0][index] def __iadd__(SufficientStats self, SufficientStats other): self.stats[0] += other.stats[0] @@ -121,15 +116,17 @@ cdef class SegmentEvaluator: cdef class Scorer: cdef string* name + cdef mteval.EvaluationMetric* metric - def __cinit__(self, char* name): - self.name = new string(name) + def __cinit__(self, bytes name=None): + if name: + self.name = new string(name) + self.metric = mteval.MetricInstance(self.name[0]) def __dealloc__(self): del self.name def __call__(self, refs): - cdef mteval.EvaluationMetric* metric = mteval.Instance(self.name[0]) if isinstance(refs, unicode) or isinstance(refs, str): refs = [refs] cdef vector[vector[WordID]]* refsv = new vector[vector[WordID]]() @@ -142,13 +139,54 @@ cdef class Scorer: del refv cdef unsigned i cdef SegmentEvaluator evaluator = SegmentEvaluator() - evaluator.metric = metric - evaluator.scorer = new shared_ptr[mteval.SegmentEvaluator](metric.CreateSegmentEvaluator(refsv[0])) + evaluator.metric = self.metric + evaluator.scorer = new shared_ptr[mteval.SegmentEvaluator]( + self.metric.CreateSegmentEvaluator(refsv[0])) del refsv # in theory should not delete but store in SegmentEvaluator return evaluator def __str__(self): return self.name.c_str() +cdef float _compute_score(void* metric_, mteval.SufficientStats* stats): + cdef Metric metric = <Metric> metric_ + cdef list ss = [] + cdef unsigned i + for i in range(stats.size()): + ss.append(stats[0][i]) + return metric.score(ss) + +cdef void _compute_sufficient_stats(void* metric_, + string* hyp, + vector[string]* refs, + mteval.SufficientStats* out): + cdef Metric metric = <Metric> metric_ + cdef list refs_ = [] + cdef unsigned i + for i in range(refs.size()): + refs_.append(refs[0][i].c_str()) + cdef list ss = metric.evaluate(hyp.c_str(), refs_) + out.fields.resize(len(ss)) + for i in range(len(ss)): + out.fields[i] = ss[i] + +cdef class Metric: + cdef Scorer scorer + def __cinit__(self): + self.scorer = Scorer() + self.scorer.name = new string(as_str(self.__class__.__name__)) + self.scorer.metric = mteval.PyMetricInstance(self.scorer.name[0], + <void*> self, _compute_sufficient_stats, _compute_score) + + def __call__(self, refs): + return self.scorer(refs) + + def score(SufficientStats stats): + return 0 + + def evaluate(self, hyp, refs): + return [] + BLEU = Scorer('IBM_BLEU') TER = Scorer('TER') +CER = Scorer('CER') |