diff options
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r-- | klm/lm/model.cc | 117 |
1 files changed, 63 insertions, 54 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 421e72fa..c7ba4908 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -1,5 +1,6 @@ #include "lm/model.hh" +#include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" @@ -21,9 +22,6 @@ size_t hash_value(const State &state) { namespace detail { template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); - if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); - if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } @@ -59,17 +57,31 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.longest.LoadedBinary(); } -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters ¶ms, const Config &config) { - SetupMemory(start, params.counts, config); +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { + // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. + util::FilePiece f(backing_.file.release(), file, config.messages); + std::vector<uint64_t> counts; + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed with search_.VariableSizeLoad + ReadARPACounts(f, counts); + + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + + std::size_t vocab_size = VocabularyT::Size(counts[0], config); + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); if (config.write_mmap) { - WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get()); - vocab_.ConfigureEnumerate(&wrap, params.counts[0]); - search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + WriteWordsWrapper wrap(config.enumerate_vocab); + vocab_.ConfigureEnumerate(&wrap, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + wrap.Write(backing_.file.get()); } else { - vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]); - search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); } + // TODO: fail faster? if (!vocab_.SawUnk()) { switch(config.unknown_missing) { @@ -89,46 +101,49 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT break; } } - if (std::fabs(search_.unigram.Unknown().backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff); } template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - unsigned char backoff_start; - FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state); - if (backoff_start - 1 < in_state.valid_length_) { - ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state); + if (ret.ngram_length - 1 < in_state.valid_length_) { + ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); } return ret; } template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { - unsigned char backoff_start; context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); - FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state); - ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start); + FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state); + ret.prob += SlowBackoffLookup(context_rbegin, context_rend, ret.ngram_length); return ret; } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { + // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); - if (context_rend == context_rbegin || *context_rbegin == 0) { + if (context_rend == context_rbegin) { out_state.valid_length_ = 0; return; } float ignored_prob; typename Search::Node node; search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); + // Tricky part is that an entry might be blank, but out_state.valid_length_ always has the last non-blank n-gram length. + out_state.valid_length_ = 1; float *backoff_out = out_state.backoff_ + 1; - const WordIndex *i = context_rbegin + 1; - for (; i < context_rend; ++i, ++backoff_out) { - if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) { - out_state.valid_length_ = i - context_rbegin; - std::copy(context_rbegin, i, out_state.history_); + const typename Search::Middle *mid = &*search_.middle.begin(); + for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { + if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { + std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); return; } + if (*backoff_out != kBlankBackoff) { + out_state.valid_length_ = i - context_rbegin + 1; + } else { + *backoff_out = 0.0; + } } - std::copy(context_rbegin, context_rend, out_state.history_); - out_state.valid_length_ = static_cast<unsigned char>(context_rend - context_rbegin); + std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); } template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup( @@ -148,7 +163,7 @@ template <class Search, class VocabularyT> float GenericModel<Search, Vocabulary // i is the order of the backoff we're looking for. for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; - ret += backoff; + if (backoff != kBlankBackoff) ret += backoff; } return ret; } @@ -162,23 +177,17 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, - unsigned char &backoff_start, State &out_state) const { FullScoreReturn ret; + // ret.ngram_length contains the last known good (non-blank) ngram length. + ret.ngram_length = 1; + typename Search::Node node; float *backoff_out(out_state.backoff_); search_.LookupUnigram(new_word, ret.prob, *backoff_out, node); - if (new_word == 0) { - ret.ngram_length = out_state.valid_length_ = 0; - // All of backoff. - backoff_start = 1; - return ret; - } out_state.history_[0] = new_word; if (context_rbegin == context_rend) { - ret.ngram_length = out_state.valid_length_ = 1; - // No backoff because we don't have the history for it. - backoff_start = P::Order(); + out_state.valid_length_ = 1; return ret; } ++backoff_out; @@ -189,45 +198,45 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { - // Ran out of history. No backoff. - backoff_start = P::Order(); - std::copy(context_rbegin, context_rend, out_state.history_ + 1); - ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1; + // Ran out of history. Typically no backoff, but this could be a blank. + out_state.valid_length_ = ret.ngram_length; + std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); // ret.prob was already set. return ret; } if (mid_iter == search_.middle.end()) break; + float revert = ret.prob; if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { // Didn't find an ngram using hist_iter. - // The history used in the found n-gram is [context_rbegin, hist_iter). - std::copy(context_rbegin, hist_iter, out_state.history_ + 1); - // Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word. - ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1; - backoff_start = mid_iter - search_.middle.begin() + 1; + std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); + out_state.valid_length_ = ret.ngram_length; // ret.prob was already set. return ret; } + if (*backoff_out == kBlankBackoff) { + *backoff_out = 0.0; + ret.prob = revert; + } else { + ret.ngram_length = hist_iter - context_rbegin + 2; + } } - // It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram. - // All that's left is to check search_.longest. + // It passed every lookup in search_.middle. All that's left is to check search_.longest. if (!search_.LookupLongest(*hist_iter, ret.prob, node)) { - // It's an (P::Order()-1)-gram - std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); - ret.ngram_length = out_state.valid_length_ = P::Order() - 1; - backoff_start = P::Order() - 1; + //assert(ret.ngram_length <= P::Order() - 1); + out_state.valid_length_ = ret.ngram_length; + std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); // ret.prob was already set. return ret; } - // It's an P::Order()-gram + // It's an P::Order()-gram. There is no blank in longest_. // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); out_state.valid_length_ = P::Order() - 1; ret.ngram_length = P::Order(); - backoff_start = P::Order(); return ret; } |