diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-05-16 13:24:08 -0700 |
---|---|---|
committer | Chris Dyer <cdyer@cab.ark.cs.cmu.edu> | 2012-05-26 22:59:54 -0400 |
commit | 149232c38eec558ddb1097698d1570aacb67b59f (patch) | |
tree | 5860b4d6f681eeb04a1020cbb2fe7e6ac394af99 /klm/lm/search_trie.cc | |
parent | 01ecc09f8e3a82c32bf7dd2f90c12554becea71d (diff) |
Big kenlm change includes lower order models for probing only. And other stuff.
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 38 |
1 files changed, 20 insertions, 18 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index ffadfa94..18e80d5a 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -273,8 +273,9 @@ class FindBlanks { // Phase to actually write n-grams to the trie. template <class Quant, class Bhiksha> class WriteEntries { public: - WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order, SRISucks &sri) : + WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), + quant_(quant), unigrams_(unigrams), middle_(middle), longest_(longest), @@ -290,7 +291,7 @@ template <class Quant, class Bhiksha> class WriteEntries { void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) { ProbBackoff weights = sri_.GetBlank(order_, order, indices); - middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff); + typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(indices[order - 1])).Write(weights.prob, weights.backoff); } void Middle(const unsigned char order, const void *data) { @@ -301,21 +302,22 @@ template <class Quant, class Bhiksha> class WriteEntries { SetExtension(weights.backoff); ++context; } - middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff); + typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(words[order - 1])).Write(weights.prob, weights.backoff); } void Longest(const void *data) { const WordIndex *words = reinterpret_cast<const WordIndex*>(data); - longest_.Insert(words[order_ - 1], reinterpret_cast<const Prob*>(words + order_)->prob); + typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob); } void Cleanup() {} private: RecordReader *contexts_; + const Quant &quant_; UnigramValue *const unigrams_; - BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_; - BitPackedLongest<typename Quant::Longest> &longest_; + BitPackedMiddle<Bhiksha> *const middle_; + BitPackedLongest &longest_; BitPacked &bigram_pack_; const unsigned char order_; SRISucks &sri_; @@ -380,7 +382,7 @@ template <class Doing> class BlankManager { }; template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { - util::ErsatzProgress progress(progress_out, message, unigram_count + 1); + util::ErsatzProgress progress(unigram_count + 1, progress_out, message); WordIndex unigram = 0; std::priority_queue<Gram> grams; grams.push(Gram(&unigram, 1)); @@ -502,7 +504,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve inputs[i-2].Rewind(); } if (Quant::kTrain) { - util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); for (unsigned char i = 2; i < counts.size(); ++i) { TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } @@ -510,7 +512,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve quant.FinishedLoading(config); } - UnigramValue *unigrams = out.unigram.Raw(); + UnigramValue *unigrams = out.unigram_.Raw(); PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams); unigram_file.reset(); @@ -519,7 +521,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve } // Fill entries except unigram probabilities. { - WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); + WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); } @@ -544,14 +546,14 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { i->FinishedLoading((i+1)->InsertIndex(), config); } - (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config); + (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); } } template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { - quant_.SetupMemory(start, config); + quant_.SetupMemory(start, counts.size(), config); start += Quant::Size(counts.size(), config); - unigram.Init(start); + unigram_.Init(start); start += Unigram::Size(counts[0]); FreeMiddles(); middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2))); @@ -565,23 +567,23 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( middle_starts[i-2], - quant_.Mid(i), + quant_.MiddleBits(config), counts[i-1], counts[0], counts[i], - (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]), + (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest_) : static_cast<const BitPacked &>(middle_begin_[i-1]), config); } - longest.Init(start, quant_.Long(counts.size()), counts[0]); + longest_.Init(start, quant_.LongestBits(config), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { - unigram.LoadedBinary(); + unigram_.LoadedBinary(); for (Middle *i = middle_begin_; i != middle_end_; ++i) { i->LoadedBinary(); } - longest.LoadedBinary(); + longest_.LoadedBinary(); } template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { |