summaryrefslogtreecommitdiff
path: root/klm/lm/trie.hh
diff options
context:
space:
mode:
authorKenneth Heafield <kenlm@kheafield.com>2011-06-26 18:40:15 -0400
committerKenneth Heafield <kenlm@kheafield.com>2011-06-26 18:40:15 -0400
commit44b023d165402406fc20bb5fa240cb9a9ebf32ea (patch)
tree1328ea3c06da5fe3e8733bb1a13fead1855ac947 /klm/lm/trie.hh
parent5fb86e4236cbca8285f2df6e10f44098afbf8dc0 (diff)
Quantization
Diffstat (limited to 'klm/lm/trie.hh')
-rw-r--r--klm/lm/trie.hh33
1 files changed, 18 insertions, 15 deletions
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index 6aef050c..8fa21aaf 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -74,23 +74,21 @@ class BitPacked {
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
- uint8_t word_bits_, prob_bits_;
+ uint8_t word_bits_;
uint8_t total_bits_;
uint64_t word_mask_;
uint8_t *base_;
- uint64_t insert_index_;
+ uint64_t insert_index_, max_vocab_;
};
-class BitPackedMiddle : public BitPacked {
+template <class Quant> class BitPackedMiddle : public BitPacked {
public:
- BitPackedMiddle() {}
-
- static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next);
+ static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next);
// next_source need not be initialized.
- void Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source);
+ BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source);
void Insert(WordIndex word, float prob, float backoff);
@@ -101,28 +99,33 @@ class BitPackedMiddle : public BitPacked {
void FinishedLoading(uint64_t next_end);
private:
- uint8_t backoff_bits_, next_bits_;
+ Quant quant_;
+ uint8_t next_bits_;
uint64_t next_mask_;
const BitPacked *next_source_;
};
-class BitPackedLongest : public BitPacked {
+template <class Quant> class BitPackedLongest : public BitPacked {
public:
- BitPackedLongest() {}
-
- static std::size_t Size(uint64_t entries, uint64_t max_vocab) {
- return BaseSize(entries, max_vocab, 0);
+ static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
+ return BaseSize(entries, max_vocab, quant_bits);
}
- void Init(void *base, uint64_t max_vocab) {
- return BaseInit(base, max_vocab, 0);
+ BitPackedLongest() {}
+
+ void Init(void *base, const Quant &quant, uint64_t max_vocab) {
+ quant_ = quant;
+ BaseInit(base, max_vocab, quant_.TotalBits());
}
void Insert(WordIndex word, float prob);
bool Find(WordIndex word, float &prob, const NodeRange &node) const;
+
+ private:
+ Quant quant_;
};
} // namespace trie