diff options
Diffstat (limited to 'klm/lm/model.cc')
-rw-r--r-- | klm/lm/model.cc | 239 |
1 files changed, 239 insertions, 0 deletions
diff --git a/klm/lm/model.cc b/klm/lm/model.cc new file mode 100644 index 00000000..6921d4d9 --- /dev/null +++ b/klm/lm/model.cc @@ -0,0 +1,239 @@ +#include "lm/model.hh" + +#include "lm/lm_exception.hh" +#include "lm/search_hashed.hh" +#include "lm/search_trie.hh" +#include "lm/read_arpa.hh" +#include "util/murmur_hash.hh" + +#include <algorithm> +#include <functional> +#include <numeric> +#include <cmath> + +namespace lm { +namespace ngram { + +size_t hash_value(const State &state) { + return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); +} + +namespace detail { + +template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) { + uint8_t *start = static_cast<uint8_t*>(base); + size_t allocated = VocabularyT::Size(counts[0], config); + vocab_.SetupMemory(start, allocated, counts[0], config); + start += allocated; + start = search_.SetupMemory(start, counts, config); + if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config)); +} + +template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { + LoadLM(file, config, *this); + + // g++ prints warnings unless these are fully initialized. + State begin_sentence = State(); + begin_sentence.valid_length_ = 1; + begin_sentence.history_[0] = vocab_.BeginSentence(); + begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; + State null_context = State(); + null_context.valid_length_ = 0; + P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { + SetupMemory(start, params.counts, config); + vocab_.LoadedBinary(fd, config.enumerate_vocab); + search_.unigram.LoadedBinary(); + for (typename std::vector<Middle>::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) { + i->LoadedBinary(); + } + search_.longest.LoadedBinary(); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters ¶ms, const Config &config) { + SetupMemory(start, params.counts, config); + + if (config.write_mmap) { + WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get()); + vocab_.ConfigureEnumerate(&wrap, params.counts[0]); + search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + } else { + vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]); + search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + } + // TODO: fail faster? + if (!vocab_.SawUnk()) { + switch(config.unknown_missing) { + case Config::THROW_UP: + { + SpecialWordMissingException e("<unk>"); + e << " and configuration was set to throw if unknown is missing"; + throw e; + } + case Config::COMPLAIN: + if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl; + // There's no break;. This is by design. + case Config::SILENT: + // Default probabilities for unknown. + search_.unigram.Unknown().backoff = 0.0; + search_.unigram.Unknown().prob = config.unknown_missing_prob; + break; + } + } + if (std::fabs(search_.unigram.Unknown().backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff); +} + +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { + unsigned char backoff_start; + FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state); + if (backoff_start - 1 < in_state.valid_length_) { + ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + } + return ret; +} + +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { + unsigned char backoff_start; + context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); + FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state); + ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start); + return ret; +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { + context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); + if (context_rend == context_rbegin || *context_rbegin == 0) { + out_state.valid_length_ = 0; + return; + } + float ignored_prob; + typename Search::Node node; + search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); + float *backoff_out = out_state.backoff_ + 1; + const WordIndex *i = context_rbegin + 1; + for (; i < context_rend; ++i, ++backoff_out) { + if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) { + out_state.valid_length_ = i - context_rbegin; + std::copy(context_rbegin, i, out_state.history_); + return; + } + } + std::copy(context_rbegin, context_rend, out_state.history_); + out_state.valid_length_ = static_cast<unsigned char>(context_rend - context_rbegin); +} + +template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup( + const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const { + // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). + if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0; + float ret = 0.0; + if (start == 1) { + ret += search_.unigram.Lookup(*context_rbegin).backoff; + start = 2; + } + typename Search::Node node; + if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { + return 0.0; + } + float backoff; + // i is the order of the backoff we're looking for. + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { + if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; + ret += backoff; + } + return ret; +} + +/* Ugly optimized function. Produce a score excluding backoff. + * The search goes in increasing order of ngram length. + * Context goes backward, so context_begin is the word immediately preceeding + * new_word. + */ +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff( + const WordIndex *context_rbegin, + const WordIndex *context_rend, + const WordIndex new_word, + unsigned char &backoff_start, + State &out_state) const { + FullScoreReturn ret; + typename Search::Node node; + float *backoff_out(out_state.backoff_); + search_.LookupUnigram(new_word, ret.prob, *backoff_out, node); + if (new_word == 0) { + ret.ngram_length = out_state.valid_length_ = 0; + // All of backoff. + backoff_start = 1; + return ret; + } + out_state.history_[0] = new_word; + if (context_rbegin == context_rend) { + ret.ngram_length = out_state.valid_length_ = 1; + // No backoff because we don't have the history for it. + backoff_start = P::Order(); + return ret; + } + ++backoff_out; + + // Ok now we now that the bigram contains known words. Start by looking it up. + + const WordIndex *hist_iter = context_rbegin; + typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin(); + for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { + if (hist_iter == context_rend) { + // Ran out of history. No backoff. + backoff_start = P::Order(); + std::copy(context_rbegin, context_rend, out_state.history_ + 1); + ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1; + // ret.prob was already set. + return ret; + } + + if (mid_iter == search_.middle.end()) break; + + if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { + // Didn't find an ngram using hist_iter. + // The history used in the found n-gram is [context_rbegin, hist_iter). + std::copy(context_rbegin, hist_iter, out_state.history_ + 1); + // Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word. + ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1; + backoff_start = mid_iter - search_.middle.begin() + 1; + // ret.prob was already set. + return ret; + } + } + + // It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram. + // All that's left is to check search_.longest. + + if (!search_.LookupLongest(*hist_iter, ret.prob, node)) { + // It's an (P::Order()-1)-gram + std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); + ret.ngram_length = out_state.valid_length_ = P::Order() - 1; + backoff_start = P::Order() - 1; + // ret.prob was already set. + return ret; + } + // It's an P::Order()-gram + // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. + std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); + out_state.valid_length_ = P::Order() - 1; + ret.ngram_length = P::Order(); + backoff_start = P::Order(); + return ret; +} + +template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; +template class GenericModel<SortedHashedSearch, SortedVocabulary>; +template class GenericModel<trie::TrieSearch, SortedVocabulary>; + +} // namespace detail +} // namespace ngram +} // namespace lm |