summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi103
1 files changed, 65 insertions, 38 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index 81ea7960..c26f5c43 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -385,7 +385,7 @@ cdef class HieroCachingRuleFactory:
self.phrases_f = defaultdict(int)
self.phrases_e = defaultdict(int)
self.phrases_fe = defaultdict(lambda: defaultdict(int))
- self.phrases_al = defaultdict(dict)
+ self.phrases_al = defaultdict(lambda: defaultdict(tuple))
# Bilexical counts
self.bilex_f = defaultdict(int)
@@ -1820,14 +1820,14 @@ cdef class HieroCachingRuleFactory:
return extracts
- # Aggregate stats from a training instance:
- # Extract hierarchical phrase pairs
- # Update bilexical counts
+ #
+ # Online grammar extraction handling
+ #
+
+ # Aggregate stats from a training instance
+ # (Extract rules, update counts)
def add_instance(self, f_words, e_words, alignment):
- # Bilexical counts
- self.aggr_bilex(f_words, e_words)
-
# Rules extracted from this instance
rules = set()
@@ -1843,7 +1843,6 @@ cdef class HieroCachingRuleFactory:
al_span[f][1] = max(al_span[f][1], e)
# Target side word coverage
- # TODO: Does Cython do bit vectors?
cover = [0] * e_len
# Extract all possible hierarchical phrases starting at a source index
@@ -1956,8 +1955,6 @@ cdef class HieroCachingRuleFactory:
extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
nt.pop()
span_flip(cover, link_i, link_j)
- # TODO: try adding NT to start, end, both
- # check: one aligned word on boundary that is not part of a NT
# Try to extract phrases from every f index
f_i = 0
@@ -1970,19 +1967,31 @@ cdef class HieroCachingRuleFactory:
f_i += 1
for rule in sorted(rules):
- logger.info(rule)
-
- # Aggregate bilexical counts
- def aggr_bilex(self, f_words, e_words):
-
+ logger.info(self.fmt_rule(*rule))
+
+ # Update phrase counts
+ f_set = set()
+ e_set = set()
+ for (f_ph, e_ph, al) in rules:
+ f_set.add(f_ph)
+ e_set.add(e_ph)
+ self.phrases_fe[f_ph][e_ph] += 1
+ if not self.phrases_al[f_ph][e_ph]:
+ self.phrases_al[f_ph][e_ph] = al
+ for f_ph in f_set:
+ self.phrases_f[f_ph] += 1
+ for e_ph in e_set:
+ self.phrases_e[e_ph] += 1
+
+ # Update Bilexical counts
for e_w in e_words:
self.bilex_e[e_w] += 1
-
for f_w in f_words:
self.bilex_f[f_w] += 1
for e_w in e_words:
self.bilex_fe[f_w][e_w] += 1
+
# Create a rule from source, target, non-terminals, and alignments
def form_rules(self, f_i, e_i, f_span, e_span, nt, al):
@@ -1992,8 +2001,6 @@ cdef class HieroCachingRuleFactory:
nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
- logger.info(nt)
-
f_sym = list(f_span[:])
off = f_i
for next_nt in nt:
@@ -2040,7 +2047,7 @@ cdef class HieroCachingRuleFactory:
i += 1
# Rule
- rules.append(fmt_rule(f_sym, e_sym, links))
+ rules.append(self.new_rule(f_sym, e_sym, links))
if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals:
return rules
last_index = nt[-1][0] if nt else 0
@@ -2048,7 +2055,7 @@ cdef class HieroCachingRuleFactory:
if not nt or not sym_isvar(f_sym[-1]):
f_sym.append(sym_setindex(self.category, last_index + 1))
e_sym.append(sym_setindex(self.category, last_index + 1))
- rules.append(fmt_rule(f_sym, e_sym, links))
+ rules.append(self.new_rule(f_sym, e_sym, links))
f_sym.pop()
e_sym.pop()
# [X] Rule
@@ -2066,24 +2073,54 @@ cdef class HieroCachingRuleFactory:
link[1] += 1
f_sym.insert(0, sym_setindex(self.category, 1))
e_sym.insert(0, sym_setindex(self.category, 1))
- rules.append(fmt_rule(f_sym, e_sym, links))
+ rules.append(self.new_rule(f_sym, e_sym, links))
if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals:
return rules
# [X] Rule [X]
if not nt or not sym_isvar(f_sym[-1]):
f_sym.append(sym_setindex(self.category, last_index + 2))
e_sym.append(sym_setindex(self.category, last_index + 2))
- rules.append(fmt_rule(f_sym, e_sym, links))
+ rules.append(self.new_rule(f_sym, e_sym, links))
return rules
+
+ def new_rule(self, f_sym, e_sym, links):
+ f = Phrase(f_sym)
+ e = Phrase(e_sym)
+ a = tuple(self.alignment.link(i, j) for (i, j) in links)
+ return (f, e, a)
+
+ def fmt_rule(self, f, e, a):
+ a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a)
+ return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str)
# Debugging
def dump_online_stats(self):
- logger.info(self.phrases_f)
- logger.info(self.phrases_e)
- logger.info(self.phrases_fe)
-
+ logger.info('------------------------------')
+ logger.info(' Online Stats ')
+ logger.info('------------------------------')
+ logger.info('F')
+ for ph in self.phrases_f:
+ logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph]))
+ logger.info('E')
+ for ph in self.phrases_e:
+ logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph]))
+ logger.info('FE')
+ for ph in self.phrases_fe:
+ for ph2 in self.phrases_fe[ph]:
+ logger.info(str(ph) + ' ||| ' + str(ph2) + ' ||| ' + str(self.phrases_fe[ph][ph2]))
+ logger.info('f')
+ for w in self.bilex_f:
+ logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w]))
+ logger.info('e')
+ for w in self.bilex_e:
+ logger.info(sym_tostring(w) + ' : ' + str(self.bilex_e[w]))
+ logger.info('fe')
+ for w in self.bilex_fe:
+ for w2 in self.bilex_fe[w]:
+ logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2]))
+
# Spans are _inclusive_ on both ends [i, j]
-# TODO: Replace all of this with bit vectors?
+# Could be more efficient but probably not a bottleneck
def span_check(vec, i, j):
k = i
while k <= j:
@@ -2096,14 +2133,4 @@ def span_flip(vec, i, j):
k = i
while k <= j:
vec[k] = ~vec[k]
- k += 1
-
-def fmt_rule(f_sym, e_sym, links):
- a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links)
- return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(sym_tostring(sym) for sym in f_sym),
- ' '.join(sym_tostring(sym) for sym in e_sym),
- a_str)
-
- #(f, e, count, als) = e
- #a = tuple('{0}-{1}'.format(packed/65536, packed%65536) for packed in als)
- #logger.info("f: {0}, e: {1}, count: {2}, a: {3}".format(f, e, count, a))
+ k += 1 \ No newline at end of file