summaryrefslogtreecommitdiff
path: root/klm/lm/trie.hh
blob: cd39298b53976682d17e2c4dbd11dbb1a15c3d32 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#ifndef LM_TRIE_H
#define LM_TRIE_H

#include "lm/weights.hh"
#include "lm/word_index.hh"
#include "util/bit_packing.hh"

#include <cstddef>

#include <stdint.h>

namespace lm {
namespace ngram {
struct Config;
namespace trie {

struct NodeRange {
  uint64_t begin, end;
};

// TODO: if the number of unigrams is a concern, also bit pack these records.  
struct UnigramValue {
  ProbBackoff weights;
  uint64_t next;
  uint64_t Next() const { return next; }
};

class UnigramPointer {
  public:
    explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {}

    UnigramPointer() : to_(NULL) {}

    bool Found() const { return to_ != NULL; }

    float Prob() const { return to_->prob; }
    float Backoff() const { return to_->backoff; }
    float Rest() const { return Prob(); }

  private:
    const ProbBackoff *to_;
};

class Unigram {
  public:
    Unigram() {}
    
    void Init(void *start) {
      unigram_ = static_cast<UnigramValue*>(start);
    }
    
    static uint64_t Size(uint64_t count) {
      // +1 in case unknown doesn't appear.  +1 for the final next.  
      return (count + 2) * sizeof(UnigramValue);
    }
    
    const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
    
    ProbBackoff &Unknown() { return unigram_[0].weights; }

    UnigramValue *Raw() {
      return unigram_;
    }
    
    UnigramPointer Find(WordIndex word, NodeRange &next) const {
      UnigramValue *val = unigram_ + word;
      next.begin = val->next;
      next.end = (val+1)->next;
      return UnigramPointer(val->weights);
    }

  private:
    UnigramValue *unigram_;
};  

class BitPacked {
  public:
    BitPacked() {}

    uint64_t InsertIndex() const {
      return insert_index_;
    }

  protected:
    static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);

    void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);

    uint8_t word_bits_;
    uint8_t total_bits_;
    uint64_t word_mask_;

    uint8_t *base_;

    uint64_t insert_index_, max_vocab_;
};

template <class Bhiksha> class BitPackedMiddle : public BitPacked {
  public:
    static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);

    // next_source need not be initialized.  
    BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);

    util::BitAddress Insert(WordIndex word);

    void FinishedLoading(uint64_t next_end, const Config &config);

    util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;

    util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
      uint64_t addr = pointer * total_bits_;
      addr += word_bits_;
      bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range);
      return util::BitAddress(base_, addr);
    }

  private:
    uint8_t quant_bits_;
    Bhiksha bhiksha_;

    const BitPacked *next_source_;
};

class BitPackedLongest : public BitPacked {
  public:
    static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
      return BaseSize(entries, max_vocab, quant_bits);
    }

    BitPackedLongest() {}

    void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) {
      BaseInit(base, max_vocab, quant_bits);
    }

    util::BitAddress Insert(WordIndex word);

    util::BitAddress Find(WordIndex word, const NodeRange &node) const;
};

} // namespace trie
} // namespace ngram
} // namespace lm

#endif // LM_TRIE_H