summaryrefslogtreecommitdiff
path: root/python/src/_cdec.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/_cdec.pyx')
-rw-r--r--python/src/_cdec.pyx29
1 files changed, 16 insertions, 13 deletions
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx
index 164d6570..c60f342f 100644
--- a/python/src/_cdec.pyx
+++ b/python/src/_cdec.pyx
@@ -3,17 +3,18 @@ from libcpp.vector cimport vector
from utils cimport *
cimport decoder
-cdef char* as_str(sentence, error_msg='Cannot convert type %s to str'):
+cdef char* as_str(data, error_msg='Cannot convert type %s to str'):
cdef bytes ret
- if isinstance(sentence, unicode):
- ret = sentence.encode('utf8')
- elif isinstance(sentence, str):
- ret = sentence
+ if isinstance(data, unicode):
+ ret = data.encode('utf8')
+ elif isinstance(data, str):
+ ret = data
else:
- raise TypeError(error_msg % type(sentence))
+ raise TypeError(error_msg % type(data))
return ret
include "vectors.pxi"
+include "grammar.pxi"
include "hypergraph.pxi"
include "lattice.pxi"
include "mteval.pxi"
@@ -89,18 +90,20 @@ cdef class Decoder:
self.weights[fname.strip()] = float(value)
def translate(self, sentence, grammar=None):
- if isinstance(sentence, unicode):
- inp = sentence.strip().encode('utf8')
- elif isinstance(sentence, str):
- inp = sentence.strip()
+ cdef bytes input_str
+ if isinstance(sentence, unicode) or isinstance(sentence, str):
+ input_str = as_str(sentence.strip())
elif isinstance(sentence, Lattice):
- inp = str(sentence) # PLF format
+ input_str = str(sentence) # PLF format
else:
raise TypeError('Cannot translate input type %s' % type(sentence))
if grammar:
- self.dec.AddSupplementalGrammarFromString(string(<char *> grammar))
+ if isinstance(grammar, str) or isinstance(grammar, unicode):
+ self.dec.AddSupplementalGrammarFromString(string(as_str(grammar)))
+ else:
+ self.dec.AddSupplementalGrammar(TextGrammar(grammar).grammar[0])
cdef decoder.BasicObserver observer = decoder.BasicObserver()
- self.dec.Decode(string(<char *>inp), &observer)
+ self.dec.Decode(string(input_str), &observer)
if observer.hypergraph == NULL:
raise ParseFailed()
cdef Hypergraph hg = Hypergraph()