diff options
Diffstat (limited to 'gi/evaluation/confusion_matrix.py')
-rw-r--r-- | gi/evaluation/confusion_matrix.py | 118 |
1 files changed, 118 insertions, 0 deletions
diff --git a/gi/evaluation/confusion_matrix.py b/gi/evaluation/confusion_matrix.py new file mode 100644 index 00000000..c5e2a379 --- /dev/null +++ b/gi/evaluation/confusion_matrix.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +import sys, math, itertools, getopt + +def usage(): + print >>sys.stderr, 'Usage:', sys.argv[0], '[-s slash_threshold] [-p output] [-m] input-1 input-2' + sys.exit(0) + +optlist, args = getopt.getopt(sys.argv[1:], 'hs:') +slash_threshold = None +output_fname = None +show_matrix = False +for opt, arg in optlist: + if opt == '-s': + slash_threshold = int(arg) + elif opt == '-p': + output_fname = arg + elif opt == '-m': + show_matrix = True + else: + usage() +if len(args) != 2 or (not show_matrix and not output_fname): + usage() + +ginfile = open(args[0]) +pinfile = open(args[1]) + +if output_fname: + try: + import Image, ImageDraw + except ImportError: + print >>sys.stderr, "Error: Python Image Library not available. Did you forget to set your PYTHONPATH environment variable?" + sys.exit(1) + +N = 0 +gold_frequencies = {} +predict_frequencies = {} +joint_frequencies = {} + +for gline, pline in itertools.izip(ginfile, pinfile): + gparts = gline.split('||| ')[1].split() + pparts = pline.split('||| ')[1].split() + assert len(gparts) == len(pparts) + + for gpart, ppart in zip(gparts, pparts): + gtag = gpart.split(':',1)[1] + ptag = ppart.split(':',1)[1] + + if slash_threshold == None or gtag.count('/') + gtag.count('\\') <= slash_threshold: + joint_frequencies.setdefault((gtag, ptag), 0) + joint_frequencies[gtag,ptag] += 1 + + predict_frequencies.setdefault(ptag, 0) + predict_frequencies[ptag] += 1 + + gold_frequencies.setdefault(gtag, 0) + gold_frequencies[gtag] += 1 + + N += 1 + +# find top tags +gtags = gold_frequencies.items() +gtags.sort(lambda x,y: x[1]-y[1]) +gtags.reverse() +#gtags = gtags[:50] + +preds = predict_frequencies.items() +preds.sort(lambda x,y: x[1]-y[1]) +preds.reverse() + +if show_matrix: + print '%7s %7s' % ('pred', 'cnt'), + for gtag, gcount in gtags: print '%7s' % gtag, + print + print '=' * 80 + + for ptag, pcount in preds: + print '%7s %7d' % (ptag, pcount), + for gtag, gcount in gtags: + print '%7d' % joint_frequencies.get((gtag, ptag), 0), + print + + print '%7s %7d' % ('total', N), + for gtag, gcount in gtags: print '%7d' % gcount, + print + +if output_fname: + offset=10 + + image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255)) + #hsl(hue, saturation%, lightness%) + + # re-sort preds to get a better diagonal + if False: + ptags = [] + remaining = set(predict_frequencies.keys()) + for y, (gtag, gcount) in enumerate(gtags): + best = (None, 0) + for ptag in remaining: + #pcount = predict_frequencies[ptag] + p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount) + if p > best[1]: best = (ptag, p) + ptags.append(ptag) + remaining.remove(ptag) + if not remaining: break + + draw = ImageDraw.Draw(image) + for x, ptag in enumerate(ptags): + pcount = predict_frequencies[ptag] + minval = math.log(offset) + maxval = math.log(pcount + offset) + for y, (gtag, gcount) in enumerate(gtags): + f = math.log(offset + joint_frequencies.get((gtag, ptag), 0)) + z = int(240. * (maxval - f) / float(maxval - minval)) + #print x, y, z, f, maxval + draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z) + del draw + image.save(output_fname) |