diff options
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r-- | klm/lm/model.cc | 33 |
1 files changed, 17 insertions, 16 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 2fd20481..a40fd2fb 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -37,7 +37,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { LoadLM(file, config, *this); - // g++ prints warnings unless these are fully initialized. + // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); begin_sentence.length = 1; begin_sentence.words[0] = vocab_.BeginSentence(); @@ -69,8 +69,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { - // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. - util::FilePiece f(backing_.file.release(), file, config.messages); + // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. + util::FilePiece f(backing_.file.release(), file, config.ProgressMessages()); try { std::vector<uint64_t> counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. @@ -80,14 +80,14 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); if (config.write_mmap) { WriteWordsWrapper wrap(config.enumerate_vocab); vocab_.ConfigureEnumerate(&wrap, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + backing_.search.size()); + wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config)); } else { vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -95,7 +95,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (!vocab_.SawUnk()) { assert(config.unknown_missing != THROW_UP); - // Default probabilities for unknown. + // Default probabilities for unknown. search_.UnknownUnigram().backoff = 0.0; search_.UnknownUnigram().prob = config.unknown_missing_logprob; } @@ -147,7 +147,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { - // Generate a state from context. + // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); if (context_rend == context_rbegin) { out_state.length = 0; @@ -191,7 +191,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, ret.rest = ptr.Rest(); ret.prob = ptr.Prob(); ret.extend_left = extend_pointer; - // If this function is called, then it does depend on left words. + // If this function is called, then it does depend on left words. ret.independent_left = false; } float subtract_me = ret.rest; @@ -199,7 +199,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, next_use = extend_length; ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret); next_use -= extend_length; - // Charge backoffs. + // Charge backoffs. for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; ret.prob -= subtract_me; ret.rest -= subtract_me; @@ -209,7 +209,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, namespace { // Do a paraonoid copy of history, assuming new_word has already been copied // (hence the -1). out_state.length could be zero so I avoided using -// std::copy. +// std::copy. void CopyRemainingHistory(const WordIndex *from, State &out_state) { WordIndex *out = out_state.words + 1; const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1; @@ -217,18 +217,19 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) { } } // namespace -/* Ugly optimized function. Produce a score excluding backoff. - * The search goes in increasing order of ngram length. +/* Ugly optimized function. Produce a score excluding backoff. + * The search goes in increasing order of ngram length. * Context goes backward, so context_begin is the word immediately preceeding - * new_word. + * new_word. */ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff( const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const { + assert(new_word < vocab_.Bound()); FullScoreReturn ret; - // ret.ngram_length contains the last known non-blank ngram length. + // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; typename Search::Node node; @@ -237,9 +238,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, ret.prob = uni.Prob(); ret.rest = uni.Rest(); - // This is the length of the context that should be used for continuation to the right. + // This is the length of the context that should be used for continuation to the right. out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; - // We'll write the word anyway since it will probably be used and does no harm being there. + // We'll write the word anyway since it will probably be used and does no harm being there. out_state.words[0] = new_word; if (context_rbegin == context_rend) return ret; |