summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/vocab.cc')
-rw-r--r--klm/lm/vocab.cc18
1 files changed, 10 insertions, 8 deletions
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 515af5db..7defd5c1 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -28,8 +28,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
-void ReadWords(int fd, EnumerateVocab *enumerate) {
- if (!enumerate) return;
+WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
+ if (!enumerate) return std::numeric_limits<WordIndex>::max();
const std::size_t kInitialRead = 16384;
std::string buf;
buf.reserve(kInitialRead + 100);
@@ -38,7 +38,7 @@ void ReadWords(int fd, EnumerateVocab *enumerate) {
while (true) {
ssize_t got = read(fd, &buf[0], kInitialRead);
if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
- if (got == 0) return;
+ if (got == 0) return index;
buf.resize(got);
while (buf[buf.size() - 1]) {
char next_char;
@@ -87,13 +87,13 @@ SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL
std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) {
// Lead with the number of entries.
- return sizeof(uint64_t) + sizeof(Entry) * entries;
+ return sizeof(uint64_t) + sizeof(uint64_t) * entries;
}
void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) {
assert(allocated >= Size(entries, config));
// Leave space for number of entries.
- begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1);
+ begin_ = reinterpret_cast<uint64_t*>(start) + 1;
end_ = begin_;
saw_unk_ = false;
}
@@ -112,7 +112,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
saw_unk_ = true;
return 0;
}
- end_->key = hashed;
+ *end_ = hashed;
if (enumerate_) {
strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size());
}
@@ -134,8 +134,10 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
util::JointSort(begin_, end_, reorder_vocab + 1);
}
SetSpecial(Index("<s>"), Index("</s>"), 0);
- // Save size.
+ // Save size. Excludes UNK.
*(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
+ // Includes UNK.
+ bound_ = end_ - begin_ + 1;
}
void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
@@ -183,7 +185,7 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
lookup_.LoadedBinary();
- ReadWords(fd, to);
+ available_ = ReadWords(fd, to);
SetSpecial(Index("<s>"), Index("</s>"), 0);
}