From 2b63fa0755954edf467a2421997eaf72771260cf Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 16 May 2012 13:24:08 -0700 Subject: Big kenlm change includes lower order models for probing only. And other stuff. --- klm/lm/model.cc | 192 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 103 insertions(+), 89 deletions(-) (limited to 'klm/lm/model.cc') diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 478ebed1..c081788c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -38,10 +38,13 @@ template GenericModel::Ge State begin_sentence = State(); begin_sentence.length = 1; begin_sentence.words[0] = vocab_.BeginSentence(); - begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff; + typename Search::Node ignored_node; + bool ignored_independent_left; + uint64_t ignored_extend_left; + begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); State null_context = State(); null_context.length = 0; - P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); + P::Init(begin_sentence, null_context, vocab_, search_.Order()); } template void GenericModel::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { @@ -50,6 +53,9 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, const Config &config) { // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. util::FilePiece f(backing_.file.release(), file, config.messages); @@ -79,8 +85,8 @@ template void GenericModel FullScoreReturn GenericModel(start)) return ret; + + bool independent_left; + uint64_t extend_left; + typename Search::Node node; if (start <= 1) { - ret.prob += search_.unigram.Lookup(*context_rbegin).backoff; + ret.prob += search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff(); start = 2; - } - typename Search::Node node; - if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { + } else if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { return ret; } - float backoff; // i is the order of the backoff we're looking for. - typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2; - for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) { - if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break; - ret.prob += backoff; + unsigned char order_minus_2 = 0; + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++order_minus_2) { + typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left)); + if (!p.Found()) break; + ret.prob += p.Backoff(); } return ret; } @@ -134,17 +142,20 @@ template void GenericModel FullScoreReturn GenericModel(extend_pointer), node, ret.independent_left, ret.extend_left)); + ret.rest = ptr.Rest(); + ret.prob = ptr.Prob(); + assert(!ret.independent_left); } else { - ret.ngram_length = P::Order(); + typename Search::MiddlePointer ptr(search_.Unpack(extend_pointer, extend_length, node)); + ret.rest = ptr.Rest(); + ret.prob = ptr.Prob(); + ret.extend_left = extend_pointer; + // If this function is called, then it does depend on left words. + ret.independent_left = false; } - ret.independent_left = true; + float subtract_me = ret.rest; + ret.ngram_length = extend_length; + next_use = extend_length; + ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret); + next_use -= extend_length; + // Charge backoffs. + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; ret.prob -= subtract_me; + ret.rest -= subtract_me; return ret; } @@ -215,66 +212,83 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) { * new_word. */ template FullScoreReturn GenericModel::ScoreExceptBackoff( - const WordIndex *context_rbegin, - const WordIndex *context_rend, + const WordIndex *const context_rbegin, + const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const { FullScoreReturn ret; // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; - float *backoff_out(out_state.backoff); typename Search::Node node; - search_.LookupUnigram(new_word, *backoff_out, node, ret); + typename Search::UnigramPointer uni(search_.LookupUnigram(new_word, node, ret.independent_left, ret.extend_left)); + out_state.backoff[0] = uni.Backoff(); + ret.prob = uni.Prob(); + ret.rest = uni.Rest(); + // This is the length of the context that should be used for continuation to the right. - out_state.length = HasExtension(*backoff_out) ? 1 : 0; + out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; // We'll write the word anyway since it will probably be used and does no harm being there. out_state.words[0] = new_word; if (context_rbegin == context_rend) return ret; - ++backoff_out; - - // Ok start by looking up the bigram. - const WordIndex *hist_iter = context_rbegin; - typename Search::MiddleIter mid_iter(search_.MiddleBegin()); - for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { - if (hist_iter == context_rend) { - // Ran out of history. Typically no backoff, but this could be a blank. - CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. - return ret; - } - if (mid_iter == search_.MiddleEnd()) break; + ResumeScore(context_rbegin, context_rend, 0, node, out_state.backoff + 1, out_state.length, ret); + CopyRemainingHistory(context_rbegin, out_state); + return ret; +} - if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *hist_iter, *backoff_out, node, ret)) { - // Didn't find an ngram using hist_iter. - CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. - ret.independent_left = true; - return ret; - } - ret.ngram_length = hist_iter - context_rbegin + 2; +template void GenericModel::ResumeScore(const WordIndex *hist_iter, const WordIndex *const context_rend, unsigned char order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const { + for (; ; ++order_minus_2, ++hist_iter, ++backoff_out) { + if (hist_iter == context_rend) return; + if (ret.independent_left) return; + if (order_minus_2 == P::Order() - 2) break; + + typename Search::MiddlePointer pointer(search_.LookupMiddle(order_minus_2, *hist_iter, node, ret.independent_left, ret.extend_left)); + if (!pointer.Found()) return; + *backoff_out = pointer.Backoff(); + ret.prob = pointer.Prob(); + ret.rest = pointer.Rest(); + ret.ngram_length = order_minus_2 + 2; if (HasExtension(*backoff_out)) { - out_state.length = ret.ngram_length; + next_use = ret.ngram_length; } } - - // It passed every lookup in search_.middle. All that's left is to check search_.longest. - if (!ret.independent_left && search_.LookupLongest(*hist_iter, ret.prob, node)) { - // It's an P::Order()-gram. + ret.independent_left = true; + typename Search::LongestPointer longest(search_.LookupLongest(*hist_iter, node)); + if (longest.Found()) { + ret.prob = longest.Prob(); + ret.rest = ret.prob; // There is no blank in longest_. ret.ngram_length = P::Order(); } - // This handles (N-1)-grams and N-grams. - CopyRemainingHistory(context_rbegin, out_state); - ret.independent_left = true; +} + +template float GenericModel::InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const { + float ret; + typename Search::Node node; + if (first_length == 1) { + if (pointers_begin >= pointers_end) return 0.0; + bool independent_left; + uint64_t extend_left; + typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast(*pointers_begin), node, independent_left, extend_left)); + ret = ptr.Prob() - ptr.Rest(); + ++first_length; + ++pointers_begin; + } else { + ret = 0.0; + } + for (const uint64_t *i = pointers_begin; i < pointers_end; ++i, ++first_length) { + typename Search::MiddlePointer ptr(search_.Unpack(*i, first_length, node)); + ret += ptr.Prob() - ptr.Rest(); + } return ret; } -template class GenericModel; // HASH_PROBING -template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, ProbingVocabulary>; +template class GenericModel, ProbingVocabulary>; +template class GenericModel, SortedVocabulary>; template class GenericModel, SortedVocabulary>; -template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; template class GenericModel, SortedVocabulary>; } // namespace detail -- cgit v1.2.3