summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/bhiksha.hh2
-rw-r--r--klm/lm/trie.cc6
-rw-r--r--klm/lm/trie.hh7
-rw-r--r--klm/lm/trie_sort.cc4
4 files changed, 10 insertions, 9 deletions
diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh
index ff7fe452..bc705959 100644
--- a/klm/lm/bhiksha.hh
+++ b/klm/lm/bhiksha.hh
@@ -11,6 +11,7 @@
*/
#include <inttypes.h>
+#include <assert.h>
#include "lm/model_type.hh"
#include "lm/trie.hh"
@@ -78,6 +79,7 @@ class ArrayBhiksha {
util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask);
out.end = ((end_it - offset_begin_) << next_inline_.bits) |
util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask);
+ //assert(out.end >= out.begin);
}
void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) {
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
index 4e60b184..20075bb8 100644
--- a/klm/lm/trie.cc
+++ b/klm/lm/trie.cc
@@ -91,16 +91,14 @@ template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
return false;
}
- uint64_t index = at_pointer;
+ pointer = at_pointer;
at_pointer *= total_bits_;
at_pointer += word_bits_;
- pointer = at_pointer;
-
quant_.Read(base_, at_pointer, prob, backoff);
at_pointer += quant_.TotalBits();
- bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);
+ bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range);
return true;
}
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
index a9f5e417..06cc96ac 100644
--- a/klm/lm/trie.hh
+++ b/klm/lm/trie.hh
@@ -99,10 +99,11 @@ template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {
bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
NodeRange ReadEntry(uint64_t pointer, float &prob) {
- quant_.ReadProb(base_, pointer, prob);
+ uint64_t addr = pointer * total_bits_;
+ addr += word_bits_;
+ quant_.ReadProb(base_, addr, prob);
NodeRange ret;
- // pointer/total_bits_ should always round down.
- bhiksha_.ReadNext(base_, pointer + quant_.TotalBits(), pointer / total_bits_, total_bits_, ret);
+ bhiksha_.ReadNext(base_, addr + quant_.TotalBits(), pointer, total_bits_, ret);
return ret;
}
diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc
index 01c4e490..86f28493 100644
--- a/klm/lm/trie_sort.cc
+++ b/klm/lm/trie_sort.cc
@@ -146,7 +146,7 @@ template <class Combine> void MergeSortedFiles(const std::string &first_name, co
++first; ++second;
}
}
- for (RecordReader &remains = (first ? second : first); remains; ++remains) {
+ for (RecordReader &remains = (first ? first : second); remains; ++remains) {
WriteOrThrow(out_file.get(), remains.Data(), entry_size);
}
}
@@ -191,7 +191,7 @@ void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const st
assembled << file_prefix << static_cast<unsigned int>(order) << "_merge_" << (merge_count++);
files.push_back(assembled.str());
MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine());
- MergeSortedFiles(files[0], files[1], files.back(), 0, order, FirstCombine());
+ MergeSortedFiles(files[0] + kContextSuffix, files[1] + kContextSuffix, files.back() + kContextSuffix, 0, order, FirstCombine());
files.pop_front();
files.pop_front();
}