diff options
Diffstat (limited to 'klm/lm/model.hh')
-rw-r--r-- | klm/lm/model.hh | 47 |
1 files changed, 34 insertions, 13 deletions
diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 21595321..fe91af2e 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -27,9 +27,9 @@ namespace ngram { class State { public: bool operator==(const State &other) const { - if (valid_length_ != other.valid_length_) return false; - const WordIndex *end = history_ + valid_length_; - for (const WordIndex *first = history_, *second = other.history_; + if (length != other.length) return false; + const WordIndex *end = words + length; + for (const WordIndex *first = words, *second = other.words; first != end; ++first, ++second) { if (*first != *second) return false; } @@ -39,27 +39,27 @@ class State { // Three way comparison function. int Compare(const State &other) const { - if (valid_length_ == other.valid_length_) { - return memcmp(history_, other.history_, valid_length_ * sizeof(WordIndex)); + if (length == other.length) { + return memcmp(words, other.words, length * sizeof(WordIndex)); } - return (valid_length_ < other.valid_length_) ? -1 : 1; + return (length < other.length) ? -1 : 1; } // Call this before using raw memcmp. void ZeroRemaining() { - for (unsigned char i = valid_length_; i < kMaxOrder - 1; ++i) { - history_[i] = 0; - backoff_[i] = 0.0; + for (unsigned char i = length; i < kMaxOrder - 1; ++i) { + words[i] = 0; + backoff[i] = 0.0; } } - unsigned char ValidLength() const { return valid_length_; } + unsigned char Length() const { return length; } // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. - WordIndex history_[kMaxOrder - 1]; - float backoff_[kMaxOrder - 1]; - unsigned char valid_length_; + WordIndex words[kMaxOrder - 1]; + float backoff[kMaxOrder - 1]; + unsigned char length; }; size_t hash_value(const State &state); @@ -75,6 +75,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod // This is the model type returned by RecognizeBinary. static const ModelType kModelType; + static const unsigned int kVersion = Search::kVersion; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -114,6 +116,25 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod */ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; + /* More efficient version of FullScore where a partial n-gram has already + * been scored. + * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if + * the n-gram does not end up extending further left, then 0 is returned. + */ + FullScoreReturn ExtendLeft( + // Additional context in reverse order. This will update add_rend to + const WordIndex *add_rbegin, const WordIndex *add_rend, + // Backoff weights to use. + const float *backoff_in, + // extend_left returned by a previous query. + uint64_t extend_pointer, + // Length of n-gram that the pointer corresponds to. + unsigned char extend_length, + // Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)] + float *backoff_out, + // Amount of additional content that should be considered by the next call. + unsigned char &next_use) const; + private: friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); |