diff options
Diffstat (limited to 'klm/lm/vocab.cc')
-rw-r--r-- | klm/lm/vocab.cc | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 515af5db..7defd5c1 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -28,8 +28,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5); // Sadly some LMs have <UNK>. const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); -void ReadWords(int fd, EnumerateVocab *enumerate) { - if (!enumerate) return; +WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { + if (!enumerate) return std::numeric_limits<WordIndex>::max(); const std::size_t kInitialRead = 16384; std::string buf; buf.reserve(kInitialRead + 100); @@ -38,7 +38,7 @@ void ReadWords(int fd, EnumerateVocab *enumerate) { while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (got == 0) return; + if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; @@ -87,13 +87,13 @@ SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { // Lead with the number of entries. - return sizeof(uint64_t) + sizeof(Entry) * entries; + return sizeof(uint64_t) + sizeof(uint64_t) * entries; } void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) { assert(allocated >= Size(entries, config)); // Leave space for number of entries. - begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1); + begin_ = reinterpret_cast<uint64_t*>(start) + 1; end_ = begin_; saw_unk_ = false; } @@ -112,7 +112,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { saw_unk_ = true; return 0; } - end_->key = hashed; + *end_ = hashed; if (enumerate_) { strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); } @@ -134,8 +134,10 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { util::JointSort(begin_, end_, reorder_vocab + 1); } SetSpecial(Index("<s>"), Index("</s>"), 0); - // Save size. + // Save size. Excludes UNK. *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_; + // Includes UNK. + bound_ = end_ - begin_ + 1; } void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { @@ -183,7 +185,7 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { lookup_.LoadedBinary(); - ReadWords(fd, to); + available_ = ReadWords(fd, to); SetSpecial(Index("<s>"), Index("</s>"), 0); } |