diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 424 |
1 files changed, 421 insertions, 3 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index 2d996581..d7fca750 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -23,7 +23,17 @@ FeatureContext = namedtuple('FeatureContext', 'test_sentence', 'f_text', 'e_text', - 'meta' + 'meta', + 'online' + ]) + +OnlineFeatureContext = namedtuple('OnlineFeatureContext', + ['fcount', + 'fsample_count', + 'paircount', + 'bilex_f', + 'bilex_e', + 'bilex_fe' ]) cdef int PRECOMPUTE = 0 @@ -265,6 +275,16 @@ cdef class HieroCachingRuleFactory: cdef IntList findexes cdef IntList findexes1 + cdef bint online + cdef samples_f + cdef phrases_f + cdef phrases_e + cdef phrases_fe + cdef phrases_al + cdef bilex_f + cdef bilex_e + cdef bilex_fe + def __cinit__(self, # compiled alignment object (REQUIRED) Alignment alignment, @@ -371,6 +391,25 @@ cdef class HieroCachingRuleFactory: self.findexes = IntList(initial_len=10) self.findexes1 = IntList(initial_len=10) + + # Online stats + + # True after data is added + self.online = False + + # Keep track of everything that can be sampled: + self.samples_f = defaultdict(int) + + # Phrase counts + self.phrases_f = defaultdict(int) + self.phrases_e = defaultdict(int) + self.phrases_fe = defaultdict(lambda: defaultdict(int)) + self.phrases_al = defaultdict(lambda: defaultdict(tuple)) + + # Bilexical counts + self.bilex_f = defaultdict(int) + self.bilex_e = defaultdict(int) + self.bilex_fe = defaultdict(lambda: defaultdict(int)) def configure(self, SuffixArray fsarray, DataArray edarray, Sampler sampler, Scorer scorer): @@ -950,6 +989,11 @@ cdef class HieroCachingRuleFactory: hit = 0 reachable_buffer = {} + # Phrase pairs processed by suffix array extractor. Do not re-extract + # during online extraction. This is probably the hackiest part of + # online grammar extraction. + seen_phrases = set() + # Do not cache between sentences self.rules.root = ExtendedTrieNode(phrase_location=PhraseLocation()) @@ -1108,7 +1152,12 @@ cdef class HieroCachingRuleFactory: f, e, count, fcount[f], num_samples, (k,i+spanlen), locs, input_match, fwords, self.fda, self.eda, - meta)) + meta, + # Include online stats. None if none. + self.online_ctx_lookup(f, e))) + # Phrase pair processed + if self.online: + seen_phrases.add((f, e)) yield Rule(self.category, f, e, scores, alignment) if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size: @@ -1132,7 +1181,29 @@ cdef class HieroCachingRuleFactory: for (i, alt, pathlen) in frontier_nodes: new_frontier.append((k, i, input_match + (i,), alt, pathlen, xnode, phrase +(xcat,), is_shadow_path)) frontier = new_frontier - + + # Online rule extraction and scoring + if self.online: + f_syms = tuple(word[0][0] for word in fwords) + for (f, lex_i, lex_j) in self.get_f_phrases(f_syms): + spanlen = (lex_j - lex_i) + 1 + if not sym_isvar(f[0]): + spanlen += 1 + if not sym_isvar(f[1]): + spanlen += 1 + for e in self.phrases_fe.get(f, ()): + if (f, e) not in seen_phrases: + # Don't add multiple instances of the same phrase here + seen_phrases.add((f, e)) + scores = self.scorer.score(FeatureContext( + f, e, 0, 0, 0, + spanlen, None, None, + fwords, self.fda, self.eda, + meta, + self.online_ctx_lookup(f, e))) + alignment = self.phrases_al[f][e] + yield Rule(self.category, f, e, scores, alignment) + stop_time = monitor_cpu() logger.info("Total time for rule lookup, extraction, and scoring = %f seconds", (stop_time - start_time)) gc.collect() @@ -1803,3 +1874,350 @@ cdef class HieroCachingRuleFactory: free(e_gap_high) return extracts + + # + # Online grammar extraction handling + # + + # Aggregate stats from a training instance + # (Extract rules, update counts) + def add_instance(self, f_words, e_words, alignment): + + self.online = True + + # Rules extracted from this instance + # Track span of lexical items (terminals) to make + # sure we don't extract the same rule for the same + # span more than once. + # (f, e, al, lex_f_i, lex_f_j) + rules = set() + + f_len = len(f_words) + e_len = len(e_words) + + # Pre-compute alignment info + al = [[] for i in range(f_len)] + fe_span = [[e_len + 1, -1] for i in range(f_len)] + ef_span = [[f_len + 1, -1] for i in range(e_len)] + for (f, e) in alignment: + al[f].append(e) + fe_span[f][0] = min(fe_span[f][0], e) + fe_span[f][1] = max(fe_span[f][1], e) + ef_span[e][0] = min(ef_span[e][0], f) + ef_span[e][1] = max(ef_span[e][1], f) + + # Target side word coverage + cover = [0] * e_len + # Non-terminal coverage + f_nt_cover = [0] * f_len + e_nt_cover = [0] * e_len + + # Extract all possible hierarchical phrases starting at a source index + # f_ i and j are current, e_ i and j are previous + # We care _considering_ f_j, so it is not yet in counts + def extract(f_i, f_j, e_i, e_j, min_bound, wc, links, nt, nt_open): + # Phrase extraction limits + if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size: + return + # Unaligned word + if not al[f_j]: + # Adjacent to non-terminal: extend (non-terminal now open) + if nt and nt[-1][2] == f_j - 1: + nt[-1][2] += 1 + extract(f_i, f_j + 1, e_i, e_j, min_bound, wc, links, nt, True) + nt[-1][2] -= 1 + # Unless non-terminal already open, always extend with word + # Make sure adding a word doesn't exceed length + if not nt_open and wc < self.max_length: + extract(f_i, f_j + 1, e_i, e_j, min_bound, wc + 1, links, nt, False) + return + # Aligned word + link_i = fe_span[f_j][0] + link_j = fe_span[f_j][1] + new_e_i = min(link_i, e_i) + new_e_j = max(link_j, e_j) + # Check reverse links of newly covered words to see if they violate left + # bound (return) or extend minimum right bound for chunk + new_min_bound = min_bound + # First aligned word creates span + if e_j == -1: + for i from new_e_i <= i <= new_e_j: + if ef_span[i][0] < f_i: + return + new_min_bound = max(new_min_bound, ef_span[i][1]) + # Other aligned words extend span + else: + for i from new_e_i <= i < e_i: + if ef_span[i][0] < f_i: + return + new_min_bound = max(new_min_bound, ef_span[i][1]) + for i from e_j < i <= new_e_j: + if ef_span[i][0] < f_i: + return + new_min_bound = max(new_min_bound, ef_span[i][1]) + # Extract, extend with word (unless non-terminal open) + if not nt_open: + nt_collision = False + for link in al[f_j]: + if e_nt_cover[link]: + nt_collision = True + # Non-terminal collisions block word extraction and extension, but + # may be okay for continuing non-terminals + if not nt_collision and wc < self.max_length: + plus_links = [] + for link in al[f_j]: + plus_links.append((f_j, link)) + cover[link] += 1 + links.append(plus_links) + if links and f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) + links.pop() + for link in al[f_j]: + cover[link] -= 1 + # Try to add a word to current non-terminal (if any), extract, extend + if nt and nt[-1][2] == f_j - 1: + # Add to non-terminal, checking for collisions + old_last_nt = nt[-1][:] + nt[-1][2] = f_j + if link_i < nt[-1][3]: + if not span_check(cover, link_i, nt[-1][3] - 1): + nt[-1] = old_last_nt + return + span_inc(cover, link_i, nt[-1][3] - 1) + span_inc(e_nt_cover, link_i, nt[-1][3] - 1) + nt[-1][3] = link_i + if link_j > nt[-1][4]: + if not span_check(cover, nt[-1][4] + 1, link_j): + nt[-1] = old_last_nt + return + span_inc(cover, nt[-1][4] + 1, link_j) + span_inc(e_nt_cover, nt[-1][4] + 1, link_j) + nt[-1][4] = link_j + if links and f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False) + nt[-1] = old_last_nt + if link_i < nt[-1][3]: + span_dec(cover, link_i, nt[-1][3] - 1) + span_dec(e_nt_cover, link_i, nt[-1][3] - 1) + if link_j > nt[-1][4]: + span_dec(cover, nt[-1][4] + 1, link_j) + span_dec(e_nt_cover, nt[-1][4] + 1, link_j) + # Try to start a new non-terminal, extract, extend + if (not nt or f_j - nt[-1][2] > 1) and wc < self.max_length and len(nt) < self.max_nonterminals: + # Check for collisions + if not span_check(cover, link_i, link_j): + return + span_inc(cover, link_i, link_j) + span_inc(e_nt_cover, link_i, link_j) + nt.append([(nt[-1][0] + 1) if nt else 1, f_j, f_j, link_i, link_j]) + # Require at least one word in phrase + if links and f_j >= new_min_bound: + rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links)) + extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False) + nt.pop() + span_dec(cover, link_i, link_j) + span_dec(e_nt_cover, link_i, link_j) + + # Try to extract phrases from every f index + for f_i from 0 <= f_i < f_len: + # Skip if phrases won't be tight on left side + if not al[f_i]: + continue + extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False) + + # Update possible phrases (samples) + # This could be more efficiently integrated with extraction + # at the cost of readability + for (f, lex_i, lex_j) in self.get_f_phrases(f_words): + self.samples_f[f] += 1 + + # Update phrase counts + for rule in rules: + (f_ph, e_ph, al) = rule[:3] + self.phrases_f[f_ph] += 1 + self.phrases_e[e_ph] += 1 + self.phrases_fe[f_ph][e_ph] += 1 + if not self.phrases_al[f_ph][e_ph]: + self.phrases_al[f_ph][e_ph] = al + + # Update Bilexical counts + for e_w in e_words: + self.bilex_e[e_w] += 1 + for f_w in f_words: + self.bilex_f[f_w] += 1 + for e_w in e_words: + self.bilex_fe[f_w][e_w] += 1 + + # Create a rule from source, target, non-terminals, and alignments + def form_rule(self, f_i, e_i, f_span, e_span, nt, al): + + # Substitute in non-terminals + nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) + f_sym = list(f_span[:]) + off = f_i + for next_nt in nt: + nt_len = (next_nt[2] - next_nt[1]) + 1 + i = 0 + while i < nt_len: + f_sym.pop(next_nt[1] - off) + i += 1 + f_sym.insert(next_nt[1] - off, sym_setindex(self.category, next_nt[0])) + off += (nt_len - 1) + e_sym = list(e_span[:]) + off = e_i + for next_nt in nt_inv: + nt_len = (next_nt[4] - next_nt[3]) + 1 + i = 0 + while i < nt_len: + e_sym.pop(next_nt[3] - off) + i += 1 + e_sym.insert(next_nt[3] - off, sym_setindex(self.category, next_nt[0])) + off += (nt_len - 1) + + # Adjusting alignment links takes some doing + links = [list(link) for sub in al for link in sub] + links_inv = sorted(links, cmp=lambda x, y: cmp(x[1], y[1])) + links_len = len(links) + nt_len = len(nt) + nt_i = 0 + off = f_i + i = 0 + while i < links_len: + while nt_i < nt_len and links[i][0] > nt[nt_i][1]: + off += (nt[nt_i][2] - nt[nt_i][1]) + nt_i += 1 + links[i][0] -= off + i += 1 + nt_i = 0 + off = e_i + i = 0 + while i < links_len: + while nt_i < nt_len and links_inv[i][1] > nt_inv[nt_i][3]: + off += (nt_inv[nt_i][4] - nt_inv[nt_i][3]) + nt_i += 1 + links_inv[i][1] -= off + i += 1 + + # Find lexical span + lex_f_i = f_i + lex_f_j = f_i + (len(f_span) - 1) + if nt: + if nt[0][1] == lex_f_i: + lex_f_i += (nt[0][2] - nt[0][1]) + 1 + if nt[-1][2] == lex_f_j: + lex_f_j -= (nt[-1][2] - nt[-1][1]) + 1 + + # Create rule (f_phrase, e_phrase, links, f_link_min, f_link_max) + f = Phrase(f_sym) + e = Phrase(e_sym) + a = tuple(self.alignment.link(i, j) for (i, j) in links) + return (f, e, a, lex_f_i, lex_f_j) + + # Rule string from rule + def fmt_rule(self, f, e, a): + a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a) + return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str) + + # Debugging + def dump_online_stats(self): + logger.info('------------------------------') + logger.info(' Online Stats ') + logger.info('------------------------------') + logger.info('f') + for w in self.bilex_f: + logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w])) + logger.info('e') + for w in self.bilex_e: + logger.info(sym_tostring(w) + ' : ' + str(self.bilex_e[w])) + logger.info('fe') + for w in self.bilex_fe: + for w2 in self.bilex_fe[w]: + logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2])) + logger.info('F') + for ph in self.phrases_f: + logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph])) + logger.info('E') + for ph in self.phrases_e: + logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph])) + logger.info('FE') + self.dump_online_rules() + + def dump_online_rules(self): + for ph in self.phrases_fe: + for ph2 in self.phrases_fe[ph]: + logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2])) + + # Lookup online stats for phrase pair (f, e). Return None if no match. + # IMPORTANT: use get() to avoid adding items to defaultdict + def online_ctx_lookup(self, f, e): + if self.online: + fcount = self.phrases_f.get(f, 0) + fsample_count = self.samples_f.get(f, 0) + d = self.phrases_fe.get(f, None) + paircount = d.get(e, 0) if d else 0 + return OnlineFeatureContext(fcount, fsample_count, paircount, self.bilex_f, self.bilex_e, self.bilex_fe) + return None + + # Find all phrases that we might try to extract + # (Used for EGivenFCoherent) + # Return set of (fphrase, lex_i, lex_j) + def get_f_phrases(self, f_words): + + f_len = len(f_words) + phrases = set() # (fphrase, lex_i, lex_j) + + def extract(f_i, f_j, lex_i, lex_j, wc, ntc, syms): + # Phrase extraction limits + if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size: + return + # Extend with word + if wc + ntc < self.max_length: + syms.append(f_words[f_j]) + f = Phrase(syms) + new_lex_i = min(lex_i, f_j) + new_lex_j = max(lex_j, f_j) + phrases.add((f, new_lex_i, new_lex_j)) + extract(f_i, f_j + 1, new_lex_i, new_lex_j, wc + 1, ntc, syms) + syms.pop() + # Extend with existing non-terminal + if syms and sym_isvar(syms[-1]): + # Don't re-extract the same phrase + extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc, syms) + # Extend with new non-terminal + if wc + ntc < self.max_length: + if not syms or (ntc < self.max_nonterminals and not sym_isvar(syms[-1])): + syms.append(sym_setindex(self.category, ntc + 1)) + f = Phrase(syms) + if wc > 0: + phrases.add((f, lex_i, lex_j)) + extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc + 1, syms) + syms.pop() + + # Try to extract phrases from every f index + for f_i from 0 <= f_i < f_len: + extract(f_i, f_i, f_len, -1, 0, 0, []) + + return phrases + +# Spans are _inclusive_ on both ends [i, j] +def span_check(vec, i, j): + k = i + while k <= j: + if vec[k]: + return False + k += 1 + return True + +def span_inc(vec, i, j): + k = i + while k <= j: + vec[k] += 1 + k += 1 + +def span_dec(vec, i, j): + k = i + while k <= j: + vec[k] -= 1 + k += 1 |