diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-10-22 12:07:20 +0100 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2012-10-22 12:07:20 +0100 |
commit | 5f98fe5c4f2a2090eeb9d30c030305a70a8347d1 (patch) | |
tree | 9b6002f850e6dea1e3400c6b19bb31a9cdf3067f /gi/evaluation/confusion_matrix.py | |
parent | cf9994131993b40be62e90e213b1e11e6b550143 (diff) | |
parent | 21825a09d97c2e0afd20512f306fb25fed55e529 (diff) |
Merge remote branch 'upstream/master'
Conflicts:
Jamroot
bjam
decoder/Jamfile
decoder/cdec.cc
dpmert/Jamfile
jam-files/sanity.jam
klm/lm/Jamfile
klm/util/Jamfile
mira/Jamfile
Diffstat (limited to 'gi/evaluation/confusion_matrix.py')
-rw-r--r-- | gi/evaluation/confusion_matrix.py | 123 |
1 files changed, 0 insertions, 123 deletions
diff --git a/gi/evaluation/confusion_matrix.py b/gi/evaluation/confusion_matrix.py deleted file mode 100644 index 2dd7aa47..00000000 --- a/gi/evaluation/confusion_matrix.py +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env python - -import sys, math, itertools, getopt - -def usage(): - print >>sys.stderr, 'Usage:', sys.argv[0], '[-s slash_threshold] [-p output] [-m] input-1 input-2' - sys.exit(0) - -optlist, args = getopt.getopt(sys.argv[1:], 'hs:mp:') -slash_threshold = None -output_fname = None -show_matrix = False -for opt, arg in optlist: - if opt == '-s': - slash_threshold = int(arg) - elif opt == '-p': - output_fname = arg - elif opt == '-m': - show_matrix = True - else: - usage() -if len(args) != 2 or (not show_matrix and not output_fname): - usage() - -ginfile = open(args[0]) -pinfile = open(args[1]) - -if output_fname: - try: - import Image, ImageDraw - except ImportError: - print >>sys.stderr, "Error: Python Image Library not available. Did you forget to set your PYTHONPATH environment variable?" - sys.exit(1) - -N = 0 -gold_frequencies = {} -predict_frequencies = {} -joint_frequencies = {} - -for gline, pline in itertools.izip(ginfile, pinfile): - gparts = gline.split('||| ')[1].split() - pparts = pline.split('||| ')[1].split() - assert len(gparts) == len(pparts) - - for gpart, ppart in zip(gparts, pparts): - gtag = gpart.split(':',1)[1] - ptag = ppart.split(':',1)[1] - - if slash_threshold == None or gtag.count('/') + gtag.count('\\') <= slash_threshold: - joint_frequencies.setdefault((gtag, ptag), 0) - joint_frequencies[gtag,ptag] += 1 - - predict_frequencies.setdefault(ptag, 0) - predict_frequencies[ptag] += 1 - - gold_frequencies.setdefault(gtag, 0) - gold_frequencies[gtag] += 1 - - N += 1 - -# find top tags -gtags = gold_frequencies.items() -gtags.sort(lambda x,y: x[1]-y[1]) -gtags.reverse() -#gtags = gtags[:50] - -preds = predict_frequencies.items() -preds.sort(lambda x,y: x[1]-y[1]) -preds.reverse() - -if show_matrix: - print '%7s %7s' % ('pred', 'cnt'), - for gtag, gcount in gtags: print '%7s' % gtag, - print - print '=' * 80 - - for ptag, pcount in preds: - print '%7s %7d' % (ptag, pcount), - for gtag, gcount in gtags: - print '%7d' % joint_frequencies.get((gtag, ptag), 0), - print - - print '%7s %7d' % ('total', N), - for gtag, gcount in gtags: print '%7d' % gcount, - print - -if output_fname: - offset=10 - - image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255)) - #hsl(hue, saturation%, lightness%) - - # re-sort preds to get a better diagonal - ptags=[] - if True: - ptags = map(lambda (p,c): p, preds) - else: - remaining = set(predict_frequencies.keys()) - for y, (gtag, gcount) in enumerate(gtags): - best = (None, 0) - for ptag in remaining: - #pcount = predict_frequencies[ptag] - p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount) - if p > best[1]: best = (ptag, p) - ptags.append(ptag) - remaining.remove(ptag) - if not remaining: break - - print 'Predicted tag ordering:', ' '.join(ptags) - print 'Gold tag ordering:', ' '.join(map(lambda (t,c): t, gtags)) - - draw = ImageDraw.Draw(image) - for x, ptag in enumerate(ptags): - pcount = predict_frequencies[ptag] - minval = math.log(offset) - maxval = math.log(pcount + offset) - for y, (gtag, gcount) in enumerate(gtags): - f = math.log(offset + joint_frequencies.get((gtag, ptag), 0)) - z = int(240. * (maxval - f) / float(maxval - minval)) - #print x, y, z, f, maxval - draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z) - del draw - image.save(output_fname) |