diff options
Diffstat (limited to 'klm/lm')
-rw-r--r-- | klm/lm/enumerate_vocab.hh | 29 | ||||
-rw-r--r-- | klm/lm/model.hh | 126 | ||||
-rw-r--r-- | klm/lm/search_hashed.hh | 156 | ||||
-rw-r--r-- | klm/lm/search_trie.hh | 83 | ||||
-rw-r--r-- | klm/lm/trie.hh | 129 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 138 |
6 files changed, 661 insertions, 0 deletions
diff --git a/klm/lm/enumerate_vocab.hh b/klm/lm/enumerate_vocab.hh new file mode 100644 index 00000000..7a2f7d12 --- /dev/null +++ b/klm/lm/enumerate_vocab.hh @@ -0,0 +1,29 @@ +#ifndef LM_ENUMERATE_VOCAB__ +#define LM_ENUMERATE_VOCAB__ + +#include "lm/word_index.hh" +#include "util/string_piece.hh" + +namespace lm { +namespace ngram { + +/* If you need the actual strings in the vocabulary, inherit from this class + * and implement Add. Then put a pointer in Config.enumerate_vocab. + * Add is called once per n-gram. index starts at 0 and increases by 1 each + * time. + */ +class EnumerateVocab { + public: + virtual ~EnumerateVocab() {} + + virtual void Add(WordIndex index, const StringPiece &str) = 0; + + protected: + EnumerateVocab() {} +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_ENUMERATE_VOCAB__ + diff --git a/klm/lm/model.hh b/klm/lm/model.hh new file mode 100644 index 00000000..e0eeee17 --- /dev/null +++ b/klm/lm/model.hh @@ -0,0 +1,126 @@ +#ifndef LM_MODEL__ +#define LM_MODEL__ + +#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/facade.hh" +#include "lm/search_hashed.hh" +#include "lm/search_trie.hh" +#include "lm/vocab.hh" +#include "lm/weights.hh" + +#include <algorithm> +#include <vector> + +namespace util { class FilePiece; } + +namespace lm { +namespace ngram { + +// If you need higher order, change this and recompile. +// Having this limit means that State can be +// (kMaxOrder - 1) * sizeof(float) bytes instead of +// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead +const std::size_t kMaxOrder = 6; + +// This is a POD. +class State { + public: + bool operator==(const State &other) const { + if (valid_length_ != other.valid_length_) return false; + const WordIndex *end = history_ + valid_length_; + for (const WordIndex *first = history_, *second = other.history_; + first != end; ++first, ++second) { + if (*first != *second) return false; + } + // If the histories are equal, so are the backoffs. + return true; + } + + // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. + // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. + WordIndex history_[kMaxOrder - 1]; + float backoff_[kMaxOrder - 1]; + unsigned char valid_length_; +}; + +size_t hash_value(const State &state); + +namespace detail { + +// Should return the same results as SRI. +// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary. +template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> { + private: + typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P; + public: + // Get the size of memory that will be mapped given ngram counts. This + // does not include small non-mapped control structures, such as this class + // itself. + static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); + + GenericModel(const char *file, const Config &config = Config()); + + FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; + + /* Slower call without in_state. Don't use this if you can avoid it. This + * is mostly a hack for Hieu to integrate it into Moses which sometimes + * forgets LM state (i.e. it doesn't store it with the phrase). Sigh. + * The context indices should be in an array. + * If context_rbegin != context_rend then *context_rbegin is the word + * before new_word. + */ + FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; + + /* Get the state for a context. Don't use this if you can avoid it. Use + * BeginSentenceState or EmptyContextState and extend from those. If + * you're only going to use this state to call FullScore once, use + * FullScoreForgotState. */ + void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; + + private: + friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); + + float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; + + FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, unsigned char &backoff_start, State &out_state) const; + + // Appears after Size in the cc file. + void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); + + void InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd); + + void InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters ¶ms, const Config &config); + + Backing &MutableBacking() { return backing_; } + + static const ModelType kModelType = Search::kModelType; + + Backing backing_; + + VocabularyT vocab_; + + typedef typename Search::Unigram Unigram; + typedef typename Search::Middle Middle; + typedef typename Search::Longest Longest; + + Search search_; +}; + +} // namespace detail + +// These must also be instantiated in the cc file. +typedef ::lm::ngram::ProbingVocabulary Vocabulary; +typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; +// Default implementation. No real reason for it to be the default. +typedef ProbingModel Model; + +typedef ::lm::ngram::SortedVocabulary SortedVocabulary; +typedef detail::GenericModel<detail::SortedHashedSearch, SortedVocabulary> SortedModel; + +typedef detail::GenericModel<trie::TrieSearch, SortedVocabulary> TrieModel; + +} // namespace ngram +} // namespace lm + +#endif // LM_MODEL__ diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh new file mode 100644 index 00000000..1ee2b9e9 --- /dev/null +++ b/klm/lm/search_hashed.hh @@ -0,0 +1,156 @@ +#ifndef LM_SEARCH_HASHED__ +#define LM_SEARCH_HASHED__ + +#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/read_arpa.hh" +#include "lm/weights.hh" + +#include "util/key_value_packing.hh" +#include "util/probing_hash_table.hh" +#include "util/sorted_uniform.hh" + +#include <algorithm> +#include <vector> + +namespace util { class FilePiece; } + +namespace lm { +namespace ngram { +namespace detail { + +inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { + uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL); + return ret; +} + +struct HashedSearch { + typedef uint64_t Node; + + class Unigram { + public: + Unigram() {} + + Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {} + + static std::size_t Size(uint64_t count) { + return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk> + } + + const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; } + + ProbBackoff &Unknown() { return unigram_[0]; } + + void LoadedBinary() {} + + // For building. + ProbBackoff *Raw() { return unigram_; } + + private: + ProbBackoff *unigram_; + }; + + Unigram unigram; + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + const ProbBackoff &entry = unigram.Lookup(word); + prob = entry.prob; + backoff = entry.backoff; + next = static_cast<Node>(word); + return true; + } +}; + +template <class MiddleT, class LongestT> struct TemplateHashedSearch : public HashedSearch { + typedef MiddleT Middle; + std::vector<Middle> middle; + + typedef LongestT Longest; + Longest longest; + + static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { + std::size_t ret = Unigram::Size(counts[0]); + for (unsigned char n = 1; n < counts.size() - 1; ++n) { + ret += Middle::Size(counts[n], config.probing_multiplier); + } + return ret + Longest::Size(counts.back(), config.probing_multiplier); + } + + uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { + std::size_t allocated = Unigram::Size(counts[0]); + unigram = Unigram(start, allocated); + start += allocated; + for (unsigned int n = 2; n < counts.size(); ++n) { + allocated = Middle::Size(counts[n - 1], config.probing_multiplier); + middle.push_back(Middle(start, allocated)); + start += allocated; + } + allocated = Longest::Size(counts.back(), config.probing_multiplier); + longest = Longest(start, allocated); + start += allocated; + return start; + } + + template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab); + + bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + prob = found->GetValue().prob; + backoff = found->GetValue().backoff; + return true; + } + + bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + backoff = found->GetValue().backoff; + return true; + } + + bool LookupLongest(WordIndex word, float &prob, Node &node) const { + node = CombineWordHash(node, word); + typename Longest::ConstIterator found; + if (!longest.Find(node, found)) return false; + prob = found->GetValue().prob; + return true; + } + + // Geenrate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + assert(begin != end); + node = static_cast<Node>(*begin); + for (const WordIndex *i = begin + 1; i < end; ++i) { + node = CombineWordHash(node, *i); + } + return true; + } +}; + +// std::identity is an SGI extension :-( +struct IdentityHash : public std::unary_function<uint64_t, size_t> { + size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); } +}; + +struct ProbingHashedSearch : public TemplateHashedSearch< + util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>, + util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > { + + static const ModelType kModelType = HASH_PROBING; +}; + +struct SortedHashedSearch : public TemplateHashedSearch< + util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, ProbBackoff> >, + util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, Prob> > > { + + static const ModelType kModelType = HASH_SORTED; +}; + +} // namespace detail +} // namespace ngram +} // namespace lm + +#endif // LM_SEARCH_HASHED__ diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh new file mode 100644 index 00000000..902f6ce6 --- /dev/null +++ b/klm/lm/search_trie.hh @@ -0,0 +1,83 @@ +#ifndef LM_SEARCH_TRIE__ +#define LM_SEARCH_TRIE__ + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "lm/weights.hh" + +#include <assert.h> + +namespace lm { +namespace ngram { +class SortedVocabulary; +namespace trie { + +struct TrieSearch { + typedef NodeRange Node; + + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram; + + typedef trie::BitPackedMiddle Middle; + std::vector<Middle> middle; + + typedef trie::BitPackedLongest Longest; + Longest longest; + + static const ModelType kModelType = TRIE_SORTED; + + static std::size_t Size(const std::vector<uint64_t> &counts, const Config &/*config*/) { + std::size_t ret = Unigram::Size(counts[0]); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + ret += Middle::Size(counts[i], counts[0], counts[i+1]); + } + return ret + Longest::Size(counts.back(), counts[0]); + } + + uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &/*config*/) { + unigram.Init(start); + start += Unigram::Size(counts[0]); + middle.resize(counts.size() - 2); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + middle[i-1].Init(start, counts[0], counts[i+1]); + start += Middle::Size(counts[i], counts[0], counts[i+1]); + } + longest.Init(start, counts[0]); + return start + Longest::Size(counts.back(), counts[0]); + } + + void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab); + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + return unigram.Find(word, prob, backoff, node); + } + + bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { + return mid.Find(word, prob, backoff, node); + } + + bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { + return mid.FindNoProb(word, backoff, node); + } + + bool LookupLongest(WordIndex word, float &prob, const Node &node) const { + return longest.Find(word, prob, node); + } + + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + // TODO: don't decode prob. + assert(begin != end); + float ignored_prob, ignored_backoff; + LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + for (const WordIndex *i = begin + 1; i < end; ++i) { + if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; + } + return true; + } +}; + +} // namespace trie +} // namespace ngram +} // namespace lm + +#endif // LM_SEARCH_TRIE__ diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh new file mode 100644 index 00000000..35dc2c96 --- /dev/null +++ b/klm/lm/trie.hh @@ -0,0 +1,129 @@ +#ifndef LM_TRIE__ +#define LM_TRIE__ + +#include <inttypes.h> + +#include <cstddef> + +#include "lm/word_index.hh" +#include "lm/weights.hh" + +namespace lm { +namespace ngram { +namespace trie { + +struct NodeRange { + uint64_t begin, end; +}; + +// TODO: if the number of unigrams is a concern, also bit pack these records. +struct UnigramValue { + ProbBackoff weights; + uint64_t next; + uint64_t Next() const { return next; } +}; + +class Unigram { + public: + Unigram() {} + + void Init(void *start) { + unigram_ = static_cast<UnigramValue*>(start); + } + + static std::size_t Size(uint64_t count) { + // +1 in case unknown doesn't appear. +1 for the final next. + return (count + 2) * sizeof(UnigramValue); + } + + const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; } + + ProbBackoff &Unknown() { return unigram_[0].weights; } + + UnigramValue *Raw() { + return unigram_; + } + + void LoadedBinary() {} + + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { + UnigramValue *val = unigram_ + word; + prob = val->weights.prob; + backoff = val->weights.backoff; + next.begin = val->next; + next.end = (val+1)->next; + return true; + } + + private: + UnigramValue *unigram_; +}; + +class BitPacked { + public: + BitPacked() {} + + uint64_t InsertIndex() const { + return insert_index_; + } + + void LoadedBinary() {} + + protected: + static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); + + void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); + + uint8_t word_bits_, prob_bits_; + uint8_t total_bits_; + uint64_t word_mask_; + + uint8_t *base_; + + uint64_t insert_index_; +}; + +class BitPackedMiddle : public BitPacked { + public: + BitPackedMiddle() {} + + static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next); + + void Init(void *base, uint64_t max_vocab, uint64_t max_next); + + void Insert(WordIndex word, float prob, float backoff, uint64_t next); + + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; + + bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; + + void FinishedLoading(uint64_t next_end); + + private: + uint8_t backoff_bits_, next_bits_; + uint64_t next_mask_; +}; + + +class BitPackedLongest : public BitPacked { + public: + BitPackedLongest() {} + + static std::size_t Size(uint64_t entries, uint64_t max_vocab) { + return BaseSize(entries, max_vocab, 0); + } + + void Init(void *base, uint64_t max_vocab) { + return BaseInit(base, max_vocab, 0); + } + + void Insert(WordIndex word, float prob); + + bool Find(WordIndex word, float &prob, const NodeRange &node) const; +}; + +} // namespace trie +} // namespace ngram +} // namespace lm + +#endif // LM_TRIE__ diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh new file mode 100644 index 00000000..bb5d789b --- /dev/null +++ b/klm/lm/vocab.hh @@ -0,0 +1,138 @@ +#ifndef LM_VOCAB__ +#define LM_VOCAB__ + +#include "lm/enumerate_vocab.hh" +#include "lm/virtual_interface.hh" +#include "util/key_value_packing.hh" +#include "util/probing_hash_table.hh" +#include "util/sorted_uniform.hh" +#include "util/string_piece.hh" + +#include <string> +#include <vector> + +namespace lm { +class ProbBackoff; + +namespace ngram { +class Config; +class EnumerateVocab; + +namespace detail { +uint64_t HashForVocab(const char *str, std::size_t len); +inline uint64_t HashForVocab(const StringPiece &str) { + return HashForVocab(str.data(), str.length()); +} +} // namespace detail + +class WriteWordsWrapper : public EnumerateVocab { + public: + WriteWordsWrapper(EnumerateVocab *inner, int fd); + + ~WriteWordsWrapper(); + + void Add(WordIndex index, const StringPiece &str); + + private: + EnumerateVocab *inner_; + int fd_; +}; + +// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. +class SortedVocabulary : public base::Vocabulary { + private: + // Sorted uniform requires a GetKey function. + struct Entry { + uint64_t GetKey() const { return key; } + uint64_t key; + bool operator<(const Entry &other) const { + return key < other.key; + } + }; + + public: + SortedVocabulary(); + + WordIndex Index(const StringPiece &str) const { + const Entry *found; + if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) { + return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table. + } else { + return 0; + } + } + + // Ignores second argument for consistency with probing hash which has a float here. + static size_t Size(std::size_t entries, const Config &config); + + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. + void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + + void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); + + WordIndex Insert(const StringPiece &str); + + // Reorders reorder_vocab so that the IDs are sorted. + void FinishedLoading(ProbBackoff *reorder_vocab); + + bool SawUnk() const { return saw_unk_; } + + void LoadedBinary(int fd, EnumerateVocab *to); + + private: + Entry *begin_, *end_; + + bool saw_unk_; + + EnumerateVocab *enumerate_; + + // Actual strings. Used only when loading from ARPA and enumerate_ != NULL + std::vector<std::string> strings_to_enumerate_; +}; + +// Vocabulary storing a map from uint64_t to WordIndex. +class ProbingVocabulary : public base::Vocabulary { + public: + ProbingVocabulary(); + + WordIndex Index(const StringPiece &str) const { + Lookup::ConstIterator i; + return lookup_.Find(detail::HashForVocab(str), i) ? i->GetValue() : 0; + } + + static size_t Size(std::size_t entries, const Config &config); + + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. + void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + + void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); + + WordIndex Insert(const StringPiece &str); + + void FinishedLoading(ProbBackoff *reorder_vocab); + + bool SawUnk() const { return saw_unk_; } + + void LoadedBinary(int fd, EnumerateVocab *to); + + private: + // std::identity is an SGI extension :-( + struct IdentityHash : public std::unary_function<uint64_t, std::size_t> { + std::size_t operator()(uint64_t arg) const { return static_cast<std::size_t>(arg); } + }; + + typedef util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, WordIndex>, IdentityHash> Lookup; + + Lookup lookup_; + + WordIndex available_; + + bool saw_unk_; + + EnumerateVocab *enumerate_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_VOCAB__ |