summaryrefslogtreecommitdiff
path: root/python/src/_cdec.pyx
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-08-03 07:46:54 -0400
committerKenneth Heafield <github@kheafield.com>2012-08-03 07:46:54 -0400
commit122f46c31102b683eaab3ad81a3a98accbc694bb (patch)
tree8d499d789b159ebed25bb23b6983813d064a6296 /python/src/_cdec.pyx
parentac664bdb0e481539cf77098a7dd0e1ec8d937ba0 (diff)
parent193d137056c3c4f73d66f8db84691d63307de894 (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'python/src/_cdec.pyx')
-rw-r--r--python/src/_cdec.pyx217
1 files changed, 87 insertions, 130 deletions
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx
index 664724dd..e93474fe 100644
--- a/python/src/_cdec.pyx
+++ b/python/src/_cdec.pyx
@@ -1,156 +1,113 @@
from libcpp.string cimport string
from libcpp.vector cimport vector
-from cython.operator cimport dereference as deref
from utils cimport *
-cimport hypergraph
cimport decoder
-cimport lattice
-cimport kbest as kb
-SetSilent(True)
-
-class ParseFailed(Exception):
- pass
-
-cdef class Weights:
- cdef vector[weight_t]* weights
+cdef char* as_str(data, char* error_msg='Cannot convert type %s to str'):
+ cdef bytes ret
+ if isinstance(data, unicode):
+ ret = data.encode('utf8')
+ elif isinstance(data, str):
+ ret = data
+ else:
+ raise TypeError(error_msg.format(type(data)))
+ return ret
+
+include "vectors.pxi"
+include "grammar.pxi"
+include "hypergraph.pxi"
+include "lattice.pxi"
+include "mteval.pxi"
- def __cinit__(self, Decoder decoder):
- self.weights = &decoder.dec.CurrentWeightVector()
-
- def __getitem__(self, char* fname):
- cdef unsigned fid = FDConvert(fname)
- if fid <= self.weights.size():
- return self.weights[0][fid]
- raise KeyError(fname)
-
- def __setitem__(self, char* fname, float value):
- cdef unsigned fid = FDConvert(<char *>fname)
- if self.weights.size() <= fid:
- self.weights.resize(fid + 1)
- self.weights[0][fid] = value
-
- def __iter__(self):
- cdef unsigned fid
- for fid in range(1, self.weights.size()):
- yield FDConvert(fid).c_str(), self.weights[0][fid]
+SetSilent(True)
+decoder.register_feature_functions()
+
+class InvalidConfig(Exception): pass
+class ParseFailed(Exception): pass
+
+def _make_config(config):
+ for key, value in config.items():
+ if isinstance(value, dict):
+ for name, info in value.items():
+ yield key, '%s %s' % (name, info)
+ elif isinstance(value, list):
+ for name in value:
+ yield key, name
+ else:
+ yield key, bytes(value)
cdef class Decoder:
cdef decoder.Decoder* dec
- cdef public Weights weights
-
- def __cinit__(self, char* config):
- decoder.register_feature_functions()
- cdef istringstream* config_stream = new istringstream(config)
+ cdef DenseVector weights
+
+ def __cinit__(self, config_str=None, **config):
+ """ Configuration can be given as a string:
+ Decoder('formalism = scfg')
+ or using keyword arguments:
+ Decoder(formalism='scfg')
+ """
+ if config_str is None:
+ formalism = config.get('formalism', None)
+ if formalism not in ('scfg', 'fst', 'lextrans', 'pb',
+ 'csplit', 'tagger', 'lexalign'):
+ raise InvalidConfig('formalism "%s" unknown' % formalism)
+ config_str = '\n'.join('%s = %s' % kv for kv in _make_config(config))
+ cdef istringstream* config_stream = new istringstream(config_str)
self.dec = new decoder.Decoder(config_stream)
del config_stream
- self.weights = Weights(self)
+ self.weights = DenseVector.__new__(DenseVector)
+ self.weights.vector = &self.dec.CurrentWeightVector()
+ self.weights.owned = True
def __dealloc__(self):
del self.dec
- def read_weights(self, cfg):
- with open(cfg) as fp:
+ property weights:
+ def __get__(self):
+ return self.weights
+
+ def __set__(self, weights):
+ if isinstance(weights, DenseVector):
+ self.weights.vector[0] = (<DenseVector> weights).vector[0]
+ elif isinstance(weights, SparseVector):
+ self.weights.vector.clear()
+ ((<SparseVector> weights).vector[0]).init_vector(self.weights.vector)
+ elif isinstance(weights, dict):
+ self.weights.vector.clear()
+ for fname, fval in weights.items():
+ self.weights[fname] = fval
+ else:
+ raise TypeError('cannot initialize weights with %s' % type(weights))
+
+ property formalism:
+ def __get__(self):
+ cdef variables_map* conf = &self.dec.GetConf()
+ return conf[0]['formalism'].as_str()
+
+ def read_weights(self, weights):
+ with open(weights) as fp:
for line in fp:
+ if line.strip().startswith('#'): continue
fname, value = line.split()
self.weights[fname.strip()] = float(value)
- # TODO: list, lattice translation
- def translate(self, unicode sentence, grammar=None):
+ def translate(self, sentence, grammar=None):
+ cdef bytes input_str
+ if isinstance(sentence, unicode) or isinstance(sentence, str):
+ input_str = as_str(sentence.strip())
+ elif isinstance(sentence, Lattice):
+ input_str = str(sentence) # PLF format
+ else:
+ raise TypeError('Cannot translate input type %s' % type(sentence))
if grammar:
- self.dec.SetSentenceGrammarFromString(string(<char *> grammar))
- inp = sentence.strip().encode('utf8')
+ if isinstance(grammar, str) or isinstance(grammar, unicode):
+ self.dec.AddSupplementalGrammarFromString(string(as_str(grammar)))
+ else:
+ self.dec.AddSupplementalGrammar(TextGrammar(grammar).grammar[0])
cdef decoder.BasicObserver observer = decoder.BasicObserver()
- self.dec.Decode(string(<char *>inp), &observer)
+ self.dec.Decode(string(input_str), &observer)
if observer.hypergraph == NULL:
raise ParseFailed()
cdef Hypergraph hg = Hypergraph()
hg.hg = new hypergraph.Hypergraph(observer.hypergraph[0])
return hg
-
-cdef class Hypergraph:
- cdef hypergraph.Hypergraph* hg
- cdef MT19937* rng
-
- def __dealloc__(self):
- del self.hg
- if self.rng != NULL:
- del self.rng
-
- def viterbi(self):
- assert (self.hg != NULL)
- cdef vector[WordID] trans
- hypergraph.ViterbiESentence(self.hg[0], &trans)
- cdef str sentence = GetString(trans).c_str()
- return sentence.decode('utf8')
-
- def viterbi_tree(self):
- assert (self.hg != NULL)
- cdef str tree = hypergraph.ViterbiETree(self.hg[0]).c_str()
- return tree.decode('utf8')
-
- def kbest(self, size):
- assert (self.hg != NULL)
- cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal](self.hg[0], size)
- cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal].Derivation* derivation
- cdef str tree
- cdef unsigned k
- for k in range(size):
- derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
- if not derivation: break
- tree = GetString(derivation._yield).c_str()
- yield tree.decode('utf8')
- del derivations
-
- def kbest_tree(self, size):
- assert (self.hg != NULL)
- cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ETreeTraversal](self.hg[0], size)
- cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal].Derivation* derivation
- cdef str sentence
- cdef unsigned k
- for k in range(size):
- derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k)
- if not derivation: break
- sentence = GetString(derivation._yield).c_str()
- yield sentence.decode('utf8')
- del derivations
-
- def intersect(self, Lattice lat):
- assert (self.hg != NULL)
- hypergraph.Intersect(lat.lattice[0], self.hg)
-
- def sample(self, unsigned n):
- assert (self.hg != NULL)
- cdef vector[hypergraph.Hypothesis]* hypos = new vector[hypergraph.Hypothesis]()
- if self.rng == NULL:
- self.rng = new MT19937()
- hypergraph.sample_hypotheses(self.hg[0], n, self.rng, hypos)
- cdef str sentence
- cdef unsigned k
- for k in range(hypos.size()):
- sentence = GetString(hypos[0][k].words).c_str()
- yield sentence.decode('utf8')
- del hypos
-
- # TODO: get feature expectations, get partition function ("inside" score)
- # TODO: reweight the forest with different weights (Hypergraph::Reweight)
- # TODO: inside-outside pruning
-
-cdef class Lattice:
- cdef lattice.Lattice* lattice
-
- def __init__(self, tuple plf_tuple):
- self.lattice = new lattice.Lattice()
- cdef bytes plf = str(plf_tuple)
- hypergraph.PLFtoLattice(string(<char *>plf), self.lattice)
-
- def __str__(self):
- return hypergraph.AsPLF(self.lattice[0]).c_str()
-
- def __iter__(self):
- return iter(eval(str(self)))
-
- def __dealloc__(self):
- del self.lattice
-
-# TODO: wrap SparseVector