From 149232c38eec558ddb1097698d1570aacb67b59f Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 16 May 2012 13:24:08 -0700 Subject: Big kenlm change includes lower order models for probing only. And other stuff. --- klm/lm/search_trie.cc | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) (limited to 'klm/lm/search_trie.cc') diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index ffadfa94..18e80d5a 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -273,8 +273,9 @@ class FindBlanks { // Phase to actually write n-grams to the trie. template class WriteEntries { public: - WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : + WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), + quant_(quant), unigrams_(unigrams), middle_(middle), longest_(longest), @@ -290,7 +291,7 @@ template class WriteEntries { void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) { ProbBackoff weights = sri_.GetBlank(order_, order, indices); - middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff); + typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(indices[order - 1])).Write(weights.prob, weights.backoff); } void Middle(const unsigned char order, const void *data) { @@ -301,21 +302,22 @@ template class WriteEntries { SetExtension(weights.backoff); ++context; } - middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff); + typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(words[order - 1])).Write(weights.prob, weights.backoff); } void Longest(const void *data) { const WordIndex *words = reinterpret_cast(data); - longest_.Insert(words[order_ - 1], reinterpret_cast(words + order_)->prob); + typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast(words + order_)->prob); } void Cleanup() {} private: RecordReader *contexts_; + const Quant &quant_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; - BitPackedLongest &longest_; + BitPackedMiddle *const middle_; + BitPackedLongest &longest_; BitPacked &bigram_pack_; const unsigned char order_; SRISucks &sri_; @@ -380,7 +382,7 @@ template class BlankManager { }; template void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { - util::ErsatzProgress progress(progress_out, message, unigram_count + 1); + util::ErsatzProgress progress(unigram_count + 1, progress_out, message); WordIndex unigram = 0; std::priority_queue grams; grams.push(Gram(&unigram, 1)); @@ -502,7 +504,7 @@ template void BuildTrie(SortedFiles &files, std::ve inputs[i-2].Rewind(); } if (Quant::kTrain) { - util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); for (unsigned char i = 2; i < counts.size(); ++i) { TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } @@ -510,7 +512,7 @@ template void BuildTrie(SortedFiles &files, std::ve quant.FinishedLoading(config); } - UnigramValue *unigrams = out.unigram.Raw(); + UnigramValue *unigrams = out.unigram_.Raw(); PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams); unigram_file.reset(); @@ -519,7 +521,7 @@ template void BuildTrie(SortedFiles &files, std::ve } // Fill entries except unigram probabilities. { - WriteEntries writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); + WriteEntries writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); } @@ -544,14 +546,14 @@ template void BuildTrie(SortedFiles &files, std::ve for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { i->FinishedLoading((i+1)->InsertIndex(), config); } - (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config); + (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); } } template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { - quant_.SetupMemory(start, config); + quant_.SetupMemory(start, counts.size(), config); start += Quant::Size(counts.size(), config); - unigram.Init(start); + unigram_.Init(start); start += Unigram::Size(counts[0]); FreeMiddles(); middle_begin_ = static_cast(malloc(sizeof(Middle) * (counts.size() - 2))); @@ -565,23 +567,23 @@ template uint8_t *TrieSearch::Setup for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( middle_starts[i-2], - quant_.Mid(i), + quant_.MiddleBits(config), counts[i-1], counts[0], counts[i], - (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1]), + (i == counts.size() - 1) ? static_cast(longest_) : static_cast(middle_begin_[i-1]), config); } - longest.Init(start, quant_.Long(counts.size()), counts[0]); + longest_.Init(start, quant_.LongestBits(config), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } template void TrieSearch::LoadedBinary() { - unigram.LoadedBinary(); + unigram_.LoadedBinary(); for (Middle *i = middle_begin_; i != middle_end_; ++i) { i->LoadedBinary(); } - longest.LoadedBinary(); + longest_.LoadedBinary(); } template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { -- cgit v1.2.3