diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-09-20 21:51:31 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-09-20 21:51:31 -0400 |
commit | d2bc8694e5450a46c6f851d926c1ebfeb3424cbf (patch) | |
tree | 5619896999d43ca478acee0da2b1c60244aab5b1 /python/src/sa/rulefactory.pxi | |
parent | 78518f1f417616633b300a361cd5e0c1bcb1ff24 (diff) | |
parent | 5d159b948ad71850bcb03d0882ea7183a3a59b7e (diff) |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index 287b9a67..5f6558b3 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -19,7 +19,11 @@ FeatureContext = namedtuple('FeatureContext', 'fsample_count', 'input_span', 'matches', - 'test_sentence' + 'input_match', + 'test_sentence', + 'f_text', + 'e_text', + 'meta' ]) cdef int PRECOMPUTE = 0 @@ -932,7 +936,7 @@ cdef class HieroCachingRuleFactory: candidate.append([next_id,curr[1]+jump]) return sorted(result); - def input(self, fwords): + def input(self, fwords, meta): '''When this function is called on the RuleFactory, it looks up all of the rules that can be used to translate the input sentence''' @@ -957,7 +961,7 @@ cdef class HieroCachingRuleFactory: for i in range(len(fwords)): for alt in range(0, len(fwords[i])): if fwords[i][alt][0] != EPSILON: - frontier.append((i, i, alt, 0, self.rules.root, (), False)) + frontier.append((i, i, (i,), alt, 0, self.rules.root, (), False)) xroot = None x1 = sym_setindex(self.category, 1) @@ -970,7 +974,7 @@ cdef class HieroCachingRuleFactory: for i in range(self.min_gap_size, len(fwords)): for alt in range(0, len(fwords[i])): if fwords[i][alt][0] != EPSILON: - frontier.append((i-self.min_gap_size, i, alt, self.min_gap_size, xroot, (x1,), True)) + frontier.append((i-self.min_gap_size, i, (i,), alt, self.min_gap_size, xroot, (x1,), True)) next_states = [] for i in range(len(fwords)): @@ -978,7 +982,7 @@ cdef class HieroCachingRuleFactory: while len(frontier) > 0: new_frontier = [] - for k, i, alt, pathlen, node, prefix, is_shadow_path in frontier: + for k, i, input_match, alt, pathlen, node, prefix, is_shadow_path in frontier: word_id = fwords[i][alt][0] spanlen = fwords[i][alt][2] # TODO get rid of k -- pathlen is replacing it @@ -987,7 +991,7 @@ cdef class HieroCachingRuleFactory: if i+spanlen >= len(fwords): continue for nualt in range(0,len(fwords[i+spanlen])): - frontier.append((k, i+spanlen, nualt, pathlen, node, prefix, is_shadow_path)) + frontier.append((k, i+spanlen, input_match, nualt, pathlen, node, prefix, is_shadow_path)) continue phrase = prefix + (word_id,) @@ -1098,18 +1102,19 @@ cdef class HieroCachingRuleFactory: fphrases[f][e][als].append(loc) for f, elist in fphrases.iteritems(): for e, alslist in elist.iteritems(): - alignment = max(alslist.iteritems(), key=lambda x: len(x[1]))[0] - locs = tuple(itertools.chain(alslist.itervalues())) + alignment, max_locs = max(alslist.iteritems(), key=lambda x: len(x[1])) + locs = tuple(itertools.chain.from_iterable(alslist.itervalues())) count = len(locs) scores = self.scorer.score(FeatureContext( f, e, count, fcount[f], num_samples, - (i,k), locs, fwords - )) + (k,i+spanlen), locs, input_match, + fwords, self.fda, self.eda, + meta)) yield Rule(self.category, f, e, scores, alignment) if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size: for alt_id in range(len(fwords[i+spanlen])): - new_frontier.append((k, i+spanlen, alt_id, pathlen + 1, node, phrase, is_shadow_path)) + new_frontier.append((k, i+spanlen, input_match, alt_id, pathlen + 1, node, phrase, is_shadow_path)) num_subpatterns = arity if not is_shadow_path: num_subpatterns = num_subpatterns + 1 @@ -1126,7 +1131,7 @@ cdef class HieroCachingRuleFactory: nodes_isteps_away_buffer[key] = frontier_nodes for (i, alt, pathlen) in frontier_nodes: - new_frontier.append((k, i, alt, pathlen, xnode, phrase +(xcat,), is_shadow_path)) + new_frontier.append((k, i, input_match + (i,), alt, pathlen, xnode, phrase +(xcat,), is_shadow_path)) frontier = new_frontier stop_time = monitor_cpu() |