# precomputes a set of collocations by advancing over the text.
# warning: nasty C code

import log
import monitor

cimport csuf
cimport cdat
cimport cintlist

from libc.stdio cimport FILE, fopen, fread, fwrite, fclose
from libc.stdlib cimport malloc, realloc, free
from libc.string cimport memset, memcpy

cdef struct _Trie_Node  # forward decl

cdef struct _Trie_Edge:
  int val
  _Trie_Node* node
  _Trie_Edge* bigger
  _Trie_Edge* smaller

cdef struct _Trie_Node:
  _Trie_Edge* root
  int* arr
  int arr_len

cdef _Trie_Node* new_trie_node():
  cdef _Trie_Node* node
  node = <_Trie_Node*> malloc(sizeof(_Trie_Node))
  node.root = NULL
  node.arr_len = 0
  node.arr = <int*> malloc(sizeof(0*sizeof(int)))
  return node

cdef _Trie_Edge* new_trie_edge(int val):
  cdef _Trie_Edge* edge
  edge = <_Trie_Edge*> malloc(sizeof(_Trie_Edge))
  edge.node = new_trie_node()
  edge.bigger = NULL
  edge.smaller = NULL
  edge.val = val
  return edge

cdef free_trie_node(_Trie_Node* node):
  if node != NULL:
    free_trie_edge(node.root)
    free(node.arr)

cdef free_trie_edge(_Trie_Edge* edge):
  if edge != NULL:
    free_trie_node(edge.node)
    free_trie_edge(edge.bigger)
    free_trie_edge(edge.smaller)

cdef _Trie_Node* trie_find(_Trie_Node* node, int val):
  cdef _Trie_Edge* cur
  cur = node.root
  while cur != NULL and cur.val != val:
    if val > cur.val:
      cur = cur.bigger
    elif val < cur.val:
      cur = cur.smaller
  if cur == NULL:
    return NULL
  else:
    return cur.node

cdef trie_node_data_append(_Trie_Node* node, int val):
  cdef int new_len
  new_len = node.arr_len + 1
  node.arr = <int*> realloc(node.arr, new_len*sizeof(int))
  node.arr[node.arr_len] = val
  node.arr_len = new_len

cdef trie_node_data_extend(_Trie_Node* node, int* vals, int num_vals):
  cdef int new_len
  new_len = node.arr_len + num_vals
  node.arr = <int*> realloc(node.arr, new_len*sizeof(int))
  memcpy(node.arr + node.arr_len, vals, num_vals*sizeof(int))
  node.arr_len = new_len


cdef _Trie_Node* trie_insert(_Trie_Node* node, int val):
  cdef _Trie_Edge** cur
  cur = &node.root
  while cur[0] != NULL and cur[0].val != val:
    if val > cur[0].val:
      cur = &cur[0].bigger
    elif val < cur[0].val:
      cur = &cur[0].smaller
  if cur[0] == NULL:
    cur[0] = new_trie_edge(val)
  return cur[0].node

cdef trie_node_to_map(_Trie_Node* node, result, prefix, int include_zeros):
  cdef cintlist.CIntList arr

  if include_zeros or node.arr_len > 0:
    arr = cintlist.CIntList()
    free(arr.arr)
    arr.arr = <int*> malloc(node.arr_len * sizeof(int))
    memcpy(arr.arr, node.arr, node.arr_len * sizeof(int))
    arr.len = node.arr_len
    arr.size = node.arr_len
    result[prefix] = arr
  trie_edge_to_map(node.root, result, prefix, include_zeros)

cdef trie_edge_to_map(_Trie_Edge* edge, result, prefix, int include_zeros):
  if edge != NULL:
    trie_edge_to_map(edge.smaller, result, prefix, include_zeros)
    trie_edge_to_map(edge.bigger, result, prefix, include_zeros)
    prefix = prefix + (edge.val,)
    trie_node_to_map(edge.node, result, prefix, include_zeros)

