summaryrefslogtreecommitdiff
path: root/klm/lm/model.hh
blob: fd9640c335db98faab37e4a537b2080bfaeb1301 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#ifndef LM_MODEL__
#define LM_MODEL__

#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/facade.hh"
#include "lm/max_order.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/vocab.hh"
#include "lm/weights.hh"

#include <algorithm>
#include <vector>

#include <string.h>

namespace util { class FilePiece; }

namespace lm {
namespace ngram {

// This is a POD but if you want memcmp to return the same as operator==, call
// ZeroRemaining first.    
class State {
  public:
    bool operator==(const State &other) const {
      if (valid_length_ != other.valid_length_) return false;
      const WordIndex *end = history_ + valid_length_;
      for (const WordIndex *first = history_, *second = other.history_;
          first != end; ++first, ++second) {
        if (*first != *second) return false;
      }
      // If the histories are equal, so are the backoffs.  
      return true;
    }

    // Three way comparison function.  
    int Compare(const State &other) const {
      if (valid_length_ == other.valid_length_) {
        return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex));
      }
      return (valid_length_ < other.valid_length_) ? -1 : 1;
    }

    // Call this before using raw memcmp.  
    void ZeroRemaining() {
      for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) {
        history_[i] = 0;
        backoff_[i] = 0.0;
      }
    }

    unsigned char ValidLength() const { return valid_length_; }

    // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.  
    // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.  
    WordIndex history_[kMaxOrder - 1];
    float backoff_[kMaxOrder - 1];
    unsigned char valid_length_;
};

size_t hash_value(const State &state);

namespace detail {

// Should return the same results as SRI.  
// Why VocabularyT instead of just Vocabulary?  ModelFacade defines Vocabulary.  
template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
  private:
    typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
  public:
    // Get the size of memory that will be mapped given ngram counts.  This
    // does not include small non-mapped control structures, such as this class
    // itself.  
    static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());

    GenericModel(const char *file, const Config &config = Config());

    FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;

    /* Slower call without in_state.  Don't use this if you can avoid it.  This
     * is mostly a hack for Hieu to integrate it into Moses which sometimes
     * forgets LM state (i.e. it doesn't store it with the phrase).  Sigh.   
     * The context indices should be in an array.  
     * If context_rbegin != context_rend then *context_rbegin is the word
     * before new_word.  
     */
    FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;

    /* Get the state for a context.  Don't use this if you can avoid it.  Use
     * BeginSentenceState or EmptyContextState and extend from those.  If
     * you're only going to use this state to call FullScore once, use
     * FullScoreForgotState. */
    void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const;

  private:
    friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);

    float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const;

    FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;

    // Appears after Size in the cc file.
    void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);

    void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);

    void InitializeFromARPA(const char *file, const Config &config);

    Backing &MutableBacking() { return backing_; }

    static const ModelType kModelType = Search::kModelType;

    Backing backing_;
    
    VocabularyT vocab_;

    typedef typename Search::Unigram Unigram;
    typedef typename Search::Middle Middle;
    typedef typename Search::Longest Longest;

    Search search_;
};

} // namespace detail

// These must also be instantiated in the cc file.  
typedef ::lm::ngram::ProbingVocabulary Vocabulary;
typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel;
// Default implementation.  No real reason for it to be the default.  
typedef ProbingModel Model;

typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
typedef detail::GenericModel<detail::SortedHashedSearch, SortedVocabulary> SortedModel;

typedef detail::GenericModel<trie::TrieSearch, SortedVocabulary> TrieModel;

} // namespace ngram
} // namespace lm

#endif // LM_MODEL__