summaryrefslogtreecommitdiff
path: root/python/src/sa/online_extractor.py
blob: d41f3b397404286154ebfffe28f1f2fe35051a31 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
#!/usr/bin/env python

import collections, sys

def main(argv):

    for line in sys.stdin:
        src, tgt, astr = (x.split() for x in line.split('|||'))
        al = sorted(tuple(int(y) for y in x.split('-')) for x in astr)
        extract_and_aggr(src, tgt, al)

# Extract hierarchical phrase pairs
# This could be far better optimized by integrating it
# with suffix array code.  For now, it gets the job done.
def extract_and_aggr(src, tgt, al, max_len=5, max_size=15, max_nt=2, boundary_nt=True):
        
    src_ph = collections.defaultdict(lambda: 0) # src = count
    tgt_ph = collections.defaultdict(lambda: 0) # tgt = count
    # [src][tgt] = count
    phrase_pairs = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
    
    src_w = collections.defaultdict(lambda: 0) # count
    tgt_w = collections.defaultdict(lambda: 0) # count
    # [src][tgt] = count
    cooc_w = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
    
    # Bilexical counts
    for word in tgt:
        tgt_w[word] += 1
    for word in src:
        src_w[word] += 1
        for t_word in tgt:
            cooc_w[word][t_word] += 1

    def next_nt(nt):
        if not nt:
            return 1
        return nt[-1][0] + 1
    
    src_len = len(src)
    
    a = [[] for i in range(src_len)]
    
    # Pre-compute alignment min and max for each word
    a_span = [[src_len + 1, -1] for i in range(src_len)]
    for (s, t) in al:
        a[s].append(t)
        a_span[s][0] = min(a_span[s][0], t)
        a_span[s][1] = max(a_span[s][1], t)

    # Target side non-terimnal coverage
    # Cython bit vector?
    cover = [0] * src_len
    
    print src
    print tgt
    print a_span
    
    # Spans are _inclusive_ on both ends [i, j]
    def span_check(vec, i, j):
        k = i
        while k <= j:
            if vec[k]:
                return False
            k += 1 
        return True
    
    def span_flip(vec, i, j):
        k = i
        while k <= j:
            vec[k] = ~vec[k]
            k += 1 

    # Extract all possible hierarchical phrases starting at a source index
    # src i and j are current, tgt i and j are previous
    def extract(src_i, src_j, tgt_i, tgt_j, wc, al, nt, nt_open):
        # Phrase extraction limits
        if wc > max_len or (src_j + 1) >= src_len or \
                (src_j - src_i) + 1 > max_size or len(nt) > max_nt:
            return
        # Unaligned word
        if not a[src_j]:
            # Open non-terminal: extend
            if nt_open:
                nt[-1][2] += 1
                extract(src_i, src_j + 1, tgt_i, tgt_j, wc, al, nt, True)
                nt[-1][2] -= 1
            # No open non-terminal: extend with word
            else:
                extract(src_i, src_j + 1, tgt_i, tgt_j, wc + 1, al, nt, False)
            return
        # Aligned word
        link_i = a_span[src_j][0]
        link_j =  a_span[src_j][1]
        new_tgt_i = min(link_i, tgt_i)
        new_tgt_j = max(link_j, tgt_j)
        # Open non-terminal: close, extract, extend
        if nt_open:
            # Close non-terminal, checking for collisions
            old_last_nt = nt[-1][:]
            nt[-1][2] = src_j
            if link_i < nt[-1][3]:
                if not span_check(cover, link_i, nt[-1][3] - 1):
                    nt[-1] = old_last_nt
                    return
                span_flip(cover, link_i, nt[-1][3] - 1)
                nt[-1][3] = link_i
            if link_j > nt[-1][4]:
                if not span_check(cover, nt[-1][4] + 1, link_j):
                    nt[-1] = old_last_nt
                    return
                span_flip(cover, nt[-1][4] + 1, link_j)
                nt[-1][4] = link_j
            add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al)
            extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False)
            nt[-1] = old_last_nt
            if link_i < nt[-1][3]:
                span_flip(cover, link_i, nt[-1][3] - 1)
            if link_j > nt[-1][4]:
                span_flip(cover, nt[-1][4] + 1, link_j)
            return
        # No open non-terminal
        # Extract, extend with word
        collision = False
        for link in a[src_j]:
            if cover[link]:
                collision = True
        # Collisions block extraction and extension, but may be okay for
        # continuing non-terminals
        if not collision:
            plus_al = []
            for link in a[src_j]:
                plus_al.append((src_j, link))
                cover[link] = ~cover[link]
            al.append(plus_al)
            add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al)
            extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc + 1, al, nt, False)
            al.pop()
            for link in a[src_j]:
                cover[link] = ~cover[link]
        # Try to add a word to a (closed) non-terminal, extract, extend
        if nt and nt[-1][2] == src_j - 1:
            # Add to non-terminal, checking for collisions
            old_last_nt = nt[-1][:]
            nt[-1][2] = src_j
            if link_i < nt[-1][3]:
                if not span_check(cover, link_i, nt[-1][3] - 1):
                    nt[-1] = old_last_nt
                    return
                span_flip(cover, link_i, nt[-1][3] - 1)
                nt[-1][3] = link_i
            if link_j > nt[-1][4]:
                if not span_check(cover, nt[-1][4] + 1, link_j):
                    nt[-1] = old_last_nt
                    return
                span_flip(cover, nt[-1][4] + 1, link_j)
                nt[-1][4] = link_j
            # Require at least one word in phrase
            if al:
                add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al)
            extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False)
            nt[-1] = old_last_nt
            if new_tgt_i < nt[-1][3]:
                span_flip(cover, link_i, nt[-1][3] - 1)
            if link_j > nt[-1][4]:
                span_flip(cover, nt[-1][4] + 1, link_j)
        # Try to start a new non-terminal, extract, extend
        if not nt or src_j - nt[-1][2] > 1:
            # Check for collisions
            if not span_check(cover, link_i, link_j):
                return
            span_flip(cover, link_i, link_j)
            nt.append([next_nt(nt), src_j, src_j, link_i, link_j])
            # Require at least one word in phrase
            if al:
                add_rule(src_i, new_tgt_i, src[src_i:src_j + 1], tgt[new_tgt_i:new_tgt_j + 1], nt, al)
            extract(src_i, src_j + 1, new_tgt_i, new_tgt_j, wc, al, nt, False)
            nt.pop()
            span_flip(cover, link_i, link_j)
        # TODO: try adding NT to start, end, both
        # check: one aligned word on boundary that is not part of a NT
            
    # Try to extract phrases from every src index
    src_i = 0
    while src_i < src_len:
        # Skip if phrases won't be tight on left side
        if not a[src_i]:
            src_i += 1
            continue
        extract(src_i, src_i, src_len + 1, -1, 1, [], [], False)
        src_i += 1

