From 205893513c8343fdc55789e427fab4c8b536dc12 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 26 Jun 2011 18:40:15 -0400 Subject: Quantization --- klm/lm/search_trie.hh | 132 ++++++++++++++++++++++++++++---------------------- 1 file changed, 74 insertions(+), 58 deletions(-) (limited to 'klm/lm/search_trie.hh') diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0f720217..0a52acb5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,72 +13,88 @@ struct Backing; class SortedVocabulary; namespace trie { -struct TrieSearch { - typedef NodeRange Node; +template class TrieSearch; +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); - typedef ::lm::ngram::trie::Unigram Unigram; - Unigram unigram; +template class TrieSearch { + public: + typedef NodeRange Node; - typedef trie::BitPackedMiddle Middle; - std::vector middle; + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram; - typedef trie::BitPackedLongest Longest; - Longest longest; + typedef trie::BitPackedMiddle Middle; - static const ModelType kModelType = TRIE_SORTED; + typedef trie::BitPackedLongest Longest; + Longest longest; - static std::size_t Size(const std::vector &counts, const Config &/*config*/) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(counts[i], counts[0], counts[i+1]); + static const ModelType kModelType = Quant::kModelType; + + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { + Quant::UpdateConfigFromBinary(fd, counts, config); } - return ret + Longest::Size(counts.back(), counts[0]); - } - - uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &/*config*/) { - unigram.Init(start); - start += Unigram::Size(counts[0]); - middle.resize(counts.size() - 2); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - middle[i-1].Init( - start, - counts[0], - counts[i+1], - (i == counts.size() - 2) ? static_cast(longest) : static_cast(middle[i])); - start += Middle::Size(counts[i], counts[0], counts[i+1]); + + static std::size_t Size(const std::vector &counts, const Config &config) { + std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + } + return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } - longest.Init(start, counts[0]); - return start + Longest::Size(counts.back(), counts[0]); - } - - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); - } - - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); - } - - bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { - return mid.FindNoProb(word, backoff, node); - } - - bool LookupLongest(WordIndex word, float &prob, const Node &node) const { - return longest.Find(word, prob, node); - } - - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - // TODO: don't decode backoff. - assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); - for (const WordIndex *i = begin + 1; i < end; ++i) { - if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; + + TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {} + + ~TrieSearch() { FreeMiddles(); } + + uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config); + + void LoadedBinary(); + + const Middle *MiddleBegin() const { return middle_begin_; } + const Middle *MiddleEnd() const { return middle_end_; } + + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + return unigram.Find(word, prob, backoff, node); + } + + bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { + return mid.Find(word, prob, backoff, node); } - return true; - } + + bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { + return mid.FindNoProb(word, backoff, node); + } + + bool LookupLongest(WordIndex word, float &prob, const Node &node) const { + return longest.Find(word, prob, node); + } + + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + // TODO: don't decode backoff. + assert(begin != end); + float ignored_prob, ignored_backoff; + LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + for (const WordIndex *i = begin + 1; i < end; ++i) { + if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; + } + return true; + } + + private: + friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); + + // Middles are managed manually so we can delay construction and they don't have to be copyable. + void FreeMiddles() { + for (const Middle *i = middle_begin_; i != middle_end_; ++i) { + i->~Middle(); + } + free(middle_begin_); + } + + Middle *middle_begin_, *middle_end_; + Quant quant_; }; } // namespace trie -- cgit v1.2.3