#ifndef LM_TRIE__ #define LM_TRIE__ #include "lm/weights.hh" #include "lm/word_index.hh" #include "util/bit_packing.hh" #include <cstddef> #include <stdint.h> namespace lm { namespace ngram { class Config; namespace trie { struct NodeRange { uint64_t begin, end; }; // TODO: if the number of unigrams is a concern, also bit pack these records. struct UnigramValue { ProbBackoff weights; uint64_t next; uint64_t Next() const { return next; } }; class UnigramPointer { public: explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {} UnigramPointer() : to_(NULL) {} bool Found() const { return to_ != NULL; } float Prob() const { return to_->prob; } float Backoff() const { return to_->backoff; } float Rest() const { return Prob(); } private: const ProbBackoff *to_; }; class Unigram { public: Unigram() {} void Init(void *start) { unigram_ = static_cast<UnigramValue*>(start); } static std::size_t Size(uint64_t count) { // +1 in case unknown doesn't appear. +1 for the final next. return (count + 2) * sizeof(UnigramValue); } const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; } ProbBackoff &Unknown() { return unigram_[0].weights; } UnigramValue *Raw() { return unigram_; } void LoadedBinary() {} UnigramPointer Find(WordIndex word, NodeRange &next) const { UnigramValue *val = unigram_ + word; next.begin = val->next; next.end = (val+1)->next; return UnigramPointer(val->weights); } private: UnigramValue *unigram_; }; class BitPacked { public: BitPacked() {} uint64_t InsertIndex() const { return insert_index_; } protected: static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); uint8_t word_bits_; uint8_t total_bits_; uint64_t word_mask_; uint8_t *base_; uint64_t insert_index_, max_vocab_; }; template <class Bhiksha> class BitPackedMiddle : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); util::BitAddress Insert(WordIndex word); void FinishedLoading(uint64_t next_end, const Config &config); void LoadedBinary() { bhiksha_.LoadedBinary(); } util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const; util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) { uint64_t addr = pointer * total_bits_; addr += word_bits_; bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range); return util::BitAddress(base_, addr); } private: uint8_t quant_bits_; Bhiksha bhiksha_; const BitPacked *next_source_; }; class BitPackedLongest : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { return BaseSize(entries, max_vocab, quant_bits); } BitPackedLongest() {} void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) { BaseInit(base, max_vocab, quant_bits); } void LoadedBinary() {} util::BitAddress Insert(WordIndex word); util::BitAddress Find(WordIndex word, const NodeRange &node) const; private: uint8_t quant_bits_; }; } // namespace trie } // namespace ngram } // namespace lm #endif // LM_TRIE__