diff options
author | Michael Denkowski <michael.j.denkowski@gmail.com> | 2013-01-07 10:03:35 -0500 |
---|---|---|
committer | Michael Denkowski <michael.j.denkowski@gmail.com> | 2013-01-07 10:03:35 -0500 |
commit | df1793f8ba6b4ea8097c94319eb93838bc497c28 (patch) | |
tree | 5609cae260a57415f18dac13ecd8c138cc1426d3 /python/src | |
parent | 99763ee6ad503f1264df2f661317d6212d30272d (diff) |
code cleanup
Diffstat (limited to 'python/src')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 109 |
1 files changed, 42 insertions, 67 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index 29bd809c..2bc65da2 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -1829,10 +1829,10 @@ cdef class HieroCachingRuleFactory: def add_instance(self, f_words, e_words, alignment): # Rules extracted from this instance - # Track span of absolute alignment links to make + # Track span of lexical items (terminals) to make # sure we don't extract the same rule for the same # span more than once. - # (f, e, al, f_min_link, f_max_link) + # (f, e, al, lex_f_i, lex_f_j) rules = set() f_len = len(f_words) @@ -1852,14 +1852,14 @@ cdef class HieroCachingRuleFactory: # Target side word coverage cover = [0] * e_len # Non-terminal coverage - f_nt_cover = [0] * e_len + f_nt_cover = [0] * f_len e_nt_cover = [0] * e_len # Extract all possible hierarchical phrases starting at a source index # f_ i and j are current, e_ i and j are previous def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open): # Phrase extraction limits - if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \ + if wc + len(nt) > self.max_length or f_j > (f_len - 1) or \ (f_j - f_i) + 1 > self.max_initial_size: return # Unaligned word @@ -1881,7 +1881,10 @@ cdef class HieroCachingRuleFactory: # Check reverse links of newly covered words to see if they violate left # bound (return) or extend minimum right bound for chunk new_min_bound = min_bound + #new_min_nt_bound = min_nt_bound + # violates_nt = False # First aligned word creates span + # TODO: NO bound if e_j == -1: for i from new_e_i <= i <= new_e_j: if ef_span[i][0] < f_i: @@ -1897,61 +1900,28 @@ cdef class HieroCachingRuleFactory: if ef_span[i][0] < f_i: return new_min_bound = max(new_min_bound, ef_span[i][1]) - # Open non-terminal: close, extract, extend - if nt_open: - # Close non-terminal, checking for collisions - old_last_nt = nt[-1][:] - nt[-1][2] = f_j - if link_i < nt[-1][3]: - if not span_check(cover, link_i, nt[-1][3] - 1): - nt[-1] = old_last_nt - return - span_inc(cover, link_i, nt[-1][3] - 1) - span_inc(e_nt_cover, link_i, nt[-1][3] - 1) - nt[-1][3] = link_i - if link_j > nt[-1][4]: - if not span_check(cover, nt[-1][4] + 1, link_j): - nt[-1] = old_last_nt - return - span_inc(cover, nt[-1][4] + 1, link_j) - span_inc(e_nt_cover, nt[-1][4] + 1, link_j) - nt[-1][4] = link_j - # Make sure we have at least one lexical alignment link - if links: - # Make sure we cover all aligned words - if f_j >= new_min_bound: - rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) - extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) - nt[-1] = old_last_nt - if link_i < nt[-1][3]: - span_dec(cover, link_i, nt[-1][3] - 1) - span_dec(e_nt_cover, link_i, nt[-1][3] - 1) - if link_j > nt[-1][4]: - span_dec(cover, nt[-1][4] + 1, link_j) - span_dec(e_nt_cover, nt[-1][4] + 1, link_j) - return - # No open non-terminal - # Extract, extend with word - nt_collision = False - for link in al[f_j]: - if e_nt_cover[link]: - nt_collision = True - # Non-terminal collisions block word extraction and extension, but - # may be okay for continuing non-terminals - if not nt_collision: - plus_links = [] - for link in al[f_j]: - plus_links.append((f_j, link)) - cover[link] += 1 - links.append(plus_links) - if links: - if f_j >= new_min_bound: - rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) - extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) - links.pop() + # Extract, extend with word (unless non-terminal open) + if not nt_open: + nt_collision = False for link in al[f_j]: - cover[link] -= 1 - # Try to add a word to a (closed) non-terminal, extract, extend + if e_nt_cover[link]: + nt_collision = True + # Non-terminal collisions block word extraction and extension, but + # may be okay for continuing non-terminals + if not nt_collision: + plus_links = [] + for link in al[f_j]: + plus_links.append((f_j, link)) + cover[link] += 1 + links.append(plus_links) + if links: + if f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) + links.pop() + for link in al[f_j]: + cover[link] -= 1 + # Try to add a word to current non-terminal (if any), extract, extend if nt and nt[-1][2] == f_j - 1: # Add to non-terminal, checking for collisions old_last_nt = nt[-1][:] @@ -1972,10 +1942,10 @@ cdef class HieroCachingRuleFactory: nt[-1][4] = link_j if links: if f_j >= new_min_bound: - rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) nt[-1] = old_last_nt - if new_e_i < nt[-1][3]: + if link_i < nt[-1][3]: span_dec(cover, link_i, nt[-1][3] - 1) span_dec(e_nt_cover, link_i, nt[-1][3] - 1) if link_j > nt[-1][4]: @@ -2032,10 +2002,8 @@ cdef class HieroCachingRuleFactory: # Create a rule from source, target, non-terminals, and alignments def form_rule(self, f_i, e_i, f_span, e_span, nt, al): - # Handle non-terminals - + # Substitute in non-terminals nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) - f_sym = list(f_span[:]) off = f_i for next_nt in nt: @@ -2046,7 +2014,6 @@ cdef class HieroCachingRuleFactory: i += 1 f_sym.insert(next_nt[1] - off, sym_setindex(self.category, next_nt[0])) off += (nt_len - 1) - e_sym = list(e_span[:]) off = e_i for next_nt in nt_inv: @@ -2060,8 +2027,6 @@ cdef class HieroCachingRuleFactory: # Adjusting alignment links takes some doing links = [list(link) for sub in al for link in sub] - f_link_min = links[0][0] - f_link_max = links[-1][0] links_inv = sorted(links, cmp=lambda x, y: cmp(x[1], y[1])) links_len = len(links) nt_len = len(nt) @@ -2084,11 +2049,21 @@ cdef class HieroCachingRuleFactory: links_inv[i][1] -= off i += 1 + # Find lexical span + lex_f_i = f_i + lex_f_j = f_i + (len(f_span) - 1) + if nt: + if nt[0][1] == lex_f_i: + lex_f_i += (nt[0][2] - nt[0][1]) + 1 + if nt[-1][2] == lex_f_j: + lex_f_j -= (nt[-1][2] - nt[-1][1]) + 1 + # Create rule (f_phrase, e_phrase, links, f_link_min, f_link_max) f = Phrase(f_sym) e = Phrase(e_sym) a = tuple(self.alignment.link(i, j) for (i, j) in links) - return (f, e, a, f_link_min, f_link_max) + print 'New rule:', self.fmt_rule(f, e, a) + return (f, e, a, lex_f_i, lex_f_j) # Rule string from rule def fmt_rule(self, f, e, a): |