diff options
Diffstat (limited to 'klm/lm/vocab.cc')
-rw-r--r-- | klm/lm/vocab.cc | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 5de68f16..fd7f96dc 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -80,14 +80,14 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { buffer_.push_back(0); } -void WriteWordsWrapper::Write(int fd) { - util::SeekEnd(fd); +void WriteWordsWrapper::Write(int fd, uint64_t start) { + util::SeekOrThrow(fd, start); util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); } SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} -std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { +uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { // Lead with the number of entries. return sizeof(uint64_t) + sizeof(uint64_t) * entries; } @@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { } *end_ = hashed; if (enumerate_) { - strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); + void *copied = string_backing_.Allocate(str.size()); + memcpy(copied, str.data(), str.size()); + strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size()); } ++end_; // This is 1 + the offset where it was inserted to make room for unk. @@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { if (enumerate_) { if (!strings_to_enumerate_.empty()) { - util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); util::JointSort(begin_, end_, values); } for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) { @@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { enumerate_->Add(i + 1, strings_to_enumerate_[i]); } strings_to_enumerate_.clear(); + string_backing_.FreeAll(); } else { util::JointSort(begin_, end_, reorder_vocab + 1); } @@ -165,7 +168,7 @@ struct ProbingVocabularyHeader { ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} -std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { +uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) { return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } |