summaryrefslogtreecommitdiff
path: root/python/src/mteval.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/mteval.pxi')
-rw-r--r--python/src/mteval.pxi74
1 files changed, 56 insertions, 18 deletions
diff --git a/python/src/mteval.pxi b/python/src/mteval.pxi
index d90eb9a6..cd1c3c81 100644
--- a/python/src/mteval.pxi
+++ b/python/src/mteval.pxi
@@ -1,15 +1,5 @@
cimport mteval
-cdef char* as_str(sentence, error_msg='Cannot convert type %s to str'):
- cdef bytes ret
- if isinstance(sentence, unicode):
- ret = sentence.encode('utf8')
- elif isinstance(sentence, str):
- ret = sentence
- else:
- raise TypeError(error_msg % type(sentence))
- return ret
-
cdef SufficientStats as_stats(x, y):
if isinstance(x, SufficientStats):
return x
@@ -20,7 +10,7 @@ cdef SufficientStats as_stats(x, y):
return stats
cdef class Candidate:
- cdef mteval.Candidate* candidate
+ cdef mteval.const_Candidate* candidate
cdef public float score
property words:
@@ -29,7 +19,7 @@ cdef class Candidate:
property fmap:
def __get__(self):
- cdef SparseVector fmap = SparseVector()
+ cdef SparseVector fmap = SparseVector.__new__(SparseVector)
fmap.vector = new FastSparseVector[weight_t](self.candidate.fmap)
return fmap
@@ -53,7 +43,12 @@ cdef class SufficientStats:
def __iter__(self):
for i in range(len(self)):
- yield self.stats[0][i]
+ yield self[i]
+
+ def __getitem__(self, int index):
+ if not 0 <= index < len(self):
+ raise IndexError('sufficient stats vector index out of range')
+ return self.stats[0][index]
def __iadd__(SufficientStats self, SufficientStats other):
self.stats[0] += other.stats[0]
@@ -121,15 +116,17 @@ cdef class SegmentEvaluator:
cdef class Scorer:
cdef string* name
+ cdef mteval.EvaluationMetric* metric
- def __cinit__(self, char* name):
- self.name = new string(name)
+ def __cinit__(self, bytes name=None):
+ if name:
+ self.name = new string(name)
+ self.metric = mteval.MetricInstance(self.name[0])
def __dealloc__(self):
del self.name
def __call__(self, refs):
- cdef mteval.EvaluationMetric* metric = mteval.Instance(self.name[0])
if isinstance(refs, unicode) or isinstance(refs, str):
refs = [refs]
cdef vector[vector[WordID]]* refsv = new vector[vector[WordID]]()
@@ -142,13 +139,54 @@ cdef class Scorer:
del refv
cdef unsigned i
cdef SegmentEvaluator evaluator = SegmentEvaluator()
- evaluator.metric = metric
- evaluator.scorer = new shared_ptr[mteval.SegmentEvaluator](metric.CreateSegmentEvaluator(refsv[0]))
+ evaluator.metric = self.metric
+ evaluator.scorer = new shared_ptr[mteval.SegmentEvaluator](
+ self.metric.CreateSegmentEvaluator(refsv[0]))
del refsv # in theory should not delete but store in SegmentEvaluator
return evaluator
def __str__(self):
return self.name.c_str()
+cdef float _compute_score(void* metric_, mteval.SufficientStats* stats):
+ cdef Metric metric = <Metric> metric_
+ cdef list ss = []
+ cdef unsigned i
+ for i in range(stats.size()):
+ ss.append(stats[0][i])
+ return metric.score(ss)
+
+cdef void _compute_sufficient_stats(void* metric_,
+ string* hyp,
+ vector[string]* refs,
+ mteval.SufficientStats* out):
+ cdef Metric metric = <Metric> metric_
+ cdef list refs_ = []
+ cdef unsigned i
+ for i in range(refs.size()):
+ refs_.append(refs[0][i].c_str())
+ cdef list ss = metric.evaluate(hyp.c_str(), refs_)
+ out.fields.resize(len(ss))
+ for i in range(len(ss)):
+ out.fields[i] = ss[i]
+
+cdef class Metric:
+ cdef Scorer scorer
+ def __cinit__(self):
+ self.scorer = Scorer()
+ self.scorer.name = new string(as_str(self.__class__.__name__))
+ self.scorer.metric = mteval.PyMetricInstance(self.scorer.name[0],
+ <void*> self, _compute_sufficient_stats, _compute_score)
+
+ def __call__(self, refs):
+ return self.scorer(refs)
+
+ def score(SufficientStats stats):
+ return 0
+
+ def evaluate(self, hyp, refs):
+ return []
+
BLEU = Scorer('IBM_BLEU')
TER = Scorer('TER')
+CER = Scorer('CER')