diff options
author | Michael Denkowski <michael.j.denkowski@gmail.com> | 2013-01-26 21:12:25 -0500 |
---|---|---|
committer | Michael Denkowski <michael.j.denkowski@gmail.com> | 2013-01-26 21:12:25 -0500 |
commit | 0a6dbb8aefb1662a68f3f14f0c42a72150d8be03 (patch) | |
tree | 5b79d719f4f9e2f37ef73cc3d278ec6667c2b47b /python/src/sa/rulefactory.pxi | |
parent | ca3da3a815b6e85531d6ded07e7d6bec7852748c (diff) |
Online grammars now diff with incremental suffix array (except lex, TODO)
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 87 |
1 files changed, 53 insertions, 34 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index b95c23df..88f77a8d 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -29,6 +29,7 @@ FeatureContext = namedtuple('FeatureContext', OnlineFeatureContext = namedtuple('OnlineFeatureContext', ['fcount', + 'fsample_count', 'paircount', 'bilex' ]) @@ -272,6 +273,7 @@ cdef class HieroCachingRuleFactory: cdef IntList findexes1 cdef bint online + cdef samples_f cdef phrases_f cdef phrases_e cdef phrases_fe @@ -392,6 +394,9 @@ cdef class HieroCachingRuleFactory: # True after data is added self.online = False + # Keep track of everything that can be sampled: + self.samples_f = defaultdict(int) + # Phrase counts self.phrases_f = defaultdict(int) self.phrases_e = defaultdict(int) @@ -1173,15 +1178,24 @@ cdef class HieroCachingRuleFactory: # Online rule extraction and scoring if self.online: f_syms = tuple(word[0][0] for word in fwords) - for (f, e, spanlen) in self.online_match(f_syms, seen_phrases): - scores = self.scorer.score(FeatureContext( - f, e, 0, 0, 0, - spanlen, None, None, - fwords, self.fda, self.eda, - meta, - self.online_ctx_lookup(f, e))) - alignment = self.phrases_al[f][e] - yield Rule(self.category, f, e, scores, alignment) + for (f, lex_i, lex_j) in self.get_f_phrases(f_syms): + spanlen = (lex_j - lex_i) + 1 + if not sym_isvar(f[0]): + spanlen += 1 + if not sym_isvar(f[1]): + spanlen += 1 + for e in self.phrases_fe.get(f, ()): + if (f, e) not in seen_phrases: + # Don't add multiple instances of the same phrase here + seen_phrases.add((f, e)) + scores = self.scorer.score(FeatureContext( + f, e, 0, 0, 0, + spanlen, None, None, + fwords, self.fda, self.eda, + meta, + self.online_ctx_lookup(f, e))) + alignment = self.phrases_al[f][e] + yield Rule(self.category, f, e, scores, alignment) stop_time = monitor_cpu() logger.info("Total time for rule lookup, extraction, and scoring = %f seconds", (stop_time - start_time)) @@ -2006,6 +2020,12 @@ cdef class HieroCachingRuleFactory: continue extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False) + # Update possible phrases (samples) + # This could be more efficiently integrated with extraction + # at the cost of readability + for (f, lex_i, lex_j) in self.get_f_phrases(f_words): + self.samples_f[f] += 1 + # Update phrase counts for rule in rules: (f_ph, e_ph, al) = rule[:3] @@ -2115,30 +2135,33 @@ cdef class HieroCachingRuleFactory: for ph in self.phrases_e: logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph])) logger.info('FE') + self.dump_online_rules() + + def dump_online_rules(self): for ph in self.phrases_fe: for ph2 in self.phrases_fe[ph]: logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2])) - + # Lookup online stats for phrase pair (f, e). Return None if no match. # IMPORTANT: use get() to avoid adding items to defaultdict def online_ctx_lookup(self, f, e): if self.online: fcount = self.phrases_f.get(f, 0) + fsample_count = self.samples_f.get(f, 0) d = self.phrases_fe.get(f, None) paircount = d.get(e, 0) if d else 0 - if paircount > 0: - print 'Online support:', f, '|||', e - return OnlineFeatureContext(fcount, paircount, self.bilex_fe) + return OnlineFeatureContext(fcount, fsample_count, paircount, self.bilex_fe) return None - # Match source words against online data. - # Return (fphrase, ephrase, length) - def online_match(self, f_words, seen_phrases): - + # Find all phrases that we might try to extract + # (Used for EGivenFCoherent) + # Return set of (fphrase, lex_i, lex_j) + def get_f_phrases(self, f_words): + f_len = len(f_words) - matches = {} # (f, e) = len + phrases = set() # (fphrase, lex_i, lex_j) - def extract(f_i, f_j, wc, ntc, syms): + def extract(f_i, f_j, lex_i, lex_j, wc, ntc, syms): # Phrase extraction limits if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size: return @@ -2146,34 +2169,30 @@ cdef class HieroCachingRuleFactory: if wc + ntc < self.max_length: syms.append(f_words[f_j]) f = Phrase(syms) - for e in self.phrases_fe[f]: - if (f, e) not in seen_phrases: - matches[(f, e)] = (f_j - f_i) + 1 - extract(f_i, f_j + 1, wc + 1, ntc, syms) + new_lex_i = min(lex_i, f_j) + new_lex_j = max(lex_j, f_j) + phrases.add((f, new_lex_i, new_lex_j)) + extract(f_i, f_j + 1, new_lex_i, new_lex_j, wc + 1, ntc, syms) syms.pop() # Extend with existing non-terminal if syms and sym_isvar(syms[-1]): # Don't re-extract the same phrase - extract(f_i, f_j + 1, wc, ntc, syms) + extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc, syms) # Extend with new non-terminal if wc + ntc < self.max_length: if not syms or (ntc < self.max_nonterminals and not sym_isvar(syms[-1])): - syms.append(sym_setindex(self.category, ntc)) + syms.append(sym_setindex(self.category, ntc + 1)) f = Phrase(syms) if wc > 0: - for e in self.phrases_fe[f]: - if (f, e) not in seen_phrases: - matches[(f, e)] = (f_j - f_i) + 1 - extract(f_i, f_j + 1, wc, ntc + 1, syms) + phrases.add((f, lex_i, lex_j)) + extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc + 1, syms) syms.pop() # Try to extract phrases from every f index for f_i from 0 <= f_i < f_len: - extract(f_i, f_i, 0, 0, []) - - for line in sorted(' ||| '.join((str(f), str(e))) for (f, e) in matches): - print 'Online new:', line - return ((f, e, matches[(f, e)]) for (f, e) in matches) + extract(f_i, f_i, f_len, -1, 0, 0, []) + + return phrases # Spans are _inclusive_ on both ends [i, j] def span_check(vec, i, j): |