summaryrefslogtreecommitdiff
path: root/python/pkg/cdec/sa/extractor.py
blob: 0cf5f6b344d804e5a228979cda54c5ecac0a6515 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from itertools import chain
import os
import logging
import cdec.configobj
from cdec.sa.features import EgivenFCoherent, SampleCountF, CountEF,\
        MaxLexEgivenF, MaxLexFgivenE, IsSingletonF, IsSingletonFE
import cdec.sa

logger = logging.getLogger('cdec.sa')

# maximum span of a grammar rule in TEST DATA
MAX_INITIAL_SIZE = 15

class GrammarExtractor:
    def __init__(self, config, features=None):
        if isinstance(config, basestring):
            if not os.path.exists(config):
                raise IOError('cannot read configuration from {0}'.format(config))
            config = cdec.configobj.ConfigObj(config, unrepr=True)
        mmaped = config.get('memory_map', False)
        if mmaped:
            logger.info('Memory mapping parallel data')
        alignment = cdec.sa.Alignment(from_binary=config['a_file'], mmaped=mmaped)
        self.factory = cdec.sa.HieroCachingRuleFactory(
                # compiled alignment object (REQUIRED)
                alignment,
                # name of generic nonterminal used by Hiero
                category="[X]",
                # maximum number of contiguous chunks of terminal symbols in RHS of a rule
                max_chunks=config['max_nt']+1,
                # maximum span of a grammar rule in TEST DATA
                max_initial_size=MAX_INITIAL_SIZE,
                # maximum number of symbols (both T and NT) allowed in a rule
                max_length=config['max_len'],
                # maximum number of nonterminals allowed in a rule (set >2 at your own risk)
                max_nonterminals=config['max_nt'],
                # maximum number of contiguous chunks of terminal symbols
                # in target-side RHS of a rule.
                max_target_chunks=config['max_nt']+1,
                # maximum number of target side symbols (both T and NT) allowed in a rule.
                max_target_length=MAX_INITIAL_SIZE,
                # minimum span of a nonterminal in the RHS of a rule in TEST DATA
                min_gap_size=1,
                # filename of file containing precomputed collocations
                precompute_file=config['precompute_file'],
                # maximum frequency rank of patterns used to compute triples (< 20)
                precompute_secondary_rank=config['rank2'],
                # maximum frequency rank of patterns used to compute collocations (< 300)
                precompute_rank=config['rank1'],
                # require extracted rules to have at least one aligned word
                require_aligned_terminal=True,
                # require each contiguous chunk of extracted rules
                # to have at least one aligned word
                require_aligned_chunks=False,
                # maximum span of a grammar rule extracted from TRAINING DATA
                train_max_initial_size=config['max_size'],
                # minimum span of an RHS nonterminal in a rule extracted from TRAINING DATA
                train_min_gap_size=config['min_gap'],
                # False if phrases should be loose (better but slower), True otherwise
                tight_phrases=config.get('tight_phrases', True),
                )

        # lexical weighting tables
        tt = cdec.sa.BiLex(from_binary=config['lex_file'], mmaped=mmaped)

        # TODO: use @cdec.sa.features decorator for standard features too
        # + add a mask to disable features
        scorer = cdec.sa.Scorer(EgivenFCoherent, SampleCountF, CountEF, 
            MaxLexFgivenE(tt), MaxLexEgivenF(tt), IsSingletonF, IsSingletonFE,
            *cdec.sa._SA_FEATURES)

        fsarray = cdec.sa.SuffixArray(from_binary=config['f_sa_file'], mmaped=mmaped)
        edarray = cdec.sa.DataArray(from_binary=config['e_file'], mmaped=mmaped)

        # lower=faster, higher=better; improvements level off above 200-300 range,
        # -1 = don't sample, use all data (VERY SLOW!)
        sampler = cdec.sa.Sampler(300, fsarray)

        self.factory.configure(fsarray, edarray, sampler, scorer)
        # Initialize feature definitions with configuration
        for fn in cdec.sa._SA_CONFIGURE:
            fn(config)

    def grammar(self, sentence):
        if isinstance(sentence, unicode):
            sentence = sentence.encode('utf8')
        words = tuple(chain(('<s>',), sentence.split(), ('</s>',)))
        meta = cdec.sa.annotate(words)
        cnet = cdec.sa.make_lattice(words)
        return self.factory.input(cnet, meta)