summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/vocab.cc')
-rw-r--r--klm/lm/vocab.cc15
1 files changed, 9 insertions, 6 deletions
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index 5de68f16..fd7f96dc 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -80,14 +80,14 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
buffer_.push_back(0);
}
-void WriteWordsWrapper::Write(int fd) {
- util::SeekEnd(fd);
+void WriteWordsWrapper::Write(int fd, uint64_t start) {
+ util::SeekOrThrow(fd, start);
util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
}
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
-std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) {
+uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
// Lead with the number of entries.
return sizeof(uint64_t) + sizeof(uint64_t) * entries;
}
@@ -116,7 +116,9 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
}
*end_ = hashed;
if (enumerate_) {
- strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size());
+ void *copied = string_backing_.Allocate(str.size());
+ memcpy(copied, str.data(), str.size());
+ strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size());
}
++end_;
// This is 1 + the offset where it was inserted to make room for unk.
@@ -126,7 +128,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
if (enumerate_) {
if (!strings_to_enumerate_.empty()) {
- util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::PairedIterator<ProbBackoff*, StringPiece*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
}
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
@@ -134,6 +136,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
}
strings_to_enumerate_.clear();
+ string_backing_.FreeAll();
} else {
util::JointSort(begin_, end_, reorder_vocab + 1);
}
@@ -165,7 +168,7 @@ struct ProbingVocabularyHeader {
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
-std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
+uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) {
return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);
}