From 89238977fc9d8f8d9a6421b0d4f35afc200f08e7 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 28 Feb 2012 17:23:55 -0500 Subject: Subject: where's my kenlm update?? From: Chris Dyer --- klm/lm/vocab.cc | 53 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 20 deletions(-) (limited to 'klm/lm/vocab.cc') diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index ffec41ca..9fd698bb 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -6,12 +6,15 @@ #include "lm/config.hh" #include "lm/weights.hh" #include "util/exception.hh" +#include "util/file.hh" #include "util/joint_sort.hh" #include "util/murmur_hash.hh" #include "util/probing_hash_table.hh" #include +#include + namespace lm { namespace ngram { @@ -29,23 +32,30 @@ const uint64_t kUnknownHash = detail::HashForVocab("", 5); // Sadly some LMs have . const uint64_t kUnknownCapHash = detail::HashForVocab("", 5); -WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { - if (!enumerate) return std::numeric_limits::max(); +void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) { + // Check that we're at the right place by reading which is always first. + char check_unk[6]; + util::ReadOrThrow(fd, check_unk, 6); + UTIL_THROW_IF( + memcmp(check_unk, "", 6), + FormatLoadException, + "Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure."); + if (!enumerate) return; + enumerate->Add(0, ""); + + // Read all the words after unk. const std::size_t kInitialRead = 16384; std::string buf; buf.reserve(kInitialRead + 100); buf.resize(kInitialRead); - WordIndex index = 0; + WordIndex index = 1; // Read already. while (true) { - ssize_t got = read(fd, &buf[0], kInitialRead); - UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words"); - if (got == 0) return index; + std::size_t got = util::ReadOrEOF(fd, &buf[0], kInitialRead); + if (got == 0) break; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; - ssize_t ret = read(fd, &next_char, 1); - UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); - UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word."); + util::ReadOrThrow(fd, &next_char, 1); buf.push_back(next_char); } // Ok now we have null terminated strings. @@ -55,6 +65,8 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { i += length + 1 /* null byte */; } } + + UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file."); } } // namespace @@ -69,8 +81,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { } void WriteWordsWrapper::Write(int fd) { - if ((off_t)-1 == lseek(fd, 0, SEEK_END)) - UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words"); + util::SeekEnd(fd); util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); } @@ -114,8 +125,10 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { if (enumerate_) { - util::PairedIterator values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); - util::JointSort(begin_, end_, values); + if (!strings_to_enumerate_.empty()) { + util::PairedIterator values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::JointSort(begin_, end_, values); + } for (WordIndex i = 0; i < static_cast(end_ - begin_); ++i) { // strikes again: +1 here. enumerate_->Add(i + 1, strings_to_enumerate_[i]); @@ -131,11 +144,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { bound_ = end_ - begin_ + 1; } -void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { +void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { end_ = begin_ + *(reinterpret_cast(begin_) - 1); - ReadWords(fd, to); SetSpecial(Index(""), Index(""), 0); bound_ = end_ - begin_ + 1; + if (have_words) ReadWords(fd, to, bound_); } namespace { @@ -153,12 +166,12 @@ struct ProbingVocabularyHeader { ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { - return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); + return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { header_ = static_cast(start); - lookup_ = Lookup(static_cast(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated); + lookup_ = Lookup(static_cast(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated); bound_ = 1; saw_unk_ = false; } @@ -178,7 +191,7 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { return 0; } else { if (enumerate_) enumerate_->Add(bound_, str); - lookup_.Insert(Lookup::Packing::Make(hashed, bound_)); + lookup_.Insert(ProbingVocabuaryEntry::Make(hashed, bound_)); return bound_++; } } @@ -190,12 +203,12 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { SetSpecial(Index(""), Index(""), 0); } -void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { +void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); lookup_.LoadedBinary(); - ReadWords(fd, to); bound_ = header_->bound; SetSpecial(Index(""), Index(""), 0); + if (have_words) ReadWords(fd, to, bound_); } void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { -- cgit v1.2.3