summaryrefslogtreecommitdiff
path: root/klm/lm/virtual_interface.hh
blob: 6a5a0196fa48056a496cfc44ad214923a6f94ff2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#ifndef LM_VIRTUAL_INTERFACE__
#define LM_VIRTUAL_INTERFACE__

#include "lm/return.hh"
#include "lm/word_index.hh"
#include "util/string_piece.hh"

#include <string>

namespace lm {
namespace base {

template <class T, class U, class V> class ModelFacade;

/* Vocabulary interface.  Call Index(string) and get a word index for use in
 * calling Model.  It provides faster convenience functions for <s>, </s>, and
 * <unk> although you can also find these using Index.  
 *
 * Some models do not load the mapping from index to string.  If you need this,
 * check if the model Vocabulary class implements such a function and access it
 * directly.  
 *
 * The Vocabulary object is always owned by the Model and can be retrieved from
 * the Model using BaseVocabulary() for this abstract interface or
 * GetVocabulary() for the actual implementation (in which case you'll need the
 * actual implementation of the Model too).  
 */
class Vocabulary {
  public:
    virtual ~Vocabulary();

    WordIndex BeginSentence() const { return begin_sentence_; }
    WordIndex EndSentence() const { return end_sentence_; }
    WordIndex NotFound() const { return not_found_; }

    /* Most implementations allow StringPiece lookups and need only override
     * Index(StringPiece).  SRI requires null termination and overrides all
     * three methods.  
     */
    virtual WordIndex Index(const StringPiece &str) const = 0;
    virtual WordIndex Index(const std::string &str) const {
      return Index(StringPiece(str));
    }
    virtual WordIndex Index(const char *str) const {
      return Index(StringPiece(str));
    }

  protected:
    // Call SetSpecial afterward.  
    Vocabulary() {}

    Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) {
      SetSpecial(begin_sentence, end_sentence, not_found);
    }

    void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found);

    WordIndex begin_sentence_, end_sentence_, not_found_;

  private:
    // Disable copy constructors.  They're private and undefined. 
    // Ersatz boost::noncopyable.
    Vocabulary(const Vocabulary &);
    Vocabulary &operator=(const Vocabulary &);
};

/* There are two ways to access a Model.  
 *
 *
 * OPTION 1: Access the Model directly (e.g. lm::ngram::Model in model.hh).
 *
 * Every Model implements the scoring function:
 * float Score(
 *   const Model::State &in_state,
 *   const WordIndex new_word,
 *   Model::State &out_state) const;
 *
 * It can also return the length of n-gram matched by the model:
 * FullScoreReturn FullScore(
 *   const Model::State &in_state,
 *   const WordIndex new_word,
 *   Model::State &out_state) const;
 *
 *
 * There are also accessor functions:
 * const State &BeginSentenceState() const;
 * const State &NullContextState() const;
 * const Vocabulary &GetVocabulary() const;
 * unsigned int Order() const;
 *
 * NB: In case you're wondering why the model implementation looks like it's
 * missing these methods, see facade.hh.  
 *
 * This is the fastest way to use a model and presents a normal State class to
 * be included in a hypothesis state structure.  
 *
 *
 * OPTION 2: Use the virtual interface below.  
 *
 * The virtual interface allow you to decide which Model to use at runtime 
 * without templatizing everything on the Model type.  However, each Model has
 * its own State class, so a single State cannot be efficiently provided (it
 * would require using the maximum memory of any Model's State or memory
 * allocation with each lookup).  This means you become responsible for
 * allocating memory with size StateSize() and passing it to the Score or 
 * FullScore functions provided here.  
 *
 * For example, cdec has a std::string containing the entire state of a
 * hypothesis.  It can reserve StateSize bytes in this string for the model
 * state.  
 *
 * All the State objects are POD, so it's ok to use raw memory for storing
 * State.
 * in_state and out_state must not have the same address. 
 */
class Model {
  public:
    virtual ~Model();

    size_t StateSize() const { return state_size_; }
    const void *BeginSentenceMemory() const { return begin_sentence_memory_; }
    const void *NullContextMemory() const { return null_context_memory_; }

    // Requires in_state != out_state
    virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0;

    // Requires in_state != out_state
    virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;

    unsigned char Order() const { return order_; }

    const Vocabulary &BaseVocabulary() const { return *base_vocab_; }

  private:
    template <class T, class U, class V> friend class ModelFacade;
    explicit Model(size_t state_size) : state_size_(state_size) {}

    const size_t state_size_;
    const void *begin_sentence_memory_, *null_context_memory_;

    const Vocabulary *base_vocab_;

    unsigned char order_;

    // Disable copy constructors.  They're private and undefined. 
    // Ersatz boost::noncopyable.
    Model(const Model &);
    Model &operator=(const Model &);
};

} // mamespace base
} // namespace lm

#endif // LM_VIRTUAL_INTERFACE__