diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 122 |
1 files changed, 117 insertions, 5 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index b10d25dd..7063c2da 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -23,7 +23,17 @@ FeatureContext = namedtuple('FeatureContext', 'test_sentence', 'f_text', 'e_text', - 'meta' + 'meta', + 'online' + ]) + +OnlineFeatureContext = namedtuple('OnlineFeatureContext', + ['fcount', + 'fsample_count', + 'paircount', + 'bilex_f', + 'bilex_e', + 'bilex_fe' ]) cdef int PRECOMPUTE = 0 @@ -264,6 +274,8 @@ cdef class HieroCachingRuleFactory: cdef IntList findexes cdef IntList findexes1 + cdef bint online + cdef samples_f cdef phrases_f cdef phrases_e cdef phrases_fe @@ -381,6 +393,12 @@ cdef class HieroCachingRuleFactory: # Online stats + # True after data is added + self.online = False + + # Keep track of everything that can be sampled: + self.samples_f = defaultdict(int) + # Phrase counts self.phrases_f = defaultdict(int) self.phrases_e = defaultdict(int) @@ -969,6 +987,11 @@ cdef class HieroCachingRuleFactory: hit = 0 reachable_buffer = {} + # Phrase pairs processed by suffix array extractor. Do not re-extract + # during online extraction. This is probably the hackiest part of + # online grammar extraction. + seen_phrases = set() + # Do not cache between sentences self.rules.root = ExtendedTrieNode(phrase_location=PhraseLocation()) @@ -1124,7 +1147,12 @@ cdef class HieroCachingRuleFactory: f, e, count, fcount[f], num_samples, (k,i+spanlen), locs, input_match, fwords, self.fda, self.eda, - meta)) + meta, + # Include online stats. None if none. + self.online_ctx_lookup(f, e))) + # Phrase pair processed + if self.online: + seen_phrases.add((f, e)) yield Rule(self.category, f, e, scores, alignment) if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size: @@ -1148,7 +1176,29 @@ cdef class HieroCachingRuleFactory: for (i, alt, pathlen) in frontier_nodes: new_frontier.append((k, i, input_match + (i,), alt, pathlen, xnode, phrase +(xcat,), is_shadow_path)) frontier = new_frontier - + + # Online rule extraction and scoring + if self.online: + f_syms = tuple(word[0][0] for word in fwords) + for (f, lex_i, lex_j) in self.get_f_phrases(f_syms): + spanlen = (lex_j - lex_i) + 1 + if not sym_isvar(f[0]): + spanlen += 1 + if not sym_isvar(f[1]): + spanlen += 1 + for e in self.phrases_fe.get(f, ()): + if (f, e) not in seen_phrases: + # Don't add multiple instances of the same phrase here + seen_phrases.add((f, e)) + scores = self.scorer.score(FeatureContext( + f, e, 0, 0, 0, + spanlen, None, None, + fwords, self.fda, self.eda, + meta, + self.online_ctx_lookup(f, e))) + alignment = self.phrases_al[f][e] + yield Rule(self.category, f, e, scores, alignment) + stop_time = monitor_cpu() logger.info("Total time for rule lookup, extraction, and scoring = %f seconds", (stop_time - start_time)) gc.collect() @@ -1828,6 +1878,8 @@ cdef class HieroCachingRuleFactory: # (Extract rules, update counts) def add_instance(self, f_words, e_words, alignment): + self.online = True + # Rules extracted from this instance # Track span of lexical items (terminals) to make # sure we don't extract the same rule for the same @@ -1970,11 +2022,17 @@ cdef class HieroCachingRuleFactory: continue extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False) + # Update possible phrases (samples) + # This could be more efficiently integrated with extraction + # at the cost of readability + for (f, lex_i, lex_j) in self.get_f_phrases(f_words): + self.samples_f[f] += 1 + # Update phrase counts for rule in rules: (f_ph, e_ph, al) = rule[:3] self.phrases_f[f_ph] += 1 - self.phrases_e[e_ph] += 1 + self.phrases_e[e_ph] += 1 self.phrases_fe[f_ph][e_ph] += 1 if not self.phrases_al[f_ph][e_ph]: self.phrases_al[f_ph][e_ph] = al @@ -1987,7 +2045,6 @@ cdef class HieroCachingRuleFactory: for e_w in e_words: self.bilex_fe[f_w][e_w] += 1 - # Create a rule from source, target, non-terminals, and alignments def form_rule(self, f_i, e_i, f_span, e_span, nt, al): @@ -2080,10 +2137,65 @@ cdef class HieroCachingRuleFactory: for ph in self.phrases_e: logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph])) logger.info('FE') + self.dump_online_rules() + + def dump_online_rules(self): for ph in self.phrases_fe: for ph2 in self.phrases_fe[ph]: logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2])) + + # Lookup online stats for phrase pair (f, e). Return None if no match. + # IMPORTANT: use get() to avoid adding items to defaultdict + def online_ctx_lookup(self, f, e): + if self.online: + fcount = self.phrases_f.get(f, 0) + fsample_count = self.samples_f.get(f, 0) + d = self.phrases_fe.get(f, None) + paircount = d.get(e, 0) if d else 0 + return OnlineFeatureContext(fcount, fsample_count, paircount, self.bilex_f, self.bilex_e, self.bilex_fe) + return None + + # Find all phrases that we might try to extract + # (Used for EGivenFCoherent) + # Return set of (fphrase, lex_i, lex_j) + def get_f_phrases(self, f_words): + + f_len = len(f_words) + phrases = set() # (fphrase, lex_i, lex_j) + def extract(f_i, f_j, lex_i, lex_j, wc, ntc, syms): + # Phrase extraction limits + if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size: + return + # Extend with word + if wc + ntc < self.max_length: + syms.append(f_words[f_j]) + f = Phrase(syms) + new_lex_i = min(lex_i, f_j) + new_lex_j = max(lex_j, f_j) + phrases.add((f, new_lex_i, new_lex_j)) + extract(f_i, f_j + 1, new_lex_i, new_lex_j, wc + 1, ntc, syms) + syms.pop() + # Extend with existing non-terminal + if syms and sym_isvar(syms[-1]): + # Don't re-extract the same phrase + extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc, syms) + # Extend with new non-terminal + if wc + ntc < self.max_length: + if not syms or (ntc < self.max_nonterminals and not sym_isvar(syms[-1])): + syms.append(sym_setindex(self.category, ntc + 1)) + f = Phrase(syms) + if wc > 0: + phrases.add((f, lex_i, lex_j)) + extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc + 1, syms) + syms.pop() + + # Try to extract phrases from every f index + for f_i from 0 <= f_i < f_len: + extract(f_i, f_i, f_len, -1, 0, 0, []) + + return phrases + # Spans are _inclusive_ on both ends [i, j] def span_check(vec, i, j): k = i |