import re, sys class Symbol: def __init__(self, nonterm, term=None, var=None): assert not (term != None and var != None) self.tag = nonterm self.token = term self.variable = var def is_variable(self): return self.variable != None def __eq__(self, other): return self.tag == other.tag and self.token == other.token and self.variable == other.variable def __ne__(self, other): return not (self == other) def __hash__(self): return hash((self.tag, self.token, self.variable)) def __repr__(self): return str(self) def __cmp__(self, other): return cmp((self.tag, self.token, self.variable), (other.tag, other.token, other.variable)) def __str__(self): parts = [] if False: # DEPENDENCY if self.token: parts.append(str(self.token)) elif self.variable != None: parts.append('#%d' % self.variable) if self.tag: parts.append(str(self.tag)) return '/'.join(parts) else: if self.tag: parts.append(str(self.tag)) if self.token: parts.append(str(self.token)) elif self.variable != None: parts.append('#%d' % self.variable) return ' '.join(parts) class TreeNode: def __init__(self, data, children=None, order=-1): self.data = data self.children = [] self.order = order self.parent = None if children: self.children = children def insert(self, child): self.children.append(child) child.parent = self def leaves(self): ls = [] for node in self.xtraversal(): if not node.children: ls.append(node.data) return ls def leaf_nodes(self): ls = [] for node in self.xtraversal(): if not node.children: ls.append(node) return ls def max_depth(self): d = 1 for child in self.children: d = max(d, 1 + child.max_depth()) if not self.children and self.data.token: d = 2 return d def max_width(self): w = 0 for child in self.children: w += child.max_width() return max(1, w) def num_internal_nodes(self): if self.children: n = 1 for child in self.children: n += child.num_internal_nodes() return n elif self.data.token: return 1 else: return 0 def postorder_traversal(self, visit): """ Postorder traversal; no guarantee that terminals will be read in the correct order for dep. trees. """ for child in self.children: child.traversal(visit) visit(self) def traversal(self, visit): """ Preorder for phrase structure trees, and inorder for dependency trees. In both cases the terminals will be read off in the correct order. """ visited_self = False if self.order <= 0: visited_self = True visit(self) for i, child in enumerate(self.children): child.traversal(visit) if i + 1 == self.order: visited_self = True visit(self) assert visited_self def xpostorder_traversal(self): for child in self.children: for node in child.xpostorder_traversal(): yield node yield self def xtraversal(self): visited_self = False if self.order <= 0: visited_self = True yield self for i, child in enumerate(self.children): for d in child.xtraversal(): yield d if i + 1 == self.order: visited_self = True yield self assert visited_self def xpostorder_traversal(self): for i, child in enumerate(self.children): for d in child.xpostorder_traversal(): yield d yield self def edges(self): es = [] self.traverse_edges(lambda h,c: es.append((h,c))) return es def traverse_edges(self, visit): for child in self.children: visit(self.data, child.data) child.traverse_edges(visit) def subtrees(self, include_self=False): st = [] if include_self: stack = [self] else: stack = self.children[:] while stack: node = stack.pop() st.append(node) stack.extend(node.children) return st def find_parent(self, node): try: index = self.children.index(node) return self, index except ValueError: for child in self.children: if isinstance(child, TreeNode): r = child.find_parent(node) if r: return r return None def is_ancestor_of(self, node): if self == node: return True for child in self.children: if child.is_ancestor_of(child): return True return False def find(self, node): if self == node: return self for child in self.children: if isinstance(child, TreeNode): r = child.find(node) if r: return r else: if child == node: return r return None def equals_ignorecase(self, other): if not isinstance(other, TreeNode): return False if self.data != other.data: return False if len(self.children) != len(other.children): return False for mc, oc in zip(self.children, other.children): if isinstance(mc, TreeNode): if not mc.equals_ignorecase(oc): return False else: if mc.lower() != oc.lower(): return False return True def node_number(self, numbering, next=0): if self.order <= 0: numbering[id(self)] = next next += 1 for i, child in enumerate(self.children): next = child.node_number(numbering, next) if i + 1 == self.order: numbering[id(self)] = next next += 1 return next def display_conll(self, out): numbering = {} self.node_number(numbering) next = 0 self.children[0].traversal(lambda x: \ out.write('%d\t%s\t%s\t%s\t%s\t_\t%d\tLAB\n' \ % (numbering[id(x)], x.data.token, x.data.token, x.data.tag, x.data.tag, numbering[id(x.parent)]))) out.write('\n') def size(self): sz = 1 for child in self.children: sz += child.size() return sz def __eq__(self, other): if isinstance(other, TreeNode) and self.data == other.data \ and self.children == other.children: return True return False def __cmp__(self, other): if not isinstance(other, TreeNode): return 1 n = cmp(self.data, other.data) if n != 0: return n n = len(self.children) - len(other.children) if n != 0: return n for sc, oc in zip(self.children, other.children): n = cmp(sc, oc) if n != 0: return n return 0 def __ne__(self, other): return not self.__eq__(other) def __hash__(self): return hash((self.data, tuple(self.children))) def __repr__(self): return str(self) def __str__(self): s = '(' space = False if self.order <= 0: s += str(self.data) space = True for i, child in enumerate(self.children): if space: s += ' ' s += str(child) space = True if i+1 == self.order: s += ' ' + str(self.data) return s + ')' def read_PSTs(fname): infile = open(fname) trees = [] for line in infile: trees.append(parse_PST(line.strip())) infile.close() return trees def parse_PST_multiline(infile, hash_is_var=True): buf = '' num_open = 0 while True: line = infile.readline() if not line: return None buf += ' ' + line.rstrip() num_open += line.count('(') - line.count(')') if num_open == 0: break return parse_PST(buf, hash_is_var) def parse_PST(line, hash_is_var=True): line = line.rstrip() if not line or line.lower() == 'null': return None # allow either (a/DT) or (DT a) #parts_re = re.compile(r'(\(*)([^/)]*)(?:/([^)]*))?(\)*)$') # only allow (DT a) parts_re = re.compile(r'(\(*)([^)]*)(\)*)$') root = TreeNode(Symbol('TOP')) stack = [root] for part in line.rstrip().split(): m = parts_re.match(part) #opening, tok_or_tag, tag, closing = m.groups() opening, tok_or_tag, closing = m.groups() tag = None #print 'token', part, 'bits', m.groups() for i in opening: node = TreeNode(Symbol(None)) stack[-1].insert(node) stack.append(node) if tag: stack[-1].data.tag = tag if hash_is_var and tok_or_tag.startswith('#'): stack[-1].data.variable = int(tok_or_tag[1:]) else: stack[-1].data.token = tok_or_tag else: if stack[-1].data.tag == None: stack[-1].data.tag = tok_or_tag else: if hash_is_var and tok_or_tag.startswith('#'): try: stack[-1].data.variable = int(tok_or_tag[1:]) except ValueError: # it's really a token! #print >>sys.stderr, 'Warning: # used for token:', tok_or_tag stack[-1].data.token = tok_or_tag else: stack[-1].data.token = tok_or_tag for i in closing: stack.pop() #assert str(root.children[0]) == line return root.children[0] def read_DTs(fname): infile = open(fname) trees = [] while True: t = parse_DT(infile) if t: trees.append(t) else: break infile.close() return trees def read_bracketed_DTs(fname): infile = open(fname) trees = [] for line in infile: trees.append(parse_bracketed_DT(line)) infile.close() return trees def parse_DT(infile): tokens = [Symbol('ROOT')] children = {} for line in infile: parts = line.rstrip().split() #print parts if not parts: break index = len(tokens) token = parts[1] tag = parts[3] parent = int(parts[6]) if token.startswith('#'): tokens.append(Symbol(tag, var=int(token[1:]))) else: tokens.append(Symbol(tag, token)) children.setdefault(parent, set()).add(index) if len(tokens) == 1: return None root = TreeNode(Symbol('ROOT'), [], 0) schedule = [] for child in sorted(children[0]): schedule.append((root, child)) while schedule: parent, index = schedule[0] del schedule[0] node = TreeNode(tokens[index]) node.order = 0 parent.insert(node) for child in sorted(children.get(index, [])): schedule.append((node, child)) if child < index: node.order += 1 return root _bracket_split_re = re.compile(r'([(]*)([^)/]*)(?:/([^)]*))?([)]*)') def parse_bracketed_DT(line, insert_root=True): line = line.rstrip() if not line or line == 'NULL': return None #print line root = TreeNode(Symbol('ROOT')) stack = [root] for part in line.rstrip().split(): m = _bracket_split_re.match(part) for c in m.group(1): node = TreeNode(Symbol(None)) stack[-1].insert(node) stack.append(node) if m.group(3) != None: if m.group(2).startswith('#'): stack[-1].data.variable = int(m.group(2)[1:]) else: stack[-1].data.token = m.group(2) stack[-1].data.tag = m.group(3) else: stack[-1].data.tag = m.group(2) stack[-1].order = len(stack[-1].children) # FIXME: also check for vars for c in m.group(4): stack.pop() assert len(stack) == 1 if not insert_root or root.children[0].data.tag == 'ROOT': return root.children[0] else: return root _bracket_split_notag_re = re.compile(r'([(]*)([^)/]*)([)]*)') def parse_bracketed_untagged_DT(line): line = line.rstrip() if not line or line == 'NULL': return None root = TreeNode(Symbol('TOP')) stack = [root] for part in line.rstrip().split(): m = _bracket_split_notag_re.match(part) for c in m.group(1): node = TreeNode(Symbol(None)) stack[-1].insert(node) stack.append(node) if stack[-1].data.token == None: stack[-1].data.token = m.group(2) stack[-1].order = len(stack[-1].children) else: child = TreeNode(Symbol(nonterm=None, term=m.group(2))) stack[-1].insert(child) for c in m.group(3): stack.pop() return root.children[0]