summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--klm/lm/enumerate_vocab.hh29
-rw-r--r--klm/lm/model.hh126
-rw-r--r--klm/lm/search_hashed.hh156
-rw-r--r--klm/lm/search_trie.hh83
-rw-r--r--klm/lm/trie.hh129
-rw-r--r--klm/lm/vocab.hh138
6 files changed, 661 insertions, 0 deletions
diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh
new file mode 100644
index 00000000..7a2f7d12
--- /dev/null
+++ b/klm/lm/enumerate_vocab.hh
@@ -0,0 +1,29 @@
+#ifndef LM_ENUMERATE_VOCAB__
+#define LM_ENUMERATE_VOCAB__
+
+#include "lm/word_index.hh"
+#include "util/string_piece.hh"
+
+namespace lm {
+namespace ngram {
+
+/* If you need the actual strings in the vocabulary, inherit from this class
+ * and implement Add. Then put a pointer in Config.enumerate_vocab.
+ * Add is called once per n-gram. index starts at 0 and increases by 1 each
+ * time.
+ */
+class EnumerateVocab {
+ public:
+ virtual ~EnumerateVocab() {}
+
+ virtual void Add(WordIndex index, const StringPiece &str) = 0;
+
+ protected:
+ EnumerateVocab() {}
+};
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_ENUMERATE_VOCAB__
+
diff --git a/klm/lm/model.hh b/klm/lm/model.hh
new file mode 100644
index 00000000..e0eeee17
--- /dev/null
+++ b/klm/lm/model.hh
@@ -0,0 +1,126 @@
+#ifndef LM_MODEL__
+#define LM_MODEL__
+
+#include "lm/binary_format.hh"
+#include "lm/config.hh"
+#include "lm/facade.hh"
+#include "lm/search_hashed.hh"
+#include "lm/search_trie.hh"
+#include "lm/vocab.hh"
+#include "lm/weights.hh"
+
+#include <algorithm>
+#include <vector>
+
+namespace util { class FilePiece; }
+
+namespace lm {
+namespace ngram {
+
+// If you need higher order, change this and recompile.
+// Having this limit means that State can be
+// (kMaxOrder - 1) * sizeof(float) bytes instead of
+// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
+const std::size_t kMaxOrder = 6;
+
+// This is a POD.
+class State {
+ public:
+ bool operator==(const State &other) const {
+ if (valid_length_ != other.valid_length_) return false;
+ const WordIndex *end = history_ + valid_length_;
+ for (const WordIndex *first = history_, *second = other.history_;
+ first != end; ++first, ++second) {
+ if (*first != *second) return false;
+ }
+ // If the histories are equal, so are the backoffs.
+ return true;
+ }
+
+ // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
+ // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
+ WordIndex history_[kMaxOrder - 1];
+ float backoff_[kMaxOrder - 1];
+ unsigned char valid_length_;
+};
+
+size_t hash_value(const State &state);
+
+namespace detail {
+
+// Should return the same results as SRI.
+// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary.
+template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
+ private:
+ typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
+ public:
+ // Get the size of memory that will be mapped given ngram counts. This
+ // does not include small non-mapped control structures, such as this class
+ // itself.
+ static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
+
+ GenericModel(const char *file, const Config &config = Config());
+
+ FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
+
+ /* Slower call without in_state. Don't use this if you can avoid it. This
+ * is mostly a hack for Hieu to integrate it into Moses which sometimes
+ * forgets LM state (i.e. it doesn't store it with the phrase). Sigh.
+ * The context indices should be in an array.
+ * If context_rbegin != context_rend then *context_rbegin is the word
+ * before new_word.
+ */
+ FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
+
+ /* Get the state for a context. Don't use this if you can avoid it. Use
+ * BeginSentenceState or EmptyContextState and extend from those. If
+ * you're only going to use this state to call FullScore once, use
+ * FullScoreForgotState. */
+ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const;
+
+ private:
+ friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);
+
+ float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const;
+
+ FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, unsigned char &backoff_start, State &out_state) const;
+
+ // Appears after Size in the cc file.
+ void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
+
+ void InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd);
+
+ void InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters &params, const Config &config);
+
+ Backing &MutableBacking() { return backing_; }
+
+ static const ModelType kModelType = Search::kModelType;
+
+ Backing backing_;
+
+ VocabularyT vocab_;
+
+ typedef typename Search::Unigram Unigram;
+ typedef typename Search::Middle Middle;
+ typedef typename Search::Longest Longest;
+
+ Search search_;
+};
+
+} // namespace detail
+
+// These must also be instantiated in the cc file.
+typedef ::lm::ngram::ProbingVocabulary Vocabulary;
+typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel;
+// Default implementation. No real reason for it to be the default.
+typedef ProbingModel Model;
+
+typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
+typedef detail::GenericModel<detail::SortedHashedSearch, SortedVocabulary> SortedModel;
+
+typedef detail::GenericModel<trie::TrieSearch, SortedVocabulary> TrieModel;
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_MODEL__
diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh
new file mode 100644
index 00000000..1ee2b9e9
--- /dev/null
+++ b/klm/lm/search_hashed.hh
@@ -0,0 +1,156 @@
+#ifndef LM_SEARCH_HASHED__
+#define LM_SEARCH_HASHED__
+
+#include "lm/binary_format.hh"
+#include "lm/config.hh"
+#include "lm/read_arpa.hh"
+#include "lm/weights.hh"
+
+#include "util/key_value_packing.hh"
+#include "util/probing_hash_table.hh"
+#include "util/sorted_uniform.hh"
+
+#include <algorithm>
+#include <vector>
+
+namespace util { class FilePiece; }
+
+namespace lm {
+namespace ngram {
+namespace detail {
+
+inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
+ uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
+ return ret;
+}
+
+struct HashedSearch {
+ typedef uint64_t Node;
+
+ class Unigram {
+ public:
+ Unigram() {}
+
+ Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {}
+
+ static std::size_t Size(uint64_t count) {
+ return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
+ }
+
+ const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; }
+
+ ProbBackoff &Unknown() { return unigram_[0]; }
+
+ void LoadedBinary() {}
+
+ // For building.
+ ProbBackoff *Raw() { return unigram_; }
+
+ private:
+ ProbBackoff *unigram_;
+ };
+
+ Unigram unigram;
+
+ bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
+ const ProbBackoff &entry = unigram.Lookup(word);
+ prob = entry.prob;
+ backoff = entry.backoff;
+ next = static_cast<Node>(word);
+ return true;
+ }
+};
+
+template <class MiddleT, class LongestT> struct TemplateHashedSearch : public HashedSearch {
+ typedef MiddleT Middle;
+ std::vector<Middle> middle;
+
+ typedef LongestT Longest;
+ Longest longest;
+
+ static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
+ std::size_t ret = Unigram::Size(counts[0]);
+ for (unsigned char n = 1; n < counts.size() - 1; ++n) {
+ ret += Middle::Size(counts[n], config.probing_multiplier);
+ }
+ return ret + Longest::Size(counts.back(), config.probing_multiplier);
+ }
+
+ uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
+ std::size_t allocated = Unigram::Size(counts[0]);
+ unigram = Unigram(start, allocated);
+ start += allocated;
+ for (unsigned int n = 2; n < counts.size(); ++n) {
+ allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
+ middle.push_back(Middle(start, allocated));
+ start += allocated;
+ }
+ allocated = Longest::Size(counts.back(), config.probing_multiplier);
+ longest = Longest(start, allocated);
+ start += allocated;
+ return start;
+ }
+
+ template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab);
+
+ bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const {
+ node = CombineWordHash(node, word);
+ typename Middle::ConstIterator found;
+ if (!middle.Find(node, found)) return false;
+ prob = found->GetValue().prob;
+ backoff = found->GetValue().backoff;
+ return true;
+ }
+
+ bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
+ node = CombineWordHash(node, word);
+ typename Middle::ConstIterator found;
+ if (!middle.Find(node, found)) return false;
+ backoff = found->GetValue().backoff;
+ return true;
+ }
+
+ bool LookupLongest(WordIndex word, float &prob, Node &node) const {
+ node = CombineWordHash(node, word);
+ typename Longest::ConstIterator found;
+ if (!longest.Find(node, found)) return false;
+ prob = found->GetValue().prob;
+ return true;
+ }
+
+ // Geenrate a node without necessarily checking that it actually exists.
+ // Optionally return false if it's know to not exist.
+ bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
+ assert(begin != end);
+ node = static_cast<Node>(*begin);
+ for (const WordIndex *i = begin + 1; i < end; ++i) {
+ node = CombineWordHash(node, *i);
+ }
+ return true;
+ }
+};
+
+// std::identity is an SGI extension :-(
+struct IdentityHash : public std::unary_function<uint64_t, size_t> {
+ size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
+};
+
+struct ProbingHashedSearch : public TemplateHashedSearch<
+ util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>,
+ util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > {
+
+ static const ModelType kModelType = HASH_PROBING;
+};
+
+struct SortedHashedSearch : public TemplateHashedSearch<
+ util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, ProbBackoff> >,
+ util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, Prob> > > {
+
+ static const ModelType kModelType = HASH_SORTED;
+};
+
+} // namespace detail
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_SEARCH_HASHED__
diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh
new file mode 100644
index 00000000..902f6ce6
--- /dev/null
+++ b/klm/lm/search_trie.hh
@@ -0,0 +1,83 @@
+#ifndef LM_SEARCH_TRIE__
+#define LM_SEARCH_TRIE__
+
+#include "lm/binary_format.hh"
+#include "lm/trie.hh"
+#include "lm/weights.hh"
+
+#include <assert.h>
+
+namespace lm {
+namespace ngram {
+class SortedVocabulary;
+namespace trie {
+
+struct TrieSearch {
+ typedef NodeRange Node;
+
+ typedef ::lm::ngram::trie::Unigram Unigram;
+ Unigram unigram;
+
+ typedef trie::BitPackedMiddle Middle;
+ std::vector<Middle> middle;
+
+ typedef trie::BitPackedLongest Longest;
+ Longest longest;
+
+ static const ModelType kModelType = TRIE_SORTED;
+
+ static std::size_t Size(const std::vector<uint64_t> &counts, const Config &/*config*/) {
+ std::size_t ret = Unigram::Size(counts[0]);
+ for (unsigned char i = 1; i < counts.size() - 1; ++i) {
+ ret += Middle::Size(counts[i], counts[0], counts[i+1]);
+ }
+ return ret + Longest::Size(counts.back(), counts[0]);
+ }
+
+ uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &/*config*/) {
+ unigram.Init(start);
+ start += Unigram::Size(counts[0]);
+ middle.resize(counts.size() - 2);
+ for (unsigned char i = 1; i < counts.size() - 1; ++i) {
+ middle[i-1].Init(start, counts[0], counts[i+1]);
+ start += Middle::Size(counts[i], counts[0], counts[i+1]);
+ }
+ longest.Init(start, counts[0]);
+ return start + Longest::Size(counts.back(), counts[0]);
+ }
+
+ void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab);
+
+ bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const {
+ return unigram.Find(word, prob, backoff, node);
+ }
+
+ bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const {
+ return mid.Find(word, prob, backoff, node);
+ }
+
+ bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
+ return mid.FindNoProb(word, backoff, node);
+ }
+
+ bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
+ return longest.Find(word, prob, node);
+ }
+
+ bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
+ // TODO: don't decode prob.
+ assert(begin != end);
+ float ignored_prob, ignored_backoff;
+ LookupUnigram(*begin, ignored_prob, ignored_backoff, node);
+ for (const WordIndex *i = begin + 1; i < end; ++i) {
+ if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false;
+ }
+ return true;
+ }
+};
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_SEARCH_TRIE__
diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh
new file mode 100644
index 00000000..35dc2c96
--- /dev/null
+++ b/klm/lm/trie.hh
@@ -0,0 +1,129 @@
+#ifndef LM_TRIE__
+#define LM_TRIE__
+
+#include <inttypes.h>
+
+#include <cstddef>
+
+#include "lm/word_index.hh"
+#include "lm/weights.hh"
+
+namespace lm {
+namespace ngram {
+namespace trie {
+
+struct NodeRange {
+ uint64_t begin, end;
+};
+
+// TODO: if the number of unigrams is a concern, also bit pack these records.
+struct UnigramValue {
+ ProbBackoff weights;
+ uint64_t next;
+ uint64_t Next() const { return next; }
+};
+
+class Unigram {
+ public:
+ Unigram() {}
+
+ void Init(void *start) {
+ unigram_ = static_cast<UnigramValue*>(start);
+ }
+
+ static std::size_t Size(uint64_t count) {
+ // +1 in case unknown doesn't appear. +1 for the final next.
+ return (count + 2) * sizeof(UnigramValue);
+ }
+
+ const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
+
+ ProbBackoff &Unknown() { return unigram_[0].weights; }
+
+ UnigramValue *Raw() {
+ return unigram_;
+ }
+
+ void LoadedBinary() {}
+
+ bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {
+ UnigramValue *val = unigram_ + word;
+ prob = val->weights.prob;
+ backoff = val->weights.backoff;
+ next.begin = val->next;
+ next.end = (val+1)->next;
+ return true;
+ }
+
+ private:
+ UnigramValue *unigram_;
+};
+
+class BitPacked {
+ public:
+ BitPacked() {}
+
+ uint64_t InsertIndex() const {
+ return insert_index_;
+ }
+
+ void LoadedBinary() {}
+
+ protected:
+ static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
+
+ void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
+
+ uint8_t word_bits_, prob_bits_;
+ uint8_t total_bits_;
+ uint64_t word_mask_;
+
+ uint8_t *base_;
+
+ uint64_t insert_index_;
+};
+
+class BitPackedMiddle : public BitPacked {
+ public:
+ BitPackedMiddle() {}
+
+ static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next);
+
+ void Init(void *base, uint64_t max_vocab, uint64_t max_next);
+
+ void Insert(WordIndex word, float prob, float backoff, uint64_t next);
+
+ bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;
+
+ bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;
+
+ void FinishedLoading(uint64_t next_end);
+
+ private:
+ uint8_t backoff_bits_, next_bits_;
+ uint64_t next_mask_;
+};
+
+
+class BitPackedLongest : public BitPacked {
+ public:
+ BitPackedLongest() {}
+
+ static std::size_t Size(uint64_t entries, uint64_t max_vocab) {
+ return BaseSize(entries, max_vocab, 0);
+ }
+
+ void Init(void *base, uint64_t max_vocab) {
+ return BaseInit(base, max_vocab, 0);
+ }
+
+ void Insert(WordIndex word, float prob);
+
+ bool Find(WordIndex word, float &prob, const NodeRange &node) const;
+};
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_TRIE__
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
new file mode 100644
index 00000000..bb5d789b
--- /dev/null
+++ b/klm/lm/vocab.hh
@@ -0,0 +1,138 @@
+#ifndef LM_VOCAB__
+#define LM_VOCAB__
+
+#include "lm/enumerate_vocab.hh"
+#include "lm/virtual_interface.hh"
+#include "util/key_value_packing.hh"
+#include "util/probing_hash_table.hh"
+#include "util/sorted_uniform.hh"
+#include "util/string_piece.hh"
+
+#include <string>
+#include <vector>
+
+namespace lm {
+class ProbBackoff;
+
+namespace ngram {
+class Config;
+class EnumerateVocab;
+
+namespace detail {
+uint64_t HashForVocab(const char *str, std::size_t len);
+inline uint64_t HashForVocab(const StringPiece &str) {
+ return HashForVocab(str.data(), str.length());
+}
+} // namespace detail
+
+class WriteWordsWrapper : public EnumerateVocab {
+ public:
+ WriteWordsWrapper(EnumerateVocab *inner, int fd);
+
+ ~WriteWordsWrapper();
+
+ void Add(WordIndex index, const StringPiece &str);
+
+ private:
+ EnumerateVocab *inner_;
+ int fd_;
+};
+
+// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
+class SortedVocabulary : public base::Vocabulary {
+ private:
+ // Sorted uniform requires a GetKey function.
+ struct Entry {
+ uint64_t GetKey() const { return key; }
+ uint64_t key;
+ bool operator<(const Entry &other) const {
+ return key < other.key;
+ }
+ };
+
+ public:
+ SortedVocabulary();
+
+ WordIndex Index(const StringPiece &str) const {
+ const Entry *found;
+ if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) {
+ return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
+ } else {
+ return 0;
+ }
+ }
+
+ // Ignores second argument for consistency with probing hash which has a float here.
+ static size_t Size(std::size_t entries, const Config &config);
+
+ // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
+ void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+
+ void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
+
+ WordIndex Insert(const StringPiece &str);
+
+ // Reorders reorder_vocab so that the IDs are sorted.
+ void FinishedLoading(ProbBackoff *reorder_vocab);
+
+ bool SawUnk() const { return saw_unk_; }
+
+ void LoadedBinary(int fd, EnumerateVocab *to);
+
+ private:
+ Entry *begin_, *end_;
+
+ bool saw_unk_;
+
+ EnumerateVocab *enumerate_;
+
+ // Actual strings. Used only when loading from ARPA and enumerate_ != NULL
+ std::vector<std::string> strings_to_enumerate_;
+};
+
+// Vocabulary storing a map from uint64_t to WordIndex.
+class ProbingVocabulary : public base::Vocabulary {
+ public:
+ ProbingVocabulary();
+
+ WordIndex Index(const StringPiece &str) const {
+ Lookup::ConstIterator i;
+ return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0;
+ }
+
+ static size_t Size(std::size_t entries, const Config &config);
+
+ // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
+ void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+
+ void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
+
+ WordIndex Insert(const StringPiece &str);
+
+ void FinishedLoading(ProbBackoff *reorder_vocab);
+
+ bool SawUnk() const { return saw_unk_; }
+
+ void LoadedBinary(int fd, EnumerateVocab *to);
+
+ private:
+ // std::identity is an SGI extension :-(
+ struct IdentityHash : public std::unary_function<uint64_t, std::size_t> {
+ std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); }
+ };
+
+ typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup;
+
+ Lookup lookup_;
+
+ WordIndex available_;
+
+ bool saw_unk_;
+
+ EnumerateVocab *enumerate_;
+};
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_VOCAB__