summaryrefslogtreecommitdiff
path: root/klm/lm/model.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/model.hh')
-rw-r--r--klm/lm/model.hh126
1 files changed, 126 insertions, 0 deletions
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
new file mode 100644
index 00000000..e0eeee17
--- /dev/null
+++ b/klm/lm/model.hh
@@ -0,0 +1,126 @@
+#ifndef LM_MODEL__
+#define LM_MODEL__
+
+#include "lm/binary_format.hh"
+#include "lm/config.hh"
+#include "lm/facade.hh"
+#include "lm/search_hashed.hh"
+#include "lm/search_trie.hh"
+#include "lm/vocab.hh"
+#include "lm/weights.hh"
+
+#include <algorithm>
+#include <vector>
+
+namespace util { class FilePiece; }
+
+namespace lm {
+namespace ngram {
+
+// If you need higher order, change this and recompile.
+// Having this limit means that State can be
+// (kMaxOrder - 1) * sizeof(float) bytes instead of
+// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
+const std::size_t kMaxOrder = 6;
+
+// This is a POD.
+class State {
+ public:
+ bool operator==(const State &other) const {
+ if (valid_length_ != other.valid_length_) return false;
+ const WordIndex *end = history_ + valid_length_;
+ for (const WordIndex *first = history_, *second = other.history_;
+ first != end; ++first, ++second) {
+ if (*first != *second) return false;
+ }
+ // If the histories are equal, so are the backoffs.
+ return true;
+ }
+
+ // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
+ // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
+ WordIndex history_[kMaxOrder - 1];
+ float backoff_[kMaxOrder - 1];
+ unsigned char valid_length_;
+};
+
+size_t hash_value(const State &state);
+
+namespace detail {
+
+// Should return the same results as SRI.
+// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary.
+template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
+ private:
+ typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
+ public:
+ // Get the size of memory that will be mapped given ngram counts. This
+ // does not include small non-mapped control structures, such as this class
+ // itself.
+ static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
+
+ GenericModel(const char *file, const Config &config = Config());
+
+ FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
+
+ /* Slower call without in_state. Don't use this if you can avoid it. This
+ * is mostly a hack for Hieu to integrate it into Moses which sometimes
+ * forgets LM state (i.e. it doesn't store it with the phrase). Sigh.
+ * The context indices should be in an array.
+ * If context_rbegin != context_rend then *context_rbegin is the word
+ * before new_word.
+ */
+ FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
+
+ /* Get the state for a context. Don't use this if you can avoid it. Use
+ * BeginSentenceState or EmptyContextState and extend from those. If
+ * you're only going to use this state to call FullScore once, use
+ * FullScoreForgotState. */
+ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const;
+
+ private:
+ friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
+
+ float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const;
+
+ FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, unsigned char &backoff_start, State &out_state) const;
+
+ // Appears after Size in the cc file.
+ void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
+
+ void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);
+
+ void InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters &params, const Config &config);
+
+ Backing &MutableBacking() { return backing_; }
+
+ static const ModelType kModelType = Search::kModelType;
+
+ Backing backing_;
+
+ VocabularyT vocab_;
+
+ typedef typename Search::Unigram Unigram;
+ typedef typename Search::Middle Middle;
+ typedef typename Search::Longest Longest;
+
+ Search search_;
+};
+
+} // namespace detail
+
+// These must also be instantiated in the cc file.
+typedef ::lm::ngram::ProbingVocabulary Vocabulary;
+typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel;
+// Default implementation. No real reason for it to be the default.
+typedef ProbingModel Model;
+
+typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
+typedef detail::GenericModel<detail::SortedHashedSearch, SortedVocabulary> SortedModel;
+
+typedef detail::GenericModel<trie::TrieSearch, SortedVocabulary> TrieModel;
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_MODEL__