diff options
author | Patrick Simianer <p@simianer.de> | 2012-03-13 09:24:47 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2012-03-13 09:24:47 +0100 |
commit | ef6085e558e26c8819f1735425761103021b6470 (patch) | |
tree | 5cf70e4c48c64d838e1326b5a505c8c4061bff4a /sa-extract/cn.py | |
parent | 10a232656a0c882b3b955d2bcfac138ce11e8a2e (diff) | |
parent | dfbc278c1057555fda9312291c8024049e00b7d8 (diff) |
merge with upstream
Diffstat (limited to 'sa-extract/cn.py')
-rw-r--r-- | sa-extract/cn.py | 164 |
1 files changed, 164 insertions, 0 deletions
diff --git a/sa-extract/cn.py b/sa-extract/cn.py new file mode 100644 index 00000000..e534783f --- /dev/null +++ b/sa-extract/cn.py @@ -0,0 +1,164 @@ +# cn.py +# Chris Dyer <redpony@umd.edu> +# Copyright (c) 2006 University of Maryland. + +# vim:tabstop=4:autoindent:expandtab + +import sys +import math +import sym +import log +import sgml + +epsilon = sym.fromstring('*EPS*'); + +class CNStats(object): + def __init__(self): + self.read = 0 + self.colls = 0 + self.words = 0 + + def collect(self, cn): + self.read += 1 + self.colls += cn.get_length() + for col in cn.columns: + self.words += len(col) + + def __str__(self): + return "confusion net statistics:\n succ. read: %d\n columns: %d\n words: %d\n avg. words/column:\t%f\n avg. cols/sent:\t%f\n\n" % (self.read, self.colls, self.words, float(self.words)/float(self.colls), float(self.colls)/float(self.read)) + +class ConfusionNet(object): + def __init__(self, sent): + object.__init__(self) + if (len(sent.words) == 0): + self.columns = () + return # empty line, it happens + line = sent.words[0] + if (line.startswith("(((")): + if (len(sent.words) > 1): + log.write("Bad sentence: %s\n" % (line)) + assert(len(sent.words) == 1) # make sure there are no spaces in your confusion nets! + line = "((('<s>',1.0,1),),"+line[1:len(line)-1]+"(('</s>',1.0,1),))" + cols = eval(line) + res = [] + for col in cols: + x = [] + for alt in col: + costs = alt[1] + if (type(costs) != type((1,2))): + costs=(float(costs),) + j=[] + for c in costs: + j.append(float(c)) + cost = tuple(j) + spanlen = 1 + if (len(alt) == 3): + spanlen = alt[2] + x.append((sym.fromstring(alt[0],terminal=True), None, spanlen)) + res.append(tuple(x)) + self.columns = tuple(res) + else: # convert a string of input into a CN + res = []; + res.append(((sym.fromstring('<s>',terminal=True), None, 1), )) + for word in sent.words: + res.append(((sym.fromstring(word,terminal=True), None, 1), )); # (alt=word, cost=0.0) + res.append(((sym.fromstring('</s>',terminal=True), None, 1), )) + self.columns = tuple(res) + + def is_epsilon(self, position): + x = self.columns[position[0]][position[1]][0] + return x == epsilon + + def compute_epsilon_run_length(self, cn_path): + if (len(cn_path) == 0): + return 0 + x = len(cn_path) - 1 + res = 0 + ''' -1 denotes a non-terminal ''' + while (x >= 0 and cn_path[x][0] >= 0 and self.is_epsilon(cn_path[x])): + res += 1 + x -= 1 + return res + + def compute_cn_cost(self, cn_path): + c = None + for (col, row) in cn_path: + if (col >= 0): + if c is None: + c = self.columns[col][row][1].clone() + else: + c += self.columns[col][row][1] + return c + + def get_column(self, col): + return self.columns[col] + + def get_length(self): + return len(self.columns) + + def __str__(self): + r = "conf net: %d\n" % (len(self.columns),) + i = 0 + for col in self.columns: + r += "%d -- " % i + i += 1 + for alternative in col: + r += "(%s, %s, %s) " % (sym.tostring(alternative[0]), alternative[1], alternative[2]) + r += "\n" + return r + + def listdown(_columns, col = 0): + # output all the possible sentences out of the self lattice + # will be used by the "dumb" adaptation of lattice decoding with suffix array + result = [] + for entry in _columns[col]: + if col+entry[2]+1<=len(_columns) : + for suffix in self.listdown(_columns,col+entry[2]): + result.append(entry[0]+' '+suffix) + #result.append(entry[0]+' '+suffix) + else: + result.append(entry[0]) + #result.append(entry[0]) + return result + + def next(self,_columns,curr_idx, min_dist=1): + # can be used only when prev_id is defined + result = [] + #print "curr_idx=%i\n" % curr_idx + if curr_idx+min_dist >= len(_columns): + return result + for alt_idx in xrange(len(_columns[curr_idx])): + alt = _columns[curr_idx][alt_idx] + #print "checking %i alternative : " % alt_idx + #print "%s %f %i\n" % (alt[0],alt[1],alt[2]) + #print alt + if alt[2]<min_dist: + #print "recursive next(%i, %i, %i)\n" % (curr_idx,alt_idx,min_dist-alt[2]) + result.extend(self.next(_columns,curr_idx+alt[2],min_dist-alt[2])) + elif curr_idx+alt[2]<len(_columns): + #print "adding because the skip %i doesn't go beyong the length\n" % alt[2] + result.append(curr_idx+alt[2]) + return set(result) + + + + +#file = open(sys.argv[1], "rb") +#sent = sgml.process_sgml_line(file.read()) +#print sent +#cn = ConfusionNet(sent) +#print cn +#results = cn.listdown() +#for result in results: +# print sym.tostring(result) +#print cn.next(0); +#print cn.next(1); +#print cn.next(2); +#print cn.next(3); +#print cn +#cn = ConfusionNet() +#k = 0 +#while (cn.read(file)): +# print cn + +#print cn.stats |