From dedfa33b198c8bbb4af879efe73607f2dc4de584 Mon Sep 17 00:00:00 2001 From: "trevor.cohn" Date: Tue, 20 Jul 2010 21:54:01 +0000 Subject: fixed typo git-svn-id: https://ws10smt.googlecode.com/svn/trunk@343 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/evaluation/entropy.py | 4 +- gi/evaluation/evaluate_entropy.py | 117 ++++++++++++++++++++++++++++++++++++ gi/evaluation/extract_ccg_labels.py | 10 ++- 3 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 gi/evaluation/evaluate_entropy.py (limited to 'gi') diff --git a/gi/evaluation/entropy.py b/gi/evaluation/entropy.py index cef0dbb4..ec1ef502 100644 --- a/gi/evaluation/entropy.py +++ b/gi/evaluation/entropy.py @@ -26,8 +26,8 @@ for line in infile: tag = part.split(':',1)[1] if slash_threshold == None or tag.count('/') + tag.count('\\') <= slash_threshold: - frequencies.setdefault(gtag, 0) - frequencies[gtag] += 1 + frequencies.setdefault(tag, 0) + frequencies[tag] += 1 N += 1 h = 0 diff --git a/gi/evaluation/evaluate_entropy.py b/gi/evaluation/evaluate_entropy.py new file mode 100644 index 00000000..43edc376 --- /dev/null +++ b/gi/evaluation/evaluate_entropy.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +import sys, math, itertools + +ginfile = open(sys.argv[1]) +pinfile = open(sys.argv[2]) +if len(sys.argv) > 3: + slash_threshold = int(sys.argv[3]) + #print >>sys.stderr, 'slash threshold', slash_threshold +else: + slash_threshold = 99999 + +# evaluating: H(G | P) = sum_{g,p} p(g,p) log { p(p) / p(g,p) } +# = sum_{g,p} c(g,p)/N { log c(p) - log N - log c(g,p) + log N } +# = 1/N sum_{g,p} c(g,p) { log c(p) - log c(g,p) } +# where G = gold, P = predicted, N = number of events + +N = 0 +gold_frequencies = {} +predict_frequencies = {} +joint_frequencies = {} + +for gline, pline in itertools.izip(ginfile, pinfile): + gparts = gline.split('||| ')[1].split() + pparts = pline.split('||| ')[1].split() + assert len(gparts) == len(pparts) + + for gpart, ppart in zip(gparts, pparts): + gtag = gpart.split(':',1)[1] + ptag = ppart.split(':',1)[1] + + if gtag.count('/') + gtag.count('\\') <= slash_threshold: + joint_frequencies.setdefault((gtag, ptag), 0) + joint_frequencies[gtag,ptag] += 1 + + predict_frequencies.setdefault(ptag, 0) + predict_frequencies[ptag] += 1 + + gold_frequencies.setdefault(gtag, 0) + gold_frequencies[gtag] += 1 + + N += 1 + +hg2p = 0 +hp2g = 0 +for (gtag, ptag), cgp in joint_frequencies.items(): + hp2g += cgp * (math.log(predict_frequencies[ptag], 2) - math.log(cgp, 2)) + hg2p += cgp * (math.log(gold_frequencies[gtag], 2) - math.log(cgp, 2)) +hg2p /= N +hp2g /= N + +hg = 0 +for gtag, c in gold_frequencies.items(): + hg -= c * (math.log(c, 2) - math.log(N, 2)) +hg /= N + +print 'H(P|G)', hg2p, 'H(G|P)', hp2g, 'VI', hg2p + hp2g, 'H(G)', hg +#sys.exit(0) + +# find top tags +gtags = gold_frequencies.items() +gtags.sort(lambda x,y: x[1]-y[1]) +gtags.reverse() +#gtags = gtags[:50] + +print '%7s %7s' % ('pred', 'cnt'), +for gtag, gcount in gtags: print '%7s' % gtag, +print +print '=' * 80 + +preds = predict_frequencies.items() +preds.sort(lambda x,y: x[1]-y[1]) +preds.reverse() +for ptag, pcount in preds: + print '%7s %7d' % (ptag, pcount), + for gtag, gcount in gtags: + print '%7d' % joint_frequencies.get((gtag, ptag), 0), + print + +print '%7s %7d' % ('total', N), +for gtag, gcount in gtags: print '%7d' % gcount, +print + +if len(sys.argv) > 4: + # needs Python Image Library (PIL) + import Image, ImageDraw + + offset=10 + + image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255)) + #hsl(hue, saturation%, lightness%) + + # resort preds to get a better diagonal + ptags = [] + remaining = set(predict_frequencies.keys()) + for y, (gtag, gcount) in enumerate(gtags): + best = (None, 0) + for ptag in remaining: + #pcount = predict_frequencies[ptag] + p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount) + if p > best[1]: best = (ptag, p) + ptags.append(ptag) + remaining.remove(ptag) + if not remaining: break + + draw = ImageDraw.Draw(image) + for x, ptag in enumerate(ptags): + pcount = predict_frequencies[ptag] + minval = math.log(offset) + maxval = math.log(pcount + offset) + for y, (gtag, gcount) in enumerate(gtags): + f = math.log(offset + joint_frequencies.get((gtag, ptag), 0)) + z = int(240. * (maxval - f) / float(maxval - minval)) + #print x, y, z, f, maxval + draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z) + del draw + image.save(sys.argv[4]) diff --git a/gi/evaluation/extract_ccg_labels.py b/gi/evaluation/extract_ccg_labels.py index 77f21004..e0034648 100644 --- a/gi/evaluation/extract_ccg_labels.py +++ b/gi/evaluation/extract_ccg_labels.py @@ -90,10 +90,13 @@ for tline, eline in itertools.izip(tinfile, einfile): else: tr = None - zh, en, spans = eline.strip().split(" ||| ") + parts = eline.strip().split(" ||| ") + zh, en = parts[:2] + spans = parts[-1] print '|||', for span in spans.split(): - i, j, x, y = map(int, span.split("-")) + sps = span.split(":") + i, j, x, y = map(int, sps[0].split("-")) if tr: a = ancestor(tr, range(x,y)) @@ -113,7 +116,8 @@ for tline, eline in itertools.izip(tinfile, einfile): cat += '\\' + f.data.tag else: break - for f in reversed(fs): + fs.reverse() + for f in fs: if f.left >= y: cat += '/' + f.data.tag else: -- cgit v1.2.3