summaryrefslogtreecommitdiff
path: root/klm/lm/trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie.cc')
-rw-r--r--klm/lm/trie.cc13
1 files changed, 8 insertions, 5 deletions
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 04bd2079..2c633613 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -82,7 +82,8 @@ std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t
return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr));
}
-void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) {
+void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) {
+ next_source_ = &next_source;
backoff_bits_ = 32;
next_bits_ = util::RequiredBits(max_next);
if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
@@ -91,9 +92,8 @@ void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) {
BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
}
-void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) {
+void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) {
assert(word <= word_mask_);
- assert(next <= next_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word);
@@ -102,6 +102,8 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t
at_pointer += prob_bits_;
util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);
at_pointer += backoff_bits_;
+ uint64_t next = next_source_->InsertIndex();
+ assert(next <= next_mask_);
util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next);
++insert_index_;
@@ -109,7 +111,9 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t
bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) {
+ return false;
+ }
at_pointer *= total_bits_;
at_pointer += word_bits_;
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
@@ -144,7 +148,6 @@ void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end);
}
-
void BitPackedLongest::Insert(WordIndex index, float prob) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;