summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi424
1 files changed, 421 insertions, 3 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index 2d996581..d7fca750 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -23,7 +23,17 @@ FeatureContext = namedtuple('FeatureContext',
'test_sentence',
'f_text',
'e_text',
- 'meta'
+ 'meta',
+ 'online'
+ ])
+
+OnlineFeatureContext = namedtuple('OnlineFeatureContext',
+ ['fcount',
+ 'fsample_count',
+ 'paircount',
+ 'bilex_f',
+ 'bilex_e',
+ 'bilex_fe'
])
cdef int PRECOMPUTE = 0
@@ -265,6 +275,16 @@ cdef class HieroCachingRuleFactory:
cdef IntList findexes
cdef IntList findexes1
+ cdef bint online
+ cdef samples_f
+ cdef phrases_f
+ cdef phrases_e
+ cdef phrases_fe
+ cdef phrases_al
+ cdef bilex_f
+ cdef bilex_e
+ cdef bilex_fe
+
def __cinit__(self,
# compiled alignment object (REQUIRED)
Alignment alignment,
@@ -371,6 +391,25 @@ cdef class HieroCachingRuleFactory:
self.findexes = IntList(initial_len=10)
self.findexes1 = IntList(initial_len=10)
+
+ # Online stats
+
+ # True after data is added
+ self.online = False
+
+ # Keep track of everything that can be sampled:
+ self.samples_f = defaultdict(int)
+
+ # Phrase counts
+ self.phrases_f = defaultdict(int)
+ self.phrases_e = defaultdict(int)
+ self.phrases_fe = defaultdict(lambda: defaultdict(int))
+ self.phrases_al = defaultdict(lambda: defaultdict(tuple))
+
+ # Bilexical counts
+ self.bilex_f = defaultdict(int)
+ self.bilex_e = defaultdict(int)
+ self.bilex_fe = defaultdict(lambda: defaultdict(int))
def configure(self, SuffixArray fsarray, DataArray edarray,
Sampler sampler, Scorer scorer):
@@ -950,6 +989,11 @@ cdef class HieroCachingRuleFactory:
hit = 0
reachable_buffer = {}
+ # Phrase pairs processed by suffix array extractor. Do not re-extract
+ # during online extraction. This is probably the hackiest part of
+ # online grammar extraction.
+ seen_phrases = set()
+
# Do not cache between sentences
self.rules.root = ExtendedTrieNode(phrase_location=PhraseLocation())
@@ -1108,7 +1152,12 @@ cdef class HieroCachingRuleFactory:
f, e, count, fcount[f], num_samples,
(k,i+spanlen), locs, input_match,
fwords, self.fda, self.eda,
- meta))
+ meta,
+ # Include online stats. None if none.
+ self.online_ctx_lookup(f, e)))
+ # Phrase pair processed
+ if self.online:
+ seen_phrases.add((f, e))
yield Rule(self.category, f, e, scores, alignment)
if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size:
@@ -1132,7 +1181,29 @@ cdef class HieroCachingRuleFactory:
for (i, alt, pathlen) in frontier_nodes:
new_frontier.append((k, i, input_match + (i,), alt, pathlen, xnode, phrase +(xcat,), is_shadow_path))
frontier = new_frontier
-
+
+ # Online rule extraction and scoring
+ if self.online:
+ f_syms = tuple(word[0][0] for word in fwords)
+ for (f, lex_i, lex_j) in self.get_f_phrases(f_syms):
+ spanlen = (lex_j - lex_i) + 1
+ if not sym_isvar(f[0]):
+ spanlen += 1
+ if not sym_isvar(f[1]):
+ spanlen += 1
+ for e in self.phrases_fe.get(f, ()):
+ if (f, e) not in seen_phrases:
+ # Don't add multiple instances of the same phrase here
+ seen_phrases.add((f, e))
+ scores = self.scorer.score(FeatureContext(
+ f, e, 0, 0, 0,
+ spanlen, None, None,
+ fwords, self.fda, self.eda,
+ meta,
+ self.online_ctx_lookup(f, e)))
+ alignment = self.phrases_al[f][e]
+ yield Rule(self.category, f, e, scores, alignment)
+
stop_time = monitor_cpu()
logger.info("Total time for rule lookup, extraction, and scoring = %f seconds", (stop_time - start_time))
gc.collect()
@@ -1803,3 +1874,350 @@ cdef class HieroCachingRuleFactory:
free(e_gap_high)
return extracts
+
+ #
+ # Online grammar extraction handling
+ #
+
+ # Aggregate stats from a training instance
+ # (Extract rules, update counts)
+ def add_instance(self, f_words, e_words, alignment):
+
+ self.online = True
+
+ # Rules extracted from this instance
+ # Track span of lexical items (terminals) to make
+ # sure we don't extract the same rule for the same
+ # span more than once.
+ # (f, e, al, lex_f_i, lex_f_j)
+ rules = set()
+
+ f_len = len(f_words)
+ e_len = len(e_words)
+
+ # Pre-compute alignment info
+ al = [[] for i in range(f_len)]
+ fe_span = [[e_len + 1, -1] for i in range(f_len)]
+ ef_span = [[f_len + 1, -1] for i in range(e_len)]
+ for (f, e) in alignment:
+ al[f].append(e)
+ fe_span[f][0] = min(fe_span[f][0], e)
+ fe_span[f][1] = max(fe_span[f][1], e)
+ ef_span[e][0] = min(ef_span[e][0], f)
+ ef_span[e][1] = max(ef_span[e][1], f)
+
+ # Target side word coverage
+ cover = [0] * e_len
+ # Non-terminal coverage
+ f_nt_cover = [0] * f_len
+ e_nt_cover = [0] * e_len
+
+ # Extract all possible hierarchical phrases starting at a source index
+ # f_ i and j are current, e_ i and j are previous
+ # We care _considering_ f_j, so it is not yet in counts
+ def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open):
+ # Phrase extraction limits
+ if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size:
+ return
+ # Unaligned word
+ if not al[f_j]:
+ # Adjacent to non-terminal: extend (non-terminal now open)
+ if nt and nt[-1][2] == f_j - 1:
+ nt[-1][2] += 1
+ extract(f_i, f_j + 1, e_i, e_j, min_bound, wc, links, nt, True)
+ nt[-1][2] -= 1
+ # Unless non-terminal already open, always extend with word
+ # Make sure adding a word doesn't exceed length
+ if not nt_open and wc < self.max_length:
+ extract(f_i, f_j + 1, e_i, e_j, min_bound, wc + 1, links, nt, False)
+ return
+ # Aligned word
+ link_i = fe_span[f_j][0]
+ link_j = fe_span[f_j][1]
+ new_e_i = min(link_i, e_i)
+ new_e_j = max(link_j, e_j)
+ # Check reverse links of newly covered words to see if they violate left
+ # bound (return) or extend minimum right bound for chunk
+ new_min_bound = min_bound
+ # First aligned word creates span
+ if e_j == -1:
+ for i from new_e_i <= i <= new_e_j:
+ if ef_span[i][0] < f_i:
+ return
+ new_min_bound = max(new_min_bound, ef_span[i][1])
+ # Other aligned words extend span
+ else:
+ for i from new_e_i <= i < e_i:
+ if ef_span[i][0] < f_i:
+ return
+ new_min_bound = max(new_min_bound, ef_span[i][1])
+ for i from e_j < i <= new_e_j:
+ if ef_span[i][0] < f_i:
+ return
+ new_min_bound = max(new_min_bound, ef_span[i][1])
+ # Extract, extend with word (unless non-terminal open)
+ if not nt_open:
+ nt_collision = False
+ for link in al[f_j]:
+ if e_nt_cover[link]:
+ nt_collision = True
+ # Non-terminal collisions block word extraction and extension, but
+ # may be okay for continuing non-terminals
+ if not nt_collision and wc < self.max_length:
+ plus_links = []
+ for link in al[f_j]:
+ plus_links.append((f_j, link))
+ cover[link] += 1
+ links.append(plus_links)
+ if links and f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
+ links.pop()
+ for link in al[f_j]:
+ cover[link] -= 1
+ # Try to add a word to current non-terminal (if any), extract, extend
+ if nt and nt[-1][2] == f_j - 1:
+ # Add to non-terminal, checking for collisions
+ old_last_nt = nt[-1][:]
+ nt[-1][2] = f_j
+ if link_i < nt[-1][3]:
+ if not span_check(cover, link_i, nt[-1][3] - 1):
+ nt[-1] = old_last_nt
+ return
+ span_inc(cover, link_i, nt[-1][3] - 1)
+ span_inc(e_nt_cover, link_i, nt[-1][3] - 1)
+ nt[-1][3] = link_i
+ if link_j > nt[-1][4]:
+ if not span_check(cover, nt[-1][4] + 1, link_j):
+ nt[-1] = old_last_nt
+ return
+ span_inc(cover, nt[-1][4] + 1, link_j)
+ span_inc(e_nt_cover, nt[-1][4] + 1, link_j)
+ nt[-1][4] = link_j
+ if links and f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
+ nt[-1] = old_last_nt
+ if link_i < nt[-1][3]:
+ span_dec(cover, link_i, nt[-1][3] - 1)
+ span_dec(e_nt_cover, link_i, nt[-1][3] - 1)
+ if link_j > nt[-1][4]:
+ span_dec(cover, nt[-1][4] + 1, link_j)
+ span_dec(e_nt_cover, nt[-1][4] + 1, link_j)
+ # Try to start a new non-terminal, extract, extend
+ if (not nt or f_j - nt[-1][2] > 1) and wc < self.max_length and len(nt) < self.max_nonterminals:
+ # Check for collisions
+ if not span_check(cover, link_i, link_j):
+ return
+ span_inc(cover, link_i, link_j)
+ span_inc(e_nt_cover, link_i, link_j)
+ nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j])
+ # Require at least one word in phrase
+ if links and f_j >= new_min_bound:
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
+ nt.pop()
+ span_dec(cover, link_i, link_j)
+ span_dec(e_nt_cover, link_i, link_j)
+
+ # Try to extract phrases from every f index
+ for f_i from 0 <= f_i < f_len:
+ # Skip if phrases won't be tight on left side
+ if not al[f_i]:
+ continue
+ extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False)
+
+ # Update possible phrases (samples)
+ # This could be more efficiently integrated with extraction
+ # at the cost of readability
+ for (f, lex_i, lex_j) in self.get_f_phrases(f_words):
+ self.samples_f[f] += 1
+
+ # Update phrase counts
+ for rule in rules:
+ (f_ph, e_ph, al) = rule[:3]
+ self.phrases_f[f_ph] += 1
+ self.phrases_e[e_ph] += 1
+ self.phrases_fe[f_ph][e_ph] += 1
+ if not self.phrases_al[f_ph][e_ph]:
+ self.phrases_al[f_ph][e_ph] = al
+
+ # Update Bilexical counts
+ for e_w in e_words:
+ self.bilex_e[e_w] += 1
+ for f_w in f_words:
+ self.bilex_f[f_w] += 1
+ for e_w in e_words:
+ self.bilex_fe[f_w][e_w] += 1
+
+ # Create a rule from source, target, non-terminals, and alignments
+ def form_rule(self, f_i, e_i, f_span, e_span, nt, al):
+
+ # Substitute in non-terminals
+ nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
+ f_sym = list(f_span[:])
+ off = f_i
+ for next_nt in nt:
+ nt_len = (next_nt[2] - next_nt[1]) + 1
+ i = 0
+ while i < nt_len:
+ f_sym.pop(next_nt[1] - off)
+ i += 1
+ f_sym.insert(next_nt[1] - off, sym_setindex(self.category, next_nt[0]))
+ off += (nt_len - 1)
+ e_sym = list(e_span[:])
+ off = e_i
+ for next_nt in nt_inv:
+ nt_len = (next_nt[4] - next_nt[3]) + 1
+ i = 0
+ while i < nt_len:
+ e_sym.pop(next_nt[3] - off)
+ i += 1
+ e_sym.insert(next_nt[3] - off, sym_setindex(self.category, next_nt[0]))
+ off += (nt_len - 1)
+
+ # Adjusting alignment links takes some doing
+ links = [list(link) for sub in al for link in sub]
+ links_inv = sorted(links, cmp=lambda x, y: cmp(x[1], y[1]))
+ links_len = len(links)
+ nt_len = len(nt)
+ nt_i = 0
+ off = f_i
+ i = 0
+ while i < links_len:
+ while nt_i < nt_len and links[i][0] > nt[nt_i][1]:
+ off += (nt[nt_i][2] - nt[nt_i][1])
+ nt_i += 1
+ links[i][0] -= off
+ i += 1
+ nt_i = 0
+ off = e_i
+ i = 0
+ while i < links_len:
+ while nt_i < nt_len and links_inv[i][1] > nt_inv[nt_i][3]:
+ off += (nt_inv[nt_i][4] - nt_inv[nt_i][3])
+ nt_i += 1
+ links_inv[i][1] -= off
+ i += 1
+
+ # Find lexical span
+ lex_f_i = f_i
+ lex_f_j = f_i + (len(f_span) - 1)
+ if nt:
+ if nt[0][1] == lex_f_i:
+ lex_f_i += (nt[0][2] - nt[0][1]) + 1
+ if nt[-1][2] == lex_f_j:
+ lex_f_j -= (nt[-1][2] - nt[-1][1]) + 1
+
+ # Create rule (f_phrase, e_phrase, links, f_link_min, f_link_max)
+ f = Phrase(f_sym)
+ e = Phrase(e_sym)
+ a = tuple(self.alignment.link(i, j) for (i, j) in links)
+ return (f, e, a, lex_f_i, lex_f_j)
+
+ # Rule string from rule
+ def fmt_rule(self, f, e, a):
+ a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a)
+ return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str)
+
+ # Debugging
+ def dump_online_stats(self):
+ logger.info('------------------------------')
+ logger.info(' Online Stats ')
+ logger.info('------------------------------')
+ logger.info('f')
+ for w in self.bilex_f:
+ logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w]))
+ logger.info('e')
+ for w in self.bilex_e:
+ logger.info(sym_tostring(w) + ' : ' + str(self.bilex_e[w]))
+ logger.info('fe')
+ for w in self.bilex_fe:
+ for w2 in self.bilex_fe[w]:
+ logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2]))
+ logger.info('F')
+ for ph in self.phrases_f:
+ logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph]))
+ logger.info('E')
+ for ph in self.phrases_e:
+ logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph]))
+ logger.info('FE')
+ self.dump_online_rules()
+
+ def dump_online_rules(self):
+ for ph in self.phrases_fe:
+ for ph2 in self.phrases_fe[ph]:
+ logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2]))
+
+ # Lookup online stats for phrase pair (f, e). Return None if no match.
+ # IMPORTANT: use get() to avoid adding items to defaultdict
+ def online_ctx_lookup(self, f, e):
+ if self.online:
+ fcount = self.phrases_f.get(f, 0)
+ fsample_count = self.samples_f.get(f, 0)
+ d = self.phrases_fe.get(f, None)
+ paircount = d.get(e, 0) if d else 0
+ return OnlineFeatureContext(fcount, fsample_count, paircount, self.bilex_f, self.bilex_e, self.bilex_fe)
+ return None
+
+ # Find all phrases that we might try to extract
+ # (Used for EGivenFCoherent)
+ # Return set of (fphrase, lex_i, lex_j)
+ def get_f_phrases(self, f_words):
+
+ f_len = len(f_words)
+ phrases = set() # (fphrase, lex_i, lex_j)
+
+ def extract(f_i, f_j, lex_i, lex_j, wc, ntc, syms):
+ # Phrase extraction limits
+ if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size:
+ return
+ # Extend with word
+ if wc + ntc < self.max_length:
+ syms.append(f_words[f_j])
+ f = Phrase(syms)
+ new_lex_i = min(lex_i, f_j)
+ new_lex_j = max(lex_j, f_j)
+ phrases.add((f, new_lex_i, new_lex_j))
+ extract(f_i, f_j + 1, new_lex_i, new_lex_j, wc + 1, ntc, syms)
+ syms.pop()
+ # Extend with existing non-terminal
+ if syms and sym_isvar(syms[-1]):
+ # Don't re-extract the same phrase
+ extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc, syms)
+ # Extend with new non-terminal
+ if wc + ntc < self.max_length:
+ if not syms or (ntc < self.max_nonterminals and not sym_isvar(syms[-1])):
+ syms.append(sym_setindex(self.category, ntc + 1))
+ f = Phrase(syms)
+ if wc > 0:
+ phrases.add((f, lex_i, lex_j))
+ extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc + 1, syms)
+ syms.pop()
+
+ # Try to extract phrases from every f index
+ for f_i from 0 <= f_i < f_len:
+ extract(f_i, f_i, f_len, -1, 0, 0, [])
+
+ return phrases
+
+# Spans are _inclusive_ on both ends [i, j]
+def span_check(vec, i, j):
+ k = i
+ while k <= j:
+ if vec[k]:
+ return False
+ k += 1
+ return True
+
+def span_inc(vec, i, j):
+ k = i
+ while k <= j:
+ vec[k] += 1
+ k += 1
+
+def span_dec(vec, i, j):
+ k = i
+ while k <= j:
+ vec[k] -= 1
+ k += 1