summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xpython/src/sa/online_extractor.py429
1 files changed, 232 insertions, 197 deletions
diff --git a/python/src/sa/online_extractor.py b/python/src/sa/online_extractor.py
index d41f3b39..fd4bb5f5 100755
--- a/python/src/sa/online_extractor.py
+++ b/python/src/sa/online_extractor.py
@@ -2,229 +2,264 @@
import collections, sys
-def main(argv):
-
- for line in sys.stdin:
- src, tgt, astr = (x.split() for x in line.split('|||'))
- al = sorted(tuple(int(y) for y in x.split('-')) for x in astr)
- extract_and_aggr(src, tgt, al)
+import cdec.configobj
-# Extract hierarchical phrase pairs
-# This could be far better optimized by integrating it
-# with suffix array code. For now, it gets the job done.
-def extract_and_aggr(src, tgt, al, max_len=5, max_size=15, max_nt=2, boundary_nt=True):
-
- src_ph = collections.defaultdict(lambda: 0) # src = count
- tgt_ph = collections.defaultdict(lambda: 0) # tgt = count
- # [src][tgt] = count
- phrase_pairs = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
-
- src_w = collections.defaultdict(lambda: 0) # count
- tgt_w = collections.defaultdict(lambda: 0) # count
- # [src][tgt] = count
- cooc_w = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
-
- # Bilexical counts
- for word in tgt:
- tgt_w[word] += 1
- for word in src:
- src_w[word] += 1
- for t_word in tgt:
- cooc_w[word][t_word] += 1
+CAT = '[X]' # Default non-terminal
+MAX_SIZE = 15 # Max span of a grammar rule (source)
+MAX_LEN = 5 # Max number of terminals and non-terminals in a rule (source)
+MAX_NT = 2 # Max number of non-terminals in a rule
+MIN_GAP = 1 # Min number of terminals between non-terminals (source)
- def next_nt(nt):
- if not nt:
- return 1
- return nt[-1][0] + 1
-
- src_len = len(src)
-
- a = [[] for i in range(src_len)]
-
- # Pre-compute alignment min and max for each word
- a_span = [[src_len + 1, -1] for i in range(src_len)]
- for (s, t) in al:
- a[s].append(t)
- a_span[s][0] = min(a_span[s][0], t)
- a_span[s][1] = max(a_span[s][1], t)
+# Spans are _inclusive_ on both ends [i, j]
+# TODO: Replace all of this with bit vectors?
+def span_check(vec, i, j):
+ k = i
+ while k <= j:
+ if vec[k]:
+ return False
+ k += 1
+ return True
- # Target side non-terimnal coverage
- # Cython bit vector?
- cover = [0] * src_len
-
- print src
- print tgt
- print a_span
-
- # Spans are _inclusive_ on both ends [i, j]
- def span_check(vec, i, j):
- k = i
- while k <= j:
- if vec[k]:
- return False
- k += 1
- return True
-
- def span_flip(vec, i, j):
- k = i
- while k <= j:
- vec[k] = ~vec[k]
- k += 1
+def span_flip(vec, i, j):
+ k = i
+ while k <= j:
+ vec[k] = ~vec[k]
+ k += 1
- # Extract all possible hierarchical phrases starting at a source index
- # src i and j are current, tgt i and j are previous
- def extract(src_i, src_j, tgt_i, tgt_j, wc, al, nt, nt_open):
- # Phrase extraction limits
- if wc > max_len or (src_j + 1) >= src_len or \
- (src_j - src_i) + 1 > max_size or len(nt) > max_nt:
- return
- # Unaligned word
- if not a[src_j]:
- # Open non-terminal: extend
- if nt_open:
- nt[-1][2] += 1
- extract(src_i, src_j + 1, tgt_i, tgt_j, wc, al, nt, True)
- nt[-1][2] -= 1
- # No open non-terminal: extend with word
- else:
- extract(src_i, src_j + 1, tgt_i, tgt_j, wc + 1, al, nt, False)
- return
- # Aligned word
- link_i = a_span[src_j][0]
- link_j = a_span[src_j][1]
- new_tgt_i = min(link_i, tgt_i)
- new_tgt_j = max(link_j, tgt_j)
- # Open non-terminal: close, extract, extend
- if nt_open:
- # Close non-terminal, checking for collisions
- old_last_nt = nt[-1][:]
- nt[-1][2] = src_j
- if link_i < nt[-1][3]:
- if not span_check(cover, link_i, nt[-1][3] - 1):
- nt[-1] = old_last_nt
- return
- span_flip(cover, link_i, nt[-1][3] - 1)
- nt[-1][3] = link_i
- if link_j > nt[-1][4]:
- if not span_check(cover, nt[-1][4] + 1, link_j):
- nt[-1] = old_last_nt
- return
- span_flip(cover, nt[-1][4] + 1, link_j)
- nt[-1][4] = link_j
- add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al)
- extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False)
- nt[-1] = old_last_nt
- if link_i < nt[-1][3]:
- span_flip(cover, link_i, nt[-1][3] - 1)
- if link_j > nt[-1][4]:
- span_flip(cover, nt[-1][4] + 1, link_j)
- return
- # No open non-terminal
- # Extract, extend with word
- collision = False
- for link in a[src_j]:
- if cover[link]:
- collision = True
- # Collisions block extraction and extension, but may be okay for
- # continuing non-terminals
- if not collision:
- plus_al = []
- for link in a[src_j]:
- plus_al.append((src_j, link))
- cover[link] = ~cover[link]
- al.append(plus_al)
- add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al)
- extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc + 1, al, nt, False)
- al.pop()
- for link in a[src_j]:
- cover[link] = ~cover[link]
- # Try to add a word to a (closed) non-terminal, extract, extend
- if nt and nt[-1][2] == src_j - 1:
- # Add to non-terminal, checking for collisions
- old_last_nt = nt[-1][:]
- nt[-1][2] = src_j
- if link_i < nt[-1][3]:
- if not span_check(cover, link_i, nt[-1][3] - 1):
- nt[-1] = old_last_nt
- return
- span_flip(cover, link_i, nt[-1][3] - 1)
- nt[-1][3] = link_i
- if link_j > nt[-1][4]:
- if not span_check(cover, nt[-1][4] + 1, link_j):
- nt[-1] = old_last_nt
- return
- span_flip(cover, nt[-1][4] + 1, link_j)
- nt[-1][4] = link_j
- # Require at least one word in phrase
- if al:
- add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al)
- extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False)
- nt[-1] = old_last_nt
- if new_tgt_i < nt[-1][3]:
- span_flip(cover, link_i, nt[-1][3] - 1)
- if link_j > nt[-1][4]:
- span_flip(cover, nt[-1][4] + 1, link_j)
- # Try to start a new non-terminal, extract, extend
- if not nt or src_j - nt[-1][2] > 1:
- # Check for collisions
- if not span_check(cover, link_i, link_j):
- return
- span_flip(cover, link_i, link_j)
- nt.append([next_nt(nt), src_j, src_j, link_i, link_j])
- # Require at least one word in phrase
- if al:
- add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al)
- extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False)
- nt.pop()
- span_flip(cover, link_i, link_j)
- # TODO: try adding NT to start, end, both
- # check: one aligned word on boundary that is not part of a NT
-
- # Try to extract phrases from every src index
- src_i = 0
- while src_i < src_len:
- # Skip if phrases won't be tight on left side
- if not a[src_i]:
- src_i += 1
- continue
- extract(src_i, src_i, src_len + 1, -1, 1, [], [], False)
- src_i += 1
+# Next non-terminal
+def next_nt(nt):
+ if not nt:
+ return 1
+ return nt[-1][0] + 1
# Create a rule from source, target, non-terminals, and alignments
-def add_rule(src_i, tgt_i, src_span, tgt_span, nt, al):
+def form_rule(f_i, e_i, f_span, e_span, nt, al):
flat = (item for sub in al for item in sub)
astr = ' '.join('{0}-{1}'.format(x[0], x[1]) for x in flat)
# print '--- Rule'
-# print src_span
-# print tgt_span
+# print f_span
+# print e_span
# print nt
# print astr
# print '---'
- # This could be more efficient but is probably not going to
- # be the bottleneck
- src_sym = src_span[:]
- off = src_i
+ # This could be more efficient but is unlikely to be the bottleneck
+ f_sym = f_span[:]
+ off = f_i
for next_nt in nt:
nt_len = (next_nt[2] - next_nt[1]) + 1
i = 0
while i < nt_len:
- src_sym.pop(next_nt[1] - off)
+ f_sym.pop(next_nt[1] - off)
i += 1
- src_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0]))
+ f_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0]))
off += (nt_len - 1)
- tgt_sym = tgt_span[:]
- off = tgt_i
+ e_sym = e_span[:]
+ off = e_i
for next_nt in sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])):
nt_len = (next_nt[4] - next_nt[3]) + 1
i = 0
while i < nt_len:
- tgt_sym.pop(next_nt[3] - off)
+ e_sym.pop(next_nt[3] - off)
i += 1
- tgt_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0]))
+ e_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0]))
off += (nt_len - 1)
- print '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(src_sym), ' '.join(tgt_sym), astr)
+ return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(f_sym), ' '.join(e_sym), astr)
+
+class OnlineGrammarExtractor:
+
+ def __init__(self, config=None):
+ if isinstance(config, str) or isinstance(config, unicode):
+ if not os.path.exists(config):
+ raise IOError('cannot read configuration from {0}'.format(config))
+ config = cdec.configobj.ConfigObj(config, unrepr=True)
+ elif not config:
+ config = collections.defaultdict(lambda: None)
+ self.category = CAT
+ self.max_size = MAX_SIZE
+ self.max_length = config['max_len'] or MAX_LEN
+ self.max_nonterminals = config['max_nt'] or MAX_NT
+ self.min_gap_size = MIN_GAP
+ # Hard coded: require at least one aligned word
+ # Hard coded: require tight phrases
+
+ # Phrase counts
+ self.phrases_f = collections.defaultdict(lambda: 0)
+ self.phrases_e = collections.defaultdict(lambda: 0)
+ self.phrases_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
+
+ # Bilexical counts
+ self.bilex_f = collections.defaultdict(lambda: 0)
+ self.bilex_e = collections.defaultdict(lambda: 0)
+ self.bilex_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
+
+ # Aggregate bilexical counts
+ def aggr_bilex(self, f_words, e_words):
+
+ for e_w in e_words:
+ self.bilex_e[e_w] += 1
+
+ for f_w in f_words:
+ self.bilex_f[f_w] += 1
+ for e_w in e_words:
+ self.bilex_fe[f_w][e_w] += 1
+
+ # Aggregate stats from a training instance:
+ # Extract hierarchical phrase pairs
+ # Update bilexical counts
+ def add_instance(self, f_words, e_words, alignment):
+
+ # Bilexical counts
+ self.aggr_bilex(f_words, e_words)
+
+ # Phrase pairs extracted from this instance
+ phrases = set()
+
+ f_len = len(f_words)
+
+ # Pre-compute alignment info
+ al = [[] for i in range(f_len)]
+ al_span = [[f_len + 1, -1] for i in range(f_len)]
+ for (f, e) in alignment:
+ al[f].append(e)
+ al_span[f][0] = min(al_span[f][0], e)
+ al_span[f][1] = max(al_span[f][1], e)
+
+ # Target side word coverage
+ # TODO: Does Cython do bit vectors?
+ cover = [0] * f_len
+
+ # Extract all possible hierarchical phrases starting at a source index
+ # f_ i and j are current, e_ i and j are previous
+ def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open):
+ # Phrase extraction limits
+ if wc > self.max_length or (f_j + 1) >= f_len or \
+ (f_j - f_i) + 1 > self.max_size or len(nt) > self.max_nonterminals:
+ return
+ # Unaligned word
+ if not al[f_j]:
+ # Open non-terminal: extend
+ if nt_open:
+ nt[-1][2] += 1
+ extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True)
+ nt[-1][2] -= 1
+ # No open non-terminal: extend with word
+ else:
+ extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False)
+ return
+ # Aligned word
+ link_i = al_span[f_j][0]
+ link_j = al_span[f_j][1]
+ new_e_i = min(link_i, e_i)
+ new_e_j = max(link_j, e_j)
+ # Open non-terminal: close, extract, extend
+ if nt_open:
+ # Close non-terminal, checking for collisions
+ old_last_nt = nt[-1][:]
+ nt[-1][2] = f_j
+ if link_i < nt[-1][3]:
+ if not span_check(cover, link_i, nt[-1][3] - 1):
+ nt[-1] = old_last_nt
+ return
+ span_flip(cover, link_i, nt[-1][3] - 1)
+ nt[-1][3] = link_i
+ if link_j > nt[-1][4]:
+ if not span_check(cover, nt[-1][4] + 1, link_j):
+ nt[-1] = old_last_nt
+ return
+ span_flip(cover, nt[-1][4] + 1, link_j)
+ nt[-1][4] = link_j
+ phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ nt[-1] = old_last_nt
+ if link_i < nt[-1][3]:
+ span_flip(cover, link_i, nt[-1][3] - 1)
+ if link_j > nt[-1][4]:
+ span_flip(cover, nt[-1][4] + 1, link_j)
+ return
+ # No open non-terminal
+ # Extract, extend with word
+ collision = False
+ for link in al[f_j]:
+ if cover[link]:
+ collision = True
+ # Collisions block extraction and extension, but may be okay for
+ # continuing non-terminals
+ if not collision:
+ plus_links = []
+ for link in al[f_j]:
+ plus_links.append((f_j, link))
+ cover[link] = ~cover[link]
+ links.append(plus_links)
+ phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False)
+ links.pop()
+ for link in al[f_j]:
+ cover[link] = ~cover[link]
+ # Try to add a word to a (closed) non-terminal, extract, extend
+ if nt and nt[-1][2] == f_j - 1:
+ # Add to non-terminal, checking for collisions
+ old_last_nt = nt[-1][:]
+ nt[-1][2] = f_j
+ if link_i < nt[-1][3]:
+ if not span_check(cover, link_i, nt[-1][3] - 1):
+ nt[-1] = old_last_nt
+ return
+ span_flip(cover, link_i, nt[-1][3] - 1)
+ nt[-1][3] = link_i
+ if link_j > nt[-1][4]:
+ if not span_check(cover, nt[-1][4] + 1, link_j):
+ nt[-1] = old_last_nt
+ return
+ span_flip(cover, nt[-1][4] + 1, link_j)
+ nt[-1][4] = link_j
+ # Require at least one word in phrase
+ if links:
+ phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ nt[-1] = old_last_nt
+ if new_e_i < nt[-1][3]:
+ span_flip(cover, link_i, nt[-1][3] - 1)
+ if link_j > nt[-1][4]:
+ span_flip(cover, nt[-1][4] + 1, link_j)
+ # Try to start a new non-terminal, extract, extend
+ if not nt or f_j - nt[-1][2] > 1:
+ # Check for collisions
+ if not span_check(cover, link_i, link_j):
+ return
+ span_flip(cover, link_i, link_j)
+ nt.append([next_nt(nt), f_j, f_j, link_i, link_j])
+ # Require at least one word in phrase
+ if links:
+ phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ nt.pop()
+ span_flip(cover, link_i, link_j)
+ # TODO: try adding NT to start, end, both
+ # check: one aligned word on boundary that is not part of a NT
+
+ # Try to extract phrases from every f index
+ f_i = 0
+ while f_i < f_len:
+ # Skip if phrases won't be tight on left side
+ if not al[f_i]:
+ f_i += 1
+ continue
+ extract(f_i, f_i, f_len + 1, -1, 1, [], [], False)
+ f_i += 1
+
+ for rule in sorted(phrases):
+ print rule
+
+def main(argv):
+
+ extractor = OnlineGrammarExtractor()
+
+ for line in sys.stdin:
+ f_words, e_words, a_str = (x.split() for x in line.split('|||'))
+ alignment = sorted(tuple(int(y) for y in x.split('-')) for x in a_str)
+ extractor.add_instance(f_words, e_words, alignment)
if __name__ == '__main__':
main(sys.argv) \ No newline at end of file