summaryrefslogtreecommitdiff
path: root/klm/lm/binary_format.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/binary_format.hh')
-rw-r--r--klm/lm/binary_format.hh22
1 files changed, 10 insertions, 12 deletions
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index a43c883c..2d66f813 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -35,10 +35,16 @@ struct Parameters {
struct Backing {
// File behind memory, if any.
util::scoped_fd file;
+ // Vocabulary lookup table. Not to be confused with the vocab words themselves.
+ util::scoped_memory vocab;
// Raw block of memory backing the language model data structures
- util::scoped_memory memory;
+ util::scoped_memory search;
};
+uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
+// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
+uint8_t *GrowForSearch(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing);
+
namespace detail {
bool IsBinaryFormat(int fd);
@@ -49,8 +55,6 @@ void MatchCheck(ModelType model_type, const Parameters &params);
uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing);
-uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing);
-
void ComplainAboutARPA(const Config &config, ModelType model_type);
} // namespace detail
@@ -61,13 +65,12 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
Backing &backing = to.MutableBacking();
backing.file.reset(util::OpenReadOrThrow(file));
- Parameters params;
-
try {
if (detail::IsBinaryFormat(backing.file.get())) {
+ Parameters params;
detail::ReadHeader(backing.file.get(), params);
detail::MatchCheck(To::kModelType, params);
- // Replace the probing_multiplier.
+ // Replace the run-time configured probing_multiplier with the one in the file.
Config new_config(config);
new_config.probing_multiplier = params.fixed.probing_multiplier;
std::size_t memory_size = To::Size(params.counts, new_config);
@@ -75,12 +78,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
to.InitializeFromBinary(start, params, new_config, backing.file.get());
} else {
detail::ComplainAboutARPA(config, To::kModelType);
- util::FilePiece f(backing.file.release(), file, config.messages);
- ReadARPACounts(f, params.counts);
- std::size_t memory_size = To::Size(params.counts, config);
- uint8_t *start = detail::SetupZeroed(config, To::kModelType, params.counts, memory_size, backing);
-
- to.InitializeFromARPA(file, f, start, params, config);
+ to.InitializeFromARPA(file, config);
}
} catch (util::Exception &e) {
e << " in file " << file;