diff options
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r-- | klm/lm/model.cc | 192 |
1 files changed, 103 insertions, 89 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 478ebed1..c081788c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -38,10 +38,13 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge State begin_sentence = State(); begin_sentence.length = 1; begin_sentence.words[0] = vocab_.BeginSentence(); - begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff; + typename Search::Node ignored_node; + bool ignored_independent_left; + uint64_t ignored_extend_left; + begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); State null_context = State(); null_context.length = 0; - P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); + P::Init(begin_sentence, null_context, vocab_, search_.Order()); } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { @@ -50,6 +53,9 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.LoadedBinary(); } +namespace { +} // namespace + template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. util::FilePiece f(backing_.file.release(), file, config.messages); @@ -79,8 +85,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT if (!vocab_.SawUnk()) { assert(config.unknown_missing != THROW_UP); // Default probabilities for unknown. - search_.unigram.Unknown().backoff = 0.0; - search_.unigram.Unknown().prob = config.unknown_missing_logprob; + search_.UnknownUnigram().backoff = 0.0; + search_.UnknownUnigram().prob = config.unknown_missing_logprob; } FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_); } catch (util::Exception &e) { @@ -109,20 +115,22 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). unsigned char start = ret.ngram_length; if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret; + + bool independent_left; + uint64_t extend_left; + typename Search::Node node; if (start <= 1) { - ret.prob += search_.unigram.Lookup(*context_rbegin).backoff; + ret.prob += search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff(); start = 2; - } - typename Search::Node node; - if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { + } else if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { return ret; } - float backoff; // i is the order of the backoff we're looking for. - typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2; - for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) { - if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break; - ret.prob += backoff; + unsigned char order_minus_2 = 0; + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++order_minus_2) { + typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left)); + if (!p.Found()) break; + ret.prob += p.Backoff(); } return ret; } @@ -134,17 +142,20 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT out_state.length = 0; return; } - FullScoreReturn ignored; typename Search::Node node; - search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored); + bool independent_left; + uint64_t extend_left; + out_state.backoff[0] = search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff(); out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; float *backoff_out = out_state.backoff + 1; - typename Search::MiddleIter mid(search_.MiddleBegin()); - for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { - if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { + unsigned char order_minus_2 = 0; + for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++order_minus_2) { + typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left)); + if (!p.Found()) { std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words); return; } + *backoff_out = p.Backoff(); if (HasExtension(*backoff_out)) out_state.length = i - context_rbegin + 1; } std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words); @@ -158,43 +169,29 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, float *backoff_out, unsigned char &next_use) const { FullScoreReturn ret; - float subtract_me; - typename Search::Node node(search_.Unpack(extend_pointer, extend_length, subtract_me)); - ret.prob = subtract_me; - ret.ngram_length = extend_length; - next_use = 0; - // If this function is called, then it does depend on left words. - ret.independent_left = false; - ret.extend_left = extend_pointer; - typename Search::MiddleIter mid_iter(search_.MiddleBegin() + extend_length - 1); - const WordIndex *i = add_rbegin; - for (; ; ++i, ++backoff_out, ++mid_iter) { - if (i == add_rend) { - // Ran out of words. - for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; - ret.prob -= subtract_me; - return ret; - } - if (mid_iter == search_.MiddleEnd()) break; - if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *i, *backoff_out, node, ret)) { - // Didn't match a word. - ret.independent_left = true; - for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; - ret.prob -= subtract_me; - return ret; - } - ret.ngram_length = mid_iter - search_.MiddleBegin() + 2; - if (HasExtension(*backoff_out)) next_use = i - add_rbegin + 1; - } - - if (ret.independent_left || !search_.LookupLongest(*i, ret.prob, node)) { - // The last backoff weight, for Order() - 1. - ret.prob += backoff_in[i - add_rbegin]; + typename Search::Node node; + if (extend_length == 1) { + typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(extend_pointer), node, ret.independent_left, ret.extend_left)); + ret.rest = ptr.Rest(); + ret.prob = ptr.Prob(); + assert(!ret.independent_left); } else { - ret.ngram_length = P::Order(); + typename Search::MiddlePointer ptr(search_.Unpack(extend_pointer, extend_length, node)); + ret.rest = ptr.Rest(); + ret.prob = ptr.Prob(); + ret.extend_left = extend_pointer; + // If this function is called, then it does depend on left words. + ret.independent_left = false; } - ret.independent_left = true; + float subtract_me = ret.rest; + ret.ngram_length = extend_length; + next_use = extend_length; + ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret); + next_use -= extend_length; + // Charge backoffs. + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; ret.prob -= subtract_me; + ret.rest -= subtract_me; return ret; } @@ -215,66 +212,83 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) { * new_word. */ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff( - const WordIndex *context_rbegin, - const WordIndex *context_rend, + const WordIndex *const context_rbegin, + const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const { FullScoreReturn ret; // ret.ngram_length contains the last known non-blank ngram length. ret.ngram_length = 1; - float *backoff_out(out_state.backoff); typename Search::Node node; - search_.LookupUnigram(new_word, *backoff_out, node, ret); + typename Search::UnigramPointer uni(search_.LookupUnigram(new_word, node, ret.independent_left, ret.extend_left)); + out_state.backoff[0] = uni.Backoff(); + ret.prob = uni.Prob(); + ret.rest = uni.Rest(); + // This is the length of the context that should be used for continuation to the right. - out_state.length = HasExtension(*backoff_out) ? 1 : 0; + out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0; // We'll write the word anyway since it will probably be used and does no harm being there. out_state.words[0] = new_word; if (context_rbegin == context_rend) return ret; - ++backoff_out; - - // Ok start by looking up the bigram. - const WordIndex *hist_iter = context_rbegin; - typename Search::MiddleIter mid_iter(search_.MiddleBegin()); - for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { - if (hist_iter == context_rend) { - // Ran out of history. Typically no backoff, but this could be a blank. - CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. - return ret; - } - if (mid_iter == search_.MiddleEnd()) break; + ResumeScore(context_rbegin, context_rend, 0, node, out_state.backoff + 1, out_state.length, ret); + CopyRemainingHistory(context_rbegin, out_state); + return ret; +} - if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *hist_iter, *backoff_out, node, ret)) { - // Didn't find an ngram using hist_iter. - CopyRemainingHistory(context_rbegin, out_state); - // ret.prob was already set. - ret.independent_left = true; - return ret; - } - ret.ngram_length = hist_iter - context_rbegin + 2; +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::ResumeScore(const WordIndex *hist_iter, const WordIndex *const context_rend, unsigned char order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const { + for (; ; ++order_minus_2, ++hist_iter, ++backoff_out) { + if (hist_iter == context_rend) return; + if (ret.independent_left) return; + if (order_minus_2 == P::Order() - 2) break; + + typename Search::MiddlePointer pointer(search_.LookupMiddle(order_minus_2, *hist_iter, node, ret.independent_left, ret.extend_left)); + if (!pointer.Found()) return; + *backoff_out = pointer.Backoff(); + ret.prob = pointer.Prob(); + ret.rest = pointer.Rest(); + ret.ngram_length = order_minus_2 + 2; if (HasExtension(*backoff_out)) { - out_state.length = ret.ngram_length; + next_use = ret.ngram_length; } } - - // It passed every lookup in search_.middle. All that's left is to check search_.longest. - if (!ret.independent_left && search_.LookupLongest(*hist_iter, ret.prob, node)) { - // It's an P::Order()-gram. + ret.independent_left = true; + typename Search::LongestPointer longest(search_.LookupLongest(*hist_iter, node)); + if (longest.Found()) { + ret.prob = longest.Prob(); + ret.rest = ret.prob; // There is no blank in longest_. ret.ngram_length = P::Order(); } - // This handles (N-1)-grams and N-grams. - CopyRemainingHistory(context_rbegin, out_state); - ret.independent_left = true; +} + +template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const { + float ret; + typename Search::Node node; + if (first_length == 1) { + if (pointers_begin >= pointers_end) return 0.0; + bool independent_left; + uint64_t extend_left; + typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(*pointers_begin), node, independent_left, extend_left)); + ret = ptr.Prob() - ptr.Rest(); + ++first_length; + ++pointers_begin; + } else { + ret = 0.0; + } + for (const uint64_t *i = pointers_begin; i < pointers_end; ++i, ++first_length) { + typename Search::MiddlePointer ptr(search_.Unpack(*i, first_length, node)); + ret += ptr.Prob() - ptr.Rest(); + } return ret; } -template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING -template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED +template class GenericModel<HashedSearch<BackoffValue>, ProbingVocabulary>; +template class GenericModel<HashedSearch<RestValue>, ProbingVocabulary>; +template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>; template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>; -template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>; template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>; } // namespace detail |