#!/usr/bin/env python2.4
'''This module implements a partial stratified tree (van Emde Boas, 1977).
Only insert findsucc, __iter__, and __contains__ are implemented.
Delete is currently not supported.
There is very little error-checking in this code -- it is designed
to be used in the limited situation described in Lopez (EMNLP-CoNLL 2007),
which doesn't cover all of the possible ways that you could misuse it
(e.g. trying to insert a key larger than the universe size)
Other notes -- this code is really rather ugly C code masquerading as
Pyrex/Python.  Virtual function calls are bypassed by hand in several
places for the sake of efficiency, and other Python niceties are 
removed for the same reason.'''

from libc.stdlib cimport malloc, free
from libc.math cimport log, ceil
from libc.string cimport memset

cdef int MIN_BOTTOM_SIZE
cdef int MIN_BOTTOM_BITS

MIN_BOTTOM_SIZE = 32
MIN_BOTTOM_BITS = 5

cdef int lower_mask[32]
cdef int i, mask

for i from 0 <= i < MIN_BOTTOM_SIZE:
	mask = (mask << 1) + 1
	lower_mask[i] = mask


cdef struct _BitSet:
	long bitset
	int min_val
	int max_val
	int size


cdef _BitSet* new_BitSet():
	cdef _BitSet* b

	b = <_BitSet*> malloc(sizeof(_BitSet))
	b.bitset = 0
	b.min_val = -1
	b.max_val = -1
	b.size = 0
	return b


cdef int bitset_findsucc(_BitSet* b, int i):
	cdef int bitset, mask
	cdef int low, high, mid

	if b.max_val == -1 or i >= b.max_val:
		return -1
	if i < b.min_val:
		return b.min_val

	bitset = b.bitset & ~lower_mask[i]
	low = i+1
	high = b.max_val+1
	while low < high-1:
		mid = (high + low)/2
		mask = ~(lower_mask[high-1] ^ lower_mask[mid-1])
		if bitset & mask == 0:
			low = mid
		else: 
			bitset = bitset & mask
			high = mid
	return low


cdef int bitset_insert(_BitSet* b, int i):
	cdef int val

	val = 1 << i
	if b.bitset & val == 0:
		b.bitset = b.bitset | val
		if b.size == 0:
			b.min_val = i
			b.max_val = i
		else:
			if i < b.min_val:
				b.min_val = i
			if i > b.max_val:
				b.max_val = i
		b.size = b.size + 1
		return 1
	return 0


cdef int bitset_contains(_BitSet* b, int i):
	cdef int val

	val = 1 << i
	if b.bitset & val == 0:
		return 0
	else:
		return 1


cdef class BitSetIterator:
	cdef _BitSet* b
	cdef int next_val

	def __next__(self):
		cdef int ret_val

		if self.next_val == -1:
			raise StopIteration()
		ret_val = self.next_val
		self.next_val = bitset_findsucc(self.b, ret_val)
		return ret_val



# This is a Python wrapper class to give access to the
# (entirely C-implemented) _BitSet struct. 
# Very slow; use only for debugging
cdef class BitSet:

	cdef _BitSet* b

	def __cinit__(self):
		self.b = new_BitSet()

	def __dealloc__(self):
		free(self.b)

	def __iter__(self):
		cdef BitSetIterator it
		it = BitSetIterator()
		it.b = self.b
		it.next_val = self.b.min_val
		return it

	def insert(self, i):
		return bitset_insert(self.b, i)

	def findsucc(self, i):
		return bitset_findsucc(self.b, i)

	def __str__(self):
		return dec2bin(self.b.bitset)+"  ("+str(self.b.size)+","+str(self.b.min_val)+","+str(self.b.max_val)+")"

	def min(self):
		return self.b.min_val

	def max(self):
		return self.b.max_val

	def __len__(self):
		return self.b.size

	def __contains__(self, i):
		return bool(bitset_contains(self.b, i))


def dec2bin(i):
	cdef d
	result = ""
	for d from 0 <= d < MIN_BOTTOM_SIZE:
		if i & lower_mask[0] == 0:
			result = "0"+result
		else:
			result = "1"+result
		i = i >> 1
	return result


cdef _VEB* new_VEB(int n):
	cdef _VEB* veb
	cdef int num_bits, num_top_bits, i

	veb = <_VEB*> malloc(sizeof(_VEB))

	num_bits = int(ceil(log(n) / log(2)))
	veb.num_bottom_bits = num_bits/2
	if veb.num_bottom_bits < MIN_BOTTOM_BITS:
		veb.num_bottom_bits = MIN_BOTTOM_BITS
	veb.top_universe_size = (n >> veb.num_bottom_bits) + 1

	veb.bottom = <void**> malloc(veb.top_universe_size * sizeof(void*))
	memset(veb.bottom, 0, veb.top_universe_size * sizeof(void*))

	if veb.top_universe_size > MIN_BOTTOM_SIZE:
		veb.top = new_VEB(veb.top_universe_size)
	else:
		veb.top = new_BitSet()

	veb.max_val = -1
	veb.min_val = -1
	veb.size = 0
	return veb


