summaryrefslogtreecommitdiff
path: root/gi/evaluation/extract_ccg_labels.py
blob: e0034648637c421aed28fbb1ff89a4b0068d8edf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python

#
# Takes spans input along with treebank and spits out CG style categories for each span.
#   spans = output from CDEC's extools/extractor with --base_phrase_spans option
#   treebank = PTB format, one tree per line
# 
# Output is in CDEC labelled-span format
#

import sys, itertools, tree

tinfile = open(sys.argv[1])
einfile = open(sys.argv[2])

def number_leaves(node, next=0):
    left, right = None, None
    for child in node.children:
        l, r = number_leaves(child, next)
        next = max(next, r+1)
        if left == None or l < left:
            left = l
        if right == None or r > right:
            right = r

    #print node, left, right, next
    if left == None or right == None:
        assert not node.children
        left = right = next

    node.left = left
    node.right = right

    return left, right

def ancestor(node, indices):
    #print node, node.left, node.right, indices
    # returns the deepest node covering all the indices
    if min(indices) >= node.left and max(indices) <= node.right:
        # try the children
        for child in node.children:
            x = ancestor(child, indices)
            if x: return x
        return node
    else:
        return None

def frontier(node, indices):
    #print 'frontier for node', node, 'indices', indices
    if node.left > max(indices) or node.right < min(indices):
        #print '\toutside'
        return [node]
    elif node.children:
        #print '\tcovering at least part'
        ns = []
        for child in node.children:
            n = frontier(child, indices)
            ns.extend(n)
        return ns
    else:
        return [node]

def project_heads(node):
    #print 'project_heads', node
    is_head = node.data.tag.endswith('-HEAD')
    if node.children:
        found = 0
        for child in node.children:
            x = project_heads(child)
            if x:
                node.data.tag = x
                found += 1
        assert found == 1
    elif is_head:
        node.data.tag = node.data.tag[:-len('-HEAD')]

    if is_head:
        return node.data.tag
    else:
        return None

for tline, eline in itertools.izip(tinfile, einfile):
    if tline.strip() != '(())':
        if tline.startswith('( '):
            tline = tline[2:-1].strip()
        tr = tree.parse_PST(tline)
	if tr != None:
		number_leaves(tr)
		#project_heads(tr) # assumes Bikel-style head annotation for the input trees
    else:
        tr = None
    
    parts = eline.strip().split(" ||| ")
    zh, en = parts[:2]
    spans = parts[-1]
    print '|||',
    for span in spans.split():
        sps = span.split(":")
        i, j, x, y = map(int, sps[0].split("-"))

        if tr:
            a = ancestor(tr, range(x,y))
	    try:
		fs = frontier(a, range(x,y))
	    except:
		print >>sys.stderr, "problem with line", tline.strip(), "--", eline.strip()
		raise

            #print x, y
            #print 'ancestor', a
            #print 'frontier', fs

            cat = a.data.tag
            for f in fs:
                if f.right < x:
                    cat += '\\' + f.data.tag
                else:
                    break
            fs.reverse()
            for f in fs:
                if f.left >= y:
                    cat += '/' + f.data.tag
                else:
                    break
        else:
            cat = 'FAIL'
            
        print '%d-%d:%s' % (x, y, cat),
    print