From bee6a3c3f6c54cf7449229488c6124dddc7e2f31 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 18 Jan 2011 15:55:40 -0500 Subject: new version of klm --- klm/lm/model.cc | 117 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 63 insertions(+), 54 deletions(-) (limited to 'klm/lm/model.cc') diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 421e72fa..c7ba4908 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -1,5 +1,6 @@ #include "lm/model.hh" +#include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" @@ -21,9 +22,6 @@ size_t hash_value(const State &state) { namespace detail { template size_t GenericModel::Size(const std::vector &counts, const Config &config) { - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); - if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); - if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } @@ -59,17 +57,31 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters ¶ms, const Config &config) { - SetupMemory(start, params.counts, config); +template void GenericModel::InitializeFromARPA(const char *file, const Config &config) { + // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. + util::FilePiece f(backing_.file.release(), file, config.messages); + std::vector counts; + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed with search_.VariableSizeLoad + ReadARPACounts(f, counts); + + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + + std::size_t vocab_size = VocabularyT::Size(counts[0], config); + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); if (config.write_mmap) { - WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get()); - vocab_.ConfigureEnumerate(&wrap, params.counts[0]); - search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + WriteWordsWrapper wrap(config.enumerate_vocab); + vocab_.ConfigureEnumerate(&wrap, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + wrap.Write(backing_.file.get()); } else { - vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]); - search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); } + // TODO: fail faster? if (!vocab_.SawUnk()) { switch(config.unknown_missing) { @@ -89,46 +101,49 @@ template void GenericModel 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff); } template FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - unsigned char backoff_start; - FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state); - if (backoff_start - 1 < in_state.valid_length_) { - ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state); + if (ret.ngram_length - 1 < in_state.valid_length_) { + ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); } return ret; } template FullScoreReturn GenericModel::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { - unsigned char backoff_start; context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); - FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state); - ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start); + FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state); + ret.prob += SlowBackoffLookup(context_rbegin, context_rend, ret.ngram_length); return ret; } template void GenericModel::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { + // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); - if (context_rend == context_rbegin || *context_rbegin == 0) { + if (context_rend == context_rbegin) { out_state.valid_length_ = 0; return; } float ignored_prob; typename Search::Node node; search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); + // Tricky part is that an entry might be blank, but out_state.valid_length_ always has the last non-blank n-gram length. + out_state.valid_length_ = 1; float *backoff_out = out_state.backoff_ + 1; - const WordIndex *i = context_rbegin + 1; - for (; i < context_rend; ++i, ++backoff_out) { - if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) { - out_state.valid_length_ = i - context_rbegin; - std::copy(context_rbegin, i, out_state.history_); + const typename Search::Middle *mid = &*search_.middle.begin(); + for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { + if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { + std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); return; } + if (*backoff_out != kBlankBackoff) { + out_state.valid_length_ = i - context_rbegin + 1; + } else { + *backoff_out = 0.0; + } } - std::copy(context_rbegin, context_rend, out_state.history_); - out_state.valid_length_ = static_cast(context_rend - context_rbegin); + std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); } template float GenericModel::SlowBackoffLookup( @@ -148,7 +163,7 @@ template float GenericModel FullScoreReturn GenericModel FullScoreReturn GenericModel::const_iterator mid_iter = search_.middle.begin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { - // Ran out of history. No backoff. - backoff_start = P::Order(); - std::copy(context_rbegin, context_rend, out_state.history_ + 1); - ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1; + // Ran out of history. Typically no backoff, but this could be a blank. + out_state.valid_length_ = ret.ngram_length; + std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); // ret.prob was already set. return ret; } if (mid_iter == search_.middle.end()) break; + float revert = ret.prob; if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { // Didn't find an ngram using hist_iter. - // The history used in the found n-gram is [context_rbegin, hist_iter). - std::copy(context_rbegin, hist_iter, out_state.history_ + 1); - // Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word. - ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1; - backoff_start = mid_iter - search_.middle.begin() + 1; + std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); + out_state.valid_length_ = ret.ngram_length; // ret.prob was already set. return ret; } + if (*backoff_out == kBlankBackoff) { + *backoff_out = 0.0; + ret.prob = revert; + } else { + ret.ngram_length = hist_iter - context_rbegin + 2; + } } - // It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram. - // All that's left is to check search_.longest. + // It passed every lookup in search_.middle. All that's left is to check search_.longest. if (!search_.LookupLongest(*hist_iter, ret.prob, node)) { - // It's an (P::Order()-1)-gram - std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); - ret.ngram_length = out_state.valid_length_ = P::Order() - 1; - backoff_start = P::Order() - 1; + //assert(ret.ngram_length <= P::Order() - 1); + out_state.valid_length_ = ret.ngram_length; + std::copy(context_rbegin, context_rbegin + ret.ngram_length - 1, out_state.history_ + 1); // ret.prob was already set. return ret; } - // It's an P::Order()-gram + // It's an P::Order()-gram. There is no blank in longest_. // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); out_state.valid_length_ = P::Order() - 1; ret.ngram_length = P::Order(); - backoff_start = P::Order(); return ret; } -- cgit v1.2.3