summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-09-20 21:51:31 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-09-20 21:51:31 -0400
commit214f4714d95cb27d31ff976a11dec8a0c0eb438d (patch)
tree0970ab16db5260f128a65d60f1dc60caf831efc5 /python/src/sa/rulefactory.pxi
parent17d085055e24bf189a3b378af77e1071922893cc (diff)
parente26edac51cc47b2b2322fbb870308daa708cec8c (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi29
1 files changed, 17 insertions, 12 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index 287b9a67..5f6558b3 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -19,7 +19,11 @@ FeatureContext = namedtuple('FeatureContext',
'fsample_count',
'input_span',
'matches',
- 'test_sentence'
+ 'input_match',
+ 'test_sentence',
+ 'f_text',
+ 'e_text',
+ 'meta'
])
cdef int PRECOMPUTE = 0
@@ -932,7 +936,7 @@ cdef class HieroCachingRuleFactory:
candidate.append([next_id,curr[1]+jump])
return sorted(result);
- def input(self, fwords):
+ def input(self, fwords, meta):
'''When this function is called on the RuleFactory,
it looks up all of the rules that can be used to translate
the input sentence'''
@@ -957,7 +961,7 @@ cdef class HieroCachingRuleFactory:
for i in range(len(fwords)):
for alt in range(0, len(fwords[i])):
if fwords[i][alt][0] != EPSILON:
- frontier.append((i, i, alt, 0, self.rules.root, (), False))
+ frontier.append((i, i, (i,), alt, 0, self.rules.root, (), False))
xroot = None
x1 = sym_setindex(self.category, 1)
@@ -970,7 +974,7 @@ cdef class HieroCachingRuleFactory:
for i in range(self.min_gap_size, len(fwords)):
for alt in range(0, len(fwords[i])):
if fwords[i][alt][0] != EPSILON:
- frontier.append((i-self.min_gap_size, i, alt, self.min_gap_size, xroot, (x1,), True))
+ frontier.append((i-self.min_gap_size, i, (i,), alt, self.min_gap_size, xroot, (x1,), True))
next_states = []
for i in range(len(fwords)):
@@ -978,7 +982,7 @@ cdef class HieroCachingRuleFactory:
while len(frontier) > 0:
new_frontier = []
- for k, i, alt, pathlen, node, prefix, is_shadow_path in frontier:
+ for k, i, input_match, alt, pathlen, node, prefix, is_shadow_path in frontier:
word_id = fwords[i][alt][0]
spanlen = fwords[i][alt][2]
# TODO get rid of k -- pathlen is replacing it
@@ -987,7 +991,7 @@ cdef class HieroCachingRuleFactory:
if i+spanlen >= len(fwords):
continue
for nualt in range(0,len(fwords[i+spanlen])):
- frontier.append((k, i+spanlen, nualt, pathlen, node, prefix, is_shadow_path))
+ frontier.append((k, i+spanlen, input_match, nualt, pathlen, node, prefix, is_shadow_path))
continue
phrase = prefix + (word_id,)
@@ -1098,18 +1102,19 @@ cdef class HieroCachingRuleFactory:
fphrases[f][e][als].append(loc)
for f, elist in fphrases.iteritems():
for e, alslist in elist.iteritems():
- alignment = max(alslist.iteritems(), key=lambda x: len(x[1]))[0]
- locs = tuple(itertools.chain(alslist.itervalues()))
+ alignment, max_locs = max(alslist.iteritems(), key=lambda x: len(x[1]))
+ locs = tuple(itertools.chain.from_iterable(alslist.itervalues()))
count = len(locs)
scores = self.scorer.score(FeatureContext(
f, e, count, fcount[f], num_samples,
- (i,k), locs, fwords
- ))
+ (k,i+spanlen), locs, input_match,
+ fwords, self.fda, self.eda,
+ meta))
yield Rule(self.category, f, e, scores, alignment)
if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size:
for alt_id in range(len(fwords[i+spanlen])):
- new_frontier.append((k, i+spanlen, alt_id, pathlen + 1, node, phrase, is_shadow_path))
+ new_frontier.append((k, i+spanlen, input_match, alt_id, pathlen + 1, node, phrase, is_shadow_path))
num_subpatterns = arity
if not is_shadow_path:
num_subpatterns = num_subpatterns + 1
@@ -1126,7 +1131,7 @@ cdef class HieroCachingRuleFactory:
nodes_isteps_away_buffer[key] = frontier_nodes
for (i, alt, pathlen) in frontier_nodes:
- new_frontier.append((k, i, alt, pathlen, xnode, phrase +(xcat,), is_shadow_path))
+ new_frontier.append((k, i, input_match + (i,), alt, pathlen, xnode, phrase +(xcat,), is_shadow_path))
frontier = new_frontier
stop_time = monitor_cpu()