summaryrefslogtreecommitdiff
path: root/klm/lm/binary_format.cc
diff options
context:
space:
mode:
authorPatrick Simianer <p@simianer.de>2012-03-13 09:24:47 +0100
committerPatrick Simianer <p@simianer.de>2012-03-13 09:24:47 +0100
commitef6085e558e26c8819f1735425761103021b6470 (patch)
tree5cf70e4c48c64d838e1326b5a505c8c4061bff4a /klm/lm/binary_format.cc
parent10a232656a0c882b3b955d2bcfac138ce11e8a2e (diff)
parentdfbc278c1057555fda9312291c8024049e00b7d8 (diff)
merge with upstream
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r--klm/lm/binary_format.cc139
1 files changed, 78 insertions, 61 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 27cada13..4796f6d1 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -1,19 +1,15 @@
#include "lm/binary_format.hh"
#include "lm/lm_exception.hh"
+#include "util/file.hh"
#include "util/file_piece.hh"
+#include <cstddef>
+#include <cstring>
#include <limits>
#include <string>
-#include <fcntl.h>
-#include <errno.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sys/mman.h>
-#include <sys/types.h>
-#include <sys/stat.h>
-#include <unistd.h>
+#include <stdint.h>
namespace lm {
namespace ngram {
@@ -24,14 +20,16 @@ const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 5;
-// Test values.
-struct Sanity {
+// Old binary files built on 32-bit machines have this header.
+// TODO: eliminate with next binary release.
+struct OldSanity {
char magic[sizeof(kMagicBytes)];
float zero_f, one_f, minus_half_f;
WordIndex one_word_index, max_word_index;
uint64_t one_uint64;
void SetToReference() {
+ std::memset(this, 0, sizeof(OldSanity));
std::memcpy(magic, kMagicBytes, sizeof(magic));
zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
one_word_index = 1;
@@ -40,27 +38,35 @@ struct Sanity {
}
};
-const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
-std::size_t TotalHeaderSize(unsigned char order) {
- return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
-}
+// Test values aligned to 8 bytes.
+struct Sanity {
+ char magic[ALIGN8(sizeof(kMagicBytes))];
+ float zero_f, one_f, minus_half_f;
+ WordIndex one_word_index, max_word_index, padding_to_8;
+ uint64_t one_uint64;
-void ReadLoop(int fd, void *to_void, std::size_t size) {
- uint8_t *to = static_cast<uint8_t*>(to_void);
- while (size) {
- ssize_t ret = read(fd, to, size);
- if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file");
- if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short");
- to += ret;
- size -= ret;
+ void SetToReference() {
+ std::memset(this, 0, sizeof(Sanity));
+ std::memcpy(magic, kMagicBytes, sizeof(kMagicBytes));
+ zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
+ one_word_index = 1;
+ max_word_index = std::numeric_limits<WordIndex>::max();
+ padding_to_8 = 0;
+ one_uint64 = 1;
}
+};
+
+const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
+
+std::size_t TotalHeaderSize(unsigned char order) {
+ return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
void WriteHeader(void *to, const Parameters &params) {
Sanity header = Sanity();
header.SetToReference();
- memcpy(to, &header, sizeof(Sanity));
+ std::memcpy(to, &header, sizeof(Sanity));
char *out = reinterpret_cast<char*>(to) + sizeof(Sanity);
*reinterpret_cast<FixedWidthParameters*>(out) = params.fixed;
@@ -74,14 +80,6 @@ void WriteHeader(void *to, const Parameters &params) {
} // namespace
-void SeekOrThrow(int fd, off_t off) {
- if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed");
-}
-
-void AdvanceOrThrow(int fd, off_t off) {
- if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed");
-}
-
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
@@ -89,7 +87,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
- backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ util::MapAnonymous(memory_size, backing.vocab);
return reinterpret_cast<uint8_t*>(backing.vocab.get());
}
}
@@ -98,42 +96,58 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
// Grow the file to accomodate the search, using zeros.
- if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size))
- UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed");
+ try {
+ util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
+ } catch (util::ErrnoException &e) {
+ e << " for file " << config.write_mmap;
+ throw e;
+ }
+ if (config.write_method == Config::WRITE_AFTER) {
+ util::MapAnonymous(memory_size, backing.search);
+ return reinterpret_cast<uint8_t*>(backing.search.get());
+ }
+ // mmap it now.
// We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
- off_t page_size = sysconf(_SC_PAGE_SIZE);
- off_t alignment_cruft = adjusted_vocab % page_size;
+ std::size_t page_size = util::SizePage();
+ std::size_t alignment_cruft = adjusted_vocab % page_size;
backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
-
return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
} else {
- backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ util::MapAnonymous(memory_size, backing.search);
return reinterpret_cast<uint8_t*>(backing.search.get());
}
}
-void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing) {
- if (config.write_mmap) {
- if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC))
- UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap);
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params;
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- params.fixed.search_version = search_version;
- WriteHeader(backing.vocab.get(), params);
+void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
+ if (!config.write_mmap) return;
+ util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
+ switch (config.write_method) {
+ case Config::WRITE_MMAP:
+ util::SyncOrThrow(backing.search.get(), backing.search.size());
+ break;
+ case Config::WRITE_AFTER:
+ util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
+ util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
+ util::FSyncOrThrow(backing.file.get());
+ break;
}
+ // header and vocab share the same mmap. The header is written here because we know the counts.
+ Parameters params = Parameters();
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ params.fixed.search_version = search_version;
+ WriteHeader(backing.vocab.get(), params);
}
namespace detail {
bool IsBinaryFormat(int fd) {
- const off_t size = util::SizeFile(fd);
- if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(Sanity)))) return false;
+ const uint64_t size = util::SizeFile(fd);
+ if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
// Try reading the header.
util::scoped_memory memory;
try {
@@ -154,19 +168,23 @@ bool IsBinaryFormat(int fd) {
if ((end_ptr != begin_version) && version != kMagicVersion) {
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
+
+ OldSanity old_sanity = OldSanity();
+ old_sanity.SetToReference();
+ UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;
}
void ReadHeader(int fd, Parameters &out) {
- SeekOrThrow(fd, sizeof(Sanity));
- ReadLoop(fd, &out.fixed, sizeof(out.fixed));
+ util::SeekOrThrow(fd, sizeof(Sanity));
+ util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed));
if (out.fixed.probing_multiplier < 1.0)
UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
out.counts.resize(static_cast<std::size_t>(out.fixed.order));
- ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
+ if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
}
void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params) {
@@ -179,11 +197,11 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet
}
void SeekPastHeader(int fd, const Parameters &params) {
- SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
+ util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
}
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
- const off_t file_size = util::SizeFile(backing.file.get());
+ const uint64_t file_size = util::SizeFile(backing.file.get());
// The header is smaller than a page, so we have to map the whole header as well.
std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size;
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
@@ -194,9 +212,8 @@ uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t
if (config.enumerate_vocab && !params.fixed.has_vocabulary)
UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
- if (config.enumerate_vocab) {
- SeekOrThrow(backing.file.get(), total_map);
- }
+ // Seek to vocabulary words
+ util::SeekOrThrow(backing.file.get(), total_map);
return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
}