diff options
Diffstat (limited to 'klm/lm/quantize.cc')
-rw-r--r-- | klm/lm/quantize.cc | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index a8e0cb21..b58c3f3f 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -47,9 +47,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64 util::AdvanceOrThrow(fd, -3); } -void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { - // Reserve 8 byte header for bit counts. - start_ = reinterpret_cast<float*>(static_cast<uint8_t*>(start) + 8); +void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) { prob_bits_ = config.prob_bits; backoff_bits_ = config.backoff_bits; // We need the reserved values. @@ -57,25 +55,35 @@ void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { if (config.backoff_bits == 0) UTIL_THROW(ConfigException, "You can't quantize backoff to zero"); if (config.prob_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing probability supports at most 25 bits. Currently you have requested " << static_cast<unsigned>(config.prob_bits) << " bits."); if (config.backoff_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing backoff supports at most 25 bits. Currently you have requested " << static_cast<unsigned>(config.backoff_bits) << " bits."); + // Reserve 8 byte header for bit counts. + actual_base_ = static_cast<uint8_t*>(base); + float *start = reinterpret_cast<float*>(actual_base_ + 8); + for (unsigned char i = 0; i < order - 2; ++i) { + tables_[i][0] = Bins(prob_bits_, start); + start += (1ULL << prob_bits_); + tables_[i][1] = Bins(backoff_bits_, start); + start += (1ULL << backoff_bits_); + } + longest_ = tables_[order - 2][0] = Bins(prob_bits_, start); } void SeparatelyQuantize::Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff) { TrainProb(order, prob); // Backoff - float *centers = start_ + TableStart(order) + ProbTableLength(); + float *centers = tables_[order - 2][1].Populate(); *(centers++) = kNoExtensionBackoff; *(centers++) = kExtensionBackoff; MakeBins(backoff, centers, (1ULL << backoff_bits_) - 2); } void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) { - float *centers = start_ + TableStart(order); + float *centers = tables_[order - 2][0].Populate(); MakeBins(prob, centers, (1ULL << prob_bits_)); } void SeparatelyQuantize::FinishedLoading(const Config &config) { - uint8_t *actual_base = reinterpret_cast<uint8_t*>(start_) - 8; + uint8_t *actual_base = actual_base_; *(actual_base++) = kSeparatelyQuantizeVersion; // version *(actual_base++) = config.prob_bits; *(actual_base++) = config.backoff_bits; |