diff options
Diffstat (limited to 'klm/lm/trie.cc')
-rw-r--r-- | klm/lm/trie.cc | 42 |
1 files changed, 21 insertions, 21 deletions
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 8ed7b2a2..04bd2079 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -15,21 +15,21 @@ namespace { // Assumes key is first. class JustKeyProxy { public: - JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {} + JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {} operator uint64_t() const { return GetKey(); } uint64_t GetKey() const { uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_); - return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_); + return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_); } private: friend class util::ProxyIterator<JustKeyProxy>; - friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); + friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); - JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits) - : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {} + JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) + : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} // This is a read-only iterator. JustKeyProxy &operator=(const JustKeyProxy &other); @@ -44,12 +44,12 @@ class JustKeyProxy { uint64_t inner_; const uint8_t *const base_; const uint64_t key_mask_; - const uint8_t total_bits_; + const uint8_t key_bits_, total_bits_; }; -bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { - util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits)); - util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits)); +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { + util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits)); + util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits)); util::ProxyIterator<JustKeyProxy> out; if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false; at_index = out.Inner(); @@ -96,67 +96,67 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t assert(next <= next_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word); + util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word); at_pointer += word_bits_; util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); at_pointer += prob_bits_; util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); at_pointer += backoff_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next); + util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next); ++insert_index_; } bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; at_pointer *= total_bits_; at_pointer += word_bits_; prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); at_pointer += prob_bits_; backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); return true; } bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; at_pointer *= total_bits_; at_pointer += word_bits_; at_pointer += prob_bits_; backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); return true; } void BitPackedMiddle::FinishedLoading(uint64_t next_end) { assert(next_end <= next_mask_); uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end); + util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end); } void BitPackedLongest::Insert(WordIndex index, float prob) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index); + util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index); at_pointer += word_bits_; util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); ++insert_index_; } -bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const { +bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; at_pointer = at_pointer * total_bits_ + word_bits_; prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); return true; |