summaryrefslogtreecommitdiff
path: root/gi/evaluation/evaluate_entropy.py
diff options
context:
space:
mode:
authortrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 21:45:45 +0000
committertrevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-16 21:45:45 +0000
commit55933074dcf0e3d5e38fe6ad87b925d7f694ceb1 (patch)
tree1d8dfd42b3f272b044bda5a59a1c76be663b0ee7 /gi/evaluation/evaluate_entropy.py
parent1207aaee1f55dbaac8a46f37635a4d1baf392760 (diff)
Added pictures
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@299 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/evaluation/evaluate_entropy.py')
-rw-r--r--gi/evaluation/evaluate_entropy.py71
1 files changed, 68 insertions, 3 deletions
diff --git a/gi/evaluation/evaluate_entropy.py b/gi/evaluation/evaluate_entropy.py
index 55449c59..e4980ccf 100644
--- a/gi/evaluation/evaluate_entropy.py
+++ b/gi/evaluation/evaluate_entropy.py
@@ -4,8 +4,9 @@ import sys, math, itertools
ginfile = open(sys.argv[1])
pinfile = open(sys.argv[2])
-if len(sys.argv) >= 4:
+if len(sys.argv) > 3:
slash_threshold = int(sys.argv[3])
+ #print >>sys.stderr, 'slash threshold', slash_threshold
else:
slash_threshold = 99999
@@ -28,7 +29,7 @@ for gline, pline in itertools.izip(ginfile, pinfile):
gtag = gpart.split(':',1)[1]
ptag = ppart.split(':',1)[1]
- if gtag.count('/') <= slash_threshold:
+ if gtag.count('/') + gtag.count('\\') <= slash_threshold:
joint_frequencies.setdefault((gtag, ptag), 0)
joint_frequencies[gtag,ptag] += 1
@@ -48,4 +49,68 @@ for (gtag, ptag), cgp in joint_frequencies.items():
hg2p /= N
hp2g /= N
-print 'H(P|G)', hg2p, 'H(G|P)', hp2g, 'VI', hg2p + hp2g
+hg = 0
+for gtag, c in gold_frequencies.items():
+ hg -= c * (math.log(c, 2) - math.log(N, 2))
+hg /= N
+
+print 'H(P|G)', hg2p, 'H(G|P)', hp2g, 'VI', hg2p + hp2g, 'H(G)', hg
+
+# find top tags
+gtags = gold_frequencies.items()
+gtags.sort(lambda x,y: x[1]-y[1])
+gtags.reverse()
+#gtags = gtags[:50]
+
+print '%7s %7s' % ('pred', 'cnt'),
+for gtag, gcount in gtags: print '%7s' % gtag,
+print
+print '=' * 80
+
+preds = predict_frequencies.items()
+preds.sort(lambda x,y: x[1]-y[1])
+preds.reverse()
+for ptag, pcount in preds:
+ print '%7s %7d' % (ptag, pcount),
+ for gtag, gcount in gtags:
+ print '%7d' % joint_frequencies.get((gtag, ptag), 0),
+ print
+
+print '%7s %7d' % ('total', N),
+for gtag, gcount in gtags: print '%7d' % gcount,
+print
+
+if len(sys.argv) > 4:
+ # needs Python Image Library (PIL)
+ import Image, ImageDraw
+
+ offset=10
+
+ image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255))
+ #hsl(hue, saturation%, lightness%)
+
+ # resort preds to get a better diagonal
+ ptags = []
+ remaining = set(predict_frequencies.keys())
+ for y, (gtag, gcount) in enumerate(gtags):
+ best = (None, 0)
+ for ptag in remaining:
+ #pcount = predict_frequencies[ptag]
+ p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount)
+ if p > best[1]: best = (ptag, p)
+ ptags.append(ptag)
+ remaining.remove(ptag)
+ if not remaining: break
+
+ draw = ImageDraw.Draw(image)
+ for x, ptag in enumerate(ptags):
+ pcount = predict_frequencies[ptag]
+ minval = math.log(offset)
+ maxval = math.log(pcount + offset)
+ for y, (gtag, gcount) in enumerate(gtags):
+ f = math.log(offset + joint_frequencies.get((gtag, ptag), 0))
+ z = int(240. * (maxval - f) / float(maxval - minval))
+ #print x, y, z, f, maxval
+ draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z)
+ del draw
+ image.save(sys.argv[4])