#!/usr/bin/env python import sys, math, itertools, getopt def usage(): print >>sys.stderr, 'Usage:', sys.argv[0], '[-s slash_threshold] [-p output] [-m] input-1 input-2' sys.exit(0) optlist, args = getopt.getopt(sys.argv[1:], 'hs:mp:') slash_threshold = None output_fname = None show_matrix = False for opt, arg in optlist: if opt == '-s': slash_threshold = int(arg) elif opt == '-p': output_fname = arg elif opt == '-m': show_matrix = True else: usage() if len(args) != 2 or (not show_matrix and not output_fname): usage() ginfile = open(args[0]) pinfile = open(args[1]) if output_fname: try: import Image, ImageDraw except ImportError: print >>sys.stderr, "Error: Python Image Library not available. Did you forget to set your PYTHONPATH environment variable?" sys.exit(1) N = 0 gold_frequencies = {} predict_frequencies = {} joint_frequencies = {} for gline, pline in itertools.izip(ginfile, pinfile): gparts = gline.split('||| ')[1].split() pparts = pline.split('||| ')[1].split() assert len(gparts) == len(pparts) for gpart, ppart in zip(gparts, pparts): gtag = gpart.split(':',1)[1] ptag = ppart.split(':',1)[1] if slash_threshold == None or gtag.count('/') + gtag.count('\\') <= slash_threshold: joint_frequencies.setdefault((gtag, ptag), 0) joint_frequencies[gtag,ptag] += 1 predict_frequencies.setdefault(ptag, 0) predict_frequencies[ptag] += 1 gold_frequencies.setdefault(gtag, 0) gold_frequencies[gtag] += 1 N += 1 # find top tags gtags = gold_frequencies.items() gtags.sort(lambda x,y: x[1]-y[1]) gtags.reverse() #gtags = gtags[:50] preds = predict_frequencies.items() preds.sort(lambda x,y: x[1]-y[1]) preds.reverse() if show_matrix: print '%7s %7s' % ('pred', 'cnt'), for gtag, gcount in gtags: print '%7s' % gtag, print print '=' * 80 for ptag, pcount in preds: print '%7s %7d' % (ptag, pcount), for gtag, gcount in gtags: print '%7d' % joint_frequencies.get((gtag, ptag), 0), print print '%7s %7d' % ('total', N), for gtag, gcount in gtags: print '%7d' % gcount, print if output_fname: offset=10 image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255)) #hsl(hue, saturation%, lightness%) # re-sort preds to get a better diagonal ptags=[] if True: ptags = map(lambda (p,c): p, preds) else: remaining = set(predict_frequencies.keys()) for y, (gtag, gcount) in enumerate(gtags): best = (None, 0) for ptag in remaining: #pcount = predict_frequencies[ptag] p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount) if p > best[1]: best = (ptag, p) ptags.append(ptag) remaining.remove(ptag) if not remaining: break print 'Predicted tag ordering:', ' '.join(ptags) print 'Gold tag ordering:', ' '.join(map(lambda (t,c): t, gtags)) draw = ImageDraw.Draw(image) for x, ptag in enumerate(ptags): pcount = predict_frequencies[ptag] minval = math.log(offset) maxval = math.log(pcount + offset) for y, (gtag, gcount) in enumerate(gtags): f = math.log(offset + joint_frequencies.get((gtag, ptag), 0)) z = int(240. * (maxval - f) / float(maxval - minval)) #print x, y, z, f, maxval draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z) del draw image.save(output_fname)