summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi38
1 files changed, 17 insertions, 21 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index 1c8d25a4..34a002c5 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -8,6 +8,8 @@ from libc.stdlib cimport malloc, realloc, free
from libc.string cimport memset, memcpy
from libc.math cimport fmod, ceil, floor, log
+from collections import defaultdict, Counter
+
cdef int PRECOMPUTE = 0
cdef int MERGE = 1
cdef int BAEZA_YATES = 2
@@ -73,8 +75,7 @@ cdef class PhraseLocation:
self.arr_high = arr_high
self.arr = arr
self.num_subpatterns = num_subpatterns
-
-
+
cdef class Sampler:
'''A Sampler implements a logic for choosing
@@ -208,6 +209,7 @@ cdef class HieroCachingRuleFactory:
cdef TrieTable rules
cdef Sampler sampler
+ cdef Scorer scorer
cdef int max_chunks
cdef int max_target_chunks
@@ -359,7 +361,8 @@ cdef class HieroCachingRuleFactory:
self.findexes = IntList(initial_len=10)
self.findexes1 = IntList(initial_len=10)
- def configure(self, SuffixArray fsarray, DataArray edarray, Sampler sampler):
+ def configure(self, SuffixArray fsarray, DataArray edarray,
+ Sampler sampler, Scorer scorer):
'''This gives the RuleFactory access to the Context object.
Here we also use it to precompute the most expensive intersections
in the corpus quickly.'''
@@ -370,6 +373,7 @@ cdef class HieroCachingRuleFactory:
self.eid2symid = self.set_idmap(self.eda)
self.precompute()
self.sampler = sampler
+ self.scorer = scorer
cdef set_idmap(self, DataArray darray):
cdef int word_id, new_word_id, N
@@ -916,7 +920,7 @@ cdef class HieroCachingRuleFactory:
candidate.append([next_id,curr[1]+jump])
return sorted(result);
- def input(self, fwords, models):
+ def input(self, fwords):
'''When this function is called on the RuleFactory,
it looks up all of the rules that can be used to translate
the input sentence'''
@@ -1074,18 +1078,12 @@ cdef class HieroCachingRuleFactory:
extract_stop = monitor_cpu()
self.extract_time = self.extract_time + extract_stop - extract_start
if len(extracts) > 0:
- fphrases = {}
- fals = {}
- fcount = {}
+ fcount = Counter()
+ fphrases = defaultdict(lambda: defaultdict(Counter))
for f, e, count, als in extracts:
- fcount.setdefault(f, 0.0)
- fcount[f] = fcount[f] + count
- fphrases.setdefault(f, {})
- fphrases[f].setdefault(e, {})
- fphrases[f][e].setdefault(als,0.0)
- fphrases[f][e][als] = fphrases[f][e][als] + count
+ fcount[f] += count
+ fphrases[f][e][als] += count
for f, elist in fphrases.iteritems():
- f_margin = fcount[f]
for e, alslist in elist.iteritems():
alignment = None
count = 0
@@ -1093,11 +1091,9 @@ cdef class HieroCachingRuleFactory:
if currcount > count:
alignment = als
count = currcount
- scores = []
- for model in models:
- scores.append(model(f, e, count, fcount[f], num_samples))
- yield Rule(self.category, f, e,
- scores=scores, word_alignments=alignment)
+ scores = self.scorer.score(f, e, count,
+ fcount[f], num_samples)
+ yield Rule(self.category, f, e, scores, alignment)
if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size:
for alt_id in range(len(fwords[i+spanlen])):
@@ -1377,9 +1373,9 @@ cdef class HieroCachingRuleFactory:
free(e_gap_order)
return result
- cdef create_alignments(self, int* sent_links, int num_links, findexes, eindexes):
+ cdef IntList create_alignments(self, int* sent_links, int num_links, findexes, eindexes):
cdef unsigned i
- ret = IntList()
+ cdef IntList ret = IntList()
for i in range(len(findexes)):
s = findexes[i]
if (s<0):