summaryrefslogtreecommitdiff
path: root/klm/lm/model.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r--klm/lm/model.cc67
1 files changed, 38 insertions, 29 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index a1d10b3d..27e24b1c 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -21,6 +21,8 @@ size_t hash_value(const State &state) {
namespace detail {
+template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
+
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
@@ -56,35 +58,40 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages);
- std::vector<uint64_t> counts;
- // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
- ReadARPACounts(f, counts);
-
- if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
- if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
- if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
-
- std::size_t vocab_size = VocabularyT::Size(counts[0], config);
- // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
- vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
-
- if (config.write_mmap) {
- WriteWordsWrapper wrap(config.enumerate_vocab);
- vocab_.ConfigureEnumerate(&wrap, counts[0]);
- search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- wrap.Write(backing_.file.get());
- } else {
- vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
- search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
- }
+ try {
+ std::vector<uint64_t> counts;
+ // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
+ ReadARPACounts(f, counts);
+
+ if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
+ if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
+ if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
+
+ std::size_t vocab_size = VocabularyT::Size(counts[0], config);
+ // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
+ vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
+
+ if (config.write_mmap) {
+ WriteWordsWrapper wrap(config.enumerate_vocab);
+ vocab_.ConfigureEnumerate(&wrap, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
+ wrap.Write(backing_.file.get());
+ } else {
+ vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
+ }
- if (!vocab_.SawUnk()) {
- assert(config.unknown_missing != THROW_UP);
- // Default probabilities for unknown.
- search_.unigram.Unknown().backoff = 0.0;
- search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ if (!vocab_.SawUnk()) {
+ assert(config.unknown_missing != THROW_UP);
+ // Default probabilities for unknown.
+ search_.unigram.Unknown().backoff = 0.0;
+ search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ }
+ FinishFile(config, kModelType, counts, backing_);
+ } catch (util::Exception &e) {
+ e << " Byte: " << f.Offset();
+ throw;
}
- FinishFile(config, kModelType, counts, backing_);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
@@ -225,8 +232,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
}
template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING
-template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>; // TRIE_SORTED
-template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>; // TRIE_SORTED_QUANT
+template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED
+template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
+template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED_QUANT
+template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
} // namespace detail
} // namespace ngram