summaryrefslogtreecommitdiff
path: root/klm/lm/model.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r--klm/lm/model.cc84
1 files changed, 52 insertions, 32 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index a26654a6..a5a16bf8 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -34,23 +34,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
}
-template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
- LoadLM(file, config, *this);
-
- // g++ prints warnings unless these are fully initialized.
- State begin_sentence = State();
- begin_sentence.length = 1;
- begin_sentence.words[0] = vocab_.BeginSentence();
- typename Search::Node ignored_node;
- bool ignored_independent_left;
- uint64_t ignored_extend_left;
- begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
- State null_context = State();
- null_context.length = 0;
- P::Init(begin_sentence, null_context, vocab_, search_.Order());
+namespace {
+void ComplainAboutARPA(const Config &config, ModelType model_type) {
+ if (config.write_mmap || !config.messages) return;
+ if (config.arpa_complain == Config::ALL) {
+ *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
+ } else if (config.arpa_complain == Config::EXPENSIVE &&
+ (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
+ *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+ }
}
-namespace {
void CheckCounts(const std::vector<uint64_t> &counts) {
UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
if (sizeof(uint64_t) > sizeof(std::size_t)) {
@@ -59,18 +53,45 @@ void CheckCounts(const std::vector<uint64_t> &counts) {
}
}
}
+
} // namespace
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
- CheckCounts(params.counts);
- SetupMemory(start, params.counts, config);
- vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab);
- search_.LoadedBinary();
+template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) {
+ util::scoped_fd fd(util::OpenReadOrThrow(file));
+ if (IsBinaryFormat(fd.get())) {
+ Parameters parameters;
+ int fd_shallow = fd.release();
+ backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters);
+ CheckCounts(parameters.counts);
+
+ Config new_config(init_config);
+ new_config.probing_multiplier = parameters.fixed.probing_multiplier;
+ Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config);
+ UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+
+ SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config);
+ vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset());
+ } else {
+ ComplainAboutARPA(init_config, kModelType);
+ InitializeFromARPA(fd.release(), file, init_config);
+ }
+
+ // g++ prints warnings unless these are fully initialized.
+ State begin_sentence = State();
+ begin_sentence.length = 1;
+ begin_sentence.words[0] = vocab_.BeginSentence();
+ typename Search::Node ignored_node;
+ bool ignored_independent_left;
+ uint64_t ignored_extend_left;
+ begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
+ State null_context = State();
+ null_context.length = 0;
+ P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
- // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
- util::FilePiece f(backing_.file.release(), file, config.ProgressMessages());
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) {
+ // Backing file is the ARPA.
+ util::FilePiece f(fd, file, config.ProgressMessages());
try {
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
@@ -81,13 +102,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
- vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
+ vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config);
- if (config.write_mmap) {
+ if (config.write_mmap && config.include_vocab) {
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config));
+ void *vocab_rebase, *search_rebase;
+ backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase);
+ // Due to writing at the end of file, mmap may have relocated data. So remap.
+ vocab_.Relocate(vocab_rebase);
+ search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config);
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
@@ -99,18 +124,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.UnknownUnigram().backoff = 0.0;
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
- FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
+ backing_.FinishFile(config, kModelType, kVersion, counts);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
}
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
- util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
- Search::UpdateConfigFromBinary(fd, counts, config);
-}
-
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {