diff options
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r-- | klm/lm/search_trie.cc | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index b830dfc3..7c57072b 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -293,7 +293,7 @@ class SortedFileReader { ReadOrThrow(file_.get(), &weights, sizeof(Weights)); } - bool Ended() { + bool Ended() const { return ended_; } @@ -480,7 +480,7 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_ CopyRestOrThrow(remaining.GetFile(), out.get()); } -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) { +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { ReadNGramHeader(f, order); const size_t count = counts[order - 1]; // Size of weights. Does it include backoff? @@ -495,11 +495,11 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; if (order == counts.size()) { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size)); + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn); } } else { for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size)); + ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn); } } // Sort full records by full n-gram. @@ -536,13 +536,14 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st } void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + PositiveProbWarn warn(config.positive_log_probability); { std::string unigram_name = file_prefix + "unigrams"; util::scoped_fd unigram_file; // In case <unk> appears. size_t extra_count = counts[0] + 1; util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); - Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); + Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; } @@ -560,7 +561,7 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, mem, file_prefix, order); + ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); } ReadEnd(f); } @@ -775,8 +776,8 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u } void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) { - SortedFileReader inputs[counts.size() - 1]; - ContextReader contexts[counts.size() - 1]; + std::vector<SortedFileReader> inputs(counts.size() - 1); + std::vector<ContextReader> contexts(counts.size() - 1); for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; @@ -790,11 +791,11 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co std::vector<uint64_t> fixed_counts(counts.size()); { - RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } - for (SortedFileReader *i = inputs; i < inputs + counts.size() - 1; ++i) { - if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); + for (std::vector<SortedFileReader>::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { + if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; @@ -807,7 +808,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert<WriteEntries> inserter(inputs, contexts, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert<WriteEntries> inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -849,7 +850,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); } out.middle.back().FinishedLoading(out.longest.InsertIndex()); - } + } } bool IsDirectory(const char *path) { |