summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/src/sa/rulefactory.pxi109
1 files changed, 42 insertions, 67 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index 29bd809c..2bc65da2 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -1829,10 +1829,10 @@ cdef class HieroCachingRuleFactory:
def add_instance(self, f_words, e_words, alignment):
# Rules extracted from this instance
- # Track span of absolute alignment links to make
+ # Track span of lexical items (terminals) to make
# sure we don't extract the same rule for the same
# span more than once.
- # (f, e, al, f_min_link, f_max_link)
+ # (f, e, al, lex_f_i, lex_f_j)
rules = set()
f_len = len(f_words)
@@ -1852,14 +1852,14 @@ cdef class HieroCachingRuleFactory:
# Target side word coverage
cover = [0] * e_len
# Non-terminal coverage
- f_nt_cover = [0] * e_len
+ f_nt_cover = [0] * f_len
e_nt_cover = [0] * e_len
# Extract all possible hierarchical phrases starting at a source index
# f_ i and j are current, e_ i and j are previous
def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open):
# Phrase extraction limits
- if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \
+ if wc + len(nt) > self.max_length or f_j > (f_len - 1) or \
(f_j - f_i) + 1 > self.max_initial_size:
return
# Unaligned word
@@ -1881,7 +1881,10 @@ cdef class HieroCachingRuleFactory:
# Check reverse links of newly covered words to see if they violate left
# bound (return) or extend minimum right bound for chunk
new_min_bound = min_bound
+ #new_min_nt_bound = min_nt_bound
+ # violates_nt = False
# First aligned word creates span
+ # TODO: NO bound
if e_j == -1:
for i from new_e_i <= i <= new_e_j:
if ef_span[i][0] < f_i:
@@ -1897,61 +1900,28 @@ cdef class HieroCachingRuleFactory:
if ef_span[i][0] < f_i:
return
new_min_bound = max(new_min_bound, ef_span[i][1])
- # Open non-terminal: close, extract, extend
- if nt_open:
- # Close non-terminal, checking for collisions
- old_last_nt = nt[-1][:]
- nt[-1][2] = f_j
- if link_i < nt[-1][3]:
- if not span_check(cover, link_i, nt[-1][3] - 1):
- nt[-1] = old_last_nt
- return
- span_inc(cover, link_i, nt[-1][3] - 1)
- span_inc(e_nt_cover, link_i, nt[-1][3] - 1)
- nt[-1][3] = link_i
- if link_j > nt[-1][4]:
- if not span_check(cover, nt[-1][4] + 1, link_j):
- nt[-1] = old_last_nt
- return
- span_inc(cover, nt[-1][4] + 1, link_j)
- span_inc(e_nt_cover, nt[-1][4] + 1, link_j)
- nt[-1][4] = link_j
- # Make sure we have at least one lexical alignment link
- if links:
- # Make sure we cover all aligned words
- if f_j >= new_min_bound:
- rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
- extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
- nt[-1] = old_last_nt
- if link_i < nt[-1][3]:
- span_dec(cover, link_i, nt[-1][3] - 1)
- span_dec(e_nt_cover, link_i, nt[-1][3] - 1)
- if link_j > nt[-1][4]:
- span_dec(cover, nt[-1][4] + 1, link_j)
- span_dec(e_nt_cover, nt[-1][4] + 1, link_j)
- return
- # No open non-terminal
- # Extract, extend with word
- nt_collision = False
- for link in al[f_j]:
- if e_nt_cover[link]:
- nt_collision = True
- # Non-terminal collisions block word extraction and extension, but
- # may be okay for continuing non-terminals
- if not nt_collision:
- plus_links = []
- for link in al[f_j]:
- plus_links.append((f_j, link))
- cover[link] += 1
- links.append(plus_links)
- if links:
- if f_j >= new_min_bound:
- rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
- extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
- links.pop()
+ # Extract, extend with word (unless non-terminal open)
+ if not nt_open:
+ nt_collision = False
for link in al[f_j]:
- cover[link] -= 1
- # Try to add a word to a (closed) non-terminal, extract, extend
+ if e_nt_cover[link]:
+ nt_collision = True
+ # Non-terminal collisions block word extraction and extension, but
+ # may be okay for continuing non-terminals
+ if not nt_collision:
+ plus_links = []
+ for link in al[f_j]:
+ plus_links.append((f_j, link))
+ cover[link] += 1
+ links.append(plus_links)
+ if links:
+ if f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
+ links.pop()
+ for link in al[f_j]:
+ cover[link] -= 1
+ # Try to add a word to current non-terminal (if any), extract, extend
if nt and nt[-1][2] == f_j - 1:
# Add to non-terminal, checking for collisions
old_last_nt = nt[-1][:]
@@ -1972,10 +1942,10 @@ cdef class HieroCachingRuleFactory:
nt[-1][4] = link_j
if links:
if f_j >= new_min_bound:
- rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
nt[-1] = old_last_nt
- if new_e_i < nt[-1][3]:
+ if link_i < nt[-1][3]:
span_dec(cover, link_i, nt[-1][3] - 1)
span_dec(e_nt_cover, link_i, nt[-1][3] - 1)
if link_j > nt[-1][4]:
@@ -2032,10 +2002,8 @@ cdef class HieroCachingRuleFactory:
# Create a rule from source, target, non-terminals, and alignments
def form_rule(self, f_i, e_i, f_span, e_span, nt, al):
- # Handle non-terminals
-
+ # Substitute in non-terminals
nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
-
f_sym = list(f_span[:])
off = f_i
for next_nt in nt:
@@ -2046,7 +2014,6 @@ cdef class HieroCachingRuleFactory:
i += 1
f_sym.insert(next_nt[1] - off, sym_setindex(self.category, next_nt[0]))
off += (nt_len - 1)
-
e_sym = list(e_span[:])
off = e_i
for next_nt in nt_inv:
@@ -2060,8 +2027,6 @@ cdef class HieroCachingRuleFactory:
# Adjusting alignment links takes some doing
links = [list(link) for sub in al for link in sub]
- f_link_min = links[0][0]
- f_link_max = links[-1][0]
links_inv = sorted(links, cmp=lambda x, y: cmp(x[1], y[1]))
links_len = len(links)
nt_len = len(nt)
@@ -2084,11 +2049,21 @@ cdef class HieroCachingRuleFactory:
links_inv[i][1] -= off
i += 1
+ # Find lexical span
+ lex_f_i = f_i
+ lex_f_j = f_i + (len(f_span) - 1)
+ if nt:
+ if nt[0][1] == lex_f_i:
+ lex_f_i += (nt[0][2] - nt[0][1]) + 1
+ if nt[-1][2] == lex_f_j:
+ lex_f_j -= (nt[-1][2] - nt[-1][1]) + 1
+
# Create rule (f_phrase, e_phrase, links, f_link_min, f_link_max)
f = Phrase(f_sym)
e = Phrase(e_sym)
a = tuple(self.alignment.link(i, j) for (i, j) in links)
- return (f, e, a, f_link_min, f_link_max)
+ print 'New rule:', self.fmt_rule(f, e, a)
+ return (f, e, a, lex_f_i, lex_f_j)
# Rule string from rule
def fmt_rule(self, f, e, a):