diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 311 |
1 files changed, 307 insertions, 4 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index a0bda793..81ea7960 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -264,6 +264,14 @@ cdef class HieroCachingRuleFactory: cdef IntList findexes cdef IntList findexes1 + cdef phrases_f + cdef phrases_e + cdef phrases_fe + cdef phrases_al + cdef bilex_f + cdef bilex_e + cdef bilex_fe + def __cinit__(self, # compiled alignment object (REQUIRED) Alignment alignment, @@ -370,6 +378,19 @@ cdef class HieroCachingRuleFactory: self.findexes = IntList(initial_len=10) self.findexes1 = IntList(initial_len=10) + + # Online stats + + # Phrase counts + self.phrases_f = defaultdict(int) + self.phrases_e = defaultdict(int) + self.phrases_fe = defaultdict(lambda: defaultdict(int)) + self.phrases_al = defaultdict(dict) + + # Bilexical counts + self.bilex_f = defaultdict(int) + self.bilex_e = defaultdict(int) + self.bilex_fe = defaultdict(lambda: defaultdict(int)) def configure(self, SuffixArray fsarray, DataArray edarray, Sampler sampler, Scorer scorer): @@ -1799,8 +1820,290 @@ cdef class HieroCachingRuleFactory: return extracts + # Aggregate stats from a training instance: + # Extract hierarchical phrase pairs + # Update bilexical counts def add_instance(self, f_words, e_words, alignment): - logger.info("I would add:") - logger.info(decode_words(f_words)) - logger.info(decode_words(e_words)) - logger.info(alignment)
\ No newline at end of file + + # Bilexical counts + self.aggr_bilex(f_words, e_words) + + # Rules extracted from this instance + rules = set() + + f_len = len(f_words) + e_len = len(e_words) + + # Pre-compute alignment info + al = [[] for i in range(f_len)] + al_span = [[f_len + 1, -1] for i in range(f_len)] + for (f, e) in alignment: + al[f].append(e) + al_span[f][0] = min(al_span[f][0], e) + al_span[f][1] = max(al_span[f][1], e) + + # Target side word coverage + # TODO: Does Cython do bit vectors? + cover = [0] * e_len + + # Extract all possible hierarchical phrases starting at a source index + # f_ i and j are current, e_ i and j are previous + def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open): + # Phrase extraction limits + if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \ + (f_j - f_i) + 1 > self.max_initial_size: + return + # Unaligned word + if not al[f_j]: + # Open non-terminal: extend + if nt_open: + nt[-1][2] += 1 + extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True) + nt[-1][2] -= 1 + # No open non-terminal: extend with word + else: + extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False) + return + # Aligned word + link_i = al_span[f_j][0] + link_j = al_span[f_j][1] + new_e_i = min(link_i, e_i) + new_e_j = max(link_j, e_j) + # Open non-terminal: close, extract, extend + if nt_open: + # Close non-terminal, checking for collisions + old_last_nt = nt[-1][:] + nt[-1][2] = f_j + if link_i < nt[-1][3]: + if not span_check(cover, link_i, nt[-1][3] - 1): + nt[-1] = old_last_nt + return + span_flip(cover, link_i, nt[-1][3] - 1) + nt[-1][3] = link_i + if link_j > nt[-1][4]: + if not span_check(cover, nt[-1][4] + 1, link_j): + nt[-1] = old_last_nt + return + span_flip(cover, nt[-1][4] + 1, link_j) + nt[-1][4] = link_j + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + rules.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + nt[-1] = old_last_nt + if link_i < nt[-1][3]: + span_flip(cover, link_i, nt[-1][3] - 1) + if link_j > nt[-1][4]: + span_flip(cover, nt[-1][4] + 1, link_j) + return + # No open non-terminal + # Extract, extend with word + collision = False + for link in al[f_j]: + if cover[link]: + collision = True + # Collisions block extraction and extension, but may be okay for + # continuing non-terminals + if not collision: + plus_links = [] + for link in al[f_j]: + plus_links.append((f_j, link)) + cover[link] = ~cover[link] + links.append(plus_links) + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + rules.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False) + links.pop() + for link in al[f_j]: + cover[link] = ~cover[link] + # Try to add a word to a (closed) non-terminal, extract, extend + if nt and nt[-1][2] == f_j - 1: + # Add to non-terminal, checking for collisions + old_last_nt = nt[-1][:] + nt[-1][2] = f_j + if link_i < nt[-1][3]: + if not span_check(cover, link_i, nt[-1][3] - 1): + nt[-1] = old_last_nt + return + span_flip(cover, link_i, nt[-1][3] - 1) + nt[-1][3] = link_i + if link_j > nt[-1][4]: + if not span_check(cover, nt[-1][4] + 1, link_j): + nt[-1] = old_last_nt + return + span_flip(cover, nt[-1][4] + 1, link_j) + nt[-1][4] = link_j + # Require at least one word in phrase + if links: + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + rules.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + nt[-1] = old_last_nt + if new_e_i < nt[-1][3]: + span_flip(cover, link_i, nt[-1][3] - 1) + if link_j > nt[-1][4]: + span_flip(cover, nt[-1][4] + 1, link_j) + # Try to start a new non-terminal, extract, extend + if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals: + # Check for collisions + if not span_check(cover, link_i, link_j): + return + span_flip(cover, link_i, link_j) + nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j]) + # Require at least one word in phrase + if links: + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + rules.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + nt.pop() + span_flip(cover, link_i, link_j) + # TODO: try adding NT to start, end, both + # check: one aligned word on boundary that is not part of a NT + + # Try to extract phrases from every f index + f_i = 0 + while f_i < f_len: + # Skip if phrases won't be tight on left side + if not al[f_i]: + f_i += 1 + continue + extract(f_i, f_i, f_len + 1, -1, 1, [], [], False) + f_i += 1 + + for rule in sorted(rules): + logger.info(rule) + + # Aggregate bilexical counts + def aggr_bilex(self, f_words, e_words): + + for e_w in e_words: + self.bilex_e[e_w] += 1 + + for f_w in f_words: + self.bilex_f[f_w] += 1 + for e_w in e_words: + self.bilex_fe[f_w][e_w] += 1 + + # Create a rule from source, target, non-terminals, and alignments + def form_rules(self, f_i, e_i, f_span, e_span, nt, al): + + # This could be more efficient but is unlikely to be the bottleneck + + rules = [] + + nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) + + logger.info(nt) + + f_sym = list(f_span[:]) + off = f_i + for next_nt in nt: + nt_len = (next_nt[2] - next_nt[1]) + 1 + i = 0 + while i < nt_len: + f_sym.pop(next_nt[1] - off) + i += 1 + f_sym.insert(next_nt[1] - off, sym_setindex(self.category, next_nt[0])) + off += (nt_len - 1) + + e_sym = list(e_span[:]) + off = e_i + for next_nt in nt_inv: + nt_len = (next_nt[4] - next_nt[3]) + 1 + i = 0 + while i < nt_len: + e_sym.pop(next_nt[3] - off) + i += 1 + e_sym.insert(next_nt[3] - off, sym_setindex(self.category, next_nt[0])) + off += (nt_len - 1) + + # Adjusting alignment links takes some doing + links = [list(link) for sub in al for link in sub] + links_len = len(links) + nt_len = len(nt) + nt_i = 0 + off = f_i + i = 0 + while i < links_len: + while nt_i < nt_len and links[i][0] > nt[nt_i][1]: + off += (nt[nt_i][2] - nt[nt_i][1]) + nt_i += 1 + links[i][0] -= off + i += 1 + nt_i = 0 + off = e_i + i = 0 + while i < links_len: + while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]: + off += (nt_inv[nt_i][4] - nt_inv[nt_i][3]) + nt_i += 1 + links[i][1] -= off + i += 1 + + # Rule + rules.append(fmt_rule(f_sym, e_sym, links)) + if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals: + return rules + last_index = nt[-1][0] if nt else 0 + # Rule [X] + if not nt or not sym_isvar(f_sym[-1]): + f_sym.append(sym_setindex(self.category, last_index + 1)) + e_sym.append(sym_setindex(self.category, last_index + 1)) + rules.append(fmt_rule(f_sym, e_sym, links)) + f_sym.pop() + e_sym.pop() + # [X] Rule + f_len = len(f_sym) + e_len = len(e_sym) + if not nt or not sym_isvar(f_sym[0]): + for i from 0 <= i < f_len: + if sym_isvar(f_sym[i]): + f_sym[i] = sym_setindex(self.category, sym_getindex(f_sym[i]) + 1) + for i from 0 <= i < e_len: + if sym_isvar(e_sym[i]): + e_sym[i] = sym_setindex(self.category, sym_getindex(e_sym[i]) + 1) + for link in links: + link[0] += 1 + link[1] += 1 + f_sym.insert(0, sym_setindex(self.category, 1)) + e_sym.insert(0, sym_setindex(self.category, 1)) + rules.append(fmt_rule(f_sym, e_sym, links)) + if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals: + return rules + # [X] Rule [X] + if not nt or not sym_isvar(f_sym[-1]): + f_sym.append(sym_setindex(self.category, last_index + 2)) + e_sym.append(sym_setindex(self.category, last_index + 2)) + rules.append(fmt_rule(f_sym, e_sym, links)) + return rules + + # Debugging + def dump_online_stats(self): + logger.info(self.phrases_f) + logger.info(self.phrases_e) + logger.info(self.phrases_fe) + +# Spans are _inclusive_ on both ends [i, j] +# TODO: Replace all of this with bit vectors? +def span_check(vec, i, j): + k = i + while k <= j: + if vec[k]: + return False + k += 1 + return True + +def span_flip(vec, i, j): + k = i + while k <= j: + vec[k] = ~vec[k] + k += 1 + +def fmt_rule(f_sym, e_sym, links): + a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links) + return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(sym_tostring(sym) for sym in f_sym), + ' '.join(sym_tostring(sym) for sym in e_sym), + a_str) + + #(f, e, count, als) = e + #a = tuple('{0}-{1}'.format(packed/65536, packed%65536) for packed in als) + #logger.info("f: {0}, e: {1}, count: {2}, a: {3}".format(f, e, count, a)) |