diff options
Diffstat (limited to 'sa-extract/rule.pyx')
-rw-r--r-- | sa-extract/rule.pyx | 286 |
1 files changed, 286 insertions, 0 deletions
diff --git a/sa-extract/rule.pyx b/sa-extract/rule.pyx new file mode 100644 index 00000000..7cd3efda --- /dev/null +++ b/sa-extract/rule.pyx @@ -0,0 +1,286 @@ +from libc.stdlib cimport malloc, calloc, realloc, free, strtof, strtol +from libc.string cimport strsep, strcpy, strlen + +cdef extern from "strutil.h": + char *strstrsep(char **stringp, char *delim) + char *strip(char *s) + char **split(char *s, char *delim, int *pn) + +import sys + +import sym +cimport sym +cdef sym.Alphabet alphabet +alphabet = sym.alphabet + +global span_limit +span_limit = None + +cdef int bufsize +cdef char *buf +bufsize = 100 +buf = <char *>malloc(bufsize) +cdef ensurebufsize(int size): + global buf, bufsize + if size > bufsize: + buf = <char *>realloc(buf, size*sizeof(char)) + bufsize = size + +cdef class Phrase: + def __cinit__(self, words): + cdef int i, j, n, n_vars + cdef char **toks + cdef bytes bwords + cdef char* cwords + + n_vars = 0 + if type(words) is str: + ensurebufsize(len(words)+1) + bwords = words + cwords = bwords + strcpy(buf, cwords) + toks = split(buf, NULL, &n) + self.syms = <int *>malloc(n*sizeof(int)) + for i from 0 <= i < n: + self.syms[i] = alphabet.fromstring(toks[i], 0) + if alphabet.isvar(self.syms[i]): + n_vars = n_vars + 1 + + else: + n = len(words) + self.syms = <int *>malloc(n*sizeof(int)) + for i from 0 <= i < n: + self.syms[i] = words[i] + if alphabet.isvar(self.syms[i]): + n_vars = n_vars + 1 + self.n = n + self.n_vars = n_vars + self.varpos = <int *>malloc(n_vars*sizeof(int)) + j = 0 + for i from 0 <= i < n: + if alphabet.isvar(self.syms[i]): + self.varpos[j] = i + j = j + 1 + + def __dealloc__(self): + free(self.syms) + free(self.varpos) + + def __str__(self): + strs = [] + cdef int i, s + for i from 0 <= i < self.n: + s = self.syms[i] + strs.append(alphabet.tostring(s)) + return " ".join(strs) + + def instantiable(self, i, j, n): + return span_limit is None or (j-i) <= span_limit + + def handle(self): + """return a hashable representation that normalizes the ordering + of the nonterminal indices""" + norm = [] + cdef int i, j, s + i = 1 + j = 0 + for j from 0 <= j < self.n: + s = self.syms[j] + if alphabet.isvar(s): + s = alphabet.setindex(s,i) + i = i + 1 + norm.append(s) + return tuple(norm) + + def strhandle(self): + strs = [] + norm = [] + cdef int i, j, s + i = 1 + j = 0 + for j from 0 <= j < self.n: + s = self.syms[j] + if alphabet.isvar(s): + s = alphabet.setindex(s,i) + i = i + 1 + norm.append(alphabet.tostring(s)) + return " ".join(norm) + + def arity(self): + return self.n_vars + + def getvarpos(self, i): + if 0 <= i < self.n_vars: + return self.varpos[i] + else: + raise IndexError + + def getvar(self, i): + if 0 <= i < self.n_vars: + return self.syms[self.varpos[i]] + else: + raise IndexError + + cdef int chunkpos(self, int k): + if k == 0: + return 0 + else: + return self.varpos[k-1]+1 + + cdef int chunklen(self, int k): + if self.n_vars == 0: + return self.n + elif k == 0: + return self.varpos[0] + elif k == self.n_vars: + return self.n-self.varpos[k-1]-1 + else: + return self.varpos[k]-self.varpos[k-1]-1 + + def clen(self, k): + return self.chunklen(k) + + def getchunk(self, ci): + cdef int start, stop + start = self.chunkpos(ci) + stop = start+self.chunklen(ci) + chunk = [] + for i from start <= i < stop: + chunk.append(self.syms[i]) + return chunk + + def __cmp__(self, other): + cdef Phrase otherp + cdef int i + otherp = other + for i from 0 <= i < min(self.n, otherp.n): + if self.syms[i] < otherp.syms[i]: + return -1 + elif self.syms[i] > otherp.syms[i]: + return 1 + if self.n < otherp.n: + return -1 + elif self.n > otherp.n: + return 1 + else: + return 0 + + def __hash__(self): + cdef int i + cdef unsigned h + h = 0 + for i from 0 <= i < self.n: + if self.syms[i] > 0: + h = (h << 1) + self.syms[i] + else: + h = (h << 1) + -self.syms[i] + return h + + def __len__(self): + return self.n + + def __getitem__(self, i): + return self.syms[i] + + def __iter__(self): + cdef int i + l = [] + for i from 0 <= i < self.n: + l.append(self.syms[i]) + return iter(l) + + def subst(self, start, children): + cdef int i + for i from 0 <= i < self.n: + if alphabet.isvar(self.syms[i]): + start = start + children[alphabet.getindex(self.syms[i])-1] + else: + start = start + (self.syms[i],) + return start + +cdef class Rule: + def __cinit__(self, lhs, f, e, owner=None, scores=None, word_alignments=None): + cdef int i, n + cdef char *rest + + self.word_alignments = word_alignments + if scores is None: + self.cscores = NULL + self.n_scores = 0 + else: + n = len(scores) + self.cscores = <float *>malloc(n*sizeof(float)) + self.n_scores = n + for i from 0 <= i < n: + self.cscores[i] = scores[i] + + def __init__(self, lhs, f, e, owner=None, scores=None, word_alignments=None): + if not sym.isvar(lhs): + sys.stderr.write("error: lhs=%d\n" % lhs) + self.lhs = lhs + self.f = f + self.e = e + self.word_alignments = word_alignments + + def __dealloc__(self): + if self.cscores != NULL: + free(self.cscores) + + def __str__(self): + return self.to_line() + + def __hash__(self): + return hash((self.lhs, self.f, self.e)) + + def __cmp__(self, Rule other): + return cmp((self.lhs, self.f, self.e, self.word_alignments), (other.lhs, other.f, other.e, self.word_alignments)) + + def __iadd__(self, Rule other): + if self.n_scores != other.n_scores: + raise ValueError + for i from 0 <= i < self.n_scores: + self.cscores[i] = self.cscores[i] + other.cscores[i] + return self + + def fmerge(self, Phrase f): + if self.f == f: + self.f = f + + def arity(self): + return self.f.arity() + + def to_line(self): + scorestrs = [] + for i from 0 <= i < self.n_scores: + scorestrs.append(str(self.cscores[i])) + fields = [alphabet.tostring(self.lhs), str(self.f), str(self.e), " ".join(scorestrs)] + if self.word_alignments is not None: + alignstr = [] + for i from 0 <= i < len(self.word_alignments): + alignstr.append("%d-%d" % (self.word_alignments[i]/65536, self.word_alignments[i]%65536)) + #for s,t in self.word_alignments: + #alignstr.append("%d-%d" % (s,t)) + fields.append(" ".join(alignstr)) + + return " ||| ".join(fields) + + property scores: + def __get__(self): + s = [None]*self.n_scores + for i from 0 <= i < self.n_scores: + s[i] = self.cscores[i] + return s + + def __set__(self, s): + if self.cscores != NULL: + free(self.cscores) + self.cscores = <float *>malloc(len(s)*sizeof(float)) + self.n_scores = len(s) + for i from 0 <= i < self.n_scores: + self.cscores[i] = s[i] + +def rule_copy(r): + r1 = Rule(r.lhs, r.f, r.e, r.owner, r.scores) + r1.word_alignments = r.word_alignments + return r1 + |