diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 156 |
1 files changed, 102 insertions, 54 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index c26f5c43..be73f567 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -1836,18 +1836,24 @@ cdef class HieroCachingRuleFactory: # Pre-compute alignment info al = [[] for i in range(f_len)] - al_span = [[f_len + 1, -1] for i in range(f_len)] + fe_span = [[f_len + 1, -1] for i in range(f_len)] + ef_span = [[e_len + 1, -1] for i in range(e_len)] for (f, e) in alignment: al[f].append(e) - al_span[f][0] = min(al_span[f][0], e) - al_span[f][1] = max(al_span[f][1], e) + fe_span[f][0] = min(fe_span[f][0], e) + fe_span[f][1] = max(fe_span[f][1], e) + ef_span[e][0] = min(ef_span[e][0], f) + ef_span[e][1] = max(ef_span[e][1], f) # Target side word coverage cover = [0] * e_len - + # Non-terminal coverage + f_nt_cover = [0] * e_len + e_nt_cover = [0] * e_len + # Extract all possible hierarchical phrases starting at a source index # f_ i and j are current, e_ i and j are previous - def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open): + def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open): # Phrase extraction limits if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \ (f_j - f_i) + 1 > self.max_initial_size: @@ -1857,17 +1863,36 @@ cdef class HieroCachingRuleFactory: # Open non-terminal: extend if nt_open: nt[-1][2] += 1 - extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True) + extract(f_i, f_j + 1, e_i, e_j, min_bound, wc, links, nt, True) nt[-1][2] -= 1 # No open non-terminal: extend with word else: - extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False) + extract(f_i, f_j + 1, e_i, e_j, min_bound, wc + 1, links, nt, False) return # Aligned word - link_i = al_span[f_j][0] - link_j = al_span[f_j][1] + link_i = fe_span[f_j][0] + link_j = fe_span[f_j][1] new_e_i = min(link_i, e_i) new_e_j = max(link_j, e_j) + # Check reverse links of newly covered words to see if they violate left + # bound (return) or extend minimum right bound for chunk + new_min_bound = min_bound + # First aligned word creates span + if e_j == -1: + for i from new_e_i <= i <= new_e_j: + if ef_span[i][0] < f_i: + return + new_min_bound = max(new_min_bound, ef_span[i][1]) + # Other aligned words extend span + else: + for i from new_e_i <= i < e_i: + if ef_span[i][0] < f_i: + return + new_min_bound = max(new_min_bound, ef_span[i][1]) + for i from e_j < i <= new_e_j: + if ef_span[i][0] < f_i: + return + new_min_bound = max(new_min_bound, ef_span[i][1]) # Open non-terminal: close, extract, extend if nt_open: # Close non-terminal, checking for collisions @@ -1877,43 +1902,50 @@ cdef class HieroCachingRuleFactory: if not span_check(cover, link_i, nt[-1][3] - 1): nt[-1] = old_last_nt return - span_flip(cover, link_i, nt[-1][3] - 1) + span_inc(cover, link_i, nt[-1][3] - 1) + span_inc(e_nt_cover, link_i, nt[-1][3] - 1) nt[-1][3] = link_i if link_j > nt[-1][4]: if not span_check(cover, nt[-1][4] + 1, link_j): nt[-1] = old_last_nt return - span_flip(cover, nt[-1][4] + 1, link_j) + span_inc(cover, nt[-1][4] + 1, link_j) + span_inc(e_nt_cover, nt[-1][4] + 1, link_j) nt[-1][4] = link_j - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) - extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + # Make sure we cover all aligned words + if f_j >= new_min_bound: + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + rules.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) nt[-1] = old_last_nt if link_i < nt[-1][3]: - span_flip(cover, link_i, nt[-1][3] - 1) + span_dec(cover, link_i, nt[-1][3] - 1) + span_dec(e_nt_cover, link_i, nt[-1][3] - 1) if link_j > nt[-1][4]: - span_flip(cover, nt[-1][4] + 1, link_j) + span_dec(cover, nt[-1][4] + 1, link_j) + span_dec(e_nt_cover, nt[-1][4] + 1, link_j) return # No open non-terminal # Extract, extend with word - collision = False + nt_collision = False for link in al[f_j]: - if cover[link]: - collision = True - # Collisions block extraction and extension, but may be okay for - # continuing non-terminals - if not collision: + if e_nt_cover[link]: + nt_collision = True + # Non-terminal collisions block word extraction and extension, but + # may be okay for continuing non-terminals + if not nt_collision: plus_links = [] for link in al[f_j]: plus_links.append((f_j, link)) - cover[link] = ~cover[link] + cover[link] += 1 links.append(plus_links) - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) - extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False) + if f_j >= new_min_bound: + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + rules.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) links.pop() for link in al[f_j]: - cover[link] = ~cover[link] + cover[link] -= 1 # Try to add a word to a (closed) non-terminal, extract, extend if nt and nt[-1][2] == f_j - 1: # Add to non-terminal, checking for collisions @@ -1923,38 +1955,46 @@ cdef class HieroCachingRuleFactory: if not span_check(cover, link_i, nt[-1][3] - 1): nt[-1] = old_last_nt return - span_flip(cover, link_i, nt[-1][3] - 1) + span_inc(cover, link_i, nt[-1][3] - 1) + span_inc(e_nt_cover, link_i, nt[-1][3] - 1) nt[-1][3] = link_i if link_j > nt[-1][4]: if not span_check(cover, nt[-1][4] + 1, link_j): nt[-1] = old_last_nt return - span_flip(cover, nt[-1][4] + 1, link_j) + span_inc(cover, nt[-1][4] + 1, link_j) + span_inc(e_nt_cover, nt[-1][4] + 1, link_j) nt[-1][4] = link_j # Require at least one word in phrase if links: - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) - extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + if f_j >= new_min_bound: + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + rules.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) nt[-1] = old_last_nt if new_e_i < nt[-1][3]: - span_flip(cover, link_i, nt[-1][3] - 1) + span_dec(cover, link_i, nt[-1][3] - 1) + span_dec(e_nt_cover, link_i, nt[-1][3] - 1) if link_j > nt[-1][4]: - span_flip(cover, nt[-1][4] + 1, link_j) + span_dec(cover, nt[-1][4] + 1, link_j) + span_dec(e_nt_cover, nt[-1][4] + 1, link_j) # Try to start a new non-terminal, extract, extend if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals: # Check for collisions if not span_check(cover, link_i, link_j): return - span_flip(cover, link_i, link_j) + span_inc(cover, link_i, link_j) + span_inc(e_nt_cover, link_i, link_j) nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j]) # Require at least one word in phrase if links: - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) - extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + if f_j >= new_min_bound: + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + rules.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) nt.pop() - span_flip(cover, link_i, link_j) + span_dec(cover, link_i, link_j) + span_dec(e_nt_cover, link_i, link_j) # Try to extract phrases from every f index f_i = 0 @@ -1963,7 +2003,7 @@ cdef class HieroCachingRuleFactory: if not al[f_i]: f_i += 1 continue - extract(f_i, f_i, f_len + 1, -1, 1, [], [], False) + extract(f_i, f_i, f_len + 1, -1, f_i, 1, [], [], False) f_i += 1 for rule in sorted(rules): @@ -2045,11 +2085,13 @@ cdef class HieroCachingRuleFactory: nt_i += 1 links[i][1] -= off i += 1 - + # Rule rules.append(self.new_rule(f_sym, e_sym, links)) if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals: return rules + # DEBUG: no boundary NTs + return rules last_index = nt[-1][0] if nt else 0 # Rule [X] if not nt or not sym_isvar(f_sym[-1]): @@ -2098,16 +2140,6 @@ cdef class HieroCachingRuleFactory: logger.info('------------------------------') logger.info(' Online Stats ') logger.info('------------------------------') - logger.info('F') - for ph in self.phrases_f: - logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph])) - logger.info('E') - for ph in self.phrases_e: - logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph])) - logger.info('FE') - for ph in self.phrases_fe: - for ph2 in self.phrases_fe[ph]: - logger.info(str(ph) + ' ||| ' + str(ph2) + ' ||| ' + str(self.phrases_fe[ph][ph2])) logger.info('f') for w in self.bilex_f: logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w])) @@ -2118,6 +2150,16 @@ cdef class HieroCachingRuleFactory: for w in self.bilex_fe: for w2 in self.bilex_fe[w]: logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2])) + logger.info('F') + for ph in self.phrases_f: + logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph])) + logger.info('E') + for ph in self.phrases_e: + logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph])) + logger.info('FE') + for ph in self.phrases_fe: + for ph2 in self.phrases_fe[ph]: + logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2])) # Spans are _inclusive_ on both ends [i, j] # Could be more efficient but probably not a bottleneck @@ -2129,8 +2171,14 @@ def span_check(vec, i, j): k += 1 return True -def span_flip(vec, i, j): +def span_inc(vec, i, j): k = i while k <= j: - vec[k] = ~vec[k] - k += 1
\ No newline at end of file + vec[k] += 1 + k += 1 + +def span_dec(vec, i, j): + k = i + while k <= j: + vec[k] -= 1 + k += 1 |