diff options
Diffstat (limited to 'klm/lm/model.hh')
-rw-r--r-- | klm/lm/model.hh | 93 |
1 files changed, 35 insertions, 58 deletions
diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 6ea62a78..be872178 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -9,6 +9,8 @@ #include "lm/quantize.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" +#include "lm/state.hh" +#include "lm/value.hh" #include "lm/vocab.hh" #include "lm/weights.hh" @@ -23,48 +25,6 @@ namespace util { class FilePiece; } namespace lm { namespace ngram { - -// This is a POD but if you want memcmp to return the same as operator==, call -// ZeroRemaining first. -class State { - public: - bool operator==(const State &other) const { - if (length != other.length) return false; - return !memcmp(words, other.words, length * sizeof(WordIndex)); - } - - // Three way comparison function. - int Compare(const State &other) const { - if (length != other.length) return length < other.length ? -1 : 1; - return memcmp(words, other.words, length * sizeof(WordIndex)); - } - - bool operator<(const State &other) const { - if (length != other.length) return length < other.length; - return memcmp(words, other.words, length * sizeof(WordIndex)) < 0; - } - - // Call this before using raw memcmp. - void ZeroRemaining() { - for (unsigned char i = length; i < kMaxOrder - 1; ++i) { - words[i] = 0; - backoff[i] = 0.0; - } - } - - unsigned char Length() const { return length; } - - // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. - // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. - WordIndex words[kMaxOrder - 1]; - float backoff[kMaxOrder - 1]; - unsigned char length; -}; - -inline size_t hash_value(const State &state) { - return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); -} - namespace detail { // Should return the same results as SRI. @@ -119,8 +79,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod /* More efficient version of FullScore where a partial n-gram has already * been scored. - * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if - * the n-gram does not end up extending further left, then 0 is returned. + * NOTE: THE RETURNED .rest AND .prob ARE RELATIVE TO THE .rest RETURNED BEFORE. */ FullScoreReturn ExtendLeft( // Additional context in reverse order. This will update add_rend to @@ -136,12 +95,24 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod // Amount of additional content that should be considered by the next call. unsigned char &next_use) const; + /* Return probabilities minus rest costs for an array of pointers. The + * first length should be the length of the n-gram to which pointers_begin + * points. + */ + float UnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const { + // Compiler should optimize this if away. + return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0; + } + private: friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); - FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; + FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; + + // Score bigrams and above. Do not include backoff. + void ResumeScore(const WordIndex *context_rbegin, const WordIndex *const context_rend, unsigned char starting_order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const; // Appears after Size in the cc file. void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); @@ -150,32 +121,38 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod void InitializeFromARPA(const char *file, const Config &config); + float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const; + Backing &MutableBacking() { return backing_; } Backing backing_; VocabularyT vocab_; - typedef typename Search::Middle Middle; - Search search_; }; } // namespace detail -// These must also be instantiated in the cc file. -typedef ::lm::ngram::ProbingVocabulary Vocabulary; -typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; // HASH_PROBING -// Default implementation. No real reason for it to be the default. -typedef ProbingModel Model; +// Instead of typedef, inherit. This allows the Model etc to be forward declared. +// Oh the joys of C and C++. +#define LM_COMMA() , +#define LM_NAME_MODEL(name, from)\ +class name : public from {\ + public:\ + name(const char *file, const Config &config = Config()) : from(file, config) {}\ +}; -// Smaller implementation. -typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary> TrieModel; // TRIE_SORTED -typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary> ArrayTrieModel; +LM_NAME_MODEL(ProbingModel, detail::GenericModel<detail::HashedSearch<BackoffValue> LM_COMMA() ProbingVocabulary>); +LM_NAME_MODEL(RestProbingModel, detail::GenericModel<detail::HashedSearch<RestValue> LM_COMMA() ProbingVocabulary>); +LM_NAME_MODEL(TrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>); +LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>); +LM_NAME_MODEL(QuantTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>); +LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>); -typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED -typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary> QuantArrayTrieModel; +// Default implementation. No real reason for it to be the default. +typedef ::lm::ngram::ProbingVocabulary Vocabulary; +typedef ProbingModel Model; } // namespace ngram } // namespace lm |