summaryrefslogtreecommitdiff
path: root/klm/lm/trie.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/trie.hh')
-rw-r--r--klm/lm/trie.hh61
1 files changed, 36 insertions, 25 deletions
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index ebe9910f..eff93292 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -1,12 +1,13 @@
#ifndef LM_TRIE__
#define LM_TRIE__
-#include <stdint.h>
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+#include "util/bit_packing.hh"
#include <cstddef>
-#include "lm/word_index.hh"
-#include "lm/weights.hh"
+#include <stdint.h>
namespace lm {
namespace ngram {
@@ -24,6 +25,22 @@ struct UnigramValue {
uint64_t Next() const { return next; }
};
+class UnigramPointer {
+ public:
+ explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {}
+
+ UnigramPointer() : to_(NULL) {}
+
+ bool Found() const { return to_ != NULL; }
+
+ float Prob() const { return to_->prob; }
+ float Backoff() const { return to_->backoff; }
+ float Rest() const { return Prob(); }
+
+ private:
+ const ProbBackoff *to_;
+};
+
class Unigram {
public:
Unigram() {}
@@ -47,12 +64,11 @@ class Unigram {
void LoadedBinary() {}
- void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
+ UnigramPointer Find(WordIndex word, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
- prob = val->weights.prob;
- backoff = val->weights.backoff;
next.begin = val->next;
next.end = (val+1)->next;
+ return UnigramPointer(val->weights);
}
private:
@@ -81,40 +97,36 @@ class BitPacked {
uint64_t insert_index_, max_vocab_;
};
-template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {
+template <class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
- BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
+ BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
- void Insert(WordIndex word, float prob, float backoff);
+ util::BitAddress Insert(WordIndex word);
void FinishedLoading(uint64_t next_end, const Config &config);
void LoadedBinary() { bhiksha_.LoadedBinary(); }
- bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const;
-
- bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
+ util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
- NodeRange ReadEntry(uint64_t pointer, float &prob) {
+ util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
uint64_t addr = pointer * total_bits_;
addr += word_bits_;
- quant_.ReadProb(base_, addr, prob);
- NodeRange ret;
- bhiksha_.ReadNext(base_, addr + quant_.TotalBits(), pointer, total_bits_, ret);
- return ret;
+ bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range);
+ return util::BitAddress(base_, addr);
}
private:
- Quant quant_;
+ uint8_t quant_bits_;
Bhiksha bhiksha_;
const BitPacked *next_source_;
};
-template <class Quant> class BitPackedLongest : public BitPacked {
+class BitPackedLongest : public BitPacked {
public:
static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
return BaseSize(entries, max_vocab, quant_bits);
@@ -122,19 +134,18 @@ template <class Quant> class BitPackedLongest : public BitPacked {
BitPackedLongest() {}
- void Init(void *base, const Quant &quant, uint64_t max_vocab) {
- quant_ = quant;
- BaseInit(base, max_vocab, quant_.TotalBits());
+ void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) {
+ BaseInit(base, max_vocab, quant_bits);
}
void LoadedBinary() {}
- void Insert(WordIndex word, float prob);
+ util::BitAddress Insert(WordIndex word);
- bool Find(WordIndex word, float &prob, const NodeRange &node) const;
+ util::BitAddress Find(WordIndex word, const NodeRange &node) const;
private:
- Quant quant_;
+ uint8_t quant_bits_;
};
} // namespace trie