diff options
author | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-05-31 13:57:24 +0200 |
---|---|---|
committer | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-05-31 13:57:24 +0200 |
commit | f1ba05780db1705493d9afb562332498b93d26f1 (patch) | |
tree | fb429a657ba97f33e8140742de9bc74d9fc88e75 /klm/lm/value_build.hh | |
parent | aadabfdf37dfd451485277cb77fad02f77b361c6 (diff) | |
parent | 317d650f6cb1e24ac6f3be6f7bf9d4246a59e0e5 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'klm/lm/value_build.hh')
-rw-r--r-- | klm/lm/value_build.hh | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh new file mode 100644 index 00000000..687a41a0 --- /dev/null +++ b/klm/lm/value_build.hh @@ -0,0 +1,97 @@ +#ifndef LM_VALUE_BUILD__ +#define LM_VALUE_BUILD__ + +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/bit_packing.hh" + +#include <vector> + +namespace lm { +namespace ngram { + +class Config; +class BackoffValue; +class RestValue; + +class NoRestBuild { + public: + typedef BackoffValue Value; + + NoRestBuild() {} + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {} + + template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const { + util::UnsetSign(weights.prob); + return false; + } + + // Probing doesn't need to go back to unigram. + const static bool kMarkEvenLower = false; +}; + +class MaxRestBuild { + public: + typedef RestValue Value; + + MaxRestBuild() {} + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const { + weights.rest = weights.prob; + util::SetSign(weights.rest); + } + + bool MarkExtends(RestWeights &weights, const RestWeights &to) const { + util::UnsetSign(weights.prob); + if (weights.rest >= to.rest) return false; + weights.rest = to.rest; + return true; + } + bool MarkExtends(RestWeights &weights, const Prob &to) const { + util::UnsetSign(weights.prob); + if (weights.rest >= to.prob) return false; + weights.rest = to.prob; + return true; + } + + // Probing does need to go back to unigram. + const static bool kMarkEvenLower = true; +}; + +template <class Model> class LowerRestBuild { + public: + typedef RestValue Value; + + LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab); + + ~LowerRestBuild(); + + void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} + void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const { + typename Model::State ignored; + if (n == 1) { + weights.rest = unigrams_[*vocab_ids]; + } else { + weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob; + } + } + + template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const { + util::UnsetSign(weights.prob); + return false; + } + + const static bool kMarkEvenLower = false; + + std::vector<float> unigrams_; + + std::vector<const Model*> models_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_VALUE_BUILD__ |