summaryrefslogtreecommitdiff
path: root/sa-extract/cn.py
diff options
context:
space:
mode:
Diffstat (limited to 'sa-extract/cn.py')
-rw-r--r--sa-extract/cn.py164
1 files changed, 164 insertions, 0 deletions
diff --git a/sa-extract/cn.py b/sa-extract/cn.py
new file mode 100644
index 00000000..e534783f
--- /dev/null
+++ b/sa-extract/cn.py
@@ -0,0 +1,164 @@
+# cn.py
+# Chris Dyer <redpony@umd.edu>
+# Copyright (c) 2006 University of Maryland.
+
+# vim:tabstop=4:autoindent:expandtab
+
+import sys
+import math
+import sym
+import log
+import sgml
+
+epsilon = sym.fromstring('*EPS*');
+
+class CNStats(object):
+ def __init__(self):
+ self.read = 0
+ self.colls = 0
+ self.words = 0
+
+ def collect(self, cn):
+ self.read += 1
+ self.colls += cn.get_length()
+ for col in cn.columns:
+ self.words += len(col)
+
+ def __str__(self):
+ return "confusion net statistics:\n succ. read: %d\n columns: %d\n words: %d\n avg. words/column:\t%f\n avg. cols/sent:\t%f\n\n" % (self.read, self.colls, self.words, float(self.words)/float(self.colls), float(self.colls)/float(self.read))
+
+class ConfusionNet(object):
+ def __init__(self, sent):
+ object.__init__(self)
+ if (len(sent.words) == 0):
+ self.columns = ()
+ return # empty line, it happens
+ line = sent.words[0]
+ if (line.startswith("(((")):
+ if (len(sent.words) > 1):
+ log.write("Bad sentence: %s\n" % (line))
+ assert(len(sent.words) == 1) # make sure there are no spaces in your confusion nets!
+ line = "((('<s>',1.0,1),),"+line[1:len(line)-1]+"(('</s>',1.0,1),))"
+ cols = eval(line)
+ res = []
+ for col in cols:
+ x = []
+ for alt in col:
+ costs = alt[1]
+ if (type(costs) != type((1,2))):
+ costs=(float(costs),)
+ j=[]
+ for c in costs:
+ j.append(float(c))
+ cost = tuple(j)
+ spanlen = 1
+ if (len(alt) == 3):
+ spanlen = alt[2]
+ x.append((sym.fromstring(alt[0],terminal=True), None, spanlen))
+ res.append(tuple(x))
+ self.columns = tuple(res)
+ else: # convert a string of input into a CN
+ res = [];
+ res.append(((sym.fromstring('<s>',terminal=True), None, 1), ))
+ for word in sent.words:
+ res.append(((sym.fromstring(word,terminal=True), None, 1), )); # (alt=word, cost=0.0)
+ res.append(((sym.fromstring('</s>',terminal=True), None, 1), ))
+ self.columns = tuple(res)
+
+ def is_epsilon(self, position):
+ x = self.columns[position[0]][position[1]][0]
+ return x == epsilon
+
+ def compute_epsilon_run_length(self, cn_path):
+ if (len(cn_path) == 0):
+ return 0
+ x = len(cn_path) - 1
+ res = 0
+ ''' -1 denotes a non-terminal '''
+ while (x >= 0 and cn_path[x][0] >= 0 and self.is_epsilon(cn_path[x])):
+ res += 1
+ x -= 1
+ return res
+
+ def compute_cn_cost(self, cn_path):
+ c = None
+ for (col, row) in cn_path:
+ if (col >= 0):
+ if c is None:
+ c = self.columns[col][row][1].clone()
+ else:
+ c += self.columns[col][row][1]
+ return c
+
+ def get_column(self, col):
+ return self.columns[col]
+
+ def get_length(self):
+ return len(self.columns)
+
+ def __str__(self):
+ r = "conf net: %d\n" % (len(self.columns),)
+ i = 0
+ for col in self.columns:
+ r += "%d -- " % i
+ i += 1
+ for alternative in col:
+ r += "(%s, %s, %s) " % (sym.tostring(alternative[0]), alternative[1], alternative[2])
+ r += "\n"
+ return r
+
+ def listdown(_columns, col = 0):
+ # output all the possible sentences out of the self lattice
+ # will be used by the "dumb" adaptation of lattice decoding with suffix array
+ result = []
+ for entry in _columns[col]:
+ if col+entry[2]+1<=len(_columns) :
+ for suffix in self.listdown(_columns,col+entry[2]):
+ result.append(entry[0]+' '+suffix)
+ #result.append(entry[0]+' '+suffix)
+ else:
+ result.append(entry[0])
+ #result.append(entry[0])
+ return result
+
+ def next(self,_columns,curr_idx, min_dist=1):
+ # can be used only when prev_id is defined
+ result = []
+ #print "curr_idx=%i\n" % curr_idx
+ if curr_idx+min_dist >= len(_columns):
+ return result
+ for alt_idx in xrange(len(_columns[curr_idx])):
+ alt = _columns[curr_idx][alt_idx]
+ #print "checking %i alternative : " % alt_idx
+ #print "%s %f %i\n" % (alt[0],alt[1],alt[2])
+ #print alt
+ if alt[2]<min_dist:
+ #print "recursive next(%i, %i, %i)\n" % (curr_idx,alt_idx,min_dist-alt[2])
+ result.extend(self.next(_columns,curr_idx+alt[2],min_dist-alt[2]))
+ elif curr_idx+alt[2]<len(_columns):
+ #print "adding because the skip %i doesn't go beyong the length\n" % alt[2]
+ result.append(curr_idx+alt[2])
+ return set(result)
+
+
+
+
+#file = open(sys.argv[1], "rb")
+#sent = sgml.process_sgml_line(file.read())
+#print sent
+#cn = ConfusionNet(sent)
+#print cn
+#results = cn.listdown()
+#for result in results:
+# print sym.tostring(result)
+#print cn.next(0);
+#print cn.next(1);
+#print cn.next(2);
+#print cn.next(3);
+#print cn
+#cn = ConfusionNet()
+#k = 0
+#while (cn.read(file)):
+# print cn
+
+#print cn.stats