summaryrefslogtreecommitdiff
path: root/klm/lm/binary_format.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/binary_format.hh')
-rw-r--r--klm/lm/binary_format.hh22
1 files changed, 12 insertions, 10 deletions
diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh
index d28cb6c5..a83f6b89 100644
--- a/klm/lm/binary_format.hh
+++ b/klm/lm/binary_format.hh
@@ -2,6 +2,7 @@
#define LM_BINARY_FORMAT__
#include "lm/config.hh"
+#include "lm/model_type.hh"
#include "lm/read_arpa.hh"
#include "util/file_piece.hh"
@@ -16,13 +17,6 @@
namespace lm {
namespace ngram {
-/* Not the best numbering system, but it grew this way for historical reasons
- * and I want to preserve existing binary files. */
-typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType;
-
-const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE_SORTED - TRIE_SORTED);
-const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE_SORTED - TRIE_SORTED);
-
/*Inspect a file to determine if it is a binary lm. If not, return false.
* If so, return true and set recognized to the type. This is the only API in
* this header designed for use by decoder authors.
@@ -36,8 +30,14 @@ struct FixedWidthParameters {
ModelType model_type;
// Does the end of the file have the actual strings in the vocabulary?
bool has_vocabulary;
+ unsigned int search_version;
};
+inline std::size_t Align8(std::size_t in) {
+ std::size_t off = in % 8;
+ return off ? (in + 8 - off) : in;
+}
+
// Parameters stored in the header of a binary file.
struct Parameters {
FixedWidthParameters fixed;
@@ -64,7 +64,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
// Write header to binary file. This is done last to prevent incomplete files
// from loading.
-void FinishFile(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, Backing &backing);
+void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing);
namespace detail {
@@ -72,7 +72,9 @@ bool IsBinaryFormat(int fd);
void ReadHeader(int fd, Parameters &params);
-void MatchCheck(ModelType model_type, const Parameters &params);
+void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params);
+
+void SeekPastHeader(int fd, const Parameters &params);
void SeekPastHeader(int fd, const Parameters &params);
@@ -90,7 +92,7 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
if (detail::IsBinaryFormat(backing.file.get())) {
Parameters params;
detail::ReadHeader(backing.file.get(), params);
- detail::MatchCheck(To::kModelType, params);
+ detail::MatchCheck(To::kModelType, To::kVersion, params);
// Replace the run-time configured probing_multiplier with the one in the file.
Config new_config(config);
new_config.probing_multiplier = params.fixed.probing_multiplier;