summaryrefslogtreecommitdiff
path: root/klm/lm/trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie.cc')
-rw-r--r--klm/lm/trie.cc57
1 files changed, 30 insertions, 27 deletions
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 63c2a612..8c536e66 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -1,5 +1,6 @@
#include "lm/trie.hh"
+#include "lm/bhiksha.hh"
#include "lm/quantize.hh"
#include "util/bit_packing.hh"
#include "util/exception.hh"
@@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
max_vocab_ = max_vocab;
}
-template <class Quant> std::size_t BitPackedMiddle<Quant>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
- return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr));
+template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
+ return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}
-template <class Quant> BitPackedMiddle<Quant>::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) {
- if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
- BaseInit(base, max_vocab, quant.TotalBits() + next_bits_);
+template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
+ BitPacked(),
+ quant_(quant),
+ // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.
+ bhiksha_(base, entries + 1, max_next, config),
+ next_source_(&next_source) {
+ if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
+ BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits());
}
-template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float prob, float backoff) {
+template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) {
assert(word <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
@@ -75,47 +81,42 @@ template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float
quant_.Write(base_, at_pointer, prob, backoff);
at_pointer += quant_.TotalBits();
uint64_t next = next_source_->InsertIndex();
- assert(next <= next_mask_);
- util::WriteInt57(base_, at_pointer, next_bits_, next);
+ bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);
++insert_index_;
}
-template <class Quant> bool BitPackedMiddle<Quant>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
+template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
return false;
}
+ uint64_t index = at_pointer;
at_pointer *= total_bits_;
at_pointer += word_bits_;
quant_.Read(base_, at_pointer, prob, backoff);
at_pointer += quant_.TotalBits();
- range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
- // Read the next entry's pointer.
- at_pointer += total_bits_;
- range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
+ bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
+
return true;
}
-template <class Quant> bool BitPackedMiddle<Quant>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
- uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
- at_pointer *= total_bits_;
+template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
+ uint64_t index;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false;
+ uint64_t at_pointer = index * total_bits_;
at_pointer += word_bits_;
quant_.ReadBackoff(base_, at_pointer, backoff);
at_pointer += quant_.TotalBits();
- range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
- // Read the next entry's pointer.
- at_pointer += total_bits_;
- range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
+ bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
return true;
}
-template <class Quant> void BitPackedMiddle<Quant>::FinishedLoading(uint64_t next_end) {
- assert(next_end <= next_mask_);
- uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
- util::WriteInt57(base_, last_next_write, next_bits_, next_end);
+template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
+ uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits();
+ bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end);
+ bhiksha_.FinishedLoading(config);
}
template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) {
@@ -135,8 +136,10 @@ template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float
return true;
}
-template class BitPackedMiddle<DontQuantize::Middle>;
-template class BitPackedMiddle<SeparatelyQuantize::Middle>;
+template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>;
+template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>;
+template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>;
+template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>;
template class BitPackedLongest<DontQuantize::Longest>;
template class BitPackedLongest<SeparatelyQuantize::Longest>;