diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-06-23 15:54:38 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-06-23 15:54:38 -0400 |
commit | 6a2bbfa9b28f8be2189744beb89a975bc3da128f (patch) | |
tree | 4d9ee4e0b6d2c37be8f8a4e05472cd21e66d0ffd /python/src/hypergraph.pxi | |
parent | b9266b068a37dc46f8de813c59ffbef2e4b89280 (diff) | |
parent | b738e349be490c24d3604c224f44fc54e16d3d7b (diff) |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'python/src/hypergraph.pxi')
-rw-r--r-- | python/src/hypergraph.pxi | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/python/src/hypergraph.pxi b/python/src/hypergraph.pxi new file mode 100644 index 00000000..c226d105 --- /dev/null +++ b/python/src/hypergraph.pxi @@ -0,0 +1,92 @@ +cimport hypergraph +cimport kbest as kb + +cdef class Hypergraph: + cdef hypergraph.Hypergraph* hg + cdef MT19937* rng + + def __dealloc__(self): + del self.hg + if self.rng != NULL: + del self.rng + + def viterbi(self): + cdef vector[WordID] trans + hypergraph.ViterbiESentence(self.hg[0], &trans) + cdef str sentence = GetString(trans).c_str() + return sentence.decode('utf8') + + def viterbi_tree(self): + cdef str tree = hypergraph.ViterbiETree(self.hg[0]).c_str() + return tree.decode('utf8') + + def viterbi_source_tree(self): + cdef str tree = hypergraph.ViterbiFTree(self.hg[0]).c_str() + return tree.decode('utf8') + + def viterbi_features(self): + cdef SparseVector fmap = SparseVector() + fmap.vector = new FastSparseVector[weight_t](hypergraph.ViterbiFeatures(self.hg[0])) + return fmap + + def kbest(self, size): + cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal](self.hg[0], size) + cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal].Derivation* derivation + cdef str sentence + cdef unsigned k + for k in range(size): + derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) + if not derivation: break + sentence = GetString(derivation._yield).c_str() + yield sentence.decode('utf8') + del derivations + + def kbest_tree(self, size): + cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ETreeTraversal](self.hg[0], size) + cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal].Derivation* derivation + cdef str tree + cdef unsigned k + for k in range(size): + derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) + if not derivation: break + tree = GetString(derivation._yield).c_str() + yield tree.decode('utf8') + del derivations + + def sample(self, unsigned n): + cdef vector[hypergraph.Hypothesis]* hypos = new vector[hypergraph.Hypothesis]() + if self.rng == NULL: + self.rng = new MT19937() + hypergraph.sample_hypotheses(self.hg[0], n, self.rng, hypos) + cdef str sentence + cdef unsigned k + for k in range(hypos.size()): + sentence = GetString(hypos[0][k].words).c_str() + yield sentence.decode('utf8') + del hypos + + # TODO richer k-best/sample output (feature vectors, trees?) + + def intersect(self, Lattice lat): + return hypergraph.Intersect(lat.lattice[0], self.hg) + + def prune(self, beam_alpha=0, density=0, **kwargs): + cdef hypergraph.EdgeMask* preserve_mask = NULL + if 'csplit_preserve_full_word' in kwargs: + preserve_mask = new hypergraph.EdgeMask(self.hg.edges_.size()) + preserve_mask[0][hypergraph.GetFullWordEdgeIndex(self.hg[0])] = True + self.hg.PruneInsideOutside(beam_alpha, density, preserve_mask, False, 1, False) + + def lattice(self): # TODO direct hg -> lattice conversion in cdec + cdef str plf = hypergraph.AsPLF(self.hg[0], True).c_str() + return Lattice(eval(plf)) + + def reweight(self, weights): + if isinstance(weights, SparseVector): + self.hg.Reweight((<SparseVector> weights).vector[0]) + elif isinstance(weights, DenseVector): + self.hg.Reweight((<DenseVector> weights).vector[0]) + else: + raise ValueError('cannot reweight hypergraph with %s' % type(weights)) + + # TODO get feature expectations, get partition function ("inside" score) |