summaryrefslogtreecommitdiff
path: root/klm/lm/model.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r--klm/lm/model.cc132
1 files changed, 86 insertions, 46 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 27e24b1c..ca581d8a 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -16,7 +16,7 @@ namespace lm {
namespace ngram {
size_t hash_value(const State &state) {
- return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_);
+ return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length);
}
namespace detail {
@@ -41,11 +41,11 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge
// g++ prints warnings unless these are fully initialized.
State begin_sentence = State();
- begin_sentence.valid_length_ = 1;
- begin_sentence.history_[0] = vocab_.BeginSentence();
- begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff;
+ begin_sentence.length = 1;
+ begin_sentence.words[0] = vocab_.BeginSentence();
+ begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff;
State null_context = State();
- null_context.valid_length_ = 0;
+ null_context.length = 0;
P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2);
}
@@ -87,7 +87,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.unigram.Unknown().backoff = 0.0;
search_.unigram.Unknown().prob = config.unknown_missing_logprob;
}
- FinishFile(config, kModelType, counts, backing_);
+ FinishFile(config, kModelType, kVersion, counts, backing_);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
@@ -95,9 +95,9 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
- FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state);
- if (ret.ngram_length - 1 < in_state.valid_length_) {
- ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob);
+ FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
+ if (ret.ngram_length - 1 < in_state.length) {
+ ret.prob = std::accumulate(in_state.backoff + ret.ngram_length - 1, in_state.backoff + in_state.length, ret.prob);
}
return ret;
}
@@ -131,32 +131,80 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
// Generate a state from context.
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
if (context_rend == context_rbegin) {
- out_state.valid_length_ = 0;
+ out_state.length = 0;
return;
}
- float ignored_prob;
+ FullScoreReturn ignored;
typename Search::Node node;
- search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node);
- out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0;
- float *backoff_out = out_state.backoff_ + 1;
+ search_.LookupUnigram(*context_rbegin, out_state.backoff[0], node, ignored);
+ out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
+ float *backoff_out = out_state.backoff + 1;
const typename Search::Middle *mid = search_.MiddleBegin();
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) {
- std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
+ std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
return;
}
- if (HasExtension(*backoff_out)) out_state.valid_length_ = i - context_rbegin + 1;
+ if (HasExtension(*backoff_out)) out_state.length = i - context_rbegin + 1;
}
- std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
+ std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
+}
+
+template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ExtendLeft(
+ const WordIndex *add_rbegin, const WordIndex *add_rend,
+ const float *backoff_in,
+ uint64_t extend_pointer,
+ unsigned char extend_length,
+ float *backoff_out,
+ unsigned char &next_use) const {
+ FullScoreReturn ret;
+ float subtract_me;
+ typename Search::Node node(search_.Unpack(extend_pointer, extend_length, subtract_me));
+ ret.prob = subtract_me;
+ ret.ngram_length = extend_length;
+ next_use = 0;
+ // If this function is called, then it does depend on left words.
+ ret.independent_left = false;
+ ret.extend_left = extend_pointer;
+ const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1;
+ const WordIndex *i = add_rbegin;
+ for (; ; ++i, ++backoff_out, ++mid_iter) {
+ if (i == add_rend) {
+ // Ran out of words.
+ for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
+ ret.prob -= subtract_me;
+ return ret;
+ }
+ if (mid_iter == search_.MiddleEnd()) break;
+ if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *i, *backoff_out, node, ret)) {
+ // Didn't match a word.
+ ret.independent_left = true;
+ for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
+ ret.prob -= subtract_me;
+ return ret;
+ }
+ ret.ngram_length = mid_iter - search_.MiddleBegin() + 2;
+ if (HasExtension(*backoff_out)) next_use = i - add_rbegin + 1;
+ }
+
+ if (ret.independent_left || !search_.LookupLongest(*i, ret.prob, node)) {
+ // The last backoff weight, for Order() - 1.
+ ret.prob += backoff_in[i - add_rbegin];
+ } else {
+ ret.ngram_length = P::Order();
+ }
+ ret.independent_left = true;
+ ret.prob -= subtract_me;
+ return ret;
}
namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied
-// (hence the -1). out_state.valid_length_ could be zero so I avoided using
+// (hence the -1). out_state.length could be zero so I avoided using
// std::copy.
void CopyRemainingHistory(const WordIndex *from, State &out_state) {
- WordIndex *out = out_state.history_ + 1;
- const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.valid_length_) - 1;
+ WordIndex *out = out_state.words + 1;
+ const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1;
for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in;
}
} // namespace
@@ -175,17 +223,17 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
// ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
+ float *backoff_out(out_state.backoff);
typename Search::Node node;
- float *backoff_out(out_state.backoff_);
- search_.LookupUnigram(new_word, ret.prob, *backoff_out, node);
- // This is the length of the context that should be used for continuation.
- out_state.valid_length_ = HasExtension(*backoff_out) ? 1 : 0;
+ search_.LookupUnigram(new_word, *backoff_out, node, ret);
+ // This is the length of the context that should be used for continuation to the right.
+ out_state.length = HasExtension(*backoff_out) ? 1 : 0;
// We'll write the word anyway since it will probably be used and does no harm being there.
- out_state.history_[0] = new_word;
+ out_state.words[0] = new_word;
if (context_rbegin == context_rend) return ret;
++backoff_out;
- // Ok now we now that the bigram contains known words. Start by looking it up.
+ // Ok start by looking up the bigram.
const WordIndex *hist_iter = context_rbegin;
const typename Search::Middle *mid_iter = search_.MiddleBegin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
@@ -198,36 +246,28 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
if (mid_iter == search_.MiddleEnd()) break;
- float revert = ret.prob;
- if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) {
+ if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *hist_iter, *backoff_out, node, ret)) {
// Didn't find an ngram using hist_iter.
CopyRemainingHistory(context_rbegin, out_state);
- // ret.prob was already set.
+ // ret.prob was already set.
+ ret.independent_left = true;
return ret;
}
- if (ret.prob == kBlankProb) {
- // It's a blank. Go back to the old probability.
- ret.prob = revert;
- } else {
- ret.ngram_length = hist_iter - context_rbegin + 2;
- if (HasExtension(*backoff_out)) {
- out_state.valid_length_ = ret.ngram_length;
- }
+ ret.ngram_length = hist_iter - context_rbegin + 2;
+ if (HasExtension(*backoff_out)) {
+ out_state.length = ret.ngram_length;
}
}
// It passed every lookup in search_.middle. All that's left is to check search_.longest.
-
- if (!search_.LookupLongest(*hist_iter, ret.prob, node)) {
- // Failed to find a longest n-gram. Fall back to the most recent non-blank.
- CopyRemainingHistory(context_rbegin, out_state);
- // ret.prob was already set.
- return ret;
+ if (!ret.independent_left && search_.LookupLongest(*hist_iter, ret.prob, node)) {
+ // It's an P::Order()-gram.
+ // There is no blank in longest_.
+ ret.ngram_length = P::Order();
}
- // It's an P::Order()-gram.
+ // This handles (N-1)-grams and N-grams.
CopyRemainingHistory(context_rbegin, out_state);
- // There is no blank in longest_.
- ret.ngram_length = P::Order();
+ ret.independent_left = true;
return ret;
}