diff options
Diffstat (limited to 'klm/lm/binary_format.hh')
-rw-r--r-- | klm/lm/binary_format.hh | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index a43c883c..2d66f813 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -35,10 +35,16 @@ struct Parameters { struct Backing { // File behind memory, if any. util::scoped_fd file; + // Vocabulary lookup table. Not to be confused with the vocab words themselves. + util::scoped_memory vocab; // Raw block of memory backing the language model data structures - util::scoped_memory memory; + util::scoped_memory search; }; +uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); +// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. +uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing); + namespace detail { bool IsBinaryFormat(int fd); @@ -49,8 +55,6 @@ void MatchCheck(ModelType model_type, const Parameters ¶ms); uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); -uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing); - void ComplainAboutARPA(const Config &config, ModelType model_type); } // namespace detail @@ -61,13 +65,12 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to) Backing &backing = to.MutableBacking(); backing.file.reset(util::OpenReadOrThrow(file)); - Parameters params; - try { if (detail::IsBinaryFormat(backing.file.get())) { + Parameters params; detail::ReadHeader(backing.file.get(), params); detail::MatchCheck(To::kModelType, params); - // Replace the probing_multiplier. + // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; std::size_t memory_size = To::Size(params.counts, new_config); @@ -75,12 +78,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to) to.InitializeFromBinary(start, params, new_config, backing.file.get()); } else { detail::ComplainAboutARPA(config, To::kModelType); - util::FilePiece f(backing.file.release(), file, config.messages); - ReadARPACounts(f, params.counts); - std::size_t memory_size = To::Size(params.counts, config); - uint8_t *start = detail::SetupZeroed(config, To::kModelType, params.counts, memory_size, backing); - - to.InitializeFromARPA(file, f, start, params, config); + to.InitializeFromARPA(file, config); } } catch (util::Exception &e) { e << " in file " << file; |