diff options
Diffstat (limited to 'klm/lm/trie.cc')
-rw-r--r-- | klm/lm/trie.cc | 61 |
1 files changed, 20 insertions, 41 deletions
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 20075bb8..0f1ca574 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,7 +1,6 @@ #include "lm/trie.hh" #include "lm/bhiksha.hh" -#include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" #include "util/sorted_uniform.hh" @@ -58,91 +57,71 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } -template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : +template <class Bhiksha> BitPackedMiddle<Bhiksha>::BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : BitPacked(), - quant_(quant), + quant_bits_(quant_bits), // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. bhiksha_(base, entries + 1, max_next, config), next_source_(&next_source) { if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits()); + BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant_bits_ + bhiksha_.InlineBits()); } -template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) { +template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Insert(WordIndex word) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; util::WriteInt57(base_, at_pointer, word_bits_, word); at_pointer += word_bits_; - quant_.Write(base_, at_pointer, prob, backoff); - at_pointer += quant_.TotalBits(); + util::BitAddress ret(base_, at_pointer); + at_pointer += quant_bits_; uint64_t next = next_source_->InsertIndex(); bhiksha_.WriteNext(base_, at_pointer, insert_index_, next); - ++insert_index_; + return ret; } -template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const { +template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Find(WordIndex word, NodeRange &range, uint64_t &pointer) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { - return false; + return util::BitAddress(NULL, 0); } pointer = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; + bhiksha_.ReadNext(base_, at_pointer + quant_bits_, pointer, total_bits_, range); - quant_.Read(base_, at_pointer, prob, backoff); - at_pointer += quant_.TotalBits(); - - bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range); - - return true; + return util::BitAddress(base_, at_pointer); } -template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { - uint64_t index; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; - uint64_t at_pointer = index * total_bits_; - at_pointer += word_bits_; - quant_.ReadBackoff(base_, at_pointer, backoff); - at_pointer += quant_.TotalBits(); - bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); - return true; -} - -template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) { +template <class Bhiksha> void BitPackedMiddle<Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) { uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); bhiksha_.FinishedLoading(config); } -template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) { +util::BitAddress BitPackedLongest::Insert(WordIndex index) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; util::WriteInt57(base_, at_pointer, word_bits_, index); at_pointer += word_bits_; - quant_.Write(base_, at_pointer, prob); ++insert_index_; + return util::BitAddress(base_, at_pointer); } -template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float &prob, const NodeRange &range) const { +util::BitAddress BitPackedLongest::Find(WordIndex word, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return util::BitAddress(NULL, 0); at_pointer = at_pointer * total_bits_ + word_bits_; - quant_.Read(base_, at_pointer, prob); - return true; + return util::BitAddress(base_, at_pointer); } -template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>; -template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>; -template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>; -template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>; -template class BitPackedLongest<DontQuantize::Longest>; -template class BitPackedLongest<SeparatelyQuantize::Longest>; +template class BitPackedMiddle<DontBhiksha>; +template class BitPackedMiddle<ArrayBhiksha>; } // namespace trie } // namespace ngram |