summaryrefslogtreecommitdiff
path: root/klm/lm/binary_format.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/binary_format.cc')
-rw-r--r--klm/lm/binary_format.cc73
1 files changed, 43 insertions, 30 deletions
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
index 69a06355..3d9700da 100644
--- a/klm/lm/binary_format.cc
+++ b/klm/lm/binary_format.cc
@@ -18,8 +18,8 @@ namespace lm {
namespace ngram {
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
-const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0";
-const long int kMagicVersion = 1;
+const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 3\n\0";
+const long int kMagicVersion = 2;
// Test values.
struct Sanity {
@@ -76,6 +76,45 @@ void WriteHeader(void *to, const Parameters &params) {
}
} // namespace
+
+uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
+ if (config.write_mmap) {
+ std::size_t total = TotalHeaderSize(order) + memory_size;
+ backing.vocab.reset(util::MapZeroedWrite(config.write_mmap, total, backing.file), total, util::scoped_memory::MMAP_ALLOCATED);
+ return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order);
+ } else {
+ backing.vocab.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ return reinterpret_cast<uint8_t*>(backing.vocab.get());
+ }
+}
+
+uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) {
+ if (config.write_mmap) {
+ // header and vocab share the same mmap. The header is written here because we know the counts.
+ Parameters params;
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+ WriteHeader(backing.vocab.get(), params);
+
+ // Grow the file to accomodate the search, using zeros.
+ if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size))
+ UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed");
+
+ // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
+ off_t page_size = sysconf(_SC_PAGE_SIZE);
+ off_t alignment_cruft = backing.vocab.size() % page_size;
+ backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);
+
+ return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
+ } else {
+ backing.search.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ return reinterpret_cast<uint8_t*>(backing.search.get());
+ }
+}
+
namespace detail {
bool IsBinaryFormat(int fd) {
@@ -128,7 +167,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t
if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
- util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory);
+ util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search);
if (config.enumerate_vocab && !params.fixed.has_vocabulary)
UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
@@ -137,33 +176,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t
if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words");
}
- return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(params.counts.size());
-}
-
-uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) {
- if (config.write_mmap) {
- std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size;
- // Write out an mmap file.
- backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED);
-
- Parameters params;
- params.counts = counts;
- params.fixed.order = counts.size();
- params.fixed.probing_multiplier = config.probing_multiplier;
- params.fixed.model_type = model_type;
- params.fixed.has_vocabulary = config.include_vocab;
-
- WriteHeader(backing.memory.get(), params);
-
- if (params.fixed.has_vocabulary) {
- if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
- UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words");
- }
- return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(counts.size());
- } else {
- backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
- return reinterpret_cast<uint8_t*>(backing.memory.get());
- }
+ return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
}
void ComplainAboutARPA(const Config &config, ModelType model_type) {