diff options
author | Kenneth Heafield <github@kheafield.com> | 2012-05-16 13:24:08 -0700 |
---|---|---|
committer | Chris Dyer <cdyer@cab.ark.cs.cmu.edu> | 2012-05-26 22:59:54 -0400 |
commit | 2b63fa0755954edf467a2421997eaf72771260cf (patch) | |
tree | ffb22b22540cd59f20f7de6bfed4313f8b946407 /klm/lm/value_build.hh | |
parent | e331ea8e69489cfd727c0ad106c76efa69f3e06c (diff) |
Big kenlm change includes lower order models for probing only. And other stuff.
Diffstat (limited to 'klm/lm/value_build.hh')
-rw-r--r-- | klm/lm/value_build.hh | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh new file mode 100644 index 00000000..687a41a0 --- /dev/null +++ b/klm/lm/value_build.hh @@ -0,0 +1,97 @@ +#ifndef LM_VALUE_BUILD__ +#define LM_VALUE_BUILD__ + +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/bit_packing.hh" + +#include <vector> + +namespace lm { +namespace ngram { + +class Config; +class BackoffValue; +class RestValue; + +class NoRestBuild { + public: + typedef BackoffValue Value; + + NoRestBuild() {} + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {} + + template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const { + util::UnsetSign(weights.prob); + return false; + } + + // Probing doesn't need to go back to unigram. + const static bool kMarkEvenLower = false; +}; + +class MaxRestBuild { + public: + typedef RestValue Value; + + MaxRestBuild() {} + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const { + weights.rest = weights.prob; + util::SetSign(weights.rest); + } + + bool MarkExtends(RestWeights &weights, const RestWeights &to) const { + util::UnsetSign(weights.prob); + if (weights.rest >= to.rest) return false; + weights.rest = to.rest; + return true; + } + bool MarkExtends(RestWeights &weights, const Prob &to) const { + util::UnsetSign(weights.prob); + if (weights.rest >= to.prob) return false; + weights.rest = to.prob; + return true; + } + + // Probing does need to go back to unigram. + const static bool kMarkEvenLower = true; +}; + +template <class Model> class LowerRestBuild { + public: + typedef RestValue Value; + + LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab); + + ~LowerRestBuild(); + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const { + typename Model::State ignored; + if (n == 1) { + weights.rest = unigrams_[*vocab_ids]; + } else { + weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob; + } + } + + template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const { + util::UnsetSign(weights.prob); + return false; + } + + const static bool kMarkEvenLower = false; + + std::vector<float> unigrams_; + + std::vector<const Model*> models_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_VALUE_BUILD__ |