diff options
author | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-05-31 13:57:24 +0200 |
---|---|---|
committer | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-05-31 13:57:24 +0200 |
commit | 6f6601111710aa67eee5169e5b7d89102cc33bb8 (patch) | |
tree | 0872544abd6bc76162f3f80eb3920999afbf2c34 /klm/lm/quantize.hh | |
parent | 8cee8b565a9c56a7732365e9563f52ff3c4ff7fd (diff) | |
parent | 090a64e73f94a6a35e5364a9d416dcf75c0a2938 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/lm/quantize.hh')
-rw-r--r-- | klm/lm/quantize.hh | 164 |
1 files changed, 91 insertions, 73 deletions
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 6d130a57..3e9153e3 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -3,6 +3,7 @@ #include "lm/blank.hh" #include "lm/config.hh" +#include "lm/max_order.hh" #include "lm/model_type.hh" #include "util/bit_packing.hh" @@ -27,37 +28,60 @@ class DontQuantize { static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } - struct Middle { - void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { - util::WriteNonPositiveFloat31(base, bit_offset, prob); - util::WriteFloat32(base, bit_offset + 31, backoff); - } - void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { - prob = util::ReadNonPositiveFloat31(base, bit_offset); - backoff = util::ReadFloat32(base, bit_offset + 31); - } - void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { - prob = util::ReadNonPositiveFloat31(base, bit_offset); - } - void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { - backoff = util::ReadFloat32(base, bit_offset + 31); - } - uint8_t TotalBits() const { return 63; } + class MiddlePointer { + public: + MiddlePointer(const DontQuantize & /*quant*/, unsigned char /*order_minus_2*/, util::BitAddress address) : address_(address) {} + + MiddlePointer() : address_(NULL, 0) {} + + bool Found() const { + return address_.base != NULL; + } + + float Prob() const { + return util::ReadNonPositiveFloat31(address_.base, address_.offset); + } + + float Backoff() const { + return util::ReadFloat32(address_.base, address_.offset + 31); + } + + float Rest() const { return Prob(); } + + void Write(float prob, float backoff) { + util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); + util::WriteFloat32(address_.base, address_.offset + 31, backoff); + } + + private: + util::BitAddress address_; }; - struct Longest { - void Write(void *base, uint64_t bit_offset, float prob) const { - util::WriteNonPositiveFloat31(base, bit_offset, prob); - } - void Read(const void *base, uint64_t bit_offset, float &prob) const { - prob = util::ReadNonPositiveFloat31(base, bit_offset); - } - uint8_t TotalBits() const { return 31; } + class LongestPointer { + public: + explicit LongestPointer(const DontQuantize &/*quant*/, util::BitAddress address) : address_(address) {} + + LongestPointer() : address_(NULL, 0) {} + + bool Found() const { + return address_.base != NULL; + } + + float Prob() const { + return util::ReadNonPositiveFloat31(address_.base, address_.offset); + } + + void Write(float prob) { + util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); + } + + private: + util::BitAddress address_; }; DontQuantize() {} - void SetupMemory(void * /*start*/, const Config & /*config*/) {} + void SetupMemory(void * /*start*/, unsigned char /*order*/, const Config & /*config*/) {} static const bool kTrain = false; // These should never be called because kTrain is false. @@ -65,9 +89,6 @@ class DontQuantize { void TrainProb(uint8_t, std::vector<float> &/*prob*/) {} void FinishedLoading(const Config &) {} - - Middle Mid(uint8_t /*order*/) const { return Middle(); } - Longest Long(uint8_t /*order*/) const { return Longest(); } }; class SeparatelyQuantize { @@ -77,7 +98,9 @@ class SeparatelyQuantize { // Sigh C++ default constructor Bins() {} - Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + + float *Populate() { return begin_; } uint64_t EncodeProb(float value) const { return Encode(value, 0); @@ -98,13 +121,13 @@ class SeparatelyQuantize { private: uint64_t Encode(float value, size_t reserved) const { - const float *above = std::lower_bound(begin_ + reserved, end_, value); + const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value); if (above == begin_ + reserved) return reserved; if (above == end_) return end_ - begin_ - 1; return above - begin_ - (value - *(above - 1) < *above - value); } - const float *begin_; + float *begin_; const float *end_; uint8_t bits_; uint64_t mask_; @@ -125,65 +148,61 @@ class SeparatelyQuantize { static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } static uint8_t LongestBits(const Config &config) { return config.prob_bits; } - class Middle { + class MiddlePointer { public: - Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) : - total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {} + MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {} - void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { - util::WriteInt57(base, bit_offset, total_bits_, - (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); - } + MiddlePointer() : address_(NULL, 0) {} - void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { - prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask())); - } + bool Found() const { return address_.base != NULL; } - void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { - uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); - prob = prob_.Decode(both >> backoff_.Bits()); - backoff = backoff_.Decode(both & backoff_.Mask()); + float Prob() const { + return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask())); } - void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { - backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask())); + float Backoff() const { + return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask())); } - uint8_t TotalBits() const { - return total_bits_; + float Rest() const { return Prob(); } + + void Write(float prob, float backoff) const { + util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(), + (ProbBins().EncodeProb(prob) << BackoffBins().Bits()) | BackoffBins().EncodeBackoff(backoff)); } private: - const uint8_t total_bits_; - const uint64_t total_mask_; - const Bins prob_; - const Bins backoff_; + const Bins &ProbBins() const { return bins_[0]; } + const Bins &BackoffBins() const { return bins_[1]; } + const Bins *bins_; + + util::BitAddress address_; }; - class Longest { + class LongestPointer { public: - // Sigh C++ default constructor - Longest() {} + LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {} + + LongestPointer() : address_(NULL, 0) {} - Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {} + bool Found() const { return address_.base != NULL; } - void Write(void *base, uint64_t bit_offset, float prob) const { - util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob)); + void Write(float prob) const { + util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob)); } - void Read(const void *base, uint64_t bit_offset, float &prob) const { - prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask())); + float Prob() const { + return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask())); } - uint8_t TotalBits() const { return prob_.Bits(); } - private: - Bins prob_; + const Bins *table_; + util::BitAddress address_; }; SeparatelyQuantize() {} - void SetupMemory(void *start, const Config &config); + void SetupMemory(void *start, unsigned char order, const Config &config); static const bool kTrain = true; // Assumes 0.0 is removed from backoff. @@ -193,18 +212,17 @@ class SeparatelyQuantize { void FinishedLoading(const Config &config); - Middle Mid(uint8_t order) const { - const float *table = start_ + TableStart(order); - return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength()); - } + const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; } - Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); } + const Bins &LongestTable() const { return longest_; } private: - size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast<uint64_t>(order - 2); } - size_t ProbTableLength() const { return (1ULL << prob_bits_); } + Bins tables_[kMaxOrder - 1][2]; + + Bins longest_; + + uint8_t *actual_base_; - float *start_; uint8_t prob_bits_, backoff_bits_; }; |