diff options
Diffstat (limited to 'klm/lm/quantize.hh')
-rw-r--r-- | klm/lm/quantize.hh | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 0b71d14a..4cf4236e 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -1,9 +1,9 @@ #ifndef LM_QUANTIZE_H__ #define LM_QUANTIZE_H__ -#include "lm/binary_format.hh" // for ModelType #include "lm/blank.hh" #include "lm/config.hh" +#include "lm/model_type.hh" #include "util/bit_packing.hh" #include <algorithm> @@ -36,6 +36,9 @@ class DontQuantize { prob = util::ReadNonPositiveFloat31(base, bit_offset); backoff = util::ReadFloat32(base, bit_offset + 31); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { backoff = util::ReadFloat32(base, bit_offset + 31); } @@ -77,7 +80,7 @@ class SeparatelyQuantize { Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} uint64_t EncodeProb(float value) const { - return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + return Encode(value, 0); } uint64_t EncodeBackoff(float value) const { @@ -132,6 +135,10 @@ class SeparatelyQuantize { (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask())); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); prob = prob_.Decode(both >> backoff_.Bits()); @@ -179,7 +186,7 @@ class SeparatelyQuantize { void SetupMemory(void *start, const Config &config); static const bool kTrain = true; - // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + // Assumes 0.0 is removed from backoff. void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff); // Train just probabilities (for longest order). void TrainProb(uint8_t order, std::vector<float> &prob); |