cdef int VEB_insert(_VEB* veb, int i):
	cdef _VEB* subv
	cdef _BitSet* subb
	cdef int a, b, tmp

	if veb.size == 0:
		veb.min_val = i
		veb.max_val = i
	elif i == veb.min_val or i == veb.max_val:
		return 0
	else:
		if i < veb.min_val:
			tmp = i
			i = veb.min_val
			veb.min_val = tmp
		a = i >> veb.num_bottom_bits
		b = i & lower_mask[veb.num_bottom_bits-1]
		if veb.bottom[a] == NULL:
			if veb.top_universe_size > MIN_BOTTOM_SIZE:
				subv = <_VEB*> veb.top
				VEB_insert(subv, a)
			else:
				subb = <_BitSet*> veb.top
				bitset_insert(subb, a)
			if veb.num_bottom_bits > MIN_BOTTOM_BITS:
				veb.bottom[a] = new_VEB(1 << veb.num_bottom_bits)
			else:
				veb.bottom[a] = new_BitSet()
		if veb.num_bottom_bits > MIN_BOTTOM_BITS:
			subv = <_VEB*> veb.bottom[a]
			if VEB_insert(subv, b) == 0:
				return 0
		else:
			subb = <_BitSet*> veb.bottom[a]
			if bitset_insert(subb, b) == 0:
				return 0

		if i > veb.max_val:
			veb.max_val = i
	veb.size = veb.size + 1
	return 1


cdef del_VEB(_VEB* veb):
	cdef int i

	if veb.top_universe_size > MIN_BOTTOM_SIZE:
		i = (<_VEB*> veb.top).min_val
	else:
		i = (<_BitSet*> veb.top).min_val

	while i != -1:
		if veb.num_bottom_bits > MIN_BOTTOM_BITS:
			del_VEB(<_VEB*> veb.bottom[i])
		else:
			free(<_BitSet*> veb.bottom[i])

		if veb.top_universe_size > MIN_BOTTOM_SIZE:
			i = VEB_findsucc(<_VEB*> veb.top, i)
		else:
			i = bitset_findsucc(<_BitSet*> veb.top, i)

	if veb.top_universe_size > MIN_BOTTOM_SIZE:
		del_VEB(<_VEB*> veb.top)
	else:
		free(<_BitSet*> veb.top)
	free(veb.bottom)
	free(veb)


cdef int VEB_findsucc(_VEB* veb, int i):
	cdef _VEB* subv
	cdef _BitSet* subb
	cdef int a, b, j, c, found

	if veb.max_val == -1 or i>=veb.max_val:
		return -1
	if i < veb.min_val:
		return veb.min_val

	a = i >> veb.num_bottom_bits
	b = i & lower_mask[veb.num_bottom_bits-1]
	found = 0
	if veb.bottom[a] != NULL:
		if veb.num_bottom_bits > MIN_BOTTOM_BITS:
			subv = <_VEB*> veb.bottom[a]
			if subv.max_val > b:
				j = (a << veb.num_bottom_bits) + VEB_findsucc(subv, b)
				found = 1
		else:
			subb = <_BitSet*> veb.bottom[a]
			if subb.max_val > b:
				j = (a << veb.num_bottom_bits) + bitset_findsucc(subb, b)
				found = 1
	if found==0:
		if veb.top_universe_size > MIN_BOTTOM_SIZE:
			subv = <_VEB*> veb.top
			c = VEB_findsucc(subv, a)
		else:
			subb = <_BitSet*> veb.top
			c = bitset_findsucc(subb, a)
		if veb.num_bottom_bits > MIN_BOTTOM_BITS:
			subv = <_VEB*> veb.bottom[c]
			j = (c << veb.num_bottom_bits) + subv.min_val
		else:
			subb = <_BitSet*> veb.bottom[c]
			j = (c << veb.num_bottom_bits) + subb.min_val
	return j


cdef int VEB_contains(_VEB* veb, int i):
	cdef _VEB* subv
	cdef _BitSet* subb
	cdef int a, b

	if veb.size == 0 or i < veb.min_val or i > veb.max_val:
		return 0

	if veb.min_val == i:
		return 1
	else:
		if veb.size == 1:
			return 0

	a = i >> veb.num_bottom_bits
	b = i & lower_mask[veb.num_bottom_bits-1]
	if veb.bottom[a] == NULL:
		return 0
	else:
		if veb.num_bottom_bits > MIN_BOTTOM_BITS:
			subv = <_VEB*> veb.bottom[a]
			return VEB_contains(subv, b)
		else:
			subb = <_BitSet*> veb.bottom[a]
			return bitset_contains(subb, b)


cdef class VEBIterator:
	cdef _VEB* v
	cdef int next_val

	def __next__(self):
		cdef int ret_val

		if self.next_val == -1:
			raise StopIteration()
		ret_val = self.next_val
		self.next_val = VEB_findsucc(self.v, ret_val)
		return ret_val


cdef class VEB:

	def __init__(self, size):
		pass

	def __cinit__(self, int size):
		self.veb = new_VEB(size)

	def __dealloc__(self):
		del_VEB(self.veb)

	def __iter__(self):
		cdef VEBIterator it
		it = VEBIterator()
		it.v = self.veb
		it.next_val = self.veb.min_val
		return it

	def insert(self, i):
		return VEB_insert(self.veb, i)

	cdef int _insert(self, int i):
		return VEB_insert(self.veb, i)

	def findsucc(self, i):
		return VEB_findsucc(self.veb, i)

	cdef int _first(self):
		return self.veb.min_val

	cdef int _findsucc(self, int i):
		return VEB_findsucc(self.veb, i)

	def __len__(self):
		return self.veb.size

	def __contains__(self, i):
		return VEB_contains(self.veb, i)