diff options
Diffstat (limited to 'python/src/_cdec.pyx')
-rw-r--r-- | python/src/_cdec.pyx | 87 |
1 files changed, 84 insertions, 3 deletions
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx index b99f087d..45320c46 100644 --- a/python/src/_cdec.pyx +++ b/python/src/_cdec.pyx @@ -1,8 +1,11 @@ from libcpp.string cimport string from libcpp.vector cimport vector +from cython.operator cimport dereference as deref from utils cimport * cimport hypergraph cimport decoder +cimport lattice +cimport kbest as kb SetSilent(True) @@ -16,13 +19,13 @@ cdef class Weights: self.weights = &decoder.dec.CurrentWeightVector() def __getitem__(self, char* fname): - cdef unsigned fid = Convert(fname) + cdef unsigned fid = FDConvert(fname) if fid <= self.weights.size(): return self.weights[0][fid] raise KeyError(fname) def __setitem__(self, char* fname, float value): - cdef unsigned fid = Convert(<char *>fname) + cdef unsigned fid = FDConvert(<char *>fname) if self.weights.size() <= fid: self.weights.resize(fid + 1) self.weights[0][fid] = value @@ -30,7 +33,7 @@ cdef class Weights: def __iter__(self): cdef unsigned fid for fid in range(1, self.weights.size()): - yield Convert(fid).c_str(), self.weights[0][fid] + yield FDConvert(fid).c_str(), self.weights[0][fid] cdef class Decoder: cdef decoder.Decoder* dec @@ -66,6 +69,7 @@ cdef class Decoder: fname, value = line.split() self.weights[fname.strip()] = float(value) + # TODO: list, lattice translation def translate(self, unicode sentence, grammar=None): if grammar: self.dec.SetSentenceGrammarFromString(string(<char *> grammar)) @@ -81,6 +85,12 @@ cdef class Decoder: cdef class Hypergraph: cdef hypergraph.Hypergraph* hg + cdef MT19937* rng + + def __dealloc__(self): + del self.hg + if self.rng != NULL: + del self.rng def viterbi(self): assert (self.hg != NULL) @@ -89,6 +99,77 @@ cdef class Hypergraph: cdef str sentence = GetString(trans).c_str() return sentence.decode('utf8') + def viterbi_tree(self): + assert (self.hg != NULL) + cdef str tree = hypergraph.ViterbiETree(self.hg[0]).c_str() + return tree.decode('utf8') + + def kbest(self, size): + assert (self.hg != NULL) + cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal](self.hg[0], size) + cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal].Derivation* derivation + cdef str tree + cdef unsigned k + for k in range(size): + derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) + if not derivation: break + tree = GetString(derivation._yield).c_str() + yield tree.decode('utf8') + del derivations + + def kbest_tree(self, size): + assert (self.hg != NULL) + cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ETreeTraversal](self.hg[0], size) + cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal].Derivation* derivation + cdef str sentence + cdef unsigned k + for k in range(size): + derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) + if not derivation: break + sentence = GetString(derivation._yield).c_str() + yield sentence.decode('utf8') + del derivations + + def intersect(self, Lattice lat): + assert (self.hg != NULL) + hypergraph.Intersect(lat.lattice[0], self.hg) + + def sample(self, unsigned n): + assert (self.hg != NULL) + cdef vector[hypergraph.Hypothesis]* hypos = new vector[hypergraph.Hypothesis]() + if self.rng == NULL: + self.rng = new MT19937() + hypergraph.sample_hypotheses(self.hg[0], n, self.rng, hypos) + cdef str sentence + cdef unsigned k + for k in range(hypos.size()): + sentence = GetString(hypos[0][k].words).c_str() + yield sentence.decode('utf8') + del hypos + + # TODO: get feature expectations, get partition function ("inside" score) + # TODO: reweight the forest with different weights (Hypergraph::Reweight) + # TODO: inside-outside pruning + +cdef class Lattice: + cdef lattice.Lattice* lattice + + def __init__(self, tuple plf_tuple): + self.lattice = new lattice.Lattice() + cdef bytes plf = str(plf_tuple) + hypergraph.PLFtoLattice(string(<char *>plf), self.lattice) + + def __str__(self): + return hypergraph.AsPLF(self.lattice[0]).c_str() + + def __iter__(self): + return iter(eval(str(self))) + + def __dealloc__(self): + del self.lattice + +# TODO: wrap SparseVector + """ def params_str(params): return '\n'.join('%s=%s' % (param, value) for param, value in params.iteritems()) |