diff options
Diffstat (limited to 'klm/lm/facade.hh')
-rw-r--r-- | klm/lm/facade.hh | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/klm/lm/facade.hh b/klm/lm/facade.hh new file mode 100644 index 00000000..8b186017 --- /dev/null +++ b/klm/lm/facade.hh @@ -0,0 +1,64 @@ +#ifndef LM_FACADE__ +#define LM_FACADE__ + +#include "lm/virtual_interface.hh" +#include "util/string_piece.hh" + +#include <string> + +namespace lm { +namespace base { + +// Common model interface that depends on knowing the specific classes. +// Curiously recurring template pattern. +template <class Child, class StateT, class VocabularyT> class ModelFacade : public Model { + public: + typedef StateT State; + typedef VocabularyT Vocabulary; + + // Default Score function calls FullScore. Model can override this. + float Score(const State &in_state, const WordIndex new_word, State &out_state) const { + return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob; + } + + /* Translate from void* to State */ + FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const { + return static_cast<const Child*>(this)->FullScore( + *reinterpret_cast<const State*>(in_state), + new_word, + *reinterpret_cast<State*>(out_state)); + } + float Score(const void *in_state, const WordIndex new_word, void *out_state) const { + return static_cast<const Child*>(this)->Score( + *reinterpret_cast<const State*>(in_state), + new_word, + *reinterpret_cast<State*>(out_state)); + } + + const State &BeginSentenceState() const { return begin_sentence_; } + const State &NullContextState() const { return null_context_; } + const Vocabulary &GetVocabulary() const { return *static_cast<const Vocabulary*>(&BaseVocabulary()); } + + protected: + ModelFacade() : Model(sizeof(State)) {} + + virtual ~ModelFacade() {} + + // begin_sentence and null_context can disappear after. vocab should stay. + void Init(const State &begin_sentence, const State &null_context, const Vocabulary &vocab, unsigned char order) { + begin_sentence_ = begin_sentence; + null_context_ = null_context; + begin_sentence_memory_ = &begin_sentence_; + null_context_memory_ = &null_context_; + base_vocab_ = &vocab; + order_ = order; + } + + private: + State begin_sentence_, null_context_; +}; + +} // mamespace base +} // namespace lm + +#endif // LM_FACADE__ |