diff options
Diffstat (limited to 'python/src/_cdec.pyx')
-rw-r--r-- | python/src/_cdec.pyx | 29 |
1 files changed, 16 insertions, 13 deletions
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx index 164d6570..c60f342f 100644 --- a/python/src/_cdec.pyx +++ b/python/src/_cdec.pyx @@ -3,17 +3,18 @@ from libcpp.vector cimport vector from utils cimport * cimport decoder -cdef char* as_str(sentence, error_msg='Cannot convert type %s to str'): +cdef char* as_str(data, error_msg='Cannot convert type %s to str'): cdef bytes ret - if isinstance(sentence, unicode): - ret = sentence.encode('utf8') - elif isinstance(sentence, str): - ret = sentence + if isinstance(data, unicode): + ret = data.encode('utf8') + elif isinstance(data, str): + ret = data else: - raise TypeError(error_msg % type(sentence)) + raise TypeError(error_msg % type(data)) return ret include "vectors.pxi" +include "grammar.pxi" include "hypergraph.pxi" include "lattice.pxi" include "mteval.pxi" @@ -89,18 +90,20 @@ cdef class Decoder: self.weights[fname.strip()] = float(value) def translate(self, sentence, grammar=None): - if isinstance(sentence, unicode): - inp = sentence.strip().encode('utf8') - elif isinstance(sentence, str): - inp = sentence.strip() + cdef bytes input_str + if isinstance(sentence, unicode) or isinstance(sentence, str): + input_str = as_str(sentence.strip()) elif isinstance(sentence, Lattice): - inp = str(sentence) # PLF format + input_str = str(sentence) # PLF format else: raise TypeError('Cannot translate input type %s' % type(sentence)) if grammar: - self.dec.AddSupplementalGrammarFromString(string(<char *> grammar)) + if isinstance(grammar, str) or isinstance(grammar, unicode): + self.dec.AddSupplementalGrammarFromString(string(as_str(grammar))) + else: + self.dec.AddSupplementalGrammar(TextGrammar(grammar).grammar[0]) cdef decoder.BasicObserver observer = decoder.BasicObserver() - self.dec.Decode(string(<char *>inp), &observer) + self.dec.Decode(string(input_str), &observer) if observer.hypergraph == NULL: raise ParseFailed() cdef Hypergraph hg = Hypergraph() |