diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1ce4d278..91f87f1c 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -916,14 +916,19 @@ template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, c FreeMiddles(); middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2))); middle_end_ = middle_begin_ + (counts.size() - 2); + std::vector<uint8_t*> middle_starts(counts.size() - 2); + for (unsigned char i = 2; i < counts.size(); ++i) { + middle_starts[i-2] = start; + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + } + // Crazy backwards thing so we initialize in the correct order. for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( - start, + middle_starts[i-2], quant_.Mid(i), counts[0], counts[i], (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1])); - start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); |