diff options
Diffstat (limited to 'python/cdec/sa/rulefactory.pxi')
-rw-r--r-- | python/cdec/sa/rulefactory.pxi | 122 |
1 files changed, 51 insertions, 71 deletions
diff --git a/python/cdec/sa/rulefactory.pxi b/python/cdec/sa/rulefactory.pxi index 10bb9737..a96b01e7 100644 --- a/python/cdec/sa/rulefactory.pxi +++ b/python/cdec/sa/rulefactory.pxi @@ -36,6 +36,31 @@ OnlineFeatureContext = namedtuple('OnlineFeatureContext', 'bilex_fe' ]) +cdef class OnlineStats: + cdef public samples_f + cdef public phrases_f + cdef public phrases_e + cdef public phrases_fe + cdef public phrases_al + cdef public bilex_f + cdef public bilex_e + cdef public bilex_fe + + def __cinit__(self): + # Keep track of everything that can be sampled: + self.samples_f = defaultdict(int) + + # Phrase counts + self.phrases_f = defaultdict(int) + self.phrases_e = defaultdict(int) + self.phrases_fe = defaultdict(lambda: defaultdict(int)) + self.phrases_al = defaultdict(lambda: defaultdict(tuple)) + + # Bilexical counts + self.bilex_f = defaultdict(int) + self.bilex_e = defaultdict(int) + self.bilex_fe = defaultdict(lambda: defaultdict(int)) + cdef int PRECOMPUTE = 0 cdef int MERGE = 1 cdef int BAEZA_YATES = 2 @@ -276,14 +301,7 @@ cdef class HieroCachingRuleFactory: cdef IntList findexes1 cdef bint online - cdef samples_f - cdef phrases_f - cdef phrases_e - cdef phrases_fe - cdef phrases_al - cdef bilex_f - cdef bilex_e - cdef bilex_fe + cdef online_stats def __cinit__(self, # compiled alignment object (REQUIRED) @@ -396,21 +414,8 @@ cdef class HieroCachingRuleFactory: # True after data is added self.online = False + self.online_stats = defaultdict(OnlineStats) - # Keep track of everything that can be sampled: - self.samples_f = defaultdict(int) - - # Phrase counts - self.phrases_f = defaultdict(int) - self.phrases_e = defaultdict(int) - self.phrases_fe = defaultdict(lambda: defaultdict(int)) - self.phrases_al = defaultdict(lambda: defaultdict(tuple)) - - # Bilexical counts - self.bilex_f = defaultdict(int) - self.bilex_e = defaultdict(int) - self.bilex_fe = defaultdict(lambda: defaultdict(int)) - def configure(self, SuffixArray fsarray, DataArray edarray, Sampler sampler, Scorer scorer): '''This gives the RuleFactory access to the Context object. @@ -970,7 +975,7 @@ cdef class HieroCachingRuleFactory: candidate.append([next_id,curr[1]+jump]) return sorted(result); - def input(self, fwords, meta): + def input(self, fwords, meta, ctx_name=None): '''When this function is called on the RuleFactory, it looks up all of the rules that can be used to translate the input sentence''' @@ -1154,7 +1159,7 @@ cdef class HieroCachingRuleFactory: fwords, self.fda, self.eda, meta, # Include online stats. None if none. - self.online_ctx_lookup(f, e))) + self.online_ctx_lookup(f, e, ctx_name))) # Phrase pair processed if self.online: seen_phrases.add((f, e)) @@ -1184,6 +1189,7 @@ cdef class HieroCachingRuleFactory: # Online rule extraction and scoring if self.online: + stats = self.online_stats[ctx_name] f_syms = tuple(word[0][0] for word in fwords) for f, lex_i, lex_j in self.get_f_phrases(f_syms): spanlen = (lex_j - lex_i) + 1 @@ -1191,7 +1197,7 @@ cdef class HieroCachingRuleFactory: spanlen += 1 if not sym_isvar(f[1]): spanlen += 1 - for e in self.phrases_fe.get(f, ()): + for e in stats.phrases_fe.get(f, ()): if (f, e) not in seen_phrases: # Don't add multiple instances of the same phrase here seen_phrases.add((f, e)) @@ -1200,8 +1206,8 @@ cdef class HieroCachingRuleFactory: spanlen, None, None, fwords, self.fda, self.eda, meta, - self.online_ctx_lookup(f, e))) - alignment = self.phrases_al[f][e] + self.online_ctx_lookup(f, e, ctx_name))) + alignment = stats.phrases_al[f][e] yield Rule(self.category, f, e, scores, alignment) stop_time = monitor_cpu() @@ -1882,7 +1888,7 @@ cdef class HieroCachingRuleFactory: # Aggregate stats from a training instance # (Extract rules, update counts) - def add_instance(self, f_words, e_words, alignment): + def add_instance(self, f_words, e_words, alignment, ctx_name=None): self.online = True @@ -2028,28 +2034,30 @@ cdef class HieroCachingRuleFactory: continue extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False) + stats = self.online_stats[ctx_name] + # Update possible phrases (samples) # This could be more efficiently integrated with extraction # at the cost of readability for f, lex_i, lex_j in self.get_f_phrases(f_words): - self.samples_f[f] += 1 + stats.samples_f[f] += 1 # Update phrase counts for rule in rules: (f_ph, e_ph, al) = rule[:3] - self.phrases_f[f_ph] += 1 - self.phrases_e[e_ph] += 1 - self.phrases_fe[f_ph][e_ph] += 1 - if not self.phrases_al[f_ph][e_ph]: - self.phrases_al[f_ph][e_ph] = al + stats.phrases_f[f_ph] += 1 + stats.phrases_e[e_ph] += 1 + stats.phrases_fe[f_ph][e_ph] += 1 + if not stats.phrases_al[f_ph][e_ph]: + stats.phrases_al[f_ph][e_ph] = al # Update Bilexical counts for e_w in e_words: - self.bilex_e[e_w] += 1 + stats.bilex_e[e_w] += 1 for f_w in f_words: - self.bilex_f[f_w] += 1 + stats.bilex_f[f_w] += 1 for e_w in e_words: - self.bilex_fe[f_w][e_w] += 1 + stats.bilex_fe[f_w][e_w] += 1 # Create a rule from source, target, non-terminals, and alignments def form_rule(self, f_i, e_i, f_span, e_span, nt, al): @@ -2121,44 +2129,16 @@ cdef class HieroCachingRuleFactory: a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a) return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str) - # Debugging - def dump_online_stats(self): - logger.info('------------------------------') - logger.info(' Online Stats ') - logger.info('------------------------------') - logger.info('f') - for w in self.bilex_f: - logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w])) - logger.info('e') - for w in self.bilex_e: - logger.info(sym_tostring(w) + ' : ' + str(self.bilex_e[w])) - logger.info('fe') - for w in self.bilex_fe: - for w2 in self.bilex_fe[w]: - logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2])) - logger.info('F') - for ph in self.phrases_f: - logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph])) - logger.info('E') - for ph in self.phrases_e: - logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph])) - logger.info('FE') - self.dump_online_rules() - - def dump_online_rules(self): - for ph in self.phrases_fe: - for ph2 in self.phrases_fe[ph]: - logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2])) - # Lookup online stats for phrase pair (f, e). Return None if no match. # IMPORTANT: use get() to avoid adding items to defaultdict - def online_ctx_lookup(self, f, e): + def online_ctx_lookup(self, f, e, ctx_name=None): if self.online: - fcount = self.phrases_f.get(f, 0) - fsample_count = self.samples_f.get(f, 0) - d = self.phrases_fe.get(f, None) + stats = self.online_stats[ctx_name] + fcount = stats.phrases_f.get(f, 0) + fsample_count = stats.samples_f.get(f, 0) + d = stats.phrases_fe.get(f, None) paircount = d.get(e, 0) if d else 0 - return OnlineFeatureContext(fcount, fsample_count, paircount, self.bilex_f, self.bilex_e, self.bilex_fe) + return OnlineFeatureContext(fcount, fsample_count, paircount, stats.bilex_f, stats.bilex_e, stats.bilex_fe) return None # Find all phrases that we might try to extract |