summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi87
1 files changed, 53 insertions, 34 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index b95c23df..88f77a8d 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -29,6 +29,7 @@ FeatureContext = namedtuple('FeatureContext',
OnlineFeatureContext = namedtuple('OnlineFeatureContext',
['fcount',
+ 'fsample_count',
'paircount',
'bilex'
])
@@ -272,6 +273,7 @@ cdef class HieroCachingRuleFactory:
cdef IntList findexes1
cdef bint online
+ cdef samples_f
cdef phrases_f
cdef phrases_e
cdef phrases_fe
@@ -392,6 +394,9 @@ cdef class HieroCachingRuleFactory:
# True after data is added
self.online = False
+ # Keep track of everything that can be sampled:
+ self.samples_f = defaultdict(int)
+
# Phrase counts
self.phrases_f = defaultdict(int)
self.phrases_e = defaultdict(int)
@@ -1173,15 +1178,24 @@ cdef class HieroCachingRuleFactory:
# Online rule extraction and scoring
if self.online:
f_syms = tuple(word[0][0] for word in fwords)
- for (f, e, spanlen) in self.online_match(f_syms, seen_phrases):
- scores = self.scorer.score(FeatureContext(
- f, e, 0, 0, 0,
- spanlen, None, None,
- fwords, self.fda, self.eda,
- meta,
- self.online_ctx_lookup(f, e)))
- alignment = self.phrases_al[f][e]
- yield Rule(self.category, f, e, scores, alignment)
+ for (f, lex_i, lex_j) in self.get_f_phrases(f_syms):
+ spanlen = (lex_j - lex_i) + 1
+ if not sym_isvar(f[0]):
+ spanlen += 1
+ if not sym_isvar(f[1]):
+ spanlen += 1
+ for e in self.phrases_fe.get(f, ()):
+ if (f, e) not in seen_phrases:
+ # Don't add multiple instances of the same phrase here
+ seen_phrases.add((f, e))
+ scores = self.scorer.score(FeatureContext(
+ f, e, 0, 0, 0,
+ spanlen, None, None,
+ fwords, self.fda, self.eda,
+ meta,
+ self.online_ctx_lookup(f, e)))
+ alignment = self.phrases_al[f][e]
+ yield Rule(self.category, f, e, scores, alignment)
stop_time = monitor_cpu()
logger.info("Total time for rule lookup, extraction, and scoring = %f seconds", (stop_time - start_time))
@@ -2006,6 +2020,12 @@ cdef class HieroCachingRuleFactory:
continue
extract(f_i, f_i, f_len + 1, -1, f_i, 0, [], [], False)
+ # Update possible phrases (samples)
+ # This could be more efficiently integrated with extraction
+ # at the cost of readability
+ for (f, lex_i, lex_j) in self.get_f_phrases(f_words):
+ self.samples_f[f] += 1
+
# Update phrase counts
for rule in rules:
(f_ph, e_ph, al) = rule[:3]
@@ -2115,30 +2135,33 @@ cdef class HieroCachingRuleFactory:
for ph in self.phrases_e:
logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph]))
logger.info('FE')
+ self.dump_online_rules()
+
+ def dump_online_rules(self):
for ph in self.phrases_fe:
for ph2 in self.phrases_fe[ph]:
logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2]))
-
+
# Lookup online stats for phrase pair (f, e). Return None if no match.
# IMPORTANT: use get() to avoid adding items to defaultdict
def online_ctx_lookup(self, f, e):
if self.online:
fcount = self.phrases_f.get(f, 0)
+ fsample_count = self.samples_f.get(f, 0)
d = self.phrases_fe.get(f, None)
paircount = d.get(e, 0) if d else 0
- if paircount > 0:
- print 'Online support:', f, '|||', e
- return OnlineFeatureContext(fcount, paircount, self.bilex_fe)
+ return OnlineFeatureContext(fcount, fsample_count, paircount, self.bilex_fe)
return None
- # Match source words against online data.
- # Return (fphrase, ephrase, length)
- def online_match(self, f_words, seen_phrases):
-
+ # Find all phrases that we might try to extract
+ # (Used for EGivenFCoherent)
+ # Return set of (fphrase, lex_i, lex_j)
+ def get_f_phrases(self, f_words):
+
f_len = len(f_words)
- matches = {} # (f, e) = len
+ phrases = set() # (fphrase, lex_i, lex_j)
- def extract(f_i, f_j, wc, ntc, syms):
+ def extract(f_i, f_j, lex_i, lex_j, wc, ntc, syms):
# Phrase extraction limits
if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size:
return
@@ -2146,34 +2169,30 @@ cdef class HieroCachingRuleFactory:
if wc + ntc < self.max_length:
syms.append(f_words[f_j])
f = Phrase(syms)
- for e in self.phrases_fe[f]:
- if (f, e) not in seen_phrases:
- matches[(f, e)] = (f_j - f_i) + 1
- extract(f_i, f_j + 1, wc + 1, ntc, syms)
+ new_lex_i = min(lex_i, f_j)
+ new_lex_j = max(lex_j, f_j)
+ phrases.add((f, new_lex_i, new_lex_j))
+ extract(f_i, f_j + 1, new_lex_i, new_lex_j, wc + 1, ntc, syms)
syms.pop()
# Extend with existing non-terminal
if syms and sym_isvar(syms[-1]):
# Don't re-extract the same phrase
- extract(f_i, f_j + 1, wc, ntc, syms)
+ extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc, syms)
# Extend with new non-terminal
if wc + ntc < self.max_length:
if not syms or (ntc < self.max_nonterminals and not sym_isvar(syms[-1])):
- syms.append(sym_setindex(self.category, ntc))
+ syms.append(sym_setindex(self.category, ntc + 1))
f = Phrase(syms)
if wc > 0:
- for e in self.phrases_fe[f]:
- if (f, e) not in seen_phrases:
- matches[(f, e)] = (f_j - f_i) + 1
- extract(f_i, f_j + 1, wc, ntc + 1, syms)
+ phrases.add((f, lex_i, lex_j))
+ extract(f_i, f_j + 1, lex_i, lex_j, wc, ntc + 1, syms)
syms.pop()
# Try to extract phrases from every f index
for f_i from 0 <= f_i < f_len:
- extract(f_i, f_i, 0, 0, [])
-
- for line in sorted(' ||| '.join((str(f), str(e))) for (f, e) in matches):
- print 'Online new:', line
- return ((f, e, matches[(f, e)]) for (f, e) in matches)
+ extract(f_i, f_i, f_len, -1, 0, 0, [])
+
+ return phrases
# Spans are _inclusive_ on both ends [i, j]
def span_check(vec, i, j):