summaryrefslogtreecommitdiff
path: root/klm/lm/model.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r--klm/lm/model.cc117
1 files changed, 63 insertions, 54 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
index 421e72fa..c7ba4908 100644
--- a/klm/lm/model.cc
+++ b/klm/lm/model.cc
@@ -1,5 +1,6 @@
#include "lm/model.hh"
+#include "lm/blank.hh"
#include "lm/lm_exception.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
@@ -21,9 +22,6 @@ size_t hash_value(const State &state) {
namespace detail {
template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
- if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
- if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
- if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
@@ -59,17 +57,31 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.longest.LoadedBinary();
}
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters &params, const Config &config) {
- SetupMemory(start, params.counts, config);
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {
+ // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any.
+ util::FilePiece f(backing_.file.release(), file, config.messages);
+ std::vector<uint64_t> counts;
+ // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed with search_.VariableSizeLoad
+ ReadARPACounts(f, counts);
+
+ if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
+ if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
+ if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
+
+ std::size_t vocab_size = VocabularyT::Size(counts[0], config);
+ // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
+ vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config);
if (config.write_mmap) {
- WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get());
- vocab_.ConfigureEnumerate(&wrap, params.counts[0]);
- search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
+ WriteWordsWrapper wrap(config.enumerate_vocab);
+ vocab_.ConfigureEnumerate(&wrap, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
+ wrap.Write(backing_.file.get());
} else {
- vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]);
- search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
+ vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
+ search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
}
+
// TODO: fail faster?
if (!vocab_.SawUnk()) {
switch(config.unknown_missing) {
@@ -89,46 +101,49 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
break;
}
}
- if (std::fabs(search_.unigram.Unknown().backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
- unsigned char backoff_start;
- FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state);
- if (backoff_start - 1 < in_state.valid_length_) {
- ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob);
+ FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state);
+ if (ret.ngram_length - 1 < in_state.valid_length_) {
+ ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob);
}
return ret;
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
- unsigned char backoff_start;
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
- FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state);
- ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start);
+ FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state);
+ ret.prob += SlowBackoffLookup(context_rbegin, context_rend, ret.ngram_length);
return ret;
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
+ // Generate a state from context.
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
- if (context_rend == context_rbegin || *context_rbegin == 0) {
+ if (context_rend == context_rbegin) {
out_state.valid_length_ = 0;
return;
}
float ignored_prob;
typename Search::Node node;
search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node);
+ // Tricky part is that an entry might be blank, but out_state.valid_length_ always has the last non-blank n-gram length.
+ out_state.valid_length_ = 1;
float *backoff_out = out_state.backoff_ + 1;
- const WordIndex *i = context_rbegin + 1;
- for (; i < context_rend; ++i, ++backoff_out) {
- if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) {
- out_state.valid_length_ = i - context_rbegin;
- std::copy(context_rbegin, i, out_state.history_);
+ const typename Search::Middle *mid = &*search_.middle.begin();
+ for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) {
+ if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) {
+ std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
return;
}
+ if (*backoff_out != kBlankBackoff) {
+ out_state.valid_length_ = i - context_rbegin + 1;
+ } else {
+ *backoff_out = 0.0;
+ }
}
- std::copy(context_rbegin, context_rend, out_state.history_);
- out_state.valid_length_ = static_cast<unsigned char>(context_rend - context_rbegin);
+ std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_);
}
template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup(
@@ -148,7 +163,7 @@ template <class Search, class VocabularyT> float GenericModel<Search, Vocabulary
// i is the order of the backoff we're looking for.
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
- ret += backoff;
+ if (backoff != kBlankBackoff) ret += backoff;
}
return ret;
}
@@ -162,23 +177,17 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
const WordIndex *context_rbegin,
const WordIndex *context_rend,
const WordIndex new_word,
- unsigned char &backoff_start,
State &out_state) const {
FullScoreReturn ret;
+ // ret.ngram_length contains the last known good (non-blank) ngram length.
+ ret.ngram_length = 1;
+
typename Search::Node node;
float *backoff_out(out_state.backoff_);
search_.LookupUnigram(new_word, ret.prob, *backoff_out, node);
- if (new_word == 0) {
- ret.ngram_length = out_state.valid_length_ = 0;
- // All of backoff.
- backoff_start = 1;
- return ret;
- }
out_state.history_[0] = new_word;
if (context_rbegin == context_rend) {
- ret.ngram_length = out_state.valid_length_ = 1;
- // No backoff because we don't have the history for it.
- backoff_start = P::Order();
+ out_state.valid_length_ = 1;
return ret;
}
++backoff_out;
@@ -189,45 +198,45 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin();
for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) {
- // Ran out of history. No backoff.
- backoff_start = P::Order();
- std::copy(context_rbegin, context_rend, out_state.history_ + 1);
- ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1;
+ // Ran out of history. Typically no backoff, but this could be a blank.
+ out_state.valid_length_ = ret.ngram_length;
+ std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1);
// ret.prob was already set.
return ret;
}
if (mid_iter == search_.middle.end()) break;
+ float revert = ret.prob;
if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) {
// Didn't find an ngram using hist_iter.
- // The history used in the found n-gram is [context_rbegin, hist_iter).
- std::copy(context_rbegin, hist_iter, out_state.history_ + 1);
- // Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word.
- ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1;
- backoff_start = mid_iter - search_.middle.begin() + 1;
+ std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1);
+ out_state.valid_length_ = ret.ngram_length;
// ret.prob was already set.
return ret;
}
+ if (*backoff_out == kBlankBackoff) {
+ *backoff_out = 0.0;
+ ret.prob = revert;
+ } else {
+ ret.ngram_length = hist_iter - context_rbegin + 2;
+ }
}
- // It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram.
- // All that's left is to check search_.longest.
+ // It passed every lookup in search_.middle. All that's left is to check search_.longest.
if (!search_.LookupLongest(*hist_iter, ret.prob, node)) {
- // It's an (P::Order()-1)-gram
- std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
- ret.ngram_length = out_state.valid_length_ = P::Order() - 1;
- backoff_start = P::Order() - 1;
+ //assert(ret.ngram_length <= P::Order() - 1);
+ out_state.valid_length_ = ret.ngram_length;
+ std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1);
// ret.prob was already set.
return ret;
}
- // It's an P::Order()-gram
+ // It's an P::Order()-gram. There is no blank in longest_.
// out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
out_state.valid_length_ = P::Order() - 1;
ret.ngram_length = P::Order();
- backoff_start = P::Order();
return ret;
}