diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-01-25 22:30:48 +0200 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-01-25 22:30:48 +0200 |
commit | c4ade3091b812ca135ae6520fa7173e1bbf28754 (patch) | |
tree | 2528af208f6dafd0c27dcbec0d2da291a9c93ca2 /klm/lm/model.cc | |
parent | d04c0ca2d9df0e147239b18e90650ca8bd51d594 (diff) |
update kenlm
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r-- | klm/lm/model.cc | 97 |
1 files changed, 48 insertions, 49 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc index c7ba4908..146fe07b 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -61,10 +61,10 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. util::FilePiece f(backing_.file.release(), file, config.messages); std::vector<uint64_t> counts; - // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed with search_.VariableSizeLoad + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. ReadARPACounts(f, counts); - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); @@ -114,7 +114,24 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state); - ret.prob += SlowBackoffLookup(context_rbegin, context_rend, ret.ngram_length); + + // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). + unsigned char start = ret.ngram_length; + if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret; + if (start <= 1) { + ret.prob += search_.unigram.Lookup(*context_rbegin).backoff; + start = 2; + } + typename Search::Node node; + if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { + return ret; + } + float backoff; + // i is the order of the backoff we're looking for. + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { + if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; + ret.prob += backoff; + } return ret; } @@ -128,8 +145,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT float ignored_prob; typename Search::Node node; search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); - // Tricky part is that an entry might be blank, but out_state.valid_length_ always has the last non-blank n-gram length. - out_state.valid_length_ = 1; + out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0; float *backoff_out = out_state.backoff_ + 1; const typename Search::Middle *mid = &*search_.middle.begin(); for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { @@ -137,36 +153,21 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); return; } - if (*backoff_out != kBlankBackoff) { - out_state.valid_length_ = i - context_rbegin + 1; - } else { - *backoff_out = 0.0; - } + if (HasExtension(*backoff_out)) out_state.valid_length_ = i - context_rbegin + 1; } std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); } -template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup( - const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const { - // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). - if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0; - float ret = 0.0; - if (start == 1) { - ret += search_.unigram.Lookup(*context_rbegin).backoff; - start = 2; - } - typename Search::Node node; - if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { - return 0.0; - } - float backoff; - // i is the order of the backoff we're looking for. - for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { - if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; - if (backoff != kBlankBackoff) ret += backoff; - } - return ret; +namespace { +// Do a paraonoid copy of history, assuming new_word has already been copied +// (hence the -1). out_state.valid_length_ could be zero so I avoided using +// std::copy. +void CopyRemainingHistory(const WordIndex *from, State &out_state) { + WordIndex *out = out_state.history_ + 1; + const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.valid_length_) - 1; + for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in; } +} // namespace /* Ugly optimized function. Produce a score excluding backoff. * The search goes in increasing order of ngram length. @@ -179,28 +180,26 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, const WordIndex new_word, State &out_state) const { FullScoreReturn ret; - // ret.ngram_length contains the last known good (non-blank) ngram length. + // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; typename Search::Node node; float *backoff_out(out_state.backoff_); search_.LookupUnigram(new_word, ret.prob, *backoff_out, node); + // This is the length of the context that should be used for continuation. + out_state.valid_length_ = HasExtension(*backoff_out) ? 1 : 0; + // We'll write the word anyway since it will probably be used and does no harm being there. out_state.history_[0] = new_word; - if (context_rbegin == context_rend) { - out_state.valid_length_ = 1; - return ret; - } + if (context_rbegin == context_rend) return ret; ++backoff_out; // Ok now we now that the bigram contains known words. Start by looking it up. - const WordIndex *hist_iter = context_rbegin; typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. - out_state.valid_length_ = ret.ngram_length; - std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); + CopyRemainingHistory(context_rbegin, out_state); // ret.prob was already set. return ret; } @@ -210,32 +209,32 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, float revert = ret.prob; if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { // Didn't find an ngram using hist_iter. - std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); - out_state.valid_length_ = ret.ngram_length; + CopyRemainingHistory(context_rbegin, out_state); // ret.prob was already set. return ret; } - if (*backoff_out == kBlankBackoff) { - *backoff_out = 0.0; + if (ret.prob == kBlankProb) { + // It's a blank. Go back to the old probability. ret.prob = revert; } else { ret.ngram_length = hist_iter - context_rbegin + 2; + if (HasExtension(*backoff_out)) { + out_state.valid_length_ = ret.ngram_length; + } } } // It passed every lookup in search_.middle. All that's left is to check search_.longest. if (!search_.LookupLongest(*hist_iter, ret.prob, node)) { - //assert(ret.ngram_length <= P::Order() - 1); - out_state.valid_length_ = ret.ngram_length; - std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); + // Failed to find a longest n-gram. Fall back to the most recent non-blank. + CopyRemainingHistory(context_rbegin, out_state); // ret.prob was already set. return ret; } - // It's an P::Order()-gram. There is no blank in longest_. - // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. - std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); - out_state.valid_length_ = P::Order() - 1; + // It's an P::Order()-gram. + CopyRemainingHistory(context_rbegin, out_state); + // There is no blank in longest_. ret.ngram_length = P::Order(); return ret; } |