summaryrefslogtreecommitdiff
path: root/klm/lm/trie.hh
diff options
context:
space:
mode:
authorredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-10 22:45:13 +0000
committerredpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-11-10 22:45:13 +0000
commit48c848908a391c157af4ad9266e8b616e8106d2e (patch)
treecf9444481236b5deb32dede4d9c6496b46a2d011 /klm/lm/trie.hh
parent3911c3c95647f97cdfffa1ae4a8ddc7f06d51b81 (diff)
forgotten files
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@710 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'klm/lm/trie.hh')
-rw-r--r--klm/lm/trie.hh129
1 files changed, 129 insertions, 0 deletions
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
new file mode 100644
index 00000000..35dc2c96
--- /dev/null
+++ b/klm/lm/trie.hh
@@ -0,0 +1,129 @@
+#ifndef LM_TRIE__
+#define LM_TRIE__
+
+#include <inttypes.h>
+
+#include <cstddef>
+
+#include "lm/word_index.hh"
+#include "lm/weights.hh"
+
+namespace lm {
+namespace ngram {
+namespace trie {
+
+struct NodeRange {
+ uint64_t begin, end;
+};
+
+// TODO: if the number of unigrams is a concern, also bit pack these records.
+struct UnigramValue {
+ ProbBackoff weights;
+ uint64_t next;
+ uint64_t Next() const { return next; }
+};
+
+class Unigram {
+ public:
+ Unigram() {}
+
+ void Init(void *start) {
+ unigram_ = static_cast<UnigramValue*>(start);
+ }
+
+ static std::size_t Size(uint64_t count) {
+ // +1 in case unknown doesn't appear. +1 for the final next.
+ return (count + 2) * sizeof(UnigramValue);
+ }
+
+ const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
+
+ ProbBackoff &Unknown() { return unigram_[0].weights; }
+
+ UnigramValue *Raw() {
+ return unigram_;
+ }
+
+ void LoadedBinary() {}
+
+ bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
+ UnigramValue *val = unigram_ + word;
+ prob = val->weights.prob;
+ backoff = val->weights.backoff;
+ next.begin = val->next;
+ next.end = (val+1)->next;
+ return true;
+ }
+
+ private:
+ UnigramValue *unigram_;
+};
+
+class BitPacked {
+ public:
+ BitPacked() {}
+
+ uint64_t InsertIndex() const {
+ return insert_index_;
+ }
+
+ void LoadedBinary() {}
+
+ protected:
+ static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
+
+ void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
+
+ uint8_t word_bits_, prob_bits_;
+ uint8_t total_bits_;
+ uint64_t word_mask_;
+
+ uint8_t *base_;
+
+ uint64_t insert_index_;
+};
+
+class BitPackedMiddle : public BitPacked {
+ public:
+ BitPackedMiddle() {}
+
+ static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next);
+
+ void Init(void *base, uint64_t max_vocab, uint64_t max_next);
+
+ void Insert(WordIndex word, float prob, float backoff, uint64_t next);
+
+ bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;
+
+ bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
+
+ void FinishedLoading(uint64_t next_end);
+
+ private:
+ uint8_t backoff_bits_, next_bits_;
+ uint64_t next_mask_;
+};
+
+
+class BitPackedLongest : public BitPacked {
+ public:
+ BitPackedLongest() {}
+
+ static std::size_t Size(uint64_t entries, uint64_t max_vocab) {
+ return BaseSize(entries, max_vocab, 0);
+ }
+
+ void Init(void *base, uint64_t max_vocab) {
+ return BaseInit(base, max_vocab, 0);
+ }
+
+ void Insert(WordIndex word, float prob);
+
+ bool Find(WordIndex word, float &prob, const NodeRange &node) const;
+};
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_TRIE__