From 47a656c0b6fdba8f91f2c5808234cbb1de682652 Mon Sep 17 00:00:00 2001 From: redpony Date: Wed, 10 Nov 2010 02:02:04 +0000 Subject: new version of klm git-svn-id: https://ws10smt.googlecode.com/svn/trunk@706 ec762483-ff6d-05da-a07a-a48fb63a330f --- klm/lm/binary_format.cc | 191 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 klm/lm/binary_format.cc (limited to 'klm/lm/binary_format.cc') diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc new file mode 100644 index 00000000..2a075b6b --- /dev/null +++ b/klm/lm/binary_format.cc @@ -0,0 +1,191 @@ +#include "lm/binary_format.hh" + +#include "lm/lm_exception.hh" +#include "util/file_piece.hh" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace lm { +namespace ngram { +namespace { +const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0"; +const long int kMagicVersion = 1; + +// Test values. +struct Sanity { + char magic[sizeof(kMagicBytes)]; + float zero_f, one_f, minus_half_f; + WordIndex one_word_index, max_word_index; + uint64_t one_uint64; + + void SetToReference() { + std::memcpy(magic, kMagicBytes, sizeof(magic)); + zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; + one_word_index = 1; + max_word_index = std::numeric_limits::max(); + one_uint64 = 1; + } +}; + +const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; + +std::size_t Align8(std::size_t in) { + std::size_t off = in % 8; + if (!off) return in; + return in + 8 - off; +} + +std::size_t TotalHeaderSize(unsigned char order) { + return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); +} + +void ReadLoop(int fd, void *to_void, std::size_t size) { + uint8_t *to = static_cast(to_void); + while (size) { + ssize_t ret = read(fd, to, size); + if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file"); + if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short"); + to += ret; + size -= ret; + } +} + +void WriteHeader(void *to, const Parameters ¶ms) { + Sanity header = Sanity(); + header.SetToReference(); + memcpy(to, &header, sizeof(Sanity)); + char *out = reinterpret_cast(to) + sizeof(Sanity); + + *reinterpret_cast(out) = params.fixed; + out += sizeof(FixedWidthParameters); + + uint64_t *counts = reinterpret_cast(out); + for (std::size_t i = 0; i < params.counts.size(); ++i) { + counts[i] = params.counts[i]; + } +} + +} // namespace +namespace detail { + +bool IsBinaryFormat(int fd) { + const off_t size = util::SizeFile(fd); + if (size == util::kBadSize || (size <= static_cast(sizeof(Sanity)))) return false; + // Try reading the header. + util::scoped_memory memory; + try { + util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory); + } catch (const util::Exception &e) { + return false; + } + Sanity reference_header = Sanity(); + reference_header.SetToReference(); + if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true; + if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) { + char *end_ptr; + const char *begin_version = static_cast(memory.get()) + strlen(kMagicBeforeVersion); + long int version = strtol(begin_version, &end_ptr, 10); + if ((end_ptr != begin_version) && version != kMagicVersion) { + UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA. Sorry."); + } + UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture."); + } + return false; +} + +void ReadHeader(int fd, Parameters &out) { + if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file"); + ReadLoop(fd, &out.fixed, sizeof(out.fixed)); + if (out.fixed.probing_multiplier < 1.0) + UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); + + out.counts.resize(static_cast(out.fixed.order)); + ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); +} + +void MatchCheck(ModelType model_type, const Parameters ¶ms) { + if (params.fixed.model_type != model_type) { + if (static_cast(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) + UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast(params.fixed.model_type) << " but this is not implemented for in this inference code."); + UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]); + } +} + +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { + const off_t file_size = util::SizeFile(backing.file.get()); + // The header is smaller than a page, so we have to map the whole header as well. + std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; + if (file_size != util::kBadSize && static_cast(file_size) < total_map) + UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); + + util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory); + + if (config.enumerate_vocab && !params.fixed.has_vocabulary) + UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + + if (config.enumerate_vocab) { + if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) + UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words"); + } + return reinterpret_cast(backing.memory.get()) + TotalHeaderSize(params.counts.size()); +} + +uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector &counts, std::size_t memory_size, Backing &backing) { + if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0"); + if (config.write_mmap) { + std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size; + // Write out an mmap file. + backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED); + + Parameters params; + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + + WriteHeader(backing.memory.get(), params); + + if (params.fixed.has_vocabulary) { + if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) + UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words"); + } + return reinterpret_cast(backing.memory.get()) + TotalHeaderSize(counts.size()); + } else { + backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + return reinterpret_cast(backing.memory.get()); + } +} + +void ComplainAboutARPA(const Config &config, ModelType model_type) { + if (config.write_mmap || !config.messages) return; + if (config.arpa_complain == Config::ALL) { + *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; + } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { + *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; + } +} + +} // namespace detail + +bool RecognizeBinary(const char *file, ModelType &recognized) { + util::scoped_fd fd(util::OpenReadOrThrow(file)); + if (!detail::IsBinaryFormat(fd.get())) return false; + Parameters params; + detail::ReadHeader(fd.get(), params); + recognized = params.fixed.model_type; + return true; +} + +} // namespace ngram +} // namespace lm -- cgit v1.2.3