summaryrefslogtreecommitdiff
path: root/python/cdec/sa/rulefactory.pxi
diff options
context:
space:
mode:
authorMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-23 08:21:58 -0700
committerMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-23 08:21:58 -0700
commit6efdb6758dc2d204d203bf3dcdb0ea14b6e2bbd5 (patch)
tree41a5223781397a7c9bcab70054be0495c160101f /python/cdec/sa/rulefactory.pxi
parentee384acf34de0d7613a70b81c674d607a3bd782a (diff)
One extractor, multiple online contexts.
Diffstat (limited to 'python/cdec/sa/rulefactory.pxi')
-rw-r--r--python/cdec/sa/rulefactory.pxi122
1 files changed, 51 insertions, 71 deletions
diff --git a/python/cdec/sa/rulefactory.pxi b/python/cdec/sa/rulefactory.pxi
index 10bb9737..a96b01e7 100644
--- a/python/cdec/sa/rulefactory.pxi
+++ b/python/cdec/sa/rulefactory.pxi
@@ -36,6 +36,31 @@ OnlineFeatureContext = namedtuple('OnlineFeatureContext',
'bilex_fe'
])
+cdef class OnlineStats:
+ cdef public samples_f
+ cdef public phrases_f
+ cdef public phrases_e
+ cdef public phrases_fe
+ cdef public phrases_al
+ cdef public bilex_f
+ cdef public bilex_e
+ cdef public bilex_fe
+
+ def __cinit__(self):
+ # Keep track of everything that can be sampled:
+ self.samples_f = defaultdict(int)
+
+ # Phrase counts
+ self.phrases_f = defaultdict(int)
+ self.phrases_e = defaultdict(int)
+ self.phrases_fe = defaultdict(lambda: defaultdict(int))
+ self.phrases_al = defaultdict(lambda: defaultdict(tuple))
+
+ # Bilexical counts
+ self.bilex_f = defaultdict(int)
+ self.bilex_e = defaultdict(int)
+ self.bilex_fe = defaultdict(lambda: defaultdict(int))
+
cdef int PRECOMPUTE = 0
cdef int MERGE = 1
cdef int BAEZA_YATES = 2
@@ -276,14 +301,7 @@ cdef class HieroCachingRuleFactory:
cdef IntList findexes1
cdef bint online
- cdef samples_f
- cdef phrases_f
- cdef phrases_e
- cdef phrases_fe
- cdef phrases_al
- cdef bilex_f
- cdef bilex_e
- cdef bilex_fe
+ cdef online_stats
def __cinit__(self,
# compiled alignment object (REQUIRED)
@@ -396,21 +414,8 @@ cdef class HieroCachingRuleFactory:
# True after data is added
self.online = False
+ self.online_stats = defaultdict(OnlineStats)
- # Keep track of everything that can be sampled:
- self.samples_f = defaultdict(int)
-
- # Phrase counts
- self.phrases_f = defaultdict(int)
- self.phrases_e = defaultdict(int)
- self.phrases_fe = defaultdict(lambda: defaultdict(int))
- self.phrases_al = defaultdict(lambda: defaultdict(tuple))
-
- # Bilexical counts
- self.bilex_f = defaultdict(int)
- self.bilex_e = defaultdict(int)
- self.bilex_fe = defaultdict(lambda: defaultdict(int))
-
def configure(self, SuffixArray fsarray, DataArray edarray,
Sampler sampler, Scorer scorer):
'''This gives the RuleFactory access to the Context object.
@@ -970,7 +975,7 @@ cdef class HieroCachingRuleFactory:
candidate.append([next_id,curr[1]+jump])
return sorted(result);
- def input(self, fwords, meta):
+ def input(self, fwords, meta, ctx_name=None):
'''When this function is called on the RuleFactory,
it looks up all of the rules that can be used to translate
the input sentence'''
@@ -1154,7 +1159,7 @@ cdef class HieroCachingRuleFactory:
fwords, self.fda, self.eda,
meta,
# Include online stats. None if none.
- self.online_ctx_lookup(f, e)))
+ self.online_ctx_lookup(f, e, ctx_name)))
# Phrase pair processed
if self.online:
seen_phrases.add((f, e))
@@ -1184,6 +1189,7 @@ cdef class HieroCachingRuleFactory:
# Online rule extraction and scoring
if self.online:
+ stats = self.online_stats[ctx_name]
f_syms = tuple(word[0][0] for word in fwords)
for f, lex_i, lex_j in self.get_f_phrases(f_syms):
spanlen = (lex_j - lex_i) + 1
@@ -1191,7 +1197,7 @@ cdef class HieroCachingRuleFactory:
spanlen += 1
if not sym_isvar(f[1]):
spanlen += 1
- for e in self.phrases_fe.get(f, ()):
+ for e in stats.phrases_fe.get(f, ()):
if (f, e) not in seen_phrases:
# Don't add multiple instances of the same phrase here
seen_phrases.add((f, e))
@@ -1200,8 +1206,8 @@ cdef class HieroCachingRuleFactory:
spanlen, None, None,
fwords, self.fda, self.eda,
meta,
- self.online_ctx_lookup(f, e)))
- alignment = self.phrases_al[f][e]
+ self.online_ctx_lookup(f, e, ctx_name)))
+ alignment = stats.phrases_al[f][e]
yield Rule(self.category, f, e, scores, alignment)
stop_time = monitor_cpu()
@@ -1882,7 +1888,7 @@ cdef class HieroCachingRuleFactory:
# Aggregate stats from a training instance
# (Extract rules, update counts)
- def add_instance(self, f_words, e_words, alignment):
+ def add_instance(self, f_words, e_words, alignment, ctx_name=None):
self.online = True
@@ -2028,28 +2034,30 @@ cdef class HieroCachingRuleFactory:
continue
extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False)
+ stats = self.online_stats[ctx_name]
+
# Update possible phrases (samples)
# This could be more efficiently integrated with extraction
# at the cost of readability
for f, lex_i, lex_j in self.get_f_phrases(f_words):
- self.samples_f[f] += 1
+ stats.samples_f[f] += 1
# Update phrase counts
for rule in rules:
(f_ph, e_ph, al) = rule[:3]
- self.phrases_f[f_ph] += 1
- self.phrases_e[e_ph] += 1
- self.phrases_fe[f_ph][e_ph] += 1
- if not self.phrases_al[f_ph][e_ph]:
- self.phrases_al[f_ph][e_ph] = al
+ stats.phrases_f[f_ph] += 1
+ stats.phrases_e[e_ph] += 1
+ stats.phrases_fe[f_ph][e_ph] += 1
+ if not stats.phrases_al[f_ph][e_ph]:
+ stats.phrases_al[f_ph][e_ph] = al
# Update Bilexical counts
for e_w in e_words:
- self.bilex_e[e_w] += 1
+ stats.bilex_e[e_w] += 1
for f_w in f_words:
- self.bilex_f[f_w] += 1
+ stats.bilex_f[f_w] += 1
for e_w in e_words:
- self.bilex_fe[f_w][e_w] += 1
+ stats.bilex_fe[f_w][e_w] += 1
# Create a rule from source, target, non-terminals, and alignments
def form_rule(self, f_i, e_i, f_span, e_span, nt, al):
@@ -2121,44 +2129,16 @@ cdef class HieroCachingRuleFactory:
a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a)
return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str)
- # Debugging
- def dump_online_stats(self):
- logger.info('------------------------------')
- logger.info(' Online Stats ')
- logger.info('------------------------------')
- logger.info('f')
- for w in self.bilex_f:
- logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w]))
- logger.info('e')
- for w in self.bilex_e:
- logger.info(sym_tostring(w) + ' : ' + str(self.bilex_e[w]))
- logger.info('fe')
- for w in self.bilex_fe:
- for w2 in self.bilex_fe[w]:
- logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2]))
- logger.info('F')
- for ph in self.phrases_f:
- logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph]))
- logger.info('E')
- for ph in self.phrases_e:
- logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph]))
- logger.info('FE')
- self.dump_online_rules()
-
- def dump_online_rules(self):
- for ph in self.phrases_fe:
- for ph2 in self.phrases_fe[ph]:
- logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2]))
-
# Lookup online stats for phrase pair (f, e). Return None if no match.
# IMPORTANT: use get() to avoid adding items to defaultdict
- def online_ctx_lookup(self, f, e):
+ def online_ctx_lookup(self, f, e, ctx_name=None):
if self.online:
- fcount = self.phrases_f.get(f, 0)
- fsample_count = self.samples_f.get(f, 0)
- d = self.phrases_fe.get(f, None)
+ stats = self.online_stats[ctx_name]
+ fcount = stats.phrases_f.get(f, 0)
+ fsample_count = stats.samples_f.get(f, 0)
+ d = stats.phrases_fe.get(f, None)
paircount = d.get(e, 0) if d else 0
- return OnlineFeatureContext(fcount, fsample_count, paircount, self.bilex_f, self.bilex_e, self.bilex_fe)
+ return OnlineFeatureContext(fcount, fsample_count, paircount, stats.bilex_f, stats.bilex_e, stats.bilex_fe)
return None
# Find all phrases that we might try to extract