summaryrefslogtreecommitdiff
path: root/klm/lm/search_trie.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/search_trie.hh')
-rw-r--r--klm/lm/search_trie.hh132
1 files changed, 74 insertions, 58 deletions
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
index 0f720217..0a52acb5 100644
--- a/klm/lm/search_trie.hh
+++ b/klm/lm/search_trie.hh
@@ -13,72 +13,88 @@ struct Backing;
class SortedVocabulary;
namespace trie {
-struct TrieSearch {
- typedef NodeRange Node;
+template <class Quant> class TrieSearch;
+template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
- typedef ::lm::ngram::trie::Unigram Unigram;
- Unigram unigram;
+template <class Quant> class TrieSearch {
+ public:
+ typedef NodeRange Node;
- typedef trie::BitPackedMiddle Middle;
- std::vector<Middle> middle;
+ typedef ::lm::ngram::trie::Unigram Unigram;
+ Unigram unigram;
- typedef trie::BitPackedLongest Longest;
- Longest longest;
+ typedef trie::BitPackedMiddle<typename Quant::Middle> Middle;
- static const ModelType kModelType = TRIE_SORTED;
+ typedef trie::BitPackedLongest<typename Quant::Longest> Longest;
+ Longest longest;
- static std::size_t Size(const std::vector<uint64_t> &counts, const Config &/*config*/) {
- std::size_t ret = Unigram::Size(counts[0]);
- for (unsigned char i = 1; i < counts.size() - 1; ++i) {
- ret += Middle::Size(counts[i], counts[0], counts[i+1]);
+ static const ModelType kModelType = Quant::kModelType;
+
+ static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
+ Quant::UpdateConfigFromBinary(fd, counts, config);
}
- return ret + Longest::Size(counts.back(), counts[0]);
- }
-
- uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &/*config*/) {
- unigram.Init(start);
- start += Unigram::Size(counts[0]);
- middle.resize(counts.size() - 2);
- for (unsigned char i = 1; i < counts.size() - 1; ++i) {
- middle[i-1].Init(
- start,
- counts[0],
- counts[i+1],
- (i == counts.size() - 2) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle[i]));
- start += Middle::Size(counts[i], counts[0], counts[i+1]);
+
+ static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
+ std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
+ for (unsigned char i = 1; i < counts.size() - 1; ++i) {
+ ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]);
+ }
+ return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
- longest.Init(start, counts[0]);
- return start + Longest::Size(counts.back(), counts[0]);
- }
-
- void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
-
- bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
- return unigram.Find(word, prob, backoff, node);
- }
-
- bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
- return mid.Find(word, prob, backoff, node);
- }
-
- bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
- return mid.FindNoProb(word, backoff, node);
- }
-
- bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
- return longest.Find(word, prob, node);
- }
-
- bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
- // TODO: don't decode backoff.
- assert(begin != end);
- float ignored_prob, ignored_backoff;
- LookupUnigram(*begin, ignored_prob, ignored_backoff, node);
- for (const WordIndex *i = begin + 1; i < end; ++i) {
- if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false;
+
+ TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {}
+
+ ~TrieSearch() { FreeMiddles(); }
+
+ uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
+
+ void LoadedBinary();
+
+ const Middle *MiddleBegin() const { return middle_begin_; }
+ const Middle *MiddleEnd() const { return middle_end_; }
+
+ void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing);
+
+ bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
+ return unigram.Find(word, prob, backoff, node);
+ }
+
+ bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
+ return mid.Find(word, prob, backoff, node);
}
- return true;
- }
+
+ bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
+ return mid.FindNoProb(word, backoff, node);
+ }
+
+ bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
+ return longest.Find(word, prob, node);
+ }
+
+ bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
+ // TODO: don't decode backoff.
+ assert(begin != end);
+ float ignored_prob, ignored_backoff;
+ LookupUnigram(*begin, ignored_prob, ignored_backoff, node);
+ for (const WordIndex *i = begin + 1; i < end; ++i) {
+ if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false;
+ }
+ return true;
+ }
+
+ private:
+ friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing);
+
+ // Middles are managed manually so we can delay construction and they don't have to be copyable.
+ void FreeMiddles() {
+ for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
+ i->~Middle();
+ }
+ free(middle_begin_);
+ }
+
+ Middle *middle_begin_, *middle_end_;
+ Quant quant_;
};
} // namespace trie