diff options
Diffstat (limited to 'klm/lm/trie.cc')
-rw-r--r-- | klm/lm/trie.cc | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 04bd2079..2c633613 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -82,7 +82,8 @@ std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr)); } -void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) { +void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) { + next_source_ = &next_source; backoff_bits_ = 32; next_bits_ = util::RequiredBits(max_next); if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); @@ -91,9 +92,8 @@ void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) { BaseInit(base, max_vocab, backoff_bits_ + next_bits_); } -void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) { +void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); - assert(next <= next_mask_); uint64_t at_pointer = insert_index_ * total_bits_; util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word); @@ -102,6 +102,8 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t at_pointer += prob_bits_; util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); at_pointer += backoff_bits_; + uint64_t next = next_source_->InsertIndex(); + assert(next <= next_mask_); util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next); ++insert_index_; @@ -109,7 +111,9 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) { + return false; + } at_pointer *= total_bits_; at_pointer += word_bits_; prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); @@ -144,7 +148,6 @@ void BitPackedMiddle::FinishedLoading(uint64_t next_end) { util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end); } - void BitPackedLongest::Insert(WordIndex index, float prob) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; |