# clex.pyx
# defines bilexical dictionaries in C, with some convenience methods
# for reading arrays directly as globs directly from disk.
# Adam Lopez <alopez@cs.umd.edu>

import gzip
import sys
import context_model

cimport cintlist
cimport cfloatlist
cimport calignment
cimport csuf
cimport cdat

from libc.stdio cimport FILE, fopen, fread, fwrite, fclose
from libc.stdlib cimport malloc, realloc, free
from libc.string cimport memset, strcpy, strlen

cdef struct _node:
  _node* smaller
  _node* bigger
  int key
  int val

cdef _node* new_node(int key):
  cdef _node* n
  n = <_node*> malloc(sizeof(_node))
  n.smaller = NULL
  n.bigger = NULL
  n.key = key
  n.val = 0
  return n


cdef del_node(_node* n):
  if n.smaller != NULL:
    del_node(n.smaller)
  if n.bigger != NULL:
    del_node(n.bigger)
  free(n)


cdef int* get_val(_node* n, int key):
  if key == n.key:
    return &n.val
  elif key < n.key:
    if n.smaller == NULL:
      n.smaller = new_node(key)
      return &(n.smaller.val)
    return get_val(n.smaller, key)
  else:
    if n.bigger == NULL:
      n.bigger = new_node(key)
      return &(n.bigger.val)
    return get_val(n.bigger, key)