cdef class TrieMap:

  cdef _Trie_Node** root
  cdef int V

  def __init__(self, alphabet_size):
    self.V = alphabet_size
    self.root = <_Trie_Node**> malloc(self.V * sizeof(_Trie_Node*))
    memset(self.root, 0, self.V * sizeof(_Trie_Node*))


  def __dealloc__(self):
    cdef int i
    for i from 0 <= i < self.V:
      if self.root[i] != NULL:
        free_trie_node(self.root[i])
    free(self.root)


  def insert(self, pattern):
    cdef int* p
    cdef int i, l
    l = len(pattern)
    p = <int*> malloc(l*sizeof(int))
    for i from 0 <= i < l:
      p[i] = pattern[i]
    self._insert(p,l)
    free(p)


  cdef _Trie_Node* _insert(self, int* pattern, int pattern_len):
    cdef int i
    cdef _Trie_Node* node
    if self.root[pattern[0]] == NULL:
      self.root[pattern[0]] = new_trie_node()
    node = self.root[pattern[0]]
    for i from 1 <= i < pattern_len:
      node = trie_insert(node, pattern[i])
    return node

  def contains(self, pattern):
    cdef int* p
    cdef int i, l
    cdef _Trie_Node* node
    l = len(pattern)
    p = <int*> malloc(l*sizeof(int))
    for i from 0 <= i < l:
      p[i] = pattern[i]
    node = self._contains(p,l)
    free(p)
    if node == NULL:
      return False
    else:
      return True

  cdef _Trie_Node* _contains(self, int* pattern, int pattern_len):
    cdef int i
    cdef _Trie_Node* node
    node = self.root[pattern[0]]
    i = 1
    while node != NULL and i < pattern_len:
      node = trie_find(node, pattern[i])
      i = i+1
    return node

  def toMap(self, flag):
    cdef int i, include_zeros

    if flag:
      include_zeros=1
    else:
      include_zeros=0
    result = {}
    for i from 0 <= i < self.V:
      if self.root[i] != NULL:
        trie_node_to_map(self.root[i], result, (i,), include_zeros)
    return result


cdef class Precomputation:

