summaryrefslogtreecommitdiff
path: root/klm/lm/binary_format.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r--klm/lm/binary_format.cc41
1 files changed, 27 insertions, 14 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index a56e998e..39c4a9b6 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -16,11 +16,11 @@ namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
-// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
+// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 5;
-// Old binary files built on 32-bit machines have this header.
+// Old binary files built on 32-bit machines have this header.
// TODO: eliminate with next binary release.
struct OldSanity {
char magic[sizeof(kMagicBytes)];
@@ -39,7 +39,7 @@ struct OldSanity {
};
-// Test values aligned to 8 bytes.
+// Test values aligned to 8 bytes.
struct Sanity {
char magic[ALIGN8(sizeof(kMagicBytes))];
float zero_f, one_f, minus_half_f;
@@ -83,7 +83,13 @@ void WriteHeader(void *to, const Parameters &params) {
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
- backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED);
+ backing.file.reset(util::CreateOrThrow(config.write_mmap));
+ if (config.write_method == Config::WRITE_MMAP) {
+ backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
+ } else {
+ util::ResizeOrThrow(backing.file.get(), 0);
+ util::MapAnonymous(total, backing.vocab);
+ }
strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order));
return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
} else {
@@ -95,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_
uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) {
std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;
if (config.write_mmap) {
- // Grow the file to accomodate the search, using zeros.
+ // Grow the file to accomodate the search, using zeros.
try {
util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size);
} catch (util::ErrnoException &e) {
@@ -108,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
return reinterpret_cast<uint8_t*>(backing.search.get());
}
// mmap it now.
- // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
+ // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
std::size_t page_size = util::SizePage();
std::size_t alignment_cruft = adjusted_vocab % page_size;
backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
@@ -116,23 +122,25 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
} else {
util::MapAnonymous(memory_size, backing.search);
return reinterpret_cast<uint8_t*>(backing.search.get());
- }
+ }
}
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
if (!config.write_mmap) return;
- util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
switch (config.write_method) {
case Config::WRITE_MMAP:
+ util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
util::SyncOrThrow(backing.search.get(), backing.search.size());
break;
case Config::WRITE_AFTER:
+ util::SeekOrThrow(backing.file.get(), 0);
+ util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size());
util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
util::FSyncOrThrow(backing.file.get());
break;
}
- // header and vocab share the same mmap. The header is written here because we know the counts.
+ // header and vocab share the same mmap. The header is written here because we know the counts.
Parameters params = Parameters();
params.counts = counts;
params.fixed.order = counts.size();
@@ -141,6 +149,10 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_
params.fixed.has_vocabulary = config.include_vocab;
params.fixed.search_version = search_version;
WriteHeader(backing.vocab.get(), params);
+ if (config.write_method == Config::WRITE_AFTER) {
+ util::SeekOrThrow(backing.file.get(), 0);
+ util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size()));
+ }
}
namespace detail {
@@ -148,7 +160,7 @@ namespace detail {
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
- // Try reading the header.
+ // Try reading the header.
util::scoped_memory memory;
try {
util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
@@ -200,10 +212,10 @@ void SeekPastHeader(int fd, const Parameters &params) {
util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
}
-uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
+uint8_t *SetupBinary(const Config &config, const Parameters &params, uint64_t memory_size, Backing &backing) {
const uint64_t file_size = util::SizeFile(backing.file.get());
- // The header is smaller than a page, so we have to map the whole header as well.
- std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size;
+ // The header is smaller than a page, so we have to map the whole header as well.
+ std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size);
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
@@ -221,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) {
if (config.write_mmap || !config.messages) return;
if (config.arpa_complain == Config::ALL) {
*config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
- } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) {
+ } else if (config.arpa_complain == Config::EXPENSIVE &&
+ (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
*config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
}
}