summaryrefslogtreecommitdiff
path: root/python/cdec/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/cdec/sa/rulefactory.pxi')
-rw-r--r--python/cdec/sa/rulefactory.pxi51
1 files changed, 14 insertions, 37 deletions
diff --git a/python/cdec/sa/rulefactory.pxi b/python/cdec/sa/rulefactory.pxi
index 044a78c8..635cca10 100644
--- a/python/cdec/sa/rulefactory.pxi
+++ b/python/cdec/sa/rulefactory.pxi
@@ -31,10 +31,6 @@ OnlineFeatureContext = namedtuple('OnlineFeatureContext',
['fcount',
'fsample_count',
'paircount',
- 'bilex_f',
- 'bilex_e',
- 'bilex_fe',
- 'bilex_ef'
])
cdef class OnlineStats:
@@ -43,10 +39,6 @@ cdef class OnlineStats:
cdef public phrases_e
cdef public phrases_fe
cdef public phrases_al
- cdef public bilex_f
- cdef public bilex_e
- cdef public bilex_fe
- cdef public bilex_ef
def __cinit__(self):
# Keep track of everything that can be sampled:
@@ -58,12 +50,6 @@ cdef class OnlineStats:
self.phrases_fe = defaultdict(lambda: defaultdict(int))
self.phrases_al = defaultdict(lambda: defaultdict(tuple))
- # Bilexical counts
- self.bilex_f = defaultdict(int)
- self.bilex_e = defaultdict(int)
- self.bilex_fe = defaultdict(lambda: defaultdict(int))
- self.bilex_ef = defaultdict(lambda: defaultdict(int))
-
cdef int PRECOMPUTE = 0
cdef int MERGE = 1
cdef int BAEZA_YATES = 2
@@ -305,10 +291,13 @@ cdef class HieroCachingRuleFactory:
cdef bint online
cdef online_stats
+ cdef bilex
def __cinit__(self,
# compiled alignment object (REQUIRED)
Alignment alignment,
+ # bilexical dictionary if online
+ bilex=None,
# parameter for double-binary search; doesn't seem to matter much
float by_slack_factor=1.0,
# name of generic nonterminal used by Hiero
@@ -414,7 +403,10 @@ cdef class HieroCachingRuleFactory:
self.findexes1 = IntList(initial_len=10)
# Online stats
-
+
+ # None if not online
+ self.bilex = bilex
+
# True after data is added
self.online = False
self.online_stats = defaultdict(OnlineStats)
@@ -2053,28 +2045,13 @@ cdef class HieroCachingRuleFactory:
stats.phrases_fe[f_ph][e_ph] += 1
if not stats.phrases_al[f_ph][e_ph]:
stats.phrases_al[f_ph][e_ph] = al
-
- # Update Bilexical counts
- aligned_fe = [list() for _ in range(len(f_words))]
- aligned_ef = [list() for _ in range(len(e_words))]
- for (i, j) in alignment:
- aligned_fe[i].append(j)
- aligned_ef[j].append(i)
- for f_i in range(len(f_words)):
- e_i_aligned = aligned_fe[f_i]
- lc = len(e_i_aligned)
- if lc > 0:
- stats.bilex_f[f_words[f_i]] += 1
- for e_i in e_i_aligned:
- stats.bilex_fe[f_words[f_i]][e_words[e_i]] += (1.0) / lc
- for e_i in range(len(e_words)):
- f_i_aligned = aligned_ef[e_i]
- lc = len(f_i_aligned)
- if lc > 0:
- stats.bilex_e[e_words[e_i]] += 1
- for f_i in f_i_aligned:
- stats.bilex_ef[e_words[e_i]][f_words[f_i]] += (1.0) / lc
+ # Update bilexical dictionary (if exists)
+ if self.bilex:
+ self.bilex.update(f_words, e_words, alignment)
+ else:
+ logger.warning('No online bilexical dictionary specified, not updating lexical weights')
+
# Create a rule from source, target, non-terminals, and alignments
def form_rule(self, f_i, e_i, f_span, e_span, nt, al):
@@ -2154,7 +2131,7 @@ cdef class HieroCachingRuleFactory:
fsample_count = stats.samples_f.get(f, 0)
d = stats.phrases_fe.get(f, None)
paircount = d.get(e, 0) if d else 0
- return OnlineFeatureContext(fcount, fsample_count, paircount, stats.bilex_f, stats.bilex_e, stats.bilex_fe)
+ return OnlineFeatureContext(fcount, fsample_count, paircount)
return None
# Find all phrases that we might try to extract