diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 21:45:45 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 21:45:45 +0000 |
commit | 18209d048bfc1e9a1cd3edcfb80b3741869a9c34 (patch) | |
tree | fc93cee6ec967c93cbed9b2034833c61ecccf006 /gi/evaluation | |
parent | a900eeb513e71ecbf5de9ba545da052002184fdc (diff) |
Added pictures
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@299 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/evaluation')
-rw-r--r-- | gi/evaluation/evaluate_entropy.py | 71 | ||||
-rw-r--r-- | gi/evaluation/extract_ccg_labels.py | 11 |
2 files changed, 76 insertions, 6 deletions
diff --git a/gi/evaluation/evaluate_entropy.py b/gi/evaluation/evaluate_entropy.py index 55449c59..e4980ccf 100644 --- a/gi/evaluation/evaluate_entropy.py +++ b/gi/evaluation/evaluate_entropy.py @@ -4,8 +4,9 @@ import sys, math, itertools ginfile = open(sys.argv[1]) pinfile = open(sys.argv[2]) -if len(sys.argv) >= 4: +if len(sys.argv) > 3: slash_threshold = int(sys.argv[3]) + #print >>sys.stderr, 'slash threshold', slash_threshold else: slash_threshold = 99999 @@ -28,7 +29,7 @@ for gline, pline in itertools.izip(ginfile, pinfile): gtag = gpart.split(':',1)[1] ptag = ppart.split(':',1)[1] - if gtag.count('/') <= slash_threshold: + if gtag.count('/') + gtag.count('\\') <= slash_threshold: joint_frequencies.setdefault((gtag, ptag), 0) joint_frequencies[gtag,ptag] += 1 @@ -48,4 +49,68 @@ for (gtag, ptag), cgp in joint_frequencies.items(): hg2p /= N hp2g /= N -print 'H(P|G)', hg2p, 'H(G|P)', hp2g, 'VI', hg2p + hp2g +hg = 0 +for gtag, c in gold_frequencies.items(): + hg -= c * (math.log(c, 2) - math.log(N, 2)) +hg /= N + +print 'H(P|G)', hg2p, 'H(G|P)', hp2g, 'VI', hg2p + hp2g, 'H(G)', hg + +# find top tags +gtags = gold_frequencies.items() +gtags.sort(lambda x,y: x[1]-y[1]) +gtags.reverse() +#gtags = gtags[:50] + +print '%7s %7s' % ('pred', 'cnt'), +for gtag, gcount in gtags: print '%7s' % gtag, +print +print '=' * 80 + +preds = predict_frequencies.items() +preds.sort(lambda x,y: x[1]-y[1]) +preds.reverse() +for ptag, pcount in preds: + print '%7s %7d' % (ptag, pcount), + for gtag, gcount in gtags: + print '%7d' % joint_frequencies.get((gtag, ptag), 0), + print + +print '%7s %7d' % ('total', N), +for gtag, gcount in gtags: print '%7d' % gcount, +print + +if len(sys.argv) > 4: + # needs Python Image Library (PIL) + import Image, ImageDraw + + offset=10 + + image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255)) + #hsl(hue, saturation%, lightness%) + + # resort preds to get a better diagonal + ptags = [] + remaining = set(predict_frequencies.keys()) + for y, (gtag, gcount) in enumerate(gtags): + best = (None, 0) + for ptag in remaining: + #pcount = predict_frequencies[ptag] + p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount) + if p > best[1]: best = (ptag, p) + ptags.append(ptag) + remaining.remove(ptag) + if not remaining: break + + draw = ImageDraw.Draw(image) + for x, ptag in enumerate(ptags): + pcount = predict_frequencies[ptag] + minval = math.log(offset) + maxval = math.log(pcount + offset) + for y, (gtag, gcount) in enumerate(gtags): + f = math.log(offset + joint_frequencies.get((gtag, ptag), 0)) + z = int(240. * (maxval - f) / float(maxval - minval)) + #print x, y, z, f, maxval + draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z) + del draw + image.save(sys.argv[4]) diff --git a/gi/evaluation/extract_ccg_labels.py b/gi/evaluation/extract_ccg_labels.py index 5dd6eb65..77f21004 100644 --- a/gi/evaluation/extract_ccg_labels.py +++ b/gi/evaluation/extract_ccg_labels.py @@ -84,8 +84,9 @@ for tline, eline in itertools.izip(tinfile, einfile): if tline.startswith('( '): tline = tline[2:-1].strip() tr = tree.parse_PST(tline) - number_leaves(tr) - #project_heads(tr) # assumes Bikel-style head annotation for the input trees + if tr != None: + number_leaves(tr) + #project_heads(tr) # assumes Bikel-style head annotation for the input trees else: tr = None @@ -96,7 +97,11 @@ for tline, eline in itertools.izip(tinfile, einfile): if tr: a = ancestor(tr, range(x,y)) - fs = frontier(a, range(x,y)) + try: + fs = frontier(a, range(x,y)) + except: + print >>sys.stderr, "problem with line", tline.strip(), "--", eline.strip() + raise #print x, y #print 'ancestor', a |