diff options
| author | Kenneth Heafield <github@kheafield.com> | 2012-05-16 13:24:08 -0700 | 
|---|---|---|
| committer | Chris Dyer <cdyer@cab.ark.cs.cmu.edu> | 2012-05-26 22:59:54 -0400 | 
| commit | 149232c38eec558ddb1097698d1570aacb67b59f (patch) | |
| tree | 5860b4d6f681eeb04a1020cbb2fe7e6ac394af99 /klm/lm/quantize.hh | |
| parent | 01ecc09f8e3a82c32bf7dd2f90c12554becea71d (diff) | |
Big kenlm change includes lower order models for probing only.  And other stuff.
Diffstat (limited to 'klm/lm/quantize.hh')
| -rw-r--r-- | klm/lm/quantize.hh | 164 | 
1 files changed, 91 insertions, 73 deletions
| diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 6d130a57..3e9153e3 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -3,6 +3,7 @@  #include "lm/blank.hh"  #include "lm/config.hh" +#include "lm/max_order.hh"  #include "lm/model_type.hh"  #include "util/bit_packing.hh" @@ -27,37 +28,60 @@ class DontQuantize {      static uint8_t MiddleBits(const Config &/*config*/) { return 63; }      static uint8_t LongestBits(const Config &/*config*/) { return 31; } -    struct Middle { -      void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { -        util::WriteNonPositiveFloat31(base, bit_offset, prob); -        util::WriteFloat32(base, bit_offset + 31, backoff); -      } -      void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { -        prob = util::ReadNonPositiveFloat31(base, bit_offset); -        backoff = util::ReadFloat32(base, bit_offset + 31); -      } -      void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { -        prob = util::ReadNonPositiveFloat31(base, bit_offset); -      } -      void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { -        backoff = util::ReadFloat32(base, bit_offset + 31); -      } -      uint8_t TotalBits() const { return 63; } +    class MiddlePointer { +      public: +        MiddlePointer(const DontQuantize & /*quant*/, unsigned char /*order_minus_2*/, util::BitAddress address) : address_(address) {} + +        MiddlePointer() : address_(NULL, 0) {} + +        bool Found() const { +          return address_.base != NULL; +        } + +        float Prob() const { +          return util::ReadNonPositiveFloat31(address_.base, address_.offset); +        } + +        float Backoff() const { +          return util::ReadFloat32(address_.base, address_.offset + 31); +        } + +        float Rest() const { return Prob(); } + +        void Write(float prob, float backoff) { +          util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); +          util::WriteFloat32(address_.base, address_.offset + 31, backoff); +        } + +      private: +        util::BitAddress address_;      }; -    struct Longest { -      void Write(void *base, uint64_t bit_offset, float prob) const { -        util::WriteNonPositiveFloat31(base, bit_offset, prob); -      } -      void Read(const void *base, uint64_t bit_offset, float &prob) const { -        prob = util::ReadNonPositiveFloat31(base, bit_offset); -      } -      uint8_t TotalBits() const { return 31; } +    class LongestPointer { +      public: +        explicit LongestPointer(const DontQuantize &/*quant*/, util::BitAddress address) : address_(address) {} + +        LongestPointer() : address_(NULL, 0) {} + +        bool Found() const { +          return address_.base != NULL; +        } + +        float Prob() const { +          return util::ReadNonPositiveFloat31(address_.base, address_.offset); +        } + +        void Write(float prob) { +          util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); +        } + +      private: +        util::BitAddress address_;      };      DontQuantize() {} -    void SetupMemory(void * /*start*/, const Config & /*config*/) {} +    void SetupMemory(void * /*start*/, unsigned char /*order*/, const Config & /*config*/) {}      static const bool kTrain = false;      // These should never be called because kTrain is false.   @@ -65,9 +89,6 @@ class DontQuantize {      void TrainProb(uint8_t, std::vector<float> &/*prob*/) {}      void FinishedLoading(const Config &) {} - -    Middle Mid(uint8_t /*order*/) const { return Middle(); } -    Longest Long(uint8_t /*order*/) const { return Longest(); }  };  class SeparatelyQuantize { @@ -77,7 +98,9 @@ class SeparatelyQuantize {          // Sigh C++ default constructor          Bins() {} -        Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} +        Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + +        float *Populate() { return begin_; }          uint64_t EncodeProb(float value) const {            return Encode(value, 0); @@ -98,13 +121,13 @@ class SeparatelyQuantize {        private:          uint64_t Encode(float value, size_t reserved) const { -          const float *above = std::lower_bound(begin_ + reserved, end_, value); +          const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value);            if (above == begin_ + reserved) return reserved;            if (above == end_) return end_ - begin_ - 1;            return above - begin_ - (value - *(above - 1) < *above - value);          } -        const float *begin_; +        float *begin_;          const float *end_;          uint8_t bits_;          uint64_t mask_; @@ -125,65 +148,61 @@ class SeparatelyQuantize {      static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; }      static uint8_t LongestBits(const Config &config) { return config.prob_bits; } -    class Middle { +    class MiddlePointer {        public: -        Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) :  -          total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {} +        MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {} -        void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { -          util::WriteInt57(base, bit_offset, total_bits_,  -              (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); -        } +        MiddlePointer() : address_(NULL, 0) {} -        void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { -          prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask())); -        } +        bool Found() const { return address_.base != NULL; } -        void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { -          uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); -          prob = prob_.Decode(both >> backoff_.Bits()); -          backoff = backoff_.Decode(both & backoff_.Mask()); +        float Prob() const { +          return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask()));          } -        void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { -          backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask())); +        float Backoff() const { +          return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask()));          } -        uint8_t TotalBits() const { -          return total_bits_; +        float Rest() const { return Prob(); } + +        void Write(float prob, float backoff) const { +          util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(),  +              (ProbBins().EncodeProb(prob) << BackoffBins().Bits()) | BackoffBins().EncodeBackoff(backoff));          }        private: -        const uint8_t total_bits_; -        const uint64_t total_mask_; -        const Bins prob_; -        const Bins backoff_; +        const Bins &ProbBins() const { return bins_[0]; } +        const Bins &BackoffBins() const { return bins_[1]; } +        const Bins *bins_; + +        util::BitAddress address_;      }; -    class Longest { +    class LongestPointer {        public: -        // Sigh C++ default constructor -        Longest() {} +        LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {} +         +        LongestPointer() : address_(NULL, 0) {} -        Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {} +        bool Found() const { return address_.base != NULL; } -        void Write(void *base, uint64_t bit_offset, float prob) const { -          util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob)); +        void Write(float prob) const { +          util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob));          } -        void Read(const void *base, uint64_t bit_offset, float &prob) const { -          prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask())); +        float Prob() const { +          return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask()));          } -        uint8_t TotalBits() const { return prob_.Bits(); } -        private: -        Bins prob_; +        const Bins *table_; +        util::BitAddress address_;      };      SeparatelyQuantize() {} -    void SetupMemory(void *start, const Config &config); +    void SetupMemory(void *start, unsigned char order, const Config &config);      static const bool kTrain = true;      // Assumes 0.0 is removed from backoff.   @@ -193,18 +212,17 @@ class SeparatelyQuantize {      void FinishedLoading(const Config &config); -    Middle Mid(uint8_t order) const { -      const float *table = start_ + TableStart(order); -      return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength()); -    } +    const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; } -    Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); } +    const Bins &LongestTable() const { return longest_; }    private: -    size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast<uint64_t>(order - 2); } -    size_t ProbTableLength() const { return (1ULL << prob_bits_); } +    Bins tables_[kMaxOrder - 1][2]; + +    Bins longest_; + +    uint8_t *actual_base_; -    float *start_;      uint8_t prob_bits_, backoff_bits_;  }; | 
