diff options
Diffstat (limited to 'klm/lm/trie.hh')
| -rw-r--r-- | klm/lm/trie.hh | 24 | 
1 files changed, 12 insertions, 12 deletions
| diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 8fa21aaf..53612064 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -10,6 +10,7 @@  namespace lm {  namespace ngram { +class Config;  namespace trie {  struct NodeRange { @@ -46,13 +47,12 @@ class Unigram {      void LoadedBinary() {} -    bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { +    void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {        UnigramValue *val = unigram_ + word;        prob = val->weights.prob;        backoff = val->weights.backoff;        next.begin = val->next;        next.end = (val+1)->next; -      return true;      }    private: @@ -67,8 +67,6 @@ class BitPacked {        return insert_index_;      } -    void LoadedBinary() {} -    protected:      static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); @@ -83,30 +81,30 @@ class BitPacked {      uint64_t insert_index_, max_vocab_;  }; -template <class Quant> class BitPackedMiddle : public BitPacked { +template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {    public: -    static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); +    static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);      // next_source need not be initialized.   -    BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); +    BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);      void Insert(WordIndex word, float prob, float backoff); +    void FinishedLoading(uint64_t next_end, const Config &config); + +    void LoadedBinary() { bhiksha_.LoadedBinary(); } +      bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;      bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; -    void FinishedLoading(uint64_t next_end); -    private:      Quant quant_; -    uint8_t next_bits_; -    uint64_t next_mask_; +    Bhiksha bhiksha_;      const BitPacked *next_source_;  }; -  template <class Quant> class BitPackedLongest : public BitPacked {    public:      static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { @@ -120,6 +118,8 @@ template <class Quant> class BitPackedLongest : public BitPacked {        BaseInit(base, max_vocab, quant_.TotalBits());      } +    void LoadedBinary() {} +      void Insert(WordIndex word, float prob);      bool Find(WordIndex word, float &prob, const NodeRange &node) const; | 
