From 2b63fa0755954edf467a2421997eaf72771260cf Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 16 May 2012 13:24:08 -0700 Subject: Big kenlm change includes lower order models for probing only. And other stuff. --- klm/lm/trie.cc | 61 +++++++++++++++++++--------------------------------------- 1 file changed, 20 insertions(+), 41 deletions(-) (limited to 'klm/lm/trie.cc') diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 20075bb8..0f1ca574 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,7 +1,6 @@ #include "lm/trie.hh" #include "lm/bhiksha.hh" -#include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" #include "util/sorted_uniform.hh" @@ -58,91 +57,71 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } -template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : +template BitPackedMiddle::BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : BitPacked(), - quant_(quant), + quant_bits_(quant_bits), // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. bhiksha_(base, entries + 1, max_next, config), next_source_(&next_source) { if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - BaseInit(reinterpret_cast(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits()); + BaseInit(reinterpret_cast(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant_bits_ + bhiksha_.InlineBits()); } -template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template util::BitAddress BitPackedMiddle::Insert(WordIndex word) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; util::WriteInt57(base_, at_pointer, word_bits_, word); at_pointer += word_bits_; - quant_.Write(base_, at_pointer, prob, backoff); - at_pointer += quant_.TotalBits(); + util::BitAddress ret(base_, at_pointer); + at_pointer += quant_bits_; uint64_t next = next_source_->InsertIndex(); bhiksha_.WriteNext(base_, at_pointer, insert_index_, next); - ++insert_index_; + return ret; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const { +template util::BitAddress BitPackedMiddle::Find(WordIndex word, NodeRange &range, uint64_t &pointer) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { - return false; + return util::BitAddress(NULL, 0); } pointer = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; + bhiksha_.ReadNext(base_, at_pointer + quant_bits_, pointer, total_bits_, range); - quant_.Read(base_, at_pointer, prob, backoff); - at_pointer += quant_.TotalBits(); - - bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range); - - return true; + return util::BitAddress(base_, at_pointer); } -template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { - uint64_t index; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; - uint64_t at_pointer = index * total_bits_; - at_pointer += word_bits_; - quant_.ReadBackoff(base_, at_pointer, backoff); - at_pointer += quant_.TotalBits(); - bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); - return true; -} - -template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { +template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); bhiksha_.FinishedLoading(config); } -template void BitPackedLongest::Insert(WordIndex index, float prob) { +util::BitAddress BitPackedLongest::Insert(WordIndex index) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; util::WriteInt57(base_, at_pointer, word_bits_, index); at_pointer += word_bits_; - quant_.Write(base_, at_pointer, prob); ++insert_index_; + return util::BitAddress(base_, at_pointer); } -template bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { +util::BitAddress BitPackedLongest::Find(WordIndex word, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return util::BitAddress(NULL, 0); at_pointer = at_pointer * total_bits_ + word_bits_; - quant_.Read(base_, at_pointer, prob); - return true; + return util::BitAddress(base_, at_pointer); } -template class BitPackedMiddle; -template class BitPackedMiddle; -template class BitPackedMiddle; -template class BitPackedMiddle; -template class BitPackedLongest; -template class BitPackedLongest; +template class BitPackedMiddle; +template class BitPackedMiddle; } // namespace trie } // namespace ngram -- cgit v1.2.3