summaryrefslogtreecommitdiff
path: root/python/src/sa/rulefactory.pxi
diff options
context:
space:
mode:
authorMichael Denkowski <michael.j.denkowski@gmail.com>2013-01-04 23:08:29 -0500
committerMichael Denkowski <michael.j.denkowski@gmail.com>2013-01-04 23:08:29 -0500
commit9f8970752a07258a1cafd25840b408ccdbb8ee1c (patch)
tree57e4e722601ea50a379a38d4b0ab1edbeda7bc74 /python/src/sa/rulefactory.pxi
parentf036b0b57ebc6081134492fa920bc9a11cff4846 (diff)
Track source span to keep accurate phrase counts
Diffstat (limited to 'python/src/sa/rulefactory.pxi')
-rw-r--r--python/src/sa/rulefactory.pxi86
1 files changed, 22 insertions, 64 deletions
diff --git a/python/src/sa/rulefactory.pxi b/python/src/sa/rulefactory.pxi
index 3fcf8879..29bd809c 100644
--- a/python/src/sa/rulefactory.pxi
+++ b/python/src/sa/rulefactory.pxi
@@ -1829,6 +1829,10 @@ cdef class HieroCachingRuleFactory:
def add_instance(self, f_words, e_words, alignment):
# Rules extracted from this instance
+ # Track span of absolute alignment links to make
+ # sure we don't extract the same rule for the same
+ # span more than once.
+ # (f, e, al, f_min_link, f_max_link)
rules = set()
f_len = len(f_words)
@@ -1916,8 +1920,7 @@ cdef class HieroCachingRuleFactory:
if links:
# Make sure we cover all aligned words
if f_j >= new_min_bound:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
nt[-1] = old_last_nt
if link_i < nt[-1][3]:
@@ -1943,8 +1946,7 @@ cdef class HieroCachingRuleFactory:
links.append(plus_links)
if links:
if f_j >= new_min_bound:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc + 1, links, nt, False)
links.pop()
for link in al[f_j]:
@@ -1970,8 +1972,7 @@ cdef class HieroCachingRuleFactory:
nt[-1][4] = link_j
if links:
if f_j >= new_min_bound:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
nt[-1] = old_last_nt
if new_e_i < nt[-1][3]:
@@ -1991,8 +1992,7 @@ cdef class HieroCachingRuleFactory:
# Require at least one word in phrase
if links:
if f_j >= new_min_bound:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- rules.add(rule)
+ rules.add(self.form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
extract(f_i, f_j + 1, new_e_i, new_e_j, new_min_bound, wc, links, nt, False)
nt.pop()
span_dec(cover, link_i, link_j)
@@ -2009,21 +2009,16 @@ cdef class HieroCachingRuleFactory:
f_i += 1
for rule in sorted(rules):
- logger.info(self.fmt_rule(*rule))
+ logger.info(self.fmt_rule(*rule[:3]))
# Update phrase counts
- f_set = set()
- e_set = set()
- for (f_ph, e_ph, al) in rules:
- f_set.add(f_ph)
- e_set.add(e_ph)
+ for rule in rules:
+ (f_ph, e_ph, al) = rule[:3]
+ self.phrases_f[f_ph] += 1
+ self.phrases_e[e_ph] += 1
self.phrases_fe[f_ph][e_ph] += 1
if not self.phrases_al[f_ph][e_ph]:
self.phrases_al[f_ph][e_ph] = al
- for f_ph in f_set:
- self.phrases_f[f_ph] += 1
- for e_ph in e_set:
- self.phrases_e[e_ph] += 1
# Update Bilexical counts
for e_w in e_words:
@@ -2035,11 +2030,10 @@ cdef class HieroCachingRuleFactory:
# Create a rule from source, target, non-terminals, and alignments
- def form_rules(self, f_i, e_i, f_span, e_span, nt, al):
-
- # This could be more efficient but is unlikely to be the bottleneck
- rules = []
+ def form_rule(self, f_i, e_i, f_span, e_span, nt, al):
+ # Handle non-terminals
+
nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
f_sym = list(f_span[:])
@@ -2066,6 +2060,8 @@ cdef class HieroCachingRuleFactory:
# Adjusting alignment links takes some doing
links = [list(link) for sub in al for link in sub]
+ f_link_min = links[0][0]
+ f_link_max = links[-1][0]
links_inv = sorted(links, cmp=lambda x, y: cmp(x[1], y[1]))
links_len = len(links)
nt_len = len(nt)
@@ -2088,51 +2084,13 @@ cdef class HieroCachingRuleFactory:
links_inv[i][1] -= off
i += 1
- # Rule
- rules.append(self.new_rule(f_sym, e_sym, links))
- if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals:
- return rules
- # DEBUG: no boundary NTs
- return rules
- last_index = nt[-1][0] if nt else 0
- # Rule [X]
- if not nt or not sym_isvar(f_sym[-1]):
- f_sym.append(sym_setindex(self.category, last_index + 1))
- e_sym.append(sym_setindex(self.category, last_index + 1))
- rules.append(self.new_rule(f_sym, e_sym, links))
- f_sym.pop()
- e_sym.pop()
- # [X] Rule
- f_len = len(f_sym)
- e_len = len(e_sym)
- if not nt or not sym_isvar(f_sym[0]):
- for i from 0 <= i < f_len:
- if sym_isvar(f_sym[i]):
- f_sym[i] = sym_setindex(self.category, sym_getindex(f_sym[i]) + 1)
- for i from 0 <= i < e_len:
- if sym_isvar(e_sym[i]):
- e_sym[i] = sym_setindex(self.category, sym_getindex(e_sym[i]) + 1)
- for link in links:
- link[0] += 1
- link[1] += 1
- f_sym.insert(0, sym_setindex(self.category, 1))
- e_sym.insert(0, sym_setindex(self.category, 1))
- rules.append(self.new_rule(f_sym, e_sym, links))
- if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals:
- return rules
- # [X] Rule [X]
- if not nt or not sym_isvar(f_sym[-1]):
- f_sym.append(sym_setindex(self.category, last_index + 2))
- e_sym.append(sym_setindex(self.category, last_index + 2))
- rules.append(self.new_rule(f_sym, e_sym, links))
- return rules
-
- def new_rule(self, f_sym, e_sym, links):
+ # Create rule (f_phrase, e_phrase, links, f_link_min, f_link_max)
f = Phrase(f_sym)
e = Phrase(e_sym)
a = tuple(self.alignment.link(i, j) for (i, j) in links)
- return (f, e, a)
-
+ return (f, e, a, f_link_min, f_link_max)
+
+ # Rule string from rule
def fmt_rule(self, f, e, a):
a_str = ' '.join('{0}-{1}'.format(*self.alignment.unlink(packed)) for packed in a)
return '[X] ||| {0} ||| {1} ||| {2}'.format(f, e, a_str)