summaryrefslogtreecommitdiff
path: root/klm/lm/trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie.cc')
-rw-r--r--klm/lm/trie.cc167
1 files changed, 167 insertions, 0 deletions
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
new file mode 100644
index 00000000..8ed7b2a2
--- /dev/null
+++ b/klm/lm/trie.cc
@@ -0,0 +1,167 @@
+#include "lm/trie.hh"
+
+#include "util/bit_packing.hh"
+#include "util/exception.hh"
+#include "util/proxy_iterator.hh"
+#include "util/sorted_uniform.hh"
+
+#include <assert.h>
+
+namespace lm {
+namespace ngram {
+namespace trie {
+namespace {
+
+// Assumes key is first.
+class JustKeyProxy {
+ public:
+ JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {}
+
+ operator uint64_t() const { return GetKey(); }
+
+ uint64_t GetKey() const {
+ uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_);
+ return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_);
+ }
+
+ private:
+ friend class util::ProxyIterator<JustKeyProxy>;
+ friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
+
+ JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits)
+ : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {}
+
+ // This is a read-only iterator.
+ JustKeyProxy &operator=(const JustKeyProxy &other);
+
+ typedef uint64_t value_type;
+
+ typedef uint64_t InnerIterator;
+ uint64_t &Inner() { return inner_; }
+ const uint64_t &Inner() const { return inner_; }
+
+ // The address in bits is base_ * 8 + inner_ * total_bits_.
+ uint64_t inner_;
+ const uint8_t *const base_;
+ const uint64_t key_mask_;
+ const uint8_t total_bits_;
+};
+
+bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
+ util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits));
+ util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits));
+ util::ProxyIterator<JustKeyProxy> out;
+ if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;
+ at_index = out.Inner();
+ return true;
+}
+} // namespace
+
+std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
+ uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits;
+ // Extra entry for next pointer at the end.
+ // +7 then / 8 to round up bits and convert to bytes
+ // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.
+ // Note that this waste is O(order), not O(number of ngrams).
+ return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t);
+}
+
+void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) {
+ util::BitPackingSanity();
+ word_bits_ = util::RequiredBits(max_vocab);
+ word_mask_ = (1ULL << word_bits_) - 1ULL;
+ if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions.");
+ prob_bits_ = 31;
+ total_bits_ = word_bits_ + prob_bits_ + remaining_bits;
+
+ base_ = static_cast<uint8_t*>(base);
+ insert_index_ = 0;
+}
+
+std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
+ return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr));
+}
+
+void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) {
+ backoff_bits_ = 32;
+ next_bits_ = util::RequiredBits(max_next);
+ if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
+ next_mask_ = (1ULL << next_bits_) - 1;
+
+ BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
+}
+
+void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) {
+ assert(word <= word_mask_);
+ assert(next <= next_mask_);
+ uint64_t at_pointer = insert_index_ * total_bits_;
+
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word);
+ at_pointer += word_bits_;
+ util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
+ at_pointer += prob_bits_;
+ util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);
+ at_pointer += backoff_bits_;
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next);
+
+ ++insert_index_;
+}
+
+bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
+ uint64_t at_pointer;
+ if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ at_pointer *= total_bits_;
+ at_pointer += word_bits_;
+ prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
+ at_pointer += prob_bits_;
+ backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
+ at_pointer += backoff_bits_;
+ range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ // Read the next entry's pointer.
+ at_pointer += total_bits_;
+ range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ return true;
+}
+
+bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
+ uint64_t at_pointer;
+ if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ at_pointer *= total_bits_;
+ at_pointer += word_bits_;
+ at_pointer += prob_bits_;
+ backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
+ at_pointer += backoff_bits_;
+ range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ // Read the next entry's pointer.
+ at_pointer += total_bits_;
+ range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ return true;
+}
+
+void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
+ assert(next_end <= next_mask_);
+ uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
+ util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end);
+}
+
+
+void BitPackedLongest::Insert(WordIndex index, float prob) {
+ assert(index <= word_mask_);
+ uint64_t at_pointer = insert_index_ * total_bits_;
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index);
+ at_pointer += word_bits_;
+ util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
+ ++insert_index_;
+}
+
+bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const {
+ uint64_t at_pointer;
+ if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false;
+ at_pointer = at_pointer * total_bits_ + word_bits_;
+ prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
+ return true;
+}
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm