summaryrefslogtreecommitdiff
path: root/gi/evaluation/confusion_matrix.py
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-10-11 14:06:32 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-10-11 14:06:32 -0400
commit9339c80d465545aec5a6dccfef7c83ca715bf11f (patch)
tree64c56d558331edad1db3832018c80e799551c39a /gi/evaluation/confusion_matrix.py
parent438dac41810b7c69fa10203ac5130d20efa2da9f (diff)
parentafd7da3b2338661657ad0c4e9eec681e014d37bf (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'gi/evaluation/confusion_matrix.py')
-rw-r--r--gi/evaluation/confusion_matrix.py123
1 files changed, 0 insertions, 123 deletions
diff --git a/gi/evaluation/confusion_matrix.py b/gi/evaluation/confusion_matrix.py
deleted file mode 100644
index 2dd7aa47..00000000
--- a/gi/evaluation/confusion_matrix.py
+++ /dev/null
@@ -1,123 +0,0 @@
-#!/usr/bin/env python
-
-import sys, math, itertools, getopt
-
-def usage():
- print >>sys.stderr, 'Usage:', sys.argv[0], '[-s slash_threshold] [-p output] [-m] input-1 input-2'
- sys.exit(0)
-
-optlist, args = getopt.getopt(sys.argv[1:], 'hs:mp:')
-slash_threshold = None
-output_fname = None
-show_matrix = False
-for opt, arg in optlist:
- if opt == '-s':
- slash_threshold = int(arg)
- elif opt == '-p':
- output_fname = arg
- elif opt == '-m':
- show_matrix = True
- else:
- usage()
-if len(args) != 2 or (not show_matrix and not output_fname):
- usage()
-
-ginfile = open(args[0])
-pinfile = open(args[1])
-
-if output_fname:
- try:
- import Image, ImageDraw
- except ImportError:
- print >>sys.stderr, "Error: Python Image Library not available. Did you forget to set your PYTHONPATH environment variable?"
- sys.exit(1)
-
-N = 0
-gold_frequencies = {}
-predict_frequencies = {}
-joint_frequencies = {}
-
-for gline, pline in itertools.izip(ginfile, pinfile):
- gparts = gline.split('||| ')[1].split()
- pparts = pline.split('||| ')[1].split()
- assert len(gparts) == len(pparts)
-
- for gpart, ppart in zip(gparts, pparts):
- gtag = gpart.split(':',1)[1]
- ptag = ppart.split(':',1)[1]
-
- if slash_threshold == None or gtag.count('/') + gtag.count('\\') <= slash_threshold:
- joint_frequencies.setdefault((gtag, ptag), 0)
- joint_frequencies[gtag,ptag] += 1
-
- predict_frequencies.setdefault(ptag, 0)
- predict_frequencies[ptag] += 1
-
- gold_frequencies.setdefault(gtag, 0)
- gold_frequencies[gtag] += 1
-
- N += 1
-
-# find top tags
-gtags = gold_frequencies.items()
-gtags.sort(lambda x,y: x[1]-y[1])
-gtags.reverse()
-#gtags = gtags[:50]
-
-preds = predict_frequencies.items()
-preds.sort(lambda x,y: x[1]-y[1])
-preds.reverse()
-
-if show_matrix:
- print '%7s %7s' % ('pred', 'cnt'),
- for gtag, gcount in gtags: print '%7s' % gtag,
- print
- print '=' * 80
-
- for ptag, pcount in preds:
- print '%7s %7d' % (ptag, pcount),
- for gtag, gcount in gtags:
- print '%7d' % joint_frequencies.get((gtag, ptag), 0),
- print
-
- print '%7s %7d' % ('total', N),
- for gtag, gcount in gtags: print '%7d' % gcount,
- print
-
-if output_fname:
- offset=10
-
- image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255))
- #hsl(hue, saturation%, lightness%)
-
- # re-sort preds to get a better diagonal
- ptags=[]
- if True:
- ptags = map(lambda (p,c): p, preds)
- else:
- remaining = set(predict_frequencies.keys())
- for y, (gtag, gcount) in enumerate(gtags):
- best = (None, 0)
- for ptag in remaining:
- #pcount = predict_frequencies[ptag]
- p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount)
- if p > best[1]: best = (ptag, p)
- ptags.append(ptag)
- remaining.remove(ptag)
- if not remaining: break
-
- print 'Predicted tag ordering:', ' '.join(ptags)
- print 'Gold tag ordering:', ' '.join(map(lambda (t,c): t, gtags))
-
- draw = ImageDraw.Draw(image)
- for x, ptag in enumerate(ptags):
- pcount = predict_frequencies[ptag]
- minval = math.log(offset)
- maxval = math.log(pcount + offset)
- for y, (gtag, gcount) in enumerate(gtags):
- f = math.log(offset + joint_frequencies.get((gtag, ptag), 0))
- z = int(240. * (maxval - f) / float(maxval - minval))
- #print x, y, z, f, maxval
- draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z)
- del draw
- image.save(output_fname)