diff options
author | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-07-11 16:08:43 +0900 |
---|---|---|
committer | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-07-11 16:08:43 +0900 |
commit | 42532406b1246e6f17766b804b8bd5cce828f0fa (patch) | |
tree | b867b7fded89c3de32f0629241a86ea735186884 /python/src/hypergraph.pxi | |
parent | 757f56e391bd2e1d7442ab38fc98aff00d064d38 (diff) |
[python] Direct hypergraph access
- small API changes (*_trees methods)
- decoder config can now passed as arguments
Diffstat (limited to 'python/src/hypergraph.pxi')
-rw-r--r-- | python/src/hypergraph.pxi | 161 |
1 files changed, 130 insertions, 31 deletions
diff --git a/python/src/hypergraph.pxi b/python/src/hypergraph.pxi index 9d09722e..2e2c04a2 100644 --- a/python/src/hypergraph.pxi +++ b/python/src/hypergraph.pxi @@ -1,5 +1,5 @@ cimport hypergraph -cimport kbest as kb +cimport kbest cdef class Hypergraph: cdef hypergraph.Hypergraph* hg @@ -13,53 +13,54 @@ cdef class Hypergraph: def viterbi(self): cdef vector[WordID] trans hypergraph.ViterbiESentence(self.hg[0], &trans) - cdef str sentence = GetString(trans).c_str() - return sentence.decode('utf8') + return unicode(GetString(trans).c_str(), 'utf8') - def viterbi_tree(self): - cdef str tree = hypergraph.ViterbiETree(self.hg[0]).c_str() - return tree.decode('utf8') - - def viterbi_source_tree(self): - cdef str tree = hypergraph.ViterbiFTree(self.hg[0]).c_str() - return tree.decode('utf8') + def viterbi_trees(self): + f_tree = unicode(hypergraph.ViterbiFTree(self.hg[0]).c_str(), 'utf8') + e_tree = unicode(hypergraph.ViterbiETree(self.hg[0]).c_str(), 'utf8') + return (f_tree, e_tree) def viterbi_features(self): cdef SparseVector fmap = SparseVector() fmap.vector = new FastSparseVector[weight_t](hypergraph.ViterbiFeatures(self.hg[0])) return fmap + def viterbi_joshua(self): + return unicode(hypergraph.JoshuaVisualizationString(self.hg[0]).c_str(), 'utf8') + def kbest(self, size): - cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal](self.hg[0], size) - cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal].Derivation* derivation - cdef bytes sentence + cdef kbest.KBestDerivations[vector[WordID], kbest.ESentenceTraversal]* derivations = new kbest.KBestDerivations[vector[WordID], kbest.ESentenceTraversal](self.hg[0], size) + cdef kbest.KBestDerivations[vector[WordID], kbest.ESentenceTraversal].Derivation* derivation cdef unsigned k try: for k in range(size): derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) if not derivation: break - sentence = GetString(derivation._yield).c_str() - yield sentence.decode('utf8') + yield unicode(GetString(derivation._yield).c_str(), 'utf8') finally: del derivations - def kbest_tree(self, size): - cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ETreeTraversal](self.hg[0], size) - cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal].Derivation* derivation - cdef str tree + def kbest_trees(self, size): + cdef kbest.KBestDerivations[vector[WordID], kbest.FTreeTraversal]* f_derivations = new kbest.KBestDerivations[vector[WordID], kbest.FTreeTraversal](self.hg[0], size) + cdef kbest.KBestDerivations[vector[WordID], kbest.FTreeTraversal].Derivation* f_derivation + cdef kbest.KBestDerivations[vector[WordID], kbest.ETreeTraversal]* e_derivations = new kbest.KBestDerivations[vector[WordID], kbest.ETreeTraversal](self.hg[0], size) + cdef kbest.KBestDerivations[vector[WordID], kbest.ETreeTraversal].Derivation* e_derivation cdef unsigned k try: for k in range(size): - derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) - if not derivation: break - tree = GetString(derivation._yield).c_str() - yield tree.decode('utf8') + f_derivation = f_derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) + e_derivation = e_derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) + if not f_derivation or not e_derivation: break + f_tree = unicode(GetString(f_derivation._yield).c_str(), 'utf8') + e_tree = unicode(GetString(e_derivation._yield).c_str(), 'utf8') + yield (f_tree, e_tree) finally: - del derivations + del f_derivations + del e_derivations def kbest_features(self, size): - cdef kb.KBestDerivations[FastSparseVector[weight_t], kb.FeatureVectorTraversal]* derivations = new kb.KBestDerivations[FastSparseVector[weight_t], kb.FeatureVectorTraversal](self.hg[0], size) - cdef kb.KBestDerivations[FastSparseVector[weight_t], kb.FeatureVectorTraversal].Derivation* derivation + cdef kbest.KBestDerivations[FastSparseVector[weight_t], kbest.FeatureVectorTraversal]* derivations = new kbest.KBestDerivations[FastSparseVector[weight_t], kbest.FeatureVectorTraversal](self.hg[0], size) + cdef kbest.KBestDerivations[FastSparseVector[weight_t], kbest.FeatureVectorTraversal].Derivation* derivation cdef SparseVector fmap cdef unsigned k try: @@ -77,17 +78,13 @@ cdef class Hypergraph: if self.rng == NULL: self.rng = new MT19937() hypergraph.sample_hypotheses(self.hg[0], n, self.rng, hypos) - cdef str sentence cdef unsigned k try: for k in range(hypos.size()): - sentence = GetString(hypos[0][k].words).c_str() - yield sentence.decode('utf8') + yield unicode(GetString(hypos[0][k].words).c_str(), 'utf8') finally: del hypos - # TODO richer k-best/sample output (feature vectors, trees?) - def intersect(self, Lattice lat): return hypergraph.Intersect(lat.lattice[0], self.hg) @@ -113,3 +110,105 @@ cdef class Hypergraph: raise TypeError('cannot reweight hypergraph with %s' % type(weights)) # TODO get feature expectations, get partition function ("inside" score) + + property edges: + def __get__(self): + cdef unsigned i + for i in range(self.hg.edges_.size()): + yield HypergraphEdge().init(self.hg, i) + + property nodes: + def __get__(self): + cdef unsigned i + for i in range(self.hg.nodes_.size()): + yield HypergraphNode().init(self.hg, i) + + property goal: + def __get__(self): + return HypergraphNode().init(self.hg, self.hg.GoalNode()) + + +include "trule.pxi" + +cdef class HypergraphEdge: + cdef hypergraph.Hypergraph* hg + cdef hypergraph.HypergraphEdge* edge + cdef public TRule trule + + cdef init(self, hypergraph.Hypergraph* hg, unsigned i): + self.hg = hg + self.edge = &hg.edges_[i] + self.trule = TRule() + self.trule.rule = self.edge.rule_.get() + return self + + def __len__(self): + return self.edge.tail_nodes_.size() + + property head_node: + def __get__(self): + return HypergraphNode().init(self.hg, self.edge.head_node_) + + property tail_nodes: + def __get__(self): + cdef unsigned i + for i in range(self.edge.tail_nodes_.size()): + yield HypergraphNode().init(self.hg, self.edge.tail_nodes_[i]) + + property span: + def __get__(self): + return (self.edge.i_, self.edge.j_) + + property feature_values: + def __get__(self): + cdef SparseVector vector = SparseVector() + vector.vector = new FastSparseVector[double](self.edge.feature_values_) + return vector + + property prob: + def __get__(self): + return self.edge.edge_prob_.as_float() + + def __richcmp__(HypergraphEdge x, HypergraphEdge y, int op): + if op == 2: # == + return x.edge == y.edge + elif op == 3: # != + return not (x == y) + raise NotImplemented('comparison not implemented for HypergraphEdge') + +cdef class HypergraphNode: + cdef hypergraph.Hypergraph* hg + cdef hypergraph.HypergraphNode* node + + cdef init(self, hypergraph.Hypergraph* hg, unsigned i): + self.hg = hg + self.node = &hg.nodes_[i] + return self + + property in_edges: + def __get__(self): + cdef unsigned i + for i in range(self.node.in_edges_.size()): + yield HypergraphEdge().init(self.hg, self.node.in_edges_[i]) + + property out_edges: + def __get__(self): + cdef unsigned i + for i in range(self.node.out_edges_.size()): + yield HypergraphEdge().init(self.hg, self.node.out_edges_[i]) + + property span: + def __get__(self): + return next(self.in_edges).span + + property cat: + def __get__(self): + if self.node.cat_: + return TDConvert(-self.node.cat_) + + def __richcmp__(HypergraphNode x, HypergraphNode y, int op): + if op == 2: # == + return x.node == y.node + elif op == 3: # != + return not (x == y) + raise NotImplemented('comparison not implemented for HypergraphNode') |