diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 86 |
1 files changed, 22 insertions, 64 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index 3fcf8879..29bd809c 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -1829,6 +1829,10 @@ cdef class HieroCachingRuleFactory: def add_instance(self, f_words, e_words, alignment): # Rules extracted from this instance + # Track span of absolute alignment links to make + # sure we don't extract the same rule for the same + # span more than once. + # (f, e, al, f_min_link, f_max_link) rules = set() f_len = len(f_words) @@ -1916,8 +1920,7 @@ cdef class HieroCachingRuleFactory: if links: # Make sure we cover all aligned words if f_j >= new_min_bound: - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) nt[-1] = old_last_nt if link_i < nt[-1][3]: @@ -1943,8 +1946,7 @@ cdef class HieroCachingRuleFactory: links.append(plus_links) if links: if f_j >= new_min_bound: - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) links.pop() for link in al[f_j]: @@ -1970,8 +1972,7 @@ cdef class HieroCachingRuleFactory: nt[-1][4] = link_j if links: if f_j >= new_min_bound: - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) nt[-1] = old_last_nt if new_e_i < nt[-1][3]: @@ -1991,8 +1992,7 @@ cdef class HieroCachingRuleFactory: # Require at least one word in phrase if links: if f_j >= new_min_bound: - for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): - rules.add(rule) + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) nt.pop() span_dec(cover, link_i, link_j) @@ -2009,21 +2009,16 @@ cdef class HieroCachingRuleFactory: f_i += 1 for rule in sorted(rules): - logger.info(self.fmt_rule(*rule)) + logger.info(self.fmt_rule(*rule[:3])) # Update phrase counts - f_set = set() - e_set = set() - for (f_ph, e_ph, al) in rules: - f_set.add(f_ph) - e_set.add(e_ph) + for rule in rules: + (f_ph, e_ph, al) = rule[:3] + self.phrases_f[f_ph] += 1 + self.phrases_e[e_ph] += 1 self.phrases_fe[f_ph][e_ph] += 1 if not self.phrases_al[f_ph][e_ph]: self.phrases_al[f_ph][e_ph] = al - for f_ph in f_set: - self.phrases_f[f_ph] += 1 - for e_ph in e_set: - self.phrases_e[e_ph] += 1 # Update Bilexical counts for e_w in e_words: @@ -2035,11 +2030,10 @@ cdef class HieroCachingRuleFactory: # Create a rule from source, target, non-terminals, and alignments - def form_rules(self, f_i, e_i, f_span, e_span, nt, al): - - # This could be more efficient but is unlikely to be the bottleneck - rules = [] + def form_rule(self, f_i, e_i, f_span, e_span, nt, al): + # Handle non-terminals + nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) f_sym = list(f_span[:]) @@ -2066,6 +2060,8 @@ cdef class HieroCachingRuleFactory: # Adjusting alignment links takes some doing links = [list(link) for sub in al for link in sub] + f_link_min = links[0][0] + f_link_max = links[-1][0] links_inv = sorted(links, cmp=lambda x, y: cmp(x[1], y[1])) links_len = len(links) nt_len = len(nt) @@ -2088,51 +2084,13 @@ cdef class HieroCachingRuleFactory: links_inv[i][1] -= off i += 1 - # Rule - rules.append(self.new_rule(f_sym, e_sym, links)) - if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals: - return rules - # DEBUG: no boundary NTs - return rules - last_index = nt[-1][0] if nt else 0 - # Rule [X] - if not nt or not sym_isvar(f_sym[-1]): - f_sym.append(sym_setindex(self.category, last_index + 1)) - e_sym.append(sym_setindex(self.category, last_index + 1)) - rules.append(self.new_rule(f_sym, e_sym, links)) - f_sym.pop() - e_sym.pop() - # [X] Rule - f_len = len(f_sym) - e_len = len(e_sym) - if not nt or not sym_isvar(f_sym[0]): - for i from 0 <= i < f_len: - if sym_isvar(f_sym[i]): - f_sym[i] = sym_setindex(self.category, sym_getindex(f_sym[i]) + 1) - for i from 0 <= i < e_len: - if sym_isvar(e_sym[i]): - e_sym[i] = sym_setindex(self.category, sym_getindex(e_sym[i]) + 1) - for link in links: - link[0] += 1 - link[1] += 1 - f_sym.insert(0, sym_setindex(self.category, 1)) - e_sym.insert(0, sym_setindex(self.category, 1)) - rules.append(self.new_rule(f_sym, e_sym, links)) - if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals: - return rules - # [X] Rule [X] - if not nt or not sym_isvar(f_sym[-1]): - f_sym.append(sym_setindex(self.category, last_index + 2)) - e_sym.append(sym_setindex(self.category, last_index + 2)) - rules.append(self.new_rule(f_sym, e_sym, links)) - return rules - - def new_rule(self, f_sym, e_sym, links): + # Create rule (f_phrase, e_phrase, links, f_link_min, f_link_max) f = Phrase(f_sym) e = Phrase(e_sym) a = tuple(self.alignment.link(i, j) for (i, j) in links) - return (f, e, a) - + return (f, e, a, f_link_min, f_link_max) + + # Rule string from rule def fmt_rule(self, f, e, a): a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a) return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str) |