summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/vocab.cc')
-rw-r--r--klm/lm/vocab.cc44
1 files changed, 26 insertions, 18 deletions
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 04979d51..03b0767a 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -1,5 +1,6 @@
#include "lm/vocab.hh"
+#include "lm/binary_format.hh"
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/config.hh"
@@ -56,16 +57,6 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
}
}
-void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
- const uint8_t *data = static_cast<const uint8_t*>(data_void);
- while (size) {
- ssize_t ret = write(fd, data, size);
- if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed");
- data += ret;
- size -= ret;
- }
-}
-
} // namespace
WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {}
@@ -80,7 +71,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
void WriteWordsWrapper::Write(int fd) {
if ((off_t)-1 == lseek(fd, 0, SEEK_END))
UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words");
- WriteOrThrow(fd, buffer_.data(), buffer_.size());
+ util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
}
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
@@ -146,15 +137,28 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
+namespace {
+const unsigned int kProbingVocabularyVersion = 0;
+} // namespace
+
+namespace detail {
+struct ProbingVocabularyHeader {
+ // Lowest unused vocab id. This is also the number of words, including <unk>.
+ unsigned int version;
+ WordIndex bound;
+};
+} // namespace detail
+
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
- return Lookup::Size(entries, config.probing_multiplier);
+ return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);
}
void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
- lookup_ = Lookup(start, allocated);
- available_ = 1;
+ header_ = static_cast<detail::ProbingVocabularyHeader*>(start);
+ lookup_ = Lookup(static_cast<uint8_t*>(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated);
+ bound_ = 1;
saw_unk_ = false;
}
@@ -172,20 +176,24 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
saw_unk_ = true;
return 0;
} else {
- if (enumerate_) enumerate_->Add(available_, str);
- lookup_.Insert(Lookup::Packing::Make(hashed, available_));
- return available_++;
+ if (enumerate_) enumerate_->Add(bound_, str);
+ lookup_.Insert(Lookup::Packing::Make(hashed, bound_));
+ return bound_++;
}
}
void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
lookup_.FinishedInserting();
+ header_->bound = bound_;
+ header_->version = kProbingVocabularyVersion;
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+ UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
lookup_.LoadedBinary();
- available_ = ReadWords(fd, to);
+ ReadWords(fd, to);
+ bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
}