summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc38
1 files changed, 20 insertions, 18 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index ffadfa94..18e80d5a 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -273,8 +273,9 @@ class FindBlanks {
// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order, SRISucks &sri) :
+ WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
contexts_(contexts),
+ quant_(quant),
unigrams_(unigrams),
middle_(middle),
longest_(longest),
@@ -290,7 +291,7 @@ template <class Quant, class Bhiksha> class WriteEntries {
void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) {
ProbBackoff weights = sri_.GetBlank(order_, order, indices);
- middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff);
+ typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(indices[order - 1])).Write(weights.prob, weights.backoff);
}
void Middle(const unsigned char order, const void *data) {
@@ -301,21 +302,22 @@ template <class Quant, class Bhiksha> class WriteEntries {
SetExtension(weights.backoff);
++context;
}
- middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff);
+ typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(words[order - 1])).Write(weights.prob, weights.backoff);
}
void Longest(const void *data) {
const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
- longest_.Insert(words[order_ - 1], reinterpret_cast<const Prob*>(words + order_)->prob);
+ typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);
}
void Cleanup() {}
private:
RecordReader *contexts_;
+ const Quant &quant_;
UnigramValue *const unigrams_;
- BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_;
- BitPackedLongest<typename Quant::Longest> &longest_;
+ BitPackedMiddle<Bhiksha> *const middle_;
+ BitPackedLongest &longest_;
BitPacked &bigram_pack_;
const unsigned char order_;
SRISucks &sri_;
@@ -380,7 +382,7 @@ template <class Doing> class BlankManager {
};
template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) {
- util::ErsatzProgress progress(progress_out, message, unigram_count + 1);
+ util::ErsatzProgress progress(unigram_count + 1, progress_out, message);
WordIndex unigram = 0;
std::priority_queue<Gram> grams;
grams.push(Gram(&unigram, 1));
@@ -502,7 +504,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
inputs[i-2].Rewind();
}
if (Quant::kTrain) {
- util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
+ util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing");
for (unsigned char i = 2; i < counts.size(); ++i) {
TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
}
@@ -510,7 +512,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
quant.FinishedLoading(config);
}
- UnigramValue *unigrams = out.unigram.Raw();
+ UnigramValue *unigrams = out.unigram_.Raw();
PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams);
unigram_file.reset();
@@ -519,7 +521,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
}
// Fill entries except unigram probabilities.
{
- WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri);
+ WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer);
}
@@ -544,14 +546,14 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve
for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
i->FinishedLoading((i+1)->InsertIndex(), config);
}
- (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config);
+ (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config);
}
}
template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
- quant_.SetupMemory(start, config);
+ quant_.SetupMemory(start, counts.size(), config);
start += Quant::Size(counts.size(), config);
- unigram.Init(start);
+ unigram_.Init(start);
start += Unigram::Size(counts[0]);
FreeMiddles();
middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2)));
@@ -565,23 +567,23 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup
for (unsigned char i = counts.size() - 1; i >= 2; --i) {
new (middle_begin_ + i - 2) Middle(
middle_starts[i-2],
- quant_.Mid(i),
+ quant_.MiddleBits(config),
counts[i-1],
counts[0],
counts[i],
- (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]),
+ (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest_) : static_cast<const BitPacked &>(middle_begin_[i-1]),
config);
}
- longest.Init(start, quant_.Long(counts.size()), counts[0]);
+ longest_.Init(start, quant_.LongestBits(config), counts[0]);
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
- unigram.LoadedBinary();
+ unigram_.LoadedBinary();
for (Middle *i = middle_begin_; i != middle_end_; ++i) {
i->LoadedBinary();
}
- longest.LoadedBinary();
+ longest_.LoadedBinary();
}
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {