diff options
Diffstat (limited to 'python/src/sa/bilex.pxi')
-rw-r--r-- | python/src/sa/bilex.pxi | 148 |
1 files changed, 53 insertions, 95 deletions
diff --git a/python/src/sa/bilex.pxi b/python/src/sa/bilex.pxi index 44bc0ce6..5e2fcd82 100644 --- a/python/src/sa/bilex.pxi +++ b/python/src/sa/bilex.pxi @@ -43,26 +43,26 @@ cdef int* get_val(_node* n, int key): return &(n.bigger.val) return get_val(n.bigger, key) +cdef int NULL_WORD = 0 cdef class BiLex: - cdef FloatList col1 - cdef FloatList col2 - cdef IntList f_index - cdef IntList e_index - cdef id2eword, id2fword, eword2id, fword2id + cdef FloatList col1, col2 + cdef IntList f_index, e_index + cdef Vocabulary f_voc, e_voc def __cinit__(self, from_text=None, from_data=False, from_binary=None, - earray=None, fsarray=None, alignment=None): - self.id2eword = [] - self.id2fword = [] - self.eword2id = {} - self.fword2id = {} + earray=None, fsarray=None, alignment=None, mmaped=False): + self.f_voc = Vocabulary() + self.e_voc = Vocabulary() self.e_index = IntList() self.f_index = IntList() self.col1 = FloatList() self.col2 = FloatList() if from_binary: - self.read_binary(from_binary) + if mmaped: + self.read_mmaped(MemoryMap(from_binary)) + else: + self.read_binary(from_binary) elif from_data: self.compute_from_data(fsarray, earray, alignment) else: @@ -74,25 +74,15 @@ cdef class BiLex: cdef int *fsent, *esent, *alignment, *links, *ealigned, *faligned cdef _node** dict cdef int *fmargin, *emargin, *count - cdef int null_word - null_word = 0 - for word in fsa.darray.id2word: # I miss list comprehensions - self.id2fword.append(word) - self.id2fword[null_word] = "NULL" - for id, word in enumerate(self.id2fword): - self.fword2id[word] = id - - for word in eda.id2word: - self.id2eword.append(word) - self.id2eword[null_word] = "NULL" - for id, word in enumerate(self.id2eword): - self.eword2id[word] = id + self.f_voc.extend(fsa.darray.voc) + self.e_voc.extend(eda.voc) + assert(self.f_voc['NULL'] == self.e_voc['NULL'] == NULL_WORD) num_pairs = 0 - V_E = len(eda.id2word) - V_F = len(fsa.darray.id2word) + V_E = len(eda.voc) + V_F = len(fsa.darray.voc) fmargin = <int*> malloc(V_F*sizeof(int)) emargin = <int*> malloc(V_E*sizeof(int)) memset(fmargin, 0, V_F*sizeof(int)) @@ -141,27 +131,27 @@ cdef class BiLex: if faligned[i] == 0: f_i = fsent[i] fmargin[f_i] = fmargin[f_i] + 1 - emargin[null_word] = emargin[null_word] + 1 + emargin[NULL_WORD] = emargin[NULL_WORD] + 1 if dict[f_i] == NULL: - dict[f_i] = new_node(null_word) + dict[f_i] = new_node(NULL_WORD) dict[f_i].val = 1 num_pairs = num_pairs + 1 else: - count = get_val(dict[f_i], null_word) + count = get_val(dict[f_i], NULL_WORD) if count[0] == 0: num_pairs = num_pairs + 1 count[0] = count[0] + 1 for j from 0 <= j < J: if ealigned[j] == 0: e_j = esent[j] - fmargin[null_word] = fmargin[null_word] + 1 + fmargin[NULL_WORD] = fmargin[NULL_WORD] + 1 emargin[e_j] = emargin[e_j] + 1 - if dict[null_word] == NULL: - dict[null_word] = new_node(e_j) - dict[null_word].val = 1 + if dict[NULL_WORD] == NULL: + dict[NULL_WORD] = new_node(e_j) + dict[NULL_WORD].val = 1 num_pairs = num_pairs + 1 else: - count = get_val(dict[null_word], e_j) + count = get_val(dict[NULL_WORD], e_j) if count[0] == 0: num_pairs = num_pairs + 1 count[0] = count[0] + 1 @@ -199,73 +189,43 @@ cdef class BiLex: self._add_node(n.bigger, num_pairs, fmargin, emargin) - def write_binary(self, char* filename): + def write_binary(self, bytes filename): cdef FILE* f f = fopen(filename, "w") self.f_index.write_handle(f) self.e_index.write_handle(f) self.col1.write_handle(f) self.col2.write_handle(f) - self.write_wordlist(self.id2fword, f) - self.write_wordlist(self.id2eword, f) + self.f_voc.write_handle(f) + self.e_voc.write_handle(f) fclose(f) - - cdef write_wordlist(self, wordlist, FILE* f): - cdef int word_len - cdef int num_words - - num_words = len(wordlist) - fwrite(&(num_words), sizeof(int), 1, f) - for word in wordlist: - word_len = len(word) + 1 - fwrite(&(word_len), sizeof(int), 1, f) - fwrite(<char *>word, sizeof(char), word_len, f) - - - cdef read_wordlist(self, word2id, id2word, FILE* f): - cdef int num_words - cdef int word_len - cdef char* word - - fread(&(num_words), sizeof(int), 1, f) - for i from 0 <= i < num_words: - fread(&(word_len), sizeof(int), 1, f) - word = <char*> malloc (word_len * sizeof(char)) - fread(word, sizeof(char), word_len, f) - word2id[word] = len(id2word) - id2word.append(word) - free(word) - - def read_binary(self, char* filename): + def read_binary(self, bytes filename): cdef FILE* f f = fopen(filename, "r") self.f_index.read_handle(f) self.e_index.read_handle(f) self.col1.read_handle(f) self.col2.read_handle(f) - self.read_wordlist(self.fword2id, self.id2fword, f) - self.read_wordlist(self.eword2id, self.id2eword, f) + self.f_voc.read_handle(f) + self.e_voc.read_handle(f) fclose(f) + def read_mmaped(self, MemoryMap buf): + self.f_index.read_mmaped(buf) + self.e_index.read_mmaped(buf) + self.col1.read_mmaped(buf) + self.col2.read_mmaped(buf) + self.f_voc.read_mmaped(buf) + self.e_voc.read_mmaped(buf) def get_e_id(self, eword): - if eword not in self.eword2id: - e_id = len(self.id2eword) - self.id2eword.append(eword) - self.eword2id[eword] = e_id - return self.eword2id[eword] - + return self.e_voc[eword] def get_f_id(self, fword): - if fword not in self.fword2id: - f_id = len(self.id2fword) - self.id2fword.append(fword) - self.fword2id[fword] = f_id - return self.fword2id[fword] + return self.f_voc[fword] - - def read_text(self, char* filename): + def read_text(self, bytes filename): cdef i, j, w, e_id, f_id, n_f, n_e, N cdef IntList fcount @@ -309,7 +269,7 @@ cdef class BiLex: for b from 0 <= b < n_f: i = self.f_index.arr[b] j = self.f_index.arr[b+1] - self.qsort(i,j, "") + self.qsort(i, j) cdef swap(self, int i, int j): @@ -332,7 +292,7 @@ cdef class BiLex: self.col2.arr[j] = ftmp - cdef qsort(self, int i, int j, pad): + cdef qsort(self, int i, int j): cdef int pval, p if i > j: @@ -351,11 +311,11 @@ cdef class BiLex: self.swap(p+1, k) self.swap(p, p+1) p = p + 1 - self.qsort(i,p, pad+" ") - self.qsort(p+1,j, pad+" ") + self.qsort(i, p) + self.qsort(p+1, j) - def write_enhanced(self, char* filename): + def write_enhanced(self, bytes filename): with open(filename, "w") as f: for i in self.f_index: f.write("%d " % i) @@ -363,10 +323,10 @@ cdef class BiLex: for i, s1, s2 in zip(self.e_index, self.col1, self.col2): f.write("%d %f %f " % (i, s1, s2)) f.write("\n") - for i, w in enumerate(self.id2fword): + for i, w in enumerate(self.f_voc.id2word): f.write("%d %s " % (i, w)) f.write("\n") - for i, w in enumerate(self.id2eword): + for i, w in enumerate(self.f_voc.id2word): f.write("%d %s " % (i, w)) f.write("\n") @@ -374,12 +334,10 @@ cdef class BiLex: def get_score(self, fword, eword, col): cdef e_id, f_id, low, high, midpoint, val - if eword not in self.eword2id: - return None - if fword not in self.fword2id: - return None - f_id = self.fword2id[fword] - e_id = self.eword2id[eword] + f_id = self.f_voc.get(fword, None) + e_id = self.e_voc.get(eword, None) + if f_id is None or e_id is None: return None + low = self.f_index.arr[f_id] high = self.f_index.arr[f_id+1] while high - low > 0: @@ -397,7 +355,7 @@ cdef class BiLex: return None - def write_text(self, char* filename): + def write_text(self, bytes filename): """Note: does not guarantee writing the dictionary in the original order""" cdef i, N, e_id, f_id @@ -410,4 +368,4 @@ cdef class BiLex: e_id = self.e_index.arr[i] score1 = self.col1.arr[i] score2 = self.col2.arr[i] - f.write("%s %s %.6f %.6f\n" % (self.id2fword[f_id], self.id2eword[e_id], score1, score2)) + f.write("%s %s %.6f %.6f\n" % (self.f_voc.id2word[f_id], self.e_voc.id2word[e_id], score1, score2)) |