diff options
Diffstat (limited to 'klm/lm/vocab.cc')
-rw-r--r-- | klm/lm/vocab.cc | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index fd7f96dc..7f0878f4 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5); // Sadly some LMs have <UNK>. const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); -void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) { +void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) { + util::SeekOrThrow(fd, offset); // Check that we're at the right place by reading <unk> which is always first. char check_unk[6]; util::ReadOrThrow(fd, check_unk, 6); @@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { buffer_.push_back(0); } -void WriteWordsWrapper::Write(int fd, uint64_t start) { - util::SeekOrThrow(fd, start); - util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); -} - SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { @@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size saw_unk_ = false; } +void SortedVocabulary::Relocate(void *new_start) { + std::size_t delta = end_ - begin_; + begin_ = reinterpret_cast<uint64_t*>(new_start) + 1; + end_ = begin_ + delta; +} + void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) { enumerate_ = to; if (enumerate_) { @@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { bound_ = end_ - begin_ + 1; } -void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); SetSpecial(Index("<s>"), Index("</s>"), 0); bound_ = end_ - begin_ + 1; - if (have_words) ReadWords(fd, to, bound_); + if (have_words) ReadWords(fd, to, bound_, offset); } namespace { @@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz saw_unk_ = false; } +void ProbingVocabulary::Relocate(void *new_start) { + header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start); + lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader))); +} + void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) { enumerate_ = to; if (enumerate_) { @@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() { SetSpecial(Index("<s>"), Index("</s>"), 0); } -void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) { UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); - lookup_.LoadedBinary(); bound_ = header_->bound; SetSpecial(Index("<s>"), Index("</s>"), 0); - if (have_words) ReadWords(fd, to, bound_); + if (have_words) ReadWords(fd, to, bound_, offset); } void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { |