diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-08-03 07:46:54 -0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-08-03 07:46:54 -0400 |
commit | be1ab0a8937f9c5668ea5e6c31b798e87672e55e (patch) | |
tree | a13aad60ab6cced213401bce6a38ac885ba171ba /python/src/_cdec.pyx | |
parent | e5d6f4ae41009c26978ecd62668501af9762b0bc (diff) | |
parent | 9fe0219562e5db25171cce8776381600ff9a5649 (diff) |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'python/src/_cdec.pyx')
-rw-r--r-- | python/src/_cdec.pyx | 217 |
1 files changed, 87 insertions, 130 deletions
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx index 664724dd..e93474fe 100644 --- a/python/src/_cdec.pyx +++ b/python/src/_cdec.pyx @@ -1,156 +1,113 @@ from libcpp.string cimport string from libcpp.vector cimport vector -from cython.operator cimport dereference as deref from utils cimport * -cimport hypergraph cimport decoder -cimport lattice -cimport kbest as kb -SetSilent(True) - -class ParseFailed(Exception): - pass - -cdef class Weights: - cdef vector[weight_t]* weights +cdef char* as_str(data, char* error_msg='Cannot convert type %s to str'): + cdef bytes ret + if isinstance(data, unicode): + ret = data.encode('utf8') + elif isinstance(data, str): + ret = data + else: + raise TypeError(error_msg.format(type(data))) + return ret + +include "vectors.pxi" +include "grammar.pxi" +include "hypergraph.pxi" +include "lattice.pxi" +include "mteval.pxi" - def __cinit__(self, Decoder decoder): - self.weights = &decoder.dec.CurrentWeightVector() - - def __getitem__(self, char* fname): - cdef unsigned fid = FDConvert(fname) - if fid <= self.weights.size(): - return self.weights[0][fid] - raise KeyError(fname) - - def __setitem__(self, char* fname, float value): - cdef unsigned fid = FDConvert(<char *>fname) - if self.weights.size() <= fid: - self.weights.resize(fid + 1) - self.weights[0][fid] = value - - def __iter__(self): - cdef unsigned fid - for fid in range(1, self.weights.size()): - yield FDConvert(fid).c_str(), self.weights[0][fid] +SetSilent(True) +decoder.register_feature_functions() + +class InvalidConfig(Exception): pass +class ParseFailed(Exception): pass + +def _make_config(config): + for key, value in config.items(): + if isinstance(value, dict): + for name, info in value.items(): + yield key, '%s %s' % (name, info) + elif isinstance(value, list): + for name in value: + yield key, name + else: + yield key, bytes(value) cdef class Decoder: cdef decoder.Decoder* dec - cdef public Weights weights - - def __cinit__(self, char* config): - decoder.register_feature_functions() - cdef istringstream* config_stream = new istringstream(config) + cdef DenseVector weights + + def __cinit__(self, config_str=None, **config): + """ Configuration can be given as a string: + Decoder('formalism = scfg') + or using keyword arguments: + Decoder(formalism='scfg') + """ + if config_str is None: + formalism = config.get('formalism', None) + if formalism not in ('scfg', 'fst', 'lextrans', 'pb', + 'csplit', 'tagger', 'lexalign'): + raise InvalidConfig('formalism "%s" unknown' % formalism) + config_str = '\n'.join('%s = %s' % kv for kv in _make_config(config)) + cdef istringstream* config_stream = new istringstream(config_str) self.dec = new decoder.Decoder(config_stream) del config_stream - self.weights = Weights(self) + self.weights = DenseVector.__new__(DenseVector) + self.weights.vector = &self.dec.CurrentWeightVector() + self.weights.owned = True def __dealloc__(self): del self.dec - def read_weights(self, cfg): - with open(cfg) as fp: + property weights: + def __get__(self): + return self.weights + + def __set__(self, weights): + if isinstance(weights, DenseVector): + self.weights.vector[0] = (<DenseVector> weights).vector[0] + elif isinstance(weights, SparseVector): + self.weights.vector.clear() + ((<SparseVector> weights).vector[0]).init_vector(self.weights.vector) + elif isinstance(weights, dict): + self.weights.vector.clear() + for fname, fval in weights.items(): + self.weights[fname] = fval + else: + raise TypeError('cannot initialize weights with %s' % type(weights)) + + property formalism: + def __get__(self): + cdef variables_map* conf = &self.dec.GetConf() + return conf[0]['formalism'].as_str() + + def read_weights(self, weights): + with open(weights) as fp: for line in fp: + if line.strip().startswith('#'): continue fname, value = line.split() self.weights[fname.strip()] = float(value) - # TODO: list, lattice translation - def translate(self, unicode sentence, grammar=None): + def translate(self, sentence, grammar=None): + cdef bytes input_str + if isinstance(sentence, unicode) or isinstance(sentence, str): + input_str = as_str(sentence.strip()) + elif isinstance(sentence, Lattice): + input_str = str(sentence) # PLF format + else: + raise TypeError('Cannot translate input type %s' % type(sentence)) if grammar: - self.dec.SetSentenceGrammarFromString(string(<char *> grammar)) - inp = sentence.strip().encode('utf8') + if isinstance(grammar, str) or isinstance(grammar, unicode): + self.dec.AddSupplementalGrammarFromString(string(as_str(grammar))) + else: + self.dec.AddSupplementalGrammar(TextGrammar(grammar).grammar[0]) cdef decoder.BasicObserver observer = decoder.BasicObserver() - self.dec.Decode(string(<char *>inp), &observer) + self.dec.Decode(string(input_str), &observer) if observer.hypergraph == NULL: raise ParseFailed() cdef Hypergraph hg = Hypergraph() hg.hg = new hypergraph.Hypergraph(observer.hypergraph[0]) return hg - -cdef class Hypergraph: - cdef hypergraph.Hypergraph* hg - cdef MT19937* rng - - def __dealloc__(self): - del self.hg - if self.rng != NULL: - del self.rng - - def viterbi(self): - assert (self.hg != NULL) - cdef vector[WordID] trans - hypergraph.ViterbiESentence(self.hg[0], &trans) - cdef str sentence = GetString(trans).c_str() - return sentence.decode('utf8') - - def viterbi_tree(self): - assert (self.hg != NULL) - cdef str tree = hypergraph.ViterbiETree(self.hg[0]).c_str() - return tree.decode('utf8') - - def kbest(self, size): - assert (self.hg != NULL) - cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal](self.hg[0], size) - cdef kb.KBestDerivations[vector[WordID], kb.ESentenceTraversal].Derivation* derivation - cdef str tree - cdef unsigned k - for k in range(size): - derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) - if not derivation: break - tree = GetString(derivation._yield).c_str() - yield tree.decode('utf8') - del derivations - - def kbest_tree(self, size): - assert (self.hg != NULL) - cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal]* derivations = new kb.KBestDerivations[vector[WordID], kb.ETreeTraversal](self.hg[0], size) - cdef kb.KBestDerivations[vector[WordID], kb.ETreeTraversal].Derivation* derivation - cdef str sentence - cdef unsigned k - for k in range(size): - derivation = derivations.LazyKthBest(self.hg.nodes_.size() - 1, k) - if not derivation: break - sentence = GetString(derivation._yield).c_str() - yield sentence.decode('utf8') - del derivations - - def intersect(self, Lattice lat): - assert (self.hg != NULL) - hypergraph.Intersect(lat.lattice[0], self.hg) - - def sample(self, unsigned n): - assert (self.hg != NULL) - cdef vector[hypergraph.Hypothesis]* hypos = new vector[hypergraph.Hypothesis]() - if self.rng == NULL: - self.rng = new MT19937() - hypergraph.sample_hypotheses(self.hg[0], n, self.rng, hypos) - cdef str sentence - cdef unsigned k - for k in range(hypos.size()): - sentence = GetString(hypos[0][k].words).c_str() - yield sentence.decode('utf8') - del hypos - - # TODO: get feature expectations, get partition function ("inside" score) - # TODO: reweight the forest with different weights (Hypergraph::Reweight) - # TODO: inside-outside pruning - -cdef class Lattice: - cdef lattice.Lattice* lattice - - def __init__(self, tuple plf_tuple): - self.lattice = new lattice.Lattice() - cdef bytes plf = str(plf_tuple) - hypergraph.PLFtoLattice(string(<char *>plf), self.lattice) - - def __str__(self): - return hypergraph.AsPLF(self.lattice[0]).c_str() - - def __iter__(self): - return iter(eval(str(self))) - - def __dealloc__(self): - del self.lattice - -# TODO: wrap SparseVector |