summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gi/evaluation/evaluate_entropy.py46
-rw-r--r--gi/evaluation/extract_ccg_labels.py100
-rw-r--r--gi/evaluation/tree.py485
3 files changed, 631 insertions, 0 deletions
diff --git a/gi/evaluation/evaluate_entropy.py b/gi/evaluation/evaluate_entropy.py
new file mode 100644
index 00000000..88533544
--- /dev/null
+++ b/gi/evaluation/evaluate_entropy.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+
+import sys, math, itertools
+
+ginfile = open(sys.argv[1])
+pinfile = open(sys.argv[2])
+
+# evaluating: H(G | P) = sum_{g,p} p(g,p) log { p(p) / p(g,p) }
+# = sum_{g,p} c(g,p)/N { log c(p) - log N - log c(g,p) + log N }
+# = 1/N sum_{g,p} c(g,p) { log c(p) - log c(g,p) }
+# where G = gold, P = predicted, N = number of events
+
+N = 0
+gold_frequencies = {}
+predict_frequencies = {}
+joint_frequencies = {}
+
+for gline, pline in itertools.izip(ginfile, pinfile):
+ gparts = gline.split('||| ')[1].split()
+ pparts = pline.split('||| ')[1].split()
+ assert len(gparts) == len(pparts)
+
+ for gpart, ppart in zip(gparts, pparts):
+ gtag = gpart.split(':',1)[1]
+ ptag = ppart.split(':',1)[1]
+
+ joint_frequencies.setdefault((gtag, ptag), 0)
+ joint_frequencies[gtag,ptag] += 1
+
+ predict_frequencies.setdefault(ptag, 0)
+ predict_frequencies[ptag] += 1
+
+ gold_frequencies.setdefault(gtag, 0)
+ gold_frequencies[gtag] += 1
+
+ N += 1
+
+hg2p = 0
+hp2g = 0
+for (gtag, ptag), cgp in joint_frequencies.items():
+ hp2g += cgp * (math.log(predict_frequencies[ptag], 2) - math.log(cgp, 2))
+ hg2p += cgp * (math.log(gold_frequencies[gtag], 2) - math.log(cgp, 2))
+hg2p /= N
+hp2g /= N
+
+print 'H(P|G)', hg2p, 'H(G|P)', hp2g, 'VI', hg2p + hp2g
diff --git a/gi/evaluation/extract_ccg_labels.py b/gi/evaluation/extract_ccg_labels.py
new file mode 100644
index 00000000..014e0399
--- /dev/null
+++ b/gi/evaluation/extract_ccg_labels.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python
+
+#
+# Takes spans input along with treebank and spits out CG style categories for each span.
+# spans = output from CDEC's extools/extractor with --base_phrase_spans option
+# treebank = PTB format, one tree per line
+#
+# Output is in CDEC labelled-span format
+#
+
+import sys, itertools, tree
+
+tinfile = open(sys.argv[1])
+einfile = open(sys.argv[2])
+
+def number_leaves(node, next=0):
+ left, right = None, None
+ for child in node.children:
+ l, r = number_leaves(child, next)
+ next = max(next, r+1)
+ if left == None or l < left:
+ left = l
+ if right == None or r > right:
+ right = r
+
+ #print node, left, right, next
+ if left == None or right == None:
+ assert not node.children
+ left = right = next
+
+ node.left = left
+ node.right = right
+
+ return left, right
+
+def ancestor(node, indices):
+ #print node, node.left, node.right, indices
+ # returns the deepest node covering all the indices
+ if min(indices) >= node.left and max(indices) <= node.right:
+ # try the children
+ for child in node.children:
+ x = ancestor(child, indices)
+ if x: return x
+ return node
+ else:
+ return None
+
+def frontier(node, indices):
+ #print 'frontier for node', node, 'indices', indices
+ if node.left > max(indices) or node.right < min(indices):
+ #print '\toutside'
+ return [node]
+ elif node.children:
+ #print '\tcovering at least part'
+ ns = []
+ for child in node.children:
+ n = frontier(child, indices)
+ ns.extend(n)
+ return ns
+ else:
+ return [node]
+
+for tline, eline in itertools.izip(tinfile, einfile):
+ if tline.strip() != '(())':
+ if tline.startswith('( '):
+ tline = tline[2:-1].strip()
+ tr = tree.parse_PST(tline)
+ number_leaves(tr)
+ else:
+ tr = None
+
+ zh, en, spans = eline.strip().split(" ||| ")
+ print '|||',
+ for span in spans.split():
+ i, j, x, y = map(int, span.split("-"))
+
+ if tr:
+ a = ancestor(tr, range(x,y))
+ fs = frontier(a, range(x,y))
+
+ #print x, y
+ #print 'ancestor', a
+ #print 'frontier', fs
+
+ cat = a.data.tag
+ for f in fs:
+ if f.right < x:
+ cat += '\\' + f.data.tag
+ else:
+ break
+ for f in reversed(fs):
+ if f.left >= y:
+ cat += '/' + f.data.tag
+ else:
+ break
+ else:
+ cat = 'FAIL'
+
+ print '%d-%d:%s' % (x, y, cat),
+ print
diff --git a/gi/evaluation/tree.py b/gi/evaluation/tree.py
new file mode 100644
index 00000000..702d80b6
--- /dev/null
+++ b/gi/evaluation/tree.py
@@ -0,0 +1,485 @@
+import re, sys
+
+class Symbol:
+ def __init__(self, nonterm, term=None, var=None):
+ assert not (term != None and var != None)
+ self.tag = nonterm
+ self.token = term
+ self.variable = var
+
+ def is_variable(self):
+ return self.variable != None
+
+ def __eq__(self, other):
+ return self.tag == other.tag and self.token == other.token and self.variable == other.variable
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ def __hash__(self):
+ return hash((self.tag, self.token, self.variable))
+
+ def __repr__(self):
+ return str(self)
+
+ def __cmp__(self, other):
+ return cmp((self.tag, self.token, self.variable),
+ (other.tag, other.token, other.variable))
+
+ def __str__(self):
+ parts = []
+ if False: # DEPENDENCY
+ if self.token:
+ parts.append(str(self.token))
+ elif self.variable != None:
+ parts.append('#%d' % self.variable)
+ if self.tag:
+ parts.append(str(self.tag))
+ return '/'.join(parts)
+ else:
+ if self.tag:
+ parts.append(str(self.tag))
+ if self.token:
+ parts.append(str(self.token))
+ elif self.variable != None:
+ parts.append('#%d' % self.variable)
+ return ' '.join(parts)
+
+class TreeNode:
+ def __init__(self, data, children=None, order=-1):
+ self.data = data
+ self.children = []
+ self.order = order
+ self.parent = None
+ if children: self.children = children
+
+ def insert(self, child):
+ self.children.append(child)
+ child.parent = self
+
+ def leaves(self):
+ ls = []
+ for node in self.xtraversal():
+ if not node.children:
+ ls.append(node.data)
+ return ls
+
+ def leaf_nodes(self):
+ ls = []
+ for node in self.xtraversal():
+ if not node.children:
+ ls.append(node)
+ return ls
+
+ def max_depth(self):
+ d = 1
+ for child in self.children:
+ d = max(d, 1 + child.max_depth())
+ if not self.children and self.data.token:
+ d = 2
+ return d
+
+ def max_width(self):
+ w = 0
+ for child in self.children:
+ w += child.max_width()
+ return max(1, w)
+
+ def num_internal_nodes(self):
+ if self.children:
+ n = 1
+ for child in self.children:
+ n += child.num_internal_nodes()
+ return n
+ elif self.data.token:
+ return 1
+ else:
+ return 0
+
+ def postorder_traversal(self, visit):
+ """
+ Postorder traversal; no guarantee that terminals will be read in the
+ correct order for dep. trees.
+ """
+ for child in self.children:
+ child.traversal(visit)
+ visit(self)
+
+ def traversal(self, visit):
+ """
+ Preorder for phrase structure trees, and inorder for dependency trees.
+ In both cases the terminals will be read off in the correct order.
+ """
+ visited_self = False
+ if self.order <= 0:
+ visited_self = True
+ visit(self)
+
+ for i, child in enumerate(self.children):
+ child.traversal(visit)
+ if i + 1 == self.order:
+ visited_self = True
+ visit(self)
+
+ assert visited_self
+
+ def xpostorder_traversal(self):
+ for child in self.children:
+ for node in child.xpostorder_traversal():
+ yield node
+ yield self
+
+ def xtraversal(self):
+ visited_self = False
+ if self.order <= 0:
+ visited_self = True
+ yield self
+
+ for i, child in enumerate(self.children):
+ for d in child.xtraversal():
+ yield d
+
+ if i + 1 == self.order:
+ visited_self = True
+ yield self
+
+ assert visited_self
+
+ def xpostorder_traversal(self):
+ for i, child in enumerate(self.children):
+ for d in child.xpostorder_traversal():
+ yield d
+ yield self
+
+ def edges(self):
+ es = []
+ self.traverse_edges(lambda h,c: es.append((h,c)))
+ return es
+
+ def traverse_edges(self, visit):
+ for child in self.children:
+ visit(self.data, child.data)
+ child.traverse_edges(visit)
+
+ def subtrees(self, include_self=False):
+ st = []
+ if include_self:
+ stack = [self]
+ else:
+ stack = self.children[:]
+
+ while stack:
+ node = stack.pop()
+ st.append(node)
+ stack.extend(node.children)
+ return st
+
+ def find_parent(self, node):
+ try:
+ index = self.children.index(node)
+ return self, index
+ except ValueError:
+ for child in self.children:
+ if isinstance(child, TreeNode):
+ r = child.find_parent(node)
+ if r: return r
+ return None
+
+ def is_ancestor_of(self, node):
+ if self == node:
+ return True
+ for child in self.children:
+ if child.is_ancestor_of(child):
+ return True
+ return False
+
+ def find(self, node):
+ if self == node:
+ return self
+ for child in self.children:
+ if isinstance(child, TreeNode):
+ r = child.find(node)
+ if r: return r
+ else:
+ if child == node:
+ return r
+ return None
+
+ def equals_ignorecase(self, other):
+ if not isinstance(other, TreeNode):
+ return False
+ if self.data != other.data:
+ return False
+ if len(self.children) != len(other.children):
+ return False
+ for mc, oc in zip(self.children, other.children):
+ if isinstance(mc, TreeNode):
+ if not mc.equals_ignorecase(oc):
+ return False
+ else:
+ if mc.lower() != oc.lower():
+ return False
+ return True
+
+ def node_number(self, numbering, next=0):
+ if self.order <= 0:
+ numbering[id(self)] = next
+ next += 1
+
+ for i, child in enumerate(self.children):
+ next = child.node_number(numbering, next)
+ if i + 1 == self.order:
+ numbering[id(self)] = next
+ next += 1
+
+ return next
+
+ def display_conll(self, out):
+ numbering = {}
+ self.node_number(numbering)
+ next = 0
+ self.children[0].traversal(lambda x: \
+ out.write('%d\t%s\t%s\t%s\t%s\t_\t%d\tLAB\n' \
+ % (numbering[id(x)], x.data.token, x.data.token,
+ x.data.tag, x.data.tag, numbering[id(x.parent)])))
+ out.write('\n')
+
+ def size(self):
+ sz = 1
+ for child in self.children:
+ sz += child.size()
+ return sz
+
+ def __eq__(self, other):
+ if isinstance(other, TreeNode) and self.data == other.data \
+ and self.children == other.children:
+ return True
+ return False
+
+ def __cmp__(self, other):
+ if not isinstance(other, TreeNode): return 1
+ n = cmp(self.data, other.data)
+ if n != 0: return n
+ n = len(self.children) - len(other.children)
+ if n != 0: return n
+ for sc, oc in zip(self.children, other.children):
+ n = cmp(sc, oc)
+ if n != 0: return n
+ return 0
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return hash((self.data, tuple(self.children)))
+
+ def __repr__(self):
+ return str(self)
+
+ def __str__(self):
+ s = '('
+ space = False
+ if self.order <= 0:
+ s += str(self.data)
+ space = True
+ for i, child in enumerate(self.children):
+ if space: s += ' '
+ s += str(child)
+ space = True
+ if i+1 == self.order:
+ s += ' ' + str(self.data)
+ return s + ')'
+
+def read_PSTs(fname):
+ infile = open(fname)
+ trees = []
+ for line in infile:
+ trees.append(parse_PST(line.strip()))
+ infile.close()
+ return trees
+
+def parse_PST_multiline(infile, hash_is_var=True):
+ buf = ''
+ num_open = 0
+ while True:
+ line = infile.readline()
+ if not line:
+ return None
+ buf += ' ' + line.rstrip()
+ num_open += line.count('(') - line.count(')')
+ if num_open == 0:
+ break
+
+ return parse_PST(buf, hash_is_var)
+
+def parse_PST(line, hash_is_var=True):
+ line = line.rstrip()
+ if not line or line.lower() == 'null':
+ return None
+
+ # allow either (a/DT) or (DT a)
+ #parts_re = re.compile(r'(\(*)([^/)]*)(?:/([^)]*))?(\)*)$')
+
+ # only allow (DT a)
+ parts_re = re.compile(r'(\(*)([^)]*)(\)*)$')
+
+ root = TreeNode(Symbol('TOP'))
+ stack = [root]
+ for part in line.rstrip().split():
+ m = parts_re.match(part)
+ #opening, tok_or_tag, tag, closing = m.groups()
+ opening, tok_or_tag, closing = m.groups()
+ tag = None
+ #print 'token', part, 'bits', m.groups()
+ for i in opening:
+ node = TreeNode(Symbol(None))
+ stack[-1].insert(node)
+ stack.append(node)
+
+ if tag:
+ stack[-1].data.tag = tag
+ if hash_is_var and tok_or_tag.startswith('#'):
+ stack[-1].data.variable = int(tok_or_tag[1:])
+ else:
+ stack[-1].data.token = tok_or_tag
+ else:
+ if stack[-1].data.tag == None:
+ stack[-1].data.tag = tok_or_tag
+ else:
+ if hash_is_var and tok_or_tag.startswith('#'):
+ try:
+ stack[-1].data.variable = int(tok_or_tag[1:])
+ except ValueError: # it's really a token!
+ #print >>sys.stderr, 'Warning: # used for token:', tok_or_tag
+ stack[-1].data.token = tok_or_tag
+ else:
+ stack[-1].data.token = tok_or_tag
+
+ for i in closing:
+ stack.pop()
+
+ #assert str(root.children[0]) == line
+ return root.children[0]
+
+def read_DTs(fname):
+ infile = open(fname)
+ trees = []
+ while True:
+ t = parse_DT(infile)
+ if t: trees.append(t)
+ else: break
+ infile.close()
+ return trees
+
+def read_bracketed_DTs(fname):
+ infile = open(fname)
+ trees = []
+ for line in infile:
+ trees.append(parse_bracketed_DT(line))
+ infile.close()
+ return trees
+
+def parse_DT(infile):
+ tokens = [Symbol('ROOT')]
+ children = {}
+
+ for line in infile:
+ parts = line.rstrip().split()
+ #print parts
+ if not parts: break
+ index = len(tokens)
+ token = parts[1]
+ tag = parts[3]
+ parent = int(parts[6])
+ if token.startswith('#'):
+ tokens.append(Symbol(tag, var=int(token[1:])))
+ else:
+ tokens.append(Symbol(tag, token))
+ children.setdefault(parent, set()).add(index)
+
+ if len(tokens) == 1: return None
+
+ root = TreeNode(Symbol('ROOT'), [], 0)
+ schedule = []
+ for child in sorted(children[0]):
+ schedule.append((root, child))
+
+ while schedule:
+ parent, index = schedule[0]
+ del schedule[0]
+
+ node = TreeNode(tokens[index])
+ node.order = 0
+ parent.insert(node)
+
+ for child in sorted(children.get(index, [])):
+ schedule.append((node, child))
+ if child < index:
+ node.order += 1
+
+ return root
+
+_bracket_split_re = re.compile(r'([(]*)([^)/]*)(?:/([^)]*))?([)]*)')
+
+def parse_bracketed_DT(line, insert_root=True):
+ line = line.rstrip()
+ if not line or line == 'NULL': return None
+ #print line
+
+ root = TreeNode(Symbol('ROOT'))
+ stack = [root]
+ for part in line.rstrip().split():
+ m = _bracket_split_re.match(part)
+
+ for c in m.group(1):
+ node = TreeNode(Symbol(None))
+ stack[-1].insert(node)
+ stack.append(node)
+
+ if m.group(3) != None:
+ if m.group(2).startswith('#'):
+ stack[-1].data.variable = int(m.group(2)[1:])
+ else:
+ stack[-1].data.token = m.group(2)
+ stack[-1].data.tag = m.group(3)
+ else:
+ stack[-1].data.tag = m.group(2)
+ stack[-1].order = len(stack[-1].children)
+ # FIXME: also check for vars
+
+ for c in m.group(4):
+ stack.pop()
+
+ assert len(stack) == 1
+ if not insert_root or root.children[0].data.tag == 'ROOT':
+ return root.children[0]
+ else:
+ return root
+
+_bracket_split_notag_re = re.compile(r'([(]*)([^)/]*)([)]*)')
+
+def parse_bracketed_untagged_DT(line):
+ line = line.rstrip()
+ if not line or line == 'NULL': return None
+
+ root = TreeNode(Symbol('TOP'))
+ stack = [root]
+ for part in line.rstrip().split():
+ m = _bracket_split_notag_re.match(part)
+
+ for c in m.group(1):
+ node = TreeNode(Symbol(None))
+ stack[-1].insert(node)
+ stack.append(node)
+
+ if stack[-1].data.token == None:
+ stack[-1].data.token = m.group(2)
+ stack[-1].order = len(stack[-1].children)
+ else:
+ child = TreeNode(Symbol(nonterm=None, term=m.group(2)))
+ stack[-1].insert(child)
+
+ for c in m.group(3):
+ stack.pop()
+
+ return root.children[0]