summaryrefslogtreecommitdiff
path: root/python/cdec/sa/extractor.py
blob: 777f5afd6c7ee07c818a2358662db8d21eb13fed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from itertools import chain
import logging
import os
import sys
import cdec.configobj
from cdec.sa._sa import gzip_or_text
from cdec.sa.features import EgivenFCoherent, SampleCountF, CountEF,\
        MaxLexEgivenF, MaxLexFgivenE, IsSingletonF, IsSingletonFE,\
        IsSupportedOnline
import cdec.sa

# maximum span of a grammar rule in TEST DATA
MAX_INITIAL_SIZE = 15

class GrammarExtractor:
    def __init__(self, config, online=False, features=None):

        logging.basicConfig(level=logging.INFO)
        logger = logging.getLogger('cdec.sa')

        if isinstance(config, basestring):
            if not os.path.exists(config):
                raise IOError('cannot read configuration from {0}'.format(config))
            config = cdec.configobj.ConfigObj(config, unrepr=True)

        logger.info('Loading alignment...')
        alignment = cdec.sa.Alignment(from_binary=config['a_file'])

        # lexical weighting tables
        if not online:
            logger.info('Loading bilexical dictionary...')
            tt = cdec.sa.BiLex(from_binary=config['lex_file'])
        else:
            logger.info('Loading online bilexical dictionary...')
            tt = cdec.sa.online.Bilex(config['bilex_file'])

        self.factory = cdec.sa.HieroCachingRuleFactory(
                # compiled alignment object (REQUIRED)
                alignment,
                # bilexical dictionary if online
                bilex=tt if online else None,
                # name of generic nonterminal used by Hiero
                category="[X]",
                # maximum number of contiguous chunks of terminal symbols in RHS of a rule
                max_chunks=config['max_nt']+1,
                # maximum span of a grammar rule in TEST DATA
                max_initial_size=MAX_INITIAL_SIZE,
                # maximum number of symbols (both T and NT) allowed in a rule
                max_length=config['max_len'],
                # maximum number of nonterminals allowed in a rule (set >2 at your own risk)
                max_nonterminals=config['max_nt'],
                # maximum number of contiguous chunks of terminal symbols
                # in target-side RHS of a rule.
                max_target_chunks=config['max_nt']+1,
                # maximum number of target side symbols (both T and NT) allowed in a rule.
                max_target_length=MAX_INITIAL_SIZE,
                # minimum span of a nonterminal in the RHS of a rule in TEST DATA
                min_gap_size=1,
                # filename of file containing precomputed collocations
                precompute_file=config['precompute_file'],
                # maximum frequency rank of patterns used to compute triples (< 20)
                precompute_secondary_rank=config['rank2'],
                # maximum frequency rank of patterns used to compute collocations (< 300)
                precompute_rank=config['rank1'],
                # require extracted rules to have at least one aligned word
                require_aligned_terminal=True,
                # require each contiguous chunk of extracted rules
                # to have at least one aligned word
                require_aligned_chunks=False,
                # maximum span of a grammar rule extracted from TRAINING DATA
                train_max_initial_size=config['max_size'],
                # minimum span of an RHS nonterminal in a rule extracted from TRAINING DATA
                train_min_gap_size=config['min_gap'],
                # False if phrases should be loose (better but slower), True otherwise
                tight_phrases=config.get('tight_phrases', True),
                )

        # TODO: clean this up
        # Load data and add features for online grammar extraction
        extended_features = []
        if online:
            extended_features.append(IsSupportedOnline)
            
        # TODO: use @cdec.sa.features decorator for standard features too
        # + add a mask to disable features
        for f in cdec.sa._SA_FEATURES:
            extended_features.append(f)
            
        scorer = cdec.sa.Scorer(EgivenFCoherent, SampleCountF, CountEF, 
            MaxLexFgivenE(tt), MaxLexEgivenF(tt), IsSingletonF, IsSingletonFE,
            *extended_features)

        fsarray = cdec.sa.SuffixArray(from_binary=config['f_sa_file'])
        edarray = cdec.sa.DataArray(from_binary=config['e_file'])

        # lower=faster, higher=better; improvements level off above 200-300 range,
        # -1 = don't sample, use all data (VERY SLOW!)
        sampler = cdec.sa.Sampler(300, fsarray)

        self.factory.configure(fsarray, edarray, sampler, scorer)
        # Initialize feature definitions with configuration
        for fn in cdec.sa._SA_CONFIGURE:
            fn(config)

    def grammar(self, sentence, ctx_name=None):
        if isinstance(sentence, unicode):
            sentence = sentence.encode('utf8')
        words = tuple(chain(('<s>',), sentence.split(), ('</s>',)))
        meta = cdec.sa.annotate(words)
        cnet = cdec.sa.make_lattice(words)
        return self.factory.input(cnet, meta, ctx_name)

    # Add training instance to data
    def add_instance(self, sentence, reference, alignment, ctx_name=None):
        f_words = cdec.sa.encode_words(sentence.split())
        e_words = cdec.sa.encode_words(reference.split())
        al = sorted(tuple(int(i) for i in pair.split('-')) for pair in alignment.split())
        self.factory.add_instance(f_words, e_words, al, ctx_name)

    # Remove all incremental data for a context
    def drop_ctx(self, ctx_name=None):
        self.factory.drop_ctx(ctx_name)