From 783c57b2d3312738ddcf992ac55ff750afe7cb47 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 27 Jan 2014 17:42:19 -0800 Subject: KenLM 5cc905bc2d214efa7de2db56a9a672b749a95591 --- klm/lm/search_trie.cc | 35 +++++++++++------------------------ 1 file changed, 11 insertions(+), 24 deletions(-) (limited to 'klm/lm/search_trie.cc') diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1b0d9b26..4a88194e 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -253,11 +253,6 @@ class FindBlanks { ++counts_.back(); } - // Unigrams wrote one past. - void Cleanup() { - --counts_[0]; - } - const std::vector &Counts() const { return counts_; } @@ -310,8 +305,6 @@ template class WriteEntries { typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast(words + order_)->prob); } - void Cleanup() {} - private: RecordReader *contexts_; const Quant &quant_; @@ -385,14 +378,14 @@ template void RecursiveInsert(const unsigned char total_order, con util::ErsatzProgress progress(unigram_count + 1, progress_out, message); WordIndex unigram = 0; std::priority_queue grams; - grams.push(Gram(&unigram, 1)); + if (unigram_count) grams.push(Gram(&unigram, 1)); for (unsigned char i = 2; i <= total_order; ++i) { if (input[i-2]) grams.push(Gram(reinterpret_cast(input[i-2].Data()), i)); } BlankManager blank(total_order, doing); - while (true) { + while (!grams.empty()) { Gram top = grams.top(); grams.pop(); unsigned char order = top.end - top.begin; @@ -400,8 +393,7 @@ template void RecursiveInsert(const unsigned char total_order, con blank.Visit(&unigram, 1, doing.UnigramProb(unigram)); doing.Unigram(unigram); progress.Set(unigram); - if (++unigram == unigram_count + 1) break; - grams.push(top); + if (++unigram < unigram_count) grams.push(top); } else { if (order == total_order) { blank.Visit(top.begin, order, reinterpret_cast(top.end)->prob); @@ -414,8 +406,6 @@ template void RecursiveInsert(const unsigned char total_order, con if (++reader) grams.push(top); } } - assert(grams.empty()); - doing.Cleanup(); } void SanityCheckCounts(const std::vector &initial, const std::vector &fixed) { @@ -469,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c } // namespace -template void BuildTrie(SortedFiles &files, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template void BuildTrie(SortedFiles &files, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) { RecordReader inputs[KENLM_MAX_ORDER - 1]; RecordReader contexts[KENLM_MAX_ORDER - 1]; @@ -498,7 +488,10 @@ template void BuildTrie(SortedFiles &files, std::ve sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); - out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + void *vocab_relocate; + void *search_base = backing.GrowForSearch(TrieSearch::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate); + vocab.Relocate(vocab_relocate); + out.SetupMemory(reinterpret_cast(search_base), fixed_counts, config); for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -524,6 +517,8 @@ template void BuildTrie(SortedFiles &files, std::ve { WriteEntries writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); + // Write the last unigram entry, which is the end pointer for the bigrams. + writer.Unigram(counts[0]); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -579,15 +574,7 @@ template uint8_t *TrieSearch::Setup return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template void TrieSearch::LoadedBinary() { - unigram_.LoadedBinary(); - for (Middle *i = middle_begin_; i != middle_end_; ++i) { - i->LoadedBinary(); - } - longest_.LoadedBinary(); -} - -template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) { std::string temporary_prefix; if (config.temporary_directory_prefix) { temporary_prefix = config.temporary_directory_prefix; -- cgit v1.2.3