summaryrefslogtreecommitdiff
path: root/python/src/sa/bilex.pxi
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/sa/bilex.pxi')
-rw-r--r--python/src/sa/bilex.pxi148
1 files changed, 53 insertions, 95 deletions
diff --git a/python/src/sa/bilex.pxi b/python/src/sa/bilex.pxi
index 44bc0ce6..5e2fcd82 100644
--- a/python/src/sa/bilex.pxi
+++ b/python/src/sa/bilex.pxi
@@ -43,26 +43,26 @@ cdef int* get_val(_node* n, int key):
return &(n.bigger.val)
return get_val(n.bigger, key)
+cdef int NULL_WORD = 0
cdef class BiLex:
- cdef FloatList col1
- cdef FloatList col2
- cdef IntList f_index
- cdef IntList e_index
- cdef id2eword, id2fword, eword2id, fword2id
+ cdef FloatList col1, col2
+ cdef IntList f_index, e_index
+ cdef Vocabulary f_voc, e_voc
def __cinit__(self, from_text=None, from_data=False, from_binary=None,
- earray=None, fsarray=None, alignment=None):
- self.id2eword = []
- self.id2fword = []
- self.eword2id = {}
- self.fword2id = {}
+ earray=None, fsarray=None, alignment=None, mmaped=False):
+ self.f_voc = Vocabulary()
+ self.e_voc = Vocabulary()
self.e_index = IntList()
self.f_index = IntList()
self.col1 = FloatList()
self.col2 = FloatList()
if from_binary:
- self.read_binary(from_binary)
+ if mmaped:
+ self.read_mmaped(MemoryMap(from_binary))
+ else:
+ self.read_binary(from_binary)
elif from_data:
self.compute_from_data(fsarray, earray, alignment)
else:
@@ -74,25 +74,15 @@ cdef class BiLex:
cdef int *fsent, *esent, *alignment, *links, *ealigned, *faligned
cdef _node** dict
cdef int *fmargin, *emargin, *count
- cdef int null_word
- null_word = 0
- for word in fsa.darray.id2word: # I miss list comprehensions
- self.id2fword.append(word)
- self.id2fword[null_word] = "NULL"
- for id, word in enumerate(self.id2fword):
- self.fword2id[word] = id
-
- for word in eda.id2word:
- self.id2eword.append(word)
- self.id2eword[null_word] = "NULL"
- for id, word in enumerate(self.id2eword):
- self.eword2id[word] = id
+ self.f_voc.extend(fsa.darray.voc)
+ self.e_voc.extend(eda.voc)
+ assert(self.f_voc['NULL'] == self.e_voc['NULL'] == NULL_WORD)
num_pairs = 0
- V_E = len(eda.id2word)
- V_F = len(fsa.darray.id2word)
+ V_E = len(eda.voc)
+ V_F = len(fsa.darray.voc)
fmargin = <int*> malloc(V_F*sizeof(int))
emargin = <int*> malloc(V_E*sizeof(int))
memset(fmargin, 0, V_F*sizeof(int))
@@ -141,27 +131,27 @@ cdef class BiLex:
if faligned[i] == 0:
f_i = fsent[i]
fmargin[f_i] = fmargin[f_i] + 1
- emargin[null_word] = emargin[null_word] + 1
+ emargin[NULL_WORD] = emargin[NULL_WORD] + 1
if dict[f_i] == NULL:
- dict[f_i] = new_node(null_word)
+ dict[f_i] = new_node(NULL_WORD)
dict[f_i].val = 1
num_pairs = num_pairs + 1
else:
- count = get_val(dict[f_i], null_word)
+ count = get_val(dict[f_i], NULL_WORD)
if count[0] == 0:
num_pairs = num_pairs + 1
count[0] = count[0] + 1
for j from 0 <= j < J:
if ealigned[j] == 0:
e_j = esent[j]
- fmargin[null_word] = fmargin[null_word] + 1
+ fmargin[NULL_WORD] = fmargin[NULL_WORD] + 1
emargin[e_j] = emargin[e_j] + 1
- if dict[null_word] == NULL:
- dict[null_word] = new_node(e_j)
- dict[null_word].val = 1
+ if dict[NULL_WORD] == NULL:
+ dict[NULL_WORD] = new_node(e_j)
+ dict[NULL_WORD].val = 1
num_pairs = num_pairs + 1
else:
- count = get_val(dict[null_word], e_j)
+ count = get_val(dict[NULL_WORD], e_j)
if count[0] == 0:
num_pairs = num_pairs + 1
count[0] = count[0] + 1
@@ -199,73 +189,43 @@ cdef class BiLex:
self._add_node(n.bigger, num_pairs, fmargin, emargin)
- def write_binary(self, char* filename):
+ def write_binary(self, bytes filename):
cdef FILE* f
f = fopen(filename, "w")
self.f_index.write_handle(f)
self.e_index.write_handle(f)
self.col1.write_handle(f)
self.col2.write_handle(f)
- self.write_wordlist(self.id2fword, f)
- self.write_wordlist(self.id2eword, f)
+ self.f_voc.write_handle(f)
+ self.e_voc.write_handle(f)
fclose(f)
-
- cdef write_wordlist(self, wordlist, FILE* f):
- cdef int word_len
- cdef int num_words
-
- num_words = len(wordlist)
- fwrite(&(num_words), sizeof(int), 1, f)
- for word in wordlist:
- word_len = len(word) + 1
- fwrite(&(word_len), sizeof(int), 1, f)
- fwrite(<char *>word, sizeof(char), word_len, f)
-
-
- cdef read_wordlist(self, word2id, id2word, FILE* f):
- cdef int num_words
- cdef int word_len
- cdef char* word
-
- fread(&(num_words), sizeof(int), 1, f)
- for i from 0 <= i < num_words:
- fread(&(word_len), sizeof(int), 1, f)
- word = <char*> malloc (word_len * sizeof(char))
- fread(word, sizeof(char), word_len, f)
- word2id[word] = len(id2word)
- id2word.append(word)
- free(word)
-
- def read_binary(self, char* filename):
+ def read_binary(self, bytes filename):
cdef FILE* f
f = fopen(filename, "r")
self.f_index.read_handle(f)
self.e_index.read_handle(f)
self.col1.read_handle(f)
self.col2.read_handle(f)
- self.read_wordlist(self.fword2id, self.id2fword, f)
- self.read_wordlist(self.eword2id, self.id2eword, f)
+ self.f_voc.read_handle(f)
+ self.e_voc.read_handle(f)
fclose(f)
+ def read_mmaped(self, MemoryMap buf):
+ self.f_index.read_mmaped(buf)
+ self.e_index.read_mmaped(buf)
+ self.col1.read_mmaped(buf)
+ self.col2.read_mmaped(buf)
+ self.f_voc.read_mmaped(buf)
+ self.e_voc.read_mmaped(buf)
def get_e_id(self, eword):
- if eword not in self.eword2id:
- e_id = len(self.id2eword)
- self.id2eword.append(eword)
- self.eword2id[eword] = e_id
- return self.eword2id[eword]
-
+ return self.e_voc[eword]
def get_f_id(self, fword):
- if fword not in self.fword2id:
- f_id = len(self.id2fword)
- self.id2fword.append(fword)
- self.fword2id[fword] = f_id
- return self.fword2id[fword]
+ return self.f_voc[fword]
-
- def read_text(self, char* filename):
+ def read_text(self, bytes filename):
cdef i, j, w, e_id, f_id, n_f, n_e, N
cdef IntList fcount
@@ -309,7 +269,7 @@ cdef class BiLex:
for b from 0 <= b < n_f:
i = self.f_index.arr[b]
j = self.f_index.arr[b+1]
- self.qsort(i,j, "")
+ self.qsort(i, j)
cdef swap(self, int i, int j):
@@ -332,7 +292,7 @@ cdef class BiLex:
self.col2.arr[j] = ftmp
- cdef qsort(self, int i, int j, pad):
+ cdef qsort(self, int i, int j):
cdef int pval, p
if i > j:
@@ -351,11 +311,11 @@ cdef class BiLex:
self.swap(p+1, k)
self.swap(p, p+1)
p = p + 1
- self.qsort(i,p, pad+" ")
- self.qsort(p+1,j, pad+" ")
+ self.qsort(i, p)
+ self.qsort(p+1, j)
- def write_enhanced(self, char* filename):
+ def write_enhanced(self, bytes filename):
with open(filename, "w") as f:
for i in self.f_index:
f.write("%d " % i)
@@ -363,10 +323,10 @@ cdef class BiLex:
for i, s1, s2 in zip(self.e_index, self.col1, self.col2):
f.write("%d %f %f " % (i, s1, s2))
f.write("\n")
- for i, w in enumerate(self.id2fword):
+ for i, w in enumerate(self.f_voc.id2word):
f.write("%d %s " % (i, w))
f.write("\n")
- for i, w in enumerate(self.id2eword):
+ for i, w in enumerate(self.f_voc.id2word):
f.write("%d %s " % (i, w))
f.write("\n")
@@ -374,12 +334,10 @@ cdef class BiLex:
def get_score(self, fword, eword, col):
cdef e_id, f_id, low, high, midpoint, val
- if eword not in self.eword2id:
- return None
- if fword not in self.fword2id:
- return None
- f_id = self.fword2id[fword]
- e_id = self.eword2id[eword]
+ f_id = self.f_voc.get(fword, None)
+ e_id = self.e_voc.get(eword, None)
+ if f_id is None or e_id is None: return None
+
low = self.f_index.arr[f_id]
high = self.f_index.arr[f_id+1]
while high - low > 0:
@@ -397,7 +355,7 @@ cdef class BiLex:
return None
- def write_text(self, char* filename):
+ def write_text(self, bytes filename):
"""Note: does not guarantee writing the dictionary in the original order"""
cdef i, N, e_id, f_id
@@ -410,4 +368,4 @@ cdef class BiLex:
e_id = self.e_index.arr[i]
score1 = self.col1.arr[i]
score2 = self.col2.arr[i]
- f.write("%s %s %.6f %.6f\n" % (self.id2fword[f_id], self.id2eword[e_id], score1, score2))
+ f.write("%s %s %.6f %.6f\n" % (self.f_voc.id2word[f_id], self.e_voc.id2word[e_id], score1, score2))