summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc45
1 files changed, 25 insertions, 20 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 91f87f1c..05059ffb 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -1,6 +1,7 @@
/* This is where the trie is built. It's on-disk. */
#include "lm/search_trie.hh"
+#include "lm/bhiksha.hh"
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/max_order.hh"
@@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin
std::string unigram_name = file_prefix + "unigrams";
util::scoped_fd unigram_file;
// In case <unk> appears.
- size_t extra_count = counts[0] + 1;
- util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff));
+ size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff);
+ util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out);
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
CheckSpecials(config, vocab);
if (!vocab.SawUnk()) ++counts[0];
@@ -610,9 +611,9 @@ class JustCount {
};
// Phase to actually write n-grams to the trie.
-template <class Quant> class WriteEntries {
+template <class Quant, class Bhiksha> class WriteEntries {
public:
- WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :
+ WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :
contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
@@ -649,7 +650,7 @@ template <class Quant> class WriteEntries {
private:
ContextReader *contexts_;
UnigramValue *const unigrams_;
- BitPackedMiddle<typename Quant::Middle> *const middle_;
+ BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_;
BitPackedLongest<typename Quant::Longest> &longest_;
BitPacked &bigram_pack_;
};
@@ -821,7 +822,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, So
} // namespace
-template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing) {
+template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {
std::vector<SortedFileReader> inputs(counts.size() - 1);
std::vector<ContextReader> contexts(counts.size() - 1);
@@ -846,7 +847,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
- out.SetupMemory(GrowForSearch(config, TrieSearch<Quant>::Size(fixed_counts, config), backing), fixed_counts, config);
+ out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);
if (Quant::kTrain) {
util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
@@ -863,7 +864,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto
UnigramValue *unigrams = out.unigram.Raw();
// Fill entries except unigram probabilities.
{
- RecursiveInsert<WriteEntries<Quant> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
}
@@ -901,14 +902,14 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto
/* Set ending offsets so the last entry will be sized properly */
// Last entry for unigrams was already set.
if (out.middle_begin_ != out.middle_end_) {
- for (typename TrieSearch<Quant>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
- i->FinishedLoading((i+1)->InsertIndex());
+ for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
+ i->FinishedLoading((i+1)->InsertIndex(), config);
}
- (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex());
+ (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config);
}
}
-template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
quant_.SetupMemory(start, config);
start += Quant::Size(counts.size(), config);
unigram.Init(start);
@@ -919,22 +920,24 @@ template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, c
std::vector<uint8_t*> middle_starts(counts.size() - 2);
for (unsigned char i = 2; i < counts.size(); ++i) {
middle_starts[i-2] = start;
- start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]);
+ start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config);
}
- // Crazy backwards thing so we initialize in the correct order.
+ // Crazy backwards thing so we initialize using pointers to ones that have already been initialized
for (unsigned char i = counts.size() - 1; i >= 2; --i) {
new (middle_begin_ + i - 2) Middle(
middle_starts[i-2],
quant_.Mid(i),
+ counts[i-1],
counts[0],
counts[i],
- (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]));
+ (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]),
+ config);
}
longest.Init(start, quant_.Long(counts.size()), counts[0]);
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
-template <class Quant> void TrieSearch<Quant>::LoadedBinary() {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {
unigram.LoadedBinary();
for (Middle *i = middle_begin_; i != middle_end_; ++i) {
i->LoadedBinary();
@@ -942,7 +945,7 @@ template <class Quant> void TrieSearch<Quant>::LoadedBinary() {
longest.LoadedBinary();
}
-template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
+template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {
temporary_directory = config.temporary_directory_prefix;
@@ -966,14 +969,16 @@ template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *fi
// At least 1MB sorting memory.
ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
- BuildTrie(temporary_directory, counts, config, *this, quant_, backing);
+ BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {
*config.messages << "Failed to delete " << temporary_directory << std::endl;
}
}
-template class TrieSearch<DontQuantize>;
-template class TrieSearch<SeparatelyQuantize>;
+template class TrieSearch<DontQuantize, DontBhiksha>;
+template class TrieSearch<DontQuantize, ArrayBhiksha>;
+template class TrieSearch<SeparatelyQuantize, DontBhiksha>;
+template class TrieSearch<SeparatelyQuantize, ArrayBhiksha>;
} // namespace trie
} // namespace ngram