summaryrefslogtreecommitdiff
path: root/python/src/trule.pxi
blob: 6168014d016d684f27d92f845a7a71a5ccb804a3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def _phrase(phrase):
    return ' '.join('[%s,%d]' % w if isinstance(w, tuple) else w.encode('utf8') for w in phrase)

cdef class TRule:
    cdef hypergraph.TRule* rule

    property arity:
        def __get__(self):
            return self.rule.arity_

    property f:
        def __get__(self):
            cdef vector[WordID]* f = &self.rule.f_
            cdef WordID w
            cdef words = []
            cdef unsigned i
            cdef int idx = 0
            for i in range(f.size()):
                w = f[0][i]
                if w < 0:
                    idx += 1
                    words.append((TDConvert(-w), idx))
                else:
                    words.append(unicode(TDConvert(w), encoding='utf8'))
            return words

    property e:
        def __get__(self):
            cdef vector[WordID]* e = &self.rule.e_
            cdef WordID w
            cdef words = []
            cdef unsigned i
            cdef int idx = 0
            for i in range(e.size()):
                w = e[0][i]
                if w < 1:
                    idx += 1
                    words.append((TDConvert(1-w), idx))
                else:
                    words.append(unicode(TDConvert(w), encoding='utf8'))
            return words

    property scores:
        def __get__(self):
            cdef SparseVector scores = SparseVector()
            scores.vector = new FastSparseVector[double](self.rule.scores_)
            return scores

    property lhs:
        def __get__(self):
            return TDConvert(-self.rule.lhs_)

    def __str__(self):
        scores = ' '.join('%s=%s' % feat for feat in self.scores)
        return '[%s] ||| %s ||| %s ||| %s' % (self.lhs, _phrase(self.f), _phrase(self.e), scores)