summaryrefslogtreecommitdiff
path: root/python/src/_cdec.pyx
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-06-05 15:47:09 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-06-05 15:47:09 -0400
commit23c7a1ed9f4d2360f8e1183539e3569fdbf60b48 (patch)
tree5cffb41d2cccb7ec3cc4583be12d78d1cedd9544 /python/src/_cdec.pyx
parent84bc2dd5a966eea898d4c0213205fd3917b285f4 (diff)
parent98e13188c40adf13b73ca994d3410f3e87d20355 (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'python/src/_cdec.pyx')
-rw-r--r--python/src/_cdec.pyx99
1 files changed, 99 insertions, 0 deletions
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx
new file mode 100644
index 00000000..b99f087d
--- /dev/null
+++ b/python/src/_cdec.pyx
@@ -0,0 +1,99 @@
+from libcpp.string cimport string
+from libcpp.vector cimport vector
+from utils cimport *
+cimport hypergraph
+cimport decoder
+
+SetSilent(True)
+
+class ParseFailed(Exception):
+ pass
+
+cdef class Weights:
+ cdef vector[weight_t]* weights
+
+ def __cinit__(self, Decoder decoder):
+ self.weights = &decoder.dec.CurrentWeightVector()
+
+ def __getitem__(self, char* fname):
+ cdef unsigned fid = Convert(fname)
+ if fid <= self.weights.size():
+ return self.weights[0][fid]
+ raise KeyError(fname)
+
+ def __setitem__(self, char* fname, float value):
+ cdef unsigned fid = Convert(<char *>fname)
+ if self.weights.size() <= fid:
+ self.weights.resize(fid + 1)
+ self.weights[0][fid] = value
+
+ def __iter__(self):
+ cdef unsigned fid
+ for fid in range(1, self.weights.size()):
+ yield Convert(fid).c_str(), self.weights[0][fid]
+
+cdef class Decoder:
+ cdef decoder.Decoder* dec
+ cdef public Weights weights
+
+ def __cinit__(self, char* config):
+ decoder.register_feature_functions()
+ cdef istringstream* config_stream = new istringstream(config) # ConfigStream(kwargs)
+ #cdef ReadFile* config_file = new ReadFile(string(config))
+ #cdef istream* config_stream = config_file.stream()
+ self.dec = new decoder.Decoder(config_stream)
+ del config_stream
+ #del config_file
+ self.weights = Weights(self)
+
+ def __dealloc__(self):
+ del self.dec
+
+ @classmethod
+ def fromconfig(cls, ini):
+ cdef dict config = {}
+ with open(ini) as fp:
+ for line in fp:
+ line = line.strip()
+ if not line or line.startswith('#'): continue
+ param, value = line.split('=')
+ config[param.strip()] = value.strip()
+ return cls(**config)
+
+ def read_weights(self, cfg):
+ with open(cfg) as fp:
+ for line in fp:
+ fname, value = line.split()
+ self.weights[fname.strip()] = float(value)
+
+ def translate(self, unicode sentence, grammar=None):
+ if grammar:
+ self.dec.SetSentenceGrammarFromString(string(<char *> grammar))
+ #sgml = '<seg grammar="%s">%s</seg>' % (grammar, sentence.encode('utf8'))
+ sgml = sentence.strip().encode('utf8')
+ cdef decoder.BasicObserver observer = decoder.BasicObserver()
+ self.dec.Decode(string(<char *>sgml), &observer)
+ if observer.hypergraph == NULL:
+ raise ParseFailed()
+ cdef Hypergraph hg = Hypergraph()
+ hg.hg = new hypergraph.Hypergraph(observer.hypergraph[0])
+ return hg
+
+cdef class Hypergraph:
+ cdef hypergraph.Hypergraph* hg
+
+ def viterbi(self):
+ assert (self.hg != NULL)
+ cdef vector[WordID] trans
+ hypergraph.ViterbiESentence(self.hg[0], &trans)
+ cdef str sentence = GetString(trans).c_str()
+ return sentence.decode('utf8')
+
+"""
+def params_str(params):
+ return '\n'.join('%s=%s' % (param, value) for param, value in params.iteritems())
+
+cdef istringstream* ConfigStream(dict params):
+ ini = params_str(params)
+ return new istringstream(<char *> ini)
+"""