summaryrefslogtreecommitdiff
path: root/klm/lm/value_build.hh
blob: 6fd26ef8f99617ab34a25f89f9f0b5ed8518b8da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#ifndef LM_VALUE_BUILD_H
#define LM_VALUE_BUILD_H

#include "lm/weights.hh"
#include "lm/word_index.hh"
#include "util/bit_packing.hh"

#include <vector>

namespace lm {
namespace ngram {

struct Config;
struct BackoffValue;
struct RestValue;

class NoRestBuild {
  public:
    typedef BackoffValue Value;

    NoRestBuild() {}

    void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
    void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {}

    template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const {
      util::UnsetSign(weights.prob);
      return false;
    }

    // Probing doesn't need to go back to unigram.
    const static bool kMarkEvenLower = false;
};

class MaxRestBuild {
  public:
    typedef RestValue Value;

    MaxRestBuild() {}

    void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
    void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const {
      weights.rest = weights.prob;
      util::SetSign(weights.rest);
    }

    bool MarkExtends(RestWeights &weights, const RestWeights &to) const {
      util::UnsetSign(weights.prob);
      if (weights.rest >= to.rest) return false;
      weights.rest = to.rest;
      return true;
    }
    bool MarkExtends(RestWeights &weights, const Prob &to) const {
      util::UnsetSign(weights.prob);
      if (weights.rest >= to.prob) return false;
      weights.rest = to.prob;
      return true;
    }

    // Probing does need to go back to unigram.  
    const static bool kMarkEvenLower = true;
};

template <class Model> class LowerRestBuild {
  public:
    typedef RestValue Value;

    LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab);

    ~LowerRestBuild();

    void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
    void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const {
      typename Model::State ignored;
      if (n == 1) {
        weights.rest = unigrams_[*vocab_ids];
      } else {
        weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob;
      }
    }

    template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const {
      util::UnsetSign(weights.prob);
      return false;
    }

    const static bool kMarkEvenLower = false;

    std::vector<float> unigrams_;

    std::vector<const Model*> models_;
};

} // namespace ngram
} // namespace lm

#endif // LM_VALUE_BUILD_H