diff options
Diffstat (limited to 'klm/lm/trie.cc')
| -rw-r--r-- | klm/lm/trie.cc | 61 | 
1 files changed, 20 insertions, 41 deletions
| diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 20075bb8..0f1ca574 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,7 +1,6 @@  #include "lm/trie.hh"  #include "lm/bhiksha.hh" -#include "lm/quantize.hh"  #include "util/bit_packing.hh"  #include "util/exception.hh"  #include "util/sorted_uniform.hh" @@ -58,91 +57,71 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)    max_vocab_ = max_vocab;  } -template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {    return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));  } -template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : +template <class Bhiksha> BitPackedMiddle<Bhiksha>::BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :    BitPacked(), -  quant_(quant), +  quant_bits_(quant_bits),    // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.    bhiksha_(base, entries + 1, max_next, config),    next_source_(&next_source) {    if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57)))  UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order.  Edit util/bit_packing.hh and fix the bit packing functions."); -  BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits()); +  BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant_bits_ + bhiksha_.InlineBits());  } -template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) { +template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Insert(WordIndex word) {    assert(word <= word_mask_);    uint64_t at_pointer = insert_index_ * total_bits_;    util::WriteInt57(base_, at_pointer, word_bits_, word);    at_pointer += word_bits_; -  quant_.Write(base_, at_pointer, prob, backoff); -  at_pointer += quant_.TotalBits(); +  util::BitAddress ret(base_, at_pointer); +  at_pointer += quant_bits_;    uint64_t next = next_source_->InsertIndex();    bhiksha_.WriteNext(base_, at_pointer, insert_index_, next); -    ++insert_index_; +  return ret;  } -template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const { +template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Find(WordIndex word, NodeRange &range, uint64_t &pointer) const {    uint64_t at_pointer;    if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { -    return false; +    return util::BitAddress(NULL, 0);    }    pointer = at_pointer;    at_pointer *= total_bits_;    at_pointer += word_bits_; +  bhiksha_.ReadNext(base_, at_pointer + quant_bits_, pointer, total_bits_, range); -  quant_.Read(base_, at_pointer, prob, backoff); -  at_pointer += quant_.TotalBits(); - -  bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range); - -  return true; +  return util::BitAddress(base_, at_pointer);  } -template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { -  uint64_t index; -  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; -  uint64_t at_pointer = index * total_bits_; -  at_pointer += word_bits_; -  quant_.ReadBackoff(base_, at_pointer, backoff); -  at_pointer += quant_.TotalBits(); -  bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); -  return true; -} - -template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) { +template <class Bhiksha> void BitPackedMiddle<Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {    uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits();    bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end);    bhiksha_.FinishedLoading(config);  } -template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) { +util::BitAddress BitPackedLongest::Insert(WordIndex index) {    assert(index <= word_mask_);    uint64_t at_pointer = insert_index_ * total_bits_;    util::WriteInt57(base_, at_pointer, word_bits_, index);    at_pointer += word_bits_; -  quant_.Write(base_, at_pointer, prob);    ++insert_index_; +  return util::BitAddress(base_, at_pointer);  } -template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float &prob, const NodeRange &range) const { +util::BitAddress BitPackedLongest::Find(WordIndex word, const NodeRange &range) const {    uint64_t at_pointer; -  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; +  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return util::BitAddress(NULL, 0);    at_pointer = at_pointer * total_bits_ + word_bits_; -  quant_.Read(base_, at_pointer, prob); -  return true; +  return util::BitAddress(base_, at_pointer);  } -template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>; -template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>; -template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>; -template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>; -template class BitPackedLongest<DontQuantize::Longest>; -template class BitPackedLongest<SeparatelyQuantize::Longest>; +template class BitPackedMiddle<DontBhiksha>; +template class BitPackedMiddle<ArrayBhiksha>;  } // namespace trie  } // namespace ngram | 
