summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2013-01-21 12:29:43 +0100
committerPatrick Simianer <p@simianer.de>2013-01-21 12:29:43 +0100
commit0d23f8aecbfaf982cd165ebfc2a1611cefcc7275 (patch)
tree8eafa6ea43224ff70635cadd4d6f027d28f4986f /python/src/sa/rulefactory.pxi
parentdbc66cd3944321961c5e11d5254fd914f05a98ad (diff)
parent7cac43b858f3b681555bf0578f54b1f822c43207 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi284
1 files changed, 127 insertions, 157 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index c26f5c43..b10d25dd 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -1829,6 +1829,10 @@ cdef class HieroCachingRuleFactory:
def add_instance(self, f_words, e_words, alignment):
# Rules extracted from this instance
+ # Track span of lexical items (terminals) to make
+ # sure we don't extract the same rule for the same
+ # span more than once.
+ # (f, e, al, lex_f_i, lex_f_j)
rules = set()
f_len = len(f_words)
@@ -1836,85 +1840,85 @@ cdef class HieroCachingRuleFactory:
# Pre-compute alignment info
al = [[] for i in range(f_len)]
- al_span = [[f_len + 1, -1] for i in range(f_len)]
+ fe_span = [[e_len + 1, -1] for i in range(f_len)]
+ ef_span = [[f_len + 1, -1] for i in range(e_len)]
for (f, e) in alignment:
al[f].append(e)
- al_span[f][0] = min(al_span[f][0], e)
- al_span[f][1] = max(al_span[f][1], e)
+ fe_span[f][0] = min(fe_span[f][0], e)
+ fe_span[f][1] = max(fe_span[f][1], e)
+ ef_span[e][0] = min(ef_span[e][0], f)
+ ef_span[e][1] = max(ef_span[e][1], f)
# Target side word coverage
cover = [0] * e_len
-
+ # Non-terminal coverage
+ f_nt_cover = [0] * f_len
+ e_nt_cover = [0] * e_len
+
# Extract all possible hierarchical phrases starting at a source index
# f_ i and j are current, e_ i and j are previous
- def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open):
+ # We care _considering_ f_j, so it is not yet in counts
+ def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open):
# Phrase extraction limits
- if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \
- (f_j - f_i) + 1 > self.max_initial_size:
+ if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size:
return
# Unaligned word
if not al[f_j]:
- # Open non-terminal: extend
- if nt_open:
+ # Adjacent to non-terminal: extend (non-terminal now open)
+ if nt and nt[-1][2] == f_j - 1:
nt[-1][2] += 1
- extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True)
+ extract(f_i, f_j + 1, e_i, e_j, min_bound, wc, links, nt, True)
nt[-1][2] -= 1
- # No open non-terminal: extend with word
- else:
- extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False)
+ # Unless non-terminal already open, always extend with word
+ # Make sure adding a word doesn't exceed length
+ if not nt_open and wc < self.max_length:
+ extract(f_i, f_j + 1, e_i, e_j, min_bound, wc + 1, links, nt, False)
return
# Aligned word
- link_i = al_span[f_j][0]
- link_j = al_span[f_j][1]
+ link_i = fe_span[f_j][0]
+ link_j = fe_span[f_j][1]
new_e_i = min(link_i, e_i)
new_e_j = max(link_j, e_j)
- # Open non-terminal: close, extract, extend
- if nt_open:
- # Close non-terminal, checking for collisions
- old_last_nt = nt[-1][:]
- nt[-1][2] = f_j
- if link_i < nt[-1][3]:
- if not span_check(cover, link_i, nt[-1][3] - 1):
- nt[-1] = old_last_nt
+ # Check reverse links of newly covered words to see if they violate left
+ # bound (return) or extend minimum right bound for chunk
+ new_min_bound = min_bound
+ # First aligned word creates span
+ if e_j == -1:
+ for i from new_e_i <= i <= new_e_j:
+ if ef_span[i][0] < f_i:
return
- span_flip(cover, link_i, nt[-1][3] - 1)
- nt[-1][3] = link_i
- if link_j > nt[-1][4]:
- if not span_check(cover, nt[-1][4] + 1, link_j):
- nt[-1] = old_last_nt
+ new_min_bound = max(new_min_bound, ef_span[i][1])
+ # Other aligned words extend span
+ else:
+ for i from new_e_i <= i < e_i:
+ if ef_span[i][0] < f_i:
return
- span_flip(cover, nt[-1][4] + 1, link_j)
- nt[-1][4] = link_j
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
- nt[-1] = old_last_nt
- if link_i < nt[-1][3]:
- span_flip(cover, link_i, nt[-1][3] - 1)
- if link_j > nt[-1][4]:
- span_flip(cover, nt[-1][4] + 1, link_j)
- return
- # No open non-terminal
- # Extract, extend with word
- collision = False
- for link in al[f_j]:
- if cover[link]:
- collision = True
- # Collisions block extraction and extension, but may be okay for
- # continuing non-terminals
- if not collision:
- plus_links = []
- for link in al[f_j]:
- plus_links.append((f_j, link))
- cover[link] = ~cover[link]
- links.append(plus_links)
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False)
- links.pop()
+ new_min_bound = max(new_min_bound, ef_span[i][1])
+ for i from e_j < i <= new_e_j:
+ if ef_span[i][0] < f_i:
+ return
+ new_min_bound = max(new_min_bound, ef_span[i][1])
+ # Extract, extend with word (unless non-terminal open)
+ if not nt_open:
+ nt_collision = False
for link in al[f_j]:
- cover[link] = ~cover[link]
- # Try to add a word to a (closed) non-terminal, extract, extend
+ if e_nt_cover[link]:
+ nt_collision = True
+ # Non-terminal collisions block word extraction and extension, but
+ # may be okay for continuing non-terminals
+ if not nt_collision and wc < self.max_length:
+ plus_links = []
+ for link in al[f_j]:
+ plus_links.append((f_j, link))
+ cover[link] += 1
+ links.append(plus_links)
+ if links and f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
+ links.pop()
+ for link in al[f_j]:
+ cover[link] -= 1
+ # Try to add a word to current non-terminal (if any), extract, extend
if nt and nt[-1][2] == f_j - 1:
# Add to non-terminal, checking for collisions
old_last_nt = nt[-1][:]
@@ -1923,65 +1927,57 @@ cdef class HieroCachingRuleFactory:
if not span_check(cover, link_i, nt[-1][3] - 1):
nt[-1] = old_last_nt
return
- span_flip(cover, link_i, nt[-1][3] - 1)
+ span_inc(cover, link_i, nt[-1][3] - 1)
+ span_inc(e_nt_cover, link_i, nt[-1][3] - 1)
nt[-1][3] = link_i
if link_j > nt[-1][4]:
if not span_check(cover, nt[-1][4] + 1, link_j):
nt[-1] = old_last_nt
return
- span_flip(cover, nt[-1][4] + 1, link_j)
+ span_inc(cover, nt[-1][4] + 1, link_j)
+ span_inc(e_nt_cover, nt[-1][4] + 1, link_j)
nt[-1][4] = link_j
- # Require at least one word in phrase
- if links:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ if links and f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
nt[-1] = old_last_nt
- if new_e_i < nt[-1][3]:
- span_flip(cover, link_i, nt[-1][3] - 1)
+ if link_i < nt[-1][3]:
+ span_dec(cover, link_i, nt[-1][3] - 1)
+ span_dec(e_nt_cover, link_i, nt[-1][3] - 1)
if link_j > nt[-1][4]:
- span_flip(cover, nt[-1][4] + 1, link_j)
+ span_dec(cover, nt[-1][4] + 1, link_j)
+ span_dec(e_nt_cover, nt[-1][4] + 1, link_j)
# Try to start a new non-terminal, extract, extend
- if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals:
+ if (not nt or f_j - nt[-1][2] > 1) and wc < self.max_length and len(nt) < self.max_nonterminals:
# Check for collisions
if not span_check(cover, link_i, link_j):
return
- span_flip(cover, link_i, link_j)
+ span_inc(cover, link_i, link_j)
+ span_inc(e_nt_cover, link_i, link_j)
nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j])
# Require at least one word in phrase
- if links:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ if links and f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
nt.pop()
- span_flip(cover, link_i, link_j)
+ span_dec(cover, link_i, link_j)
+ span_dec(e_nt_cover, link_i, link_j)
# Try to extract phrases from every f index
- f_i = 0
- while f_i < f_len:
+ for f_i from 0 <= f_i < f_len:
# Skip if phrases won't be tight on left side
if not al[f_i]:
- f_i += 1
continue
- extract(f_i, f_i, f_len + 1, -1, 1, [], [], False)
- f_i += 1
-
- for rule in sorted(rules):
- logger.info(self.fmt_rule(*rule))
+ extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False)
# Update phrase counts
- f_set = set()
- e_set = set()
- for (f_ph, e_ph, al) in rules:
- f_set.add(f_ph)
- e_set.add(e_ph)
+ for rule in rules:
+ (f_ph, e_ph, al) = rule[:3]
+ self.phrases_f[f_ph] += 1
+ self.phrases_e[e_ph] += 1
self.phrases_fe[f_ph][e_ph] += 1
if not self.phrases_al[f_ph][e_ph]:
self.phrases_al[f_ph][e_ph] = al
- for f_ph in f_set:
- self.phrases_f[f_ph] += 1
- for e_ph in e_set:
- self.phrases_e[e_ph] += 1
# Update Bilexical counts
for e_w in e_words:
@@ -1993,14 +1989,10 @@ cdef class HieroCachingRuleFactory:
# Create a rule from source, target, non-terminals, and alignments
- def form_rules(self, f_i, e_i, f_span, e_span, nt, al):
-
- # This could be more efficient but is unlikely to be the bottleneck
-
- rules = []
+ def form_rule(self, f_i, e_i, f_span, e_span, nt, al):
+ # Substitute in non-terminals
nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
-
f_sym = list(f_span[:])
off = f_i
for next_nt in nt:
@@ -2011,7 +2003,6 @@ cdef class HieroCachingRuleFactory:
i += 1
f_sym.insert(next_nt[1] - off, sym_setindex(self.category, next_nt[0]))
off += (nt_len - 1)
-
e_sym = list(e_span[:])
off = e_i
for next_nt in nt_inv:
@@ -2025,6 +2016,7 @@ cdef class HieroCachingRuleFactory:
# Adjusting alignment links takes some doing
links = [list(link) for sub in al for link in sub]
+ links_inv = sorted(links, cmp=lambda x, y: cmp(x[1], y[1]))
links_len = len(links)
nt_len = len(nt)
nt_i = 0
@@ -2040,55 +2032,28 @@ cdef class HieroCachingRuleFactory:
off = e_i
i = 0
while i < links_len:
- while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]:
+ while nt_i < nt_len and links_inv[i][1] > nt_inv[nt_i][3]:
off += (nt_inv[nt_i][4] - nt_inv[nt_i][3])
nt_i += 1
- links[i][1] -= off
+ links_inv[i][1] -= off
i += 1
-
- # Rule
- rules.append(self.new_rule(f_sym, e_sym, links))
- if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals:
- return rules
- last_index = nt[-1][0] if nt else 0
- # Rule [X]
- if not nt or not sym_isvar(f_sym[-1]):
- f_sym.append(sym_setindex(self.category, last_index + 1))
- e_sym.append(sym_setindex(self.category, last_index + 1))
- rules.append(self.new_rule(f_sym, e_sym, links))
- f_sym.pop()
- e_sym.pop()
- # [X] Rule
- f_len = len(f_sym)
- e_len = len(e_sym)
- if not nt or not sym_isvar(f_sym[0]):
- for i from 0 <= i < f_len:
- if sym_isvar(f_sym[i]):
- f_sym[i] = sym_setindex(self.category, sym_getindex(f_sym[i]) + 1)
- for i from 0 <= i < e_len:
- if sym_isvar(e_sym[i]):
- e_sym[i] = sym_setindex(self.category, sym_getindex(e_sym[i]) + 1)
- for link in links:
- link[0] += 1
- link[1] += 1
- f_sym.insert(0, sym_setindex(self.category, 1))
- e_sym.insert(0, sym_setindex(self.category, 1))
- rules.append(self.new_rule(f_sym, e_sym, links))
- if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals:
- return rules
- # [X] Rule [X]
- if not nt or not sym_isvar(f_sym[-1]):
- f_sym.append(sym_setindex(self.category, last_index + 2))
- e_sym.append(sym_setindex(self.category, last_index + 2))
- rules.append(self.new_rule(f_sym, e_sym, links))
- return rules
-
- def new_rule(self, f_sym, e_sym, links):
+
+ # Find lexical span
+ lex_f_i = f_i
+ lex_f_j = f_i + (len(f_span) - 1)
+ if nt:
+ if nt[0][1] == lex_f_i:
+ lex_f_i += (nt[0][2] - nt[0][1]) + 1
+ if nt[-1][2] == lex_f_j:
+ lex_f_j -= (nt[-1][2] - nt[-1][1]) + 1
+
+ # Create rule (f_phrase, e_phrase, links, f_link_min, f_link_max)
f = Phrase(f_sym)
e = Phrase(e_sym)
a = tuple(self.alignment.link(i, j) for (i, j) in links)
- return (f, e, a)
-
+ return (f, e, a, lex_f_i, lex_f_j)
+
+ # Rule string from rule
def fmt_rule(self, f, e, a):
a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a)
return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str)
@@ -2098,16 +2063,6 @@ cdef class HieroCachingRuleFactory:
logger.info('------------------------------')
logger.info(' Online Stats ')
logger.info('------------------------------')
- logger.info('F')
- for ph in self.phrases_f:
- logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph]))
- logger.info('E')
- for ph in self.phrases_e:
- logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph]))
- logger.info('FE')
- for ph in self.phrases_fe:
- for ph2 in self.phrases_fe[ph]:
- logger.info(str(ph) + ' ||| ' + str(ph2) + ' ||| ' + str(self.phrases_fe[ph][ph2]))
logger.info('f')
for w in self.bilex_f:
logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w]))
@@ -2118,9 +2073,18 @@ cdef class HieroCachingRuleFactory:
for w in self.bilex_fe:
for w2 in self.bilex_fe[w]:
logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2]))
+ logger.info('F')
+ for ph in self.phrases_f:
+ logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph]))
+ logger.info('E')
+ for ph in self.phrases_e:
+ logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph]))
+ logger.info('FE')
+ for ph in self.phrases_fe:
+ for ph2 in self.phrases_fe[ph]:
+ logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2]))
# Spans are _inclusive_ on both ends [i, j]
-# Could be more efficient but probably not a bottleneck
def span_check(vec, i, j):
k = i
while k <= j:
@@ -2129,8 +2093,14 @@ def span_check(vec, i, j):
k += 1
return True
-def span_flip(vec, i, j):
+def span_inc(vec, i, j):
k = i
while k <= j:
- vec[k] = ~vec[k]
- k += 1 \ No newline at end of file
+ vec[k] += 1
+ k += 1
+
+def span_dec(vec, i, j):
+ k = i
+ while k <= j:
+ vec[k] -= 1
+ k += 1