summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.cc')
-rw-r--r--klm/lm/search_trie.cc126
1 files changed, 105 insertions, 21 deletions
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
index 7c57072b..91f87f1c 100644
--- a/klm/lm/search_trie.cc
+++ b/klm/lm/search_trie.cc
@@ -4,6 +4,7 @@
#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/max_order.hh"
+#include "lm/quantize.hh"
#include "lm/read_arpa.hh"
#include "lm/trie.hh"
#include "lm/vocab.hh"
@@ -21,6 +22,7 @@
#include <cstdio>
#include <deque>
#include <limits>
+#include <numeric>
#include <vector>
#include <sys/mman.h>
@@ -579,7 +581,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W
// Phase to count n-grams, including blanks inserted because they were pruned but have extensions
class JustCount {
public:
- JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order)
+ template <class Middle, class Longest> JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order)
: counts_(counts), longest_counts_(counts + order - 1) {}
void Unigrams(WordIndex begin, WordIndex end) {
@@ -608,9 +610,9 @@ class JustCount {
};
// Phase to actually write n-grams to the trie.
-class WriteEntries {
+template <class Quant> class WriteEntries {
public:
- WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) :
+ WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :
contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
@@ -647,14 +649,14 @@ class WriteEntries {
private:
ContextReader *contexts_;
UnigramValue *const unigrams_;
- BitPackedMiddle *const middle_;
- BitPackedLongest &longest_;
+ BitPackedMiddle<typename Quant::Middle> *const middle_;
+ BitPackedLongest<typename Quant::Longest> &longest_;
BitPacked &bigram_pack_;
};
template <class Doing> class RecursiveInsert {
public:
- RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) :
+ template <class MiddleT, class LongestT> RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) :
doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) {
}
@@ -775,7 +777,51 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u
}
}
-void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) {
+bool IsDirectory(const char *path) {
+ struct stat info;
+ if (0 != stat(path, &info)) return false;
+ return S_ISDIR(info.st_mode);
+}
+
+template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) {
+ ProbBackoff weights;
+ std::vector<float> probs, backoffs;
+ probs.reserve(count);
+ backoffs.reserve(count);
+ for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) {
+ uint64_t entries = reader.ReadCount();
+ for (uint64_t c = 0; c < entries; ++c) {
+ reader.ReadWord();
+ reader.ReadWeights(weights);
+ // kBlankProb isn't added yet.
+ probs.push_back(weights.prob);
+ if (weights.backoff != 0.0) backoffs.push_back(weights.backoff);
+ ++progress;
+ }
+ }
+ quant.Train(order, probs, backoffs);
+}
+
+template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) {
+ Prob weights;
+ std::vector<float> probs, backoffs;
+ probs.reserve(count);
+ for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) {
+ uint64_t entries = reader.ReadCount();
+ for (uint64_t c = 0; c < entries; ++c) {
+ reader.ReadWord();
+ reader.ReadWeights(weights);
+ // kBlankProb isn't added yet.
+ probs.push_back(weights.prob);
+ ++progress;
+ }
+ }
+ quant.TrainProb(order, probs);
+}
+
+} // namespace
+
+template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing) {
std::vector<SortedFileReader> inputs(counts.size() - 1);
std::vector<ContextReader> contexts(counts.size() - 1);
@@ -791,7 +837,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
std::vector<uint64_t> fixed_counts(counts.size());
{
- RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]);
}
for (std::vector<SortedFileReader>::const_iterator i = inputs.begin(); i != inputs.end(); ++i) {
@@ -800,7 +846,16 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
- out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config);
+ out.SetupMemory(GrowForSearch(config, TrieSearch<Quant>::Size(fixed_counts, config), backing), fixed_counts, config);
+
+ if (Quant::kTrain) {
+ util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
+ for (unsigned char i = 2; i < counts.size(); ++i) {
+ TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant);
+ }
+ TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant);
+ quant.FinishedLoading(config);
+ }
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
@@ -808,7 +863,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
UnigramValue *unigrams = out.unigram.Raw();
// Fill entries except unigram probabilities.
{
- RecursiveInsert<WriteEntries> inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size());
+ RecursiveInsert<WriteEntries<Quant> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());
inserter.Apply(config.messages, "Building trie", fixed_counts[0]);
}
@@ -845,23 +900,49 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co
/* Set ending offsets so the last entry will be sized properly */
// Last entry for unigrams was already set.
- if (!out.middle.empty()) {
- for (size_t i = 0; i < out.middle.size() - 1; ++i) {
- out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex());
+ if (out.middle_begin_ != out.middle_end_) {
+ for (typename TrieSearch<Quant>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
+ i->FinishedLoading((i+1)->InsertIndex());
}
- out.middle.back().FinishedLoading(out.longest.InsertIndex());
+ (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex());
}
}
-bool IsDirectory(const char *path) {
- struct stat info;
- if (0 != stat(path, &info)) return false;
- return S_ISDIR(info.st_mode);
+template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+ quant_.SetupMemory(start, config);
+ start += Quant::Size(counts.size(), config);
+ unigram.Init(start);
+ start += Unigram::Size(counts[0]);
+ FreeMiddles();
+ middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2)));
+ middle_end_ = middle_begin_ + (counts.size() - 2);
+ std::vector<uint8_t*> middle_starts(counts.size() - 2);
+ for (unsigned char i = 2; i < counts.size(); ++i) {
+ middle_starts[i-2] = start;
+ start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]);
+ }
+ // Crazy backwards thing so we initialize in the correct order.
+ for (unsigned char i = counts.size() - 1; i >= 2; --i) {
+ new (middle_begin_ + i - 2) Middle(
+ middle_starts[i-2],
+ quant_.Mid(i),
+ counts[0],
+ counts[i],
+ (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]));
+ }
+ longest.Init(start, quant_.Long(counts.size()), counts[0]);
+ return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
-} // namespace
+template <class Quant> void TrieSearch<Quant>::LoadedBinary() {
+ unigram.LoadedBinary();
+ for (Middle *i = middle_begin_; i != middle_end_; ++i) {
+ i->LoadedBinary();
+ }
+ longest.LoadedBinary();
+}
-void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
+template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {
std::string temporary_directory;
if (config.temporary_directory_prefix) {
temporary_directory = config.temporary_directory_prefix;
@@ -885,12 +966,15 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v
// At least 1MB sorting memory.
ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab);
- BuildTrie(temporary_directory, counts, config, *this, backing);
+ BuildTrie(temporary_directory, counts, config, *this, quant_, backing);
if (rmdir(temporary_directory.c_str()) && config.messages) {
*config.messages << "Failed to delete " << temporary_directory << std::endl;
}
}
+template class TrieSearch<DontQuantize>;
+template class TrieSearch<SeparatelyQuantize>;
+
} // namespace trie
} // namespace ngram
} // namespace lm