diff options
author | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-20 15:42:22 -0500 |
---|---|---|
committer | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-20 15:42:22 -0500 |
commit | f6b1b4d636d403fe7258e49c842d7442e638092e (patch) | |
tree | 566235baf46d1a4c9ddea0fd4f70b7941e4b78a2 | |
parent | 1cd434c2c971f23bc9fcbe18062b24c02873126d (diff) |
hiero phrase extraction. Don't trust the word alignments yet.
-rwxr-xr-x | python/src/sa/online_extractor.py | 230 |
1 files changed, 230 insertions, 0 deletions
diff --git a/python/src/sa/online_extractor.py b/python/src/sa/online_extractor.py new file mode 100755 index 00000000..d41f3b39 --- /dev/null +++ b/python/src/sa/online_extractor.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python + +import collections, sys + +def main(argv): + + for line in sys.stdin: + src, tgt, astr = (x.split() for x in line.split('|||')) + al = sorted(tuple(int(y) for y in x.split('-')) for x in astr) + extract_and_aggr(src, tgt, al) + +# Extract hierarchical phrase pairs +# This could be far better optimized by integrating it +# with suffix array code. For now, it gets the job done. +def extract_and_aggr(src, tgt, al, max_len=5, max_size=15, max_nt=2, boundary_nt=True): + + src_ph = collections.defaultdict(lambda: 0) # src = count + tgt_ph = collections.defaultdict(lambda: 0) # tgt = count + # [src][tgt] = count + phrase_pairs = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) + + src_w = collections.defaultdict(lambda: 0) # count + tgt_w = collections.defaultdict(lambda: 0) # count + # [src][tgt] = count + cooc_w = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) + + # Bilexical counts + for word in tgt: + tgt_w[word] += 1 + for word in src: + src_w[word] += 1 + for t_word in tgt: + cooc_w[word][t_word] += 1 + + def next_nt(nt): + if not nt: + return 1 + return nt[-1][0] + 1 + + src_len = len(src) + + a = [[] for i in range(src_len)] + + # Pre-compute alignment min and max for each word + a_span = [[src_len + 1, -1] for i in range(src_len)] + for (s, t) in al: + a[s].append(t) + a_span[s][0] = min(a_span[s][0], t) + a_span[s][1] = max(a_span[s][1], t) + + # Target side non-terimnal coverage + # Cython bit vector? + cover = [0] * src_len + + print src + print tgt + print a_span + + # Spans are _inclusive_ on both ends [i, j] + def span_check(vec, i, j): + k = i + while k <= j: + if vec[k]: + return False + k += 1 + return True + + def span_flip(vec, i, j): + k = i + while k <= j: + vec[k] = ~vec[k] + k += 1 + + # Extract all possible hierarchical phrases starting at a source index + # src i and j are current, tgt i and j are previous + def extract(src_i, src_j, tgt_i, tgt_j, wc, al, nt, nt_open): + # Phrase extraction limits + if wc > max_len or (src_j + 1) >= src_len or \ + (src_j - src_i) + 1 > max_size or len(nt) > max_nt: + return + # Unaligned word + if not a[src_j]: + # Open non-terminal: extend + if nt_open: + nt[-1][2] += 1 + extract(src_i, src_j + 1, tgt_i, tgt_j, wc, al, nt, True) + nt[-1][2] -= 1 + # No open non-terminal: extend with word + else: + extract(src_i, src_j + 1, tgt_i, tgt_j, wc + 1, al, nt, False) + return + # Aligned word + link_i = a_span[src_j][0] + link_j = a_span[src_j][1] + new_tgt_i = min(link_i, tgt_i) + new_tgt_j = max(link_j, tgt_j) + # Open non-terminal: close, extract, extend + if nt_open: + # Close non-terminal, checking for collisions + old_last_nt = nt[-1][:] + nt[-1][2] = src_j + if link_i < nt[-1][3]: + if not span_check(cover, link_i, nt[-1][3] - 1): + nt[-1] = old_last_nt + return + span_flip(cover, link_i, nt[-1][3] - 1) + nt[-1][3] = link_i + if link_j > nt[-1][4]: + if not span_check(cover, nt[-1][4] + 1, link_j): + nt[-1] = old_last_nt + return + span_flip(cover, nt[-1][4] + 1, link_j) + nt[-1][4] = link_j + add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al) + extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False) + nt[-1] = old_last_nt + if link_i < nt[-1][3]: + span_flip(cover, link_i, nt[-1][3] - 1) + if link_j > nt[-1][4]: + span_flip(cover, nt[-1][4] + 1, link_j) + return + # No open non-terminal + # Extract, extend with word + collision = False + for link in a[src_j]: + if cover[link]: + collision = True + # Collisions block extraction and extension, but may be okay for + # continuing non-terminals + if not collision: + plus_al = [] + for link in a[src_j]: + plus_al.append((src_j, link)) + cover[link] = ~cover[link] + al.append(plus_al) + add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al) + extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc + 1, al, nt, False) + al.pop() + for link in a[src_j]: + cover[link] = ~cover[link] + # Try to add a word to a (closed) non-terminal, extract, extend + if nt and nt[-1][2] == src_j - 1: + # Add to non-terminal, checking for collisions + old_last_nt = nt[-1][:] + nt[-1][2] = src_j + if link_i < nt[-1][3]: + if not span_check(cover, link_i, nt[-1][3] - 1): + nt[-1] = old_last_nt + return + span_flip(cover, link_i, nt[-1][3] - 1) + nt[-1][3] = link_i + if link_j > nt[-1][4]: + if not span_check(cover, nt[-1][4] + 1, link_j): + nt[-1] = old_last_nt + return + span_flip(cover, nt[-1][4] + 1, link_j) + nt[-1][4] = link_j + # Require at least one word in phrase + if al: + add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al) + extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False) + nt[-1] = old_last_nt + if new_tgt_i < nt[-1][3]: + span_flip(cover, link_i, nt[-1][3] - 1) + if link_j > nt[-1][4]: + span_flip(cover, nt[-1][4] + 1, link_j) + # Try to start a new non-terminal, extract, extend + if not nt or src_j - nt[-1][2] > 1: + # Check for collisions + if not span_check(cover, link_i, link_j): + return + span_flip(cover, link_i, link_j) + nt.append([next_nt(nt), src_j, src_j, link_i, link_j]) + # Require at least one word in phrase + if al: + add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al) + extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False) + nt.pop() + span_flip(cover, link_i, link_j) + # TODO: try adding NT to start, end, both + # check: one aligned word on boundary that is not part of a NT + + # Try to extract phrases from every src index + src_i = 0 + while src_i < src_len: + # Skip if phrases won't be tight on left side + if not a[src_i]: + src_i += 1 + continue + extract(src_i, src_i, src_len + 1, -1, 1, [], [], False) + src_i += 1 + +# Create a rule from source, target, non-terminals, and alignments +def add_rule(src_i, tgt_i, src_span, tgt_span, nt, al): + flat = (item for sub in al for item in sub) + astr = ' '.join('{0}-{1}'.format(x[0], x[1]) for x in flat) + +# print '--- Rule' +# print src_span +# print tgt_span +# print nt +# print astr +# print '---' + + # This could be more efficient but is probably not going to + # be the bottleneck + src_sym = src_span[:] + off = src_i + for next_nt in nt: + nt_len = (next_nt[2] - next_nt[1]) + 1 + i = 0 + while i < nt_len: + src_sym.pop(next_nt[1] - off) + i += 1 + src_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0])) + off += (nt_len - 1) + tgt_sym = tgt_span[:] + off = tgt_i + for next_nt in sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])): + nt_len = (next_nt[4] - next_nt[3]) + 1 + i = 0 + while i < nt_len: + tgt_sym.pop(next_nt[3] - off) + i += 1 + tgt_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0])) + off += (nt_len - 1) + print '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(src_sym), ' '.join(tgt_sym), astr) + +if __name__ == '__main__': + main(sys.argv)
\ No newline at end of file |