summaryrefslogtreecommitdiff
path: root/sa-extract/csuf.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'sa-extract/csuf.pyx')
-rw-r--r--sa-extract/csuf.pyx321
1 files changed, 321 insertions, 0 deletions
diff --git a/sa-extract/csuf.pyx b/sa-extract/csuf.pyx
new file mode 100644
index 00000000..64c44788
--- /dev/null
+++ b/sa-extract/csuf.pyx
@@ -0,0 +1,321 @@
+# csuf.pyx
+# Defines suffix arrays that can be directly written to/read from disk in binary format
+# Adam Lopez <alopez@cs.umd.edu>
+
+import sys
+import log
+import cdat
+import cintlist
+import monitor
+
+from libc.stdio cimport FILE, fclose, fopen
+
+cdef class SuffixArray:
+
+ def __init__(self, filename, from_binary=False):
+ self.darray = cdat.DataArray()
+ self.sa = cintlist.CIntList()
+ self.ha = cintlist.CIntList()
+ if from_binary:
+ self.read_binary(filename)
+ else:
+ self.read_text(filename)
+
+
+ def __getitem__(self, i):
+ return self.sa.arr[i]
+
+
+ def getSentId(self, i):
+ return self.darray.getSentId(i)
+
+
+ def getSent(self, i):
+ return self.darray.getSent(i)
+
+
+ def getSentPos(self, loc):
+ return self.darray.getSentPos(loc)
+
+ def read_text(self, filename):
+ '''Constructs suffix array using the algorithm
+ of Larsson & Sadahkane (1999)'''
+ cdef int V, N, i, j, h, a_i, n, current_run, skip
+ cdef cintlist.CIntList isa, word_count
+
+ self.darray = cdat.DataArray(filename, from_binary=False, use_sent_id=True)
+ N = len(self.darray)
+ V = len(self.darray.id2word)
+
+ self.sa = cintlist.CIntList(initial_len=N)
+ self.ha = cintlist.CIntList(initial_len=V+1)
+
+ isa = cintlist.CIntList(initial_len=N)
+ word_count = cintlist.CIntList(initial_len=V+1)
+
+ '''Step 1: bucket sort data'''
+ sort_start_time = monitor.cpu()
+ start_time = sort_start_time
+ for i from 0 <= i < N:
+ a_i = self.darray.data.arr[i]
+ word_count.arr[a_i] = word_count.arr[a_i] + 1
+
+ n = 0
+ for i from 0 <= i < V+1:
+ self.ha.arr[i] = n
+ n = n + word_count.arr[i]
+ word_count.arr[i] = 0
+
+ for i from 0 <= i < N:
+ a_i = self.darray.data.arr[i]
+ self.sa.arr[self.ha.arr[a_i] + word_count.arr[a_i]] = i
+ isa.arr[i] = self.ha.arr[a_i + 1] - 1 # bucket pointer is last index in bucket
+ word_count.arr[a_i] = word_count.arr[a_i] + 1
+
+ '''Determine size of initial runs'''
+ current_run = 0
+ for i from 0 <= i < V+1:
+ if i < V and self.ha.arr[i+1] - self.ha.arr[i] == 1:
+ current_run = current_run + 1
+ else:
+ if current_run > 0:
+ self.sa.arr[self.ha.arr[i] - current_run] = -current_run
+ current_run = 0
+
+ sys.stderr.write(" Bucket sort took %f seconds\n" % (monitor.cpu() - sort_start_time))
+
+ '''Step 2: prefix-doubling sort'''
+ h = 1
+ while self.sa.arr[0] != -N:
+ sort_start_time = monitor.cpu()
+ sys.stderr.write(" Refining, sort depth = %d\n" % (h,))
+ i = 0
+ skip = 0
+ while i < N:
+ if self.sa.arr[i] < 0:
+ #sys.stderr.write("Skip from %d to %d\n" % (i, i-self.sa.arr[i]-1))
+ skip = skip + self.sa.arr[i]
+ i = i - self.sa.arr[i]
+ else:
+ if skip < 0:
+ self.sa.arr[i+skip] = skip
+ skip = 0
+ j = isa.arr[self.sa.arr[i]]
+ #sys.stderr.write("Process from %d to %d (%d, %d, %d)\n" % (i, j, self.sa.arr[i], self.darray.data.arr[self.sa.arr[i]], isa.arr[self.sa.arr[i]]))
+ self.q3sort(i, j, h, isa)
+ i = j+1
+ if skip < 0:
+ self.sa.arr[i+skip] = skip
+ h = h * 2
+ sys.stderr.write(" Refinement took %f seconds\n" % (monitor.cpu() - sort_start_time))
+
+ '''Step 3: read off suffix array from inverse suffix array'''
+ sys.stderr.write(" Finalizing sort...\n")
+ for i from 0 <= i < N:
+ j = isa.arr[i]
+ self.sa.arr[j] = i
+ sys.stderr.write("Suffix array construction took %f seconds\n" % (monitor.cpu() - start_time))
+
+ def q3sort(self, int i, int j, int h, cintlist.CIntList isa, pad=""):
+ '''This is a ternary quicksort. It divides the array into
+ three partitions: items less than the pivot, items equal
+ to pivot, and items greater than pivot. The first and last
+ of these partitions are then recursively sorted'''
+ cdef int k, midpoint, pval, phead, ptail, tmp
+
+ if j-i < -1:
+ raise Exception("Unexpected condition found in q3sort: sort from %d to %d" % (i,j))
+ if j-i == -1: # recursive base case -- empty interval
+ return
+ if (j-i == 0): # recursive base case -- singleton interval
+ isa.arr[self.sa.arr[i]] = i
+ self.sa.arr[i] = -1
+ return
+
+ # NOTE: choosing the first item as a pivot value resulted in
+ # stack overflow for some very large buckets. I think there
+ # is a natural bias towards order due the way the word ids are
+ # assigned; thus this resulted in the range to the left of the
+ # pivot being nearly empty. Therefore, choose the middle item.
+ # If the method of assigning word_id's is changed, this method
+ # may need to be reconsidered as well.
+ midpoint = (i+j)/2
+ pval = isa.arr[self.sa.arr[midpoint] + h]
+ if i != midpoint:
+ tmp = self.sa.arr[midpoint]
+ self.sa.arr[midpoint] = self.sa.arr[i]
+ self.sa.arr[i] = tmp
+ phead = i
+ ptail = i
+
+ # find the three partitions. phead marks the first element
+ # of the middle partition, and ptail marks the last element
+ for k from i+1 <= k < j+1:
+ if isa.arr[self.sa.arr[k] + h] < pval:
+ if k > ptail+1:
+ tmp = self.sa.arr[phead]
+ self.sa.arr[phead] = self.sa.arr[k]
+ self.sa.arr[k] = self.sa.arr[ptail+1]
+ self.sa.arr[ptail+1] = tmp
+ else: # k == ptail+1
+ tmp = self.sa.arr[phead]
+ self.sa.arr[phead] = self.sa.arr[k]
+ self.sa.arr[k] = tmp
+ phead = phead + 1
+ ptail = ptail + 1
+ else:
+ if isa.arr[self.sa.arr[k] + h] == pval:
+ if k > ptail+1:
+ tmp = self.sa.arr[ptail+1]
+ self.sa.arr[ptail+1] = self.sa.arr[k]
+ self.sa.arr[k] = tmp
+ ptail = ptail + 1
+
+ # recursively sort smaller suffixes
+ self.q3sort(i, phead-1, h, isa, pad+" ")
+
+ # update suffixes with pivot value
+ # corresponds to update_group function in Larsson & Sadakane
+ for k from phead <= k < ptail+1:
+ isa.arr[self.sa.arr[k]] = ptail
+ if phead == ptail:
+ self.sa.arr[phead] = -1
+
+ # recursively sort larger suffixes
+ self.q3sort(ptail+1, j, h, isa, pad+" ")
+
+
+ def write_text(self, filename):
+ self.darray.write_text(filename)
+
+
+ def read_binary(self, filename):
+ cdef FILE *f
+ cdef bytes bfilename = filename
+ cdef char* cfilename = bfilename
+ f = fopen(cfilename, "r")
+ self.darray.read_handle(f)
+ self.sa.read_handle(f)
+ self.ha.read_handle(f)
+ fclose(f)
+
+
+ def write_binary(self, filename):
+ cdef FILE* f
+ cdef bytes bfilename = filename
+ cdef char* cfilename = bfilename
+ f = fopen(cfilename, "w")
+ self.darray.write_handle(f)
+ self.sa.write_handle(f)
+ self.ha.write_handle(f)
+ fclose(f)
+
+
+ def write_enhanced(self, filename):
+ f = open(filename, "w")
+ self.darray.write_enhanced_handle(f)
+ for a_i in self.sa:
+ f.write("%d " % a_i)
+ f.write("\n")
+ for w_i in self.ha:
+ f.write("%d " % w_i)
+ f.write("\n")
+ f.close()
+
+
+ cdef __search_high(self, word_id, offset, low, high):
+ cdef int midpoint
+
+ if low >= high:
+ return high
+ midpoint = (high + low) / 2
+ if self.darray.data.arr[self.sa.arr[midpoint] + offset] == word_id:
+ return self.__search_high(word_id, offset, midpoint+1, high)
+ else:
+ return self.__search_high(word_id, offset, low, midpoint)
+
+
+ cdef __search_low(self, int word_id, int offset, int low, int high):
+ cdef int midpoint
+
+ if low >= high:
+ return high
+ midpoint = (high + low) / 2
+ if self.darray.data.arr[self.sa.arr[midpoint] + offset] == word_id:
+ return self.__search_low(word_id, offset, low, midpoint)
+ else:
+ return self.__search_low(word_id, offset, midpoint+1, high)
+
+
+ cdef __get_range(self, int word_id, int offset, int low, int high, int midpoint):
+ return (self.__search_low(word_id, offset, low, midpoint),
+ self.__search_high(word_id, offset, midpoint, high))
+
+
+ cdef __lookup_helper(self, int word_id, int offset, int low, int high):
+ cdef int midpoint
+
+ if offset == 0:
+ return (self.ha.arr[word_id], self.ha.arr[word_id+1])
+ if low >= high:
+ return None
+
+ midpoint = (high + low) / 2
+ if self.darray.data.arr[self.sa.arr[midpoint] + offset] == word_id:
+ return self.__get_range(word_id, offset, low, high, midpoint)
+ if self.darray.data.arr[self.sa.arr[midpoint] + offset] > word_id:
+ return self.__lookup_helper(word_id, offset, low, midpoint)
+ else:
+ return self.__lookup_helper(word_id, offset, midpoint+1, high)
+
+
+ def lookup(self, word, offset, int low, int high):
+ if low == -1:
+ low = 0
+ if high == -1:
+ high = len(self.sa)
+ if word in self.darray.word2id:
+ word_id = self.darray.word2id[word]
+ return self.__lookup_helper(word_id, offset, low, high)
+ else:
+ return None
+
+
+
+ def print_sa(self, isa):
+ '''Slow; Use only in case of emergency'''
+ cdef int i, j, k, N
+ cdef cintlist.CIntList tmp_sa
+
+ N = len(self.sa)
+ for i from 0 <= i < N:
+ sys.stderr.write("%2d " % i)
+ sys.stderr.write("\n")
+ for i from 0 <= i < N:
+ sys.stderr.write("%2d " % self.darray.data.arr[i])
+ sys.stderr.write("\n")
+ for i from 0 <= i < N:
+ sys.stderr.write("%2d " % isa.arr[i])
+ sys.stderr.write("\n\n\n")
+
+ # Recover partially sorted array
+ tmp_sa = cintlist.CIntList(initial_len=N)
+ for i from 0 <= i < N:
+ j = isa.arr[i]
+ tmp_sa.arr[j] = i
+ for i from 0 <= i < N:
+ if self.sa.arr[i] > 0:
+ tmp_sa.arr[i] = self.sa.arr[i]
+
+ for i from 0 <= i < N:
+ j = tmp_sa.arr[i]
+ sys.stderr.write("%2d %2d | " % (i, self.sa.arr[i]))
+ for k from j <= k < N:
+ sys.stderr.write("%2d " % self.darray.data.arr[k])
+ sys.stderr.write("\n")
+ sys.stderr.write("\n")
+
+
+
+
+