summaryrefslogtreecommitdiff
path: root/gi/evaluation
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-10-11 14:06:32 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-10-11 14:06:32 -0400
commit07ea7b64b6f85e5798a8068453ed9fd2b97396db (patch)
tree644496a1690d84d82a396bbc1e39160788beb2cd /gi/evaluation
parent37b9e45e5cb29d708f7249dbe0b0fb27685282a0 (diff)
parenta36fcc5d55c1de84ae68c1091ebff2b1c32dc3b7 (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'gi/evaluation')
-rw-r--r--gi/evaluation/conditional_entropy.py61
-rw-r--r--gi/evaluation/confusion_matrix.py123
-rw-r--r--gi/evaluation/entropy.py38
-rw-r--r--gi/evaluation/extract_ccg_labels.py129
-rw-r--r--gi/evaluation/tree.py485
5 files changed, 0 insertions, 836 deletions
diff --git a/gi/evaluation/conditional_entropy.py b/gi/evaluation/conditional_entropy.py
deleted file mode 100644
index 356d3b1d..00000000
--- a/gi/evaluation/conditional_entropy.py
+++ /dev/null
@@ -1,61 +0,0 @@
-#!/usr/bin/env python
-
-import sys, math, itertools, getopt
-
-def usage():
- print >>sys.stderr, 'Usage:', sys.argv[0], '[-s slash_threshold] input-1 input-2'
- sys.exit(0)
-
-optlist, args = getopt.getopt(sys.argv[1:], 'hs:')
-slash_threshold = None
-for opt, arg in optlist:
- if opt == '-s':
- slash_threshold = int(arg)
- else:
- usage()
-if len(args) != 2:
- usage()
-
-ginfile = open(args[0])
-pinfile = open(args[1])
-
-# evaluating: H(G | P) = sum_{g,p} p(g,p) log { p(p) / p(g,p) }
-# = sum_{g,p} c(g,p)/N { log c(p) - log N - log c(g,p) + log N }
-# = 1/N sum_{g,p} c(g,p) { log c(p) - log c(g,p) }
-# where G = gold, P = predicted, N = number of events
-
-N = 0
-gold_frequencies = {}
-predict_frequencies = {}
-joint_frequencies = {}
-
-for gline, pline in itertools.izip(ginfile, pinfile):
- gparts = gline.split('||| ')[1].split()
- pparts = pline.split('||| ')[1].split()
- assert len(gparts) == len(pparts)
-
- for gpart, ppart in zip(gparts, pparts):
- gtag = gpart.split(':',1)[1]
- ptag = ppart.split(':',1)[1]
-
- if slash_threshold == None or gtag.count('/') + gtag.count('\\') <= slash_threshold:
- joint_frequencies.setdefault((gtag, ptag), 0)
- joint_frequencies[gtag,ptag] += 1
-
- predict_frequencies.setdefault(ptag, 0)
- predict_frequencies[ptag] += 1
-
- gold_frequencies.setdefault(gtag, 0)
- gold_frequencies[gtag] += 1
-
- N += 1
-
-hg2p = 0
-hp2g = 0
-for (gtag, ptag), cgp in joint_frequencies.items():
- hp2g += cgp * (math.log(predict_frequencies[ptag], 2) - math.log(cgp, 2))
- hg2p += cgp * (math.log(gold_frequencies[gtag], 2) - math.log(cgp, 2))
-hg2p /= N
-hp2g /= N
-
-print 'H(P|G)', hg2p, 'H(G|P)', hp2g, 'VI', hg2p + hp2g
diff --git a/gi/evaluation/confusion_matrix.py b/gi/evaluation/confusion_matrix.py
deleted file mode 100644
index 2dd7aa47..00000000
--- a/gi/evaluation/confusion_matrix.py
+++ /dev/null
@@ -1,123 +0,0 @@
-#!/usr/bin/env python
-
-import sys, math, itertools, getopt
-
-def usage():
- print >>sys.stderr, 'Usage:', sys.argv[0], '[-s slash_threshold] [-p output] [-m] input-1 input-2'
- sys.exit(0)
-
-optlist, args = getopt.getopt(sys.argv[1:], 'hs:mp:')
-slash_threshold = None
-output_fname = None
-show_matrix = False
-for opt, arg in optlist:
- if opt == '-s':
- slash_threshold = int(arg)
- elif opt == '-p':
- output_fname = arg
- elif opt == '-m':
- show_matrix = True
- else:
- usage()
-if len(args) != 2 or (not show_matrix and not output_fname):
- usage()
-
-ginfile = open(args[0])
-pinfile = open(args[1])
-
-if output_fname:
- try:
- import Image, ImageDraw
- except ImportError:
- print >>sys.stderr, "Error: Python Image Library not available. Did you forget to set your PYTHONPATH environment variable?"
- sys.exit(1)
-
-N = 0
-gold_frequencies = {}
-predict_frequencies = {}
-joint_frequencies = {}
-
-for gline, pline in itertools.izip(ginfile, pinfile):
- gparts = gline.split('||| ')[1].split()
- pparts = pline.split('||| ')[1].split()
- assert len(gparts) == len(pparts)
-
- for gpart, ppart in zip(gparts, pparts):
- gtag = gpart.split(':',1)[1]
- ptag = ppart.split(':',1)[1]
-
- if slash_threshold == None or gtag.count('/') + gtag.count('\\') <= slash_threshold:
- joint_frequencies.setdefault((gtag, ptag), 0)
- joint_frequencies[gtag,ptag] += 1
-
- predict_frequencies.setdefault(ptag, 0)
- predict_frequencies[ptag] += 1
-
- gold_frequencies.setdefault(gtag, 0)
- gold_frequencies[gtag] += 1
-
- N += 1
-
-# find top tags
-gtags = gold_frequencies.items()
-gtags.sort(lambda x,y: x[1]-y[1])
-gtags.reverse()
-#gtags = gtags[:50]
-
-preds = predict_frequencies.items()
-preds.sort(lambda x,y: x[1]-y[1])
-preds.reverse()
-
-if show_matrix:
- print '%7s %7s' % ('pred', 'cnt'),
- for gtag, gcount in gtags: print '%7s' % gtag,
- print
- print '=' * 80
-
- for ptag, pcount in preds:
- print '%7s %7d' % (ptag, pcount),
- for gtag, gcount in gtags:
- print '%7d' % joint_frequencies.get((gtag, ptag), 0),
- print
-
- print '%7s %7d' % ('total', N),
- for gtag, gcount in gtags: print '%7d' % gcount,
- print
-
-if output_fname:
- offset=10
-
- image = Image.new("RGB", (len(preds), len(gtags)), (255, 255, 255))
- #hsl(hue, saturation%, lightness%)
-
- # re-sort preds to get a better diagonal
- ptags=[]
- if True:
- ptags = map(lambda (p,c): p, preds)
- else:
- remaining = set(predict_frequencies.keys())
- for y, (gtag, gcount) in enumerate(gtags):
- best = (None, 0)
- for ptag in remaining:
- #pcount = predict_frequencies[ptag]
- p = joint_frequencies.get((gtag, ptag), 0)# / float(pcount)
- if p > best[1]: best = (ptag, p)
- ptags.append(ptag)
- remaining.remove(ptag)
- if not remaining: break
-
- print 'Predicted tag ordering:', ' '.join(ptags)
- print 'Gold tag ordering:', ' '.join(map(lambda (t,c): t, gtags))
-
- draw = ImageDraw.Draw(image)
- for x, ptag in enumerate(ptags):
- pcount = predict_frequencies[ptag]
- minval = math.log(offset)
- maxval = math.log(pcount + offset)
- for y, (gtag, gcount) in enumerate(gtags):
- f = math.log(offset + joint_frequencies.get((gtag, ptag), 0))
- z = int(240. * (maxval - f) / float(maxval - minval))
- #print x, y, z, f, maxval
- draw.point([(x,y)], fill='hsl(%d, 100%%, 50%%)' % z)
- del draw
- image.save(output_fname)
diff --git a/gi/evaluation/entropy.py b/gi/evaluation/entropy.py
deleted file mode 100644
index ec1ef502..00000000
--- a/gi/evaluation/entropy.py
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/usr/bin/env python
-
-import sys, math, itertools, getopt
-
-def usage():
- print >>sys.stderr, 'Usage:', sys.argv[0], '[-s slash_threshold] input file'
- sys.exit(0)
-
-optlist, args = getopt.getopt(sys.argv[1:], 'hs:')
-slash_threshold = None
-for opt, arg in optlist:
- if opt == '-s':
- slash_threshold = int(arg)
- else:
- usage()
-if len(args) != 1:
- usage()
-
-infile = open(args[0])
-N = 0
-frequencies = {}
-
-for line in infile:
-
- for part in line.split('||| ')[1].split():
- tag = part.split(':',1)[1]
-
- if slash_threshold == None or tag.count('/') + tag.count('\\') <= slash_threshold:
- frequencies.setdefault(tag, 0)
- frequencies[tag] += 1
- N += 1
-
-h = 0
-for tag, c in frequencies.items():
- h -= c * (math.log(c, 2) - math.log(N, 2))
-h /= N
-
-print 'entropy', h
diff --git a/gi/evaluation/extract_ccg_labels.py b/gi/evaluation/extract_ccg_labels.py
deleted file mode 100644
index e0034648..00000000
--- a/gi/evaluation/extract_ccg_labels.py
+++ /dev/null
@@ -1,129 +0,0 @@
-#!/usr/bin/env python
-
-#
-# Takes spans input along with treebank and spits out CG style categories for each span.
-# spans = output from CDEC's extools/extractor with --base_phrase_spans option
-# treebank = PTB format, one tree per line
-#
-# Output is in CDEC labelled-span format
-#
-
-import sys, itertools, tree
-
-tinfile = open(sys.argv[1])
-einfile = open(sys.argv[2])
-
-def number_leaves(node, next=0):
- left, right = None, None
- for child in node.children:
- l, r = number_leaves(child, next)
- next = max(next, r+1)
- if left == None or l < left:
- left = l
- if right == None or r > right:
- right = r
-
- #print node, left, right, next
- if left == None or right == None:
- assert not node.children
- left = right = next
-
- node.left = left
- node.right = right
-
- return left, right
-
-def ancestor(node, indices):
- #print node, node.left, node.right, indices
- # returns the deepest node covering all the indices
- if min(indices) >= node.left and max(indices) <= node.right:
- # try the children
- for child in node.children:
- x = ancestor(child, indices)
- if x: return x
- return node
- else:
- return None
-
-def frontier(node, indices):
- #print 'frontier for node', node, 'indices', indices
- if node.left > max(indices) or node.right < min(indices):
- #print '\toutside'
- return [node]
- elif node.children:
- #print '\tcovering at least part'
- ns = []
- for child in node.children:
- n = frontier(child, indices)
- ns.extend(n)
- return ns
- else:
- return [node]
-
-def project_heads(node):
- #print 'project_heads', node
- is_head = node.data.tag.endswith('-HEAD')
- if node.children:
- found = 0
- for child in node.children:
- x = project_heads(child)
- if x:
- node.data.tag = x
- found += 1
- assert found == 1
- elif is_head:
- node.data.tag = node.data.tag[:-len('-HEAD')]
-
- if is_head:
- return node.data.tag
- else:
- return None
-
-for tline, eline in itertools.izip(tinfile, einfile):
- if tline.strip() != '(())':
- if tline.startswith('( '):
- tline = tline[2:-1].strip()
- tr = tree.parse_PST(tline)
- if tr != None:
- number_leaves(tr)
- #project_heads(tr) # assumes Bikel-style head annotation for the input trees
- else:
- tr = None
-
- parts = eline.strip().split(" ||| ")
- zh, en = parts[:2]
- spans = parts[-1]
- print '|||',
- for span in spans.split():
- sps = span.split(":")
- i, j, x, y = map(int, sps[0].split("-"))
-
- if tr:
- a = ancestor(tr, range(x,y))
- try:
- fs = frontier(a, range(x,y))
- except:
- print >>sys.stderr, "problem with line", tline.strip(), "--", eline.strip()
- raise
-
- #print x, y
- #print 'ancestor', a
- #print 'frontier', fs
-
- cat = a.data.tag
- for f in fs:
- if f.right < x:
- cat += '\\' + f.data.tag
- else:
- break
- fs.reverse()
- for f in fs:
- if f.left >= y:
- cat += '/' + f.data.tag
- else:
- break
- else:
- cat = 'FAIL'
-
- print '%d-%d:%s' % (x, y, cat),
- print
diff --git a/gi/evaluation/tree.py b/gi/evaluation/tree.py
deleted file mode 100644
index 702d80b6..00000000
--- a/gi/evaluation/tree.py
+++ /dev/null
@@ -1,485 +0,0 @@
-import re, sys
-
-class Symbol:
- def __init__(self, nonterm, term=None, var=None):
- assert not (term != None and var != None)
- self.tag = nonterm
- self.token = term
- self.variable = var
-
- def is_variable(self):
- return self.variable != None
-
- def __eq__(self, other):
- return self.tag == other.tag and self.token == other.token and self.variable == other.variable
-
- def __ne__(self, other):
- return not (self == other)
-
- def __hash__(self):
- return hash((self.tag, self.token, self.variable))
-
- def __repr__(self):
- return str(self)
-
- def __cmp__(self, other):
- return cmp((self.tag, self.token, self.variable),
- (other.tag, other.token, other.variable))
-
- def __str__(self):
- parts = []
- if False: # DEPENDENCY
- if self.token:
- parts.append(str(self.token))
- elif self.variable != None:
- parts.append('#%d' % self.variable)
- if self.tag:
- parts.append(str(self.tag))
- return '/'.join(parts)
- else:
- if self.tag:
- parts.append(str(self.tag))
- if self.token:
- parts.append(str(self.token))
- elif self.variable != None:
- parts.append('#%d' % self.variable)
- return ' '.join(parts)
-
-class TreeNode:
- def __init__(self, data, children=None, order=-1):
- self.data = data
- self.children = []
- self.order = order
- self.parent = None
- if children: self.children = children
-
- def insert(self, child):
- self.children.append(child)
- child.parent = self
-
- def leaves(self):
- ls = []
- for node in self.xtraversal():
- if not node.children:
- ls.append(node.data)
- return ls
-
- def leaf_nodes(self):
- ls = []
- for node in self.xtraversal():
- if not node.children:
- ls.append(node)
- return ls
-
- def max_depth(self):
- d = 1
- for child in self.children:
- d = max(d, 1 + child.max_depth())
- if not self.children and self.data.token:
- d = 2
- return d
-
- def max_width(self):
- w = 0
- for child in self.children:
- w += child.max_width()
- return max(1, w)
-
- def num_internal_nodes(self):
- if self.children:
- n = 1
- for child in self.children:
- n += child.num_internal_nodes()
- return n
- elif self.data.token:
- return 1
- else:
- return 0
-
- def postorder_traversal(self, visit):
- """
- Postorder traversal; no guarantee that terminals will be read in the
- correct order for dep. trees.
- """
- for child in self.children:
- child.traversal(visit)
- visit(self)
-
- def traversal(self, visit):
- """
- Preorder for phrase structure trees, and inorder for dependency trees.
- In both cases the terminals will be read off in the correct order.
- """
- visited_self = False
- if self.order <= 0:
- visited_self = True
- visit(self)
-
- for i, child in enumerate(self.children):
- child.traversal(visit)
- if i + 1 == self.order:
- visited_self = True
- visit(self)
-
- assert visited_self
-
- def xpostorder_traversal(self):
- for child in self.children:
- for node in child.xpostorder_traversal():
- yield node
- yield self
-
- def xtraversal(self):
- visited_self = False
- if self.order <= 0:
- visited_self = True
- yield self
-
- for i, child in enumerate(self.children):
- for d in child.xtraversal():
- yield d
-
- if i + 1 == self.order:
- visited_self = True
- yield self
-
- assert visited_self
-
- def xpostorder_traversal(self):
- for i, child in enumerate(self.children):
- for d in child.xpostorder_traversal():
- yield d
- yield self
-
- def edges(self):
- es = []
- self.traverse_edges(lambda h,c: es.append((h,c)))
- return es
-
- def traverse_edges(self, visit):
- for child in self.children:
- visit(self.data, child.data)
- child.traverse_edges(visit)
-
- def subtrees(self, include_self=False):
- st = []
- if include_self:
- stack = [self]
- else:
- stack = self.children[:]
-
- while stack:
- node = stack.pop()
- st.append(node)
- stack.extend(node.children)
- return st
-
- def find_parent(self, node):
- try:
- index = self.children.index(node)
- return self, index
- except ValueError:
- for child in self.children:
- if isinstance(child, TreeNode):
- r = child.find_parent(node)
- if r: return r
- return None
-
- def is_ancestor_of(self, node):
- if self == node:
- return True
- for child in self.children:
- if child.is_ancestor_of(child):
- return True
- return False
-
- def find(self, node):
- if self == node:
- return self
- for child in self.children:
- if isinstance(child, TreeNode):
- r = child.find(node)
- if r: return r
- else:
- if child == node:
- return r
- return None
-
- def equals_ignorecase(self, other):
- if not isinstance(other, TreeNode):
- return False
- if self.data != other.data:
- return False
- if len(self.children) != len(other.children):
- return False
- for mc, oc in zip(self.children, other.children):
- if isinstance(mc, TreeNode):
- if not mc.equals_ignorecase(oc):
- return False
- else:
- if mc.lower() != oc.lower():
- return False
- return True
-
- def node_number(self, numbering, next=0):
- if self.order <= 0:
- numbering[id(self)] = next
- next += 1
-
- for i, child in enumerate(self.children):
- next = child.node_number(numbering, next)
- if i + 1 == self.order:
- numbering[id(self)] = next
- next += 1
-
- return next
-
- def display_conll(self, out):
- numbering = {}
- self.node_number(numbering)
- next = 0
- self.children[0].traversal(lambda x: \
- out.write('%d\t%s\t%s\t%s\t%s\t_\t%d\tLAB\n' \
- % (numbering[id(x)], x.data.token, x.data.token,
- x.data.tag, x.data.tag, numbering[id(x.parent)])))
- out.write('\n')
-
- def size(self):
- sz = 1
- for child in self.children:
- sz += child.size()
- return sz
-
- def __eq__(self, other):
- if isinstance(other, TreeNode) and self.data == other.data \
- and self.children == other.children:
- return True
- return False
-
- def __cmp__(self, other):
- if not isinstance(other, TreeNode): return 1
- n = cmp(self.data, other.data)
- if n != 0: return n
- n = len(self.children) - len(other.children)
- if n != 0: return n
- for sc, oc in zip(self.children, other.children):
- n = cmp(sc, oc)
- if n != 0: return n
- return 0
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __hash__(self):
- return hash((self.data, tuple(self.children)))
-
- def __repr__(self):
- return str(self)
-
- def __str__(self):
- s = '('
- space = False
- if self.order <= 0:
- s += str(self.data)
- space = True
- for i, child in enumerate(self.children):
- if space: s += ' '
- s += str(child)
- space = True
- if i+1 == self.order:
- s += ' ' + str(self.data)
- return s + ')'
-
-def read_PSTs(fname):
- infile = open(fname)
- trees = []
- for line in infile:
- trees.append(parse_PST(line.strip()))
- infile.close()
- return trees
-
-def parse_PST_multiline(infile, hash_is_var=True):
- buf = ''
- num_open = 0
- while True:
- line = infile.readline()
- if not line:
- return None
- buf += ' ' + line.rstrip()
- num_open += line.count('(') - line.count(')')
- if num_open == 0:
- break
-
- return parse_PST(buf, hash_is_var)
-
-def parse_PST(line, hash_is_var=True):
- line = line.rstrip()
- if not line or line.lower() == 'null':
- return None
-
- # allow either (a/DT) or (DT a)
- #parts_re = re.compile(r'(\(*)([^/)]*)(?:/([^)]*))?(\)*)$')
-
- # only allow (DT a)
- parts_re = re.compile(r'(\(*)([^)]*)(\)*)$')
-
- root = TreeNode(Symbol('TOP'))
- stack = [root]
- for part in line.rstrip().split():
- m = parts_re.match(part)
- #opening, tok_or_tag, tag, closing = m.groups()
- opening, tok_or_tag, closing = m.groups()
- tag = None
- #print 'token', part, 'bits', m.groups()
- for i in opening:
- node = TreeNode(Symbol(None))
- stack[-1].insert(node)
- stack.append(node)
-
- if tag:
- stack[-1].data.tag = tag
- if hash_is_var and tok_or_tag.startswith('#'):
- stack[-1].data.variable = int(tok_or_tag[1:])
- else:
- stack[-1].data.token = tok_or_tag
- else:
- if stack[-1].data.tag == None:
- stack[-1].data.tag = tok_or_tag
- else:
- if hash_is_var and tok_or_tag.startswith('#'):
- try:
- stack[-1].data.variable = int(tok_or_tag[1:])
- except ValueError: # it's really a token!
- #print >>sys.stderr, 'Warning: # used for token:', tok_or_tag
- stack[-1].data.token = tok_or_tag
- else:
- stack[-1].data.token = tok_or_tag
-
- for i in closing:
- stack.pop()
-
- #assert str(root.children[0]) == line
- return root.children[0]
-
-def read_DTs(fname):
- infile = open(fname)
- trees = []
- while True:
- t = parse_DT(infile)
- if t: trees.append(t)
- else: break
- infile.close()
- return trees
-
-def read_bracketed_DTs(fname):
- infile = open(fname)
- trees = []
- for line in infile:
- trees.append(parse_bracketed_DT(line))
- infile.close()
- return trees
-
-def parse_DT(infile):
- tokens = [Symbol('ROOT')]
- children = {}
-
- for line in infile:
- parts = line.rstrip().split()
- #print parts
- if not parts: break
- index = len(tokens)
- token = parts[1]
- tag = parts[3]
- parent = int(parts[6])
- if token.startswith('#'):
- tokens.append(Symbol(tag, var=int(token[1:])))
- else:
- tokens.append(Symbol(tag, token))
- children.setdefault(parent, set()).add(index)
-
- if len(tokens) == 1: return None
-
- root = TreeNode(Symbol('ROOT'), [], 0)
- schedule = []
- for child in sorted(children[0]):
- schedule.append((root, child))
-
- while schedule:
- parent, index = schedule[0]
- del schedule[0]
-
- node = TreeNode(tokens[index])
- node.order = 0
- parent.insert(node)
-
- for child in sorted(children.get(index, [])):
- schedule.append((node, child))
- if child < index:
- node.order += 1
-
- return root
-
-_bracket_split_re = re.compile(r'([(]*)([^)/]*)(?:/([^)]*))?([)]*)')
-
-def parse_bracketed_DT(line, insert_root=True):
- line = line.rstrip()
- if not line or line == 'NULL': return None
- #print line
-
- root = TreeNode(Symbol('ROOT'))
- stack = [root]
- for part in line.rstrip().split():
- m = _bracket_split_re.match(part)
-
- for c in m.group(1):
- node = TreeNode(Symbol(None))
- stack[-1].insert(node)
- stack.append(node)
-
- if m.group(3) != None:
- if m.group(2).startswith('#'):
- stack[-1].data.variable = int(m.group(2)[1:])
- else:
- stack[-1].data.token = m.group(2)
- stack[-1].data.tag = m.group(3)
- else:
- stack[-1].data.tag = m.group(2)
- stack[-1].order = len(stack[-1].children)
- # FIXME: also check for vars
-
- for c in m.group(4):
- stack.pop()
-
- assert len(stack) == 1
- if not insert_root or root.children[0].data.tag == 'ROOT':
- return root.children[0]
- else:
- return root
-
-_bracket_split_notag_re = re.compile(r'([(]*)([^)/]*)([)]*)')
-
-def parse_bracketed_untagged_DT(line):
- line = line.rstrip()
- if not line or line == 'NULL': return None
-
- root = TreeNode(Symbol('TOP'))
- stack = [root]
- for part in line.rstrip().split():
- m = _bracket_split_notag_re.match(part)
-
- for c in m.group(1):
- node = TreeNode(Symbol(None))
- stack[-1].insert(node)
- stack.append(node)
-
- if stack[-1].data.token == None:
- stack[-1].data.token = m.group(2)
- stack[-1].order = len(stack[-1].children)
- else:
- child = TreeNode(Symbol(nonterm=None, term=m.group(2)))
- stack[-1].insert(child)
-
- for c in m.group(3):
- stack.pop()
-
- return root.children[0]