diff options
Diffstat (limited to 'python/src/grammar.pxi')
-rw-r--r-- | python/src/grammar.pxi | 61 |
1 files changed, 43 insertions, 18 deletions
diff --git a/python/src/grammar.pxi b/python/src/grammar.pxi index 80d9fbf5..5ec21422 100644 --- a/python/src/grammar.pxi +++ b/python/src/grammar.pxi @@ -1,4 +1,6 @@ cimport grammar +cimport cdec.sa._sa as _sa +import cdec.sa._sa as _sa def _phrase(phrase): return ' '.join(w.encode('utf8') if isinstance(w, unicode) else str(w) for w in phrase) @@ -23,9 +25,41 @@ cdef class NTRef: def __str__(self): return '[%d]' % self.ref -cdef class BaseTRule: +cdef TRule convert_rule(_sa.Rule rule): + cdef unsigned i + cdef lhs = _sa.sym_tocat(rule.lhs) + cdef scores = {} + for i in range(rule.n_scores): + scores['PhraseModel_'+str(i)] = rule.cscores[i] + f, e = [], [] + cdef int* fsyms = rule.f.syms + for i in range(rule.f.n): + if _sa.sym_isvar(fsyms[i]): + f.append(NT(_sa.sym_tocat(fsyms[i]))) + else: + f.append(_sa.sym_tostring(fsyms[i])) + cdef int* esyms = rule.e.syms + for i in range(rule.e.n): + if _sa.sym_isvar(esyms[i]): + e.append(NTRef(_sa.sym_getindex(esyms[i]))) + else: + e.append(_sa.sym_tostring(esyms[i])) + cdef a = [(point/65536, point%65536) for point in rule.word_alignments] + return TRule(lhs, f, e, scores, a) + +cdef class TRule: cdef shared_ptr[grammar.TRule]* rule + def __init__(self, lhs, f, e, scores, a=None): + self.rule = new shared_ptr[grammar.TRule](new grammar.TRule()) + self.lhs = lhs + self.e = e + self.f = f + self.scores = scores + if a: + self.a = a + self.rule.get().ComputeArity() + def __dealloc__(self): del self.rule @@ -104,7 +138,7 @@ cdef class BaseTRule: property scores: def __get__(self): - cdef SparseVector scores = SparseVector() + cdef SparseVector scores = SparseVector.__new__(SparseVector) scores.vector = new FastSparseVector[double](self.rule.get().scores_) return scores @@ -132,17 +166,6 @@ cdef class BaseTRule: return '%s ||| %s ||| %s ||| %s' % (self.lhs, _phrase(self.f), _phrase(self.e), scores) -cdef class TRule(BaseTRule): - def __cinit__(self, lhs, f, e, scores, a=None): - self.rule = new shared_ptr[grammar.TRule](new grammar.TRule()) - self.lhs = lhs - self.e = e - self.f = f - self.scores = scores - if a: - self.a = a - self.rule.get().ComputeArity() - cdef class Grammar: cdef shared_ptr[grammar.Grammar]* grammar @@ -150,12 +173,12 @@ cdef class Grammar: del self.grammar def __iter__(self): - cdef grammar.GrammarIter* root = self.grammar.get().GetRoot() - cdef grammar.RuleBin* rbin = root.GetRules() + cdef grammar.const_GrammarIter* root = self.grammar.get().GetRoot() + cdef grammar.const_RuleBin* rbin = root.GetRules() cdef TRule trule cdef unsigned i for i in range(rbin.GetNumRules()): - trule = TRule() + trule = TRule.__new__(TRule) trule.rule = new shared_ptr[grammar.TRule](rbin.GetIthRule(i)) yield trule @@ -171,6 +194,8 @@ cdef class TextGrammar(Grammar): self.grammar = new shared_ptr[grammar.Grammar](new grammar.TextGrammar()) cdef grammar.TextGrammar* _g = <grammar.TextGrammar*> self.grammar.get() for trule in rules: - if not isinstance(trule, BaseTRule): + if isinstance(trule, _sa.Rule): + trule = convert_rule(trule) + elif not isinstance(trule, TRule): raise ValueError('the grammar should contain TRule objects') - _g.AddRule((<BaseTRule> trule).rule[0]) + _g.AddRule((<TRule> trule).rule[0]) |