summaryrefslogtreecommitdiff
path: root/klm/lm/binary_format.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r--klm/lm/binary_format.cc259
1 files changed, 152 insertions, 107 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 39c4a9b6..9c744b13 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -8,11 +8,15 @@
#include <cstring>
#include <limits>
#include <string>
+#include <cstdlib>
#include <stdint.h>
namespace lm {
namespace ngram {
+
+const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
+
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
@@ -57,8 +61,6 @@ struct Sanity {
}
};
-const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
-
std::size_t TotalHeaderSize(unsigned char order) {
return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
@@ -80,83 +82,6 @@ void WriteHeader(void *to, const Parameters &params) {
} // namespace
-uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
- if (config.write_mmap) {
- std::size_t total = TotalHeaderSize(order) + memory_size;
- backing.file.reset(util::CreateOrThrow(config.write_mmap));
- if (config.write_method == Config::WRITE_MMAP) {
- backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
- } else {
- util::ResizeOrThrow(backing.file.get(), 0);
- util::MapAnonymous(total, backing.vocab);
- }
- strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
- return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
- } else {
- util::MapAnonymous(memory_size, backing.vocab);
- return reinterpret_cast<uint8_t*>(backing.vocab.get());
- }
-}
-
-uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
- std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
- if (config.write_mmap) {
- // Grow the file to accomodate the search, using zeros.
- try {
- util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
- } catch (util::ErrnoException &e) {
- e << " for file " << config.write_mmap;
- throw e;
- }
-
- if (config.write_method == Config::WRITE_AFTER) {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
- // mmap it now.
- // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
- std::size_t page_size = util::SizePage();
- std::size_t alignment_cruft = adjusted_vocab % page_size;
- backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
- } else {
- util::MapAnonymous(memory_size, backing.search);
- return reinterpret_cast<uint8_t*>(backing.search.get());
- }
-}
-
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
- if (!config.write_mmap) return;
- switch (config.write_method) {
- case Config::WRITE_MMAP:
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
- util::SyncOrThrow(backing.search.get(), backing.search.size());
- break;
- case Config::WRITE_AFTER:
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
- util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
- util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
- util::FSyncOrThrow(backing.file.get());
- break;
- }
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params = Parameters();
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- params.fixed.search_version = search_version;
- WriteHeader(backing.vocab.get(), params);
- if (config.write_method == Config::WRITE_AFTER) {
- util::SeekOrThrow(backing.file.get(), 0);
- util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
- }
-}
-
-namespace detail {
-
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
@@ -169,21 +94,21 @@ bool IsBinaryFormat(int fd) {
}
Sanity reference_header = Sanity();
reference_header.SetToReference();
- if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
- if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
+ if (!std::memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
+ if (!std::memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
UTIL_THROW(FormatLoadException, "This binary file did not finish building");
}
- if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
+ if (!std::memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
char *end_ptr;
const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
- long int version = strtol(begin_version, &end_ptr, 10);
+ long int version = std::strtol(begin_version, &end_ptr, 10);
if ((end_ptr != begin_version) && version != kMagicVersion) {
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
OldSanity old_sanity = OldSanity();
old_sanity.SetToReference();
- UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
+ UTIL_THROW_IF(!std::memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;
@@ -208,44 +133,164 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);
}
-void SeekPastHeader(int fd, const Parameters &params) {
- util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
+const std::size_t kInvalidSize = static_cast<std::size_t>(-1);
+
+BinaryFormat::BinaryFormat(const Config &config)
+ : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method),
+ header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {}
+
+void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params) {
+ file_.reset(fd);
+ write_mmap_ = NULL; // Ignore write requests; this is already in binary format.
+ ReadHeader(fd, params);
+ MatchCheck(model_type, search_version, params);
+ header_size_ = TotalHeaderSize(params.counts.size());
+}
+
+void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const {
+ assert(header_size_ != kInvalidSize);
+ util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_);
}
-uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
- const uint64_t file_size = util::SizeFile(backing.file.get());
+void *BinaryFormat::LoadBinary(std::size_t size) {
+ assert(header_size_ != kInvalidSize);
+ const uint64_t file_size = util::SizeFile(file_.get());
// The header is smaller than a page, so we have to map the whole header as well.
- std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
- if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
- UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
+ uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size);
+ UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
- util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search);
+ util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_);
- if (config.enumerate_vocab && !params.fixed.has_vocabulary)
- UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+ vocab_string_offset_ = total_map;
+ return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+}
+
+void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
+ vocab_size_ = memory_size;
+ if (!write_mmap_) {
+ header_size_ = 0;
+ util::MapAnonymous(memory_size, memory_vocab_);
+ return reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ }
+ header_size_ = TotalHeaderSize(order);
+ std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size));
+ file_.reset(util::CreateOrThrow(write_mmap_));
+ // some gccs complain about uninitialized variables even though all enum values are covered.
+ void *vocab_base = NULL;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = mapping_.get();
+ break;
+ case Config::WRITE_AFTER:
+ util::ResizeOrThrow(file_.get(), 0);
+ util::MapAnonymous(total, memory_vocab_);
+ vocab_base = memory_vocab_.get();
+ break;
+ }
+ strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_);
+ return reinterpret_cast<uint8_t*>(vocab_base) + header_size_;
+}
- // Seek to vocabulary words
- util::SeekOrThrow(backing.file.get(), total_map);
- return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
+void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) {
+ assert(vocab_size_ != kInvalidSize);
+ vocab_pad_ = vocab_pad;
+ std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size;
+ vocab_string_offset_ = new_size;
+ if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) {
+ util::MapAnonymous(memory_size, memory_search_);
+ assert(header_size_ == 0 || write_mmap_);
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ return reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+
+ assert(write_method_ == Config::WRITE_MMAP);
+ // Also known as total size without vocab words.
+ // Grow the file to accomodate the search, using zeros.
+ // According to man mmap, behavior is undefined when the file is resized
+ // underneath a mmap that is not a multiple of the page size. So to be
+ // safe, we'll unmap it and map it again.
+ mapping_.reset();
+ util::ResizeOrThrow(file_.get(), new_size);
+ void *ret;
+ MapFile(vocab_base, ret);
+ return ret;
}
-void ComplainAboutARPA(const Config &config, ModelType model_type) {
- if (config.write_mmap || !config.messages) return;
- if (config.arpa_complain == Config::ALL) {
- *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
- } else if (config.arpa_complain == Config::EXPENSIVE &&
- (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
- *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) {
+ // Checking Config's include_vocab is the responsibility of the caller.
+ assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize);
+ if (!write_mmap_) {
+ // Unchanged base.
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get());
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ return;
+ }
+ if (write_method_ == Config::WRITE_MMAP) {
+ mapping_.reset();
+ }
+ util::SeekOrThrow(file_.get(), VocabStringReadingOffset());
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ if (write_method_ == Config::WRITE_MMAP) {
+ MapFile(vocab_base, search_base);
+ } else {
+ vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
+ }
+}
+
+void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) {
+ if (!write_mmap_) return;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size());
+ util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_);
+ util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size());
+ util::FSyncOrThrow(file_.get());
+ break;
+ }
+ // header and vocab share the same mmap.
+ Parameters params = Parameters();
+ memset(&params, 0, sizeof(Parameters));
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ params.fixed.search_version = search_version;
+ switch (write_method_) {
+ case Config::WRITE_MMAP:
+ WriteHeader(mapping_.get(), params);
+ util::SyncOrThrow(mapping_.get(), mapping_.size());
+ break;
+ case Config::WRITE_AFTER:
+ {
+ std::vector<uint8_t> buffer(TotalHeaderSize(counts.size()));
+ WriteHeader(&buffer[0], params);
+ util::SeekOrThrow(file_.get(), 0);
+ util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
+ }
+ break;
}
}
-} // namespace detail
+void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) {
+ mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED);
+ vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
+ search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_;
+}
bool RecognizeBinary(const char *file, ModelType &recognized) {
util::scoped_fd fd(util::OpenReadOrThrow(file));
- if (!detail::IsBinaryFormat(fd.get())) return false;
+ if (!IsBinaryFormat(fd.get())) {
+ return false;
+ }
Parameters params;
- detail::ReadHeader(fd.get(), params);
+ ReadHeader(fd.get(), params);
recognized = params.fixed.model_type;
return true;
}