diff options
Diffstat (limited to 'klm/lm/vocab.hh')
| -rw-r--r-- | klm/lm/vocab.hh | 90 | 
1 files changed, 76 insertions, 14 deletions
| diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 074b74d8..d6ae07b8 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -1,9 +1,11 @@ -#ifndef LM_VOCAB__ -#define LM_VOCAB__ +#ifndef LM_VOCAB_H +#define LM_VOCAB_H  #include "lm/enumerate_vocab.hh"  #include "lm/lm_exception.hh"  #include "lm/virtual_interface.hh" +#include "util/fake_ofstream.hh" +#include "util/murmur_hash.hh"  #include "util/pool.hh"  #include "util/probing_hash_table.hh"  #include "util/sorted_uniform.hh" @@ -104,17 +106,16 @@ class SortedVocabulary : public base::Vocabulary {  #pragma pack(push)  #pragma pack(4) -struct ProbingVocabuaryEntry { +struct ProbingVocabularyEntry {    uint64_t key;    WordIndex value;    typedef uint64_t Key; -  uint64_t GetKey() const { -    return key; -  } +  uint64_t GetKey() const { return key; } +  void SetKey(uint64_t to) { key = to; } -  static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) { -    ProbingVocabuaryEntry ret; +  static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) { +    ProbingVocabularyEntry ret;      ret.key = key;      ret.value = value;      return ret; @@ -132,13 +133,18 @@ class ProbingVocabulary : public base::Vocabulary {        return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;      } +    static uint64_t Size(uint64_t entries, float probing_multiplier); +    // This just unwraps Config to get the probing_multiplier.      static uint64_t Size(uint64_t entries, const Config &config);      // Vocab words are [0, Bound()).        WordIndex Bound() const { return bound_; }      // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway. -    void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); +    void SetupMemory(void *start, std::size_t allocated); +    void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { +      SetupMemory(start, allocated); +    }      void Relocate(void *new_start); @@ -147,8 +153,9 @@ class ProbingVocabulary : public base::Vocabulary {      WordIndex Insert(const StringPiece &str);      template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) { -      InternalFinishedLoading(); +      FinishedLoading();      } +    void FinishedLoading();      std::size_t UnkCountChangePadding() const { return 0; } @@ -157,9 +164,7 @@ class ProbingVocabulary : public base::Vocabulary {      void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);    private: -    void InternalFinishedLoading(); - -    typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup; +    typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup;      Lookup lookup_; @@ -181,7 +186,64 @@ template <class Vocab> void CheckSpecials(const Config &config, const Vocab &voc    if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");  } +class WriteUniqueWords { +  public: +    explicit WriteUniqueWords(int fd) : word_list_(fd) {} + +    void operator()(const StringPiece &word) { +      word_list_ << word << '\0'; +    } + +  private: +    util::FakeOFStream word_list_; +}; + +class NoOpUniqueWords { +  public: +    NoOpUniqueWords() {} +    void operator()(const StringPiece &word) {} +}; + +template <class NewWordAction = NoOpUniqueWords> class GrowableVocab { +  public: +    static std::size_t MemUsage(WordIndex content) { +      return Lookup::MemUsage(content > 2 ? content : 2); +    } + +    // Does not take ownership of write_wordi +    template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction()) +      : lookup_(initial_size), new_word_(new_word_construct) { +      FindOrInsert("<unk>"); // Force 0 +      FindOrInsert("<s>"); // Force 1 +      FindOrInsert("</s>"); // Force 2 +    } + +    WordIndex Index(const StringPiece &str) const { +      Lookup::ConstIterator i; +      return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; +    } + +    WordIndex FindOrInsert(const StringPiece &word) { +      ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size()); +      Lookup::MutableIterator it; +      if (!lookup_.FindOrInsert(entry, it)) { +        new_word_(word); +        UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words.  Change WordIndex to uint64_t in lm/word_index.hh"); +      } +      return it->value; +    } + +    WordIndex Size() const { return lookup_.Size(); } + +  private: +    typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup; + +    Lookup lookup_; + +    NewWordAction new_word_; +}; +  } // namespace ngram  } // namespace lm -#endif // LM_VOCAB__ +#endif // LM_VOCAB_H | 
