diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 35 |
1 files changed, 11 insertions, 24 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1b0d9b26..4a88194e 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -253,11 +253,6 @@ class FindBlanks { ++counts_.back(); } - // Unigrams wrote one past. - void Cleanup() { - --counts_[0]; - } - const std::vector<uint64_t> &Counts() const { return counts_; } @@ -310,8 +305,6 @@ template <class Quant, class Bhiksha> class WriteEntries { typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob); } - void Cleanup() {} - private: RecordReader *contexts_; const Quant &quant_; @@ -385,14 +378,14 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con util::ErsatzProgress progress(unigram_count + 1, progress_out, message); WordIndex unigram = 0; std::priority_queue<Gram> grams; - grams.push(Gram(&unigram, 1)); + if (unigram_count) grams.push(Gram(&unigram, 1)); for (unsigned char i = 2; i <= total_order; ++i) { if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i)); } BlankManager<Doing> blank(total_order, doing); - while (true) { + while (!grams.empty()) { Gram top = grams.top(); grams.pop(); unsigned char order = top.end - top.begin; @@ -400,8 +393,7 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con blank.Visit(&unigram, 1, doing.UnigramProb(unigram)); doing.Unigram(unigram); progress.Set(unigram); - if (++unigram == unigram_count + 1) break; - grams.push(top); + if (++unigram < unigram_count) grams.push(top); } else { if (order == total_order) { blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob); @@ -414,8 +406,6 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con if (++reader) grams.push(top); } } - assert(grams.empty()); - doing.Cleanup(); } void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) { @@ -469,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c } // namespace -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) { RecordReader inputs[KENLM_MAX_ORDER - 1]; RecordReader contexts[KENLM_MAX_ORDER - 1]; @@ -498,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); - out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); + void *vocab_relocate; + void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate); + vocab.Relocate(vocab_relocate); + out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config); for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -524,6 +517,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve { WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); + // Write the last unigram entry, which is the end pointer for the bigrams. + writer.Unigram(counts[0]); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -579,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { - unigram_.LoadedBinary(); - for (Middle *i = middle_begin_; i != middle_end_; ++i) { - i->LoadedBinary(); - } - longest_.LoadedBinary(); -} - -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) { std::string temporary_prefix; if (config.temporary_directory_prefix) { temporary_prefix = config.temporary_directory_prefix; |