summaryrefslogtreecommitdiff
path: root/python/cdec/sa/rulefactory.pxi
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2014-03-04 22:47:28 +0100
committerPatrick Simianer <p@simianer.de>2014-03-04 22:47:28 +0100
commit3eedf96b5a08b3e3414888d328c505814b84d8db (patch)
tree1bc4b6debf8be58b4180ceb4ae960463e93c4bdf /python/cdec/sa/rulefactory.pxi
parent739a8cd9a92ee10411e352e1677235a8c39ba8b3 (diff)
parent3b00351ff2047c226cde750fe67eae5b34388373 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'python/cdec/sa/rulefactory.pxi')
-rw-r--r--python/cdec/sa/rulefactory.pxi31
1 files changed, 23 insertions, 8 deletions
diff --git a/python/cdec/sa/rulefactory.pxi b/python/cdec/sa/rulefactory.pxi
index 78a23196..044a78c8 100644
--- a/python/cdec/sa/rulefactory.pxi
+++ b/python/cdec/sa/rulefactory.pxi
@@ -33,7 +33,8 @@ OnlineFeatureContext = namedtuple('OnlineFeatureContext',
'paircount',
'bilex_f',
'bilex_e',
- 'bilex_fe'
+ 'bilex_fe',
+ 'bilex_ef'
])
cdef class OnlineStats:
@@ -45,6 +46,7 @@ cdef class OnlineStats:
cdef public bilex_f
cdef public bilex_e
cdef public bilex_fe
+ cdef public bilex_ef
def __cinit__(self):
# Keep track of everything that can be sampled:
@@ -60,6 +62,7 @@ cdef class OnlineStats:
self.bilex_f = defaultdict(int)
self.bilex_e = defaultdict(int)
self.bilex_fe = defaultdict(lambda: defaultdict(int))
+ self.bilex_ef = defaultdict(lambda: defaultdict(int))
cdef int PRECOMPUTE = 0
cdef int MERGE = 1
@@ -2052,13 +2055,25 @@ cdef class HieroCachingRuleFactory:
stats.phrases_al[f_ph][e_ph] = al
# Update Bilexical counts
- # TODO: use alignments instead of cooc
- for e_w in e_words:
- stats.bilex_e[e_w] += 1
- for f_w in f_words:
- stats.bilex_f[f_w] += 1
- for e_w in e_words:
- stats.bilex_fe[f_w][e_w] += 1
+ aligned_fe = [list() for _ in range(len(f_words))]
+ aligned_ef = [list() for _ in range(len(e_words))]
+ for (i, j) in alignment:
+ aligned_fe[i].append(j)
+ aligned_ef[j].append(i)
+ for f_i in range(len(f_words)):
+ e_i_aligned = aligned_fe[f_i]
+ lc = len(e_i_aligned)
+ if lc > 0:
+ stats.bilex_f[f_words[f_i]] += 1
+ for e_i in e_i_aligned:
+ stats.bilex_fe[f_words[f_i]][e_words[e_i]] += (1.0) / lc
+ for e_i in range(len(e_words)):
+ f_i_aligned = aligned_ef[e_i]
+ lc = len(f_i_aligned)
+ if lc > 0:
+ stats.bilex_e[e_words[e_i]] += 1
+ for f_i in f_i_aligned:
+ stats.bilex_ef[e_words[e_i]][f_words[f_i]] += (1.0) / lc
# Create a rule from source, target, non-terminals, and alignments
def form_rule(self, f_i, e_i, f_span, e_span, nt, al):