#ifndef LM_SEARCH_TRIE__ #define LM_SEARCH_TRIE__ #include "lm/binary_format.hh" #include "lm/trie.hh" #include "lm/weights.hh" #include <assert.h> namespace lm { namespace ngram { struct Backing; class SortedVocabulary; namespace trie { struct TrieSearch { typedef NodeRange Node; typedef ::lm::ngram::trie::Unigram Unigram; Unigram unigram; typedef trie::BitPackedMiddle Middle; std::vector<Middle> middle; typedef trie::BitPackedLongest Longest; Longest longest; static const ModelType kModelType = TRIE_SORTED; static std::size_t Size(const std::vector<uint64_t> &counts, const Config &/*config*/) { std::size_t ret = Unigram::Size(counts[0]); for (unsigned char i = 1; i < counts.size() - 1; ++i) { ret += Middle::Size(counts[i], counts[0], counts[i+1]); } return ret + Longest::Size(counts.back(), counts[0]); } uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &/*config*/) { unigram.Init(start); start += Unigram::Size(counts[0]); middle.resize(counts.size() - 2); for (unsigned char i = 1; i < counts.size() - 1; ++i) { middle[i-1].Init( start, counts[0], counts[i+1], (i == counts.size() - 2) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle[i])); start += Middle::Size(counts[i], counts[0], counts[i+1]); } longest.Init(start, counts[0]); return start + Longest::Size(counts.back(), counts[0]); } void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { return unigram.Find(word, prob, backoff, node); } bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { return mid.Find(word, prob, backoff, node); } bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { return mid.FindNoProb(word, backoff, node); } bool LookupLongest(WordIndex word, float &prob, const Node &node) const { return longest.Find(word, prob, node); } bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { // TODO: don't decode backoff. assert(begin != end); float ignored_prob, ignored_backoff; LookupUnigram(*begin, ignored_prob, ignored_backoff, node); for (const WordIndex *i = begin + 1; i < end; ++i) { if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; } return true; } }; } // namespace trie } // namespace ngram } // namespace lm #endif // LM_SEARCH_TRIE__