summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xpython/src/sa/online_extractor.py156
1 files changed, 94 insertions, 62 deletions
diff --git a/python/src/sa/online_extractor.py b/python/src/sa/online_extractor.py
index 0c013cb3..90087f30 100755
--- a/python/src/sa/online_extractor.py
+++ b/python/src/sa/online_extractor.py
@@ -44,63 +44,6 @@ def fmt_rule(f_sym, e_sym, links):
' '.join(str(sym) for sym in e_sym),
a_str)
-# Create a rule from source, target, non-terminals, and alignments
-def form_rules(f_i, e_i, f_span, e_span, nt, al):
-
- # This could be more efficient but is unlikely to be the bottleneck
-
- rules = []
-
- nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
-
- f_sym = f_span[:]
- off = f_i
- for next_nt in nt:
- nt_len = (next_nt[2] - next_nt[1]) + 1
- i = 0
- while i < nt_len:
- f_sym.pop(next_nt[1] - off)
- i += 1
- f_sym.insert(next_nt[1] - off, NonTerminal(next_nt[0]))
- off += (nt_len - 1)
-
- e_sym = e_span[:]
- off = e_i
- for next_nt in nt_inv:
- nt_len = (next_nt[4] - next_nt[3]) + 1
- i = 0
- while i < nt_len:
- e_sym.pop(next_nt[3] - off)
- i += 1
- e_sym.insert(next_nt[3] - off, NonTerminal(next_nt[0]))
- off += (nt_len - 1)
-
- # Adjusting alignment links takes some doing
- links = [list(link) for sub in al for link in sub]
- links_len = len(links)
- nt_len = len(nt)
- nt_i = 0
- off = f_i
- i = 0
- while i < links_len:
- while nt_i < nt_len and links[i][0] > nt[nt_i][1]:
- off += (nt[nt_i][2] - nt[nt_i][1])
- nt_i += 1
- links[i][0] -= off
- i += 1
- nt_i = 0
- off = e_i
- i = 0
- while i < links_len:
- while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]:
- off += (nt_inv[nt_i][4] - nt_inv[nt_i][3])
- nt_i += 1
- links[i][1] -= off
- i += 1
-
- rules.append(fmt_rule(f_sym, e_sym, links))
- return rules
-
class OnlineGrammarExtractor:
def __init__(self, config=None):
@@ -168,7 +111,7 @@ class OnlineGrammarExtractor:
# f_ i and j are current, e_ i and j are previous
def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open):
# Phrase extraction limits
- if wc > self.max_length or (f_j + 1) >= f_len or \
+ if wc > self.max_length or (f_j + 1) > f_len or \
(f_j - f_i) + 1 > self.max_size:
return
# Unaligned word
@@ -204,7 +147,7 @@ class OnlineGrammarExtractor:
return
span_flip(cover, nt[-1][4] + 1, link_j)
nt[-1][4] = link_j
- for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
phrases.add(rule)
extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
nt[-1] = old_last_nt
@@ -227,7 +170,7 @@ class OnlineGrammarExtractor:
plus_links.append((f_j, link))
cover[link] = ~cover[link]
links.append(plus_links)
- for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
phrases.add(rule)
extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False)
links.pop()
@@ -252,7 +195,7 @@ class OnlineGrammarExtractor:
nt[-1][4] = link_j
# Require at least one word in phrase
if links:
- for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
phrases.add(rule)
extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
nt[-1] = old_last_nt
@@ -269,7 +212,7 @@ class OnlineGrammarExtractor:
nt.append([next_nt(nt), f_j, f_j, link_i, link_j])
# Require at least one word in phrase
if links:
- for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
phrases.add(rule)
extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
nt.pop()
@@ -290,6 +233,95 @@ class OnlineGrammarExtractor:
for rule in sorted(phrases):
print rule
+ # Create a rule from source, target, non-terminals, and alignments
+ def form_rules(self, f_i, e_i, f_span, e_span, nt, al):
+
+ # This could be more efficient but is unlikely to be the bottleneck
+
+ rules = []
+
+ nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
+
+ f_sym = f_span[:]
+ off = f_i
+ for next_nt in nt:
+ nt_len = (next_nt[2] - next_nt[1]) + 1
+ i = 0
+ while i < nt_len:
+ f_sym.pop(next_nt[1] - off)
+ i += 1
+ f_sym.insert(next_nt[1] - off, NonTerminal(next_nt[0]))
+ off += (nt_len - 1)
+
+ e_sym = e_span[:]
+ off = e_i
+ for next_nt in nt_inv:
+ nt_len = (next_nt[4] - next_nt[3]) + 1
+ i = 0
+ while i < nt_len:
+ e_sym.pop(next_nt[3] - off)
+ i += 1
+ e_sym.insert(next_nt[3] - off, NonTerminal(next_nt[0]))
+ off += (nt_len - 1)
+
+ # Adjusting alignment links takes some doing
+ links = [list(link) for sub in al for link in sub]
+ links_len = len(links)
+ nt_len = len(nt)
+ nt_i = 0
+ off = f_i
+ i = 0
+ while i < links_len:
+ while nt_i < nt_len and links[i][0] > nt[nt_i][1]:
+ off += (nt[nt_i][2] - nt[nt_i][1])
+ nt_i += 1
+ links[i][0] -= off
+ i += 1
+ nt_i = 0
+ off = e_i
+ i = 0
+ while i < links_len:
+ while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]:
+ off += (nt_inv[nt_i][4] - nt_inv[nt_i][3])
+ nt_i += 1
+ links[i][1] -= off
+ i += 1
+
+ # Rule
+ rules.append(fmt_rule(f_sym, e_sym, links))
+ if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals:
+ return rules
+ last_index = nt[-1][0] if nt else 0
+ # Rule [X]
+ if not nt or not isinstance(f_sym[-1], NonTerminal):
+ f_sym.append(NonTerminal(last_index + 1))
+ e_sym.append(NonTerminal(last_index + 1))
+ rules.append(fmt_rule(f_sym, e_sym, links))
+ f_sym.pop()
+ e_sym.pop()
+ # [X] Rule
+ if not nt or not isinstance(f_sym[0], NonTerminal):
+ for sym in f_sym:
+ if isinstance(sym, NonTerminal):
+ sym.index += 1
+ for sym in e_sym:
+ if isinstance(sym, NonTerminal):
+ sym.index += 1
+ for link in links:
+ link[0] += 1
+ link[1] += 1
+ f_sym.insert(0, NonTerminal(1))
+ e_sym.insert(0, NonTerminal(1))
+ rules.append(fmt_rule(f_sym, e_sym, links))
+ if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals:
+ return rules
+ # [X] Rule [X]
+ if not nt or not isinstance(f_sym[-1], NonTerminal):
+ f_sym.append(NonTerminal(last_index + 2))
+ e_sym.append(NonTerminal(last_index + 2))
+ rules.append(fmt_rule(f_sym, e_sym, links))
+ return rules
+
def main(argv):
extractor = OnlineGrammarExtractor()