summaryrefslogtreecommitdiff
path: root/sa-extract/csuf.pyx
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-08-03 07:46:54 -0400
committerKenneth Heafield <github@kheafield.com>2012-08-03 07:46:54 -0400
commitbe1ab0a8937f9c5668ea5e6c31b798e87672e55e (patch)
treea13aad60ab6cced213401bce6a38ac885ba171ba /sa-extract/csuf.pyx
parente5d6f4ae41009c26978ecd62668501af9762b0bc (diff)
parent9fe0219562e5db25171cce8776381600ff9a5649 (diff)
Merge branch 'master' of github.com:redpony/cdec
Diffstat (limited to 'sa-extract/csuf.pyx')
-rw-r--r--sa-extract/csuf.pyx321
1 files changed, 0 insertions, 321 deletions
diff --git a/sa-extract/csuf.pyx b/sa-extract/csuf.pyx
deleted file mode 100644
index 64c44788..00000000
--- a/sa-extract/csuf.pyx
+++ /dev/null
@@ -1,321 +0,0 @@
-# csuf.pyx
-# Defines suffix arrays that can be directly written to/read from disk in binary format
-# Adam Lopez <alopez@cs.umd.edu>
-
-import sys
-import log
-import cdat
-import cintlist
-import monitor
-
-from libc.stdio cimport FILE, fclose, fopen
-
-cdef class SuffixArray:
-
- def __init__(self, filename, from_binary=False):
- self.darray = cdat.DataArray()
- self.sa = cintlist.CIntList()
- self.ha = cintlist.CIntList()
- if from_binary:
- self.read_binary(filename)
- else:
- self.read_text(filename)
-
-
- def __getitem__(self, i):
- return self.sa.arr[i]
-
-
- def getSentId(self, i):
- return self.darray.getSentId(i)
-
-
- def getSent(self, i):
- return self.darray.getSent(i)
-
-
- def getSentPos(self, loc):
- return self.darray.getSentPos(loc)
-
- def read_text(self, filename):
- '''Constructs suffix array using the algorithm
- of Larsson & Sadahkane (1999)'''
- cdef int V, N, i, j, h, a_i, n, current_run, skip
- cdef cintlist.CIntList isa, word_count
-
- self.darray = cdat.DataArray(filename, from_binary=False, use_sent_id=True)
- N = len(self.darray)
- V = len(self.darray.id2word)
-
- self.sa = cintlist.CIntList(initial_len=N)
- self.ha = cintlist.CIntList(initial_len=V+1)
-
- isa = cintlist.CIntList(initial_len=N)
- word_count = cintlist.CIntList(initial_len=V+1)
-
- '''Step 1: bucket sort data'''
- sort_start_time = monitor.cpu()
- start_time = sort_start_time
- for i from 0 <= i < N:
- a_i = self.darray.data.arr[i]
- word_count.arr[a_i] = word_count.arr[a_i] + 1
-
- n = 0
- for i from 0 <= i < V+1:
- self.ha.arr[i] = n
- n = n + word_count.arr[i]
- word_count.arr[i] = 0
-
- for i from 0 <= i < N:
- a_i = self.darray.data.arr[i]
- self.sa.arr[self.ha.arr[a_i] + word_count.arr[a_i]] = i
- isa.arr[i] = self.ha.arr[a_i + 1] - 1 # bucket pointer is last index in bucket
- word_count.arr[a_i] = word_count.arr[a_i] + 1
-
- '''Determine size of initial runs'''
- current_run = 0
- for i from 0 <= i < V+1:
- if i < V and self.ha.arr[i+1] - self.ha.arr[i] == 1:
- current_run = current_run + 1
- else:
- if current_run > 0:
- self.sa.arr[self.ha.arr[i] - current_run] = -current_run
- current_run = 0
-
- sys.stderr.write(" Bucket sort took %f seconds\n" % (monitor.cpu() - sort_start_time))
-
- '''Step 2: prefix-doubling sort'''
- h = 1
- while self.sa.arr[0] != -N:
- sort_start_time = monitor.cpu()
- sys.stderr.write(" Refining, sort depth = %d\n" % (h,))
- i = 0
- skip = 0
- while i < N:
- if self.sa.arr[i] < 0:
- #sys.stderr.write("Skip from %d to %d\n" % (i, i-self.sa.arr[i]-1))
- skip = skip + self.sa.arr[i]
- i = i - self.sa.arr[i]
- else:
- if skip < 0:
- self.sa.arr[i+skip] = skip
- skip = 0
- j = isa.arr[self.sa.arr[i]]
- #sys.stderr.write("Process from %d to %d (%d, %d, %d)\n" % (i, j, self.sa.arr[i], self.darray.data.arr[self.sa.arr[i]], isa.arr[self.sa.arr[i]]))
- self.q3sort(i, j, h, isa)
- i = j+1
- if skip < 0:
- self.sa.arr[i+skip] = skip
- h = h * 2
- sys.stderr.write(" Refinement took %f seconds\n" % (monitor.cpu() - sort_start_time))
-
- '''Step 3: read off suffix array from inverse suffix array'''
- sys.stderr.write(" Finalizing sort...\n")
- for i from 0 <= i < N:
- j = isa.arr[i]
- self.sa.arr[j] = i
- sys.stderr.write("Suffix array construction took %f seconds\n" % (monitor.cpu() - start_time))
-
- def q3sort(self, int i, int j, int h, cintlist.CIntList isa, pad=""):
- '''This is a ternary quicksort. It divides the array into
- three partitions: items less than the pivot, items equal
- to pivot, and items greater than pivot. The first and last
- of these partitions are then recursively sorted'''
- cdef int k, midpoint, pval, phead, ptail, tmp
-
- if j-i < -1:
- raise Exception("Unexpected condition found in q3sort: sort from %d to %d" % (i,j))
- if j-i == -1: # recursive base case -- empty interval
- return
- if (j-i == 0): # recursive base case -- singleton interval
- isa.arr[self.sa.arr[i]] = i
- self.sa.arr[i] = -1
- return
-
- # NOTE: choosing the first item as a pivot value resulted in
- # stack overflow for some very large buckets. I think there
- # is a natural bias towards order due the way the word ids are
- # assigned; thus this resulted in the range to the left of the
- # pivot being nearly empty. Therefore, choose the middle item.
- # If the method of assigning word_id's is changed, this method
- # may need to be reconsidered as well.
- midpoint = (i+j)/2
- pval = isa.arr[self.sa.arr[midpoint] + h]
- if i != midpoint:
- tmp = self.sa.arr[midpoint]
- self.sa.arr[midpoint] = self.sa.arr[i]
- self.sa.arr[i] = tmp
- phead = i
- ptail = i
-
- # find the three partitions. phead marks the first element
- # of the middle partition, and ptail marks the last element
- for k from i+1 <= k < j+1:
- if isa.arr[self.sa.arr[k] + h] < pval:
- if k > ptail+1:
- tmp = self.sa.arr[phead]
- self.sa.arr[phead] = self.sa.arr[k]
- self.sa.arr[k] = self.sa.arr[ptail+1]
- self.sa.arr[ptail+1] = tmp
- else: # k == ptail+1
- tmp = self.sa.arr[phead]
- self.sa.arr[phead] = self.sa.arr[k]
- self.sa.arr[k] = tmp
- phead = phead + 1
- ptail = ptail + 1
- else:
- if isa.arr[self.sa.arr[k] + h] == pval:
- if k > ptail+1:
- tmp = self.sa.arr[ptail+1]
- self.sa.arr[ptail+1] = self.sa.arr[k]
- self.sa.arr[k] = tmp
- ptail = ptail + 1
-
- # recursively sort smaller suffixes
- self.q3sort(i, phead-1, h, isa, pad+" ")
-
- # update suffixes with pivot value
- # corresponds to update_group function in Larsson & Sadakane
- for k from phead <= k < ptail+1:
- isa.arr[self.sa.arr[k]] = ptail
- if phead == ptail:
- self.sa.arr[phead] = -1
-
- # recursively sort larger suffixes
- self.q3sort(ptail+1, j, h, isa, pad+" ")
-
-
- def write_text(self, filename):
- self.darray.write_text(filename)
-
-
- def read_binary(self, filename):
- cdef FILE *f
- cdef bytes bfilename = filename
- cdef char* cfilename = bfilename
- f = fopen(cfilename, "r")
- self.darray.read_handle(f)
- self.sa.read_handle(f)
- self.ha.read_handle(f)
- fclose(f)
-
-
- def write_binary(self, filename):
- cdef FILE* f
- cdef bytes bfilename = filename
- cdef char* cfilename = bfilename
- f = fopen(cfilename, "w")
- self.darray.write_handle(f)
- self.sa.write_handle(f)
- self.ha.write_handle(f)
- fclose(f)
-
-
- def write_enhanced(self, filename):
- f = open(filename, "w")
- self.darray.write_enhanced_handle(f)
- for a_i in self.sa:
- f.write("%d " % a_i)
- f.write("\n")
- for w_i in self.ha:
- f.write("%d " % w_i)
- f.write("\n")
- f.close()
-
-
- cdef __search_high(self, word_id, offset, low, high):
- cdef int midpoint
-
- if low >= high:
- return high
- midpoint = (high + low) / 2
- if self.darray.data.arr[self.sa.arr[midpoint] + offset] == word_id:
- return self.__search_high(word_id, offset, midpoint+1, high)
- else:
- return self.__search_high(word_id, offset, low, midpoint)
-
-
- cdef __search_low(self, int word_id, int offset, int low, int high):
- cdef int midpoint
-
- if low >= high:
- return high
- midpoint = (high + low) / 2
- if self.darray.data.arr[self.sa.arr[midpoint] + offset] == word_id:
- return self.__search_low(word_id, offset, low, midpoint)
- else:
- return self.__search_low(word_id, offset, midpoint+1, high)
-
-
- cdef __get_range(self, int word_id, int offset, int low, int high, int midpoint):
- return (self.__search_low(word_id, offset, low, midpoint),
- self.__search_high(word_id, offset, midpoint, high))
-
-
- cdef __lookup_helper(self, int word_id, int offset, int low, int high):
- cdef int midpoint
-
- if offset == 0:
- return (self.ha.arr[word_id], self.ha.arr[word_id+1])
- if low >= high:
- return None
-
- midpoint = (high + low) / 2
- if self.darray.data.arr[self.sa.arr[midpoint] + offset] == word_id:
- return self.__get_range(word_id, offset, low, high, midpoint)
- if self.darray.data.arr[self.sa.arr[midpoint] + offset] > word_id:
- return self.__lookup_helper(word_id, offset, low, midpoint)
- else:
- return self.__lookup_helper(word_id, offset, midpoint+1, high)
-
-
- def lookup(self, word, offset, int low, int high):
- if low == -1:
- low = 0
- if high == -1:
- high = len(self.sa)
- if word in self.darray.word2id:
- word_id = self.darray.word2id[word]
- return self.__lookup_helper(word_id, offset, low, high)
- else:
- return None
-
-
-
- def print_sa(self, isa):
- '''Slow; Use only in case of emergency'''
- cdef int i, j, k, N
- cdef cintlist.CIntList tmp_sa
-
- N = len(self.sa)
- for i from 0 <= i < N:
- sys.stderr.write("%2d " % i)
- sys.stderr.write("\n")
- for i from 0 <= i < N:
- sys.stderr.write("%2d " % self.darray.data.arr[i])
- sys.stderr.write("\n")
- for i from 0 <= i < N:
- sys.stderr.write("%2d " % isa.arr[i])
- sys.stderr.write("\n\n\n")
-
- # Recover partially sorted array
- tmp_sa = cintlist.CIntList(initial_len=N)
- for i from 0 <= i < N:
- j = isa.arr[i]
- tmp_sa.arr[j] = i
- for i from 0 <= i < N:
- if self.sa.arr[i] > 0:
- tmp_sa.arr[i] = self.sa.arr[i]
-
- for i from 0 <= i < N:
- j = tmp_sa.arr[i]
- sys.stderr.write("%2d %2d | " % (i, self.sa.arr[i]))
- for k from j <= k < N:
- sys.stderr.write("%2d " % self.darray.data.arr[k])
- sys.stderr.write("\n")
- sys.stderr.write("\n")
-
-
-
-
-