summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi48
1 files changed, 18 insertions, 30 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index 2bc65da2..b10d25dd 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -1840,8 +1840,8 @@ cdef class HieroCachingRuleFactory:
# Pre-compute alignment info
al = [[] for i in range(f_len)]
- fe_span = [[f_len + 1, -1] for i in range(f_len)]
- ef_span = [[e_len + 1, -1] for i in range(e_len)]
+ fe_span = [[e_len + 1, -1] for i in range(f_len)]
+ ef_span = [[f_len + 1, -1] for i in range(e_len)]
for (f, e) in alignment:
al[f].append(e)
fe_span[f][0] = min(fe_span[f][0], e)
@@ -1857,10 +1857,10 @@ cdef class HieroCachingRuleFactory:
# Extract all possible hierarchical phrases starting at a source index
# f_ i and j are current, e_ i and j are previous
+ # We care _considering_ f_j, so it is not yet in counts
def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open):
# Phrase extraction limits
- if wc + len(nt) > self.max_length or f_j > (f_len - 1) or \
- (f_j - f_i) + 1 > self.max_initial_size:
+ if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size:
return
# Unaligned word
if not al[f_j]:
@@ -1870,7 +1870,8 @@ cdef class HieroCachingRuleFactory:
extract(f_i, f_j + 1, e_i, e_j, min_bound, wc, links, nt, True)
nt[-1][2] -= 1
# Unless non-terminal already open, always extend with word
- if not nt_open:
+ # Make sure adding a word doesn't exceed length
+ if not nt_open and wc < self.max_length:
extract(f_i, f_j + 1, e_i, e_j, min_bound, wc + 1, links, nt, False)
return
# Aligned word
@@ -1881,11 +1882,8 @@ cdef class HieroCachingRuleFactory:
# Check reverse links of newly covered words to see if they violate left
# bound (return) or extend minimum right bound for chunk
new_min_bound = min_bound
- #new_min_nt_bound = min_nt_bound
- # violates_nt = False
# First aligned word creates span
- # TODO: NO bound
- if e_j == -1:
+ if e_j == -1:
for i from new_e_i <= i <= new_e_j:
if ef_span[i][0] < f_i:
return
@@ -1908,15 +1906,14 @@ cdef class HieroCachingRuleFactory:
nt_collision = True
# Non-terminal collisions block word extraction and extension, but
# may be okay for continuing non-terminals
- if not nt_collision:
+ if not nt_collision and wc < self.max_length:
plus_links = []
for link in al[f_j]:
plus_links.append((f_j, link))
cover[link] += 1
links.append(plus_links)
- if links:
- if f_j >= new_min_bound:
- rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ if links and f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
links.pop()
for link in al[f_j]:
@@ -1940,9 +1937,8 @@ cdef class HieroCachingRuleFactory:
span_inc(cover, nt[-1][4] + 1, link_j)
span_inc(e_nt_cover, nt[-1][4] + 1, link_j)
nt[-1][4] = link_j
- if links:
- if f_j >= new_min_bound:
- rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ if links and f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
nt[-1] = old_last_nt
if link_i < nt[-1][3]:
@@ -1952,7 +1948,7 @@ cdef class HieroCachingRuleFactory:
span_dec(cover, nt[-1][4] + 1, link_j)
span_dec(e_nt_cover, nt[-1][4] + 1, link_j)
# Try to start a new non-terminal, extract, extend
- if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals:
+ if (not nt or f_j - nt[-1][2] > 1) and wc < self.max_length and len(nt) < self.max_nonterminals:
# Check for collisions
if not span_check(cover, link_i, link_j):
return
@@ -1960,26 +1956,19 @@ cdef class HieroCachingRuleFactory:
span_inc(e_nt_cover, link_i, link_j)
nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j])
# Require at least one word in phrase
- if links:
- if f_j >= new_min_bound:
- rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
- extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
+ if links and f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
nt.pop()
span_dec(cover, link_i, link_j)
span_dec(e_nt_cover, link_i, link_j)
# Try to extract phrases from every f index
- f_i = 0
- while f_i < f_len:
+ for f_i from 0 <= f_i < f_len:
# Skip if phrases won't be tight on left side
if not al[f_i]:
- f_i += 1
continue
- extract(f_i, f_i, f_len + 1, -1, f_i, 1, [], [], False)
- f_i += 1
-
- for rule in sorted(rules):
- logger.info(self.fmt_rule(*rule[:3]))
+ extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False)
# Update phrase counts
for rule in rules:
@@ -2062,7 +2051,6 @@ cdef class HieroCachingRuleFactory:
f = Phrase(f_sym)
e = Phrase(e_sym)
a = tuple(self.alignment.link(i, j) for (i, j) in links)
- print 'New rule:', self.fmt_rule(f, e, a)
return (f, e, a, lex_f_i, lex_f_j)
# Rule string from rule