diff options
Diffstat (limited to 'klm/lm/trie.hh')
-rw-r--r-- | klm/lm/trie.hh | 61 |
1 files changed, 36 insertions, 25 deletions
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index ebe9910f..eff93292 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -1,12 +1,13 @@ #ifndef LM_TRIE__ #define LM_TRIE__ -#include <stdint.h> +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/bit_packing.hh" #include <cstddef> -#include "lm/word_index.hh" -#include "lm/weights.hh" +#include <stdint.h> namespace lm { namespace ngram { @@ -24,6 +25,22 @@ struct UnigramValue { uint64_t Next() const { return next; } }; +class UnigramPointer { + public: + explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {} + + UnigramPointer() : to_(NULL) {} + + bool Found() const { return to_ != NULL; } + + float Prob() const { return to_->prob; } + float Backoff() const { return to_->backoff; } + float Rest() const { return Prob(); } + + private: + const ProbBackoff *to_; +}; + class Unigram { public: Unigram() {} @@ -47,12 +64,11 @@ class Unigram { void LoadedBinary() {} - void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { + UnigramPointer Find(WordIndex word, NodeRange &next) const { UnigramValue *val = unigram_ + word; - prob = val->weights.prob; - backoff = val->weights.backoff; next.begin = val->next; next.end = (val+1)->next; + return UnigramPointer(val->weights); } private: @@ -81,40 +97,36 @@ class BitPacked { uint64_t insert_index_, max_vocab_; }; -template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked { +template <class Bhiksha> class BitPackedMiddle : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. - BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); + BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); - void Insert(WordIndex word, float prob, float backoff); + util::BitAddress Insert(WordIndex word); void FinishedLoading(uint64_t next_end, const Config &config); void LoadedBinary() { bhiksha_.LoadedBinary(); } - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const; - - bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; + util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const; - NodeRange ReadEntry(uint64_t pointer, float &prob) { + util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) { uint64_t addr = pointer * total_bits_; addr += word_bits_; - quant_.ReadProb(base_, addr, prob); - NodeRange ret; - bhiksha_.ReadNext(base_, addr + quant_.TotalBits(), pointer, total_bits_, ret); - return ret; + bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range); + return util::BitAddress(base_, addr); } private: - Quant quant_; + uint8_t quant_bits_; Bhiksha bhiksha_; const BitPacked *next_source_; }; -template <class Quant> class BitPackedLongest : public BitPacked { +class BitPackedLongest : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { return BaseSize(entries, max_vocab, quant_bits); @@ -122,19 +134,18 @@ template <class Quant> class BitPackedLongest : public BitPacked { BitPackedLongest() {} - void Init(void *base, const Quant &quant, uint64_t max_vocab) { - quant_ = quant; - BaseInit(base, max_vocab, quant_.TotalBits()); + void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) { + BaseInit(base, max_vocab, quant_bits); } void LoadedBinary() {} - void Insert(WordIndex word, float prob); + util::BitAddress Insert(WordIndex word); - bool Find(WordIndex word, float &prob, const NodeRange &node) const; + util::BitAddress Find(WordIndex word, const NodeRange &node) const; private: - Quant quant_; + uint8_t quant_bits_; }; } // namespace trie |