diff options
author | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-23 21:09:41 -0500 |
---|---|---|
committer | Michael Denkowski <michael.j.denkowski@gmail.com> | 2012-12-23 21:09:41 -0500 |
commit | 29190ecb9a1a1771459a5c1bf0c7afa54d4b0416 (patch) | |
tree | 2759e1405d8aca12c55fb9384f67d798fd50320a /python | |
parent | 778a4cec55f82bcc66b3f52de7cc871e8daaeb92 (diff) |
NonTerminal class
Diffstat (limited to 'python')
-rwxr-xr-x | python/src/sa/online_extractor.py | 98 |
1 files changed, 58 insertions, 40 deletions
diff --git a/python/src/sa/online_extractor.py b/python/src/sa/online_extractor.py index 06eb5357..0c013cb3 100755 --- a/python/src/sa/online_extractor.py +++ b/python/src/sa/online_extractor.py @@ -4,11 +4,11 @@ import collections, sys import cdec.configobj -CAT = '[X]' # Default non-terminal -MAX_SIZE = 15 # Max span of a grammar rule (source) -MAX_LEN = 5 # Max number of terminals and non-terminals in a rule (source) -MAX_NT = 2 # Max number of non-terminals in a rule -MIN_GAP = 1 # Min number of terminals between non-terminals (source) +CAT = '[X]' # Default non-terminal +MAX_SIZE = 15 # Max span of a grammar rule (source) +MAX_LEN = 5 # Max number of terminals and non-terminals in a rule (source) +MAX_NT = 2 # Max number of non-terminals in a rule +MIN_GAP = 1 # Min number of terminals between non-terminals (source) # Spans are _inclusive_ on both ends [i, j] # TODO: Replace all of this with bit vectors? @@ -17,14 +17,14 @@ def span_check(vec, i, j): while k <= j: if vec[k]: return False - k += 1 + k += 1 return True def span_flip(vec, i, j): k = i while k <= j: vec[k] = ~vec[k] - k += 1 + k += 1 # Next non-terminal def next_nt(nt): @@ -32,13 +32,27 @@ def next_nt(nt): return 1 return nt[-1][0] + 1 +class NonTerminal: + def __init__(self, index): + self.index = index + def __str__(self): + return '[X,{0}]'.format(self.index) + +def fmt_rule(f_sym, e_sym, links): + a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links) + return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(str(sym) for sym in f_sym), + ' '.join(str(sym) for sym in e_sym), + a_str) + # Create a rule from source, target, non-terminals, and alignments -def form_rule(f_i, e_i, f_span, e_span, nt, al): - +def form_rules(f_i, e_i, f_span, e_span, nt, al): + # This could be more efficient but is unlikely to be the bottleneck - + + rules = [] + nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) - + f_sym = f_span[:] off = f_i for next_nt in nt: @@ -47,7 +61,7 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al): while i < nt_len: f_sym.pop(next_nt[1] - off) i += 1 - f_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0])) + f_sym.insert(next_nt[1] - off, NonTerminal(next_nt[0])) off += (nt_len - 1) e_sym = e_span[:] @@ -58,9 +72,9 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al): while i < nt_len: e_sym.pop(next_nt[3] - off) i += 1 - e_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0])) + e_sym.insert(next_nt[3] - off, NonTerminal(next_nt[0])) off += (nt_len - 1) - + # Adjusting alignment links takes some doing links = [list(link) for sub in al for link in sub] links_len = len(links) @@ -73,7 +87,7 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al): off += (nt[nt_i][2] - nt[nt_i][1]) nt_i += 1 links[i][0] -= off - i += 1 + i += 1 nt_i = 0 off = e_i i = 0 @@ -82,13 +96,13 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al): off += (nt_inv[nt_i][4] - nt_inv[nt_i][3]) nt_i += 1 links[i][1] -= off - i += 1 - a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links) - - return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(f_sym), ' '.join(e_sym), a_str) + i += 1 + + rules.append(fmt_rule(f_sym, e_sym, links)) + return rules class OnlineGrammarExtractor: - + def __init__(self, config=None): if isinstance(config, str) or isinstance(config, unicode): if not os.path.exists(config): @@ -103,23 +117,23 @@ class OnlineGrammarExtractor: self.min_gap_size = MIN_GAP # Hard coded: require at least one aligned word # Hard coded: require tight phrases - + # Phrase counts self.phrases_f = collections.defaultdict(lambda: 0) self.phrases_e = collections.defaultdict(lambda: 0) self.phrases_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) - + # Bilexical counts self.bilex_f = collections.defaultdict(lambda: 0) self.bilex_e = collections.defaultdict(lambda: 0) self.bilex_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) - + # Aggregate bilexical counts def aggr_bilex(self, f_words, e_words): - + for e_w in e_words: self.bilex_e[e_w] += 1 - + for f_w in f_words: self.bilex_f[f_w] += 1 for e_w in e_words: @@ -129,15 +143,15 @@ class OnlineGrammarExtractor: # Extract hierarchical phrase pairs # Update bilexical counts def add_instance(self, f_words, e_words, alignment): - + # Bilexical counts self.aggr_bilex(f_words, e_words) - + # Phrase pairs extracted from this instance phrases = set() - + f_len = len(f_words) - + # Pre-compute alignment info al = [[] for i in range(f_len)] al_span = [[f_len + 1, -1] for i in range(f_len)] @@ -145,11 +159,11 @@ class OnlineGrammarExtractor: al[f].append(e) al_span[f][0] = min(al_span[f][0], e) al_span[f][1] = max(al_span[f][1], e) - + # Target side word coverage # TODO: Does Cython do bit vectors? cover = [0] * f_len - + # Extract all possible hierarchical phrases starting at a source index # f_ i and j are current, e_ i and j are previous def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open): @@ -170,7 +184,7 @@ class OnlineGrammarExtractor: return # Aligned word link_i = al_span[f_j][0] - link_j = al_span[f_j][1] + link_j = al_span[f_j][1] new_e_i = min(link_i, e_i) new_e_j = max(link_j, e_j) # Open non-terminal: close, extract, extend @@ -190,7 +204,8 @@ class OnlineGrammarExtractor: return span_flip(cover, nt[-1][4] + 1, link_j) nt[-1][4] = link_j - phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + phrases.add(rule) extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) nt[-1] = old_last_nt if link_i < nt[-1][3]: @@ -212,7 +227,8 @@ class OnlineGrammarExtractor: plus_links.append((f_j, link)) cover[link] = ~cover[link] links.append(plus_links) - phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + phrases.add(rule) extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False) links.pop() for link in al[f_j]: @@ -236,7 +252,8 @@ class OnlineGrammarExtractor: nt[-1][4] = link_j # Require at least one word in phrase if links: - phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + phrases.add(rule) extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) nt[-1] = old_last_nt if new_e_i < nt[-1][3]: @@ -252,13 +269,14 @@ class OnlineGrammarExtractor: nt.append([next_nt(nt), f_j, f_j, link_i, link_j]) # Require at least one word in phrase if links: - phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + phrases.add(rule) extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) nt.pop() span_flip(cover, link_i, link_j) # TODO: try adding NT to start, end, both # check: one aligned word on boundary that is not part of a NT - + # Try to extract phrases from every f index f_i = 0 while f_i < f_len: @@ -268,18 +286,18 @@ class OnlineGrammarExtractor: continue extract(f_i, f_i, f_len + 1, -1, 1, [], [], False) f_i += 1 - + for rule in sorted(phrases): print rule def main(argv): extractor = OnlineGrammarExtractor() - + for line in sys.stdin: f_words, e_words, a_str = (x.split() for x in line.split('|||')) alignment = sorted(tuple(int(y) for y in x.split('-')) for x in a_str) extractor.add_instance(f_words, e_words, alignment) if __name__ == '__main__': - main(sys.argv)
\ No newline at end of file + main(sys.argv) |