summaryrefslogtreecommitdiff
path: root/klm/lm/trie.cc
diff options
context:
space:
mode:
authorKenneth Heafield <kenlm@kheafield.com>2011-06-26 18:40:15 -0400
committerPatrick Simianer <p@simianer.de>2011-09-23 19:13:57 +0200
commitb0b0db256b4379ee5404f728dc85d26690ac729e (patch)
tree1328ea3c06da5fe3e8733bb1a13fead1855ac947 /klm/lm/trie.cc
parent9075eb00e694b0ccdd8b2569d1c011cb63df8f2e (diff)
Quantization
Diffstat (limited to 'klm/lm/trie.cc')
-rw-r--r--klm/lm/trie.cc123
1 files changed, 49 insertions, 74 deletions
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 2c633613..63c2a612 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -1,8 +1,8 @@
#include "lm/trie.hh"
+#include "lm/quantize.hh"
#include "util/bit_packing.hh"
#include "util/exception.hh"
-#include "util/proxy_iterator.hh"
#include "util/sorted_uniform.hh"
#include <assert.h>
@@ -12,53 +12,32 @@ namespace ngram {
namespace trie {
namespace {
-// Assumes key is first.
-class JustKeyProxy {
+class KeyAccessor {
public:
- JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {}
+ KeyAccessor(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
+ : base_(reinterpret_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
- operator uint64_t() const { return GetKey(); }
+ typedef uint64_t Key;
- uint64_t GetKey() const {
- uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_);
- return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_);
+ Key operator()(uint64_t index) const {
+ return util::ReadInt57(base_, index * static_cast<uint64_t>(total_bits_), key_bits_, key_mask_);
}
private:
- friend class util::ProxyIterator<JustKeyProxy>;
- friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
-
- JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
- : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
-
- // This is a read-only iterator.
- JustKeyProxy &operator=(const JustKeyProxy &other);
-
- typedef uint64_t value_type;
-
- typedef uint64_t InnerIterator;
- uint64_t &Inner() { return inner_; }
- const uint64_t &Inner() const { return inner_; }
-
- // The address in bits is base_ * 8 + inner_ * total_bits_.
- uint64_t inner_;
const uint8_t *const base_;
- const uint64_t key_mask_;
+ const WordIndex key_mask_;
const uint8_t key_bits_, total_bits_;
};
-bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
- util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits));
- util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits));
- util::ProxyIterator<JustKeyProxy> out;
- if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;
- at_index = out.Inner();
+bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const uint64_t key, uint64_t &at_index) {
+ KeyAccessor accessor(base, key_mask, key_bits, total_bits);
+ if (!util::BoundedSortedUniformFind<uint64_t, KeyAccessor, util::PivotSelect<sizeof(WordIndex)>::T>(accessor, begin_index - 1, (uint64_t)0, end_index, max_vocab, key, at_index)) return false;
return true;
}
} // namespace
std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
- uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits;
+ uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits;
// Extra entry for next pointer at the end.
// +7 then / 8 to round up bits and convert to bytes
// +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.
@@ -71,100 +50,96 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
word_bits_ = util::RequiredBits(max_vocab);
word_mask_ = (1ULL << word_bits_) - 1ULL;
if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions.");
- prob_bits_ = 31;
- total_bits_ = word_bits_ + prob_bits_ + remaining_bits;
+ total_bits_ = word_bits_ + remaining_bits;
base_ = static_cast<uint8_t*>(base);
insert_index_ = 0;
+ max_vocab_ = max_vocab;
}
-std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
- return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr));
+template <class Quant> std::size_t BitPackedMiddle<Quant>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
+ return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr));
}
-void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) {
- next_source_ = &next_source;
- backoff_bits_ = 32;
- next_bits_ = util::RequiredBits(max_next);
+template <class Quant> BitPackedMiddle<Quant>::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) {
if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
- next_mask_ = (1ULL << next_bits_) - 1;
-
- BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
+ BaseInit(base, max_vocab, quant.TotalBits() + next_bits_);
}
-void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) {
+template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float prob, float backoff) {
assert(word <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
- util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word);
+ util::WriteInt57(base_, at_pointer, word_bits_, word);
at_pointer += word_bits_;
- util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
- at_pointer += prob_bits_;
- util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);
- at_pointer += backoff_bits_;
+ quant_.Write(base_, at_pointer, prob, backoff);
+ at_pointer += quant_.TotalBits();
uint64_t next = next_source_->InsertIndex();
assert(next <= next_mask_);
- util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next);
+ util::WriteInt57(base_, at_pointer, next_bits_, next);
++insert_index_;
}
-bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
+template <class Quant> bool BitPackedMiddle<Quant>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) {
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
return false;
}
at_pointer *= total_bits_;
at_pointer += word_bits_;
- prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
- at_pointer += prob_bits_;
- backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
- at_pointer += backoff_bits_;
- range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
+ quant_.Read(base_, at_pointer, prob, backoff);
+ at_pointer += quant_.TotalBits();
+
+ range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
// Read the next entry's pointer.
at_pointer += total_bits_;
- range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
+ range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
return true;
}
-bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
+template <class Quant> bool BitPackedMiddle<Quant>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
at_pointer *= total_bits_;
at_pointer += word_bits_;
- at_pointer += prob_bits_;
- backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
- at_pointer += backoff_bits_;
- range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
+ quant_.ReadBackoff(base_, at_pointer, backoff);
+ at_pointer += quant_.TotalBits();
+ range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
// Read the next entry's pointer.
at_pointer += total_bits_;
- range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_);
+ range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_);
return true;
}
-void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
+template <class Quant> void BitPackedMiddle<Quant>::FinishedLoading(uint64_t next_end) {
assert(next_end <= next_mask_);
uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
- util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end);
+ util::WriteInt57(base_, last_next_write, next_bits_, next_end);
}
-void BitPackedLongest::Insert(WordIndex index, float prob) {
+template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
- util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index);
+ util::WriteInt57(base_, at_pointer, word_bits_, index);
at_pointer += word_bits_;
- util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
+ quant_.Write(base_, at_pointer, prob);
++insert_index_;
}
-bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const {
+template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float &prob, const NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
at_pointer = at_pointer * total_bits_ + word_bits_;
- prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
+ quant_.Read(base_, at_pointer, prob);
return true;
}
+template class BitPackedMiddle<DontQuantize::Middle>;
+template class BitPackedMiddle<SeparatelyQuantize::Middle>;
+template class BitPackedLongest<DontQuantize::Longest>;
+template class BitPackedLongest<SeparatelyQuantize::Longest>;
+
} // namespace trie
} // namespace ngram
} // namespace lm