summaryrefslogtreecommitdiff
path: root/python/src/_cdec.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/_cdec.pyx')
-rw-r--r--python/src/_cdec.pyx87
1 files changed, 84 insertions, 3 deletions
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx
index b99f087d..45320c46 100644
--- a/python/src/_cdec.pyx
+++ b/python/src/_cdec.pyx
@@ -1,8 +1,11 @@
from libcpp.string cimport string
from libcpp.vector cimport vector
+from cython.operator cimport dereference as deref
from utils cimport *
cimport hypergraph
cimport decoder
+cimport lattice
+cimport kbest as kb
SetSilent(True)
@@ -16,13 +19,13 @@ cdef class Weights:
self.weights = &decoder.dec.CurrentWeightVector()
def __getitem__(self, char* fname):
- cdef unsigned fid = Convert(fname)
+ cdef unsigned fid = FDConvert(fname)
if fid <= self.weights.size():
return self.weights[0][fid]
raise KeyError(fname)
def __setitem__(self, char* fname, float value):
- cdef unsigned fid = Convert(<char *>fname)
+ cdef unsigned fid = FDConvert(<char *>fname)
if self.weights.size() <= fid:
self.weights.resize(fid + 1)
self.weights[0][fid] = value
@@ -30,7 +33,7 @@ cdef class Weights:
def __iter__(self):
cdef unsigned fid
for fid in range(1, self.weights.size()):
- yield Convert(fid).c_str(), self.weights[0][fid]
+ yield FDConvert(fid).c_str(), self.weights[0][fid]
cdef class Decoder:
cdef decoder.Decoder* dec
@@ -66,6 +69,7 @@ cdef class Decoder:
fname, value = line.split()
self.weights[fname.strip()] = float(value)
+ # TODO: list, lattice translation
def translate(self, unicode sentence, grammar=None):
if grammar:
self.dec.SetSentenceGrammarFromString(string(<char *> grammar))
@@ -81,6 +85,12 @@ cdef class Decoder:
cdef class Hypergraph:
cdef hypergraph.Hypergraph* hg
+ cdef MT19937* rng
+
+ def __dealloc__(self):
+ del self.hg
+ if self.rng != NULL:
+ del self.rng
def viterbi(self):
assert (self.hg != NULL)
@@ -89,6 +99,77 @@ cdef class Hypergraph:
cdef str sentence = GetString(trans).c_str()
return sentence.decode('utf8')
+ def viterbi_tree(self):
+ assert (self.hg != NULL)
+ cdef str tree = hypergraph.ViterbiETree(self.hg[0]).c_str()
+ return tree.decode('utf8')
+
+ def kbest(self, size):
+ assert (self.hg != NULL)
+ cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal](self.hg[0], size)
+ cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal].Derivation* derivation
+ cdef str tree
+ cdef unsigned k
+ for k in range(size):
+ derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
+ if not derivation: break
+ tree = GetString(derivation._yield).c_str()
+ yield tree.decode('utf8')
+ del derivations
+
+ def kbest_tree(self, size):
+ assert (self.hg != NULL)
+ cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ETreeTraversal](self.hg[0], size)
+ cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal].Derivation* derivation
+ cdef str sentence
+ cdef unsigned k
+ for k in range(size):
+ derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
+ if not derivation: break
+ sentence = GetString(derivation._yield).c_str()
+ yield sentence.decode('utf8')
+ del derivations
+
+ def intersect(self, Lattice lat):
+ assert (self.hg != NULL)
+ hypergraph.Intersect(lat.lattice[0], self.hg)
+
+ def sample(self, unsigned n):
+ assert (self.hg != NULL)
+ cdef vector[hypergraph.Hypothesis]* hypos = new vector[hypergraph.Hypothesis]()
+ if self.rng == NULL:
+ self.rng = new MT19937()
+ hypergraph.sample_hypotheses(self.hg[0], n, self.rng, hypos)
+ cdef str sentence
+ cdef unsigned k
+ for k in range(hypos.size()):
+ sentence = GetString(hypos[0][k].words).c_str()
+ yield sentence.decode('utf8')
+ del hypos
+
+ # TODO: get feature expectations, get partition function ("inside" score)
+ # TODO: reweight the forest with different weights (Hypergraph::Reweight)
+ # TODO: inside-outside pruning
+
+cdef class Lattice:
+ cdef lattice.Lattice* lattice
+
+ def __init__(self, tuple plf_tuple):
+ self.lattice = new lattice.Lattice()
+ cdef bytes plf = str(plf_tuple)
+ hypergraph.PLFtoLattice(string(<char *>plf), self.lattice)
+
+ def __str__(self):
+ return hypergraph.AsPLF(self.lattice[0]).c_str()
+
+ def __iter__(self):
+ return iter(eval(str(self)))
+
+ def __dealloc__(self):
+ del self.lattice
+
+# TODO: wrap SparseVector
+
"""
def params_str(params):
return '\n'.join('%s=%s' % (param, value) for param, value in params.iteritems())