summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi156
1 files changed, 102 insertions, 54 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index c26f5c43..be73f567 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -1836,18 +1836,24 @@ cdef class HieroCachingRuleFactory:
# Pre-compute alignment info
al = [[] for i in range(f_len)]
- al_span = [[f_len + 1, -1] for i in range(f_len)]
+ fe_span = [[f_len + 1, -1] for i in range(f_len)]
+ ef_span = [[e_len + 1, -1] for i in range(e_len)]
for (f, e) in alignment:
al[f].append(e)
- al_span[f][0] = min(al_span[f][0], e)
- al_span[f][1] = max(al_span[f][1], e)
+ fe_span[f][0] = min(fe_span[f][0], e)
+ fe_span[f][1] = max(fe_span[f][1], e)
+ ef_span[e][0] = min(ef_span[e][0], f)
+ ef_span[e][1] = max(ef_span[e][1], f)
# Target side word coverage
cover = [0] * e_len
-
+ # Non-terminal coverage
+ f_nt_cover = [0] * e_len
+ e_nt_cover = [0] * e_len
+
# Extract all possible hierarchical phrases starting at a source index
# f_ i and j are current, e_ i and j are previous
- def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open):
+ def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open):
# Phrase extraction limits
if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \
(f_j - f_i) + 1 > self.max_initial_size:
@@ -1857,17 +1863,36 @@ cdef class HieroCachingRuleFactory:
# Open non-terminal: extend
if nt_open:
nt[-1][2] += 1
- extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True)
+ extract(f_i, f_j + 1, e_i, e_j, min_bound, wc, links, nt, True)
nt[-1][2] -= 1
# No open non-terminal: extend with word
else:
- extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False)
+ extract(f_i, f_j + 1, e_i, e_j, min_bound, wc + 1, links, nt, False)
return
# Aligned word
- link_i = al_span[f_j][0]
- link_j = al_span[f_j][1]
+ link_i = fe_span[f_j][0]
+ link_j = fe_span[f_j][1]
new_e_i = min(link_i, e_i)
new_e_j = max(link_j, e_j)
+ # Check reverse links of newly covered words to see if they violate left
+ # bound (return) or extend minimum right bound for chunk
+ new_min_bound = min_bound
+ # First aligned word creates span
+ if e_j == -1:
+ for i from new_e_i <= i <= new_e_j:
+ if ef_span[i][0] < f_i:
+ return
+ new_min_bound = max(new_min_bound, ef_span[i][1])
+ # Other aligned words extend span
+ else:
+ for i from new_e_i <= i < e_i:
+ if ef_span[i][0] < f_i:
+ return
+ new_min_bound = max(new_min_bound, ef_span[i][1])
+ for i from e_j < i <= new_e_j:
+ if ef_span[i][0] < f_i:
+ return
+ new_min_bound = max(new_min_bound, ef_span[i][1])
# Open non-terminal: close, extract, extend
if nt_open:
# Close non-terminal, checking for collisions
@@ -1877,43 +1902,50 @@ cdef class HieroCachingRuleFactory:
if not span_check(cover, link_i, nt[-1][3] - 1):
nt[-1] = old_last_nt
return
- span_flip(cover, link_i, nt[-1][3] - 1)
+ span_inc(cover, link_i, nt[-1][3] - 1)
+ span_inc(e_nt_cover, link_i, nt[-1][3] - 1)
nt[-1][3] = link_i
if link_j > nt[-1][4]:
if not span_check(cover, nt[-1][4] + 1, link_j):
nt[-1] = old_last_nt
return
- span_flip(cover, nt[-1][4] + 1, link_j)
+ span_inc(cover, nt[-1][4] + 1, link_j)
+ span_inc(e_nt_cover, nt[-1][4] + 1, link_j)
nt[-1][4] = link_j
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ # Make sure we cover all aligned words
+ if f_j >= new_min_bound:
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ rules.add(rule)
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
nt[-1] = old_last_nt
if link_i < nt[-1][3]:
- span_flip(cover, link_i, nt[-1][3] - 1)
+ span_dec(cover, link_i, nt[-1][3] - 1)
+ span_dec(e_nt_cover, link_i, nt[-1][3] - 1)
if link_j > nt[-1][4]:
- span_flip(cover, nt[-1][4] + 1, link_j)
+ span_dec(cover, nt[-1][4] + 1, link_j)
+ span_dec(e_nt_cover, nt[-1][4] + 1, link_j)
return
# No open non-terminal
# Extract, extend with word
- collision = False
+ nt_collision = False
for link in al[f_j]:
- if cover[link]:
- collision = True
- # Collisions block extraction and extension, but may be okay for
- # continuing non-terminals
- if not collision:
+ if e_nt_cover[link]:
+ nt_collision = True
+ # Non-terminal collisions block word extraction and extension, but
+ # may be okay for continuing non-terminals
+ if not nt_collision:
plus_links = []
for link in al[f_j]:
plus_links.append((f_j, link))
- cover[link] = ~cover[link]
+ cover[link] += 1
links.append(plus_links)
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False)
+ if f_j >= new_min_bound:
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ rules.add(rule)
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
links.pop()
for link in al[f_j]:
- cover[link] = ~cover[link]
+ cover[link] -= 1
# Try to add a word to a (closed) non-terminal, extract, extend
if nt and nt[-1][2] == f_j - 1:
# Add to non-terminal, checking for collisions
@@ -1923,38 +1955,46 @@ cdef class HieroCachingRuleFactory:
if not span_check(cover, link_i, nt[-1][3] - 1):
nt[-1] = old_last_nt
return
- span_flip(cover, link_i, nt[-1][3] - 1)
+ span_inc(cover, link_i, nt[-1][3] - 1)
+ span_inc(e_nt_cover, link_i, nt[-1][3] - 1)
nt[-1][3] = link_i
if link_j > nt[-1][4]:
if not span_check(cover, nt[-1][4] + 1, link_j):
nt[-1] = old_last_nt
return
- span_flip(cover, nt[-1][4] + 1, link_j)
+ span_inc(cover, nt[-1][4] + 1, link_j)
+ span_inc(e_nt_cover, nt[-1][4] + 1, link_j)
nt[-1][4] = link_j
# Require at least one word in phrase
if links:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ if f_j >= new_min_bound:
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ rules.add(rule)
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
nt[-1] = old_last_nt
if new_e_i < nt[-1][3]:
- span_flip(cover, link_i, nt[-1][3] - 1)
+ span_dec(cover, link_i, nt[-1][3] - 1)
+ span_dec(e_nt_cover, link_i, nt[-1][3] - 1)
if link_j > nt[-1][4]:
- span_flip(cover, nt[-1][4] + 1, link_j)
+ span_dec(cover, nt[-1][4] + 1, link_j)
+ span_dec(e_nt_cover, nt[-1][4] + 1, link_j)
# Try to start a new non-terminal, extract, extend
if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals:
# Check for collisions
if not span_check(cover, link_i, link_j):
return
- span_flip(cover, link_i, link_j)
+ span_inc(cover, link_i, link_j)
+ span_inc(e_nt_cover, link_i, link_j)
nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j])
# Require at least one word in phrase
if links:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ if f_j >= new_min_bound:
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ rules.add(rule)
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
nt.pop()
- span_flip(cover, link_i, link_j)
+ span_dec(cover, link_i, link_j)
+ span_dec(e_nt_cover, link_i, link_j)
# Try to extract phrases from every f index
f_i = 0
@@ -1963,7 +2003,7 @@ cdef class HieroCachingRuleFactory:
if not al[f_i]:
f_i += 1
continue
- extract(f_i, f_i, f_len + 1, -1, 1, [], [], False)
+ extract(f_i, f_i, f_len + 1, -1, f_i, 1, [], [], False)
f_i += 1
for rule in sorted(rules):
@@ -2045,11 +2085,13 @@ cdef class HieroCachingRuleFactory:
nt_i += 1
links[i][1] -= off
i += 1
-
+
# Rule
rules.append(self.new_rule(f_sym, e_sym, links))
if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals:
return rules
+ # DEBUG: no boundary NTs
+ return rules
last_index = nt[-1][0] if nt else 0
# Rule [X]
if not nt or not sym_isvar(f_sym[-1]):
@@ -2098,16 +2140,6 @@ cdef class HieroCachingRuleFactory:
logger.info('------------------------------')
logger.info(' Online Stats ')
logger.info('------------------------------')
- logger.info('F')
- for ph in self.phrases_f:
- logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph]))
- logger.info('E')
- for ph in self.phrases_e:
- logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph]))
- logger.info('FE')
- for ph in self.phrases_fe:
- for ph2 in self.phrases_fe[ph]:
- logger.info(str(ph) + ' ||| ' + str(ph2) + ' ||| ' + str(self.phrases_fe[ph][ph2]))
logger.info('f')
for w in self.bilex_f:
logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w]))
@@ -2118,6 +2150,16 @@ cdef class HieroCachingRuleFactory:
for w in self.bilex_fe:
for w2 in self.bilex_fe[w]:
logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2]))
+ logger.info('F')
+ for ph in self.phrases_f:
+ logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph]))
+ logger.info('E')
+ for ph in self.phrases_e:
+ logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph]))
+ logger.info('FE')
+ for ph in self.phrases_fe:
+ for ph2 in self.phrases_fe[ph]:
+ logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2]))
# Spans are _inclusive_ on both ends [i, j]
# Could be more efficient but probably not a bottleneck
@@ -2129,8 +2171,14 @@ def span_check(vec, i, j):
k += 1
return True
-def span_flip(vec, i, j):
+def span_inc(vec, i, j):
k = i
while k <= j:
- vec[k] = ~vec[k]
- k += 1 \ No newline at end of file
+ vec[k] += 1
+ k += 1
+
+def span_dec(vec, i, j):
+ k = i
+ while k <= j:
+ vec[k] -= 1
+ k += 1