diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-06-05 15:47:09 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-06-05 15:47:09 -0400 |
commit | 6d95c0b1edea38a972c287793fabae8f7f0eea1a (patch) | |
tree | a70c0aaddf6e3313ebe349b6a90542ab5359b3be /python/src/_cdec.pyx | |
parent | 76950778554a836230b68aa1ba517bc19521a66f (diff) | |
parent | 57d8aeaf862bad8516e8aa87eca10fdb75bdcb3a (diff) |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'python/src/_cdec.pyx')
-rw-r--r-- | python/src/_cdec.pyx | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx new file mode 100644 index 00000000..b99f087d --- /dev/null +++ b/python/src/_cdec.pyx @@ -0,0 +1,99 @@ +from libcpp.string cimport string +from libcpp.vector cimport vector +from utils cimport * +cimport hypergraph +cimport decoder + +SetSilent(True) + +class ParseFailed(Exception): + pass + +cdef class Weights: + cdef vector[weight_t]* weights + + def __cinit__(self, Decoder decoder): + self.weights = &decoder.dec.CurrentWeightVector() + + def __getitem__(self, char* fname): + cdef unsigned fid = Convert(fname) + if fid <= self.weights.size(): + return self.weights[0][fid] + raise KeyError(fname) + + def __setitem__(self, char* fname, float value): + cdef unsigned fid = Convert(<char *>fname) + if self.weights.size() <= fid: + self.weights.resize(fid + 1) + self.weights[0][fid] = value + + def __iter__(self): + cdef unsigned fid + for fid in range(1, self.weights.size()): + yield Convert(fid).c_str(), self.weights[0][fid] + +cdef class Decoder: + cdef decoder.Decoder* dec + cdef public Weights weights + + def __cinit__(self, char* config): + decoder.register_feature_functions() + cdef istringstream* config_stream = new istringstream(config) # ConfigStream(kwargs) + #cdef ReadFile* config_file = new ReadFile(string(config)) + #cdef istream* config_stream = config_file.stream() + self.dec = new decoder.Decoder(config_stream) + del config_stream + #del config_file + self.weights = Weights(self) + + def __dealloc__(self): + del self.dec + + @classmethod + def fromconfig(cls, ini): + cdef dict config = {} + with open(ini) as fp: + for line in fp: + line = line.strip() + if not line or line.startswith('#'): continue + param, value = line.split('=') + config[param.strip()] = value.strip() + return cls(**config) + + def read_weights(self, cfg): + with open(cfg) as fp: + for line in fp: + fname, value = line.split() + self.weights[fname.strip()] = float(value) + + def translate(self, unicode sentence, grammar=None): + if grammar: + self.dec.SetSentenceGrammarFromString(string(<char *> grammar)) + #sgml = '<seg grammar="%s">%s</seg>' % (grammar, sentence.encode('utf8')) + sgml = sentence.strip().encode('utf8') + cdef decoder.BasicObserver observer = decoder.BasicObserver() + self.dec.Decode(string(<char *>sgml), &observer) + if observer.hypergraph == NULL: + raise ParseFailed() + cdef Hypergraph hg = Hypergraph() + hg.hg = new hypergraph.Hypergraph(observer.hypergraph[0]) + return hg + +cdef class Hypergraph: + cdef hypergraph.Hypergraph* hg + + def viterbi(self): + assert (self.hg != NULL) + cdef vector[WordID] trans + hypergraph.ViterbiESentence(self.hg[0], &trans) + cdef str sentence = GetString(trans).c_str() + return sentence.decode('utf8') + +""" +def params_str(params): + return '\n'.join('%s=%s' % (param, value) for param, value in params.iteritems()) + +cdef istringstream* ConfigStream(dict params): + ini = params_str(params) + return new istringstream(<char *> ini) +""" |