summaryrefslogtreecommitdiff
path: root/python/src/hypergraph.pxi
blob: c226d105cd389c63599e721a956cd2d08fae4648 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
cimport hypergraph
cimport kbest as kb

cdef class Hypergraph:
    cdef hypergraph.Hypergraph* hg
    cdef MT19937* rng

    def __dealloc__(self):
        del self.hg
        if self.rng != NULL:
            del self.rng

    def viterbi(self):
        cdef vector[WordID] trans
        hypergraph.ViterbiESentence(self.hg[0], &trans)
        cdef str sentence = GetString(trans).c_str()
        return sentence.decode('utf8')

    def viterbi_tree(self):
        cdef str tree = hypergraph.ViterbiETree(self.hg[0]).c_str()
        return tree.decode('utf8')

    def viterbi_source_tree(self):
        cdef str tree = hypergraph.ViterbiFTree(self.hg[0]).c_str()
        return tree.decode('utf8')
    
    def viterbi_features(self):
        cdef SparseVector fmap = SparseVector()
        fmap.vector = new FastSparseVector[weight_t](hypergraph.ViterbiFeatures(self.hg[0]))
        return fmap

    def kbest(self, size):
        cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal](self.hg[0], size)
        cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal].Derivation* derivation
        cdef str sentence
        cdef unsigned k
        for k in range(size):
            derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
            if not derivation: break
            sentence = GetString(derivation._yield).c_str()
            yield sentence.decode('utf8')
        del derivations

    def kbest_tree(self, size):
        cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ETreeTraversal](self.hg[0], size)
        cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal].Derivation* derivation
        cdef str tree
        cdef unsigned k
        for k in range(size):
            derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
            if not derivation: break
            tree = GetString(derivation._yield).c_str()
            yield tree.decode('utf8')
        del derivations

    def sample(self, unsigned n):
        cdef vector[hypergraph.Hypothesis]* hypos = new vector[hypergraph.Hypothesis]()
        if self.rng == NULL:
            self.rng = new MT19937()
        hypergraph.sample_hypotheses(self.hg[0], n, self.rng, hypos)
        cdef str sentence
        cdef unsigned k
        for k in range(hypos.size()):
            sentence = GetString(hypos[0][k].words).c_str()
            yield sentence.decode('utf8')
        del hypos

    # TODO richer k-best/sample output (feature vectors, trees?)

    def intersect(self, Lattice lat):
        return hypergraph.Intersect(lat.lattice[0], self.hg)

    def prune(self, beam_alpha=0, density=0, **kwargs):
        cdef hypergraph.EdgeMask* preserve_mask = NULL
        if 'csplit_preserve_full_word' in kwargs:
             preserve_mask = new hypergraph.EdgeMask(self.hg.edges_.size())
             preserve_mask[0][hypergraph.GetFullWordEdgeIndex(self.hg[0])] = True
        self.hg.PruneInsideOutside(beam_alpha, density, preserve_mask, False, 1, False)

    def lattice(self): # TODO direct hg -> lattice conversion in cdec
        cdef str plf = hypergraph.AsPLF(self.hg[0], True).c_str()
        return Lattice(eval(plf))

    def reweight(self, weights):
        if isinstance(weights, SparseVector):
            self.hg.Reweight((<SparseVector> weights).vector[0])
        elif isinstance(weights, DenseVector):
            self.hg.Reweight((<DenseVector> weights).vector[0])
        else:
            raise ValueError('cannot reweight hypergraph with %s' % type(weights))

    # TODO get feature expectations, get partition function ("inside" score)