summaryrefslogtreecommitdiff
path: root/klm/lm/model.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r--klm/lm/model.cc97
1 files changed, 48 insertions, 49 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index c7ba4908..146fe07b 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -61,10 +61,10 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
// Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
util::FilePiece f(backing_.file.release(), file, config.messages);
std::vector<uint64_t> counts;
- // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed with search_.VariableSizeLoad
+ // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
ReadARPACounts(f, counts);
- if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
+ if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile.");
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
@@ -114,7 +114,24 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state);
- ret.prob += SlowBackoffLookup(context_rbegin, context_rend, ret.ngram_length);
+
+ // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
+ unsigned char start = ret.ngram_length;
+ if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret;
+ if (start <= 1) {
+ ret.prob += search_.unigram.Lookup(*context_rbegin).backoff;
+ start = 2;
+ }
+ typename Search::Node node;
+ if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
+ return ret;
+ }
+ float backoff;
+ // i is the order of the backoff we're looking for.
+ for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
+ if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
+ ret.prob += backoff;
+ }
return ret;
}
@@ -128,8 +145,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
float ignored_prob;
typename Search::Node node;
search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node);
- // Tricky part is that an entry might be blank, but out_state.valid_length_ always has the last non-blank n-gram length.
- out_state.valid_length_ = 1;
+ out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0;
float *backoff_out = out_state.backoff_ + 1;
const typename Search::Middle *mid = &*search_.middle.begin();
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
@@ -137,36 +153,21 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
return;
}
- if (*backoff_out != kBlankBackoff) {
- out_state.valid_length_ = i - context_rbegin + 1;
- } else {
- *backoff_out = 0.0;
- }
+ if (HasExtension(*backoff_out)) out_state.valid_length_ = i - context_rbegin + 1;
}
std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
}
-template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup(
- const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const {
- // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
- if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0;
- float ret = 0.0;
- if (start == 1) {
- ret += search_.unigram.Lookup(*context_rbegin).backoff;
- start = 2;
- }
- typename Search::Node node;
- if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
- return 0.0;
- }
- float backoff;
- // i is the order of the backoff we're looking for.
- for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
- if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
- if (backoff != kBlankBackoff) ret += backoff;
- }
- return ret;
+namespace {
+// Do a paraonoid copy of history, assuming new_word has already been copied
+// (hence the -1). out_state.valid_length_ could be zero so I avoided using
+// std::copy.
+void CopyRemainingHistory(const WordIndex *from, State &out_state) {
+ WordIndex *out = out_state.history_ + 1;
+ const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.valid_length_) - 1;
+ for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in;
}
+} // namespace
/* Ugly optimized function. Produce a score excluding backoff.
* The search goes in increasing order of ngram length.
@@ -179,28 +180,26 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
const WordIndex new_word,
State &out_state) const {
FullScoreReturn ret;
- // ret.ngram_length contains the last known good (non-blank) ngram length.
+ // ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
typename Search::Node node;
float *backoff_out(out_state.backoff_);
search_.LookupUnigram(new_word, ret.prob, *backoff_out, node);
+ // This is the length of the context that should be used for continuation.
+ out_state.valid_length_ = HasExtension(*backoff_out) ? 1 : 0;
+ // We'll write the word anyway since it will probably be used and does no harm being there.
out_state.history_[0] = new_word;
- if (context_rbegin == context_rend) {
- out_state.valid_length_ = 1;
- return ret;
- }
+ if (context_rbegin == context_rend) return ret;
++backoff_out;
// Ok now we now that the bigram contains known words. Start by looking it up.
-
const WordIndex *hist_iter = context_rbegin;
typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
// Ran out of history. Typically no backoff, but this could be a blank.
- out_state.valid_length_ = ret.ngram_length;
- std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1);
+ CopyRemainingHistory(context_rbegin, out_state);
// ret.prob was already set.
return ret;
}
@@ -210,32 +209,32 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
float revert = ret.prob;
if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) {
// Didn't find an ngram using hist_iter.
- std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1);
- out_state.valid_length_ = ret.ngram_length;
+ CopyRemainingHistory(context_rbegin, out_state);
// ret.prob was already set.
return ret;
}
- if (*backoff_out == kBlankBackoff) {
- *backoff_out = 0.0;
+ if (ret.prob == kBlankProb) {
+ // It's a blank. Go back to the old probability.
ret.prob = revert;
} else {
ret.ngram_length = hist_iter - context_rbegin + 2;
+ if (HasExtension(*backoff_out)) {
+ out_state.valid_length_ = ret.ngram_length;
+ }
}
}
// It passed every lookup in search_.middle. All that's left is to check search_.longest.
if (!search_.LookupLongest(*hist_iter, ret.prob, node)) {
- //assert(ret.ngram_length <= P::Order() - 1);
- out_state.valid_length_ = ret.ngram_length;
- std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1);
+ // Failed to find a longest n-gram. Fall back to the most recent non-blank.
+ CopyRemainingHistory(context_rbegin, out_state);
// ret.prob was already set.
return ret;
}
- // It's an P::Order()-gram. There is no blank in longest_.
- // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
- std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
- out_state.valid_length_ = P::Order() - 1;
+ // It's an P::Order()-gram.
+ CopyRemainingHistory(context_rbegin, out_state);
+ // There is no blank in longest_.
ret.ngram_length = P::Order();
return ret;
}