diff options
-rwxr-xr-x | python/src/sa/online_extractor.py | 156 |
1 files changed, 94 insertions, 62 deletions
diff --git a/python/src/sa/online_extractor.py b/python/src/sa/online_extractor.py index 0c013cb3..90087f30 100755 --- a/python/src/sa/online_extractor.py +++ b/python/src/sa/online_extractor.py @@ -44,63 +44,6 @@ def fmt_rule(f_sym, e_sym, links): ' '.join(str(sym) for sym in e_sym), a_str) -# Create a rule from source, target, non-terminals, and alignments -def form_rules(f_i, e_i, f_span, e_span, nt, al): - - # This could be more efficient but is unlikely to be the bottleneck - - rules = [] - - nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) - - f_sym = f_span[:] - off = f_i - for next_nt in nt: - nt_len = (next_nt[2] - next_nt[1]) + 1 - i = 0 - while i < nt_len: - f_sym.pop(next_nt[1] - off) - i += 1 - f_sym.insert(next_nt[1] - off, NonTerminal(next_nt[0])) - off += (nt_len - 1) - - e_sym = e_span[:] - off = e_i - for next_nt in nt_inv: - nt_len = (next_nt[4] - next_nt[3]) + 1 - i = 0 - while i < nt_len: - e_sym.pop(next_nt[3] - off) - i += 1 - e_sym.insert(next_nt[3] - off, NonTerminal(next_nt[0])) - off += (nt_len - 1) - - # Adjusting alignment links takes some doing - links = [list(link) for sub in al for link in sub] - links_len = len(links) - nt_len = len(nt) - nt_i = 0 - off = f_i - i = 0 - while i < links_len: - while nt_i < nt_len and links[i][0] > nt[nt_i][1]: - off += (nt[nt_i][2] - nt[nt_i][1]) - nt_i += 1 - links[i][0] -= off - i += 1 - nt_i = 0 - off = e_i - i = 0 - while i < links_len: - while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]: - off += (nt_inv[nt_i][4] - nt_inv[nt_i][3]) - nt_i += 1 - links[i][1] -= off - i += 1 - - rules.append(fmt_rule(f_sym, e_sym, links)) - return rules - class OnlineGrammarExtractor: def __init__(self, config=None): @@ -168,7 +111,7 @@ class OnlineGrammarExtractor: # f_ i and j are current, e_ i and j are previous def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open): # Phrase extraction limits - if wc > self.max_length or (f_j + 1) >= f_len or \ + if wc > self.max_length or (f_j + 1) > f_len or \ (f_j - f_i) + 1 > self.max_size: return # Unaligned word @@ -204,7 +147,7 @@ class OnlineGrammarExtractor: return span_flip(cover, nt[-1][4] + 1, link_j) nt[-1][4] = link_j - for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): phrases.add(rule) extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) nt[-1] = old_last_nt @@ -227,7 +170,7 @@ class OnlineGrammarExtractor: plus_links.append((f_j, link)) cover[link] = ~cover[link] links.append(plus_links) - for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): phrases.add(rule) extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False) links.pop() @@ -252,7 +195,7 @@ class OnlineGrammarExtractor: nt[-1][4] = link_j # Require at least one word in phrase if links: - for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): phrases.add(rule) extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) nt[-1] = old_last_nt @@ -269,7 +212,7 @@ class OnlineGrammarExtractor: nt.append([next_nt(nt), f_j, f_j, link_i, link_j]) # Require at least one word in phrase if links: - for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): phrases.add(rule) extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) nt.pop() @@ -290,6 +233,95 @@ class OnlineGrammarExtractor: for rule in sorted(phrases): print rule + # Create a rule from source, target, non-terminals, and alignments + def form_rules(self, f_i, e_i, f_span, e_span, nt, al): + + # This could be more efficient but is unlikely to be the bottleneck + + rules = [] + + nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) + + f_sym = f_span[:] + off = f_i + for next_nt in nt: + nt_len = (next_nt[2] - next_nt[1]) + 1 + i = 0 + while i < nt_len: + f_sym.pop(next_nt[1] - off) + i += 1 + f_sym.insert(next_nt[1] - off, NonTerminal(next_nt[0])) + off += (nt_len - 1) + + e_sym = e_span[:] + off = e_i + for next_nt in nt_inv: + nt_len = (next_nt[4] - next_nt[3]) + 1 + i = 0 + while i < nt_len: + e_sym.pop(next_nt[3] - off) + i += 1 + e_sym.insert(next_nt[3] - off, NonTerminal(next_nt[0])) + off += (nt_len - 1) + + # Adjusting alignment links takes some doing + links = [list(link) for sub in al for link in sub] + links_len = len(links) + nt_len = len(nt) + nt_i = 0 + off = f_i + i = 0 + while i < links_len: + while nt_i < nt_len and links[i][0] > nt[nt_i][1]: + off += (nt[nt_i][2] - nt[nt_i][1]) + nt_i += 1 + links[i][0] -= off + i += 1 + nt_i = 0 + off = e_i + i = 0 + while i < links_len: + while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]: + off += (nt_inv[nt_i][4] - nt_inv[nt_i][3]) + nt_i += 1 + links[i][1] -= off + i += 1 + + # Rule + rules.append(fmt_rule(f_sym, e_sym, links)) + if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals: + return rules + last_index = nt[-1][0] if nt else 0 + # Rule [X] + if not nt or not isinstance(f_sym[-1], NonTerminal): + f_sym.append(NonTerminal(last_index + 1)) + e_sym.append(NonTerminal(last_index + 1)) + rules.append(fmt_rule(f_sym, e_sym, links)) + f_sym.pop() + e_sym.pop() + # [X] Rule + if not nt or not isinstance(f_sym[0], NonTerminal): + for sym in f_sym: + if isinstance(sym, NonTerminal): + sym.index += 1 + for sym in e_sym: + if isinstance(sym, NonTerminal): + sym.index += 1 + for link in links: + link[0] += 1 + link[1] += 1 + f_sym.insert(0, NonTerminal(1)) + e_sym.insert(0, NonTerminal(1)) + rules.append(fmt_rule(f_sym, e_sym, links)) + if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals: + return rules + # [X] Rule [X] + if not nt or not isinstance(f_sym[-1], NonTerminal): + f_sym.append(NonTerminal(last_index + 2)) + e_sym.append(NonTerminal(last_index + 2)) + rules.append(fmt_rule(f_sym, e_sym, links)) + return rules + def main(argv): extractor = OnlineGrammarExtractor() |