diff options
| author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-20 18:37:04 +0000 | 
|---|---|---|
| committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-20 18:37:04 +0000 | 
| commit | f281f2deac864d57a0eb566ae1f1c203ee5a8623 (patch) | |
| tree | 9de5753f91edab5b89fd40152360f0e7135818cb /gi/evaluation/conditional_entropy.py | |
| parent | 9380fb4819f3ed56cb7ad77a43728718039389cc (diff) | |
Cleaned up scripts
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@336 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/evaluation/conditional_entropy.py')
| -rw-r--r-- | gi/evaluation/conditional_entropy.py | 61 | 
1 files changed, 61 insertions, 0 deletions
diff --git a/gi/evaluation/conditional_entropy.py b/gi/evaluation/conditional_entropy.py new file mode 100644 index 00000000..356d3b1d --- /dev/null +++ b/gi/evaluation/conditional_entropy.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +import sys, math, itertools, getopt + +def usage(): +    print >>sys.stderr, 'Usage:', sys.argv[0], '[-s slash_threshold] input-1 input-2' +    sys.exit(0) + +optlist, args = getopt.getopt(sys.argv[1:], 'hs:') +slash_threshold = None +for opt, arg in optlist: +    if opt == '-s': +        slash_threshold = int(arg) +    else: +        usage() +if len(args) != 2: +    usage() + +ginfile = open(args[0]) +pinfile = open(args[1]) + +# evaluating: H(G | P) = sum_{g,p} p(g,p) log { p(p) / p(g,p) } +#                      = sum_{g,p} c(g,p)/N { log c(p) - log N - log c(g,p) + log N } +#                      = 1/N sum_{g,p} c(g,p) { log c(p) - log c(g,p) } +# where G = gold, P = predicted, N = number of events + +N = 0 +gold_frequencies = {} +predict_frequencies = {} +joint_frequencies = {} + +for gline, pline in itertools.izip(ginfile, pinfile): +    gparts = gline.split('||| ')[1].split() +    pparts = pline.split('||| ')[1].split() +    assert len(gparts) == len(pparts) + +    for gpart, ppart in zip(gparts, pparts): +        gtag = gpart.split(':',1)[1] +        ptag = ppart.split(':',1)[1] + +        if slash_threshold == None or gtag.count('/') + gtag.count('\\') <= slash_threshold: +            joint_frequencies.setdefault((gtag, ptag), 0) +            joint_frequencies[gtag,ptag] += 1 + +            predict_frequencies.setdefault(ptag, 0) +            predict_frequencies[ptag] += 1 + +            gold_frequencies.setdefault(gtag, 0) +            gold_frequencies[gtag] += 1 + +            N += 1 + +hg2p = 0 +hp2g = 0 +for (gtag, ptag), cgp in joint_frequencies.items(): +    hp2g += cgp * (math.log(predict_frequencies[ptag], 2) - math.log(cgp, 2)) +    hg2p += cgp * (math.log(gold_frequencies[gtag], 2) - math.log(cgp, 2)) +hg2p /= N +hp2g /= N + +print 'H(P|G)', hg2p, 'H(G|P)', hp2g, 'VI', hg2p + hp2g  | 
