summaryrefslogtreecommitdiff
path: root/klm/lm/trie.cc
blob: 8ed7b2a2126d27bff820bddc07679f7bb35fc76c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#include "lm/trie.hh"

#include "util/bit_packing.hh"
#include "util/exception.hh"
#include "util/proxy_iterator.hh"
#include "util/sorted_uniform.hh"

#include <assert.h>

namespace lm {
namespace ngram {
namespace trie {
namespace {

// Assumes key is first.  
class JustKeyProxy {
  public:
    JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {}

    operator uint64_t() const { return GetKey(); }

    uint64_t GetKey() const {
      uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_);
      return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_);
    }

  private:
    friend class util::ProxyIterator<JustKeyProxy>;
    friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);

    JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits)
      : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {}

    // This is a read-only iterator.  
    JustKeyProxy &operator=(const JustKeyProxy &other);

    typedef uint64_t value_type;

    typedef uint64_t InnerIterator;
    uint64_t &Inner() { return inner_; }
    const uint64_t &Inner() const { return inner_; }

    // The address in bits is base_ * 8 + inner_ * total_bits_.  
    uint64_t inner_;
    const uint8_t *const base_;
    const uint64_t key_mask_;
    const uint8_t total_bits_;
};

bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
  util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits));
  util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits));
  util::ProxyIterator<JustKeyProxy> out;
  if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;
  at_index = out.Inner();
  return true;
}
} // namespace

std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
  uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits;
  // Extra entry for next pointer at the end.  
  // +7 then / 8 to round up bits and convert to bytes
  // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.  
  // Note that this waste is O(order), not O(number of ngrams).
  return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t);
}

void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) {
  util::BitPackingSanity();
  word_bits_ = util::RequiredBits(max_vocab);
  word_mask_ = (1ULL << word_bits_) - 1ULL;
  if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented.  Edit util/bit_packing.hh and fix the bit packing functions.");
  prob_bits_ = 31;
  total_bits_ = word_bits_ + prob_bits_ + remaining_bits;

  base_ = static_cast<uint8_t*>(base);
  insert_index_ = 0;
}

std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
  return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr));
}

void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) {
  backoff_bits_ = 32;
  next_bits_ = util::RequiredBits(max_next);
  if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order.  Edit util/bit_packing.hh and fix the bit packing functions.");
  next_mask_ = (1ULL << next_bits_)  - 1;

  BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
}

void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) {
  assert(word <= word_mask_);
  assert(next <= next_mask_);
  uint64_t at_pointer = insert_index_ * total_bits_;

  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word);
  at_pointer += word_bits_;
  util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
  at_pointer += prob_bits_;
  util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);
  at_pointer += backoff_bits_;
  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next);

  ++insert_index_;
}

bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
  uint64_t at_pointer;
  if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
  at_pointer *= total_bits_;
  at_pointer += word_bits_;
  prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
  at_pointer += prob_bits_;
  backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
  at_pointer += backoff_bits_;
  range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
  // Read the next entry's pointer.  
  at_pointer += total_bits_;
  range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
  return true;
}

bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
  uint64_t at_pointer;
  if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
  at_pointer *= total_bits_;
  at_pointer += word_bits_;
  at_pointer += prob_bits_;
  backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
  at_pointer += backoff_bits_;
  range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
  // Read the next entry's pointer.  
  at_pointer += total_bits_;
  range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
  return true;
}

void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
  assert(next_end <= next_mask_);
  uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
  util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end);
}


void BitPackedLongest::Insert(WordIndex index, float prob) {
  assert(index <= word_mask_);
  uint64_t at_pointer = insert_index_ * total_bits_;
  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index);
  at_pointer += word_bits_;
  util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
  ++insert_index_;
}

bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const {
  uint64_t at_pointer;
  if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false;
  at_pointer = at_pointer * total_bits_ + word_bits_;
  prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
  return true;
}

} // namespace trie
} // namespace ngram
} // namespace lm