1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
#!/usr/bin/env python
import sys, math, itertools, getopt
def usage():
print >>sys.stderr, 'Usage:', sys.argv[0], '[-s slash_threshold] [-p output] [-m] input-1 input-2'
sys.exit(0)
optlist, args = getopt.getopt(sys.argv[1:], 'hs:mp:')
slash_threshold = None
output_fname = None
show_matrix = False
for opt, arg in optlist:
if opt == '-s':
slash_threshold = int(arg)
elif opt == '-p':
output_fname = arg
elif opt == '-m':
show_matrix = True
else:
usage()
if len(args) != 2 or (not show_matrix and not output_fname):
usage()
ginfile = open(args[0])
pinfile = open(args[1])
if output_fname:
try:
import Image, ImageDraw
except ImportError:
print >>sys.stderr, "Error: Python Image Library not available. Did you forget to set your PYTHONPATH environment variable?"
sys.exit(1)
N = 0
gold_frequencies = {}
predict_frequencies = {}
joint_frequencies = {}
for gline, pline in itertools.izip(ginfile, pinfile):
gparts = gline.split('||| ')[1].split()
pparts = pline.split('||| ')[1].split()
assert len(gparts) == len(pparts)
for gpart, ppart in zip(gparts, pparts):
gtag = gpart.split(':',1)[1]
ptag = ppart.split(':',1)[1]
if slash_threshold == None or gtag.count('/') + gtag.count('\\') <= slash_threshold:
joint_frequencies.setdefault((gtag, ptag), 0)
joint_frequencies[gtag,ptag] += 1
predict_frequencies.setdefault(ptag, 0)
predict_frequencies[ptag] += 1
gold_frequencies.setdefault(gtag, 0)
gold_frequencies[gtag] += 1
N += 1
# find top tags
gtags = gold_frequencies.items()
gtags.sort(lambda x,y: x[1]-y[1])
gtags.reverse()
#gtags = gtags[:50]
preds = predict_frequencies.items()
preds.sort(lambda x,y: x[1]-y[1])
preds.reverse()
if show_matrix:
print '%7s %7s' % ('pred', 'cnt'),
for gtag, gcount in gtags: print '%7s' % gtag,
print
print '=' * 80
for ptag, pcount in preds:
print '%7s %7d' % (ptag, pcount),
for gtag, gcount in gtags:
print '%7d' % joint_frequencies.get((gtag, ptag), 0),
print
print '%7s %7d' % ('total', N),
for gtag, gcount in gtags: print '%7d' % gcount,
print
if output_fname:
offset=10
image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255))
#hsl(hue, saturation%, lightness%)
# re-sort preds to get a better diagonal
ptags=[]
if False:
remaining = set(predict_frequencies.keys())
for y, (gtag, gcount) in enumerate(gtags):
best = (None, 0)
for ptag in remaining:
#pcount = predict_frequencies[ptag]
p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount)
if p > best[1]: best = (ptag, p)
ptags.append(ptag)
remaining.remove(ptag)
if not remaining: break
draw = ImageDraw.Draw(image)
for x, ptag in enumerate(ptags):
pcount = predict_frequencies[ptag]
minval = math.log(offset)
maxval = math.log(pcount + offset)
for y, (gtag, gcount) in enumerate(gtags):
f = math.log(offset + joint_frequencies.get((gtag, ptag), 0))
z = int(240. * (maxval - f) / float(maxval - minval))
#print x, y, z, f, maxval
draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z)
del draw
image.save(output_fname)
|