summaryrefslogtreecommitdiff
path: root/klm/lm/trie.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie.hh')
-rw-r--r--klm/lm/trie.hh33
1 files changed, 18 insertions, 15 deletions
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index 6aef050c..8fa21aaf 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -74,23 +74,21 @@ class BitPacked {
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
- uint8_t word_bits_, prob_bits_;
+ uint8_t word_bits_;
uint8_t total_bits_;
uint64_t word_mask_;
uint8_t *base_;
- uint64_t insert_index_;
+ uint64_t insert_index_, max_vocab_;
};
-class BitPackedMiddle : public BitPacked {
+template <class Quant> class BitPackedMiddle : public BitPacked {
public:
- BitPackedMiddle() {}
-
- static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next);
+ static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next);
// next_source need not be initialized.
- void Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source);
+ BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source);
void Insert(WordIndex word, float prob, float backoff);
@@ -101,28 +99,33 @@ class BitPackedMiddle : public BitPacked {
void FinishedLoading(uint64_t next_end);
private:
- uint8_t backoff_bits_, next_bits_;
+ Quant quant_;
+ uint8_t next_bits_;
uint64_t next_mask_;
const BitPacked *next_source_;
};
-class BitPackedLongest : public BitPacked {
+template <class Quant> class BitPackedLongest : public BitPacked {
public:
- BitPackedLongest() {}
-
- static std::size_t Size(uint64_t entries, uint64_t max_vocab) {
- return BaseSize(entries, max_vocab, 0);
+ static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
+ return BaseSize(entries, max_vocab, quant_bits);
}
- void Init(void *base, uint64_t max_vocab) {
- return BaseInit(base, max_vocab, 0);
+ BitPackedLongest() {}
+
+ void Init(void *base, const Quant &quant, uint64_t max_vocab) {
+ quant_ = quant;
+ BaseInit(base, max_vocab, quant_.TotalBits());
}
void Insert(WordIndex word, float prob);
bool Find(WordIndex word, float &prob, const NodeRange &node) const;
+
+ private:
+ Quant quant_;
};
} // namespace trie