summaryrefslogtreecommitdiff
path: root/klm/lm/value_build.hh
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-05-16 13:24:08 -0700
committerChris Dyer <cdyer@cab.ark.cs.cmu.edu>2012-05-26 22:59:54 -0400
commit2b63fa0755954edf467a2421997eaf72771260cf (patch)
treeffb22b22540cd59f20f7de6bfed4313f8b946407 /klm/lm/value_build.hh
parente331ea8e69489cfd727c0ad106c76efa69f3e06c (diff)
Big kenlm change includes lower order models for probing only. And other stuff.
Diffstat (limited to 'klm/lm/value_build.hh')
-rw-r--r--klm/lm/value_build.hh97
1 files changed, 97 insertions, 0 deletions
diff --git a/klm/lm/value_build.hh b/klm/lm/value_build.hh
new file mode 100644
index 00000000..687a41a0
--- /dev/null
+++ b/klm/lm/value_build.hh
@@ -0,0 +1,97 @@
+#ifndef LM_VALUE_BUILD__
+#define LM_VALUE_BUILD__
+
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+#include "util/bit_packing.hh"
+
+#include <vector>
+
+namespace lm {
+namespace ngram {
+
+class Config;
+class BackoffValue;
+class RestValue;
+
+class NoRestBuild {
+ public:
+ typedef BackoffValue Value;
+
+ NoRestBuild() {}
+
+ void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
+ void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {}
+
+ template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const {
+ util::UnsetSign(weights.prob);
+ return false;
+ }
+
+ // Probing doesn't need to go back to unigram.
+ const static bool kMarkEvenLower = false;
+};
+
+class MaxRestBuild {
+ public:
+ typedef RestValue Value;
+
+ MaxRestBuild() {}
+
+ void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
+ void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const {
+ weights.rest = weights.prob;
+ util::SetSign(weights.rest);
+ }
+
+ bool MarkExtends(RestWeights &weights, const RestWeights &to) const {
+ util::UnsetSign(weights.prob);
+ if (weights.rest >= to.rest) return false;
+ weights.rest = to.rest;
+ return true;
+ }
+ bool MarkExtends(RestWeights &weights, const Prob &to) const {
+ util::UnsetSign(weights.prob);
+ if (weights.rest >= to.prob) return false;
+ weights.rest = to.prob;
+ return true;
+ }
+
+ // Probing does need to go back to unigram.
+ const static bool kMarkEvenLower = true;
+};
+
+template <class Model> class LowerRestBuild {
+ public:
+ typedef RestValue Value;
+
+ LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab);
+
+ ~LowerRestBuild();
+
+ void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
+ void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const {
+ typename Model::State ignored;
+ if (n == 1) {
+ weights.rest = unigrams_[*vocab_ids];
+ } else {
+ weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob;
+ }
+ }
+
+ template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const {
+ util::UnsetSign(weights.prob);
+ return false;
+ }
+
+ const static bool kMarkEvenLower = false;
+
+ std::vector<float> unigrams_;
+
+ std::vector<const Model*> models_;
+};
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_VALUE_BUILD__