summaryrefslogtreecommitdiff
path: root/klm/lm/model.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.hh')
-rw-r--r--klm/lm/model.hh93
1 files changed, 35 insertions, 58 deletions
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
index 6ea62a78..be872178 100644
--- a/klm/lm/model.hh
+++ b/klm/lm/model.hh
@@ -9,6 +9,8 @@
#include "lm/quantize.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
+#include "lm/state.hh"
+#include "lm/value.hh"
#include "lm/vocab.hh"
#include "lm/weights.hh"
@@ -23,48 +25,6 @@ namespace util { class FilePiece; }
namespace lm {
namespace ngram {
-
-// This is a POD but if you want memcmp to return the same as operator==, call
-// ZeroRemaining first.
-class State {
- public:
- bool operator==(const State &other) const {
- if (length != other.length) return false;
- return !memcmp(words, other.words, length * sizeof(WordIndex));
- }
-
- // Three way comparison function.
- int Compare(const State &other) const {
- if (length != other.length) return length < other.length ? -1 : 1;
- return memcmp(words, other.words, length * sizeof(WordIndex));
- }
-
- bool operator<(const State &other) const {
- if (length != other.length) return length < other.length;
- return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;
- }
-
- // Call this before using raw memcmp.
- void ZeroRemaining() {
- for (unsigned char i = length; i < kMaxOrder - 1; ++i) {
- words[i] = 0;
- backoff[i] = 0.0;
- }
- }
-
- unsigned char Length() const { return length; }
-
- // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
- // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
- WordIndex words[kMaxOrder - 1];
- float backoff[kMaxOrder - 1];
- unsigned char length;
-};
-
-inline size_t hash_value(const State &state) {
- return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length);
-}
-
namespace detail {
// Should return the same results as SRI.
@@ -119,8 +79,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
/* More efficient version of FullScore where a partial n-gram has already
* been scored.
- * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if
- * the n-gram does not end up extending further left, then 0 is returned.
+ * NOTE: THE RETURNED .rest AND .prob ARE RELATIVE TO THE .rest RETURNED BEFORE.
*/
FullScoreReturn ExtendLeft(
// Additional context in reverse order. This will update add_rend to
@@ -136,12 +95,24 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
// Amount of additional content that should be considered by the next call.
unsigned char &next_use) const;
+ /* Return probabilities minus rest costs for an array of pointers. The
+ * first length should be the length of the n-gram to which pointers_begin
+ * points.
+ */
+ float UnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
+ // Compiler should optimize this if away.
+ return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
+ }
+
private:
friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config);
- FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
+ FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
+
+ // Score bigrams and above. Do not include backoff.
+ void ResumeScore(const WordIndex *context_rbegin, const WordIndex *const context_rend, unsigned char starting_order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const;
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
@@ -150,32 +121,38 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
void InitializeFromARPA(const char *file, const Config &config);
+ float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const;
+
Backing &MutableBacking() { return backing_; }
Backing backing_;
VocabularyT vocab_;
- typedef typename Search::Middle Middle;
-
Search search_;
};
} // namespace detail
-// These must also be instantiated in the cc file.
-typedef ::lm::ngram::ProbingVocabulary Vocabulary;
-typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; // HASH_PROBING
-// Default implementation. No real reason for it to be the default.
-typedef ProbingModel Model;
+// Instead of typedef, inherit. This allows the Model etc to be forward declared.
+// Oh the joys of C and C++.
+#define LM_COMMA() ,
+#define LM_NAME_MODEL(name, from)\
+class name : public from {\
+ public:\
+ name(const char *file, const Config &config = Config()) : from(file, config) {}\
+};
-// Smaller implementation.
-typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
-typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary> TrieModel; // TRIE_SORTED
-typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary> ArrayTrieModel;
+LM_NAME_MODEL(ProbingModel, detail::GenericModel<detail::HashedSearch<BackoffValue> LM_COMMA() ProbingVocabulary>);
+LM_NAME_MODEL(RestProbingModel, detail::GenericModel<detail::HashedSearch<RestValue> LM_COMMA() ProbingVocabulary>);
+LM_NAME_MODEL(TrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
+LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
+LM_NAME_MODEL(QuantTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
+LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
-typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED
-typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary> QuantArrayTrieModel;
+// Default implementation. No real reason for it to be the default.
+typedef ::lm::ngram::ProbingVocabulary Vocabulary;
+typedef ProbingModel Model;
} // namespace ngram
} // namespace lm