diff options
author | Kenneth Heafield <kenlm@kheafield.com> | 2011-06-26 18:40:15 -0400 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2011-09-23 19:13:57 +0200 |
commit | b0b0db256b4379ee5404f728dc85d26690ac729e (patch) | |
tree | 1328ea3c06da5fe3e8733bb1a13fead1855ac947 /klm/lm/trie.hh | |
parent | 9075eb00e694b0ccdd8b2569d1c011cb63df8f2e (diff) |
Quantization
Diffstat (limited to 'klm/lm/trie.hh')
-rw-r--r-- | klm/lm/trie.hh | 33 |
1 files changed, 18 insertions, 15 deletions
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 6aef050c..8fa21aaf 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -74,23 +74,21 @@ class BitPacked { void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); - uint8_t word_bits_, prob_bits_; + uint8_t word_bits_; uint8_t total_bits_; uint64_t word_mask_; uint8_t *base_; - uint64_t insert_index_; + uint64_t insert_index_, max_vocab_; }; -class BitPackedMiddle : public BitPacked { +template <class Quant> class BitPackedMiddle : public BitPacked { public: - BitPackedMiddle() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); // next_source need not be initialized. - void Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); void Insert(WordIndex word, float prob, float backoff); @@ -101,28 +99,33 @@ class BitPackedMiddle : public BitPacked { void FinishedLoading(uint64_t next_end); private: - uint8_t backoff_bits_, next_bits_; + Quant quant_; + uint8_t next_bits_; uint64_t next_mask_; const BitPacked *next_source_; }; -class BitPackedLongest : public BitPacked { +template <class Quant> class BitPackedLongest : public BitPacked { public: - BitPackedLongest() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab) { - return BaseSize(entries, max_vocab, 0); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { + return BaseSize(entries, max_vocab, quant_bits); } - void Init(void *base, uint64_t max_vocab) { - return BaseInit(base, max_vocab, 0); + BitPackedLongest() {} + + void Init(void *base, const Quant &quant, uint64_t max_vocab) { + quant_ = quant; + BaseInit(base, max_vocab, quant_.TotalBits()); } void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; + + private: + Quant quant_; }; } // namespace trie |