diff options
Diffstat (limited to 'klm/lm/vocab.hh')
-rw-r--r-- | klm/lm/vocab.hh | 90 |
1 files changed, 76 insertions, 14 deletions
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 074b74d8..d6ae07b8 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -1,9 +1,11 @@ -#ifndef LM_VOCAB__ -#define LM_VOCAB__ +#ifndef LM_VOCAB_H +#define LM_VOCAB_H #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" +#include "util/fake_ofstream.hh" +#include "util/murmur_hash.hh" #include "util/pool.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" @@ -104,17 +106,16 @@ class SortedVocabulary : public base::Vocabulary { #pragma pack(push) #pragma pack(4) -struct ProbingVocabuaryEntry { +struct ProbingVocabularyEntry { uint64_t key; WordIndex value; typedef uint64_t Key; - uint64_t GetKey() const { - return key; - } + uint64_t GetKey() const { return key; } + void SetKey(uint64_t to) { key = to; } - static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) { - ProbingVocabuaryEntry ret; + static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) { + ProbingVocabularyEntry ret; ret.key = key; ret.value = value; return ret; @@ -132,13 +133,18 @@ class ProbingVocabulary : public base::Vocabulary { return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } + static uint64_t Size(uint64_t entries, float probing_multiplier); + // This just unwraps Config to get the probing_multiplier. static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()). WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. - void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void SetupMemory(void *start, std::size_t allocated); + void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { + SetupMemory(start, allocated); + } void Relocate(void *new_start); @@ -147,8 +153,9 @@ class ProbingVocabulary : public base::Vocabulary { WordIndex Insert(const StringPiece &str); template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) { - InternalFinishedLoading(); + FinishedLoading(); } + void FinishedLoading(); std::size_t UnkCountChangePadding() const { return 0; } @@ -157,9 +164,7 @@ class ProbingVocabulary : public base::Vocabulary { void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: - void InternalFinishedLoading(); - - typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup; + typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup; Lookup lookup_; @@ -181,7 +186,64 @@ template <class Vocab> void CheckSpecials(const Config &config, const Vocab &voc if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>"); } +class WriteUniqueWords { + public: + explicit WriteUniqueWords(int fd) : word_list_(fd) {} + + void operator()(const StringPiece &word) { + word_list_ << word << '\0'; + } + + private: + util::FakeOFStream word_list_; +}; + +class NoOpUniqueWords { + public: + NoOpUniqueWords() {} + void operator()(const StringPiece &word) {} +}; + +template <class NewWordAction = NoOpUniqueWords> class GrowableVocab { + public: + static std::size_t MemUsage(WordIndex content) { + return Lookup::MemUsage(content > 2 ? content : 2); + } + + // Does not take ownership of write_wordi + template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction()) + : lookup_(initial_size), new_word_(new_word_construct) { + FindOrInsert("<unk>"); // Force 0 + FindOrInsert("<s>"); // Force 1 + FindOrInsert("</s>"); // Force 2 + } + + WordIndex Index(const StringPiece &str) const { + Lookup::ConstIterator i; + return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; + } + + WordIndex FindOrInsert(const StringPiece &word) { + ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size()); + Lookup::MutableIterator it; + if (!lookup_.FindOrInsert(entry, it)) { + new_word_(word); + UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh"); + } + return it->value; + } + + WordIndex Size() const { return lookup_.Size(); } + + private: + typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup; + + Lookup lookup_; + + NewWordAction new_word_; +}; + } // namespace ngram } // namespace lm -#endif // LM_VOCAB__ +#endif // LM_VOCAB_H |