summaryrefslogtreecommitdiff
path: root/klm/lm/binary_format.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r--klm/lm/binary_format.cc35
1 files changed, 23 insertions, 12 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 2a6aff34..9be0bc8e 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -9,6 +9,7 @@
#include <fcntl.h>
#include <errno.h>
#include <stdlib.h>
+#include <string.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
@@ -19,6 +20,8 @@ namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0";
+// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
+const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 4;
// Test values.
@@ -81,6 +84,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED);
+ strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
@@ -88,17 +92,8 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
}
}
-uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) {
+uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
- // header and vocab share the same mmap. The header is written here because we know the counts.
- Parameters params;
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
- WriteHeader(backing.vocab.get(), params);
-
// Grow the file to accomodate the search, using zeros.
if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size))
UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed");
@@ -115,6 +110,19 @@ uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::ve
}
}
+void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing) {
+ if (config.write_mmap) {
+ // header and vocab share the same mmap. The header is written here because we know the counts.
+ Parameters params;
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ WriteHeader(backing.vocab.get(), params);
+ }
+}
+
namespace detail {
bool IsBinaryFormat(int fd) {
@@ -130,14 +138,17 @@ bool IsBinaryFormat(int fd) {
Sanity reference_header = Sanity();
reference_header.SetToReference();
if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
+ if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
+ UTIL_THROW(FormatLoadException, "This binary file did not finish building");
+ }
if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
char *end_ptr;
const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
long int version = strtol(begin_version, &end_ptr, 10);
if ((end_ptr != begin_version) && version != kMagicVersion) {
- UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA. Sorry.");
+ UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
- UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture.");
+ UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;
}