diff options
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r-- | klm/lm/binary_format.cc | 16 |
1 files changed, 6 insertions, 10 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index e02e621a..27cada13 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -19,10 +19,10 @@ namespace lm { namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0"; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; // This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; -const long int kMagicVersion = 4; +const long int kMagicVersion = 5; // Test values. struct Sanity { @@ -42,12 +42,6 @@ struct Sanity { const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; -std::size_t Align8(std::size_t in) { - std::size_t off = in % 8; - if (!off) return in; - return in + 8 - off; -} - std::size_t TotalHeaderSize(unsigned char order) { return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } @@ -119,7 +113,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t } } -void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing) { +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing) { if (config.write_mmap) { if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); @@ -130,6 +124,7 @@ void FinishFile(const Config &config, ModelType model_type, const std::vector<ui params.fixed.probing_multiplier = config.probing_multiplier; params.fixed.model_type = model_type; params.fixed.has_vocabulary = config.include_vocab; + params.fixed.search_version = search_version; WriteHeader(backing.vocab.get(), params); } } @@ -174,12 +169,13 @@ void ReadHeader(int fd, Parameters &out) { ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); } -void MatchCheck(ModelType model_type, const Parameters ¶ms) { +void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms) { if (params.fixed.model_type != model_type) { if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code."); UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]); } + UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version); } void SeekPastHeader(int fd, const Parameters ¶ms) { |