summaryrefslogtreecommitdiff
path: root/klm/lm/quantize.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/quantize.hh')
-rw-r--r--klm/lm/quantize.hh13
1 files changed, 10 insertions, 3 deletions
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh
index 0b71d14a..4cf4236e 100644
--- a/klm/lm/quantize.hh
+++ b/klm/lm/quantize.hh
@@ -1,9 +1,9 @@
#ifndef LM_QUANTIZE_H__
#define LM_QUANTIZE_H__
-#include "lm/binary_format.hh" // for ModelType
#include "lm/blank.hh"
#include "lm/config.hh"
+#include "lm/model_type.hh"
#include "util/bit_packing.hh"
#include <algorithm>
@@ -36,6 +36,9 @@ class DontQuantize {
prob = util::ReadNonPositiveFloat31(base, bit_offset);
backoff = util::ReadFloat32(base, bit_offset + 31);
}
+ void ReadProb(const void *base, uint64_t bit_offset, float &prob) const {
+ prob = util::ReadNonPositiveFloat31(base, bit_offset);
+ }
void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const {
backoff = util::ReadFloat32(base, bit_offset + 31);
}
@@ -77,7 +80,7 @@ class SeparatelyQuantize {
Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
uint64_t EncodeProb(float value) const {
- return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1));
+ return Encode(value, 0);
}
uint64_t EncodeBackoff(float value) const {
@@ -132,6 +135,10 @@ class SeparatelyQuantize {
(prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff));
}
+ void ReadProb(const void *base, uint64_t bit_offset, float &prob) const {
+ prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask()));
+ }
+
void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const {
uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_);
prob = prob_.Decode(both >> backoff_.Bits());
@@ -179,7 +186,7 @@ class SeparatelyQuantize {
void SetupMemory(void *start, const Config &config);
static const bool kTrain = true;
- // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff.
+ // Assumes 0.0 is removed from backoff.
void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff);
// Train just probabilities (for longest order).
void TrainProb(uint8_t order, std::vector<float> &prob);