diff options
Diffstat (limited to 'python/src/sa/online_extractor.py')
-rwxr-xr-x | python/src/sa/online_extractor.py | 48 |
1 files changed, 34 insertions, 14 deletions
diff --git a/python/src/sa/online_extractor.py b/python/src/sa/online_extractor.py index fd4bb5f5..06eb5357 100755 --- a/python/src/sa/online_extractor.py +++ b/python/src/sa/online_extractor.py @@ -34,17 +34,11 @@ def next_nt(nt): # Create a rule from source, target, non-terminals, and alignments def form_rule(f_i, e_i, f_span, e_span, nt, al): - flat = (item for sub in al for item in sub) - astr = ' '.join('{0}-{1}'.format(x[0], x[1]) for x in flat) - -# print '--- Rule' -# print f_span -# print e_span -# print nt -# print astr -# print '---' # This could be more efficient but is unlikely to be the bottleneck + + nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) + f_sym = f_span[:] off = f_i for next_nt in nt: @@ -55,9 +49,10 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al): i += 1 f_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0])) off += (nt_len - 1) + e_sym = e_span[:] off = e_i - for next_nt in sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])): + for next_nt in nt_inv: nt_len = (next_nt[4] - next_nt[3]) + 1 i = 0 while i < nt_len: @@ -65,7 +60,32 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al): i += 1 e_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0])) off += (nt_len - 1) - return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(f_sym), ' '.join(e_sym), astr) + + # Adjusting alignment links takes some doing + links = [list(link) for sub in al for link in sub] + links_len = len(links) + nt_len = len(nt) + nt_i = 0 + off = f_i + i = 0 + while i < links_len: + while nt_i < nt_len and links[i][0] > nt[nt_i][1]: + off += (nt[nt_i][2] - nt[nt_i][1]) + nt_i += 1 + links[i][0] -= off + i += 1 + nt_i = 0 + off = e_i + i = 0 + while i < links_len: + while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]: + off += (nt_inv[nt_i][4] - nt_inv[nt_i][3]) + nt_i += 1 + links[i][1] -= off + i += 1 + a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links) + + return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(f_sym), ' '.join(e_sym), a_str) class OnlineGrammarExtractor: @@ -135,7 +155,7 @@ class OnlineGrammarExtractor: def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open): # Phrase extraction limits if wc > self.max_length or (f_j + 1) >= f_len or \ - (f_j - f_i) + 1 > self.max_size or len(nt) > self.max_nonterminals: + (f_j - f_i) + 1 > self.max_size: return # Unaligned word if not al[f_j]: @@ -224,7 +244,7 @@ class OnlineGrammarExtractor: if link_j > nt[-1][4]: span_flip(cover, nt[-1][4] + 1, link_j) # Try to start a new non-terminal, extract, extend - if not nt or f_j - nt[-1][2] > 1: + if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals: # Check for collisions if not span_check(cover, link_i, link_j): return @@ -238,7 +258,7 @@ class OnlineGrammarExtractor: span_flip(cover, link_i, link_j) # TODO: try adding NT to start, end, both # check: one aligned word on boundary that is not part of a NT - + # Try to extract phrases from every f index f_i = 0 while f_i < f_len: |