summaryrefslogtreecommitdiff
path: root/klm/lm/trie.cc
blob: 8c536e66771634ef7aa612e37087408fab18cd71 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include "lm/trie.hh"

#include "lm/bhiksha.hh"
#include "lm/quantize.hh"
#include "util/bit_packing.hh"
#include "util/exception.hh"
#include "util/sorted_uniform.hh"

#include <assert.h>

namespace lm {
namespace ngram {
namespace trie {
namespace {

class KeyAccessor {
  public:
    KeyAccessor(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) 
      : base_(reinterpret_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}

    typedef uint64_t Key;

    Key operator()(uint64_t index) const {
      return util::ReadInt57(base_, index * static_cast<uint64_t>(total_bits_), key_bits_, key_mask_);
    }

  private:
    const uint8_t *const base_;
    const WordIndex key_mask_;
    const uint8_t key_bits_, total_bits_;
};

bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const uint64_t key, uint64_t &at_index) {
  KeyAccessor accessor(base, key_mask, key_bits, total_bits);
  if (!util::BoundedSortedUniformFind<uint64_t, KeyAccessor, util::PivotSelect<sizeof(WordIndex)>::T>(accessor, begin_index - 1, (uint64_t)0, end_index, max_vocab, key, at_index)) return false;
  return true;
}
} // namespace

std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
  uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits;
  // Extra entry for next pointer at the end.  
  // +7 then / 8 to round up bits and convert to bytes
  // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.  
  // Note that this waste is O(order), not O(number of ngrams).
  return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t);
}

void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) {
  util::BitPackingSanity();
  word_bits_ = util::RequiredBits(max_vocab);
  word_mask_ = (1ULL << word_bits_) - 1ULL;
  if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented.  Edit util/bit_packing.hh and fix the bit packing functions.");
  total_bits_ = word_bits_ + remaining_bits;

  base_ = static_cast<uint8_t*>(base);
  insert_index_ = 0;
  max_vocab_ = max_vocab;
}

template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
  return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}

template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
  BitPacked(),
  quant_(quant),
  // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.
  bhiksha_(base, entries + 1, max_next, config),
  next_source_(&next_source) {
  if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57)))  UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order.  Edit util/bit_packing.hh and fix the bit packing functions.");
  BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits());
}

template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) {
  assert(word <= word_mask_);
  uint64_t at_pointer = insert_index_ * total_bits_;

  util::WriteInt57(base_, at_pointer, word_bits_, word);
  at_pointer += word_bits_;
  quant_.Write(base_, at_pointer, prob, backoff);
  at_pointer += quant_.TotalBits();
  uint64_t next = next_source_->InsertIndex();
  bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);

  ++insert_index_;
}

template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
  uint64_t at_pointer;
  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
    return false;
  }
  uint64_t index = at_pointer;
  at_pointer *= total_bits_;
  at_pointer += word_bits_;
  quant_.Read(base_, at_pointer, prob, backoff);
  at_pointer += quant_.TotalBits();

  bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);

  return true;
}

template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
  uint64_t index;
  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false;
  uint64_t at_pointer = index * total_bits_;
  at_pointer += word_bits_;
  quant_.ReadBackoff(base_, at_pointer, backoff);
  at_pointer += quant_.TotalBits();
  bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
  return true;
}

template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
  uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits();
  bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end);
  bhiksha_.FinishedLoading(config);
}

template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) {
  assert(index <= word_mask_);
  uint64_t at_pointer = insert_index_ * total_bits_;
  util::WriteInt57(base_, at_pointer, word_bits_, index);
  at_pointer += word_bits_;
  quant_.Write(base_, at_pointer, prob);
  ++insert_index_;
}

template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float &prob, const NodeRange &range) const {
  uint64_t at_pointer;
  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
  at_pointer = at_pointer * total_bits_ + word_bits_;
  quant_.Read(base_, at_pointer, prob);
  return true;
}

template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>;
template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>;
template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>;
template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>;
template class BitPackedLongest<DontQuantize::Longest>;
template class BitPackedLongest<SeparatelyQuantize::Longest>;

} // namespace trie
} // namespace ngram
} // namespace lm