From d3e2ec203a5cf550320caa8023ac3dd103b0be7d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 13 Oct 2014 00:42:37 -0400 Subject: new kenlm --- klm/lm/vocab.hh | 90 ++++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 14 deletions(-) (limited to 'klm/lm/vocab.hh') diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 074b74d8..d6ae07b8 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -1,9 +1,11 @@ -#ifndef LM_VOCAB__ -#define LM_VOCAB__ +#ifndef LM_VOCAB_H +#define LM_VOCAB_H #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/virtual_interface.hh" +#include "util/fake_ofstream.hh" +#include "util/murmur_hash.hh" #include "util/pool.hh" #include "util/probing_hash_table.hh" #include "util/sorted_uniform.hh" @@ -104,17 +106,16 @@ class SortedVocabulary : public base::Vocabulary { #pragma pack(push) #pragma pack(4) -struct ProbingVocabuaryEntry { +struct ProbingVocabularyEntry { uint64_t key; WordIndex value; typedef uint64_t Key; - uint64_t GetKey() const { - return key; - } + uint64_t GetKey() const { return key; } + void SetKey(uint64_t to) { key = to; } - static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) { - ProbingVocabuaryEntry ret; + static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) { + ProbingVocabularyEntry ret; ret.key = key; ret.value = value; return ret; @@ -132,13 +133,18 @@ class ProbingVocabulary : public base::Vocabulary { return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; } + static uint64_t Size(uint64_t entries, float probing_multiplier); + // This just unwraps Config to get the probing_multiplier. static uint64_t Size(uint64_t entries, const Config &config); // Vocab words are [0, Bound()). WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. - void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); + void SetupMemory(void *start, std::size_t allocated); + void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { + SetupMemory(start, allocated); + } void Relocate(void *new_start); @@ -147,8 +153,9 @@ class ProbingVocabulary : public base::Vocabulary { WordIndex Insert(const StringPiece &str); template void FinishedLoading(Weights * /*reorder_vocab*/) { - InternalFinishedLoading(); + FinishedLoading(); } + void FinishedLoading(); std::size_t UnkCountChangePadding() const { return 0; } @@ -157,9 +164,7 @@ class ProbingVocabulary : public base::Vocabulary { void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); private: - void InternalFinishedLoading(); - - typedef util::ProbingHashTable Lookup; + typedef util::ProbingHashTable Lookup; Lookup lookup_; @@ -181,7 +186,64 @@ template void CheckSpecials(const Config &config, const Vocab &voc if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, ""); } +class WriteUniqueWords { + public: + explicit WriteUniqueWords(int fd) : word_list_(fd) {} + + void operator()(const StringPiece &word) { + word_list_ << word << '\0'; + } + + private: + util::FakeOFStream word_list_; +}; + +class NoOpUniqueWords { + public: + NoOpUniqueWords() {} + void operator()(const StringPiece &word) {} +}; + +template class GrowableVocab { + public: + static std::size_t MemUsage(WordIndex content) { + return Lookup::MemUsage(content > 2 ? content : 2); + } + + // Does not take ownership of write_wordi + template GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction()) + : lookup_(initial_size), new_word_(new_word_construct) { + FindOrInsert(""); // Force 0 + FindOrInsert(""); // Force 1 + FindOrInsert(""); // Force 2 + } + + WordIndex Index(const StringPiece &str) const { + Lookup::ConstIterator i; + return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; + } + + WordIndex FindOrInsert(const StringPiece &word) { + ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size()); + Lookup::MutableIterator it; + if (!lookup_.FindOrInsert(entry, it)) { + new_word_(word); + UTIL_THROW_IF(Size() >= std::numeric_limits::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh"); + } + return it->value; + } + + WordIndex Size() const { return lookup_.Size(); } + + private: + typedef util::AutoProbing Lookup; + + Lookup lookup_; + + NewWordAction new_word_; +}; + } // namespace ngram } // namespace lm -#endif // LM_VOCAB__ +#endif // LM_VOCAB_H -- cgit v1.2.3