summaryrefslogtreecommitdiff
path: root/python/src/grammar.pxi
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-07-28 12:11:44 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-07-28 12:11:44 -0400
commit306e0ba4754c6c4f460536cfe8c3f118dc1cc175 (patch)
treead5ea3b0a5370ac613d1bad715fe0f5ab8c91c11 /python/src/grammar.pxi
parent934e55dc12c3f374684bc6a0797e6f85c7abb85a (diff)
parentee5e376e263d9aeabdeee6968b4457f53d3fc772 (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'python/src/grammar.pxi')
-rw-r--r--python/src/grammar.pxi61
1 files changed, 43 insertions, 18 deletions
diff --git a/python/src/grammar.pxi b/python/src/grammar.pxi
index 80d9fbf5..5ec21422 100644
--- a/python/src/grammar.pxi
+++ b/python/src/grammar.pxi
@@ -1,4 +1,6 @@
cimport grammar
+cimport cdec.sa._sa as _sa
+import cdec.sa._sa as _sa
def _phrase(phrase):
return ' '.join(w.encode('utf8') if isinstance(w, unicode) else str(w) for w in phrase)
@@ -23,9 +25,41 @@ cdef class NTRef:
def __str__(self):
return '[%d]' % self.ref
-cdef class BaseTRule:
+cdef TRule convert_rule(_sa.Rule rule):
+ cdef unsigned i
+ cdef lhs = _sa.sym_tocat(rule.lhs)
+ cdef scores = {}
+ for i in range(rule.n_scores):
+ scores['PhraseModel_'+str(i)] = rule.cscores[i]
+ f, e = [], []
+ cdef int* fsyms = rule.f.syms
+ for i in range(rule.f.n):
+ if _sa.sym_isvar(fsyms[i]):
+ f.append(NT(_sa.sym_tocat(fsyms[i])))
+ else:
+ f.append(_sa.sym_tostring(fsyms[i]))
+ cdef int* esyms = rule.e.syms
+ for i in range(rule.e.n):
+ if _sa.sym_isvar(esyms[i]):
+ e.append(NTRef(_sa.sym_getindex(esyms[i])))
+ else:
+ e.append(_sa.sym_tostring(esyms[i]))
+ cdef a = [(point/65536, point%65536) for point in rule.word_alignments]
+ return TRule(lhs, f, e, scores, a)
+
+cdef class TRule:
cdef shared_ptr[grammar.TRule]* rule
+ def __init__(self, lhs, f, e, scores, a=None):
+ self.rule = new shared_ptr[grammar.TRule](new grammar.TRule())
+ self.lhs = lhs
+ self.e = e
+ self.f = f
+ self.scores = scores
+ if a:
+ self.a = a
+ self.rule.get().ComputeArity()
+
def __dealloc__(self):
del self.rule
@@ -104,7 +138,7 @@ cdef class BaseTRule:
property scores:
def __get__(self):
- cdef SparseVector scores = SparseVector()
+ cdef SparseVector scores = SparseVector.__new__(SparseVector)
scores.vector = new FastSparseVector[double](self.rule.get().scores_)
return scores
@@ -132,17 +166,6 @@ cdef class BaseTRule:
return '%s ||| %s ||| %s ||| %s' % (self.lhs,
_phrase(self.f), _phrase(self.e), scores)
-cdef class TRule(BaseTRule):
- def __cinit__(self, lhs, f, e, scores, a=None):
- self.rule = new shared_ptr[grammar.TRule](new grammar.TRule())
- self.lhs = lhs
- self.e = e
- self.f = f
- self.scores = scores
- if a:
- self.a = a
- self.rule.get().ComputeArity()
-
cdef class Grammar:
cdef shared_ptr[grammar.Grammar]* grammar
@@ -150,12 +173,12 @@ cdef class Grammar:
del self.grammar
def __iter__(self):
- cdef grammar.GrammarIter* root = self.grammar.get().GetRoot()
- cdef grammar.RuleBin* rbin = root.GetRules()
+ cdef grammar.const_GrammarIter* root = self.grammar.get().GetRoot()
+ cdef grammar.const_RuleBin* rbin = root.GetRules()
cdef TRule trule
cdef unsigned i
for i in range(rbin.GetNumRules()):
- trule = TRule()
+ trule = TRule.__new__(TRule)
trule.rule = new shared_ptr[grammar.TRule](rbin.GetIthRule(i))
yield trule
@@ -171,6 +194,8 @@ cdef class TextGrammar(Grammar):
self.grammar = new shared_ptr[grammar.Grammar](new grammar.TextGrammar())
cdef grammar.TextGrammar* _g = <grammar.TextGrammar*> self.grammar.get()
for trule in rules:
- if not isinstance(trule, BaseTRule):
+ if isinstance(trule, _sa.Rule):
+ trule = convert_rule(trule)
+ elif not isinstance(trule, TRule):
raise ValueError('the grammar should contain TRule objects')
- _g.AddRule((<BaseTRule> trule).rule[0])
+ _g.AddRule((<TRule> trule).rule[0])