summaryrefslogtreecommitdiff
path: root/klm/lm/trie.cc
diff options
context:
space:
mode:
authorPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-05-31 13:57:24 +0200
committerPatrick Simianer <simianer@cl.uni-heidelberg.de>2012-05-31 13:57:24 +0200
commitf1ba05780db1705493d9afb562332498b93d26f1 (patch)
treefb429a657ba97f33e8140742de9bc74d9fc88e75 /klm/lm/trie.cc
parentaadabfdf37dfd451485277cb77fad02f77b361c6 (diff)
parent317d650f6cb1e24ac6f3be6f7bf9d4246a59e0e5 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/lm/trie.cc')
-rw-r--r--klm/lm/trie.cc61
1 files changed, 20 insertions, 41 deletions
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 20075bb8..0f1ca574 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -1,7 +1,6 @@
#include "lm/trie.hh"
#include "lm/bhiksha.hh"
-#include "lm/quantize.hh"
#include "util/bit_packing.hh"
#include "util/exception.hh"
#include "util/sorted_uniform.hh"
@@ -58,91 +57,71 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)
max_vocab_ = max_vocab;
}
-template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
+template <class Bhiksha> std::size_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}
-template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
+template <class Bhiksha> BitPackedMiddle<Bhiksha>::BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
BitPacked(),
- quant_(quant),
+ quant_bits_(quant_bits),
// If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.
bhiksha_(base, entries + 1, max_next, config),
next_source_(&next_source) {
if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
- BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits());
+ BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant_bits_ + bhiksha_.InlineBits());
}
-template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) {
+template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Insert(WordIndex word) {
assert(word <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_, at_pointer, word_bits_, word);
at_pointer += word_bits_;
- quant_.Write(base_, at_pointer, prob, backoff);
- at_pointer += quant_.TotalBits();
+ util::BitAddress ret(base_, at_pointer);
+ at_pointer += quant_bits_;
uint64_t next = next_source_->InsertIndex();
bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);
-
++insert_index_;
+ return ret;
}
-template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const {
+template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Find(WordIndex word, NodeRange &range, uint64_t &pointer) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
- return false;
+ return util::BitAddress(NULL, 0);
}
pointer = at_pointer;
at_pointer *= total_bits_;
at_pointer += word_bits_;
+ bhiksha_.ReadNext(base_, at_pointer + quant_bits_, pointer, total_bits_, range);
- quant_.Read(base_, at_pointer, prob, backoff);
- at_pointer += quant_.TotalBits();
-
- bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range);
-
- return true;
+ return util::BitAddress(base_, at_pointer);
}
-template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
- uint64_t index;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false;
- uint64_t at_pointer = index * total_bits_;
- at_pointer += word_bits_;
- quant_.ReadBackoff(base_, at_pointer, backoff);
- at_pointer += quant_.TotalBits();
- bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
- return true;
-}
-
-template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
+template <class Bhiksha> void BitPackedMiddle<Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits();
bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end);
bhiksha_.FinishedLoading(config);
}
-template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) {
+util::BitAddress BitPackedLongest::Insert(WordIndex index) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_, at_pointer, word_bits_, index);
at_pointer += word_bits_;
- quant_.Write(base_, at_pointer, prob);
++insert_index_;
+ return util::BitAddress(base_, at_pointer);
}
-template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float &prob, const NodeRange &range) const {
+util::BitAddress BitPackedLongest::Find(WordIndex word, const NodeRange &range) const {
uint64_t at_pointer;
- if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false;
+ if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return util::BitAddress(NULL, 0);
at_pointer = at_pointer * total_bits_ + word_bits_;
- quant_.Read(base_, at_pointer, prob);
- return true;
+ return util::BitAddress(base_, at_pointer);
}
-template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>;
-template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>;
-template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>;
-template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>;
-template class BitPackedLongest<DontQuantize::Longest>;
-template class BitPackedLongest<SeparatelyQuantize::Longest>;
+template class BitPackedMiddle<DontBhiksha>;
+template class BitPackedMiddle<ArrayBhiksha>;
} // namespace trie
} // namespace ngram