diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 284 |
1 files changed, 127 insertions, 157 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index c26f5c43..b10d25dd 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -1829,6 +1829,10 @@ cdef class HieroCachingRuleFactory: def add_instance(self, f_words, e_words, alignment): # Rules extracted from this instance + # Track span of lexical items (terminals) to make + # sure we don't extract the same rule for the same + # span more than once. + # (f, e, al, lex_f_i, lex_f_j) rules = set() f_len = len(f_words) @@ -1836,85 +1840,85 @@ cdef class HieroCachingRuleFactory: # Pre-compute alignment info al = [[] for i in range(f_len)] - al_span = [[f_len + 1, -1] for i in range(f_len)] + fe_span = [[e_len + 1, -1] for i in range(f_len)] + ef_span = [[f_len + 1, -1] for i in range(e_len)] for (f, e) in alignment: al[f].append(e) - al_span[f][0] = min(al_span[f][0], e) - al_span[f][1] = max(al_span[f][1], e) + fe_span[f][0] = min(fe_span[f][0], e) + fe_span[f][1] = max(fe_span[f][1], e) + ef_span[e][0] = min(ef_span[e][0], f) + ef_span[e][1] = max(ef_span[e][1], f) # Target side word coverage cover = [0] * e_len - + # Non-terminal coverage + f_nt_cover = [0] * f_len + e_nt_cover = [0] * e_len + # Extract all possible hierarchical phrases starting at a source index # f_ i and j are current, e_ i and j are previous - def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open): + # We care _considering_ f_j, so it is not yet in counts + def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open): # Phrase extraction limits - if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \ - (f_j - f_i) + 1 > self.max_initial_size: + if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size: return # Unaligned word if not al[f_j]: - # Open non-terminal: extend - if nt_open: + # Adjacent to non-terminal: extend (non-terminal now open) + if nt and nt[-1][2] == f_j - 1: nt[-1][2] += 1 - extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True) + extract(f_i, f_j + 1, e_i, e_j, min_bound, wc, links, nt, True) nt[-1][2] -= 1 - # No open non-terminal: extend with word - else: - extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False) + # Unless non-terminal already open, always extend with word + # Make sure adding a word doesn't exceed length + if not nt_open and wc < self.max_length: + extract(f_i, f_j + 1, e_i, e_j, min_bound, wc + 1, links, nt, False) return # Aligned word - link_i = al_span[f_j][0] - link_j = al_span[f_j][1] + link_i = fe_span[f_j][0] + link_j = fe_span[f_j][1] new_e_i = min(link_i, e_i) new_e_j = max(link_j, e_j) - # Open non-terminal: close, extract, extend - if nt_open: - # Close non-terminal, checking for collisions - old_last_nt = nt[-1][:] - nt[-1][2] = f_j - if link_i < nt[-1][3]: - if not span_check(cover, link_i, nt[-1][3] - 1): - nt[-1] = old_last_nt + # Check reverse links of newly covered words to see if they violate left + # bound (return) or extend minimum right bound for chunk + new_min_bound = min_bound + # First aligned word creates span + if e_j == -1: + for i from new_e_i <= i <= new_e_j: + if ef_span[i][0] < f_i: return - span_flip(cover, link_i, nt[-1][3] - 1) - nt[-1][3] = link_i - if link_j > nt[-1][4]: - if not span_check(cover, nt[-1][4] + 1, link_j): - nt[-1] = old_last_nt + new_min_bound = max(new_min_bound, ef_span[i][1]) + # Other aligned words extend span + else: + for i from new_e_i <= i < e_i: + if ef_span[i][0] < f_i: return - span_flip(cover, nt[-1][4] + 1, link_j) - nt[-1][4] = link_j - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) - extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) - nt[-1] = old_last_nt - if link_i < nt[-1][3]: - span_flip(cover, link_i, nt[-1][3] - 1) - if link_j > nt[-1][4]: - span_flip(cover, nt[-1][4] + 1, link_j) - return - # No open non-terminal - # Extract, extend with word - collision = False - for link in al[f_j]: - if cover[link]: - collision = True - # Collisions block extraction and extension, but may be okay for - # continuing non-terminals - if not collision: - plus_links = [] - for link in al[f_j]: - plus_links.append((f_j, link)) - cover[link] = ~cover[link] - links.append(plus_links) - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) - extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False) - links.pop() + new_min_bound = max(new_min_bound, ef_span[i][1]) + for i from e_j < i <= new_e_j: + if ef_span[i][0] < f_i: + return + new_min_bound = max(new_min_bound, ef_span[i][1]) + # Extract, extend with word (unless non-terminal open) + if not nt_open: + nt_collision = False for link in al[f_j]: - cover[link] = ~cover[link] - # Try to add a word to a (closed) non-terminal, extract, extend + if e_nt_cover[link]: + nt_collision = True + # Non-terminal collisions block word extraction and extension, but + # may be okay for continuing non-terminals + if not nt_collision and wc < self.max_length: + plus_links = [] + for link in al[f_j]: + plus_links.append((f_j, link)) + cover[link] += 1 + links.append(plus_links) + if links and f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) + links.pop() + for link in al[f_j]: + cover[link] -= 1 + # Try to add a word to current non-terminal (if any), extract, extend if nt and nt[-1][2] == f_j - 1: # Add to non-terminal, checking for collisions old_last_nt = nt[-1][:] @@ -1923,65 +1927,57 @@ cdef class HieroCachingRuleFactory: if not span_check(cover, link_i, nt[-1][3] - 1): nt[-1] = old_last_nt return - span_flip(cover, link_i, nt[-1][3] - 1) + span_inc(cover, link_i, nt[-1][3] - 1) + span_inc(e_nt_cover, link_i, nt[-1][3] - 1) nt[-1][3] = link_i if link_j > nt[-1][4]: if not span_check(cover, nt[-1][4] + 1, link_j): nt[-1] = old_last_nt return - span_flip(cover, nt[-1][4] + 1, link_j) + span_inc(cover, nt[-1][4] + 1, link_j) + span_inc(e_nt_cover, nt[-1][4] + 1, link_j) nt[-1][4] = link_j - # Require at least one word in phrase - if links: - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) - extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + if links and f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) nt[-1] = old_last_nt - if new_e_i < nt[-1][3]: - span_flip(cover, link_i, nt[-1][3] - 1) + if link_i < nt[-1][3]: + span_dec(cover, link_i, nt[-1][3] - 1) + span_dec(e_nt_cover, link_i, nt[-1][3] - 1) if link_j > nt[-1][4]: - span_flip(cover, nt[-1][4] + 1, link_j) + span_dec(cover, nt[-1][4] + 1, link_j) + span_dec(e_nt_cover, nt[-1][4] + 1, link_j) # Try to start a new non-terminal, extract, extend - if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals: + if (not nt or f_j - nt[-1][2] > 1) and wc < self.max_length and len(nt) < self.max_nonterminals: # Check for collisions if not span_check(cover, link_i, link_j): return - span_flip(cover, link_i, link_j) + span_inc(cover, link_i, link_j) + span_inc(e_nt_cover, link_i, link_j) nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j]) # Require at least one word in phrase - if links: - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) - extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + if links and f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) nt.pop() - span_flip(cover, link_i, link_j) + span_dec(cover, link_i, link_j) + span_dec(e_nt_cover, link_i, link_j) # Try to extract phrases from every f index - f_i = 0 - while f_i < f_len: + for f_i from 0 <= f_i < f_len: # Skip if phrases won't be tight on left side if not al[f_i]: - f_i += 1 continue - extract(f_i, f_i, f_len + 1, -1, 1, [], [], False) - f_i += 1 - - for rule in sorted(rules): - logger.info(self.fmt_rule(*rule)) + extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False) # Update phrase counts - f_set = set() - e_set = set() - for (f_ph, e_ph, al) in rules: - f_set.add(f_ph) - e_set.add(e_ph) + for rule in rules: + (f_ph, e_ph, al) = rule[:3] + self.phrases_f[f_ph] += 1 + self.phrases_e[e_ph] += 1 self.phrases_fe[f_ph][e_ph] += 1 if not self.phrases_al[f_ph][e_ph]: self.phrases_al[f_ph][e_ph] = al - for f_ph in f_set: - self.phrases_f[f_ph] += 1 - for e_ph in e_set: - self.phrases_e[e_ph] += 1 # Update Bilexical counts for e_w in e_words: @@ -1993,14 +1989,10 @@ cdef class HieroCachingRuleFactory: # Create a rule from source, target, non-terminals, and alignments - def form_rules(self, f_i, e_i, f_span, e_span, nt, al): - - # This could be more efficient but is unlikely to be the bottleneck - - rules = [] + def form_rule(self, f_i, e_i, f_span, e_span, nt, al): + # Substitute in non-terminals nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) - f_sym = list(f_span[:]) off = f_i for next_nt in nt: @@ -2011,7 +2003,6 @@ cdef class HieroCachingRuleFactory: i += 1 f_sym.insert(next_nt[1] - off, sym_setindex(self.category, next_nt[0])) off += (nt_len - 1) - e_sym = list(e_span[:]) off = e_i for next_nt in nt_inv: @@ -2025,6 +2016,7 @@ cdef class HieroCachingRuleFactory: # Adjusting alignment links takes some doing links = [list(link) for sub in al for link in sub] + links_inv = sorted(links, cmp=lambda x, y: cmp(x[1], y[1])) links_len = len(links) nt_len = len(nt) nt_i = 0 @@ -2040,55 +2032,28 @@ cdef class HieroCachingRuleFactory: off = e_i i = 0 while i < links_len: - while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]: + while nt_i < nt_len and links_inv[i][1] > nt_inv[nt_i][3]: off += (nt_inv[nt_i][4] - nt_inv[nt_i][3]) nt_i += 1 - links[i][1] -= off + links_inv[i][1] -= off i += 1 - - # Rule - rules.append(self.new_rule(f_sym, e_sym, links)) - if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals: - return rules - last_index = nt[-1][0] if nt else 0 - # Rule [X] - if not nt or not sym_isvar(f_sym[-1]): - f_sym.append(sym_setindex(self.category, last_index + 1)) - e_sym.append(sym_setindex(self.category, last_index + 1)) - rules.append(self.new_rule(f_sym, e_sym, links)) - f_sym.pop() - e_sym.pop() - # [X] Rule - f_len = len(f_sym) - e_len = len(e_sym) - if not nt or not sym_isvar(f_sym[0]): - for i from 0 <= i < f_len: - if sym_isvar(f_sym[i]): - f_sym[i] = sym_setindex(self.category, sym_getindex(f_sym[i]) + 1) - for i from 0 <= i < e_len: - if sym_isvar(e_sym[i]): - e_sym[i] = sym_setindex(self.category, sym_getindex(e_sym[i]) + 1) - for link in links: - link[0] += 1 - link[1] += 1 - f_sym.insert(0, sym_setindex(self.category, 1)) - e_sym.insert(0, sym_setindex(self.category, 1)) - rules.append(self.new_rule(f_sym, e_sym, links)) - if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals: - return rules - # [X] Rule [X] - if not nt or not sym_isvar(f_sym[-1]): - f_sym.append(sym_setindex(self.category, last_index + 2)) - e_sym.append(sym_setindex(self.category, last_index + 2)) - rules.append(self.new_rule(f_sym, e_sym, links)) - return rules - - def new_rule(self, f_sym, e_sym, links): + + # Find lexical span + lex_f_i = f_i + lex_f_j = f_i + (len(f_span) - 1) + if nt: + if nt[0][1] == lex_f_i: + lex_f_i += (nt[0][2] - nt[0][1]) + 1 + if nt[-1][2] == lex_f_j: + lex_f_j -= (nt[-1][2] - nt[-1][1]) + 1 + + # Create rule (f_phrase, e_phrase, links, f_link_min, f_link_max) f = Phrase(f_sym) e = Phrase(e_sym) a = tuple(self.alignment.link(i, j) for (i, j) in links) - return (f, e, a) - + return (f, e, a, lex_f_i, lex_f_j) + + # Rule string from rule def fmt_rule(self, f, e, a): a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a) return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str) @@ -2098,16 +2063,6 @@ cdef class HieroCachingRuleFactory: logger.info('------------------------------') logger.info(' Online Stats ') logger.info('------------------------------') - logger.info('F') - for ph in self.phrases_f: - logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph])) - logger.info('E') - for ph in self.phrases_e: - logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph])) - logger.info('FE') - for ph in self.phrases_fe: - for ph2 in self.phrases_fe[ph]: - logger.info(str(ph) + ' ||| ' + str(ph2) + ' ||| ' + str(self.phrases_fe[ph][ph2])) logger.info('f') for w in self.bilex_f: logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w])) @@ -2118,9 +2073,18 @@ cdef class HieroCachingRuleFactory: for w in self.bilex_fe: for w2 in self.bilex_fe[w]: logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2])) + logger.info('F') + for ph in self.phrases_f: + logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph])) + logger.info('E') + for ph in self.phrases_e: + logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph])) + logger.info('FE') + for ph in self.phrases_fe: + for ph2 in self.phrases_fe[ph]: + logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2])) # Spans are _inclusive_ on both ends [i, j] -# Could be more efficient but probably not a bottleneck def span_check(vec, i, j): k = i while k <= j: @@ -2129,8 +2093,14 @@ def span_check(vec, i, j): k += 1 return True -def span_flip(vec, i, j): +def span_inc(vec, i, j): k = i while k <= j: - vec[k] = ~vec[k] - k += 1
\ No newline at end of file + vec[k] += 1 + k += 1 + +def span_dec(vec, i, j): + k = i + while k <= j: + vec[k] -= 1 + k += 1 |