summaryrefslogtreecommitdiff
path: root/klm/lm/quantize.hh
diff options
context:
space:
mode:
authorPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-05-31 13:57:24 +0200
committerPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-05-31 13:57:24 +0200
commit6f6601111710aa67eee5169e5b7d89102cc33bb8 (patch)
tree0872544abd6bc76162f3f80eb3920999afbf2c34 /klm/lm/quantize.hh
parent8cee8b565a9c56a7732365e9563f52ff3c4ff7fd (diff)
parent090a64e73f94a6a35e5364a9d416dcf75c0a2938 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/lm/quantize.hh')
-rw-r--r--klm/lm/quantize.hh164
1 files changed, 91 insertions, 73 deletions
diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh
index 6d130a57..3e9153e3 100644
--- a/klm/lm/quantize.hh
+++ b/klm/lm/quantize.hh
@@ -3,6 +3,7 @@
#include "lm/blank.hh"
#include "lm/config.hh"
+#include "lm/max_order.hh"
#include "lm/model_type.hh"
#include "util/bit_packing.hh"
@@ -27,37 +28,60 @@ class DontQuantize {
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
- struct Middle {
- void Write(void *base, uint64_t bit_offset, float prob, float backoff) const {
- util::WriteNonPositiveFloat31(base, bit_offset, prob);
- util::WriteFloat32(base, bit_offset + 31, backoff);
- }
- void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const {
- prob = util::ReadNonPositiveFloat31(base, bit_offset);
- backoff = util::ReadFloat32(base, bit_offset + 31);
- }
- void ReadProb(const void *base, uint64_t bit_offset, float &prob) const {
- prob = util::ReadNonPositiveFloat31(base, bit_offset);
- }
- void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const {
- backoff = util::ReadFloat32(base, bit_offset + 31);
- }
- uint8_t TotalBits() const { return 63; }
+ class MiddlePointer {
+ public:
+ MiddlePointer(const DontQuantize & /*quant*/, unsigned char /*order_minus_2*/, util::BitAddress address) : address_(address) {}
+
+ MiddlePointer() : address_(NULL, 0) {}
+
+ bool Found() const {
+ return address_.base != NULL;
+ }
+
+ float Prob() const {
+ return util::ReadNonPositiveFloat31(address_.base, address_.offset);
+ }
+
+ float Backoff() const {
+ return util::ReadFloat32(address_.base, address_.offset + 31);
+ }
+
+ float Rest() const { return Prob(); }
+
+ void Write(float prob, float backoff) {
+ util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
+ util::WriteFloat32(address_.base, address_.offset + 31, backoff);
+ }
+
+ private:
+ util::BitAddress address_;
};
- struct Longest {
- void Write(void *base, uint64_t bit_offset, float prob) const {
- util::WriteNonPositiveFloat31(base, bit_offset, prob);
- }
- void Read(const void *base, uint64_t bit_offset, float &prob) const {
- prob = util::ReadNonPositiveFloat31(base, bit_offset);
- }
- uint8_t TotalBits() const { return 31; }
+ class LongestPointer {
+ public:
+ explicit LongestPointer(const DontQuantize &/*quant*/, util::BitAddress address) : address_(address) {}
+
+ LongestPointer() : address_(NULL, 0) {}
+
+ bool Found() const {
+ return address_.base != NULL;
+ }
+
+ float Prob() const {
+ return util::ReadNonPositiveFloat31(address_.base, address_.offset);
+ }
+
+ void Write(float prob) {
+ util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
+ }
+
+ private:
+ util::BitAddress address_;
};
DontQuantize() {}
- void SetupMemory(void * /*start*/, const Config & /*config*/) {}
+ void SetupMemory(void * /*start*/, unsigned char /*order*/, const Config & /*config*/) {}
static const bool kTrain = false;
// These should never be called because kTrain is false.
@@ -65,9 +89,6 @@ class DontQuantize {
void TrainProb(uint8_t, std::vector<float> &/*prob*/) {}
void FinishedLoading(const Config &) {}
-
- Middle Mid(uint8_t /*order*/) const { return Middle(); }
- Longest Long(uint8_t /*order*/) const { return Longest(); }
};
class SeparatelyQuantize {
@@ -77,7 +98,9 @@ class SeparatelyQuantize {
// Sigh C++ default constructor
Bins() {}
- Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
+ Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
+
+ float *Populate() { return begin_; }
uint64_t EncodeProb(float value) const {
return Encode(value, 0);
@@ -98,13 +121,13 @@ class SeparatelyQuantize {
private:
uint64_t Encode(float value, size_t reserved) const {
- const float *above = std::lower_bound(begin_ + reserved, end_, value);
+ const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value);
if (above == begin_ + reserved) return reserved;
if (above == end_) return end_ - begin_ - 1;
return above - begin_ - (value - *(above - 1) < *above - value);
}
- const float *begin_;
+ float *begin_;
const float *end_;
uint8_t bits_;
uint64_t mask_;
@@ -125,65 +148,61 @@ class SeparatelyQuantize {
static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; }
static uint8_t LongestBits(const Config &config) { return config.prob_bits; }
- class Middle {
+ class MiddlePointer {
public:
- Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) :
- total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {}
+ MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {}
- void Write(void *base, uint64_t bit_offset, float prob, float backoff) const {
- util::WriteInt57(base, bit_offset, total_bits_,
- (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff));
- }
+ MiddlePointer() : address_(NULL, 0) {}
- void ReadProb(const void *base, uint64_t bit_offset, float &prob) const {
- prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask()));
- }
+ bool Found() const { return address_.base != NULL; }
- void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const {
- uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_);
- prob = prob_.Decode(both >> backoff_.Bits());
- backoff = backoff_.Decode(both & backoff_.Mask());
+ float Prob() const {
+ return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask()));
}
- void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const {
- backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask()));
+ float Backoff() const {
+ return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask()));
}
- uint8_t TotalBits() const {
- return total_bits_;
+ float Rest() const { return Prob(); }
+
+ void Write(float prob, float backoff) const {
+ util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(),
+ (ProbBins().EncodeProb(prob) << BackoffBins().Bits()) | BackoffBins().EncodeBackoff(backoff));
}
private:
- const uint8_t total_bits_;
- const uint64_t total_mask_;
- const Bins prob_;
- const Bins backoff_;
+ const Bins &ProbBins() const { return bins_[0]; }
+ const Bins &BackoffBins() const { return bins_[1]; }
+ const Bins *bins_;
+
+ util::BitAddress address_;
};
- class Longest {
+ class LongestPointer {
public:
- // Sigh C++ default constructor
- Longest() {}
+ LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {}
+
+ LongestPointer() : address_(NULL, 0) {}
- Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {}
+ bool Found() const { return address_.base != NULL; }
- void Write(void *base, uint64_t bit_offset, float prob) const {
- util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob));
+ void Write(float prob) const {
+ util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob));
}
- void Read(const void *base, uint64_t bit_offset, float &prob) const {
- prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask()));
+ float Prob() const {
+ return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask()));
}
- uint8_t TotalBits() const { return prob_.Bits(); }
-
private:
- Bins prob_;
+ const Bins *table_;
+ util::BitAddress address_;
};
SeparatelyQuantize() {}
- void SetupMemory(void *start, const Config &config);
+ void SetupMemory(void *start, unsigned char order, const Config &config);
static const bool kTrain = true;
// Assumes 0.0 is removed from backoff.
@@ -193,18 +212,17 @@ class SeparatelyQuantize {
void FinishedLoading(const Config &config);
- Middle Mid(uint8_t order) const {
- const float *table = start_ + TableStart(order);
- return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength());
- }
+ const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; }
- Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); }
+ const Bins &LongestTable() const { return longest_; }
private:
- size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast<uint64_t>(order - 2); }
- size_t ProbTableLength() const { return (1ULL << prob_bits_); }
+ Bins tables_[kMaxOrder - 1][2];
+
+ Bins longest_;
+
+ uint8_t *actual_base_;
- float *start_;
uint8_t prob_bits_, backoff_bits_;
};