diff options
author | Patrick Simianer <p@simianer.de> | 2012-03-13 09:24:47 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2012-03-13 09:24:47 +0100 |
commit | ef6085e558e26c8819f1735425761103021b6470 (patch) | |
tree | 5cf70e4c48c64d838e1326b5a505c8c4061bff4a /klm/lm/binary_format.cc | |
parent | 10a232656a0c882b3b955d2bcfac138ce11e8a2e (diff) | |
parent | dfbc278c1057555fda9312291c8024049e00b7d8 (diff) |
merge with upstream
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r-- | klm/lm/binary_format.cc | 139 |
1 files changed, 78 insertions, 61 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 27cada13..4796f6d1 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -1,19 +1,15 @@ #include "lm/binary_format.hh" #include "lm/lm_exception.hh" +#include "util/file.hh" #include "util/file_piece.hh" +#include <cstddef> +#include <cstring> #include <limits> #include <string> -#include <fcntl.h> -#include <errno.h> -#include <stdlib.h> -#include <string.h> -#include <sys/mman.h> -#include <sys/types.h> -#include <sys/stat.h> -#include <unistd.h> +#include <stdint.h> namespace lm { namespace ngram { @@ -24,14 +20,16 @@ const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; const long int kMagicVersion = 5; -// Test values. -struct Sanity { +// Old binary files built on 32-bit machines have this header. +// TODO: eliminate with next binary release. +struct OldSanity { char magic[sizeof(kMagicBytes)]; float zero_f, one_f, minus_half_f; WordIndex one_word_index, max_word_index; uint64_t one_uint64; void SetToReference() { + std::memset(this, 0, sizeof(OldSanity)); std::memcpy(magic, kMagicBytes, sizeof(magic)); zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; one_word_index = 1; @@ -40,27 +38,35 @@ struct Sanity { } }; -const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; -std::size_t TotalHeaderSize(unsigned char order) { - return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); -} +// Test values aligned to 8 bytes. +struct Sanity { + char magic[ALIGN8(sizeof(kMagicBytes))]; + float zero_f, one_f, minus_half_f; + WordIndex one_word_index, max_word_index, padding_to_8; + uint64_t one_uint64; -void ReadLoop(int fd, void *to_void, std::size_t size) { - uint8_t *to = static_cast<uint8_t*>(to_void); - while (size) { - ssize_t ret = read(fd, to, size); - if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file"); - if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short"); - to += ret; - size -= ret; + void SetToReference() { + std::memset(this, 0, sizeof(Sanity)); + std::memcpy(magic, kMagicBytes, sizeof(kMagicBytes)); + zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; + one_word_index = 1; + max_word_index = std::numeric_limits<WordIndex>::max(); + padding_to_8 = 0; + one_uint64 = 1; } +}; + +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; + +std::size_t TotalHeaderSize(unsigned char order) { + return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } void WriteHeader(void *to, const Parameters ¶ms) { Sanity header = Sanity(); header.SetToReference(); - memcpy(to, &header, sizeof(Sanity)); + std::memcpy(to, &header, sizeof(Sanity)); char *out = reinterpret_cast<char*>(to) + sizeof(Sanity); *reinterpret_cast<FixedWidthParameters*>(out) = params.fixed; @@ -74,14 +80,6 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace -void SeekOrThrow(int fd, off_t off) { - if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed"); -} - -void AdvanceOrThrow(int fd, off_t off) { - if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed"); -} - uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; @@ -89,7 +87,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); } else { - backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + util::MapAnonymous(memory_size, backing.vocab); return reinterpret_cast<uint8_t*>(backing.vocab.get()); } } @@ -98,42 +96,58 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { // Grow the file to accomodate the search, using zeros. - if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) - UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed"); + try { + util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); + } catch (util::ErrnoException &e) { + e << " for file " << config.write_mmap; + throw e; + } + if (config.write_method == Config::WRITE_AFTER) { + util::MapAnonymous(memory_size, backing.search); + return reinterpret_cast<uint8_t*>(backing.search.get()); + } + // mmap it now. // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. - off_t page_size = sysconf(_SC_PAGE_SIZE); - off_t alignment_cruft = adjusted_vocab % page_size; + std::size_t page_size = util::SizePage(); + std::size_t alignment_cruft = adjusted_vocab % page_size; backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); - return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft; } else { - backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + util::MapAnonymous(memory_size, backing.search); return reinterpret_cast<uint8_t*>(backing.search.get()); } } -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing) { - if (config.write_mmap) { - if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) - UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); - // header and vocab share the same mmap. The header is written here because we know the counts. - Parameters params; - params.counts = counts; - params.fixed.order = counts.size(); - params.fixed.probing_multiplier = config.probing_multiplier; - params.fixed.model_type = model_type; - params.fixed.has_vocabulary = config.include_vocab; - params.fixed.search_version = search_version; - WriteHeader(backing.vocab.get(), params); +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { + if (!config.write_mmap) return; + util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); + switch (config.write_method) { + case Config::WRITE_MMAP: + util::SyncOrThrow(backing.search.get(), backing.search.size()); + break; + case Config::WRITE_AFTER: + util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); + util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); + util::FSyncOrThrow(backing.file.get()); + break; } + // header and vocab share the same mmap. The header is written here because we know the counts. + Parameters params = Parameters(); + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + params.fixed.search_version = search_version; + WriteHeader(backing.vocab.get(), params); } namespace detail { bool IsBinaryFormat(int fd) { - const off_t size = util::SizeFile(fd); - if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(Sanity)))) return false; + const uint64_t size = util::SizeFile(fd); + if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; // Try reading the header. util::scoped_memory memory; try { @@ -154,19 +168,23 @@ bool IsBinaryFormat(int fd) { if ((end_ptr != begin_version) && version != kMagicVersion) { UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary"); } + + OldSanity old_sanity = OldSanity(); + old_sanity.SetToReference(); + UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable."); UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture"); } return false; } void ReadHeader(int fd, Parameters &out) { - SeekOrThrow(fd, sizeof(Sanity)); - ReadLoop(fd, &out.fixed, sizeof(out.fixed)); + util::SeekOrThrow(fd, sizeof(Sanity)); + util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed)); if (out.fixed.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); out.counts.resize(static_cast<std::size_t>(out.fixed.order)); - ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); + if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); } void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms) { @@ -179,11 +197,11 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet } void SeekPastHeader(int fd, const Parameters ¶ms) { - SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); + util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); } uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { - const off_t file_size = util::SizeFile(backing.file.get()); + const uint64_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) @@ -194,9 +212,8 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t if (config.enumerate_vocab && !params.fixed.has_vocabulary) UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); - if (config.enumerate_vocab) { - SeekOrThrow(backing.file.get(), total_map); - } + // Seek to vocabulary words + util::SeekOrThrow(backing.file.get(), total_map); return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size()); } |