diff options
-rw-r--r-- | klm/lm/bhiksha.hh | 2 | ||||
-rw-r--r-- | klm/lm/trie.cc | 6 | ||||
-rw-r--r-- | klm/lm/trie.hh | 7 | ||||
-rw-r--r-- | klm/lm/trie_sort.cc | 4 |
4 files changed, 10 insertions, 9 deletions
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index ff7fe452..bc705959 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -11,6 +11,7 @@ */ #include <inttypes.h> +#include <assert.h> #include "lm/model_type.hh" #include "lm/trie.hh" @@ -78,6 +79,7 @@ class ArrayBhiksha { util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); out.end = ((end_it - offset_begin_) << next_inline_.bits) | util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); + //assert(out.end >= out.begin); } void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 4e60b184..20075bb8 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -91,16 +91,14 @@ template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } - uint64_t index = at_pointer; + pointer = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; - pointer = at_pointer; - quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); - bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); + bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range); return true; } diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index a9f5e417..06cc96ac 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -99,10 +99,11 @@ template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked { bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; NodeRange ReadEntry(uint64_t pointer, float &prob) { - quant_.ReadProb(base_, pointer, prob); + uint64_t addr = pointer * total_bits_; + addr += word_bits_; + quant_.ReadProb(base_, addr, prob); NodeRange ret; - // pointer/total_bits_ should always round down. - bhiksha_.ReadNext(base_, pointer + quant_.TotalBits(), pointer / total_bits_, total_bits_, ret); + bhiksha_.ReadNext(base_, addr + quant_.TotalBits(), pointer, total_bits_, ret); return ret; } diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 01c4e490..86f28493 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -146,7 +146,7 @@ template <class Combine> void MergeSortedFiles(const std::string &first_name, co ++first; ++second; } } - for (RecordReader &remains = (first ? second : first); remains; ++remains) { + for (RecordReader &remains = (first ? first : second); remains; ++remains) { WriteOrThrow(out_file.get(), remains.Data(), entry_size); } } @@ -191,7 +191,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++); files.push_back(assembled.str()); MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); - MergeSortedFiles(files[0], files[1], files.back(), 0, order, FirstCombine()); + MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order, FirstCombine()); files.pop_front(); files.pop_front(); } |