#!/usr/bin/python import sys from collections import defaultdict def dict_max(d): max_val=-1 max_key=None for k in d: if d[k] > max_val: max_val = d[k] max_key = k assert max_key return max_key if len(sys.argv) != 3: print "Usage: score-mkcls.py gold classes" exit(1) gold_file=open(sys.argv[1],'r') term_to_topics = {} for line in open(sys.argv[2],'r'): term,cls = line.split() term_to_topics[term] = cls gold_to_topics = defaultdict(dict) topics_to_gold = defaultdict(dict) for gold_line in gold_file: gold_tokens = gold_line.split() for gold_token in gold_tokens: gold_term,gold_tag = gold_token.rsplit('|',1) pred_token = term_to_topics[gold_term] gold_to_topics[gold_tag][pred_token] \ = gold_to_topics[gold_tag].get(pred_token, 0) + 1 topics_to_gold[pred_token][gold_tag] \ = topics_to_gold[pred_token].get(gold_tag, 0) + 1 pred=0 correct=0 gold_file=open(sys.argv[1],'r') for gold_line in gold_file: gold_tokens = gold_line.split() for gold_token in gold_tokens: gold_term,gold_tag = gold_token.rsplit('|',1) pred_token = term_to_topics[gold_term] print "%s|%s|%s" % (gold_token, pred_token, dict_max(topics_to_gold[pred_token])), pred += 1 if gold_tag == dict_max(topics_to_gold[pred_token]): correct += 1 print print >>sys.stderr, "Many-to-One Accuracy = %f" % (float(correct) / pred) #for x in gold_to_topics: # print x,dict_max(gold_to_topics[x]) #print "###################################################" #for x in range(len(topics_to_gold)): # print x,dict_max(topics_to_gold[str(x)]) # print x,topics_to_gold[str(x)] #print term_to_topics