summaryrefslogtreecommitdiff
path: root/python/src/_cdec.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/_cdec.pyx')
-rw-r--r--python/src/_cdec.pyx180
1 files changed, 180 insertions, 0 deletions
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx
new file mode 100644
index 00000000..45320c46
--- /dev/null
+++ b/python/src/_cdec.pyx
@@ -0,0 +1,180 @@
+from libcpp.string cimport string
+from libcpp.vector cimport vector
+from cython.operator cimport dereference as deref
+from utils cimport *
+cimport hypergraph
+cimport decoder
+cimport lattice
+cimport kbest as kb
+
+SetSilent(True)
+
+class ParseFailed(Exception):
+ pass
+
+cdef class Weights:
+ cdef vector[weight_t]* weights
+
+ def __cinit__(self, Decoder decoder):
+ self.weights = &decoder.dec.CurrentWeightVector()
+
+ def __getitem__(self, char* fname):
+ cdef unsigned fid = FDConvert(fname)
+ if fid <= self.weights.size():
+ return self.weights[0][fid]
+ raise KeyError(fname)
+
+ def __setitem__(self, char* fname, float value):
+ cdef unsigned fid = FDConvert(<char *>fname)
+ if self.weights.size() <= fid:
+ self.weights.resize(fid + 1)
+ self.weights[0][fid] = value
+
+ def __iter__(self):
+ cdef unsigned fid
+ for fid in range(1, self.weights.size()):
+ yield FDConvert(fid).c_str(), self.weights[0][fid]
+
+cdef class Decoder:
+ cdef decoder.Decoder* dec
+ cdef public Weights weights
+
+ def __cinit__(self, char* config):
+ decoder.register_feature_functions()
+ cdef istringstream* config_stream = new istringstream(config) # ConfigStream(kwargs)
+ #cdef ReadFile* config_file = new ReadFile(string(config))
+ #cdef istream* config_stream = config_file.stream()
+ self.dec = new decoder.Decoder(config_stream)
+ del config_stream
+ #del config_file
+ self.weights = Weights(self)
+
+ def __dealloc__(self):
+ del self.dec
+
+ @classmethod
+ def fromconfig(cls, ini):
+ cdef dict config = {}
+ with open(ini) as fp:
+ for line in fp:
+ line = line.strip()
+ if not line or line.startswith('#'): continue
+ param, value = line.split('=')
+ config[param.strip()] = value.strip()
+ return cls(**config)
+
+ def read_weights(self, cfg):
+ with open(cfg) as fp:
+ for line in fp:
+ fname, value = line.split()
+ self.weights[fname.strip()] = float(value)
+
+ # TODO: list, lattice translation
+ def translate(self, unicode sentence, grammar=None):
+ if grammar:
+ self.dec.SetSentenceGrammarFromString(string(<char *> grammar))
+ #sgml = '<seg grammar="%s">%s</seg>' % (grammar, sentence.encode('utf8'))
+ sgml = sentence.strip().encode('utf8')
+ cdef decoder.BasicObserver observer = decoder.BasicObserver()
+ self.dec.Decode(string(<char *>sgml), &observer)
+ if observer.hypergraph == NULL:
+ raise ParseFailed()
+ cdef Hypergraph hg = Hypergraph()
+ hg.hg = new hypergraph.Hypergraph(observer.hypergraph[0])
+ return hg
+
+cdef class Hypergraph:
+ cdef hypergraph.Hypergraph* hg
+ cdef MT19937* rng
+
+ def __dealloc__(self):
+ del self.hg
+ if self.rng != NULL:
+ del self.rng
+
+ def viterbi(self):
+ assert (self.hg != NULL)
+ cdef vector[WordID] trans
+ hypergraph.ViterbiESentence(self.hg[0], &trans)
+ cdef str sentence = GetString(trans).c_str()
+ return sentence.decode('utf8')
+
+ def viterbi_tree(self):
+ assert (self.hg != NULL)
+ cdef str tree = hypergraph.ViterbiETree(self.hg[0]).c_str()
+ return tree.decode('utf8')
+
+ def kbest(self, size):
+ assert (self.hg != NULL)
+ cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal](self.hg[0], size)
+ cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal].Derivation* derivation
+ cdef str tree
+ cdef unsigned k
+ for k in range(size):
+ derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
+ if not derivation: break
+ tree = GetString(derivation._yield).c_str()
+ yield tree.decode('utf8')
+ del derivations
+
+ def kbest_tree(self, size):
+ assert (self.hg != NULL)
+ cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ETreeTraversal](self.hg[0], size)
+ cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal].Derivation* derivation
+ cdef str sentence
+ cdef unsigned k
+ for k in range(size):
+ derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
+ if not derivation: break
+ sentence = GetString(derivation._yield).c_str()
+ yield sentence.decode('utf8')
+ del derivations
+
+ def intersect(self, Lattice lat):
+ assert (self.hg != NULL)
+ hypergraph.Intersect(lat.lattice[0], self.hg)
+
+ def sample(self, unsigned n):
+ assert (self.hg != NULL)
+ cdef vector[hypergraph.Hypothesis]* hypos = new vector[hypergraph.Hypothesis]()
+ if self.rng == NULL:
+ self.rng = new MT19937()
+ hypergraph.sample_hypotheses(self.hg[0], n, self.rng, hypos)
+ cdef str sentence
+ cdef unsigned k
+ for k in range(hypos.size()):
+ sentence = GetString(hypos[0][k].words).c_str()
+ yield sentence.decode('utf8')
+ del hypos
+
+ # TODO: get feature expectations, get partition function ("inside" score)
+ # TODO: reweight the forest with different weights (Hypergraph::Reweight)
+ # TODO: inside-outside pruning
+
+cdef class Lattice:
+ cdef lattice.Lattice* lattice
+
+ def __init__(self, tuple plf_tuple):
+ self.lattice = new lattice.Lattice()
+ cdef bytes plf = str(plf_tuple)
+ hypergraph.PLFtoLattice(string(<char *>plf), self.lattice)
+
+ def __str__(self):
+ return hypergraph.AsPLF(self.lattice[0]).c_str()
+
+ def __iter__(self):
+ return iter(eval(str(self)))
+
+ def __dealloc__(self):
+ del self.lattice
+
+# TODO: wrap SparseVector
+
+"""
+def params_str(params):
+ return '\n'.join('%s=%s' % (param, value) for param, value in params.iteritems())
+
+cdef istringstream* ConfigStream(dict params):
+ ini = params_str(params)
+ return new istringstream(<char *> ini)
+"""