diff options
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r-- | python/src/sa/rulefactory.pxi | 103 |
1 files changed, 65 insertions, 38 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi index 81ea7960..c26f5c43 100644 --- a/python/src/sa/rulefactory.pxi +++ b/python/src/sa/rulefactory.pxi @@ -385,7 +385,7 @@ cdef class HieroCachingRuleFactory: self.phrases_f = defaultdict(int) self.phrases_e = defaultdict(int) self.phrases_fe = defaultdict(lambda: defaultdict(int)) - self.phrases_al = defaultdict(dict) + self.phrases_al = defaultdict(lambda: defaultdict(tuple)) # Bilexical counts self.bilex_f = defaultdict(int) @@ -1820,14 +1820,14 @@ cdef class HieroCachingRuleFactory: return extracts - # Aggregate stats from a training instance: - # Extract hierarchical phrase pairs - # Update bilexical counts + # + # Online grammar extraction handling + # + + # Aggregate stats from a training instance + # (Extract rules, update counts) def add_instance(self, f_words, e_words, alignment): - # Bilexical counts - self.aggr_bilex(f_words, e_words) - # Rules extracted from this instance rules = set() @@ -1843,7 +1843,6 @@ cdef class HieroCachingRuleFactory: al_span[f][1] = max(al_span[f][1], e) # Target side word coverage - # TODO: Does Cython do bit vectors? cover = [0] * e_len # Extract all possible hierarchical phrases starting at a source index @@ -1956,8 +1955,6 @@ cdef class HieroCachingRuleFactory: extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) nt.pop() span_flip(cover, link_i, link_j) - # TODO: try adding NT to start, end, both - # check: one aligned word on boundary that is not part of a NT # Try to extract phrases from every f index f_i = 0 @@ -1970,19 +1967,31 @@ cdef class HieroCachingRuleFactory: f_i += 1 for rule in sorted(rules): - logger.info(rule) - - # Aggregate bilexical counts - def aggr_bilex(self, f_words, e_words): - + logger.info(self.fmt_rule(*rule)) + + # Update phrase counts + f_set = set() + e_set = set() + for (f_ph, e_ph, al) in rules: + f_set.add(f_ph) + e_set.add(e_ph) + self.phrases_fe[f_ph][e_ph] += 1 + if not self.phrases_al[f_ph][e_ph]: + self.phrases_al[f_ph][e_ph] = al + for f_ph in f_set: + self.phrases_f[f_ph] += 1 + for e_ph in e_set: + self.phrases_e[e_ph] += 1 + + # Update Bilexical counts for e_w in e_words: self.bilex_e[e_w] += 1 - for f_w in f_words: self.bilex_f[f_w] += 1 for e_w in e_words: self.bilex_fe[f_w][e_w] += 1 + # Create a rule from source, target, non-terminals, and alignments def form_rules(self, f_i, e_i, f_span, e_span, nt, al): @@ -1992,8 +2001,6 @@ cdef class HieroCachingRuleFactory: nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) - logger.info(nt) - f_sym = list(f_span[:]) off = f_i for next_nt in nt: @@ -2040,7 +2047,7 @@ cdef class HieroCachingRuleFactory: i += 1 # Rule - rules.append(fmt_rule(f_sym, e_sym, links)) + rules.append(self.new_rule(f_sym, e_sym, links)) if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals: return rules last_index = nt[-1][0] if nt else 0 @@ -2048,7 +2055,7 @@ cdef class HieroCachingRuleFactory: if not nt or not sym_isvar(f_sym[-1]): f_sym.append(sym_setindex(self.category, last_index + 1)) e_sym.append(sym_setindex(self.category, last_index + 1)) - rules.append(fmt_rule(f_sym, e_sym, links)) + rules.append(self.new_rule(f_sym, e_sym, links)) f_sym.pop() e_sym.pop() # [X] Rule @@ -2066,24 +2073,54 @@ cdef class HieroCachingRuleFactory: link[1] += 1 f_sym.insert(0, sym_setindex(self.category, 1)) e_sym.insert(0, sym_setindex(self.category, 1)) - rules.append(fmt_rule(f_sym, e_sym, links)) + rules.append(self.new_rule(f_sym, e_sym, links)) if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals: return rules # [X] Rule [X] if not nt or not sym_isvar(f_sym[-1]): f_sym.append(sym_setindex(self.category, last_index + 2)) e_sym.append(sym_setindex(self.category, last_index + 2)) - rules.append(fmt_rule(f_sym, e_sym, links)) + rules.append(self.new_rule(f_sym, e_sym, links)) return rules + + def new_rule(self, f_sym, e_sym, links): + f = Phrase(f_sym) + e = Phrase(e_sym) + a = tuple(self.alignment.link(i, j) for (i, j) in links) + return (f, e, a) + + def fmt_rule(self, f, e, a): + a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a) + return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str) # Debugging def dump_online_stats(self): - logger.info(self.phrases_f) - logger.info(self.phrases_e) - logger.info(self.phrases_fe) - + logger.info('------------------------------') + logger.info(' Online Stats ') + logger.info('------------------------------') + logger.info('F') + for ph in self.phrases_f: + logger.info(str(ph) + ' ||| ' + str(self.phrases_f[ph])) + logger.info('E') + for ph in self.phrases_e: + logger.info(str(ph) + ' ||| ' + str(self.phrases_e[ph])) + logger.info('FE') + for ph in self.phrases_fe: + for ph2 in self.phrases_fe[ph]: + logger.info(str(ph) + ' ||| ' + str(ph2) + ' ||| ' + str(self.phrases_fe[ph][ph2])) + logger.info('f') + for w in self.bilex_f: + logger.info(sym_tostring(w) + ' : ' + str(self.bilex_f[w])) + logger.info('e') + for w in self.bilex_e: + logger.info(sym_tostring(w) + ' : ' + str(self.bilex_e[w])) + logger.info('fe') + for w in self.bilex_fe: + for w2 in self.bilex_fe[w]: + logger.info(sym_tostring(w) + ' : ' + sym_tostring(w2) + ' : ' + str(self.bilex_fe[w][w2])) + # Spans are _inclusive_ on both ends [i, j] -# TODO: Replace all of this with bit vectors? +# Could be more efficient but probably not a bottleneck def span_check(vec, i, j): k = i while k <= j: @@ -2096,14 +2133,4 @@ def span_flip(vec, i, j): k = i while k <= j: vec[k] = ~vec[k] - k += 1 - -def fmt_rule(f_sym, e_sym, links): - a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links) - return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(sym_tostring(sym) for sym in f_sym), - ' '.join(sym_tostring(sym) for sym in e_sym), - a_str) - - #(f, e, count, als) = e - #a = tuple('{0}-{1}'.format(packed/65536, packed%65536) for packed in als) - #logger.info("f: {0}, e: {1}, count: {2}, a: {3}".format(f, e, count, a)) + k += 1
\ No newline at end of file |