summaryrefslogtreecommitdiff
path: root/python/cdec/grammar.pxi
diff options
context:
space:
mode:
authorVictor Chahuneau <vchahune@cs.cmu.edu>2013-08-26 20:12:32 -0400
committerVictor Chahuneau <vchahune@cs.cmu.edu>2013-08-26 20:12:32 -0400
commit03799a2d330c6dbbe12154d4bcea236210b4f6ed (patch)
tree7adb0bc8dd2987fa32ee1299d8821dd8b7b06706 /python/cdec/grammar.pxi
parent8b491e57f8a011f4f8496e44bed7eb7a4360bc93 (diff)
Improve the package structure of pycdec
This change should not break anything, but now you can run: python setup.py build_ext --inplace and use the cleaner: PYTHONPATH=/path/to/cdec/python python -m ...
Diffstat (limited to 'python/cdec/grammar.pxi')
-rw-r--r--python/cdec/grammar.pxi227
1 files changed, 227 insertions, 0 deletions
diff --git a/python/cdec/grammar.pxi b/python/cdec/grammar.pxi
new file mode 100644
index 00000000..d523e4d2
--- /dev/null
+++ b/python/cdec/grammar.pxi
@@ -0,0 +1,227 @@
+cimport grammar
+cimport cdec.sa._sa as _sa
+import cdec.sa._sa as _sa
+
+def _phrase(phrase):
+ return ' '.join(w.encode('utf8') if isinstance(w, unicode) else str(w) for w in phrase)
+
+cdef class NT:
+ cdef public bytes cat
+ cdef public unsigned ref
+ def __init__(self, bytes cat, unsigned ref=0):
+ """NT(bytes cat, int ref=0) -> Non-terminal from category `cat`."""
+ self.cat = cat
+ self.ref = ref
+
+ def __str__(self):
+ if self.ref > 0:
+ return '[%s,%d]' % (self.cat, self.ref)
+ return '[%s]' % self.cat
+
+cdef class NTRef:
+ cdef public unsigned ref
+ def __init__(self, unsigned ref):
+ """NTRef(int ref) -> Non-terminal reference."""
+ self.ref = ref
+
+ def __str__(self):
+ return '[%d]' % self.ref
+
+cdef TRule convert_rule(_sa.Rule rule):
+ lhs = _sa.sym_tocat(rule.lhs)
+ scores = dict(rule.scores)
+ f, e = [], []
+ cdef int* fsyms = rule.f.syms
+ for i in range(rule.f.n):
+ if _sa.sym_isvar(fsyms[i]):
+ f.append(NT(_sa.sym_tocat(fsyms[i])))
+ else:
+ f.append(_sa.sym_tostring(fsyms[i]))
+ cdef int* esyms = rule.e.syms
+ for i in range(rule.e.n):
+ if _sa.sym_isvar(esyms[i]):
+ e.append(NTRef(_sa.sym_getindex(esyms[i])))
+ else:
+ e.append(_sa.sym_tostring(esyms[i]))
+ a = list(rule.alignments())
+ return TRule(lhs, f, e, scores, a)
+
+cdef class TRule:
+ cdef shared_ptr[grammar.TRule]* rule
+
+ def __init__(self, lhs, f, e, scores, a=None):
+ """TRule(lhs, f, e, scores, a=None) -> Translation rule.
+ lhs: left hand side non-terminal
+ f: source phrase (list of words/NT)
+ e: target phrase (list of words/NTRef)
+ scores: dictionary of feature scores
+ a: optional list of alignment points"""
+ self.rule = new shared_ptr[grammar.TRule](new grammar.TRule())
+ self.lhs = lhs
+ self.e = e
+ self.f = f
+ self.scores = scores
+ if a:
+ self.a = a
+ self.rule.get().ComputeArity()
+
+ def __dealloc__(self):
+ del self.rule
+
+ property arity:
+ def __get__(self):
+ return self.rule.get().arity_
+
+ property f:
+ def __get__(self):
+ cdef vector[WordID]* f_ = &self.rule.get().f_
+ cdef WordID w
+ cdef f = []
+ cdef unsigned i
+ cdef int idx = 0
+ for i in range(f_.size()):
+ w = f_[0][i]
+ if w < 0:
+ idx += 1
+ f.append(NT(TDConvert(-w).c_str(), idx))
+ else:
+ f.append(unicode(TDConvert(w).c_str(), encoding='utf8'))
+ return f
+
+ def __set__(self, f):
+ cdef vector[WordID]* f_ = &self.rule.get().f_
+ f_.resize(len(f))
+ cdef unsigned i
+ cdef int idx = 0
+ for i in range(len(f)):
+ if isinstance(f[i], NT):
+ f_[0][i] = -TDConvert((<NT> f[i]).cat)
+ else:
+ fi = as_str(f[i])
+ f_[0][i] = TDConvert(fi)
+
+ property e:
+ def __get__(self):
+ cdef vector[WordID]* e_ = &self.rule.get().e_
+ cdef WordID w
+ cdef e = []
+ cdef unsigned i
+ cdef int idx = 0
+ for i in range(e_.size()):
+ w = e_[0][i]
+ if w < 1:
+ idx += 1
+ e.append(NTRef(1-w))
+ else:
+ e.append(unicode(TDConvert(w).c_str(), encoding='utf8'))
+ return e
+
+ def __set__(self, e):
+ cdef vector[WordID]* e_ = &self.rule.get().e_
+ e_.resize(len(e))
+ cdef unsigned i
+ for i in range(len(e)):
+ if isinstance(e[i], NTRef):
+ e_[0][i] = 1-e[i].ref
+ else:
+ ei = as_str(e[i])
+ e_[0][i] = TDConvert(ei)
+
+ property a:
+ def __get__(self):
+ cdef unsigned i
+ cdef vector[grammar.AlignmentPoint]* a = &self.rule.get().a_
+ for i in range(a.size()):
+ yield (a[0][i].s_, a[0][i].t_)
+
+ def __set__(self, a):
+ cdef vector[grammar.AlignmentPoint]* a_ = &self.rule.get().a_
+ a_.resize(len(a))
+ cdef unsigned i
+ cdef int s, t
+ for i in range(len(a)):
+ s, t = a[i]
+ a_[0][i] = grammar.AlignmentPoint(s, t)
+
+ property scores:
+ def __get__(self):
+ cdef SparseVector scores = SparseVector.__new__(SparseVector)
+ scores.vector = new FastSparseVector[double](self.rule.get().scores_)
+ return scores
+
+ def __set__(self, scores):
+ cdef FastSparseVector[double]* scores_ = &self.rule.get().scores_
+ scores_.clear()
+ cdef int fid
+ cdef float fval
+ for fname, fval in scores.items():
+ fn = as_str(fname)
+ fid = FDConvert(fn)
+ if fid < 0: raise KeyError(fname)
+ scores_.set_value(fid, fval)
+
+ property lhs:
+ def __get__(self):
+ return NT(TDConvert(-self.rule.get().lhs_).c_str())
+
+ def __set__(self, lhs):
+ if not isinstance(lhs, NT):
+ lhs = NT(lhs)
+ self.rule.get().lhs_ = -TDConvert((<NT> lhs).cat)
+
+ def __str__(self):
+ scores = ' '.join('%s=%s' % feat for feat in self.scores)
+ return '%s ||| %s ||| %s ||| %s' % (self.lhs,
+ _phrase(self.f), _phrase(self.e), scores)
+
+cdef class MRule(TRule):
+ def __init__(self, lhs, rhs, scores):
+ """MRule(lhs, rhs, scores, a=None) -> Monolingual rule.
+ lhs: left hand side non-terminal
+ rhs: right hand side phrase (list of words/NT)
+ scores: dictionary of feature scores"""
+ cdef unsigned i = 1
+ e = []
+ for s in rhs:
+ if isinstance(s, NT):
+ e.append(NTRef(i))
+ i += 1
+ else:
+ e.append(s)
+ super(MRule, self).__init__(lhs, rhs, e, scores, None)
+
+cdef class Grammar:
+ cdef shared_ptr[grammar.Grammar]* grammar
+
+ def __dealloc__(self):
+ del self.grammar
+
+ def __iter__(self):
+ cdef grammar.const_GrammarIter* root = self.grammar.get().GetRoot()
+ cdef grammar.const_RuleBin* rbin = root.GetRules()
+ cdef TRule trule
+ cdef unsigned i
+ for i in range(rbin.GetNumRules()):
+ trule = TRule.__new__(TRule)
+ trule.rule = new shared_ptr[grammar.TRule](rbin.GetIthRule(i))
+ yield trule
+
+ property name:
+ def __get__(self):
+ str(self.grammar.get().GetGrammarName().c_str())
+
+ def __set__(self, name):
+ name = as_str(name)
+ self.grammar.get().SetGrammarName(name)
+
+cdef class TextGrammar(Grammar):
+ def __init__(self, rules):
+ """TextGrammar(rules) -> SCFG Grammar containing the rules."""
+ self.grammar = new shared_ptr[grammar.Grammar](new grammar.TextGrammar())
+ cdef grammar.TextGrammar* _g = <grammar.TextGrammar*> self.grammar.get()
+ for trule in rules:
+ if isinstance(trule, _sa.Rule):
+ trule = convert_rule(trule)
+ elif not isinstance(trule, TRule):
+ raise ValueError('the grammar should contain TRule objects')
+ _g.AddRule((<TRule> trule).rule[0])