summaryrefslogtreecommitdiff
path: root/klm/lm/trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie.cc')
-rw-r--r--klm/lm/trie.cc42
1 files changed, 21 insertions, 21 deletions
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 8ed7b2a2..04bd2079 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -15,21 +15,21 @@ namespace {
// Assumes key is first.
class JustKeyProxy {
public:
- JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {}
+ JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {}
operator uint64_t() const { return GetKey(); }
uint64_t GetKey() const {
uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_);
- return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_);
+ return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_);
}
private:
friend class util::ProxyIterator<JustKeyProxy>;
- friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
+ friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
- JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits)
- : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {}
+ JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
+ : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
// This is a read-only iterator.
JustKeyProxy &operator=(const JustKeyProxy &other);
@@ -44,12 +44,12 @@ class JustKeyProxy {
uint64_t inner_;
const uint8_t *const base_;
const uint64_t key_mask_;
- const uint8_t total_bits_;
+ const uint8_t key_bits_, total_bits_;
};
-bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
- util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits));
- util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits));
+bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
+ util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits));
+ util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits));
util::ProxyIterator<JustKeyProxy> out;
if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;
at_index = out.Inner();
@@ -96,67 +96,67 @@ void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t
assert(next <= next_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
- util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word);
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word);
at_pointer += word_bits_;
util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
at_pointer += prob_bits_;
util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);
at_pointer += backoff_bits_;
- util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next);
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next);
++insert_index_;
}
bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer *= total_bits_;
at_pointer += word_bits_;
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += prob_bits_;
backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += backoff_bits_;
- range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
// Read the next entry's pointer.
at_pointer += total_bits_;
- range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
return true;
}
bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer *= total_bits_;
at_pointer += word_bits_;
at_pointer += prob_bits_;
backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += backoff_bits_;
- range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
// Read the next entry's pointer.
at_pointer += total_bits_;
- range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
return true;
}
void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
assert(next_end <= next_mask_);
uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
- util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end);
+ util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end);
}
void BitPackedLongest::Insert(WordIndex index, float prob) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
- util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index);
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index);
at_pointer += word_bits_;
util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
++insert_index_;
}
-bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const {
+bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer = at_pointer * total_bits_ + word_bits_;
prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
return true;