summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi40
1 files changed, 26 insertions, 14 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index afd83785..69cadac9 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -3,12 +3,24 @@
# Much faster than the Python numbers reported there.
# Note to reader: this code is closer to C than Python
import gc
+import itertools
from libc.stdlib cimport malloc, realloc, free
from libc.string cimport memset, memcpy
from libc.math cimport fmod, ceil, floor, log
-from collections import defaultdict
+from collections import defaultdict, Counter, namedtuple
+
+FeatureContext = namedtuple("FeatureContext",
+ ["fphrase",
+ "ephrase",
+ "paircount",
+ "fcount",
+ "fsample_count",
+ "input_span",
+ "matches",
+ "test_sentence"
+ ])
cdef int PRECOMPUTE = 0
cdef int MERGE = 1
@@ -1070,29 +1082,29 @@ cdef class HieroCachingRuleFactory:
extract = []
assign_matching(&matching, sample.arr, j, num_subpatterns, self.fda.sent_id.arr)
+ loc = tuple(sample[j:j+num_subpatterns])
extract = self.extract(hiero_phrase, &matching, chunklen.arr, num_subpatterns)
- extracts.extend(extract)
+ extracts.extend([(e, loc) for e in extract])
j = j + num_subpatterns
num_samples = sample.len/num_subpatterns
extract_stop = monitor_cpu()
self.extract_time = self.extract_time + extract_stop - extract_start
if len(extracts) > 0:
- fcount = defaultdict(int)
- fphrases = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
- for f, e, count, als in extracts:
+ fcount = Counter()
+ fphrases = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
+ for (f, e, count, als), loc in extracts:
fcount[f] += count
- fphrases[f][e][als] += count
+ fphrases[f][e][als].append(loc)
for f, elist in fphrases.iteritems():
for e, alslist in elist.iteritems():
- alignment = None
- count = 0
- for als, currcount in alslist.iteritems():
- if currcount > count:
- alignment = als
- count = currcount
- scores = self.scorer.score(f, e, count,
- fcount[f], num_samples)
+ alignment = max(alslist.iteritems(), key=lambda x: len(x[1]))[0]
+ locs = tuple(itertools.chain(alslist.itervalues()))
+ count = len(locs)
+ scores = self.scorer.score(FeatureContext(
+ f, e, count, fcount[f], num_samples,
+ (i,k), locs, fwords
+ ))
yield Rule(self.category, f, e, scores, alignment)
if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size: