diff options
author | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-21 20:13:27 -0500 |
---|---|---|
committer | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-21 20:13:27 -0500 |
commit | a125581487dfd9da4400bab51815252b02861f31 (patch) | |
tree | 918ce9c22cff4a464f1d0bb9e60da92377028176 /python/src | |
parent | 2d90a579b59b89b518db677406e48dc7b4d77dc1 (diff) |
New classy version
Diffstat (limited to 'python/src')
-rwxr-xr-x | python/src/sa/online_extractor.py | 429 |
1 files changed, 232 insertions, 197 deletions
diff --git a/python/src/sa/online_extractor.py b/python/src/sa/online_extractor.py index d41f3b39..fd4bb5f5 100755 --- a/python/src/sa/online_extractor.py +++ b/python/src/sa/online_extractor.py @@ -2,229 +2,264 @@ import collections, sys -def main(argv): - - for line in sys.stdin: - src, tgt, astr = (x.split() for x in line.split('|||')) - al = sorted(tuple(int(y) for y in x.split('-')) for x in astr) - extract_and_aggr(src, tgt, al) +import cdec.configobj -# Extract hierarchical phrase pairs -# This could be far better optimized by integrating it -# with suffix array code. For now, it gets the job done. -def extract_and_aggr(src, tgt, al, max_len=5, max_size=15, max_nt=2, boundary_nt=True): - - src_ph = collections.defaultdict(lambda: 0) # src = count - tgt_ph = collections.defaultdict(lambda: 0) # tgt = count - # [src][tgt] = count - phrase_pairs = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) - - src_w = collections.defaultdict(lambda: 0) # count - tgt_w = collections.defaultdict(lambda: 0) # count - # [src][tgt] = count - cooc_w = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) - - # Bilexical counts - for word in tgt: - tgt_w[word] += 1 - for word in src: - src_w[word] += 1 - for t_word in tgt: - cooc_w[word][t_word] += 1 +CAT = '[X]' # Default non-terminal +MAX_SIZE = 15 # Max span of a grammar rule (source) +MAX_LEN = 5 # Max number of terminals and non-terminals in a rule (source) +MAX_NT = 2 # Max number of non-terminals in a rule +MIN_GAP = 1 # Min number of terminals between non-terminals (source) - def next_nt(nt): - if not nt: - return 1 - return nt[-1][0] + 1 - - src_len = len(src) - - a = [[] for i in range(src_len)] - - # Pre-compute alignment min and max for each word - a_span = [[src_len + 1, -1] for i in range(src_len)] - for (s, t) in al: - a[s].append(t) - a_span[s][0] = min(a_span[s][0], t) - a_span[s][1] = max(a_span[s][1], t) +# Spans are _inclusive_ on both ends [i, j] +# TODO: Replace all of this with bit vectors? +def span_check(vec, i, j): + k = i + while k <= j: + if vec[k]: + return False + k += 1 + return True - # Target side non-terimnal coverage - # Cython bit vector? - cover = [0] * src_len - - print src - print tgt - print a_span - - # Spans are _inclusive_ on both ends [i, j] - def span_check(vec, i, j): - k = i - while k <= j: - if vec[k]: - return False - k += 1 - return True - - def span_flip(vec, i, j): - k = i - while k <= j: - vec[k] = ~vec[k] - k += 1 +def span_flip(vec, i, j): + k = i + while k <= j: + vec[k] = ~vec[k] + k += 1 - # Extract all possible hierarchical phrases starting at a source index - # src i and j are current, tgt i and j are previous - def extract(src_i, src_j, tgt_i, tgt_j, wc, al, nt, nt_open): - # Phrase extraction limits - if wc > max_len or (src_j + 1) >= src_len or \ - (src_j - src_i) + 1 > max_size or len(nt) > max_nt: - return - # Unaligned word - if not a[src_j]: - # Open non-terminal: extend - if nt_open: - nt[-1][2] += 1 - extract(src_i, src_j + 1, tgt_i, tgt_j, wc, al, nt, True) - nt[-1][2] -= 1 - # No open non-terminal: extend with word - else: - extract(src_i, src_j + 1, tgt_i, tgt_j, wc + 1, al, nt, False) - return - # Aligned word - link_i = a_span[src_j][0] - link_j = a_span[src_j][1] - new_tgt_i = min(link_i, tgt_i) - new_tgt_j = max(link_j, tgt_j) - # Open non-terminal: close, extract, extend - if nt_open: - # Close non-terminal, checking for collisions - old_last_nt = nt[-1][:] - nt[-1][2] = src_j - if link_i < nt[-1][3]: - if not span_check(cover, link_i, nt[-1][3] - 1): - nt[-1] = old_last_nt - return - span_flip(cover, link_i, nt[-1][3] - 1) - nt[-1][3] = link_i - if link_j > nt[-1][4]: - if not span_check(cover, nt[-1][4] + 1, link_j): - nt[-1] = old_last_nt - return - span_flip(cover, nt[-1][4] + 1, link_j) - nt[-1][4] = link_j - add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al) - extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False) - nt[-1] = old_last_nt - if link_i < nt[-1][3]: - span_flip(cover, link_i, nt[-1][3] - 1) - if link_j > nt[-1][4]: - span_flip(cover, nt[-1][4] + 1, link_j) - return - # No open non-terminal - # Extract, extend with word - collision = False - for link in a[src_j]: - if cover[link]: - collision = True - # Collisions block extraction and extension, but may be okay for - # continuing non-terminals - if not collision: - plus_al = [] - for link in a[src_j]: - plus_al.append((src_j, link)) - cover[link] = ~cover[link] - al.append(plus_al) - add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al) - extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc + 1, al, nt, False) - al.pop() - for link in a[src_j]: - cover[link] = ~cover[link] - # Try to add a word to a (closed) non-terminal, extract, extend - if nt and nt[-1][2] == src_j - 1: - # Add to non-terminal, checking for collisions - old_last_nt = nt[-1][:] - nt[-1][2] = src_j - if link_i < nt[-1][3]: - if not span_check(cover, link_i, nt[-1][3] - 1): - nt[-1] = old_last_nt - return - span_flip(cover, link_i, nt[-1][3] - 1) - nt[-1][3] = link_i - if link_j > nt[-1][4]: - if not span_check(cover, nt[-1][4] + 1, link_j): - nt[-1] = old_last_nt - return - span_flip(cover, nt[-1][4] + 1, link_j) - nt[-1][4] = link_j - # Require at least one word in phrase - if al: - add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al) - extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False) - nt[-1] = old_last_nt - if new_tgt_i < nt[-1][3]: - span_flip(cover, link_i, nt[-1][3] - 1) - if link_j > nt[-1][4]: - span_flip(cover, nt[-1][4] + 1, link_j) - # Try to start a new non-terminal, extract, extend - if not nt or src_j - nt[-1][2] > 1: - # Check for collisions - if not span_check(cover, link_i, link_j): - return - span_flip(cover, link_i, link_j) - nt.append([next_nt(nt), src_j, src_j, link_i, link_j]) - # Require at least one word in phrase - if al: - add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al) - extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False) - nt.pop() - span_flip(cover, link_i, link_j) - # TODO: try adding NT to start, end, both - # check: one aligned word on boundary that is not part of a NT - - # Try to extract phrases from every src index - src_i = 0 - while src_i < src_len: - # Skip if phrases won't be tight on left side - if not a[src_i]: - src_i += 1 - continue - extract(src_i, src_i, src_len + 1, -1, 1, [], [], False) - src_i += 1 +# Next non-terminal +def next_nt(nt): + if not nt: + return 1 + return nt[-1][0] + 1 # Create a rule from source, target, non-terminals, and alignments -def add_rule(src_i, tgt_i, src_span, tgt_span, nt, al): +def form_rule(f_i, e_i, f_span, e_span, nt, al): flat = (item for sub in al for item in sub) astr = ' '.join('{0}-{1}'.format(x[0], x[1]) for x in flat) # print '--- Rule' -# print src_span -# print tgt_span +# print f_span +# print e_span # print nt # print astr # print '---' - # This could be more efficient but is probably not going to - # be the bottleneck - src_sym = src_span[:] - off = src_i + # This could be more efficient but is unlikely to be the bottleneck + f_sym = f_span[:] + off = f_i for next_nt in nt: nt_len = (next_nt[2] - next_nt[1]) + 1 i = 0 while i < nt_len: - src_sym.pop(next_nt[1] - off) + f_sym.pop(next_nt[1] - off) i += 1 - src_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0])) + f_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0])) off += (nt_len - 1) - tgt_sym = tgt_span[:] - off = tgt_i + e_sym = e_span[:] + off = e_i for next_nt in sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])): nt_len = (next_nt[4] - next_nt[3]) + 1 i = 0 while i < nt_len: - tgt_sym.pop(next_nt[3] - off) + e_sym.pop(next_nt[3] - off) i += 1 - tgt_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0])) + e_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0])) off += (nt_len - 1) - print '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(src_sym), ' '.join(tgt_sym), astr) + return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(f_sym), ' '.join(e_sym), astr) + +class OnlineGrammarExtractor: + + def __init__(self, config=None): + if isinstance(config, str) or isinstance(config, unicode): + if not os.path.exists(config): + raise IOError('cannot read configuration from {0}'.format(config)) + config = cdec.configobj.ConfigObj(config, unrepr=True) + elif not config: + config = collections.defaultdict(lambda: None) + self.category = CAT + self.max_size = MAX_SIZE + self.max_length = config['max_len'] or MAX_LEN + self.max_nonterminals = config['max_nt'] or MAX_NT + self.min_gap_size = MIN_GAP + # Hard coded: require at least one aligned word + # Hard coded: require tight phrases + + # Phrase counts + self.phrases_f = collections.defaultdict(lambda: 0) + self.phrases_e = collections.defaultdict(lambda: 0) + self.phrases_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) + + # Bilexical counts + self.bilex_f = collections.defaultdict(lambda: 0) + self.bilex_e = collections.defaultdict(lambda: 0) + self.bilex_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) + + # Aggregate bilexical counts + def aggr_bilex(self, f_words, e_words): + + for e_w in e_words: + self.bilex_e[e_w] += 1 + + for f_w in f_words: + self.bilex_f[f_w] += 1 + for e_w in e_words: + self.bilex_fe[f_w][e_w] += 1 + + # Aggregate stats from a training instance: + # Extract hierarchical phrase pairs + # Update bilexical counts + def add_instance(self, f_words, e_words, alignment): + + # Bilexical counts + self.aggr_bilex(f_words, e_words) + + # Phrase pairs extracted from this instance + phrases = set() + + f_len = len(f_words) + + # Pre-compute alignment info + al = [[] for i in range(f_len)] + al_span = [[f_len + 1, -1] for i in range(f_len)] + for (f, e) in alignment: + al[f].append(e) + al_span[f][0] = min(al_span[f][0], e) + al_span[f][1] = max(al_span[f][1], e) + + # Target side word coverage + # TODO: Does Cython do bit vectors? + cover = [0] * f_len + + # Extract all possible hierarchical phrases starting at a source index + # f_ i and j are current, e_ i and j are previous + def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open): + # Phrase extraction limits + if wc > self.max_length or (f_j + 1) >= f_len or \ + (f_j - f_i) + 1 > self.max_size or len(nt) > self.max_nonterminals: + return + # Unaligned word + if not al[f_j]: + # Open non-terminal: extend + if nt_open: + nt[-1][2] += 1 + extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True) + nt[-1][2] -= 1 + # No open non-terminal: extend with word + else: + extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False) + return + # Aligned word + link_i = al_span[f_j][0] + link_j = al_span[f_j][1] + new_e_i = min(link_i, e_i) + new_e_j = max(link_j, e_j) + # Open non-terminal: close, extract, extend + if nt_open: + # Close non-terminal, checking for collisions + old_last_nt = nt[-1][:] + nt[-1][2] = f_j + if link_i < nt[-1][3]: + if not span_check(cover, link_i, nt[-1][3] - 1): + nt[-1] = old_last_nt + return + span_flip(cover, link_i, nt[-1][3] - 1) + nt[-1][3] = link_i + if link_j > nt[-1][4]: + if not span_check(cover, nt[-1][4] + 1, link_j): + nt[-1] = old_last_nt + return + span_flip(cover, nt[-1][4] + 1, link_j) + nt[-1][4] = link_j + phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + nt[-1] = old_last_nt + if link_i < nt[-1][3]: + span_flip(cover, link_i, nt[-1][3] - 1) + if link_j > nt[-1][4]: + span_flip(cover, nt[-1][4] + 1, link_j) + return + # No open non-terminal + # Extract, extend with word + collision = False + for link in al[f_j]: + if cover[link]: + collision = True + # Collisions block extraction and extension, but may be okay for + # continuing non-terminals + if not collision: + plus_links = [] + for link in al[f_j]: + plus_links.append((f_j, link)) + cover[link] = ~cover[link] + links.append(plus_links) + phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False) + links.pop() + for link in al[f_j]: + cover[link] = ~cover[link] + # Try to add a word to a (closed) non-terminal, extract, extend + if nt and nt[-1][2] == f_j - 1: + # Add to non-terminal, checking for collisions + old_last_nt = nt[-1][:] + nt[-1][2] = f_j + if link_i < nt[-1][3]: + if not span_check(cover, link_i, nt[-1][3] - 1): + nt[-1] = old_last_nt + return + span_flip(cover, link_i, nt[-1][3] - 1) + nt[-1][3] = link_i + if link_j > nt[-1][4]: + if not span_check(cover, nt[-1][4] + 1, link_j): + nt[-1] = old_last_nt + return + span_flip(cover, nt[-1][4] + 1, link_j) + nt[-1][4] = link_j + # Require at least one word in phrase + if links: + phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + nt[-1] = old_last_nt + if new_e_i < nt[-1][3]: + span_flip(cover, link_i, nt[-1][3] - 1) + if link_j > nt[-1][4]: + span_flip(cover, nt[-1][4] + 1, link_j) + # Try to start a new non-terminal, extract, extend + if not nt or f_j - nt[-1][2] > 1: + # Check for collisions + if not span_check(cover, link_i, link_j): + return + span_flip(cover, link_i, link_j) + nt.append([next_nt(nt), f_j, f_j, link_i, link_j]) + # Require at least one word in phrase + if links: + phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + nt.pop() + span_flip(cover, link_i, link_j) + # TODO: try adding NT to start, end, both + # check: one aligned word on boundary that is not part of a NT + + # Try to extract phrases from every f index + f_i = 0 + while f_i < f_len: + # Skip if phrases won't be tight on left side + if not al[f_i]: + f_i += 1 + continue + extract(f_i, f_i, f_len + 1, -1, 1, [], [], False) + f_i += 1 + + for rule in sorted(phrases): + print rule + +def main(argv): + + extractor = OnlineGrammarExtractor() + + for line in sys.stdin: + f_words, e_words, a_str = (x.split() for x in line.split('|||')) + alignment = sorted(tuple(int(y) for y in x.split('-')) for x in a_str) + extractor.add_instance(f_words, e_words, alignment) if __name__ == '__main__': main(sys.argv)
\ No newline at end of file |