diff options
author | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-05-31 13:57:24 +0200 |
---|---|---|
committer | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-05-31 13:57:24 +0200 |
commit | f1ba05780db1705493d9afb562332498b93d26f1 (patch) | |
tree | fb429a657ba97f33e8140742de9bc74d9fc88e75 /klm/lm/search_trie.hh | |
parent | aadabfdf37dfd451485277cb77fad02f77b361c6 (diff) | |
parent | 317d650f6cb1e24ac6f3be6f7bf9d4246a59e0e5 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/lm/search_trie.hh')
-rw-r--r-- | klm/lm/search_trie.hh | 71 |
1 files changed, 34 insertions, 37 deletions
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 5155ca02..10b22ab1 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -28,13 +28,11 @@ template <class Quant, class Bhiksha> class TrieSearch { public: typedef NodeRange Node; - typedef ::lm::ngram::trie::Unigram Unigram; - Unigram unigram; - - typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle; + typedef ::lm::ngram::trie::UnigramPointer UnigramPointer; + typedef typename Quant::MiddlePointer MiddlePointer; + typedef typename Quant::LongestPointer LongestPointer; - typedef trie::BitPackedLongest<typename Quant::Longest> Longest; - Longest longest; + static const bool kDifferentRest = false; static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); @@ -62,55 +60,46 @@ template <class Quant, class Bhiksha> class TrieSearch { void LoadedBinary(); - typedef const Middle *MiddleIter; + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - const Middle *MiddleBegin() const { return middle_begin_; } - const Middle *MiddleEnd() const { return middle_end_; } + unsigned char Order() const { + return middle_end_ - middle_begin_ + 2; + } - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + ProbBackoff &UnknownUnigram() { return unigram_.Unknown(); } - void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { - unigram.Find(word, ret.prob, backoff, node); - ret.independent_left = (node.begin == node.end); - ret.extend_left = static_cast<uint64_t>(word); + UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const { + extend_left = static_cast<uint64_t>(word); + UnigramPointer ret(unigram_.Find(word, next)); + independent_left = (next.begin == next.end); + return ret; } - bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { - if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false; - ret.independent_left = (node.begin == node.end); - return true; + MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { + return MiddlePointer(quant_, extend_length - 2, middle_begin_[extend_length - 2].ReadEntry(extend_pointer, node)); } - bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { - return mid.FindNoProb(word, backoff, node); + MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_left) const { + util::BitAddress address(middle_begin_[order_minus_2].Find(word, node, extend_left)); + independent_left = (address.base == NULL) || (node.begin == node.end); + return MiddlePointer(quant_, order_minus_2, address); } - bool LookupLongest(WordIndex word, float &prob, const Node &node) const { - return longest.Find(word, prob, node); + LongestPointer LookupLongest(WordIndex word, const Node &node) const { + return LongestPointer(quant_, longest_.Find(word, node)); } bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - // TODO: don't decode backoff. assert(begin != end); - FullScoreReturn ignored; - float ignored_backoff; - LookupUnigram(*begin, ignored_backoff, node, ignored); + bool independent_left; + uint64_t ignored; + LookupUnigram(*begin, node, independent_left, ignored); for (const WordIndex *i = begin + 1; i < end; ++i) { - if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; + if (independent_left || !LookupMiddle(i - begin - 1, *i, node, independent_left, ignored).Found()) return false; } return true; } - Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { - if (extend_length == 1) { - float ignored; - Node ret; - unigram.Find(static_cast<WordIndex>(extend_pointer), prob, ignored, ret); - return ret; - } - return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob); - } - private: friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); @@ -122,8 +111,16 @@ template <class Quant, class Bhiksha> class TrieSearch { free(middle_begin_); } + typedef trie::BitPackedMiddle<Bhiksha> Middle; + + typedef trie::BitPackedLongest Longest; + Longest longest_; + Middle *middle_begin_, *middle_end_; Quant quant_; + + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram_; }; } // namespace trie |