diff options
| author | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-03-11 17:06:48 -0400 | 
|---|---|---|
| committer | Chris Dyer <cdyer@allegro.clab.cs.cmu.edu> | 2013-03-11 17:06:48 -0400 | 
| commit | 7cc5d9c568890be7fe01363fdf52912d8c6d5665 (patch) | |
| tree | ae0ab29586fd796f358b0fb075ae01a31eedaf5c /python | |
| parent | e81511381148fcb23e7803675ccacfc3c992c221 (diff) | |
| parent | a3a884706a10ee3029a2eec06e12bb610ea9df13 (diff) | |
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'python')
| -rwxr-xr-x | python/pkg/cdec/sa/online_extractor.py | 337 | 
1 files changed, 0 insertions, 337 deletions
| diff --git a/python/pkg/cdec/sa/online_extractor.py b/python/pkg/cdec/sa/online_extractor.py deleted file mode 100755 index 03a46b3b..00000000 --- a/python/pkg/cdec/sa/online_extractor.py +++ /dev/null @@ -1,337 +0,0 @@ -#!/usr/bin/env python - -import collections, sys - -import cdec.configobj - -CAT = '[X]'  # Default non-terminal -MAX_SIZE = 15  # Max span of a grammar rule (source) -MAX_LEN = 5  # Max number of terminals and non-terminals in a rule (source) -MAX_NT = 2  # Max number of non-terminals in a rule -MIN_GAP = 1  # Min number of terminals between non-terminals (source) - -# Spans are _inclusive_ on both ends [i, j] -# TODO: Replace all of this with bit vectors? -def span_check(vec, i, j): -    k = i -    while k <= j: -        if vec[k]: -            return False -        k += 1 -    return True - -def span_flip(vec, i, j): -    k = i -    while k <= j: -        vec[k] = ~vec[k] -        k += 1 - -# Next non-terminal -def next_nt(nt): -     if not nt: -         return 1 -     return nt[-1][0] + 1 - -class NonTerminal: -    def __init__(self, index): -        self.index = index -    def __str__(self): -        return '[X,{0}]'.format(self.index) - -def fmt_rule(f_sym, e_sym, links): -    a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links) -    return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(str(sym) for sym in f_sym), -                                                ' '.join(str(sym) for sym in e_sym), -                                                a_str) - -class OnlineGrammarExtractor: - -    def __init__(self, config=None): -        if isinstance(config, str) or isinstance(config, unicode): -            if not os.path.exists(config): -                raise IOError('cannot read configuration from {0}'.format(config)) -            config = cdec.configobj.ConfigObj(config, unrepr=True) -        elif not config: -            config = collections.defaultdict(lambda: None) -        self.category = CAT -        self.max_size = MAX_SIZE -        self.max_length = config['max_len'] or MAX_LEN -        self.max_nonterminals = config['max_nt'] or MAX_NT -        self.min_gap_size = MIN_GAP -        # Hard coded: require at least one aligned word -        # Hard coded: require tight phrases - -        # Phrase counts -        self.phrases_f = collections.defaultdict(lambda: 0) -        self.phrases_e = collections.defaultdict(lambda: 0) -        self.phrases_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) - -        # Bilexical counts -        self.bilex_f = collections.defaultdict(lambda: 0) -        self.bilex_e = collections.defaultdict(lambda: 0) -        self.bilex_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) - -    # Aggregate bilexical counts -    def aggr_bilex(self, f_words, e_words): - -        for e_w in e_words: -            self.bilex_e[e_w] += 1 - -        for f_w in f_words: -            self.bilex_f[f_w] += 1 -            for e_w in e_words: -                self.bilex_fe[f_w][e_w] += 1 - -    # Aggregate stats from a training instance: -    # Extract hierarchical phrase pairs -    # Update bilexical counts -    def add_instance(self, f_words, e_words, alignment): - -        # Bilexical counts -        self.aggr_bilex(f_words, e_words) - -        # Phrase pairs extracted from this instance -        phrases = set() - -        f_len = len(f_words) -        e_len = len(e_words) - -        # Pre-compute alignment info -        al = [[] for i in range(f_len)] -        al_span = [[f_len + 1, -1] for i in range(f_len)] -        for (f, e) in alignment: -            al[f].append(e) -            al_span[f][0] = min(al_span[f][0], e) -            al_span[f][1] = max(al_span[f][1], e) - -        # Target side word coverage -        # TODO: Does Cython do bit vectors? -        cover = [0] * e_len - -        # Extract all possible hierarchical phrases starting at a source index -        # f_ i and j are current, e_ i and j are previous -        def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open): -            # Phrase extraction limits -            if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \ -                    (f_j - f_i) + 1 > self.max_size: -                return -            # Unaligned word -            if not al[f_j]: -                # Open non-terminal: extend -                if nt_open: -                    nt[-1][2] += 1 -                    extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True) -                    nt[-1][2] -= 1 -                # No open non-terminal: extend with word -                else: -                    extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False) -                return -            # Aligned word -            link_i = al_span[f_j][0] -            link_j = al_span[f_j][1] -            new_e_i = min(link_i, e_i) -            new_e_j = max(link_j, e_j) -            # Open non-terminal: close, extract, extend -            if nt_open: -                # Close non-terminal, checking for collisions -                old_last_nt = nt[-1][:] -                nt[-1][2] = f_j -                if link_i < nt[-1][3]: -                    if not span_check(cover, link_i, nt[-1][3] - 1): -                        nt[-1] = old_last_nt -                        return -                    span_flip(cover, link_i, nt[-1][3] - 1) -                    nt[-1][3] = link_i -                if link_j > nt[-1][4]: -                    if not span_check(cover, nt[-1][4] + 1, link_j): -                        nt[-1] = old_last_nt -                        return -                    span_flip(cover, nt[-1][4] + 1, link_j) -                    nt[-1][4] = link_j -                for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): -                    phrases.add(rule) -                extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) -                nt[-1] = old_last_nt -                if link_i < nt[-1][3]: -                    span_flip(cover, link_i, nt[-1][3] - 1) -                if link_j > nt[-1][4]: -                    span_flip(cover, nt[-1][4] + 1, link_j) -                return -            # No open non-terminal -            # Extract, extend with word -            collision = False -            for link in al[f_j]: -                if cover[link]: -                    collision = True -            # Collisions block extraction and extension, but may be okay for -            # continuing non-terminals -            if not collision: -                plus_links = [] -                for link in al[f_j]: -                    plus_links.append((f_j, link)) -                    cover[link] = ~cover[link] -                links.append(plus_links) -                for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): -                    phrases.add(rule) -                extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False) -                links.pop() -                for link in al[f_j]: -                    cover[link] = ~cover[link] -            # Try to add a word to a (closed) non-terminal, extract, extend -            if nt and nt[-1][2] == f_j - 1: -                # Add to non-terminal, checking for collisions -                old_last_nt = nt[-1][:] -                nt[-1][2] = f_j -                if link_i < nt[-1][3]: -                    if not span_check(cover, link_i, nt[-1][3] - 1): -                        nt[-1] = old_last_nt -                        return -                    span_flip(cover, link_i, nt[-1][3] - 1) -                    nt[-1][3] = link_i -                if link_j > nt[-1][4]: -                    if not span_check(cover, nt[-1][4] + 1, link_j): -                        nt[-1] = old_last_nt -                        return -                    span_flip(cover, nt[-1][4] + 1, link_j) -                    nt[-1][4] = link_j -                # Require at least one word in phrase -                if links: -                    for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): -                        phrases.add(rule) -                extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) -                nt[-1] = old_last_nt -                if new_e_i < nt[-1][3]: -                    span_flip(cover, link_i, nt[-1][3] - 1) -                if link_j > nt[-1][4]: -                    span_flip(cover, nt[-1][4] + 1, link_j) -            # Try to start a new non-terminal, extract, extend -            if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals: -                # Check for collisions -                if not span_check(cover, link_i, link_j): -                    return -                span_flip(cover, link_i, link_j) -                nt.append([next_nt(nt), f_j, f_j, link_i, link_j]) -                # Require at least one word in phrase -                if links: -                    for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): -                        phrases.add(rule) -                extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) -                nt.pop() -                span_flip(cover, link_i, link_j) -            # TODO: try adding NT to start, end, both -            # check: one aligned word on boundary that is not part of a NT - -        # Try to extract phrases from every f index -        f_i = 0 -        while f_i < f_len: -            # Skip if phrases won't be tight on left side -            if not al[f_i]: -                f_i += 1 -                continue -            extract(f_i, f_i, f_len + 1, -1, 1, [], [], False) -            f_i += 1 - -        for rule in sorted(phrases): -            print rule - -    # Create a rule from source, target, non-terminals, and alignments -    def form_rules(self, f_i, e_i, f_span, e_span, nt, al): -     -        # This could be more efficient but is unlikely to be the bottleneck -     -        rules = [] -     -        nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) -     -        f_sym = f_span[:] -        off = f_i -        for next_nt in nt: -            nt_len = (next_nt[2] - next_nt[1]) + 1 -            i = 0 -            while i < nt_len: -                f_sym.pop(next_nt[1] - off) -                i += 1 -            f_sym.insert(next_nt[1] - off, NonTerminal(next_nt[0])) -            off += (nt_len - 1) -     -        e_sym = e_span[:] -        off = e_i -        for next_nt in nt_inv: -            nt_len = (next_nt[4] - next_nt[3]) + 1 -            i = 0 -            while i < nt_len: -                e_sym.pop(next_nt[3] - off) -                i += 1 -            e_sym.insert(next_nt[3] - off, NonTerminal(next_nt[0])) -            off += (nt_len - 1) -     -        # Adjusting alignment links takes some doing -        links = [list(link) for sub in al for link in sub] -        links_len = len(links) -        nt_len = len(nt) -        nt_i = 0 -        off = f_i -        i = 0 -        while i < links_len: -            while nt_i < nt_len and links[i][0] > nt[nt_i][1]: -                off += (nt[nt_i][2] - nt[nt_i][1]) -                nt_i += 1 -            links[i][0] -= off -            i += 1 -        nt_i = 0 -        off = e_i -        i = 0 -        while i < links_len: -            while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]: -                off += (nt_inv[nt_i][4] - nt_inv[nt_i][3]) -                nt_i += 1 -            links[i][1] -= off -            i += 1 -     -        # Rule -        rules.append(fmt_rule(f_sym, e_sym, links)) -        if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals: -            return rules -        last_index = nt[-1][0] if nt else 0 -        # Rule [X] -        if not nt or not isinstance(f_sym[-1], NonTerminal): -            f_sym.append(NonTerminal(last_index + 1)) -            e_sym.append(NonTerminal(last_index + 1)) -            rules.append(fmt_rule(f_sym, e_sym, links)) -            f_sym.pop() -            e_sym.pop() -        # [X] Rule -        if not nt or not isinstance(f_sym[0], NonTerminal): -            for sym in f_sym: -                if isinstance(sym, NonTerminal): -                    sym.index += 1 -            for sym in e_sym: -                if isinstance(sym, NonTerminal): -                    sym.index += 1 -            for link in links: -                link[0] += 1 -                link[1] += 1 -            f_sym.insert(0, NonTerminal(1)) -            e_sym.insert(0, NonTerminal(1)) -            rules.append(fmt_rule(f_sym, e_sym, links)) -        if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals: -            return rules -        # [X] Rule [X] -        if not nt or not isinstance(f_sym[-1], NonTerminal): -            f_sym.append(NonTerminal(last_index + 2)) -            e_sym.append(NonTerminal(last_index + 2)) -            rules.append(fmt_rule(f_sym, e_sym, links)) -        return rules - -def main(argv): - -    extractor = OnlineGrammarExtractor() - -    for line in sys.stdin: -        print >> sys.stderr, line.strip() -        f_words, e_words, a_str = (x.split() for x in line.split('|||')) -        alignment = sorted(tuple(int(y) for y in x.split('-')) for x in a_str) -        extractor.add_instance(f_words, e_words, alignment) - -if __name__ == '__main__': -    main(sys.argv) | 
