summaryrefslogtreecommitdiff
path: root/python/cdec/scfg/extractor.py
blob: 1dfa2421bde0323b88a5f0c51d7f3f16869238a5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import sys, os
import re
import StringIO
from itertools import chain

import clex
import rulefactory
import calignment
import csuf
import cdat
import sym
import log

from features import EgivenFCoherent, SampleCountF, CountEF,\
        MaxLexEgivenF, MaxLexFgivenE, IsSingletonF, IsSingletonFE
from features import contextless

log.level = -1

class Output(StringIO.StringIO):
    def close(self):
        pass

    def __str__(self):
        return self.getvalue()

def get_cn(sentence):
    sentence = chain(('<s>',), sentence.split(), ('</s>',))
    sentence = (sym.fromstring(word, terminal=True) for word in sentence)
    return tuple(((word, None, 1), ) for word in sentence)

class PhonyGrammar:
    def add(self, thing):
        pass

class GrammarExtractor:
    def __init__(self, cfg):
        if isinstance(cfg, dict):
            config = cfg
        elif isinstance(cfg, str):
            cfg_file = os.path.basename(cfg)
            if not re.match(r'^\w+\.py$', cfg_file):
                raise ValueError('Config must be a *.py file')
            sys.path.append(os.path.dirname(cfg))
            config =  __import__(cfg_file.replace('.py', '')).__dict__
            sys.path.pop()
        alignment = calignment.Alignment(config['a_file'], from_binary=True)
        self.factory = rulefactory.HieroCachingRuleFactory(
                # compiled alignment object (REQUIRED)
                alignment=alignment,
                # name of generic nonterminal used by Hiero
                category="[X]",
                # do not change for extraction
                grammar=PhonyGrammar(), # TODO: set to None?
                # maximum number of contiguous chunks of terminal symbols in RHS of a rule. If None, defaults to max_nonterminals+1
                max_chunks=None,
                # maximum span of a grammar rule in TEST DATA
                max_initial_size=15,
                # maximum number of symbols (both T and NT) allowed in a rule
                max_length=config['max_len'],
                # maximum number of nonterminals allowed in a rule (set >2 at your own risk)
                max_nonterminals=config['max_nt'],
                # maximum number of contiguous chunks of terminal symbols in target-side RHS of a rule. If None, defaults to max_nonterminals+1
                max_target_chunks=None,
                # maximum number of target side symbols (both T and NT) allowed in a rule. If None, defaults to max_initial_size
                max_target_length=None,
                # minimum span of a nonterminal in the RHS of a rule in TEST DATA
                min_gap_size=1,
                # filename of file containing precomputed collocations
                precompute_file=config['precompute_file'],
                # maximum frequency rank of patterns used to compute triples (don't set higher than 20).
                precompute_secondary_rank=config['rank2'],
                # maximum frequency rank of patterns used to compute collocations (no need to set higher than maybe 200-300)
                precompute_rank=config['rank1'],
                # require extracted rules to have at least one aligned word
                require_aligned_terminal=True,
                # require each contiguous chunk of extracted rules to have at least one aligned word
                require_aligned_chunks=False,
                # generate a complete grammar for each input sentence
                per_sentence_grammar=True,
                # maximum span of a grammar rule extracted from TRAINING DATA
                train_max_initial_size=config['max_size'],
                # minimum span of an RHS nonterminal in a rule extracted from TRAINING DATA
                train_min_gap_size=config['min_gap'],
                # True if phrases should be tight, False otherwise (False seems to give better results but is slower)
                tight_phrases=True,
                )
        self.fsarray = csuf.SuffixArray(config['f_sa_file'], from_binary=True)
        self.edarray = cdat.DataArray(config['e_file'], from_binary=True)

        self.factory.registerContext(self)

        # lower=faster, higher=better; improvements level off above 200-300 range, -1 = don't sample, use all data (VERY SLOW!)
        self.sampler = rulefactory.Sampler(300)
        self.sampler.registerContext(self)

        # lexical weighting tables
        tt = clex.CLex(config['lex_file'], from_binary=True)

        self.models = (EgivenFCoherent, SampleCountF, CountEF, 
                MaxLexFgivenE(tt), MaxLexEgivenF(tt), IsSingletonF, IsSingletonFE)
        self.models = tuple(contextless(feature) for feature in self.models)

    def grammar(self, sentence):
        if isinstance(sentence, unicode):
            sentence = sentence.encode('utf8')
        out = Output()
        cn = get_cn(sentence)
        self.factory.input(cn, output=out)
        return str(out)

def main(config):
    extractor = GrammarExtractor(config)
    sys.stdout.write(extractor.grammar(next(sys.stdin)))

if __name__ == '__main__':
    if len(sys.argv) != 2 or not sys.argv[1].endswith('.py'):
        sys.stderr.write('Usage: %s config.py\n' % sys.argv[0])
        sys.exit(1)
    main(*sys.argv[1:])