diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 49 |
1 files changed, 25 insertions, 24 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 832cc9f7..1b0d9b26 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -55,7 +55,7 @@ struct ProbPointer { uint64_t index; }; -// Array of n-grams and float indices. +// Array of n-grams and float indices. class BackoffMessages { public: void Init(std::size_t entry_size) { @@ -89,7 +89,7 @@ class BackoffMessages { if (!HasExtension(weights.backoff)) { weights.backoff = kExtensionBackoff; UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed."); - WriteOrThrow(unigrams, &weights, sizeof(weights)); + util::WriteOrThrow(unigrams, &weights, sizeof(weights)); } const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex)); base[write_to.array][write_to.index] += weights.backoff; @@ -100,7 +100,7 @@ class BackoffMessages { void Apply(float *const *const base, RecordReader &reader) { FinishedAdding(); if (current_ == allocated_) return; - // We'll also use the same buffer to record messages to blanks that they extend. + // We'll also use the same buffer to record messages to blanks that they extend. WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_); const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex); for (reader.Rewind(); reader && (current_ != allocated_); ) { @@ -109,7 +109,7 @@ class BackoffMessages { ++reader; break; case 1: - // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. + // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w; current_ += entry_size_; break; @@ -126,7 +126,7 @@ class BackoffMessages { break; } } - // Now this is a list of blanks that extend right. + // Now this is a list of blanks that extend right. entry_size_ = sizeof(WordIndex) * order; Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get())); current_ = (uint8_t*)backing_.get(); @@ -153,7 +153,7 @@ class BackoffMessages { private: void FinishedAdding() { Resize(current_ - (uint8_t*)backing_.get()); - // Sort requests in same order as files. + // Sort requests in same order as files. std::sort( util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)), util::SizedIterator(util::SizedProxy(current_, entry_size_)), @@ -220,7 +220,7 @@ class SRISucks { } private: - // This used to be one array. Then I needed to separate it by order for quantization to work. + // This used to be one array. Then I needed to separate it by order for quantization to work. std::vector<float> values_[KENLM_MAX_ORDER - 1]; BackoffMessages messages_[KENLM_MAX_ORDER - 1]; @@ -253,7 +253,7 @@ class FindBlanks { ++counts_.back(); } - // Unigrams wrote one past. + // Unigrams wrote one past. void Cleanup() { --counts_[0]; } @@ -270,15 +270,15 @@ class FindBlanks { SRISucks &sri_; }; -// Phase to actually write n-grams to the trie. +// Phase to actually write n-grams to the trie. template <class Quant, class Bhiksha> class WriteEntries { public: - WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : + WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), quant_(quant), unigrams_(unigrams), middle_(middle), - longest_(longest), + longest_(longest), bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)), order_(order), sri_(sri) {} @@ -328,7 +328,7 @@ struct Gram { const WordIndex *begin, *end; - // For queue, this is the direction we want. + // For queue, this is the direction we want. bool operator<(const Gram &other) const { return std::lexicographical_compare(other.begin, other.end, begin, end); } @@ -353,7 +353,7 @@ template <class Doing> class BlankManager { been_length_ = length; return; } - // There are blanks to insert starting with order blank. + // There are blanks to insert starting with order blank. unsigned char blank = cur - to + 1; UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); const float *lower_basis; @@ -363,7 +363,7 @@ template <class Doing> class BlankManager { assert(*lower_basis != kBadProb); doing_.MiddleBlank(blank, to, based_on, *lower_basis); *pre = *cur; - // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. + // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. basis_[blank - 1] = kBadProb; } *pre = *cur; @@ -377,7 +377,7 @@ template <class Doing> class BlankManager { unsigned char been_length_; float basis_[KENLM_MAX_ORDER]; - + Doing &doing_; }; @@ -451,7 +451,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, Re } void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { - // Fill unigram probabilities. + // Fill unigram probabilities. try { rewind(file); for (WordIndex i = 0; i < unigram_count; ++i) { @@ -486,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve util::scoped_memory unigrams; MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder); fixed_counts = finder.Counts(); } unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); @@ -504,7 +504,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve inputs[i-2].Rewind(); } if (Quant::kTrain) { - util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); + util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), + config.ProgressMessages(), "Quantizing"); for (unsigned char i = 2; i < counts.size(); ++i) { TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } @@ -519,13 +520,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); } - // Fill entries except unigram probabilities. + // Fill entries except unigram probabilities. { WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); } - // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. + // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. for (unsigned char order = 2; order <= counts.size(); ++order) { const RecordReader &context = contexts[order - 2]; if (context) { @@ -541,13 +542,13 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve } /* Set ending offsets so the last entry will be sized properly */ - // Last entry for unigrams was already set. + // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { i->FinishedLoading((i+1)->InsertIndex(), config); } (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); - } + } } template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { @@ -595,7 +596,7 @@ template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::Initializ } else { temporary_prefix = file; } - // At least 1MB sorting memory. + // At least 1MB sorting memory. SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab); BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); |