diff options
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r-- | klm/lm/model.cc | 132 |
1 files changed, 86 insertions, 46 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 27e24b1c..ca581d8a 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -16,7 +16,7 @@ namespace lm { namespace ngram { size_t hash_value(const State &state) { - return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); } namespace detail { @@ -41,11 +41,11 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); - begin_sentence.valid_length_ = 1; - begin_sentence.history_[0] = vocab_.BeginSentence(); - begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff; State null_context = State(); - null_context.valid_length_ = 0; + null_context.length = 0; P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } @@ -87,7 +87,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.unigram.Unknown().backoff = 0.0; search_.unigram.Unknown().prob = config.unknown_missing_logprob; } - FinishFile(config, kModelType, counts, backing_); + FinishFile(config, kModelType, kVersion, counts, backing_); } catch (util::Exception &e) { e << " Byte: " << f.Offset(); throw; @@ -95,9 +95,9 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT } template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state); - if (ret.ngram_length - 1 < in_state.valid_length_) { - ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); + if (ret.ngram_length - 1 < in_state.length) { + ret.prob = std::accumulate(in_state.backoff + ret.ngram_length - 1, in_state.backoff + in_state.length, ret.prob); } return ret; } @@ -131,32 +131,80 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); if (context_rend == context_rbegin) { - out_state.valid_length_ = 0; + out_state.length = 0; return; } - float ignored_prob; + FullScoreReturn ignored; typename Search::Node node; - search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); - out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0; - float *backoff_out = out_state.backoff_ + 1; + search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored); + out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; + float *backoff_out = out_state.backoff + 1; const typename Search::Middle *mid = search_.MiddleBegin(); for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { - std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); + std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words); return; } - if (HasExtension(*backoff_out)) out_state.valid_length_ = i - context_rbegin + 1; + if (HasExtension(*backoff_out)) out_state.length = i - context_rbegin + 1; } - std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); + std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words); +} + +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ExtendLeft( + const WordIndex *add_rbegin, const WordIndex *add_rend, + const float *backoff_in, + uint64_t extend_pointer, + unsigned char extend_length, + float *backoff_out, + unsigned char &next_use) const { + FullScoreReturn ret; + float subtract_me; + typename Search::Node node(search_.Unpack(extend_pointer, extend_length, subtract_me)); + ret.prob = subtract_me; + ret.ngram_length = extend_length; + next_use = 0; + // If this function is called, then it does depend on left words. + ret.independent_left = false; + ret.extend_left = extend_pointer; + const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1; + const WordIndex *i = add_rbegin; + for (; ; ++i, ++backoff_out, ++mid_iter) { + if (i == add_rend) { + // Ran out of words. + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + if (mid_iter == search_.MiddleEnd()) break; + if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *i, *backoff_out, node, ret)) { + // Didn't match a word. + ret.independent_left = true; + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + ret.ngram_length = mid_iter - search_.MiddleBegin() + 2; + if (HasExtension(*backoff_out)) next_use = i - add_rbegin + 1; + } + + if (ret.independent_left || !search_.LookupLongest(*i, ret.prob, node)) { + // The last backoff weight, for Order() - 1. + ret.prob += backoff_in[i - add_rbegin]; + } else { + ret.ngram_length = P::Order(); + } + ret.independent_left = true; + ret.prob -= subtract_me; + return ret; } namespace { // Do a paraonoid copy of history, assuming new_word has already been copied -// (hence the -1). out_state.valid_length_ could be zero so I avoided using +// (hence the -1). out_state.length could be zero so I avoided using // std::copy. void CopyRemainingHistory(const WordIndex *from, State &out_state) { - WordIndex *out = out_state.history_ + 1; - const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.valid_length_) - 1; + WordIndex *out = out_state.words + 1; + const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1; for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in; } } // namespace @@ -175,17 +223,17 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; + float *backoff_out(out_state.backoff); typename Search::Node node; - float *backoff_out(out_state.backoff_); - search_.LookupUnigram(new_word, ret.prob, *backoff_out, node); - // This is the length of the context that should be used for continuation. - out_state.valid_length_ = HasExtension(*backoff_out) ? 1 : 0; + search_.LookupUnigram(new_word, *backoff_out, node, ret); + // This is the length of the context that should be used for continuation to the right. + out_state.length = HasExtension(*backoff_out) ? 1 : 0; // We'll write the word anyway since it will probably be used and does no harm being there. - out_state.history_[0] = new_word; + out_state.words[0] = new_word; if (context_rbegin == context_rend) return ret; ++backoff_out; - // Ok now we now that the bigram contains known words. Start by looking it up. + // Ok start by looking up the bigram. const WordIndex *hist_iter = context_rbegin; const typename Search::Middle *mid_iter = search_.MiddleBegin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { @@ -198,36 +246,28 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, if (mid_iter == search_.MiddleEnd()) break; - float revert = ret.prob; - if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { + if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *hist_iter, *backoff_out, node, ret)) { // Didn't find an ngram using hist_iter. CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. + // ret.prob was already set. + ret.independent_left = true; return ret; } - if (ret.prob == kBlankProb) { - // It's a blank. Go back to the old probability. - ret.prob = revert; - } else { - ret.ngram_length = hist_iter - context_rbegin + 2; - if (HasExtension(*backoff_out)) { - out_state.valid_length_ = ret.ngram_length; - } + ret.ngram_length = hist_iter - context_rbegin + 2; + if (HasExtension(*backoff_out)) { + out_state.length = ret.ngram_length; } } // It passed every lookup in search_.middle. All that's left is to check search_.longest. - - if (!search_.LookupLongest(*hist_iter, ret.prob, node)) { - // Failed to find a longest n-gram. Fall back to the most recent non-blank. - CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. - return ret; + if (!ret.independent_left && search_.LookupLongest(*hist_iter, ret.prob, node)) { + // It's an P::Order()-gram. + // There is no blank in longest_. + ret.ngram_length = P::Order(); } - // It's an P::Order()-gram. + // This handles (N-1)-grams and N-grams. CopyRemainingHistory(context_rbegin, out_state); - // There is no blank in longest_. - ret.ngram_length = P::Order(); + ret.independent_left = true; return ret; } |