summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.hh
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-05-16 13:24:08 -0700
committerChris Dyer <cdyer@cab.ark.cs.cmu.edu>2012-05-26 22:59:54 -0400
commit149232c38eec558ddb1097698d1570aacb67b59f (patch)
tree5860b4d6f681eeb04a1020cbb2fe7e6ac394af99 /klm/lm/search_trie.hh
parent01ecc09f8e3a82c32bf7dd2f90c12554becea71d (diff)
Big kenlm change includes lower order models for probing only. And other stuff.
Diffstat (limited to 'klm/lm/search_trie.hh')
-rw-r--r--klm/lm/search_trie.hh71
1 files changed, 34 insertions, 37 deletions
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 5155ca02..10b22ab1 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -28,13 +28,11 @@ template <class Quant, class Bhiksha> class TrieSearch {
public:
typedef NodeRange Node;
- typedef ::lm::ngram::trie::Unigram Unigram;
- Unigram unigram;
-
- typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;
+ typedef ::lm::ngram::trie::UnigramPointer UnigramPointer;
+ typedef typename Quant::MiddlePointer MiddlePointer;
+ typedef typename Quant::LongestPointer LongestPointer;
- typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
- Longest longest;
+ static const bool kDifferentRest = false;
static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
@@ -62,55 +60,46 @@ template <class Quant, class Bhiksha> class TrieSearch {
void LoadedBinary();
- typedef const Middle *MiddleIter;
+ void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
- const Middle *MiddleBegin() const { return middle_begin_; }
- const Middle *MiddleEnd() const { return middle_end_; }
+ unsigned char Order() const {
+ return middle_end_ - middle_begin_ + 2;
+ }
- void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
+ ProbBackoff &UnknownUnigram() { return unigram_.Unknown(); }
- void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
- unigram.Find(word, ret.prob, backoff, node);
- ret.independent_left = (node.begin == node.end);
- ret.extend_left = static_cast<uint64_t>(word);
+ UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
+ extend_left = static_cast<uint64_t>(word);
+ UnigramPointer ret(unigram_.Find(word, next));
+ independent_left = (next.begin == next.end);
+ return ret;
}
- bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const {
- if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false;
- ret.independent_left = (node.begin == node.end);
- return true;
+ MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
+ return MiddlePointer(quant_, extend_length - 2, middle_begin_[extend_length - 2].ReadEntry(extend_pointer, node));
}
- bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
- return mid.FindNoProb(word, backoff, node);
+ MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_left) const {
+ util::BitAddress address(middle_begin_[order_minus_2].Find(word, node, extend_left));
+ independent_left = (address.base == NULL) || (node.begin == node.end);
+ return MiddlePointer(quant_, order_minus_2, address);
}
- bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
- return longest.Find(word, prob, node);
+ LongestPointer LookupLongest(WordIndex word, const Node &node) const {
+ return LongestPointer(quant_, longest_.Find(word, node));
}
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
- // TODO: don't decode backoff.
assert(begin != end);
- FullScoreReturn ignored;
- float ignored_backoff;
- LookupUnigram(*begin, ignored_backoff, node, ignored);
+ bool independent_left;
+ uint64_t ignored;
+ LookupUnigram(*begin, node, independent_left, ignored);
for (const WordIndex *i = begin + 1; i < end; ++i) {
- if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false;
+ if (independent_left || !LookupMiddle(i - begin - 1, *i, node, independent_left, ignored).Found()) return false;
}
return true;
}
- Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const {
- if (extend_length == 1) {
- float ignored;
- Node ret;
- unigram.Find(static_cast<WordIndex>(extend_pointer), prob, ignored, ret);
- return ret;
- }
- return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob);
- }
-
private:
friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);
@@ -122,8 +111,16 @@ template <class Quant, class Bhiksha> class TrieSearch {
free(middle_begin_);
}
+ typedef trie::BitPackedMiddle<Bhiksha> Middle;
+
+ typedef trie::BitPackedLongest Longest;
+ Longest longest_;
+
Middle *middle_begin_, *middle_end_;
Quant quant_;
+
+ typedef ::lm::ngram::trie::Unigram Unigram;
+ Unigram unigram_;
};
} // namespace trie