cdef class CLex:

  cdef cfloatlist.CFloatList col1
  cdef cfloatlist.CFloatList col2
  cdef cintlist.CIntList f_index
  cdef cintlist.CIntList e_index
  cdef id2eword, id2fword, eword2id, fword2id

  def __init__(self, filename, from_binary=False, 
        from_data=False, earray=None, fsarray=None):
    self.id2eword = []
    self.id2fword = []
    self.eword2id = {}
    self.fword2id = {}
    self.e_index = cintlist.CIntList()
    self.f_index = cintlist.CIntList()
    self.col1 = cfloatlist.CFloatList()
    self.col2 = cfloatlist.CFloatList()
    if from_binary:
      self.read_binary(filename)
    else:
      if from_data:
        self.compute_from_data(fsarray, earray, filename)
      else:
        self.read_text(filename)
    '''print "self.eword2id"
    print "============="
    for x in self.eword2id:
      print x
    print "self.fword2id"
    print "============="
    for x in self.fword2id:
      print x
    print "-------------"'''


  cdef compute_from_data(self, csuf.SuffixArray fsa, cdat.DataArray eda, calignment.Alignment aa):
    cdef int sent_id, num_links, l, i, j, f_i, e_j, I, J, V_E, V_F, num_pairs
    cdef int *fsent, *esent, *alignment, *links, *ealigned, *faligned
    cdef _node** dict
    cdef int *fmargin, *emargin, *count
    cdef bytes word
    cdef int null_word

    null_word = 0
    for word in fsa.darray.id2word: # I miss list comprehensions
      self.id2fword.append(word)
    self.id2fword[null_word] = "NULL"
    for id, word in enumerate(self.id2fword):
      self.fword2id[word] = id

    for word in eda.id2word:
      self.id2eword.append(word)
    self.id2eword[null_word] = "NULL"
    for id, word in enumerate(self.id2eword):
      self.eword2id[word] = id

    num_pairs = 0

    V_E = len(eda.id2word)
    V_F = len(fsa.darray.id2word)
    fmargin = <int*> malloc(V_F*sizeof(int))
    emargin = <int*> malloc(V_E*sizeof(int))
    memset(fmargin, 0, V_F*sizeof(int))
    memset(emargin, 0, V_E*sizeof(int))

    dict = <_node**> malloc(V_F*sizeof(_node*))
    memset(dict, 0, V_F*sizeof(_node*))

    num_sents = len(fsa.darray.sent_index)
    for sent_id from 0 <= sent_id < num_sents-1:

      fsent = fsa.darray.data.arr + fsa.darray.sent_index.arr[sent_id]
      I = fsa.darray.sent_index.arr[sent_id+1] - fsa.darray.sent_index.arr[sent_id] - 1
      faligned = <int*> malloc(I*sizeof(int))
      memset(faligned, 0, I*sizeof(int))

      esent = eda.data.arr + eda.sent_index.arr[sent_id]
      J = eda.sent_index.arr[sent_id+1] - eda.sent_index.arr[sent_id] - 1
      ealigned = <int*> malloc(J*sizeof(int))
      memset(ealigned, 0, J*sizeof(int))

      links = aa._get_sent_links(sent_id, &num_links)

      for l from 0 <= l < num_links:
        i = links[l*2]
        j = links[l*2+1]
        if i >= I or j >= J:
          sys.stderr.write("  %d-%d out of bounds (I=%d,J=%d) in line %d\n" % (i,j,I,J,sent_id+1))
          assert i < I
          assert j < J
        f_i = fsent[i]
        e_j = esent[j]
        fmargin[f_i] = fmargin[f_i]+1
        emargin[e_j] = emargin[e_j]+1
        if dict[f_i] == NULL:
          dict[f_i] = new_node(e_j)
          dict[f_i].val = 1
          num_pairs = num_pairs + 1
        else:
          count = get_val(dict[f_i], e_j)
          if count[0] == 0:
            num_pairs = num_pairs + 1
          count[0] = count[0] + 1
        # add count
        faligned[i] = 1
        ealigned[j] = 1
      for i from 0 <= i < I:
        if faligned[i] == 0:
          f_i = fsent[i]
          fmargin[f_i] = fmargin[f_i] + 1
          emargin[null_word] = emargin[null_word] + 1
          if dict[f_i] == NULL:
            dict[f_i] = new_node(null_word)
            dict[f_i].val = 1
            num_pairs = num_pairs + 1
          else:
            count = get_val(dict[f_i], null_word)
            if count[0] == 0:
              num_pairs = num_pairs + 1
            count[0] = count[0] + 1
      for j from 0 <= j < J:
        if ealigned[j] == 0:
          e_j = esent[j]
          fmargin[null_word] = fmargin[null_word] + 1
          emargin[e_j] = emargin[e_j] + 1
          if dict[null_word] == NULL:
            dict[null_word] = new_node(e_j)
            dict[null_word].val = 1
            num_pairs = num_pairs + 1
          else:
            count = get_val(dict[null_word], e_j)
            if count[0] == 0:
              num_pairs = num_pairs + 1
            count[0] = count[0] + 1
      free(links)
      free(faligned)
      free(ealigned)
    self.f_index = cintlist.CIntList(initial_len=V_F)
    self.e_index = cintlist.CIntList(initial_len=num_pairs)
    self.col1 = cfloatlist.CFloatList(initial_len=num_pairs)
    self.col2 = cfloatlist.CFloatList(initial_len=num_pairs)

    num_pairs = 0
    for i from 0 <= i < V_F:
      #self.f_index[i] = num_pairs
      self.f_index.set(i, num_pairs)
      if dict[i] != NULL:
        self._add_node(dict[i], &num_pairs, float(fmargin[i]), emargin)
        del_node(dict[i])
    free(fmargin)
    free(emargin)
    free(dict)
    return


  cdef _add_node(self, _node* n, int* num_pairs, float fmargin, int* emargin):
    cdef int loc
    if n.smaller != NULL:
      self._add_node(n.smaller, num_pairs, fmargin, emargin)
    loc = num_pairs[0]
    self.e_index.set(loc, n.key)
    self.col1.set(loc, float(n.val)/fmargin)
    self.col2.set(loc, float(n.val)/float(emargin[n.key]))
    num_pairs[0] = loc + 1
    if n.bigger != NULL:
      self._add_node(n.bigger, num_pairs, fmargin, emargin)


  def write_binary(self, filename):
    cdef FILE* f
    cdef bytes bfilename = filename
    cdef char* cfilename = bfilename
    f = fopen(cfilename, "w")
    self.f_index.write_handle(f)
    self.e_index.write_handle(f)
    self.col1.write_handle(f)
    self.col2.write_handle(f)
    self.write_wordlist(self.id2fword, f)
    self.write_wordlist(self.id2eword, f)
    fclose(f)


  cdef write_wordlist(self, wordlist, FILE* f):
    cdef int word_len
    cdef int num_words
    cdef char* c_word

    num_words = len(wordlist)
    fwrite(&(num_words), sizeof(int), 1, f)
    for word in wordlist:
      c_word = word
      word_len = strlen(c_word) + 1
      fwrite(&(word_len), sizeof(int), 1, f)
      fwrite(c_word, sizeof(char), word_len, f)


  cdef read_wordlist(self, word2id, id2word, FILE* f):
    cdef int num_words
    cdef int word_len
    cdef char* c_word
    cdef bytes py_word

    fread(&(num_words), sizeof(int), 1, f)
    for i from 0 <= i < num_words:
      fread(&(word_len), sizeof(int), 1, f)
      c_word = <char*> malloc (word_len * sizeof(char))
      fread(c_word, sizeof(char), word_len, f)
      py_word = c_word
      free(c_word)
      word2id[py_word] = len(id2word)
      id2word.append(py_word)

  def read_binary(self, filename):
    cdef FILE* f
    cdef bytes bfilename = filename
    cdef char* cfilename = bfilename
    f = fopen(cfilename, "r")
    self.f_index.read_handle(f)
    self.e_index.read_handle(f)
    self.col1.read_handle(f)
    self.col2.read_handle(f)
    self.read_wordlist(self.fword2id, self.id2fword, f)
    self.read_wordlist(self.eword2id, self.id2eword, f)
    fclose(f)


  def get_e_id(self, eword):
    if eword not in self.eword2id:
      e_id = len(self.id2eword)
      self.id2eword.append(eword)
      self.eword2id[eword] = e_id
    return self.eword2id[eword]


  def get_f_id(self, fword):
    if fword not in self.fword2id:
      f_id = len(self.id2fword)
      self.id2fword.append(fword)
      self.fword2id[fword] = f_id
    return self.fword2id[fword]


  def read_text(self, filename):
    cdef i, j, w, e_id, f_id, n_f, n_e, N
    cdef cintlist.CIntList fcount 

    fcount = cintlist.CIntList()
    if filename[-2:] == "gz":
      f = gzip.GzipFile(filename)
    else:
      f = open(filename)

    # first loop merely establishes size of array objects
    sys.stderr.write("Initial read...\n")
    for line in f:
      (fword, eword, score1, score2) = line.split()
      f_id = self.get_f_id(fword)
      e_id = self.get_e_id(eword)
      while f_id >= len(fcount):
        fcount.append(0)
      fcount.arr[f_id] = fcount.arr[f_id] + 1

    # Allocate space for dictionary in arrays
    N = 0
    n_f = len(fcount)
    self.f_index = cintlist.CIntList(initial_len=n_f+1)
    for i from 0 <= i < n_f:
      self.f_index.arr[i] = N
      N = N + fcount.arr[i]
      fcount.arr[i] = 0
    self.f_index.arr[n_f] = N
    self.e_index = cintlist.CIntList(initial_len=N)
    self.col1 = cfloatlist.CFloatList(initial_len=N)
    self.col2 = cfloatlist.CFloatList(initial_len=N)

    # Re-read file, placing words into buckets
    sys.stderr.write("Bucket sort...\n")
    f.seek(0)
    for line in f:
      (fword, eword, score1, score2) = line.split()
      f_id = self.get_f_id(fword)
      e_id = self.get_e_id(eword)
      index = self.f_index.arr[f_id] + fcount.arr[f_id]
      fcount.arr[f_id] = fcount.arr[f_id] + 1
      self.e_index.arr[index] = int(e_id)
      self.col1[index] = float(score1)
      self.col2[index] = float(score2)
    f.close()

    sys.stderr.write("Final sort...\n")
    # Sort buckets by eword
    for b from 0 <= b < n_f:
      i = self.f_index.arr[b]
      j = self.f_index.arr[b+1]
      self.qsort(i,j, "")


  cdef swap(self, int i, int j):
    cdef int itmp
    cdef float ftmp

    if i == j:
      return

    itmp = self.e_index.arr[i]
    self.e_index.arr[i] = self.e_index.arr[j]
    self.e_index.arr[j] = itmp

    ftmp = self.col1.arr[i]
    self.col1.arr[i] = self.col1.arr[j]
    self.col1.arr[j] = ftmp

    ftmp = self.col2.arr[i]
    self.col2.arr[i] = self.col2.arr[j]
    self.col2.arr[j] = ftmp


  cdef qsort(self, int i, int j, pad):
    cdef int pval, p

    if i > j:
      raise Exception("Sort error in CLex")
    if i == j: #empty interval
      return
    if i == j-1: # singleton interval
      return

    p = (i+j)/2
    pval = self.e_index.arr[p]
    self.swap(i, p)
    p = i
    for k from i+1 <= k < j:
      if pval >= self.e_index.arr[k]:
        self.swap(p+1, k)
        self.swap(p, p+1)
        p = p + 1
    self.qsort(i,p, pad+"  ")
    self.qsort(p+1,j, pad+"  ")


  def write_enhanced(self, filename):
    f = open(filename, "w")
    for i in self.f_index:
      f.write("%d " % i)
    f.write("\n")
    for i, s1, s2 in zip(self.e_index, self.col1, self.col2):
      f.write("%d %f %f " % (i, s1, s2))
    f.write("\n")
    for i, w in enumerate(self.id2fword):
      f.write("%d %s " % (i, w))
    f.write("\n")
    for i, w in enumerate(self.id2eword):
      f.write("%d %s " % (i, w))
    f.write("\n")
    f.close()


  def get_score(self, fword, eword, col):
    cdef e_id, f_id, low, high, midpoint, val
    #print "get_score fword=",fword,", eword=",eword,", col=",col

    if eword not in self.eword2id:
      return None
    if fword not in self.fword2id:
      return None
    f_id = self.fword2id[fword]
    e_id = self.eword2id[eword]
    low = self.f_index.arr[f_id]
    high = self.f_index.arr[f_id+1]
    while high - low > 0:
      midpoint = (low+high)/2
      val = self.e_index.arr[midpoint]
      if val == e_id:
        if col == 0:
          return self.col1.arr[midpoint]
        if col == 1:
          return self.col2.arr[midpoint]
      if val > e_id:
        high = midpoint
      if val < e_id:
        low = midpoint + 1
    return None


  def write_text(self, filename):
    """Note: does not guarantee writing the dictionary in the original order"""
    cdef i, N, e_id, f_id
    
    f = open(filename, "w")
    N = len(self.e_index)
    f_id = 0
    for i from 0 <= i < N:
      while self.f_index.arr[f_id+1] == i:
        f_id = f_id + 1
      e_id = self.e_index.arr[i]
      score1 = self.col1.arr[i]
      score2 = self.col2.arr[i]
      f.write("%s %s %.6f %.6f\n" % (self.id2fword[f_id], self.id2eword[e_id], score1, score2))
    f.close()