diff options
Diffstat (limited to 'klm/lm/model.hh')
-rw-r--r-- | klm/lm/model.hh | 62 |
1 files changed, 42 insertions, 20 deletions
diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 21595321..c278acd6 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -12,6 +12,8 @@ #include "lm/vocab.hh" #include "lm/weights.hh" +#include "util/murmur_hash.hh" + #include <algorithm> #include <vector> @@ -27,42 +29,41 @@ namespace ngram { class State { public: bool operator==(const State &other) const { - if (valid_length_ != other.valid_length_) return false; - const WordIndex *end = history_ + valid_length_; - for (const WordIndex *first = history_, *second = other.history_; - first != end; ++first, ++second) { - if (*first != *second) return false; - } - // If the histories are equal, so are the backoffs. - return true; + if (length != other.length) return false; + return !memcmp(words, other.words, length * sizeof(WordIndex)); } // Three way comparison function. int Compare(const State &other) const { - if (valid_length_ == other.valid_length_) { - return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex)); - } - return (valid_length_ < other.valid_length_) ? -1 : 1; + if (length != other.length) return length < other.length ? -1 : 1; + return memcmp(words, other.words, length * sizeof(WordIndex)); + } + + bool operator<(const State &other) const { + if (length != other.length) return length < other.length; + return memcmp(words, other.words, length * sizeof(WordIndex)) < 0; } // Call this before using raw memcmp. void ZeroRemaining() { - for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) { - history_[i] = 0; - backoff_[i] = 0.0; + for (unsigned char i = length; i < kMaxOrder - 1; ++i) { + words[i] = 0; + backoff[i] = 0.0; } } - unsigned char ValidLength() const { return valid_length_; } + unsigned char Length() const { return length; } // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. - WordIndex history_[kMaxOrder - 1]; - float backoff_[kMaxOrder - 1]; - unsigned char valid_length_; + WordIndex words[kMaxOrder - 1]; + float backoff[kMaxOrder - 1]; + unsigned char length; }; -size_t hash_value(const State &state); +inline size_t hash_value(const State &state) { + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); +} namespace detail { @@ -75,6 +76,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod // This is the model type returned by RecognizeBinary. static const ModelType kModelType; + static const unsigned int kVersion = Search::kVersion; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -114,6 +117,25 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod */ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; + /* More efficient version of FullScore where a partial n-gram has already + * been scored. + * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if + * the n-gram does not end up extending further left, then 0 is returned. + */ + FullScoreReturn ExtendLeft( + // Additional context in reverse order. This will update add_rend to + const WordIndex *add_rbegin, const WordIndex *add_rend, + // Backoff weights to use. + const float *backoff_in, + // extend_left returned by a previous query. + uint64_t extend_pointer, + // Length of n-gram that the pointer corresponds to. + unsigned char extend_length, + // Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)] + float *backoff_out, + // Amount of additional content that should be considered by the next call. + unsigned char &next_use) const; + private: friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); |