summaryrefslogtreecommitdiff
path: root/klm/lm/model.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r--klm/lm/model.cc192
1 files changed, 103 insertions, 89 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 478ebed1..c081788c 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -38,10 +38,13 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
State begin_sentence = State();
begin_sentence.length = 1;
begin_sentence.words[0] = vocab_.BeginSentence();
- begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff;
+ typename Search::Node ignored_node;
+ bool ignored_independent_left;
+ uint64_t ignored_extend_left;
+ begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
State null_context = State();
null_context.length = 0;
- P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2);
+ P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
@@ -50,6 +53,9 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.LoadedBinary();
}
+namespace {
+} // namespace
+
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages);
@@ -79,8 +85,8 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
if (!vocab_.SawUnk()) {
assert(config.unknown_missing != THROW_UP);
// Default probabilities for unknown.
- search_.unigram.Unknown().backoff = 0.0;
- search_.unigram.Unknown().prob = config.unknown_missing_logprob;
+ search_.UnknownUnigram().backoff = 0.0;
+ search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
} catch (util::Exception &e) {
@@ -109,20 +115,22 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
unsigned char start = ret.ngram_length;
if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret;
+
+ bool independent_left;
+ uint64_t extend_left;
+ typename Search::Node node;
if (start <= 1) {
- ret.prob += search_.unigram.Lookup(*context_rbegin).backoff;
+ ret.prob += search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
start = 2;
- }
- typename Search::Node node;
- if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
+ } else if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
return ret;
}
- float backoff;
// i is the order of the backoff we're looking for.
- typename Search::MiddleIter mid_iter = search_.MiddleBegin() + start - 2;
- for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) {
- if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break;
- ret.prob += backoff;
+ unsigned char order_minus_2 = 0;
+ for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++order_minus_2) {
+ typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
+ if (!p.Found()) break;
+ ret.prob += p.Backoff();
}
return ret;
}
@@ -134,17 +142,20 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
out_state.length = 0;
return;
}
- FullScoreReturn ignored;
typename Search::Node node;
- search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored);
+ bool independent_left;
+ uint64_t extend_left;
+ out_state.backoff[0] = search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
float *backoff_out = out_state.backoff + 1;
- typename Search::MiddleIter mid(search_.MiddleBegin());
- for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
- if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) {
+ unsigned char order_minus_2 = 0;
+ for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++order_minus_2) {
+ typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
+ if (!p.Found()) {
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
return;
}
+ *backoff_out = p.Backoff();
if (HasExtension(*backoff_out)) out_state.length = i - context_rbegin + 1;
}
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
@@ -158,43 +169,29 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
float *backoff_out,
unsigned char &next_use) const {
FullScoreReturn ret;
- float subtract_me;
- typename Search::Node node(search_.Unpack(extend_pointer, extend_length, subtract_me));
- ret.prob = subtract_me;
- ret.ngram_length = extend_length;
- next_use = 0;
- // If this function is called, then it does depend on left words.
- ret.independent_left = false;
- ret.extend_left = extend_pointer;
- typename Search::MiddleIter mid_iter(search_.MiddleBegin() + extend_length - 1);
- const WordIndex *i = add_rbegin;
- for (; ; ++i, ++backoff_out, ++mid_iter) {
- if (i == add_rend) {
- // Ran out of words.
- for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
- ret.prob -= subtract_me;
- return ret;
- }
- if (mid_iter == search_.MiddleEnd()) break;
- if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *i, *backoff_out, node, ret)) {
- // Didn't match a word.
- ret.independent_left = true;
- for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
- ret.prob -= subtract_me;
- return ret;
- }
- ret.ngram_length = mid_iter - search_.MiddleBegin() + 2;
- if (HasExtension(*backoff_out)) next_use = i - add_rbegin + 1;
- }
-
- if (ret.independent_left || !search_.LookupLongest(*i, ret.prob, node)) {
- // The last backoff weight, for Order() - 1.
- ret.prob += backoff_in[i - add_rbegin];
+ typename Search::Node node;
+ if (extend_length == 1) {
+ typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(extend_pointer), node, ret.independent_left, ret.extend_left));
+ ret.rest = ptr.Rest();
+ ret.prob = ptr.Prob();
+ assert(!ret.independent_left);
} else {
- ret.ngram_length = P::Order();
+ typename Search::MiddlePointer ptr(search_.Unpack(extend_pointer, extend_length, node));
+ ret.rest = ptr.Rest();
+ ret.prob = ptr.Prob();
+ ret.extend_left = extend_pointer;
+ // If this function is called, then it does depend on left words.
+ ret.independent_left = false;
}
- ret.independent_left = true;
+ float subtract_me = ret.rest;
+ ret.ngram_length = extend_length;
+ next_use = extend_length;
+ ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);
+ next_use -= extend_length;
+ // Charge backoffs.
+ for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
ret.prob -= subtract_me;
+ ret.rest -= subtract_me;
return ret;
}
@@ -215,66 +212,83 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) {
* new_word.
*/
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
- const WordIndex *context_rbegin,
- const WordIndex *context_rend,
+ const WordIndex *const context_rbegin,
+ const WordIndex *const context_rend,
const WordIndex new_word,
State &out_state) const {
FullScoreReturn ret;
// ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
- float *backoff_out(out_state.backoff);
typename Search::Node node;
- search_.LookupUnigram(new_word, *backoff_out, node, ret);
+ typename Search::UnigramPointer uni(search_.LookupUnigram(new_word, node, ret.independent_left, ret.extend_left));
+ out_state.backoff[0] = uni.Backoff();
+ ret.prob = uni.Prob();
+ ret.rest = uni.Rest();
+
// This is the length of the context that should be used for continuation to the right.
- out_state.length = HasExtension(*backoff_out) ? 1 : 0;
+ out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
// We'll write the word anyway since it will probably be used and does no harm being there.
out_state.words[0] = new_word;
if (context_rbegin == context_rend) return ret;
- ++backoff_out;
-
- // Ok start by looking up the bigram.
- const WordIndex *hist_iter = context_rbegin;
- typename Search::MiddleIter mid_iter(search_.MiddleBegin());
- for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
- if (hist_iter == context_rend) {
- // Ran out of history. Typically no backoff, but this could be a blank.
- CopyRemainingHistory(context_rbegin, out_state);
- // ret.prob was already set.
- return ret;
- }
- if (mid_iter == search_.MiddleEnd()) break;
+ ResumeScore(context_rbegin, context_rend, 0, node, out_state.backoff + 1, out_state.length, ret);
+ CopyRemainingHistory(context_rbegin, out_state);
+ return ret;
+}
- if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *hist_iter, *backoff_out, node, ret)) {
- // Didn't find an ngram using hist_iter.
- CopyRemainingHistory(context_rbegin, out_state);
- // ret.prob was already set.
- ret.independent_left = true;
- return ret;
- }
- ret.ngram_length = hist_iter - context_rbegin + 2;
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::ResumeScore(const WordIndex *hist_iter, const WordIndex *const context_rend, unsigned char order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const {
+ for (; ; ++order_minus_2, ++hist_iter, ++backoff_out) {
+ if (hist_iter == context_rend) return;
+ if (ret.independent_left) return;
+ if (order_minus_2 == P::Order() - 2) break;
+
+ typename Search::MiddlePointer pointer(search_.LookupMiddle(order_minus_2, *hist_iter, node, ret.independent_left, ret.extend_left));
+ if (!pointer.Found()) return;
+ *backoff_out = pointer.Backoff();
+ ret.prob = pointer.Prob();
+ ret.rest = pointer.Rest();
+ ret.ngram_length = order_minus_2 + 2;
if (HasExtension(*backoff_out)) {
- out_state.length = ret.ngram_length;
+ next_use = ret.ngram_length;
}
}
-
- // It passed every lookup in search_.middle. All that's left is to check search_.longest.
- if (!ret.independent_left && search_.LookupLongest(*hist_iter, ret.prob, node)) {
- // It's an P::Order()-gram.
+ ret.independent_left = true;
+ typename Search::LongestPointer longest(search_.LookupLongest(*hist_iter, node));
+ if (longest.Found()) {
+ ret.prob = longest.Prob();
+ ret.rest = ret.prob;
// There is no blank in longest_.
ret.ngram_length = P::Order();
}
- // This handles (N-1)-grams and N-grams.
- CopyRemainingHistory(context_rbegin, out_state);
- ret.independent_left = true;
+}
+
+template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
+ float ret;
+ typename Search::Node node;
+ if (first_length == 1) {
+ if (pointers_begin >= pointers_end) return 0.0;
+ bool independent_left;
+ uint64_t extend_left;
+ typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(*pointers_begin), node, independent_left, extend_left));
+ ret = ptr.Prob() - ptr.Rest();
+ ++first_length;
+ ++pointers_begin;
+ } else {
+ ret = 0.0;
+ }
+ for (const uint64_t *i = pointers_begin; i < pointers_end; ++i, ++first_length) {
+ typename Search::MiddlePointer ptr(search_.Unpack(*i, first_length, node));
+ ret += ptr.Prob() - ptr.Rest();
+ }
return ret;
}
-template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING
-template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED
+template class GenericModel<HashedSearch<BackoffValue>, ProbingVocabulary>;
+template class GenericModel<HashedSearch<RestValue>, ProbingVocabulary>;
+template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>;
template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
-template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED_QUANT
+template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>;
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
} // namespace detail