#!/usr/bin/env python

import sys, math, itertools

ginfile = open(sys.argv[1])
pinfile = open(sys.argv[2])
if len(sys.argv) > 3:
    slash_threshold = int(sys.argv[3])
    #print >>sys.stderr, 'slash threshold', slash_threshold
else:
    slash_threshold = 99999

# evaluating: H(G | P) = sum_{g,p} p(g,p) log { p(p) / p(g,p) }
#                      = sum_{g,p} c(g,p)/N { log c(p) - log N - log c(g,p) + log N }
#                      = 1/N sum_{g,p} c(g,p) { log c(p) - log c(g,p) }
# where G = gold, P = predicted, N = number of events

N = 0
gold_frequencies = {}
predict_frequencies = {}
joint_frequencies = {}

for gline, pline in itertools.izip(ginfile, pinfile):
    gparts = gline.split('||| ')[1].split()
    pparts = pline.split('||| ')[1].split()
    assert len(gparts) == len(pparts)

    for gpart, ppart in zip(gparts, pparts):
        gtag = gpart.split(':',1)[1]
        ptag = ppart.split(':',1)[1]

        if gtag.count('/') + gtag.count('\\') <= slash_threshold:
            joint_frequencies.setdefault((gtag, ptag), 0)
            joint_frequencies[gtag,ptag] += 1

            predict_frequencies.setdefault(ptag, 0)
            predict_frequencies[ptag] += 1

            gold_frequencies.setdefault(gtag, 0)
            gold_frequencies[gtag] += 1

            N += 1

hg2p = 0
hp2g = 0
for (gtag, ptag), cgp in joint_frequencies.items():
    hp2g += cgp * (math.log(predict_frequencies[ptag], 2) - math.log(cgp, 2))
    hg2p += cgp * (math.log(gold_frequencies[gtag], 2) - math.log(cgp, 2))
hg2p /= N
hp2g /= N

hg = 0
for gtag, c in gold_frequencies.items():
    hg -= c * (math.log(c, 2) - math.log(N, 2))
hg /= N

print 'H(P|G)', hg2p, 'H(G|P)', hp2g, 'VI', hg2p + hp2g, 'H(G)', hg

# find top tags
gtags = gold_frequencies.items()
gtags.sort(lambda x,y: x[1]-y[1])
gtags.reverse()
#gtags = gtags[:50]

print '%7s %7s' % ('pred', 'cnt'),
for gtag, gcount in gtags: print '%7s' % gtag,
print
print '=' * 80

preds = predict_frequencies.items()
preds.sort(lambda x,y: x[1]-y[1])
preds.reverse()
for ptag, pcount in preds:
    print '%7s %7d' % (ptag, pcount),
    for gtag, gcount in gtags:
	print '%7d' % joint_frequencies.get((gtag, ptag), 0),
    print

print '%7s %7d' % ('total', N),
for gtag, gcount in gtags: print '%7d' % gcount,
print

if len(sys.argv) > 4:
	# needs Python Image Library (PIL)
	import Image, ImageDraw

	offset=10

	image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255))
	#hsl(hue, saturation%, lightness%)

	# resort preds to get a better diagonal
	ptags = []
	remaining = set(predict_frequencies.keys())
	for y, (gtag, gcount) in enumerate(gtags):
	    best = (None, 0)
	    for ptag in remaining:
		#pcount = predict_frequencies[ptag]
		p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount)
		if p > best[1]: best = (ptag, p)
	    ptags.append(ptag)
	    remaining.remove(ptag)
	    if not remaining: break
	
	draw = ImageDraw.Draw(image)
	for x, ptag in enumerate(ptags):
	    pcount = predict_frequencies[ptag]
	    minval = math.log(offset)
	    maxval = math.log(pcount + offset)
	    for y, (gtag, gcount) in enumerate(gtags):
		f = math.log(offset + joint_frequencies.get((gtag, ptag), 0))
		z = int(240. * (maxval - f) / float(maxval - minval))
		#print x, y, z, f, maxval
		draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z)
	del draw
	image.save(sys.argv[4])