summaryrefslogtreecommitdiff
path: root/python/src/mteval.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/mteval.pxi')
-rw-r--r--python/src/mteval.pxi17
1 files changed, 14 insertions, 3 deletions
diff --git a/python/src/mteval.pxi b/python/src/mteval.pxi
index 67a29f6f..d90eb9a6 100644
--- a/python/src/mteval.pxi
+++ b/python/src/mteval.pxi
@@ -10,6 +10,15 @@ cdef char* as_str(sentence, error_msg='Cannot convert type %s to str'):
raise TypeError(error_msg % type(sentence))
return ret
+cdef SufficientStats as_stats(x, y):
+ if isinstance(x, SufficientStats):
+ return x
+ elif x == 0 and isinstance(y, SufficientStats):
+ stats = SufficientStats()
+ stats.stats = new mteval.SufficientStats()
+ stats.metric = (<SufficientStats> y).metric
+ return stats
+
cdef class Candidate:
cdef mteval.Candidate* candidate
cdef public float score
@@ -50,10 +59,12 @@ cdef class SufficientStats:
self.stats[0] += other.stats[0]
return self
- def __add__(SufficientStats x, SufficientStats y):
+ def __add__(x, y):
+ cdef SufficientStats sx = as_stats(x, y)
+ cdef SufficientStats sy = as_stats(y, x)
cdef SufficientStats result = SufficientStats()
- result.stats = new mteval.SufficientStats(mteval.add(x.stats[0], y.stats[0]))
- result.metric = x.metric
+ result.stats = new mteval.SufficientStats(mteval.add(sx.stats[0], sy.stats[0]))
+ result.metric = sx.metric
return result
cdef class CandidateSet: