# cn.py
# Chris Dyer <redpony@umd.edu>
# Copyright (c) 2006 University of Maryland.

# vim:tabstop=4:autoindent:expandtab

import sys
import math
import sym
import log
import sgml

epsilon = sym.fromstring('*EPS*');

class CNStats(object):
    def __init__(self):
      self.read = 0
      self.colls = 0
      self.words = 0

    def collect(self, cn):
      self.read += 1
      self.colls += cn.get_length()
      for col in cn.columns:
        self.words += len(col)

    def __str__(self):
      return "confusion net statistics:\n succ. read: %d\n columns:    %d\n words:      %d\n avg. words/column:\t%f\n avg. cols/sent:\t%f\n\n" % (self.read, self.colls, self.words, float(self.words)/float(self.colls), float(self.colls)/float(self.read))

class ConfusionNet(object):
    def __init__(self, sent):
        object.__init__(self)
        if (len(sent.words) == 0):
            self.columns = ()
            return # empty line, it happens
        line = sent.words[0]
        if (line.startswith("(((")):
            if (len(sent.words) > 1):
                log.write("Bad sentence: %s\n" % (line))
            assert(len(sent.words) == 1) # make sure there are no spaces in your confusion nets!
            line =  "((('<s>',1.0,1),),"+line[1:len(line)-1]+"(('</s>',1.0,1),))"
            cols = eval(line)
            res = []
            for col in cols:
               x = []
               for alt in col:
                   costs = alt[1]
                   if (type(costs) != type((1,2))):
                       costs=(float(costs),)
                   j=[]
                   for c in costs:
                       j.append(float(c))
                   cost = tuple(j)
                   spanlen = 1
                   if (len(alt) == 3):
                       spanlen = alt[2]
                   x.append((sym.fromstring(alt[0],terminal=True), None, spanlen))
               res.append(tuple(x))
            self.columns = tuple(res)
        else:  # convert a string of input into a CN
            res = [];
            res.append(((sym.fromstring('<s>',terminal=True), None, 1), ))
            for word in sent.words:
               res.append(((sym.fromstring(word,terminal=True), None, 1), ));  # (alt=word, cost=0.0)
            res.append(((sym.fromstring('</s>',terminal=True), None, 1), ))
            self.columns = tuple(res)

    def is_epsilon(self, position):
        x = self.columns[position[0]][position[1]][0]
        return x == epsilon

    def compute_epsilon_run_length(self, cn_path):
        if (len(cn_path) == 0):
            return 0
        x = len(cn_path) - 1
        res = 0
        ''' -1 denotes a non-terminal '''
        while (x >= 0 and cn_path[x][0] >= 0 and self.is_epsilon(cn_path[x])):
            res += 1
            x -= 1
        return res

    def compute_cn_cost(self, cn_path):
       c = None
       for (col, row) in cn_path:
           if (col >= 0):
               if c is None:
                   c = self.columns[col][row][1].clone()
               else:
                   c += self.columns[col][row][1]
       return c                                        

    def get_column(self, col):
        return self.columns[col]

    def get_length(self):
        return len(self.columns)

    def __str__(self):
        r = "conf net: %d\n" % (len(self.columns),)
        i = 0
        for col in self.columns:
            r += "%d -- " % i
            i += 1
            for alternative in col:
                r += "(%s, %s, %s) " % (sym.tostring(alternative[0]), alternative[1], alternative[2])
            r += "\n"
        return r
    
    def listdown(_columns, col = 0):
        # output all the possible sentences out of the self lattice
        # will be used by the "dumb" adaptation of lattice decoding with suffix array
        result = []
        for entry in _columns[col]:
            if col+entry[2]+1<=len(_columns) :
                for suffix in self.listdown(_columns,col+entry[2]):
                    result.append(entry[0]+' '+suffix)
                    #result.append(entry[0]+' '+suffix)
            else:
                result.append(entry[0])
                #result.append(entry[0])
        return result
    
    def next(self,_columns,curr_idx, min_dist=1):
        # can be used only when prev_id is defined
        result = []
        #print "curr_idx=%i\n" % curr_idx
        if curr_idx+min_dist >= len(_columns): 
            return result
        for alt_idx in xrange(len(_columns[curr_idx])):
            alt = _columns[curr_idx][alt_idx]
            #print "checking %i alternative : " % alt_idx
            #print "%s %f %i\n" % (alt[0],alt[1],alt[2])
            #print alt
            if alt[2]<min_dist:
                #print "recursive next(%i, %i, %i)\n" % (curr_idx,alt_idx,min_dist-alt[2])
                result.extend(self.next(_columns,curr_idx+alt[2],min_dist-alt[2]))
            elif curr_idx+alt[2]<len(_columns):
                #print "adding because the skip %i doesn't go beyong the length\n" % alt[2]
                result.append(curr_idx+alt[2])
        return set(result)
                
    


#file = open(sys.argv[1], "rb")
#sent = sgml.process_sgml_line(file.read())
#print sent
#cn = ConfusionNet(sent)
#print cn
#results = cn.listdown()
#for result in results:
#    print sym.tostring(result)
#print cn.next(0);
#print cn.next(1);
#print cn.next(2);
#print cn.next(3);
#print cn
#cn = ConfusionNet()
#k = 0
#while (cn.read(file)):
#  print cn
  
#print cn.stats