From 205893513c8343fdc55789e427fab4c8b536dc12 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 26 Jun 2011 18:40:15 -0400 Subject: Quantization --- klm/lm/search_trie.cc | 121 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 100 insertions(+), 21 deletions(-) (limited to 'klm/lm/search_trie.cc') diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 7c57072b..1ce4d278 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -4,6 +4,7 @@ #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" +#include "lm/quantize.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" #include "lm/vocab.hh" @@ -21,6 +22,7 @@ #include #include #include +#include #include #include @@ -579,7 +581,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W // Phase to count n-grams, including blanks inserted because they were pruned but have extensions class JustCount { public: - JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + template JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) : counts_(counts), longest_counts_(counts + order - 1) {} void Unigrams(WordIndex begin, WordIndex end) { @@ -608,9 +610,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -class WriteEntries { +template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -647,14 +649,14 @@ class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; - BitPackedLongest &longest_; + BitPackedMiddle *const middle_; + BitPackedLongest &longest_; BitPacked &bigram_pack_; }; template class RecursiveInsert { public: - RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + template RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { } @@ -775,7 +777,51 @@ void SanityCheckCounts(const std::vector &initial, const std::vector &counts, const Config &config, TrieSearch &out, Backing &backing) { +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + +template void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + ProbBackoff weights; + std::vector probs, backoffs; + probs.reserve(count); + backoffs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; + } + } + quant.Train(order, probs, backoffs); +} + +template void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + Prob weights; + std::vector probs, backoffs; + probs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + ++progress; + } + } + quant.TrainProb(order, probs); +} + +} // namespace + +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing) { std::vector inputs(counts.size() - 1); std::vector contexts(counts.size() - 1); @@ -791,7 +837,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } for (std::vector::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { @@ -800,7 +846,16 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + + if (Quant::kTrain) { + util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + for (unsigned char i = 2; i < counts.size(); ++i) { + TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + } + TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); + quant.FinishedLoading(config); + } for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -808,7 +863,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -845,23 +900,44 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. - if (!out.middle.empty()) { - for (size_t i = 0; i < out.middle.size() - 1; ++i) { - out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); + if (out.middle_begin_ != out.middle_end_) { + for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex()); } - out.middle.back().FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); } } -bool IsDirectory(const char *path) { - struct stat info; - if (0 != stat(path, &info)) return false; - return S_ISDIR(info.st_mode); +template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { + quant_.SetupMemory(start, config); + start += Quant::Size(counts.size(), config); + unigram.Init(start); + start += Unigram::Size(counts[0]); + FreeMiddles(); + middle_begin_ = static_cast(malloc(sizeof(Middle) * (counts.size() - 2))); + middle_end_ = middle_begin_ + (counts.size() - 2); + for (unsigned char i = counts.size() - 1; i >= 2; --i) { + new (middle_begin_ + i - 2) Middle( + start, + quant_.Mid(i), + counts[0], + counts[i], + (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + } + longest.Init(start, quant_.Long(counts.size()), counts[0]); + return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -} // namespace +template void TrieSearch::LoadedBinary() { + unigram.LoadedBinary(); + for (Middle *i = middle_begin_; i != middle_end_; ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); +} -void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -885,12 +961,15 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } +template class TrieSearch; +template class TrieSearch; + } // namespace trie } // namespace ngram } // namespace lm -- cgit v1.2.3 From 59932be2de387ecfcaa81a8387e8f21d5123c050 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 27 Jun 2011 17:50:41 -0400 Subject: Fix binary format for trie --- klm/lm/quantize.cc | 2 +- klm/lm/search_trie.cc | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'klm/lm/search_trie.cc') diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index b4d76893..4bb6b1b8 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -34,7 +34,7 @@ void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) { } } -const char kSeparatelyQuantizeVersion = 1; +const char kSeparatelyQuantizeVersion = 2; } // namespace diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1ce4d278..91f87f1c 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -916,14 +916,19 @@ template uint8_t *TrieSearch::SetupMemory(uint8_t *start, c FreeMiddles(); middle_begin_ = static_cast(malloc(sizeof(Middle) * (counts.size() - 2))); middle_end_ = middle_begin_ + (counts.size() - 2); + std::vector middle_starts(counts.size() - 2); + for (unsigned char i = 2; i < counts.size(); ++i) { + middle_starts[i-2] = start; + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + } + // Crazy backwards thing so we initialize in the correct order. for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( - start, + middle_starts[i-2], quant_.Mid(i), counts[0], counts[i], (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); - start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); -- cgit v1.2.3