#  Defined in .pxd file, here for reference:
#  cdef int precompute_rank
#  cdef int precompute_secondary_rank
#  cdef int max_length
#  cdef int max_nonterminals
#  cdef int train_max_initial_size
#  cdef int train_min_gap_size
#  cdef precomputed_index
#  cdef precomputed_collocations

  def __init__(self, filename, sa=None, precompute_rank=1000, precompute_secondary_rank=20, max_length=5, 
        max_nonterminals=2, train_max_initial_size=10, train_min_gap_size=2, from_binary=False):
    self.precompute_rank = precompute_rank
    self.precompute_secondary_rank = precompute_secondary_rank
    self.max_length = max_length
    self.max_nonterminals = max_nonterminals
    self.train_max_initial_size = train_max_initial_size
    self.train_min_gap_size = train_min_gap_size
    if from_binary:
      self.read_binary(filename)
    else:
      self.precompute(filename, sa)


  def read_binary(self, filename):
    cdef FILE* f
    cdef bytes bfilename = filename
    cdef char* cfilename = bfilename
    f = fopen(cfilename, "r")
    fread(&(self.precompute_rank), sizeof(int), 1, f)
    fread(&(self.precompute_secondary_rank), sizeof(int), 1, f)
    fread(&(self.max_length), sizeof(int), 1, f)
    fread(&(self.max_nonterminals), sizeof(int), 1, f)
    fread(&(self.train_max_initial_size), sizeof(int), 1, f)
    fread(&(self.train_min_gap_size), sizeof(int), 1, f)
    self.precomputed_index = self.read_map(f)
    self.precomputed_collocations = self.read_map(f)
    fclose(f)


  def write_binary(self, filename):
    cdef FILE* f
    cdef bytes bfilename = filename
    cdef char* cfilename = bfilename

    f = fopen(cfilename, "w")
    fwrite(&(self.precompute_rank), sizeof(int), 1, f)
    fwrite(&(self.precompute_secondary_rank), sizeof(int), 1, f)
    fwrite(&(self.max_length), sizeof(int), 1, f)
    fwrite(&(self.max_nonterminals), sizeof(int), 1, f)
    fwrite(&(self.train_max_initial_size), sizeof(int), 1, f)
    fwrite(&(self.train_min_gap_size), sizeof(int), 1, f)
    self.write_map(self.precomputed_index, f)
    self.write_map(self.precomputed_collocations, f)
    fclose(f)


  cdef write_map(self, m, FILE* f):
    cdef int i, N
    cdef cintlist.CIntList arr

    N = len(m)
    fwrite(&(N), sizeof(int), 1, f)
    for pattern, val in m.iteritems():
      N = len(pattern)
      fwrite(&(N), sizeof(int), 1, f)
      for word_id in pattern:
        i = word_id
        fwrite(&(i), sizeof(int), 1, f)
      arr = val
      arr.write_handle(f)


  cdef read_map(self, FILE* f):
    cdef int i, j, k, word_id, N
    cdef cintlist.CIntList arr

    m = {}
    fread(&(N), sizeof(int), 1, f)
    for j from 0 <= j < N:
      fread(&(i), sizeof(int), 1, f)
      key = ()
      for k from 0 <= k < i:
        fread(&(word_id), sizeof(int), 1, f)
        key = key + (word_id,)
      arr = cintlist.CIntList()
      arr.read_handle(f)
      m[key] = arr
    return m


  def precompute(self, filename, sa):
    cdef int i, l, N, max_pattern_len, i1, l1, i2, l2, i3, l3, ptr1, ptr2, ptr3, is_super, sent_count, max_rank
    cdef csuf.SuffixArray sarray
    cdef cdat.DataArray darray
    cdef cintlist.CIntList data, queue, cost_by_rank, count_by_rank
    cdef TrieMap frequent_patterns, super_frequent_patterns, collocations
    cdef _Trie_Node* node

    sarray = sa
    darray = sarray.darray
    data = darray.data

    frequent_patterns = TrieMap(len(darray.id2word))
    super_frequent_patterns = TrieMap(len(darray.id2word))
    collocations = TrieMap(len(darray.id2word))

    I_set = set()
    J_set = set()
    J2_set = set()
    IJ_set = set()
    pattern_rank = {}

    log.writeln("Precomputing frequent intersections\n", 1)
    start_time = monitor.cpu()

    max_pattern_len = 0
    if filename is not None:
      precompute_file = open(filename)
      for rank, line in enumerate(precompute_file):
        if rank >= self.precompute_rank:
          break
        phrase_words = line.split()[2:]
        phrase = ()
        for word in phrase_words:
          phrase = phrase + (darray.word2id[word],)
        max_pattern_len = max(max_pattern_len, len(phrase))
        frequent_patterns.insert(phrase)
        I_set.add(phrase)
        pattern_rank[phrase] = rank
        if rank < self.precompute_secondary_rank:
          super_frequent_patterns.insert(phrase)
          J_set.add(phrase)
      precompute_file.close()

    queue = cintlist.CIntList(increment=1000)

    log.writeln("  Computing inverted indexes...", 1)
    N = len(data)
    for i from 0 <= i < N:
      sa_word_id = data.arr[i]
      if sa_word_id == 1:
        queue._append(-1)
      else:
        for l from 1 <= l <= max_pattern_len:
          node = frequent_patterns._contains(data.arr+i, l)
          if node == NULL:
            break
          queue._append(i)
          queue._append(l)
          trie_node_data_append(node, i)

    log.writeln("  Computing collocations...", 1)
    N = len(queue)
    ptr1 = 0
    sent_count = 0
    while ptr1 < N:  # main loop
      i1 = queue.arr[ptr1]
      if i1 > -1:
        l1 = queue.arr[ptr1+1]
        ptr2 = ptr1 + 2
        while ptr2 < N:
          i2 = queue.arr[ptr2]
          if i2 == -1 or i2 - i1 >= self.train_max_initial_size:
            break
          l2 = queue.arr[ptr2+1]
          if i2 - i1 - l1 >= self.train_min_gap_size and i2 + l2 - i1 <= self.train_max_initial_size and l1+l2+1 <= self.max_length:
            node = collocations._insert(data.arr+i1, l1)
            node = trie_insert(node, -1)
            for i from i2 <= i < i2+l2:
              node = trie_insert(node, data.arr[i])
            trie_node_data_append(node, i1)
            trie_node_data_append(node, i2)
            if super_frequent_patterns._contains(data.arr+i2, l2) != NULL:
              if super_frequent_patterns._contains(data.arr+i1, l1) != NULL:
                is_super = 1
              else:
                is_super = 0
              ptr3 = ptr2 + 2
              while ptr3 < N:
                i3 = queue.arr[ptr3]
                if i3 == -1 or i3 - i1 >= self.train_max_initial_size:
                  break
                l3 = queue.arr[ptr3+1]
                if i3 - i2 - l2 >= self.train_min_gap_size and i3 + l3 - i1 <= self.train_max_initial_size and l1+l2+l3+2 <= self.max_length:
                  if is_super or super_frequent_patterns._contains(data.arr+i3, l3) != NULL:
                    node = collocations._insert(data.arr+i1, l1)
                    node = trie_insert(node, -1)
                    for i from i2 <= i < i2+l2:
                      node = trie_insert(node, data.arr[i])
                    node = trie_insert(node, -1)
                    for i from i3 <= i < i3+l3:
                      node = trie_insert(node, data.arr[i])
                    trie_node_data_append(node, i1)
                    trie_node_data_append(node, i2)
                    trie_node_data_append(node, i3)
                ptr3 = ptr3 + 2
          ptr2 = ptr2 + 2
        ptr1 = ptr1 + 2
      else:
        sent_count = sent_count + 1
        if sent_count % 10000 == 0:
          log.writeln("    %d sentences" % sent_count)
        ptr1 = ptr1 + 1

    self.precomputed_collocations = collocations.toMap(False)
    self.precomputed_index = frequent_patterns.toMap(True)

    x = 0
    for pattern1 in J_set:
      for pattern2 in J_set:
        if len(pattern1) + len(pattern2) + 1 < self.max_length:
          combined_pattern = pattern1 + (-1,) + pattern2
          J2_set.add(combined_pattern)

    for pattern1 in I_set:
      for pattern2 in I_set:
        x = x+1
        if len(pattern1) + len(pattern2) + 1 <= self.max_length:
          combined_pattern = pattern1 + (-1,) + pattern2
          IJ_set.add(combined_pattern)

    for pattern1 in I_set:
      for pattern2 in J2_set:
        x = x+2
        if len(pattern1) + len(pattern2) + 1<= self.max_length:
          combined_pattern = pattern1 + (-1,) + pattern2
          IJ_set.add(combined_pattern)
          combined_pattern = pattern2 + (-1,) + pattern1
          IJ_set.add(combined_pattern)

    N = len(pattern_rank)
    cost_by_rank = cintlist.CIntList(initial_len=N)
    count_by_rank = cintlist.CIntList(initial_len=N)
    for pattern, arr in self.precomputed_collocations.iteritems():
      if pattern not in IJ_set:
        s = ""
        for word_id in pattern:
          if word_id == -1:
            s = s + "X "
          else:
            s = s + darray.id2word[word_id] + " "
        log.writeln("ERROR: unexpected pattern %s in set of precomputed collocations" % (s), 1)
      else:
        chunk = ()
        max_rank = 0
        arity = 0
        for word_id in pattern:
          if word_id == -1:
            max_rank = max(max_rank, pattern_rank[chunk])
            arity = arity + 1
            chunk = ()
          else:
            chunk = chunk + (word_id,)
        max_rank = max(max_rank, pattern_rank[chunk])
        cost_by_rank.arr[max_rank] = cost_by_rank.arr[max_rank] + (4*len(arr))
        count_by_rank.arr[max_rank] = count_by_rank.arr[max_rank] + (len(arr)/(arity+1))

    cumul_cost = 0
    cumul_count = 0
    for i from 0 <= i < N:
      cumul_cost = cumul_cost + cost_by_rank.arr[i]
      cumul_count = cumul_count + count_by_rank.arr[i]
      log.writeln("RANK %d\tCOUNT, COST: %d  %d\tCUMUL: %d, %d" % (i, count_by_rank.arr[i], cost_by_rank.arr[i], cumul_count, cumul_cost))

    num_found_patterns = len(self.precomputed_collocations)
    for pattern in IJ_set:
      if pattern not in self.precomputed_collocations:
        self.precomputed_collocations[pattern] = cintlist.CIntList()

    stop_time = monitor.cpu()
    log.writeln("Precomputed collocations for %d patterns out of %d possible (upper bound %d)" % (num_found_patterns,len(self.precomputed_collocations),x))
    log.writeln("Precomputed inverted index for %d patterns " % len(self.precomputed_index))
    log.writeln("Precomputation took %f seconds" % (stop_time - start_time))
    log.writeln("Detailed statistics:")