summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi101
1 files changed, 96 insertions, 5 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index b10d25dd..b95c23df 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -23,7 +23,14 @@ FeatureContext = namedtuple('FeatureContext',
'test_sentence',
'f_text',
'e_text',
- 'meta'
+ 'meta',
+ 'online'
+ ])
+
+OnlineFeatureContext = namedtuple('OnlineFeatureContext',
+ ['fcount',
+ 'paircount',
+ 'bilex'
])
cdef int PRECOMPUTE = 0
@@ -264,6 +271,7 @@ cdef class HieroCachingRuleFactory:
cdef IntList findexes
cdef IntList findexes1
+ cdef bint online
cdef phrases_f
cdef phrases_e
cdef phrases_fe
@@ -381,6 +389,9 @@ cdef class HieroCachingRuleFactory:
# Online stats
+ # True after data is added
+ self.online = False
+
# Phrase counts
self.phrases_f = defaultdict(int)
self.phrases_e = defaultdict(int)
@@ -969,6 +980,11 @@ cdef class HieroCachingRuleFactory:
hit = 0
reachable_buffer = {}
+ # Phrase pairs processed by suffix array extractor. Do not re-extract
+ # during online extraction. This is probably the hackiest part of
+ # online grammar extraction.
+ seen_phrases = set()
+
# Do not cache between sentences
self.rules.root = ExtendedTrieNode(phrase_location=PhraseLocation())
@@ -1124,7 +1140,12 @@ cdef class HieroCachingRuleFactory:
f, e, count, fcount[f], num_samples,
(k,i+spanlen), locs, input_match,
fwords, self.fda, self.eda,
- meta))
+ meta,
+ # Include online stats. None if none.
+ self.online_ctx_lookup(f, e)))
+ # Phrase pair processed
+ if self.online:
+ seen_phrases.add((f, e))
yield Rule(self.category, f, e, scores, alignment)
if len(phrase) < self.max_length and i+spanlen < len(fwords) and pathlen+1 <= self.max_initial_size:
@@ -1148,7 +1169,20 @@ cdef class HieroCachingRuleFactory:
for (i, alt, pathlen) in frontier_nodes:
new_frontier.append((k, i, input_match + (i,), alt, pathlen, xnode, phrase +(xcat,), is_shadow_path))
frontier = new_frontier
-
+
+ # Online rule extraction and scoring
+ if self.online:
+ f_syms = tuple(word[0][0] for word in fwords)
+ for (f, e, spanlen) in self.online_match(f_syms, seen_phrases):
+ scores = self.scorer.score(FeatureContext(
+ f, e, 0, 0, 0,
+ spanlen, None, None,
+ fwords, self.fda, self.eda,
+ meta,
+ self.online_ctx_lookup(f, e)))
+ alignment = self.phrases_al[f][e]
+ yield Rule(self.category, f, e, scores, alignment)
+
stop_time = monitor_cpu()
logger.info("Total time for rule lookup, extraction, and scoring = %f seconds", (stop_time - start_time))
gc.collect()
@@ -1828,6 +1862,8 @@ cdef class HieroCachingRuleFactory:
# (Extract rules, update counts)
def add_instance(self, f_words, e_words, alignment):
+ self.online = True
+
# Rules extracted from this instance
# Track span of lexical items (terminals) to make
# sure we don't extract the same rule for the same
@@ -1974,7 +2010,7 @@ cdef class HieroCachingRuleFactory:
for rule in rules:
(f_ph, e_ph, al) = rule[:3]
self.phrases_f[f_ph] += 1
- self.phrases_e[e_ph] += 1
+ self.phrases_e[e_ph] += 1
self.phrases_fe[f_ph][e_ph] += 1
if not self.phrases_al[f_ph][e_ph]:
self.phrases_al[f_ph][e_ph] = al
@@ -1987,7 +2023,6 @@ cdef class HieroCachingRuleFactory:
for e_w in e_words:
self.bilex_fe[f_w][e_w] += 1
-
# Create a rule from source, target, non-terminals, and alignments
def form_rule(self, f_i, e_i, f_span, e_span, nt, al):
@@ -2083,7 +2118,63 @@ cdef class HieroCachingRuleFactory:
for ph in self.phrases_fe:
for ph2 in self.phrases_fe[ph]:
logger.info(self.fmt_rule(str(ph), str(ph2), self.phrases_al[ph][ph2]) + ' ||| ' + str(self.phrases_fe[ph][ph2]))
+
+ # Lookup online stats for phrase pair (f, e). Return None if no match.
+ # IMPORTANT: use get() to avoid adding items to defaultdict
+ def online_ctx_lookup(self, f, e):
+ if self.online:
+ fcount = self.phrases_f.get(f, 0)
+ d = self.phrases_fe.get(f, None)
+ paircount = d.get(e, 0) if d else 0
+ if paircount > 0:
+ print 'Online support:', f, '|||', e
+ return OnlineFeatureContext(fcount, paircount, self.bilex_fe)
+ return None
+
+ # Match source words against online data.
+ # Return (fphrase, ephrase, length)
+ def online_match(self, f_words, seen_phrases):
+ f_len = len(f_words)
+ matches = {} # (f, e) = len
+
+ def extract(f_i, f_j, wc, ntc, syms):
+ # Phrase extraction limits
+ if f_j > (f_len - 1) or (f_j - f_i) + 1 > self.max_initial_size:
+ return
+ # Extend with word
+ if wc + ntc < self.max_length:
+ syms.append(f_words[f_j])
+ f = Phrase(syms)
+ for e in self.phrases_fe[f]:
+ if (f, e) not in seen_phrases:
+ matches[(f, e)] = (f_j - f_i) + 1
+ extract(f_i, f_j + 1, wc + 1, ntc, syms)
+ syms.pop()
+ # Extend with existing non-terminal
+ if syms and sym_isvar(syms[-1]):
+ # Don't re-extract the same phrase
+ extract(f_i, f_j + 1, wc, ntc, syms)
+ # Extend with new non-terminal
+ if wc + ntc < self.max_length:
+ if not syms or (ntc < self.max_nonterminals and not sym_isvar(syms[-1])):
+ syms.append(sym_setindex(self.category, ntc))
+ f = Phrase(syms)
+ if wc > 0:
+ for e in self.phrases_fe[f]:
+ if (f, e) not in seen_phrases:
+ matches[(f, e)] = (f_j - f_i) + 1
+ extract(f_i, f_j + 1, wc, ntc + 1, syms)
+ syms.pop()
+
+ # Try to extract phrases from every f index
+ for f_i from 0 <= f_i < f_len:
+ extract(f_i, f_i, 0, 0, [])
+
+ for line in sorted(' ||| '.join((str(f), str(e))) for (f, e) in matches):
+ print 'Online new:', line
+ return ((f, e, matches[(f, e)]) for (f, e) in matches)
+
# Spans are _inclusive_ on both ends [i, j]
def span_check(vec, i, j):
k = i