summaryrefslogtreecommitdiff
path: root/python/src
diff options
context:
space:
mode:
Diffstat (limited to 'python/src')
-rwxr-xr-xpython/src/sa/online_extractor.py48
1 files changed, 34 insertions, 14 deletions
diff --git a/python/src/sa/online_extractor.py b/python/src/sa/online_extractor.py
index fd4bb5f5..06eb5357 100755
--- a/python/src/sa/online_extractor.py
+++ b/python/src/sa/online_extractor.py
@@ -34,17 +34,11 @@ def next_nt(nt):
# Create a rule from source, target, non-terminals, and alignments
def form_rule(f_i, e_i, f_span, e_span, nt, al):
- flat = (item for sub in al for item in sub)
- astr = ' '.join('{0}-{1}'.format(x[0], x[1]) for x in flat)
-
-# print '--- Rule'
-# print f_span
-# print e_span
-# print nt
-# print astr
-# print '---'
# This could be more efficient but is unlikely to be the bottleneck
+
+ nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
+
f_sym = f_span[:]
off = f_i
for next_nt in nt:
@@ -55,9 +49,10 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al):
i += 1
f_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0]))
off += (nt_len - 1)
+
e_sym = e_span[:]
off = e_i
- for next_nt in sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])):
+ for next_nt in nt_inv:
nt_len = (next_nt[4] - next_nt[3]) + 1
i = 0
while i < nt_len:
@@ -65,7 +60,32 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al):
i += 1
e_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0]))
off += (nt_len - 1)
- return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(f_sym), ' '.join(e_sym), astr)
+
+ # Adjusting alignment links takes some doing
+ links = [list(link) for sub in al for link in sub]
+ links_len = len(links)
+ nt_len = len(nt)
+ nt_i = 0
+ off = f_i
+ i = 0
+ while i < links_len:
+ while nt_i < nt_len and links[i][0] > nt[nt_i][1]:
+ off += (nt[nt_i][2] - nt[nt_i][1])
+ nt_i += 1
+ links[i][0] -= off
+ i += 1
+ nt_i = 0
+ off = e_i
+ i = 0
+ while i < links_len:
+ while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]:
+ off += (nt_inv[nt_i][4] - nt_inv[nt_i][3])
+ nt_i += 1
+ links[i][1] -= off
+ i += 1
+ a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links)
+
+ return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(f_sym), ' '.join(e_sym), a_str)
class OnlineGrammarExtractor:
@@ -135,7 +155,7 @@ class OnlineGrammarExtractor:
def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open):
# Phrase extraction limits
if wc > self.max_length or (f_j + 1) >= f_len or \
- (f_j - f_i) + 1 > self.max_size or len(nt) > self.max_nonterminals:
+ (f_j - f_i) + 1 > self.max_size:
return
# Unaligned word
if not al[f_j]:
@@ -224,7 +244,7 @@ class OnlineGrammarExtractor:
if link_j > nt[-1][4]:
span_flip(cover, nt[-1][4] + 1, link_j)
# Try to start a new non-terminal, extract, extend
- if not nt or f_j - nt[-1][2] > 1:
+ if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals:
# Check for collisions
if not span_check(cover, link_i, link_j):
return
@@ -238,7 +258,7 @@ class OnlineGrammarExtractor:
span_flip(cover, link_i, link_j)
# TODO: try adding NT to start, end, both
# check: one aligned word on boundary that is not part of a NT
-
+
# Try to extract phrases from every f index
f_i = 0
while f_i < f_len: