summaryrefslogtreecommitdiff
path: root/python/src/sa
diff options
context:
space:
mode:
authorMichael Denkowski <michael.j.denkowski@gmail.com>2012-12-23 21:09:41 -0500
committerMichael Denkowski <michael.j.denkowski@gmail.com>2012-12-23 21:09:41 -0500
commit0b9a1c5ac21b7c403cf6bef017ed692c250b297e (patch)
tree4cdfc7db52a275cede104a2c8ee32e929a782e65 /python/src/sa
parent597d89c11db53e91bc011eab70fd613bbe6453e8 (diff)
NonTerminal class
Diffstat (limited to 'python/src/sa')
-rwxr-xr-xpython/src/sa/online_extractor.py98
1 files changed, 58 insertions, 40 deletions
diff --git a/python/src/sa/online_extractor.py b/python/src/sa/online_extractor.py
index 06eb5357..0c013cb3 100755
--- a/python/src/sa/online_extractor.py
+++ b/python/src/sa/online_extractor.py
@@ -4,11 +4,11 @@ import collections, sys
import cdec.configobj
-CAT = '[X]' # Default non-terminal
-MAX_SIZE = 15 # Max span of a grammar rule (source)
-MAX_LEN = 5 # Max number of terminals and non-terminals in a rule (source)
-MAX_NT = 2 # Max number of non-terminals in a rule
-MIN_GAP = 1 # Min number of terminals between non-terminals (source)
+CAT = '[X]' # Default non-terminal
+MAX_SIZE = 15 # Max span of a grammar rule (source)
+MAX_LEN = 5 # Max number of terminals and non-terminals in a rule (source)
+MAX_NT = 2 # Max number of non-terminals in a rule
+MIN_GAP = 1 # Min number of terminals between non-terminals (source)
# Spans are _inclusive_ on both ends [i, j]
# TODO: Replace all of this with bit vectors?
@@ -17,14 +17,14 @@ def span_check(vec, i, j):
while k <= j:
if vec[k]:
return False
- k += 1
+ k += 1
return True
def span_flip(vec, i, j):
k = i
while k <= j:
vec[k] = ~vec[k]
- k += 1
+ k += 1
# Next non-terminal
def next_nt(nt):
@@ -32,13 +32,27 @@ def next_nt(nt):
return 1
return nt[-1][0] + 1
+class NonTerminal:
+ def __init__(self, index):
+ self.index = index
+ def __str__(self):
+ return '[X,{0}]'.format(self.index)
+
+def fmt_rule(f_sym, e_sym, links):
+ a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links)
+ return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(str(sym) for sym in f_sym),
+ ' '.join(str(sym) for sym in e_sym),
+ a_str)
+
# Create a rule from source, target, non-terminals, and alignments
-def form_rule(f_i, e_i, f_span, e_span, nt, al):
-
+def form_rules(f_i, e_i, f_span, e_span, nt, al):
+
# This could be more efficient but is unlikely to be the bottleneck
-
+
+ rules = []
+
nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
-
+
f_sym = f_span[:]
off = f_i
for next_nt in nt:
@@ -47,7 +61,7 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al):
while i < nt_len:
f_sym.pop(next_nt[1] - off)
i += 1
- f_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0]))
+ f_sym.insert(next_nt[1] - off, NonTerminal(next_nt[0]))
off += (nt_len - 1)
e_sym = e_span[:]
@@ -58,9 +72,9 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al):
while i < nt_len:
e_sym.pop(next_nt[3] - off)
i += 1
- e_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0]))
+ e_sym.insert(next_nt[3] - off, NonTerminal(next_nt[0]))
off += (nt_len - 1)
-
+
# Adjusting alignment links takes some doing
links = [list(link) for sub in al for link in sub]
links_len = len(links)
@@ -73,7 +87,7 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al):
off += (nt[nt_i][2] - nt[nt_i][1])
nt_i += 1
links[i][0] -= off
- i += 1
+ i += 1
nt_i = 0
off = e_i
i = 0
@@ -82,13 +96,13 @@ def form_rule(f_i, e_i, f_span, e_span, nt, al):
off += (nt_inv[nt_i][4] - nt_inv[nt_i][3])
nt_i += 1
links[i][1] -= off
- i += 1
- a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links)
-
- return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(f_sym), ' '.join(e_sym), a_str)
+ i += 1
+
+ rules.append(fmt_rule(f_sym, e_sym, links))
+ return rules
class OnlineGrammarExtractor:
-
+
def __init__(self, config=None):
if isinstance(config, str) or isinstance(config, unicode):
if not os.path.exists(config):
@@ -103,23 +117,23 @@ class OnlineGrammarExtractor:
self.min_gap_size = MIN_GAP
# Hard coded: require at least one aligned word
# Hard coded: require tight phrases
-
+
# Phrase counts
self.phrases_f = collections.defaultdict(lambda: 0)
self.phrases_e = collections.defaultdict(lambda: 0)
self.phrases_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
-
+
# Bilexical counts
self.bilex_f = collections.defaultdict(lambda: 0)
self.bilex_e = collections.defaultdict(lambda: 0)
self.bilex_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
-
+
# Aggregate bilexical counts
def aggr_bilex(self, f_words, e_words):
-
+
for e_w in e_words:
self.bilex_e[e_w] += 1
-
+
for f_w in f_words:
self.bilex_f[f_w] += 1
for e_w in e_words:
@@ -129,15 +143,15 @@ class OnlineGrammarExtractor:
# Extract hierarchical phrase pairs
# Update bilexical counts
def add_instance(self, f_words, e_words, alignment):
-
+
# Bilexical counts
self.aggr_bilex(f_words, e_words)
-
+
# Phrase pairs extracted from this instance
phrases = set()
-
+
f_len = len(f_words)
-
+
# Pre-compute alignment info
al = [[] for i in range(f_len)]
al_span = [[f_len + 1, -1] for i in range(f_len)]
@@ -145,11 +159,11 @@ class OnlineGrammarExtractor:
al[f].append(e)
al_span[f][0] = min(al_span[f][0], e)
al_span[f][1] = max(al_span[f][1], e)
-
+
# Target side word coverage
# TODO: Does Cython do bit vectors?
cover = [0] * f_len
-
+
# Extract all possible hierarchical phrases starting at a source index
# f_ i and j are current, e_ i and j are previous
def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open):
@@ -170,7 +184,7 @@ class OnlineGrammarExtractor:
return
# Aligned word
link_i = al_span[f_j][0]
- link_j = al_span[f_j][1]
+ link_j = al_span[f_j][1]
new_e_i = min(link_i, e_i)
new_e_j = max(link_j, e_j)
# Open non-terminal: close, extract, extend
@@ -190,7 +204,8 @@ class OnlineGrammarExtractor:
return
span_flip(cover, nt[-1][4] + 1, link_j)
nt[-1][4] = link_j
- phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ phrases.add(rule)
extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
nt[-1] = old_last_nt
if link_i < nt[-1][3]:
@@ -212,7 +227,8 @@ class OnlineGrammarExtractor:
plus_links.append((f_j, link))
cover[link] = ~cover[link]
links.append(plus_links)
- phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ phrases.add(rule)
extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False)
links.pop()
for link in al[f_j]:
@@ -236,7 +252,8 @@ class OnlineGrammarExtractor:
nt[-1][4] = link_j
# Require at least one word in phrase
if links:
- phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ phrases.add(rule)
extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
nt[-1] = old_last_nt
if new_e_i < nt[-1][3]:
@@ -252,13 +269,14 @@ class OnlineGrammarExtractor:
nt.append([next_nt(nt), f_j, f_j, link_i, link_j])
# Require at least one word in phrase
if links:
- phrases.add(form_rule(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links))
+ for rule in form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
+ phrases.add(rule)
extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
nt.pop()
span_flip(cover, link_i, link_j)
# TODO: try adding NT to start, end, both
# check: one aligned word on boundary that is not part of a NT
-
+
# Try to extract phrases from every f index
f_i = 0
while f_i < f_len:
@@ -268,18 +286,18 @@ class OnlineGrammarExtractor:
continue
extract(f_i, f_i, f_len + 1, -1, 1, [], [], False)
f_i += 1
-
+
for rule in sorted(phrases):
print rule
def main(argv):
extractor = OnlineGrammarExtractor()
-
+
for line in sys.stdin:
f_words, e_words, a_str = (x.split() for x in line.split('|||'))
alignment = sorted(tuple(int(y) for y in x.split('-')) for x in a_str)
extractor.add_instance(f_words, e_words, alignment)
if __name__ == '__main__':
- main(sys.argv) \ No newline at end of file
+ main(sys.argv)