summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc27
1 files changed, 14 insertions, 13 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index b830dfc3..7c57072b 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -293,7 +293,7 @@ class SortedFileReader {
ReadOrThrow(file_.get(), &weights, sizeof(Weights));
}
- bool Ended() {
+ bool Ended() const {
return ended_;
}
@@ -480,7 +480,7 @@ void MergeContextFiles(const std::string &first_base, const std::string &second_
CopyRestOrThrow(remaining.GetFile(), out.get());
}
-void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order) {
+void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) {
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
@@ -495,11 +495,11 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
if (order == counts.size()) {
for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size));
+ ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn);
}
} else {
for (; out != out_end; out += entry_size) {
- ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size));
+ ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
}
}
// Sort full records by full n-gram.
@@ -536,13 +536,14 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
}
void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+ PositiveProbWarn warn(config.positive_log_probability);
{
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
// In case <unk> appears.
size_t extra_count = counts[0] + 1;
util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff));
- Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
+ Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
CheckSpecials(config, vocab);
if (!vocab.SawUnk()) ++counts[0];
}
@@ -560,7 +561,7 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
for (unsigned char order = 2; order <= counts.size(); ++order) {
- ConvertToSorted(f, vocab, counts, mem, file_prefix, order);
+ ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn);
}
ReadEnd(f);
}
@@ -775,8 +776,8 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u
}
void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
- SortedFileReader inputs[counts.size() - 1];
- ContextReader contexts[counts.size() - 1];
+ std::vector<SortedFileReader> inputs(counts.size() - 1);
+ std::vector<ContextReader> contexts(counts.size() - 1);
for (unsigned char i = 2; i <= counts.size(); ++i) {
std::stringstream assembled;
@@ -790,11 +791,11 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
std::vector<uint64_t> fixed_counts(counts.size());
{
- RecursiveInsert<JustCount> counter(inputs, contexts, NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
}
- for (SortedFileReader *i = inputs; i < inputs + counts.size() - 1; ++i) {
- if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
+ for (std::vector<SortedFileReader>::const_iterator i = inputs.begin(); i != inputs.end(); ++i) {
+ if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading");
}
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
@@ -807,7 +808,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
UnigramValue *unigrams = out.unigram.Raw();
// Fill entries except unigram probabilities.
{
- RecursiveInsert<WriteEntries> inserter(inputs, contexts, unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<WriteEntries> inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
}
@@ -849,7 +850,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex());
}
out.middle.back().FinishedLoading(out.longest.InsertIndex());
- }
+ }
}
bool IsDirectory(const char *path) {