summaryrefslogtreecommitdiff
path: root/klm/lm/binary_format.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r--klm/lm/binary_format.cc16
1 files changed, 6 insertions, 10 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index e02e621a..27cada13 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -19,10 +19,10 @@ namespace lm {
namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
-const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0";
+const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
-const long int kMagicVersion = 4;
+const long int kMagicVersion = 5;
// Test values.
struct Sanity {
@@ -42,12 +42,6 @@ struct Sanity {
const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
-std::size_t Align8(std::size_t in) {
- std::size_t off = in % 8;
- if (!off) return in;
- return in + 8 - off;
-}
-
std::size_t TotalHeaderSize(unsigned char order) {
return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
@@ -119,7 +113,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
}
}
-void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing) {
+void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing) {
if (config.write_mmap) {
if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC))
UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap);
@@ -130,6 +124,7 @@ void FinishFile(const Config &config, ModelType model_type, const std::vector<ui
params.fixed.probing_multiplier = config.probing_multiplier;
params.fixed.model_type = model_type;
params.fixed.has_vocabulary = config.include_vocab;
+ params.fixed.search_version = search_version;
WriteHeader(backing.vocab.get(), params);
}
}
@@ -174,12 +169,13 @@ void ReadHeader(int fd, Parameters &out) {
ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
}
-void MatchCheck(ModelType model_type, const Parameters &params) {
+void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params) {
if (params.fixed.model_type != model_type) {
if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *)))
UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code.");
UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]);
}
+ UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);
}
void SeekPastHeader(int fd, const Parameters &params) {