diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 83 |
1 files changed, 48 insertions, 35 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index 1c8d25a4..5f6558b3 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -3,11 +3,29 @@ # Much faster than the Python numbers reported there. # Note to reader: this code is closer to C than Python import gc +import itertools from libc.stdlib cimport malloc, realloc, free from libc.string cimport memset, memcpy from libc.math cimport fmod, ceil, floor, log +from collections import defaultdict, Counter, namedtuple + +FeatureContext = namedtuple('FeatureContext', + ['fphrase', + 'ephrase', + 'paircount', + 'fcount', + 'fsample_count', + 'input_span', + 'matches', + 'input_match', + 'test_sentence', + 'f_text', + 'e_text', + 'meta' + ]) + cdef int PRECOMPUTE = 0 cdef int MERGE = 1 cdef int BAEZA_YATES = 2 @@ -73,8 +91,7 @@ cdef class PhraseLocation: self.arr_high = arr_high self.arr = arr self.num_subpatterns = num_subpatterns - - + cdef class Sampler: '''A Sampler implements a logic for choosing @@ -208,6 +225,7 @@ cdef class HieroCachingRuleFactory: cdef TrieTable rules cdef Sampler sampler + cdef Scorer scorer cdef int max_chunks cdef int max_target_chunks @@ -359,7 +377,8 @@ cdef class HieroCachingRuleFactory: self.findexes = IntList(initial_len=10) self.findexes1 = IntList(initial_len=10) - def configure(self, SuffixArray fsarray, DataArray edarray, Sampler sampler): + def configure(self, SuffixArray fsarray, DataArray edarray, + Sampler sampler, Scorer scorer): '''This gives the RuleFactory access to the Context object. Here we also use it to precompute the most expensive intersections in the corpus quickly.''' @@ -370,6 +389,7 @@ cdef class HieroCachingRuleFactory: self.eid2symid = self.set_idmap(self.eda) self.precompute() self.sampler = sampler + self.scorer = scorer cdef set_idmap(self, DataArray darray): cdef int word_id, new_word_id, N @@ -916,7 +936,7 @@ cdef class HieroCachingRuleFactory: candidate.append([next_id,curr[1]+jump]) return sorted(result); - def input(self, fwords, models): + def input(self, fwords, meta): '''When this function is called on the RuleFactory, it looks up all of the rules that can be used to translate the input sentence''' @@ -941,7 +961,7 @@ cdef class HieroCachingRuleFactory: for i in range(len(fwords)): for alt in range(0, len(fwords[i])): if fwords[i][alt][0] != EPSILON: - frontier.append((i, i, alt, 0, self.rules.root, (), False)) + frontier.append((i, i, (i,), alt, 0, self.rules.root, (), False)) xroot = None x1 = sym_setindex(self.category, 1) @@ -954,7 +974,7 @@ cdef class HieroCachingRuleFactory: for i in range(self.min_gap_size, len(fwords)): for alt in range(0, len(fwords[i])): if fwords[i][alt][0] != EPSILON: - frontier.append((i-self.min_gap_size, i, alt, self.min_gap_size, xroot, (x1,), True)) + frontier.append((i-self.min_gap_size, i, (i,), alt, self.min_gap_size, xroot, (x1,), True)) next_states = [] for i in range(len(fwords)): @@ -962,7 +982,7 @@ cdef class HieroCachingRuleFactory: while len(frontier) > 0: new_frontier = [] - for k, i, alt, pathlen, node, prefix, is_shadow_path in frontier: + for k, i, input_match, alt, pathlen, node, prefix, is_shadow_path in frontier: word_id = fwords[i][alt][0] spanlen = fwords[i][alt][2] # TODO get rid of k -- pathlen is replacing it @@ -971,7 +991,7 @@ cdef class HieroCachingRuleFactory: if i+spanlen >= len(fwords): continue for nualt in range(0,len(fwords[i+spanlen])): - frontier.append((k, i+spanlen, nualt, pathlen, node, prefix, is_shadow_path)) + frontier.append((k, i+spanlen, input_match, nualt, pathlen, node, prefix, is_shadow_path)) continue phrase = prefix + (word_id,) @@ -1066,42 +1086,35 @@ cdef class HieroCachingRuleFactory: extract = [] assign_matching(&matching, sample.arr, j, num_subpatterns, self.fda.sent_id.arr) + loc = tuple(sample[j:j+num_subpatterns]) extract = self.extract(hiero_phrase, &matching, chunklen.arr, num_subpatterns) - extracts.extend(extract) + extracts.extend([(e, loc) for e in extract]) j = j + num_subpatterns num_samples = sample.len/num_subpatterns extract_stop = monitor_cpu() self.extract_time = self.extract_time + extract_stop - extract_start if len(extracts) > 0: - fphrases = {} - fals = {} - fcount = {} - for f, e, count, als in extracts: - fcount.setdefault(f, 0.0) - fcount[f] = fcount[f] + count - fphrases.setdefault(f, {}) - fphrases[f].setdefault(e, {}) - fphrases[f][e].setdefault(als,0.0) - fphrases[f][e][als] = fphrases[f][e][als] + count + fcount = Counter() + fphrases = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for (f, e, count, als), loc in extracts: + fcount[f] += count + fphrases[f][e][als].append(loc) for f, elist in fphrases.iteritems(): - f_margin = fcount[f] for e, alslist in elist.iteritems(): - alignment = None - count = 0 - for als, currcount in alslist.iteritems(): - if currcount > count: - alignment = als - count = currcount - scores = [] - for model in models: - scores.append(model(f, e, count, fcount[f], num_samples)) - yield Rule(self.category, f, e, - scores=scores, word_alignments=alignment) + alignment, max_locs = max(alslist.iteritems(), key=lambda x: len(x[1])) + locs = tuple(itertools.chain.from_iterable(alslist.itervalues())) + count = len(locs) + scores = self.scorer.score(FeatureContext( + f, e, count, fcount[f], num_samples, + (k,i+spanlen), locs, input_match, + fwords, self.fda, self.eda, + meta)) + yield Rule(self.category, f, e, scores, alignment) if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size: for alt_id in range(len(fwords[i+spanlen])): - new_frontier.append((k, i+spanlen, alt_id, pathlen + 1, node, phrase, is_shadow_path)) + new_frontier.append((k, i+spanlen, input_match, alt_id, pathlen + 1, node, phrase, is_shadow_path)) num_subpatterns = arity if not is_shadow_path: num_subpatterns = num_subpatterns + 1 @@ -1118,7 +1131,7 @@ cdef class HieroCachingRuleFactory: nodes_isteps_away_buffer[key] = frontier_nodes for (i, alt, pathlen) in frontier_nodes: - new_frontier.append((k, i, alt, pathlen, xnode, phrase +(xcat,), is_shadow_path)) + new_frontier.append((k, i, input_match + (i,), alt, pathlen, xnode, phrase +(xcat,), is_shadow_path)) frontier = new_frontier stop_time = monitor_cpu() @@ -1377,9 +1390,9 @@ cdef class HieroCachingRuleFactory: free(e_gap_order) return result - cdef create_alignments(self, int* sent_links, int num_links, findexes, eindexes): + cdef IntList create_alignments(self, int* sent_links, int num_links, findexes, eindexes): cdef unsigned i - ret = IntList() + cdef IntList ret = IntList() for i in range(len(findexes)): s = findexes[i] if (s<0): |