summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/vocab.cc')
-rw-r--r--klm/lm/vocab.cc28
1 files changed, 17 insertions, 11 deletions
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index fd7f96dc..7f0878f4 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
-void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) {
+void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
+ util::SeekOrThrow(fd, offset);
// Check that we're at the right place by reading <unk> which is always first.
char check_unk[6];
util::ReadOrThrow(fd, check_unk, 6);
@@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
-void WriteWordsWrapper::Write(int fd, uint64_t start) {
- util::SeekOrThrow(fd, start);
- util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
-}
-
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
@@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size
saw_unk_ = false;
}
+void SortedVocabulary::Relocate(void *new_start) {
+ std::size_t delta = end_ - begin_;
+ begin_ = reinterpret_cast<uint64_t*>(new_start) + 1;
+ end_ = begin_ + delta;
+}
+
void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
enumerate_ = to;
if (enumerate_) {
@@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
-void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
namespace {
@@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz
saw_unk_ = false;
}
+void ProbingVocabulary::Relocate(void *new_start) {
+ header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start);
+ lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)));
+}
+
void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
enumerate_ = to;
if (enumerate_) {
@@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
-void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
+void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
- lookup_.LoadedBinary();
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
- if (have_words) ReadWords(fd, to, bound_);
+ if (have_words) ReadWords(fd, to, bound_, offset);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {