summaryrefslogtreecommitdiff
path: root/gi/evaluation/extract_ccg_labels.py
blob: 014e03998d22f556d80820fa3033a34a70643949 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python

#
# Takes spans input along with treebank and spits out CG style categories for each span.
#   spans = output from CDEC's extools/extractor with --base_phrase_spans option
#   treebank = PTB format, one tree per line
# 
# Output is in CDEC labelled-span format
#

import sys, itertools, tree

tinfile = open(sys.argv[1])
einfile = open(sys.argv[2])

def number_leaves(node, next=0):
    left, right = None, None
    for child in node.children:
        l, r = number_leaves(child, next)
        next = max(next, r+1)
        if left == None or l < left:
            left = l
        if right == None or r > right:
            right = r

    #print node, left, right, next
    if left == None or right == None:
        assert not node.children
        left = right = next

    node.left = left
    node.right = right

    return left, right

def ancestor(node, indices):
    #print node, node.left, node.right, indices
    # returns the deepest node covering all the indices
    if min(indices) >= node.left and max(indices) <= node.right:
        # try the children
        for child in node.children:
            x = ancestor(child, indices)
            if x: return x
        return node
    else:
        return None

def frontier(node, indices):
    #print 'frontier for node', node, 'indices', indices
    if node.left > max(indices) or node.right < min(indices):
        #print '\toutside'
        return [node]
    elif node.children:
        #print '\tcovering at least part'
        ns = []
        for child in node.children:
            n = frontier(child, indices)
            ns.extend(n)
        return ns
    else:
        return [node]

for tline, eline in itertools.izip(tinfile, einfile):
    if tline.strip() != '(())':
        if tline.startswith('( '):
            tline = tline[2:-1].strip()
        tr = tree.parse_PST(tline)
        number_leaves(tr)
    else:
        tr = None
    
    zh, en, spans = eline.strip().split(" ||| ")
    print '|||',
    for span in spans.split():
        i, j, x, y = map(int, span.split("-"))

        if tr:
            a = ancestor(tr, range(x,y))
            fs = frontier(a, range(x,y))

            #print x, y
            #print 'ancestor', a
            #print 'frontier', fs

            cat = a.data.tag
            for f in fs:
                if f.right < x:
                    cat += '\\' + f.data.tag
                else:
                    break
            for f in reversed(fs):
                if f.left >= y:
                    cat += '/' + f.data.tag
                else:
                    break
        else:
            cat = 'FAIL'
            
        print '%d-%d:%s' % (x, y, cat),
    print