summaryrefslogtreecommitdiff
path: root/klm/lm/binary_format.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r--klm/lm/binary_format.cc191
1 files changed, 191 insertions, 0 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
new file mode 100644
index 00000000..2a075b6b
--- /dev/null
+++ b/klm/lm/binary_format.cc
@@ -0,0 +1,191 @@
+#include "lm/binary_format.hh"
+
+#include "lm/lm_exception.hh"
+#include "util/file_piece.hh"
+
+#include <limits>
+#include <string>
+
+#include <fcntl.h>
+#include <errno.h>
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+namespace lm {
+namespace ngram {
+namespace {
+const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
+const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0";
+const long int kMagicVersion = 1;
+
+// Test values.
+struct Sanity {
+ char magic[sizeof(kMagicBytes)];
+ float zero_f, one_f, minus_half_f;
+ WordIndex one_word_index, max_word_index;
+ uint64_t one_uint64;
+
+ void SetToReference() {
+ std::memcpy(magic, kMagicBytes, sizeof(magic));
+ zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
+ one_word_index = 1;
+ max_word_index = std::numeric_limits<WordIndex>::max();
+ one_uint64 = 1;
+ }
+};
+
+const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"};
+
+std::size_t Align8(std::size_t in) {
+ std::size_t off = in % 8;
+ if (!off) return in;
+ return in + 8 - off;
+}
+
+std::size_t TotalHeaderSize(unsigned char order) {
+ return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
+}
+
+void ReadLoop(int fd, void *to_void, std::size_t size) {
+ uint8_t *to = static_cast<uint8_t*>(to_void);
+ while (size) {
+ ssize_t ret = read(fd, to, size);
+ if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file");
+ if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short");
+ to += ret;
+ size -= ret;
+ }
+}
+
+void WriteHeader(void *to, const Parameters &params) {
+ Sanity header = Sanity();
+ header.SetToReference();
+ memcpy(to, &header, sizeof(Sanity));
+ char *out = reinterpret_cast<char*>(to) + sizeof(Sanity);
+
+ *reinterpret_cast<FixedWidthParameters*>(out) = params.fixed;
+ out += sizeof(FixedWidthParameters);
+
+ uint64_t *counts = reinterpret_cast<uint64_t*>(out);
+ for (std::size_t i = 0; i < params.counts.size(); ++i) {
+ counts[i] = params.counts[i];
+ }
+}
+
+} // namespace
+namespace detail {
+
+bool IsBinaryFormat(int fd) {
+ const off_t size = util::SizeFile(fd);
+ if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(Sanity)))) return false;
+ // Try reading the header.
+ util::scoped_memory memory;
+ try {
+ util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
+ } catch (const util::Exception &e) {
+ return false;
+ }
+ Sanity reference_header = Sanity();
+ reference_header.SetToReference();
+ if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
+ if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
+ char *end_ptr;
+ const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
+ long int version = strtol(begin_version, &end_ptr, 10);
+ if ((end_ptr != begin_version) && version != kMagicVersion) {
+ UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA. Sorry.");
+ }
+ UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture.");
+ }
+ return false;
+}
+
+void ReadHeader(int fd, Parameters &out) {
+ if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file");
+ ReadLoop(fd, &out.fixed, sizeof(out.fixed));
+ if (out.fixed.probing_multiplier < 1.0)
+ UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
+
+ out.counts.resize(static_cast<std::size_t>(out.fixed.order));
+ ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
+}
+
+void MatchCheck(ModelType model_type, const Parameters &params) {
+ if (params.fixed.model_type != model_type) {
+ if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *)))
+ UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code.");
+ UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]);
+ }
+}
+
+uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
+ const off_t file_size = util::SizeFile(backing.file.get());
+ // The header is smaller than a page, so we have to map the whole header as well.
+ std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size;
+ if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
+ UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
+
+ util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory);
+
+ if (config.enumerate_vocab && !params.fixed.has_vocabulary)
+ UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+
+ if (config.enumerate_vocab) {
+ if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
+ UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words");
+ }
+ return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(params.counts.size());
+}
+
+uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) {
+ if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0");
+ if (config.write_mmap) {
+ std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size;
+ // Write out an mmap file.
+ backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED);
+
+ Parameters params;
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+
+ WriteHeader(backing.memory.get(), params);
+
+ if (params.fixed.has_vocabulary) {
+ if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
+ UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words");
+ }
+ return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(counts.size());
+ } else {
+ backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ return reinterpret_cast<uint8_t*>(backing.memory.get());
+ }
+}
+
+void ComplainAboutARPA(const Config &config, ModelType model_type) {
+ if (config.write_mmap || !config.messages) return;
+ if (config.arpa_complain == Config::ALL) {
+ *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
+ } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) {
+ *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+ }
+}
+
+} // namespace detail
+
+bool RecognizeBinary(const char *file, ModelType &recognized) {
+ util::scoped_fd fd(util::OpenReadOrThrow(file));
+ if (!detail::IsBinaryFormat(fd.get())) return false;
+ Parameters params;
+ detail::ReadHeader(fd.get(), params);
+ recognized = params.fixed.model_type;
+ return true;
+}
+
+} // namespace ngram
+} // namespace lm