diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 38 |
1 files changed, 17 insertions, 21 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index 1c8d25a4..34a002c5 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -8,6 +8,8 @@ from libc.stdlib cimport malloc, realloc, free from libc.string cimport memset, memcpy from libc.math cimport fmod, ceil, floor, log +from collections import defaultdict, Counter + cdef int PRECOMPUTE = 0 cdef int MERGE = 1 cdef int BAEZA_YATES = 2 @@ -73,8 +75,7 @@ cdef class PhraseLocation: self.arr_high = arr_high self.arr = arr self.num_subpatterns = num_subpatterns - - + cdef class Sampler: '''A Sampler implements a logic for choosing @@ -208,6 +209,7 @@ cdef class HieroCachingRuleFactory: cdef TrieTable rules cdef Sampler sampler + cdef Scorer scorer cdef int max_chunks cdef int max_target_chunks @@ -359,7 +361,8 @@ cdef class HieroCachingRuleFactory: self.findexes = IntList(initial_len=10) self.findexes1 = IntList(initial_len=10) - def configure(self, SuffixArray fsarray, DataArray edarray, Sampler sampler): + def configure(self, SuffixArray fsarray, DataArray edarray, + Sampler sampler, Scorer scorer): '''This gives the RuleFactory access to the Context object. Here we also use it to precompute the most expensive intersections in the corpus quickly.''' @@ -370,6 +373,7 @@ cdef class HieroCachingRuleFactory: self.eid2symid = self.set_idmap(self.eda) self.precompute() self.sampler = sampler + self.scorer = scorer cdef set_idmap(self, DataArray darray): cdef int word_id, new_word_id, N @@ -916,7 +920,7 @@ cdef class HieroCachingRuleFactory: candidate.append([next_id,curr[1]+jump]) return sorted(result); - def input(self, fwords, models): + def input(self, fwords): '''When this function is called on the RuleFactory, it looks up all of the rules that can be used to translate the input sentence''' @@ -1074,18 +1078,12 @@ cdef class HieroCachingRuleFactory: extract_stop = monitor_cpu() self.extract_time = self.extract_time + extract_stop - extract_start if len(extracts) > 0: - fphrases = {} - fals = {} - fcount = {} + fcount = Counter() + fphrases = defaultdict(lambda: defaultdict(Counter)) for f, e, count, als in extracts: - fcount.setdefault(f, 0.0) - fcount[f] = fcount[f] + count - fphrases.setdefault(f, {}) - fphrases[f].setdefault(e, {}) - fphrases[f][e].setdefault(als,0.0) - fphrases[f][e][als] = fphrases[f][e][als] + count + fcount[f] += count + fphrases[f][e][als] += count for f, elist in fphrases.iteritems(): - f_margin = fcount[f] for e, alslist in elist.iteritems(): alignment = None count = 0 @@ -1093,11 +1091,9 @@ cdef class HieroCachingRuleFactory: if currcount > count: alignment = als count = currcount - scores = [] - for model in models: - scores.append(model(f, e, count, fcount[f], num_samples)) - yield Rule(self.category, f, e, - scores=scores, word_alignments=alignment) + scores = self.scorer.score(f, e, count, + fcount[f], num_samples) + yield Rule(self.category, f, e, scores, alignment) if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size: for alt_id in range(len(fwords[i+spanlen])): @@ -1377,9 +1373,9 @@ cdef class HieroCachingRuleFactory: free(e_gap_order) return result - cdef create_alignments(self, int* sent_links, int num_links, findexes, eindexes): + cdef IntList create_alignments(self, int* sent_links, int num_links, findexes, eindexes): cdef unsigned i - ret = IntList() + cdef IntList ret = IntList() for i in range(len(findexes)): s = findexes[i] if (s<0): |