summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.hh
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/vocab.hh')
-rw-r--r--klm/lm/vocab.hh90
1 files changed, 76 insertions, 14 deletions
diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh
index 074b74d8..d6ae07b8 100644
--- a/klm/lm/vocab.hh
+++ b/klm/lm/vocab.hh
@@ -1,9 +1,11 @@
-#ifndef LM_VOCAB__
-#define LM_VOCAB__
+#ifndef LM_VOCAB_H
+#define LM_VOCAB_H
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
+#include "util/fake_ofstream.hh"
+#include "util/murmur_hash.hh"
#include "util/pool.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"
@@ -104,17 +106,16 @@ class SortedVocabulary : public base::Vocabulary {
#pragma pack(push)
#pragma pack(4)
-struct ProbingVocabuaryEntry {
+struct ProbingVocabularyEntry {
uint64_t key;
WordIndex value;
typedef uint64_t Key;
- uint64_t GetKey() const {
- return key;
- }
+ uint64_t GetKey() const { return key; }
+ void SetKey(uint64_t to) { key = to; }
- static ProbingVocabuaryEntry Make(uint64_t key, WordIndex value) {
- ProbingVocabuaryEntry ret;
+ static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) {
+ ProbingVocabularyEntry ret;
ret.key = key;
ret.value = value;
return ret;
@@ -132,13 +133,18 @@ class ProbingVocabulary : public base::Vocabulary {
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
+ static uint64_t Size(uint64_t entries, float probing_multiplier);
+ // This just unwraps Config to get the probing_multiplier.
static uint64_t Size(uint64_t entries, const Config &config);
// Vocab words are [0, Bound()).
WordIndex Bound() const { return bound_; }
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
- void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
+ void SetupMemory(void *start, std::size_t allocated);
+ void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
+ SetupMemory(start, allocated);
+ }
void Relocate(void *new_start);
@@ -147,8 +153,9 @@ class ProbingVocabulary : public base::Vocabulary {
WordIndex Insert(const StringPiece &str);
template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) {
- InternalFinishedLoading();
+ FinishedLoading();
}
+ void FinishedLoading();
std::size_t UnkCountChangePadding() const { return 0; }
@@ -157,9 +164,7 @@ class ProbingVocabulary : public base::Vocabulary {
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
- void InternalFinishedLoading();
-
- typedef util::ProbingHashTable<ProbingVocabuaryEntry, util::IdentityHash> Lookup;
+ typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup;
Lookup lookup_;
@@ -181,7 +186,64 @@ template <class Vocab> void CheckSpecials(const Config &config, const Vocab &voc
if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
}
+class WriteUniqueWords {
+ public:
+ explicit WriteUniqueWords(int fd) : word_list_(fd) {}
+
+ void operator()(const StringPiece &word) {
+ word_list_ << word << '\0';
+ }
+
+ private:
+ util::FakeOFStream word_list_;
+};
+
+class NoOpUniqueWords {
+ public:
+ NoOpUniqueWords() {}
+ void operator()(const StringPiece &word) {}
+};
+
+template <class NewWordAction = NoOpUniqueWords> class GrowableVocab {
+ public:
+ static std::size_t MemUsage(WordIndex content) {
+ return Lookup::MemUsage(content > 2 ? content : 2);
+ }
+
+ // Does not take ownership of write_wordi
+ template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction())
+ : lookup_(initial_size), new_word_(new_word_construct) {
+ FindOrInsert("<unk>"); // Force 0
+ FindOrInsert("<s>"); // Force 1
+ FindOrInsert("</s>"); // Force 2
+ }
+
+ WordIndex Index(const StringPiece &str) const {
+ Lookup::ConstIterator i;
+ return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
+ }
+
+ WordIndex FindOrInsert(const StringPiece &word) {
+ ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size());
+ Lookup::MutableIterator it;
+ if (!lookup_.FindOrInsert(entry, it)) {
+ new_word_(word);
+ UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh");
+ }
+ return it->value;
+ }
+
+ WordIndex Size() const { return lookup_.Size(); }
+
+ private:
+ typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup;
+
+ Lookup lookup_;
+
+ NewWordAction new_word_;
+};
+
} // namespace ngram
} // namespace lm
-#endif // LM_VOCAB__
+#endif // LM_VOCAB_H