diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-12 16:38:52 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-12 16:38:52 +0000 |
commit | 767c1816691013585b00bb38264e9f7d32c25747 (patch) | |
tree | d7602605b8ead86a910af948b3460de7d7ce302d /gi/evaluation/tree.py | |
parent | 3368ed9579857982a51d78e834cd6f44e1915deb (diff) |
Code for evaluating clusterings
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@222 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/evaluation/tree.py')
-rw-r--r-- | gi/evaluation/tree.py | 485 |
1 files changed, 485 insertions, 0 deletions
diff --git a/gi/evaluation/tree.py b/gi/evaluation/tree.py new file mode 100644 index 00000000..702d80b6 --- /dev/null +++ b/gi/evaluation/tree.py @@ -0,0 +1,485 @@ +import re, sys + +class Symbol: + def __init__(self, nonterm, term=None, var=None): + assert not (term != None and var != None) + self.tag = nonterm + self.token = term + self.variable = var + + def is_variable(self): + return self.variable != None + + def __eq__(self, other): + return self.tag == other.tag and self.token == other.token and self.variable == other.variable + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash((self.tag, self.token, self.variable)) + + def __repr__(self): + return str(self) + + def __cmp__(self, other): + return cmp((self.tag, self.token, self.variable), + (other.tag, other.token, other.variable)) + + def __str__(self): + parts = [] + if False: # DEPENDENCY + if self.token: + parts.append(str(self.token)) + elif self.variable != None: + parts.append('#%d' % self.variable) + if self.tag: + parts.append(str(self.tag)) + return '/'.join(parts) + else: + if self.tag: + parts.append(str(self.tag)) + if self.token: + parts.append(str(self.token)) + elif self.variable != None: + parts.append('#%d' % self.variable) + return ' '.join(parts) + +class TreeNode: + def __init__(self, data, children=None, order=-1): + self.data = data + self.children = [] + self.order = order + self.parent = None + if children: self.children = children + + def insert(self, child): + self.children.append(child) + child.parent = self + + def leaves(self): + ls = [] + for node in self.xtraversal(): + if not node.children: + ls.append(node.data) + return ls + + def leaf_nodes(self): + ls = [] + for node in self.xtraversal(): + if not node.children: + ls.append(node) + return ls + + def max_depth(self): + d = 1 + for child in self.children: + d = max(d, 1 + child.max_depth()) + if not self.children and self.data.token: + d = 2 + return d + + def max_width(self): + w = 0 + for child in self.children: + w += child.max_width() + return max(1, w) + + def num_internal_nodes(self): + if self.children: + n = 1 + for child in self.children: + n += child.num_internal_nodes() + return n + elif self.data.token: + return 1 + else: + return 0 + + def postorder_traversal(self, visit): + """ + Postorder traversal; no guarantee that terminals will be read in the + correct order for dep. trees. + """ + for child in self.children: + child.traversal(visit) + visit(self) + + def traversal(self, visit): + """ + Preorder for phrase structure trees, and inorder for dependency trees. + In both cases the terminals will be read off in the correct order. + """ + visited_self = False + if self.order <= 0: + visited_self = True + visit(self) + + for i, child in enumerate(self.children): + child.traversal(visit) + if i + 1 == self.order: + visited_self = True + visit(self) + + assert visited_self + + def xpostorder_traversal(self): + for child in self.children: + for node in child.xpostorder_traversal(): + yield node + yield self + + def xtraversal(self): + visited_self = False + if self.order <= 0: + visited_self = True + yield self + + for i, child in enumerate(self.children): + for d in child.xtraversal(): + yield d + + if i + 1 == self.order: + visited_self = True + yield self + + assert visited_self + + def xpostorder_traversal(self): + for i, child in enumerate(self.children): + for d in child.xpostorder_traversal(): + yield d + yield self + + def edges(self): + es = [] + self.traverse_edges(lambda h,c: es.append((h,c))) + return es + + def traverse_edges(self, visit): + for child in self.children: + visit(self.data, child.data) + child.traverse_edges(visit) + + def subtrees(self, include_self=False): + st = [] + if include_self: + stack = [self] + else: + stack = self.children[:] + + while stack: + node = stack.pop() + st.append(node) + stack.extend(node.children) + return st + + def find_parent(self, node): + try: + index = self.children.index(node) + return self, index + except ValueError: + for child in self.children: + if isinstance(child, TreeNode): + r = child.find_parent(node) + if r: return r + return None + + def is_ancestor_of(self, node): + if self == node: + return True + for child in self.children: + if child.is_ancestor_of(child): + return True + return False + + def find(self, node): + if self == node: + return self + for child in self.children: + if isinstance(child, TreeNode): + r = child.find(node) + if r: return r + else: + if child == node: + return r + return None + + def equals_ignorecase(self, other): + if not isinstance(other, TreeNode): + return False + if self.data != other.data: + return False + if len(self.children) != len(other.children): + return False + for mc, oc in zip(self.children, other.children): + if isinstance(mc, TreeNode): + if not mc.equals_ignorecase(oc): + return False + else: + if mc.lower() != oc.lower(): + return False + return True + + def node_number(self, numbering, next=0): + if self.order <= 0: + numbering[id(self)] = next + next += 1 + + for i, child in enumerate(self.children): + next = child.node_number(numbering, next) + if i + 1 == self.order: + numbering[id(self)] = next + next += 1 + + return next + + def display_conll(self, out): + numbering = {} + self.node_number(numbering) + next = 0 + self.children[0].traversal(lambda x: \ + out.write('%d\t%s\t%s\t%s\t%s\t_\t%d\tLAB\n' \ + % (numbering[id(x)], x.data.token, x.data.token, + x.data.tag, x.data.tag, numbering[id(x.parent)]))) + out.write('\n') + + def size(self): + sz = 1 + for child in self.children: + sz += child.size() + return sz + + def __eq__(self, other): + if isinstance(other, TreeNode) and self.data == other.data \ + and self.children == other.children: + return True + return False + + def __cmp__(self, other): + if not isinstance(other, TreeNode): return 1 + n = cmp(self.data, other.data) + if n != 0: return n + n = len(self.children) - len(other.children) + if n != 0: return n + for sc, oc in zip(self.children, other.children): + n = cmp(sc, oc) + if n != 0: return n + return 0 + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.data, tuple(self.children))) + + def __repr__(self): + return str(self) + + def __str__(self): + s = '(' + space = False + if self.order <= 0: + s += str(self.data) + space = True + for i, child in enumerate(self.children): + if space: s += ' ' + s += str(child) + space = True + if i+1 == self.order: + s += ' ' + str(self.data) + return s + ')' + +def read_PSTs(fname): + infile = open(fname) + trees = [] + for line in infile: + trees.append(parse_PST(line.strip())) + infile.close() + return trees + +def parse_PST_multiline(infile, hash_is_var=True): + buf = '' + num_open = 0 + while True: + line = infile.readline() + if not line: + return None + buf += ' ' + line.rstrip() + num_open += line.count('(') - line.count(')') + if num_open == 0: + break + + return parse_PST(buf, hash_is_var) + +def parse_PST(line, hash_is_var=True): + line = line.rstrip() + if not line or line.lower() == 'null': + return None + + # allow either (a/DT) or (DT a) + #parts_re = re.compile(r'(\(*)([^/)]*)(?:/([^)]*))?(\)*)$') + + # only allow (DT a) + parts_re = re.compile(r'(\(*)([^)]*)(\)*)$') + + root = TreeNode(Symbol('TOP')) + stack = [root] + for part in line.rstrip().split(): + m = parts_re.match(part) + #opening, tok_or_tag, tag, closing = m.groups() + opening, tok_or_tag, closing = m.groups() + tag = None + #print 'token', part, 'bits', m.groups() + for i in opening: + node = TreeNode(Symbol(None)) + stack[-1].insert(node) + stack.append(node) + + if tag: + stack[-1].data.tag = tag + if hash_is_var and tok_or_tag.startswith('#'): + stack[-1].data.variable = int(tok_or_tag[1:]) + else: + stack[-1].data.token = tok_or_tag + else: + if stack[-1].data.tag == None: + stack[-1].data.tag = tok_or_tag + else: + if hash_is_var and tok_or_tag.startswith('#'): + try: + stack[-1].data.variable = int(tok_or_tag[1:]) + except ValueError: # it's really a token! + #print >>sys.stderr, 'Warning: # used for token:', tok_or_tag + stack[-1].data.token = tok_or_tag + else: + stack[-1].data.token = tok_or_tag + + for i in closing: + stack.pop() + + #assert str(root.children[0]) == line + return root.children[0] + +def read_DTs(fname): + infile = open(fname) + trees = [] + while True: + t = parse_DT(infile) + if t: trees.append(t) + else: break + infile.close() + return trees + +def read_bracketed_DTs(fname): + infile = open(fname) + trees = [] + for line in infile: + trees.append(parse_bracketed_DT(line)) + infile.close() + return trees + +def parse_DT(infile): + tokens = [Symbol('ROOT')] + children = {} + + for line in infile: + parts = line.rstrip().split() + #print parts + if not parts: break + index = len(tokens) + token = parts[1] + tag = parts[3] + parent = int(parts[6]) + if token.startswith('#'): + tokens.append(Symbol(tag, var=int(token[1:]))) + else: + tokens.append(Symbol(tag, token)) + children.setdefault(parent, set()).add(index) + + if len(tokens) == 1: return None + + root = TreeNode(Symbol('ROOT'), [], 0) + schedule = [] + for child in sorted(children[0]): + schedule.append((root, child)) + + while schedule: + parent, index = schedule[0] + del schedule[0] + + node = TreeNode(tokens[index]) + node.order = 0 + parent.insert(node) + + for child in sorted(children.get(index, [])): + schedule.append((node, child)) + if child < index: + node.order += 1 + + return root + +_bracket_split_re = re.compile(r'([(]*)([^)/]*)(?:/([^)]*))?([)]*)') + +def parse_bracketed_DT(line, insert_root=True): + line = line.rstrip() + if not line or line == 'NULL': return None + #print line + + root = TreeNode(Symbol('ROOT')) + stack = [root] + for part in line.rstrip().split(): + m = _bracket_split_re.match(part) + + for c in m.group(1): + node = TreeNode(Symbol(None)) + stack[-1].insert(node) + stack.append(node) + + if m.group(3) != None: + if m.group(2).startswith('#'): + stack[-1].data.variable = int(m.group(2)[1:]) + else: + stack[-1].data.token = m.group(2) + stack[-1].data.tag = m.group(3) + else: + stack[-1].data.tag = m.group(2) + stack[-1].order = len(stack[-1].children) + # FIXME: also check for vars + + for c in m.group(4): + stack.pop() + + assert len(stack) == 1 + if not insert_root or root.children[0].data.tag == 'ROOT': + return root.children[0] + else: + return root + +_bracket_split_notag_re = re.compile(r'([(]*)([^)/]*)([)]*)') + +def parse_bracketed_untagged_DT(line): + line = line.rstrip() + if not line or line == 'NULL': return None + + root = TreeNode(Symbol('TOP')) + stack = [root] + for part in line.rstrip().split(): + m = _bracket_split_notag_re.match(part) + + for c in m.group(1): + node = TreeNode(Symbol(None)) + stack[-1].insert(node) + stack.append(node) + + if stack[-1].data.token == None: + stack[-1].data.token = m.group(2) + stack[-1].order = len(stack[-1].children) + else: + child = TreeNode(Symbol(nonterm=None, term=m.group(2))) + stack[-1].insert(child) + + for c in m.group(3): + stack.pop() + + return root.children[0] |