summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/quantize.cc2
-rw-r--r--klm/lm/search_trie.cc9
2 files changed, 8 insertions, 3 deletions
diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc
index b4d76893..4bb6b1b8 100644
--- a/klm/lm/quantize.cc
+++ b/klm/lm/quantize.cc
@@ -34,7 +34,7 @@ void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) {
}
}
-const char kSeparatelyQuantizeVersion = 1;
+const char kSeparatelyQuantizeVersion = 2;
} // namespace
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 1ce4d278..91f87f1c 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -916,14 +916,19 @@ template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, c
FreeMiddles();
middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2)));
middle_end_ = middle_begin_ + (counts.size() - 2);
+ std::vector<uint8_t*> middle_starts(counts.size() - 2);
+ for (unsigned char i = 2; i < counts.size(); ++i) {
+ middle_starts[i-2] = start;
+ start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]);
+ }
+ // Crazy backwards thing so we initialize in the correct order.
for (unsigned char i = counts.size() - 1; i >= 2; --i) {
new (middle_begin_ + i - 2) Middle(
- start,
+ middle_starts[i-2],
quant_.Mid(i),
counts[0],
counts[i],
(i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]));
- start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]);
}
longest.Init(start, quant_.Long(counts.size()), counts[0]);
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);