diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-08-03 07:46:54 -0400 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-08-03 07:46:54 -0400 |
commit | be1ab0a8937f9c5668ea5e6c31b798e87672e55e (patch) | |
tree | a13aad60ab6cced213401bce6a38ac885ba171ba /python/src/grammar.pxi | |
parent | e5d6f4ae41009c26978ecd62668501af9762b0bc (diff) | |
parent | 9fe0219562e5db25171cce8776381600ff9a5649 (diff) |
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'python/src/grammar.pxi')
-rw-r--r-- | python/src/grammar.pxi | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/python/src/grammar.pxi b/python/src/grammar.pxi new file mode 100644 index 00000000..a9a5ea14 --- /dev/null +++ b/python/src/grammar.pxi @@ -0,0 +1,213 @@ +cimport grammar +cimport cdec.sa._sa as _sa +import cdec.sa._sa as _sa + +def _phrase(phrase): + return ' '.join(w.encode('utf8') if isinstance(w, unicode) else str(w) for w in phrase) + +cdef class NT: + cdef public bytes cat + cdef public unsigned ref + def __init__(self, char* cat, unsigned ref=0): + self.cat = cat + self.ref = ref + + def __str__(self): + if self.ref > 0: + return '[%s,%d]' % (self.cat, self.ref) + return '[%s]' % self.cat + +cdef class NTRef: + cdef public unsigned ref + def __init__(self, unsigned ref): + self.ref = ref + + def __str__(self): + return '[%d]' % self.ref + +cdef TRule convert_rule(_sa.Rule rule): + cdef unsigned i + cdef lhs = _sa.sym_tocat(rule.lhs) + cdef scores = {} + for i in range(rule.n_scores): + scores['PhraseModel_'+str(i)] = rule.cscores[i] + f, e = [], [] + cdef int* fsyms = rule.f.syms + for i in range(rule.f.n): + if _sa.sym_isvar(fsyms[i]): + f.append(NT(_sa.sym_tocat(fsyms[i]))) + else: + f.append(_sa.sym_tostring(fsyms[i])) + cdef int* esyms = rule.e.syms + for i in range(rule.e.n): + if _sa.sym_isvar(esyms[i]): + e.append(NTRef(_sa.sym_getindex(esyms[i]))) + else: + e.append(_sa.sym_tostring(esyms[i])) + cdef a = [(point/65536, point%65536) for point in rule.word_alignments] + return TRule(lhs, f, e, scores, a) + +cdef class TRule: + cdef shared_ptr[grammar.TRule]* rule + + def __init__(self, lhs, f, e, scores, a=None): + self.rule = new shared_ptr[grammar.TRule](new grammar.TRule()) + self.lhs = lhs + self.e = e + self.f = f + self.scores = scores + if a: + self.a = a + self.rule.get().ComputeArity() + + def __dealloc__(self): + del self.rule + + property arity: + def __get__(self): + return self.rule.get().arity_ + + property f: + def __get__(self): + cdef vector[WordID]* f_ = &self.rule.get().f_ + cdef WordID w + cdef f = [] + cdef unsigned i + cdef int idx = 0 + for i in range(f_.size()): + w = f_[0][i] + if w < 0: + idx += 1 + f.append(NT(TDConvert(-w), idx)) + else: + f.append(unicode(TDConvert(w), encoding='utf8')) + return f + + def __set__(self, f): + cdef vector[WordID]* f_ = &self.rule.get().f_ + f_.resize(len(f)) + cdef unsigned i + cdef int idx = 0 + for i in range(len(f)): + if isinstance(f[i], NT): + f_[0][i] = -TDConvert(<char *>f[i].cat) + else: + f_[0][i] = TDConvert(<char *>as_str(f[i])) + + property e: + def __get__(self): + cdef vector[WordID]* e_ = &self.rule.get().e_ + cdef WordID w + cdef e = [] + cdef unsigned i + cdef int idx = 0 + for i in range(e_.size()): + w = e_[0][i] + if w < 1: + idx += 1 + e.append(NTRef(1-w)) + else: + e.append(unicode(TDConvert(w), encoding='utf8')) + return e + + def __set__(self, e): + cdef vector[WordID]* e_ = &self.rule.get().e_ + e_.resize(len(e)) + cdef unsigned i + for i in range(len(e)): + if isinstance(e[i], NTRef): + e_[0][i] = 1-e[i].ref + else: + e_[0][i] = TDConvert(<char *>as_str(e[i])) + + property a: + def __get__(self): + cdef unsigned i + cdef vector[grammar.AlignmentPoint]* a = &self.rule.get().a_ + for i in range(a.size()): + yield (a[0][i].s_, a[0][i].t_) + + def __set__(self, a): + cdef vector[grammar.AlignmentPoint]* a_ = &self.rule.get().a_ + a_.resize(len(a)) + cdef unsigned i + cdef int s, t + for i in range(len(a)): + s, t = a[i] + a_[0][i] = grammar.AlignmentPoint(s, t) + + property scores: + def __get__(self): + cdef SparseVector scores = SparseVector.__new__(SparseVector) + scores.vector = new FastSparseVector[double](self.rule.get().scores_) + return scores + + def __set__(self, scores): + cdef FastSparseVector[double]* scores_ = &self.rule.get().scores_ + scores_.clear() + cdef int fid + cdef float fval + for fname, fval in scores.items(): + fid = FDConvert(<char *>as_str(fname)) + if fid < 0: raise KeyError(fname) + scores_.set_value(fid, fval) + + property lhs: + def __get__(self): + return NT(TDConvert(-self.rule.get().lhs_)) + + def __set__(self, lhs): + if not isinstance(lhs, NT): + lhs = NT(lhs) + self.rule.get().lhs_ = -TDConvert(<char *>lhs.cat) + + def __str__(self): + scores = ' '.join('%s=%s' % feat for feat in self.scores) + return '%s ||| %s ||| %s ||| %s' % (self.lhs, + _phrase(self.f), _phrase(self.e), scores) + +cdef class MRule(TRule): + def __init__(self, lhs, rhs, scores, a=None): + cdef unsigned i = 1 + e = [] + for s in rhs: + if isinstance(s, NT): + e.append(NTRef(i)) + i += 1 + else: + e.append(s) + super(MRule, self).__init__(lhs, rhs, e, scores, a) + +cdef class Grammar: + cdef shared_ptr[grammar.Grammar]* grammar + + def __dealloc__(self): + del self.grammar + + def __iter__(self): + cdef grammar.const_GrammarIter* root = self.grammar.get().GetRoot() + cdef grammar.const_RuleBin* rbin = root.GetRules() + cdef TRule trule + cdef unsigned i + for i in range(rbin.GetNumRules()): + trule = TRule.__new__(TRule) + trule.rule = new shared_ptr[grammar.TRule](rbin.GetIthRule(i)) + yield trule + + property name: + def __get__(self): + self.grammar.get().GetGrammarName().c_str() + + def __set__(self, name): + self.grammar.get().SetGrammarName(string(<char *>name)) + +cdef class TextGrammar(Grammar): + def __cinit__(self, rules): + self.grammar = new shared_ptr[grammar.Grammar](new grammar.TextGrammar()) + cdef grammar.TextGrammar* _g = <grammar.TextGrammar*> self.grammar.get() + for trule in rules: + if isinstance(trule, _sa.Rule): + trule = convert_rule(trule) + elif not isinstance(trule, TRule): + raise ValueError('the grammar should contain TRule objects') + _g.AddRule((<TRule> trule).rule[0]) |