diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 48 |
1 files changed, 18 insertions, 30 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index 2bc65da2..b10d25dd 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -1840,8 +1840,8 @@ cdef class HieroCachingRuleFactory: # Pre-compute alignment info al = [[] for i in range(f_len)] - fe_span = [[f_len + 1, -1] for i in range(f_len)] - ef_span = [[e_len + 1, -1] for i in range(e_len)] + fe_span = [[e_len + 1, -1] for i in range(f_len)] + ef_span = [[f_len + 1, -1] for i in range(e_len)] for (f, e) in alignment: al[f].append(e) fe_span[f][0] = min(fe_span[f][0], e) @@ -1857,10 +1857,10 @@ cdef class HieroCachingRuleFactory: # Extract all possible hierarchical phrases starting at a source index # f_ i and j are current, e_ i and j are previous + # We care _considering_ f_j, so it is not yet in counts def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open): # Phrase extraction limits - if wc + len(nt) > self.max_length or f_j > (f_len - 1) or \ - (f_j - f_i) + 1 > self.max_initial_size: + if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size: return # Unaligned word if not al[f_j]: @@ -1870,7 +1870,8 @@ cdef class HieroCachingRuleFactory: extract(f_i, f_j + 1, e_i, e_j, min_bound, wc, links, nt, True) nt[-1][2] -= 1 # Unless non-terminal already open, always extend with word - if not nt_open: + # Make sure adding a word doesn't exceed length + if not nt_open and wc < self.max_length: extract(f_i, f_j + 1, e_i, e_j, min_bound, wc + 1, links, nt, False) return # Aligned word @@ -1881,11 +1882,8 @@ cdef class HieroCachingRuleFactory: # Check reverse links of newly covered words to see if they violate left # bound (return) or extend minimum right bound for chunk new_min_bound = min_bound - #new_min_nt_bound = min_nt_bound - # violates_nt = False # First aligned word creates span - # TODO: NO bound - if e_j == -1: + if e_j == -1: for i from new_e_i <= i <= new_e_j: if ef_span[i][0] < f_i: return @@ -1908,15 +1906,14 @@ cdef class HieroCachingRuleFactory: nt_collision = True # Non-terminal collisions block word extraction and extension, but # may be okay for continuing non-terminals - if not nt_collision: + if not nt_collision and wc < self.max_length: plus_links = [] for link in al[f_j]: plus_links.append((f_j, link)) cover[link] += 1 links.append(plus_links) - if links: - if f_j >= new_min_bound: - rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + if links and f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) links.pop() for link in al[f_j]: @@ -1940,9 +1937,8 @@ cdef class HieroCachingRuleFactory: span_inc(cover, nt[-1][4] + 1, link_j) span_inc(e_nt_cover, nt[-1][4] + 1, link_j) nt[-1][4] = link_j - if links: - if f_j >= new_min_bound: - rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + if links and f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) nt[-1] = old_last_nt if link_i < nt[-1][3]: @@ -1952,7 +1948,7 @@ cdef class HieroCachingRuleFactory: span_dec(cover, nt[-1][4] + 1, link_j) span_dec(e_nt_cover, nt[-1][4] + 1, link_j) # Try to start a new non-terminal, extract, extend - if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals: + if (not nt or f_j - nt[-1][2] > 1) and wc < self.max_length and len(nt) < self.max_nonterminals: # Check for collisions if not span_check(cover, link_i, link_j): return @@ -1960,26 +1956,19 @@ cdef class HieroCachingRuleFactory: span_inc(e_nt_cover, link_i, link_j) nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j]) # Require at least one word in phrase - if links: - if f_j >= new_min_bound: - rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) - extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) + if links and f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) nt.pop() span_dec(cover, link_i, link_j) span_dec(e_nt_cover, link_i, link_j) # Try to extract phrases from every f index - f_i = 0 - while f_i < f_len: + for f_i from 0 <= f_i < f_len: # Skip if phrases won't be tight on left side if not al[f_i]: - f_i += 1 continue - extract(f_i, f_i, f_len + 1, -1, f_i, 1, [], [], False) - f_i += 1 - - for rule in sorted(rules): - logger.info(self.fmt_rule(*rule[:3])) + extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False) # Update phrase counts for rule in rules: @@ -2062,7 +2051,6 @@ cdef class HieroCachingRuleFactory: f = Phrase(f_sym) e = Phrase(e_sym) a = tuple(self.alignment.link(i, j) for (i, j) in links) - print 'New rule:', self.fmt_rule(f, e, a) return (f, e, a, lex_f_i, lex_f_j) # Rule string from rule |