summaryrefslogtreecommitdiff
path: root/sa-extract/cn.py
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-08-03 07:46:54 -0400
committerKenneth Heafield <github@kheafield.com>2012-08-03 07:46:54 -0400
commitbe1ab0a8937f9c5668ea5e6c31b798e87672e55e (patch)
treea13aad60ab6cced213401bce6a38ac885ba171ba /sa-extract/cn.py
parente5d6f4ae41009c26978ecd62668501af9762b0bc (diff)
parent9fe0219562e5db25171cce8776381600ff9a5649 (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'sa-extract/cn.py')
-rw-r--r--sa-extract/cn.py164
1 files changed, 0 insertions, 164 deletions
diff --git a/sa-extract/cn.py b/sa-extract/cn.py
deleted file mode 100644
index e534783f..00000000
--- a/sa-extract/cn.py
+++ /dev/null
@@ -1,164 +0,0 @@
-# cn.py
-# Chris Dyer <redpony@umd.edu>
-# Copyright (c) 2006 University of Maryland.
-
-# vim:tabstop=4:autoindent:expandtab
-
-import sys
-import math
-import sym
-import log
-import sgml
-
-epsilon = sym.fromstring('*EPS*');
-
-class CNStats(object):
- def __init__(self):
- self.read = 0
- self.colls = 0
- self.words = 0
-
- def collect(self, cn):
- self.read += 1
- self.colls += cn.get_length()
- for col in cn.columns:
- self.words += len(col)
-
- def __str__(self):
- return "confusion net statistics:\n succ. read: %d\n columns: %d\n words: %d\n avg. words/column:\t%f\n avg. cols/sent:\t%f\n\n" % (self.read, self.colls, self.words, float(self.words)/float(self.colls), float(self.colls)/float(self.read))
-
-class ConfusionNet(object):
- def __init__(self, sent):
- object.__init__(self)
- if (len(sent.words) == 0):
- self.columns = ()
- return # empty line, it happens
- line = sent.words[0]
- if (line.startswith("(((")):
- if (len(sent.words) > 1):
- log.write("Bad sentence: %s\n" % (line))
- assert(len(sent.words) == 1) # make sure there are no spaces in your confusion nets!
- line = "((('<s>',1.0,1),),"+line[1:len(line)-1]+"(('</s>',1.0,1),))"
- cols = eval(line)
- res = []
- for col in cols:
- x = []
- for alt in col:
- costs = alt[1]
- if (type(costs) != type((1,2))):
- costs=(float(costs),)
- j=[]
- for c in costs:
- j.append(float(c))
- cost = tuple(j)
- spanlen = 1
- if (len(alt) == 3):
- spanlen = alt[2]
- x.append((sym.fromstring(alt[0],terminal=True), None, spanlen))
- res.append(tuple(x))
- self.columns = tuple(res)
- else: # convert a string of input into a CN
- res = [];
- res.append(((sym.fromstring('<s>',terminal=True), None, 1), ))
- for word in sent.words:
- res.append(((sym.fromstring(word,terminal=True), None, 1), )); # (alt=word, cost=0.0)
- res.append(((sym.fromstring('</s>',terminal=True), None, 1), ))
- self.columns = tuple(res)
-
- def is_epsilon(self, position):
- x = self.columns[position[0]][position[1]][0]
- return x == epsilon
-
- def compute_epsilon_run_length(self, cn_path):
- if (len(cn_path) == 0):
- return 0
- x = len(cn_path) - 1
- res = 0
- ''' -1 denotes a non-terminal '''
- while (x >= 0 and cn_path[x][0] >= 0 and self.is_epsilon(cn_path[x])):
- res += 1
- x -= 1
- return res
-
- def compute_cn_cost(self, cn_path):
- c = None
- for (col, row) in cn_path:
- if (col >= 0):
- if c is None:
- c = self.columns[col][row][1].clone()
- else:
- c += self.columns[col][row][1]
- return c
-
- def get_column(self, col):
- return self.columns[col]
-
- def get_length(self):
- return len(self.columns)
-
- def __str__(self):
- r = "conf net: %d\n" % (len(self.columns),)
- i = 0
- for col in self.columns:
- r += "%d -- " % i
- i += 1
- for alternative in col:
- r += "(%s, %s, %s) " % (sym.tostring(alternative[0]), alternative[1], alternative[2])
- r += "\n"
- return r
-
- def listdown(_columns, col = 0):
- # output all the possible sentences out of the self lattice
- # will be used by the "dumb" adaptation of lattice decoding with suffix array
- result = []
- for entry in _columns[col]:
- if col+entry[2]+1<=len(_columns) :
- for suffix in self.listdown(_columns,col+entry[2]):
- result.append(entry[0]+' '+suffix)
- #result.append(entry[0]+' '+suffix)
- else:
- result.append(entry[0])
- #result.append(entry[0])
- return result
-
- def next(self,_columns,curr_idx, min_dist=1):
- # can be used only when prev_id is defined
- result = []
- #print "curr_idx=%i\n" % curr_idx
- if curr_idx+min_dist >= len(_columns):
- return result
- for alt_idx in xrange(len(_columns[curr_idx])):
- alt = _columns[curr_idx][alt_idx]
- #print "checking %i alternative : " % alt_idx
- #print "%s %f %i\n" % (alt[0],alt[1],alt[2])
- #print alt
- if alt[2]<min_dist:
- #print "recursive next(%i, %i, %i)\n" % (curr_idx,alt_idx,min_dist-alt[2])
- result.extend(self.next(_columns,curr_idx+alt[2],min_dist-alt[2]))
- elif curr_idx+alt[2]<len(_columns):
- #print "adding because the skip %i doesn't go beyong the length\n" % alt[2]
- result.append(curr_idx+alt[2])
- return set(result)
-
-
-
-
-#file = open(sys.argv[1], "rb")
-#sent = sgml.process_sgml_line(file.read())
-#print sent
-#cn = ConfusionNet(sent)
-#print cn
-#results = cn.listdown()
-#for result in results:
-# print sym.tostring(result)
-#print cn.next(0);
-#print cn.next(1);
-#print cn.next(2);
-#print cn.next(3);
-#print cn
-#cn = ConfusionNet()
-#k = 0
-#while (cn.read(file)):
-# print cn
-
-#print cn.stats