summaryrefslogtreecommitdiff
path: root/klm/lm/trie.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie.hh')
-rw-r--r--klm/lm/trie.hh24
1 files changed, 12 insertions, 12 deletions
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index 8fa21aaf..53612064 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -10,6 +10,7 @@
namespace lm {
namespace ngram {
+class Config;
namespace trie {
struct NodeRange {
@@ -46,13 +47,12 @@ class Unigram {
void LoadedBinary() {}
- bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
+ void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
prob = val->weights.prob;
backoff = val->weights.backoff;
next.begin = val->next;
next.end = (val+1)->next;
- return true;
}
private:
@@ -67,8 +67,6 @@ class BitPacked {
return insert_index_;
}
- void LoadedBinary() {}
-
protected:
static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
@@ -83,30 +81,30 @@ class BitPacked {
uint64_t insert_index_, max_vocab_;
};
-template <class Quant> class BitPackedMiddle : public BitPacked {
+template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
- static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next);
+ static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
- BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source);
+ BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
void Insert(WordIndex word, float prob, float backoff);
+ void FinishedLoading(uint64_t next_end, const Config &config);
+
+ void LoadedBinary() { bhiksha_.LoadedBinary(); }
+
bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;
bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
- void FinishedLoading(uint64_t next_end);
-
private:
Quant quant_;
- uint8_t next_bits_;
- uint64_t next_mask_;
+ Bhiksha bhiksha_;
const BitPacked *next_source_;
};
-
template <class Quant> class BitPackedLongest : public BitPacked {
public:
static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
@@ -120,6 +118,8 @@ template <class Quant> class BitPackedLongest : public BitPacked {
BaseInit(base, max_vocab, quant_.TotalBits());
}
+ void LoadedBinary() {}
+
void Insert(WordIndex word, float prob);
bool Find(WordIndex word, float &prob, const NodeRange &node) const;