diff options
Diffstat (limited to 'sa-extract/cn.py')
-rw-r--r-- | sa-extract/cn.py | 164 |
1 files changed, 0 insertions, 164 deletions
diff --git a/sa-extract/cn.py b/sa-extract/cn.py deleted file mode 100644 index e534783f..00000000 --- a/sa-extract/cn.py +++ /dev/null @@ -1,164 +0,0 @@ -# cn.py -# Chris Dyer <redpony@umd.edu> -# Copyright (c) 2006 University of Maryland. - -# vim:tabstop=4:autoindent:expandtab - -import sys -import math -import sym -import log -import sgml - -epsilon = sym.fromstring('*EPS*'); - -class CNStats(object): - def __init__(self): - self.read = 0 - self.colls = 0 - self.words = 0 - - def collect(self, cn): - self.read += 1 - self.colls += cn.get_length() - for col in cn.columns: - self.words += len(col) - - def __str__(self): - return "confusion net statistics:\n succ. read: %d\n columns: %d\n words: %d\n avg. words/column:\t%f\n avg. cols/sent:\t%f\n\n" % (self.read, self.colls, self.words, float(self.words)/float(self.colls), float(self.colls)/float(self.read)) - -class ConfusionNet(object): - def __init__(self, sent): - object.__init__(self) - if (len(sent.words) == 0): - self.columns = () - return # empty line, it happens - line = sent.words[0] - if (line.startswith("(((")): - if (len(sent.words) > 1): - log.write("Bad sentence: %s\n" % (line)) - assert(len(sent.words) == 1) # make sure there are no spaces in your confusion nets! - line = "((('<s>',1.0,1),),"+line[1:len(line)-1]+"(('</s>',1.0,1),))" - cols = eval(line) - res = [] - for col in cols: - x = [] - for alt in col: - costs = alt[1] - if (type(costs) != type((1,2))): - costs=(float(costs),) - j=[] - for c in costs: - j.append(float(c)) - cost = tuple(j) - spanlen = 1 - if (len(alt) == 3): - spanlen = alt[2] - x.append((sym.fromstring(alt[0],terminal=True), None, spanlen)) - res.append(tuple(x)) - self.columns = tuple(res) - else: # convert a string of input into a CN - res = []; - res.append(((sym.fromstring('<s>',terminal=True), None, 1), )) - for word in sent.words: - res.append(((sym.fromstring(word,terminal=True), None, 1), )); # (alt=word, cost=0.0) - res.append(((sym.fromstring('</s>',terminal=True), None, 1), )) - self.columns = tuple(res) - - def is_epsilon(self, position): - x = self.columns[position[0]][position[1]][0] - return x == epsilon - - def compute_epsilon_run_length(self, cn_path): - if (len(cn_path) == 0): - return 0 - x = len(cn_path) - 1 - res = 0 - ''' -1 denotes a non-terminal ''' - while (x >= 0 and cn_path[x][0] >= 0 and self.is_epsilon(cn_path[x])): - res += 1 - x -= 1 - return res - - def compute_cn_cost(self, cn_path): - c = None - for (col, row) in cn_path: - if (col >= 0): - if c is None: - c = self.columns[col][row][1].clone() - else: - c += self.columns[col][row][1] - return c - - def get_column(self, col): - return self.columns[col] - - def get_length(self): - return len(self.columns) - - def __str__(self): - r = "conf net: %d\n" % (len(self.columns),) - i = 0 - for col in self.columns: - r += "%d -- " % i - i += 1 - for alternative in col: - r += "(%s, %s, %s) " % (sym.tostring(alternative[0]), alternative[1], alternative[2]) - r += "\n" - return r - - def listdown(_columns, col = 0): - # output all the possible sentences out of the self lattice - # will be used by the "dumb" adaptation of lattice decoding with suffix array - result = [] - for entry in _columns[col]: - if col+entry[2]+1<=len(_columns) : - for suffix in self.listdown(_columns,col+entry[2]): - result.append(entry[0]+' '+suffix) - #result.append(entry[0]+' '+suffix) - else: - result.append(entry[0]) - #result.append(entry[0]) - return result - - def next(self,_columns,curr_idx, min_dist=1): - # can be used only when prev_id is defined - result = [] - #print "curr_idx=%i\n" % curr_idx - if curr_idx+min_dist >= len(_columns): - return result - for alt_idx in xrange(len(_columns[curr_idx])): - alt = _columns[curr_idx][alt_idx] - #print "checking %i alternative : " % alt_idx - #print "%s %f %i\n" % (alt[0],alt[1],alt[2]) - #print alt - if alt[2]<min_dist: - #print "recursive next(%i, %i, %i)\n" % (curr_idx,alt_idx,min_dist-alt[2]) - result.extend(self.next(_columns,curr_idx+alt[2],min_dist-alt[2])) - elif curr_idx+alt[2]<len(_columns): - #print "adding because the skip %i doesn't go beyong the length\n" % alt[2] - result.append(curr_idx+alt[2]) - return set(result) - - - - -#file = open(sys.argv[1], "rb") -#sent = sgml.process_sgml_line(file.read()) -#print sent -#cn = ConfusionNet(sent) -#print cn -#results = cn.listdown() -#for result in results: -# print sym.tostring(result) -#print cn.next(0); -#print cn.next(1); -#print cn.next(2); -#print cn.next(3); -#print cn -#cn = ConfusionNet() -#k = 0 -#while (cn.read(file)): -# print cn - -#print cn.stats |