summaryrefslogtreecommitdiff
path: root/python/src/hypergraph.pxi
diff options
context:
space:
mode:
authorVictor Chahuneau <vchahune@cs.cmu.edu>2012-07-11 16:08:43 +0900
committerVictor Chahuneau <vchahune@cs.cmu.edu>2012-07-11 16:08:43 +0900
commit42532406b1246e6f17766b804b8bd5cce828f0fa (patch)
treeb867b7fded89c3de32f0629241a86ea735186884 /python/src/hypergraph.pxi
parent757f56e391bd2e1d7442ab38fc98aff00d064d38 (diff)
[python] Direct hypergraph access
- small API changes (*_trees methods) - decoder config can now passed as arguments
Diffstat (limited to 'python/src/hypergraph.pxi')
-rw-r--r--python/src/hypergraph.pxi161
1 files changed, 130 insertions, 31 deletions
diff --git a/python/src/hypergraph.pxi b/python/src/hypergraph.pxi
index 9d09722e..2e2c04a2 100644
--- a/python/src/hypergraph.pxi
+++ b/python/src/hypergraph.pxi
@@ -1,5 +1,5 @@
cimport hypergraph
-cimport kbest as kb
+cimport kbest
cdef class Hypergraph:
cdef hypergraph.Hypergraph* hg
@@ -13,53 +13,54 @@ cdef class Hypergraph:
def viterbi(self):
cdef vector[WordID] trans
hypergraph.ViterbiESentence(self.hg[0], &trans)
- cdef str sentence = GetString(trans).c_str()
- return sentence.decode('utf8')
+ return unicode(GetString(trans).c_str(), 'utf8')
- def viterbi_tree(self):
- cdef str tree = hypergraph.ViterbiETree(self.hg[0]).c_str()
- return tree.decode('utf8')
-
- def viterbi_source_tree(self):
- cdef str tree = hypergraph.ViterbiFTree(self.hg[0]).c_str()
- return tree.decode('utf8')
+ def viterbi_trees(self):
+ f_tree = unicode(hypergraph.ViterbiFTree(self.hg[0]).c_str(), 'utf8')
+ e_tree = unicode(hypergraph.ViterbiETree(self.hg[0]).c_str(), 'utf8')
+ return (f_tree, e_tree)
def viterbi_features(self):
cdef SparseVector fmap = SparseVector()
fmap.vector = new FastSparseVector[weight_t](hypergraph.ViterbiFeatures(self.hg[0]))
return fmap
+ def viterbi_joshua(self):
+ return unicode(hypergraph.JoshuaVisualizationString(self.hg[0]).c_str(), 'utf8')
+
def kbest(self, size):
- cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal](self.hg[0], size)
- cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal].Derivation* derivation
- cdef bytes sentence
+ cdef kbest.KBestDerivations[vector[WordID], kbest.ESentenceTraversal]* derivations = new kbest.KBestDerivations[vector[WordID], kbest.ESentenceTraversal](self.hg[0], size)
+ cdef kbest.KBestDerivations[vector[WordID], kbest.ESentenceTraversal].Derivation* derivation
cdef unsigned k
try:
for k in range(size):
derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
if not derivation: break
- sentence = GetString(derivation._yield).c_str()
- yield sentence.decode('utf8')
+ yield unicode(GetString(derivation._yield).c_str(), 'utf8')
finally:
del derivations
- def kbest_tree(self, size):
- cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ETreeTraversal](self.hg[0], size)
- cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal].Derivation* derivation
- cdef str tree
+ def kbest_trees(self, size):
+ cdef kbest.KBestDerivations[vector[WordID], kbest.FTreeTraversal]* f_derivations = new kbest.KBestDerivations[vector[WordID], kbest.FTreeTraversal](self.hg[0], size)
+ cdef kbest.KBestDerivations[vector[WordID], kbest.FTreeTraversal].Derivation* f_derivation
+ cdef kbest.KBestDerivations[vector[WordID], kbest.ETreeTraversal]* e_derivations = new kbest.KBestDerivations[vector[WordID], kbest.ETreeTraversal](self.hg[0], size)
+ cdef kbest.KBestDerivations[vector[WordID], kbest.ETreeTraversal].Derivation* e_derivation
cdef unsigned k
try:
for k in range(size):
- derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
- if not derivation: break
- tree = GetString(derivation._yield).c_str()
- yield tree.decode('utf8')
+ f_derivation = f_derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
+ e_derivation = e_derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
+ if not f_derivation or not e_derivation: break
+ f_tree = unicode(GetString(f_derivation._yield).c_str(), 'utf8')
+ e_tree = unicode(GetString(e_derivation._yield).c_str(), 'utf8')
+ yield (f_tree, e_tree)
finally:
- del derivations
+ del f_derivations
+ del e_derivations
def kbest_features(self, size):
- cdef kb.KBestDerivations[FastSparseVector[weight_t], kb.FeatureVectorTraversal]* derivations = new kb.KBestDerivations[FastSparseVector[weight_t], kb.FeatureVectorTraversal](self.hg[0], size)
- cdef kb.KBestDerivations[FastSparseVector[weight_t], kb.FeatureVectorTraversal].Derivation* derivation
+ cdef kbest.KBestDerivations[FastSparseVector[weight_t], kbest.FeatureVectorTraversal]* derivations = new kbest.KBestDerivations[FastSparseVector[weight_t], kbest.FeatureVectorTraversal](self.hg[0], size)
+ cdef kbest.KBestDerivations[FastSparseVector[weight_t], kbest.FeatureVectorTraversal].Derivation* derivation
cdef SparseVector fmap
cdef unsigned k
try:
@@ -77,17 +78,13 @@ cdef class Hypergraph:
if self.rng == NULL:
self.rng = new MT19937()
hypergraph.sample_hypotheses(self.hg[0], n, self.rng, hypos)
- cdef str sentence
cdef unsigned k
try:
for k in range(hypos.size()):
- sentence = GetString(hypos[0][k].words).c_str()
- yield sentence.decode('utf8')
+ yield unicode(GetString(hypos[0][k].words).c_str(), 'utf8')
finally:
del hypos
- # TODO richer k-best/sample output (feature vectors, trees?)
-
def intersect(self, Lattice lat):
return hypergraph.Intersect(lat.lattice[0], self.hg)
@@ -113,3 +110,105 @@ cdef class Hypergraph:
raise TypeError('cannot reweight hypergraph with %s' % type(weights))
# TODO get feature expectations, get partition function ("inside" score)
+
+ property edges:
+ def __get__(self):
+ cdef unsigned i
+ for i in range(self.hg.edges_.size()):
+ yield HypergraphEdge().init(self.hg, i)
+
+ property nodes:
+ def __get__(self):
+ cdef unsigned i
+ for i in range(self.hg.nodes_.size()):
+ yield HypergraphNode().init(self.hg, i)
+
+ property goal:
+ def __get__(self):
+ return HypergraphNode().init(self.hg, self.hg.GoalNode())
+
+
+include "trule.pxi"
+
+cdef class HypergraphEdge:
+ cdef hypergraph.Hypergraph* hg
+ cdef hypergraph.HypergraphEdge* edge
+ cdef public TRule trule
+
+ cdef init(self, hypergraph.Hypergraph* hg, unsigned i):
+ self.hg = hg
+ self.edge = &hg.edges_[i]
+ self.trule = TRule()
+ self.trule.rule = self.edge.rule_.get()
+ return self
+
+ def __len__(self):
+ return self.edge.tail_nodes_.size()
+
+ property head_node:
+ def __get__(self):
+ return HypergraphNode().init(self.hg, self.edge.head_node_)
+
+ property tail_nodes:
+ def __get__(self):
+ cdef unsigned i
+ for i in range(self.edge.tail_nodes_.size()):
+ yield HypergraphNode().init(self.hg, self.edge.tail_nodes_[i])
+
+ property span:
+ def __get__(self):
+ return (self.edge.i_, self.edge.j_)
+
+ property feature_values:
+ def __get__(self):
+ cdef SparseVector vector = SparseVector()
+ vector.vector = new FastSparseVector[double](self.edge.feature_values_)
+ return vector
+
+ property prob:
+ def __get__(self):
+ return self.edge.edge_prob_.as_float()
+
+ def __richcmp__(HypergraphEdge x, HypergraphEdge y, int op):
+ if op == 2: # ==
+ return x.edge == y.edge
+ elif op == 3: # !=
+ return not (x == y)
+ raise NotImplemented('comparison not implemented for HypergraphEdge')
+
+cdef class HypergraphNode:
+ cdef hypergraph.Hypergraph* hg
+ cdef hypergraph.HypergraphNode* node
+
+ cdef init(self, hypergraph.Hypergraph* hg, unsigned i):
+ self.hg = hg
+ self.node = &hg.nodes_[i]
+ return self
+
+ property in_edges:
+ def __get__(self):
+ cdef unsigned i
+ for i in range(self.node.in_edges_.size()):
+ yield HypergraphEdge().init(self.hg, self.node.in_edges_[i])
+
+ property out_edges:
+ def __get__(self):
+ cdef unsigned i
+ for i in range(self.node.out_edges_.size()):
+ yield HypergraphEdge().init(self.hg, self.node.out_edges_[i])
+
+ property span:
+ def __get__(self):
+ return next(self.in_edges).span
+
+ property cat:
+ def __get__(self):
+ if self.node.cat_:
+ return TDConvert(-self.node.cat_)
+
+ def __richcmp__(HypergraphNode x, HypergraphNode y, int op):
+ if op == 2: # ==
+ return x.node == y.node
+ elif op == 3: # !=
+ return not (x == y)
+ raise NotImplemented('comparison not implemented for HypergraphNode')