diff options
Diffstat (limited to 'klm/lm/search_trie.hh')
-rw-r--r-- | klm/lm/search_trie.hh | 132 |
1 files changed, 74 insertions, 58 deletions
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0f720217..0a52acb5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,72 +13,88 @@ struct Backing; class SortedVocabulary; namespace trie { -struct TrieSearch { - typedef NodeRange Node; +template <class Quant> class TrieSearch; +template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); - typedef ::lm::ngram::trie::Unigram Unigram; - Unigram unigram; +template <class Quant> class TrieSearch { + public: + typedef NodeRange Node; - typedef trie::BitPackedMiddle Middle; - std::vector<Middle> middle; + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram; - typedef trie::BitPackedLongest Longest; - Longest longest; + typedef trie::BitPackedMiddle<typename Quant::Middle> Middle; - static const ModelType kModelType = TRIE_SORTED; + typedef trie::BitPackedLongest<typename Quant::Longest> Longest; + Longest longest; - static std::size_t Size(const std::vector<uint64_t> &counts, const Config &/*config*/) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(counts[i], counts[0], counts[i+1]); + static const ModelType kModelType = Quant::kModelType; + + static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { + Quant::UpdateConfigFromBinary(fd, counts, config); } - return ret + Longest::Size(counts.back(), counts[0]); - } - - uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &/*config*/) { - unigram.Init(start); - start += Unigram::Size(counts[0]); - middle.resize(counts.size() - 2); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - middle[i-1].Init( - start, - counts[0], - counts[i+1], - (i == counts.size() - 2) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle[i])); - start += Middle::Size(counts[i], counts[0], counts[i+1]); + + static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { + std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + } + return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } - longest.Init(start, counts[0]); - return start + Longest::Size(counts.back(), counts[0]); - } - - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); - } - - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); - } - - bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { - return mid.FindNoProb(word, backoff, node); - } - - bool LookupLongest(WordIndex word, float &prob, const Node &node) const { - return longest.Find(word, prob, node); - } - - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - // TODO: don't decode backoff. - assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); - for (const WordIndex *i = begin + 1; i < end; ++i) { - if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; + + TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {} + + ~TrieSearch() { FreeMiddles(); } + + uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); + + void LoadedBinary(); + + const Middle *MiddleBegin() const { return middle_begin_; } + const Middle *MiddleEnd() const { return middle_end_; } + + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + return unigram.Find(word, prob, backoff, node); + } + + bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { + return mid.Find(word, prob, backoff, node); } - return true; - } + + bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { + return mid.FindNoProb(word, backoff, node); + } + + bool LookupLongest(WordIndex word, float &prob, const Node &node) const { + return longest.Find(word, prob, node); + } + + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + // TODO: don't decode backoff. + assert(begin != end); + float ignored_prob, ignored_backoff; + LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + for (const WordIndex *i = begin + 1; i < end; ++i) { + if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; + } + return true; + } + + private: + friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); + + // Middles are managed manually so we can delay construction and they don't have to be copyable. + void FreeMiddles() { + for (const Middle *i = middle_begin_; i != middle_end_; ++i) { + i->~Middle(); + } + free(middle_begin_); + } + + Middle *middle_begin_, *middle_end_; + Quant quant_; }; } // namespace trie |