From 149232c38eec558ddb1097698d1570aacb67b59f Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 16 May 2012 13:24:08 -0700 Subject: Big kenlm change includes lower order models for probing only. And other stuff. --- klm/lm/value_build.hh | 97 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 klm/lm/value_build.hh (limited to 'klm/lm/value_build.hh') diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh new file mode 100644 index 00000000..687a41a0 --- /dev/null +++ b/klm/lm/value_build.hh @@ -0,0 +1,97 @@ +#ifndef LM_VALUE_BUILD__ +#define LM_VALUE_BUILD__ + +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/bit_packing.hh" + +#include + +namespace lm { +namespace ngram { + +class Config; +class BackoffValue; +class RestValue; + +class NoRestBuild { + public: + typedef BackoffValue Value; + + NoRestBuild() {} + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {} + + template bool MarkExtends(ProbBackoff &weights, const Second &) const { + util::UnsetSign(weights.prob); + return false; + } + + // Probing doesn't need to go back to unigram. + const static bool kMarkEvenLower = false; +}; + +class MaxRestBuild { + public: + typedef RestValue Value; + + MaxRestBuild() {} + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const { + weights.rest = weights.prob; + util::SetSign(weights.rest); + } + + bool MarkExtends(RestWeights &weights, const RestWeights &to) const { + util::UnsetSign(weights.prob); + if (weights.rest >= to.rest) return false; + weights.rest = to.rest; + return true; + } + bool MarkExtends(RestWeights &weights, const Prob &to) const { + util::UnsetSign(weights.prob); + if (weights.rest >= to.prob) return false; + weights.rest = to.prob; + return true; + } + + // Probing does need to go back to unigram. + const static bool kMarkEvenLower = true; +}; + +template class LowerRestBuild { + public: + typedef RestValue Value; + + LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab); + + ~LowerRestBuild(); + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const { + typename Model::State ignored; + if (n == 1) { + weights.rest = unigrams_[*vocab_ids]; + } else { + weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob; + } + } + + template bool MarkExtends(RestWeights &weights, const Second &) const { + util::UnsetSign(weights.prob); + return false; + } + + const static bool kMarkEvenLower = false; + + std::vector unigrams_; + + std::vector models_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_VALUE_BUILD__ -- cgit v1.2.3