# Create a rule from source, target, non-terminals, and alignments
def add_rule(src_i, tgt_i, src_span, tgt_span, nt, al):
    flat = (item for sub in al for item in sub)
    astr = ' '.join('{0}-{1}'.format(x[0], x[1]) for x in flat)
    
#    print '--- Rule'
#    print src_span
#    print tgt_span
#    print nt
#    print astr
#    print '---'
    
    # This could be more efficient but is probably not going to
    # be the bottleneck
    src_sym = src_span[:]
    off = src_i
    for next_nt in nt:
        nt_len = (next_nt[2] - next_nt[1]) + 1
        i = 0
        while i < nt_len:
            src_sym.pop(next_nt[1] - off)
            i += 1
        src_sym.insert(next_nt[1] - off, '[X,{0}]'.format(next_nt[0]))
        off += (nt_len - 1)
    tgt_sym = tgt_span[:]
    off = tgt_i
    for next_nt in sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])):
        nt_len = (next_nt[4] - next_nt[3]) + 1
        i = 0
        while i < nt_len:
            tgt_sym.pop(next_nt[3] - off)
            i += 1
        tgt_sym.insert(next_nt[3] - off, '[X,{0}]'.format(next_nt[0]))
        off += (nt_len - 1)
    print '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(src_sym), ' '.join(tgt_sym), astr)

if __name__ == '__main__':
    main(sys.argv)