diff options
Diffstat (limited to 'python/cdec/sa/rulefactory.pxi')
-rw-r--r-- | python/cdec/sa/rulefactory.pxi | 51 |
1 files changed, 14 insertions, 37 deletions
diff --git a/python/cdec/sa/rulefactory.pxi b/python/cdec/sa/rulefactory.pxi index 044a78c8..635cca10 100644 --- a/python/cdec/sa/rulefactory.pxi +++ b/python/cdec/sa/rulefactory.pxi @@ -31,10 +31,6 @@ OnlineFeatureContext = namedtuple('OnlineFeatureContext', ['fcount', 'fsample_count', 'paircount', - 'bilex_f', - 'bilex_e', - 'bilex_fe', - 'bilex_ef' ]) cdef class OnlineStats: @@ -43,10 +39,6 @@ cdef class OnlineStats: cdef public phrases_e cdef public phrases_fe cdef public phrases_al - cdef public bilex_f - cdef public bilex_e - cdef public bilex_fe - cdef public bilex_ef def __cinit__(self): # Keep track of everything that can be sampled: @@ -58,12 +50,6 @@ cdef class OnlineStats: self.phrases_fe = defaultdict(lambda: defaultdict(int)) self.phrases_al = defaultdict(lambda: defaultdict(tuple)) - # Bilexical counts - self.bilex_f = defaultdict(int) - self.bilex_e = defaultdict(int) - self.bilex_fe = defaultdict(lambda: defaultdict(int)) - self.bilex_ef = defaultdict(lambda: defaultdict(int)) - cdef int PRECOMPUTE = 0 cdef int MERGE = 1 cdef int BAEZA_YATES = 2 @@ -305,10 +291,13 @@ cdef class HieroCachingRuleFactory: cdef bint online cdef online_stats + cdef bilex def __cinit__(self, # compiled alignment object (REQUIRED) Alignment alignment, + # bilexical dictionary if online + bilex=None, # parameter for double-binary search; doesn't seem to matter much float by_slack_factor=1.0, # name of generic nonterminal used by Hiero @@ -414,7 +403,10 @@ cdef class HieroCachingRuleFactory: self.findexes1 = IntList(initial_len=10) # Online stats - + + # None if not online + self.bilex = bilex + # True after data is added self.online = False self.online_stats = defaultdict(OnlineStats) @@ -2053,28 +2045,13 @@ cdef class HieroCachingRuleFactory: stats.phrases_fe[f_ph][e_ph] += 1 if not stats.phrases_al[f_ph][e_ph]: stats.phrases_al[f_ph][e_ph] = al - - # Update Bilexical counts - aligned_fe = [list() for _ in range(len(f_words))] - aligned_ef = [list() for _ in range(len(e_words))] - for (i, j) in alignment: - aligned_fe[i].append(j) - aligned_ef[j].append(i) - for f_i in range(len(f_words)): - e_i_aligned = aligned_fe[f_i] - lc = len(e_i_aligned) - if lc > 0: - stats.bilex_f[f_words[f_i]] += 1 - for e_i in e_i_aligned: - stats.bilex_fe[f_words[f_i]][e_words[e_i]] += (1.0) / lc - for e_i in range(len(e_words)): - f_i_aligned = aligned_ef[e_i] - lc = len(f_i_aligned) - if lc > 0: - stats.bilex_e[e_words[e_i]] += 1 - for f_i in f_i_aligned: - stats.bilex_ef[e_words[e_i]][f_words[f_i]] += (1.0) / lc + # Update bilexical dictionary (if exists) + if self.bilex: + self.bilex.update(f_words, e_words, alignment) + else: + logger.warning('No online bilexical dictionary specified, not updating lexical weights') + # Create a rule from source, target, non-terminals, and alignments def form_rule(self, f_i, e_i, f_span, e_span, nt, al): @@ -2154,7 +2131,7 @@ cdef class HieroCachingRuleFactory: fsample_count = stats.samples_f.get(f, 0) d = stats.phrases_fe.get(f, None) paircount = d.get(e, 0) if d else 0 - return OnlineFeatureContext(fcount, fsample_count, paircount, stats.bilex_f, stats.bilex_e, stats.bilex_fe) + return OnlineFeatureContext(fcount, fsample_count, paircount) return None # Find all phrases that we might try to extract |