summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/vocab.cc')
-rw-r--r--klm/lm/vocab.cc53
1 files changed, 33 insertions, 20 deletions
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
index ffec41ca..9fd698bb 100644
--- a/klm/lm/vocab.cc
+++ b/klm/lm/vocab.cc
@@ -6,12 +6,15 @@
#include "lm/config.hh"
#include "lm/weights.hh"
#include "util/exception.hh"
+#include "util/file.hh"
#include "util/joint_sort.hh"
#include "util/murmur_hash.hh"
#include "util/probing_hash_table.hh"
#include <string>
+#include <string.h>
+
namespace lm {
namespace ngram {
@@ -29,23 +32,30 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
-WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
- if (!enumerate) return std::numeric_limits<WordIndex>::max();
+void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) {
+ // Check that we're at the right place by reading <unk> which is always first.
+ char check_unk[6];
+ util::ReadOrThrow(fd, check_unk, 6);
+ UTIL_THROW_IF(
+ memcmp(check_unk, "<unk>", 6),
+ FormatLoadException,
+ "Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure.");
+ if (!enumerate) return;
+ enumerate->Add(0, "<unk>");
+
+ // Read all the words after unk.
const std::size_t kInitialRead = 16384;
std::string buf;
buf.reserve(kInitialRead + 100);
buf.resize(kInitialRead);
- WordIndex index = 0;
+ WordIndex index = 1; // Read <unk> already.
while (true) {
- ssize_t got = read(fd, &buf[0], kInitialRead);
- UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words");
- if (got == 0) return index;
+ std::size_t got = util::ReadOrEOF(fd, &buf[0], kInitialRead);
+ if (got == 0) break;
buf.resize(got);
while (buf[buf.size() - 1]) {
char next_char;
- ssize_t ret = read(fd, &next_char, 1);
- UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words");
- UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word.");
+ util::ReadOrThrow(fd, &next_char, 1);
buf.push_back(next_char);
}
// Ok now we have null terminated strings.
@@ -55,6 +65,8 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {
i += length + 1 /* null byte */;
}
}
+
+ UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file.");
}
} // namespace
@@ -69,8 +81,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
}
void WriteWordsWrapper::Write(int fd) {
- if ((off_t)-1 == lseek(fd, 0, SEEK_END))
- UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words");
+ util::SeekEnd(fd);
util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
}
@@ -114,8 +125,10 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) {
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
if (enumerate_) {
- util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
- util::JointSort(begin_, end_, values);
+ if (!strings_to_enumerate_.empty()) {
+ util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::JointSort(begin_, end_, values);
+ }
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
// <unk> strikes again: +1 here.
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
@@ -131,11 +144,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
bound_ = end_ - begin_ + 1;
}
-void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
- ReadWords(fd, to);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
+ if (have_words) ReadWords(fd, to, bound_);
}
namespace {
@@ -153,12 +166,12 @@ struct ProbingVocabularyHeader {
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
- return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);
+ return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier);
}
void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
header_ = static_cast<detail::ProbingVocabularyHeader*>(start);
- lookup_ = Lookup(static_cast<uint8_t*>(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated);
+ lookup_ = Lookup(static_cast<uint8_t*>(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated);
bound_ = 1;
saw_unk_ = false;
}
@@ -178,7 +191,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
return 0;
} else {
if (enumerate_) enumerate_->Add(bound_, str);
- lookup_.Insert(Lookup::Packing::Make(hashed, bound_));
+ lookup_.Insert(ProbingVocabuaryEntry::Make(hashed, bound_));
return bound_++;
}
}
@@ -190,12 +203,12 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
-void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
lookup_.LoadedBinary();
- ReadWords(fd, to);
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
+ if (have_words) ReadWords(fd, to, bound_);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {