summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi311
1 files changed, 307 insertions, 4 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index a0bda793..81ea7960 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -264,6 +264,14 @@ cdef class HieroCachingRuleFactory:
cdef IntList findexes
cdef IntList findexes1
+ cdef phrases_f
+ cdef phrases_e
+ cdef phrases_fe
+ cdef phrases_al
+ cdef bilex_f
+ cdef bilex_e
+ cdef bilex_fe
+
def __cinit__(self,
# compiled alignment object (REQUIRED)
Alignment alignment,
@@ -370,6 +378,19 @@ cdef class HieroCachingRuleFactory:
self.findexes = IntList(initial_len=10)
self.findexes1 = IntList(initial_len=10)
+
+ # Online stats
+
+ # Phrase counts
+ self.phrases_f = defaultdict(int)
+ self.phrases_e = defaultdict(int)
+ self.phrases_fe = defaultdict(lambda: defaultdict(int))
+ self.phrases_al = defaultdict(dict)
+
+ # Bilexical counts
+ self.bilex_f = defaultdict(int)
+ self.bilex_e = defaultdict(int)
+ self.bilex_fe = defaultdict(lambda: defaultdict(int))
def configure(self, SuffixArray fsarray, DataArray edarray,
Sampler sampler, Scorer scorer):
@@ -1799,8 +1820,290 @@ cdef class HieroCachingRuleFactory:
return extracts
+ # Aggregate stats from a training instance:
+ # Extract hierarchical phrase pairs
+ # Update bilexical counts
def add_instance(self, f_words, e_words, alignment):
- logger.info("I would add:")
- logger.info(decode_words(f_words))
- logger.info(decode_words(e_words))
- logger.info(alignment) \ No newline at end of file
+
+ # Bilexical counts
+ self.aggr_bilex(f_words, e_words)
+
+ # Rules extracted from this instance
+ rules = set()
+
+ f_len = len(f_words)
+ e_len = len(e_words)
+
+ # Pre-compute alignment info
+ al = [[] for i in range(f_len)]
+ al_span = [[f_len + 1, -1] for i in range(f_len)]
+ for (f, e) in alignment:
+ al[f].append(e)
+ al_span[f][0] = min(al_span[f][0], e)
+ al_span[f][1] = max(al_span[f][1], e)
+
+ # Target side word coverage
+ # TODO: Does Cython do bit vectors?
+ cover = [0] * e_len
+
+ # Extract all possible hierarchical phrases starting at a source index
+ # f_ i and j are current, e_ i and j are previous
+ def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open):
+ # Phrase extraction limits
+ if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \
+ (f_j - f_i) + 1 > self.max_initial_size:
+ return
+ # Unaligned word
+ if not al[f_j]:
+ # Open non-terminal: extend
+ if nt_open:
+ nt[-1][2] += 1
+ extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True)
+ nt[-1][2] -= 1
+ # No open non-terminal: extend with word
+ else:
+ extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False)
+ return
+ # Aligned word
+ link_i = al_span[f_j][0]
+ link_j = al_span[f_j][1]
+ new_e_i = min(link_i, e_i)
+ new_e_j = max(link_j, e_j)
+ # Open non-terminal: close, extract, extend
+ if nt_open:
+ # Close non-terminal, checking for collisions
+ old_last_nt = nt[-1][:]
+ nt[-1][2] = f_j
+ if link_i < nt[-1][3]:
+ if not span_check(cover, link_i, nt[-1][3] - 1):
+ nt[-1] = old_last_nt
+ return
+ span_flip(cover, link_i, nt[-1][3] - 1)
+ nt[-1][3] = link_i
+ if link_j > nt[-1][4]:
+ if not span_check(cover, nt[-1][4] + 1, link_j):
+ nt[-1] = old_last_nt
+ return
+ span_flip(cover, nt[-1][4] + 1, link_j)
+ nt[-1][4] = link_j
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ rules.add(rule)
+ extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ nt[-1] = old_last_nt
+ if link_i < nt[-1][3]:
+ span_flip(cover, link_i, nt[-1][3] - 1)
+ if link_j > nt[-1][4]:
+ span_flip(cover, nt[-1][4] + 1, link_j)
+ return
+ # No open non-terminal
+ # Extract, extend with word
+ collision = False
+ for link in al[f_j]:
+ if cover[link]:
+ collision = True
+ # Collisions block extraction and extension, but may be okay for
+ # continuing non-terminals
+ if not collision:
+ plus_links = []
+ for link in al[f_j]:
+ plus_links.append((f_j, link))
+ cover[link] = ~cover[link]
+ links.append(plus_links)
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ rules.add(rule)
+ extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False)
+ links.pop()
+ for link in al[f_j]:
+ cover[link] = ~cover[link]
+ # Try to add a word to a (closed) non-terminal, extract, extend
+ if nt and nt[-1][2] == f_j - 1:
+ # Add to non-terminal, checking for collisions
+ old_last_nt = nt[-1][:]
+ nt[-1][2] = f_j
+ if link_i < nt[-1][3]:
+ if not span_check(cover, link_i, nt[-1][3] - 1):
+ nt[-1] = old_last_nt
+ return
+ span_flip(cover, link_i, nt[-1][3] - 1)
+ nt[-1][3] = link_i
+ if link_j > nt[-1][4]:
+ if not span_check(cover, nt[-1][4] + 1, link_j):
+ nt[-1] = old_last_nt
+ return
+ span_flip(cover, nt[-1][4] + 1, link_j)
+ nt[-1][4] = link_j
+ # Require at least one word in phrase
+ if links:
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ rules.add(rule)
+ extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ nt[-1] = old_last_nt
+ if new_e_i < nt[-1][3]:
+ span_flip(cover, link_i, nt[-1][3] - 1)
+ if link_j > nt[-1][4]:
+ span_flip(cover, nt[-1][4] + 1, link_j)
+ # Try to start a new non-terminal, extract, extend
+ if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals:
+ # Check for collisions
+ if not span_check(cover, link_i, link_j):
+ return
+ span_flip(cover, link_i, link_j)
+ nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j])
+ # Require at least one word in phrase
+ if links:
+ for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ rules.add(rule)
+ extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
+ nt.pop()
+ span_flip(cover, link_i, link_j)
+ # TODO: try adding NT to start, end, both
+ # check: one aligned word on boundary that is not part of a NT
+
+ # Try to extract phrases from every f index
+ f_i = 0
+ while f_i < f_len:
+ # Skip if phrases won't be tight on left side
+ if not al[f_i]:
+ f_i += 1
+ continue
+ extract(f_i, f_i, f_len + 1, -1, 1, [], [], False)
+ f_i += 1
+
+ for rule in sorted(rules):
+ logger.info(rule)
+
+ # Aggregate bilexical counts
+ def aggr_bilex(self, f_words, e_words):
+
+ for e_w in e_words:
+ self.bilex_e[e_w] += 1
+
+ for f_w in f_words:
+ self.bilex_f[f_w] += 1
+ for e_w in e_words:
+ self.bilex_fe[f_w][e_w] += 1
+
+ # Create a rule from source, target, non-terminals, and alignments
+ def form_rules(self, f_i, e_i, f_span, e_span, nt, al):
+
+ # This could be more efficient but is unlikely to be the bottleneck
+
+ rules = []
+
+ nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
+
+ logger.info(nt)
+
+ f_sym = list(f_span[:])
+ off = f_i
+ for next_nt in nt:
+ nt_len = (next_nt[2] - next_nt[1]) + 1
+ i = 0
+ while i < nt_len:
+ f_sym.pop(next_nt[1] - off)
+ i += 1
+ f_sym.insert(next_nt[1] - off, sym_setindex(self.category, next_nt[0]))
+ off += (nt_len - 1)
+
+ e_sym = list(e_span[:])
+ off = e_i
+ for next_nt in nt_inv:
+ nt_len = (next_nt[4] - next_nt[3]) + 1
+ i = 0
+ while i < nt_len:
+ e_sym.pop(next_nt[3] - off)
+ i += 1
+ e_sym.insert(next_nt[3] - off, sym_setindex(self.category, next_nt[0]))
+ off += (nt_len - 1)
+
+ # Adjusting alignment links takes some doing
+ links = [list(link) for sub in al for link in sub]
+ links_len = len(links)
+ nt_len = len(nt)
+ nt_i = 0
+ off = f_i
+ i = 0
+ while i < links_len:
+ while nt_i < nt_len and links[i][0] > nt[nt_i][1]:
+ off += (nt[nt_i][2] - nt[nt_i][1])
+ nt_i += 1
+ links[i][0] -= off
+ i += 1
+ nt_i = 0
+ off = e_i
+ i = 0
+ while i < links_len:
+ while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]:
+ off += (nt_inv[nt_i][4] - nt_inv[nt_i][3])
+ nt_i += 1
+ links[i][1] -= off
+ i += 1
+
+ # Rule
+ rules.append(fmt_rule(f_sym, e_sym, links))
+ if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals:
+ return rules
+ last_index = nt[-1][0] if nt else 0
+ # Rule [X]
+ if not nt or not sym_isvar(f_sym[-1]):
+ f_sym.append(sym_setindex(self.category, last_index + 1))
+ e_sym.append(sym_setindex(self.category, last_index + 1))
+ rules.append(fmt_rule(f_sym, e_sym, links))
+ f_sym.pop()
+ e_sym.pop()
+ # [X] Rule
+ f_len = len(f_sym)
+ e_len = len(e_sym)
+ if not nt or not sym_isvar(f_sym[0]):
+ for i from 0 <= i < f_len:
+ if sym_isvar(f_sym[i]):
+ f_sym[i] = sym_setindex(self.category, sym_getindex(f_sym[i]) + 1)
+ for i from 0 <= i < e_len:
+ if sym_isvar(e_sym[i]):
+ e_sym[i] = sym_setindex(self.category, sym_getindex(e_sym[i]) + 1)
+ for link in links:
+ link[0] += 1
+ link[1] += 1
+ f_sym.insert(0, sym_setindex(self.category, 1))
+ e_sym.insert(0, sym_setindex(self.category, 1))
+ rules.append(fmt_rule(f_sym, e_sym, links))
+ if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals:
+ return rules
+ # [X] Rule [X]
+ if not nt or not sym_isvar(f_sym[-1]):
+ f_sym.append(sym_setindex(self.category, last_index + 2))
+ e_sym.append(sym_setindex(self.category, last_index + 2))
+ rules.append(fmt_rule(f_sym, e_sym, links))
+ return rules
+
+ # Debugging
+ def dump_online_stats(self):
+ logger.info(self.phrases_f)
+ logger.info(self.phrases_e)
+ logger.info(self.phrases_fe)
+
+# Spans are _inclusive_ on both ends [i, j]
+# TODO: Replace all of this with bit vectors?
+def span_check(vec, i, j):
+ k = i
+ while k <= j:
+ if vec[k]:
+ return False
+ k += 1
+ return True
+
+def span_flip(vec, i, j):
+ k = i
+ while k <= j:
+ vec[k] = ~vec[k]
+ k += 1
+
+def fmt_rule(f_sym, e_sym, links):
+ a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links)
+ return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(sym_tostring(sym) for sym in f_sym),
+ ' '.join(sym_tostring(sym) for sym in e_sym),
+ a_str)
+
+ #(f, e, count, als) = e
+ #a = tuple('{0}-{1}'.format(packed/65536, packed%65536) for packed in als)
+ #logger.info("f: {0}, e: {1}, count: {2}, a: {3}".format(f, e, count, a))