diff options
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r-- | klm/lm/binary_format.cc | 73 |
1 files changed, 43 insertions, 30 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 69a06355..3d9700da 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -18,8 +18,8 @@ namespace lm { namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0"; -const long int kMagicVersion = 1; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 3\n\0"; +const long int kMagicVersion = 2; // Test values. struct Sanity { @@ -76,6 +76,45 @@ void WriteHeader(void *to, const Parameters ¶ms) { } } // namespace + +uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { + if (config.write_mmap) { + std::size_t total = TotalHeaderSize(order) + memory_size; + backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED); + return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); + } else { + backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + return reinterpret_cast<uint8_t*>(backing.vocab.get()); + } +} + +uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) { + if (config.write_mmap) { + // header and vocab share the same mmap. The header is written here because we know the counts. + Parameters params; + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + WriteHeader(backing.vocab.get(), params); + + // Grow the file to accomodate the search, using zeros. + if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) + UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); + + // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. + off_t page_size = sysconf(_SC_PAGE_SIZE); + off_t alignment_cruft = backing.vocab.size() % page_size; + backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); + + return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft; + } else { + backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + return reinterpret_cast<uint8_t*>(backing.search.get()); + } +} + namespace detail { bool IsBinaryFormat(int fd) { @@ -128,7 +167,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); - util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory); + util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search); if (config.enumerate_vocab && !params.fixed.has_vocabulary) UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); @@ -137,33 +176,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words"); } - return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(params.counts.size()); -} - -uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) { - if (config.write_mmap) { - std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size; - // Write out an mmap file. - backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED); - - Parameters params; - params.counts = counts; - params.fixed.order = counts.size(); - params.fixed.probing_multiplier = config.probing_multiplier; - params.fixed.model_type = model_type; - params.fixed.has_vocabulary = config.include_vocab; - - WriteHeader(backing.memory.get(), params); - - if (params.fixed.has_vocabulary) { - if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) - UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words"); - } - return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(counts.size()); - } else { - backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); - return reinterpret_cast<uint8_t*>(backing.memory.get()); - } + return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size()); } void ComplainAboutARPA(const Config &config, ModelType model_type) { |