summaryrefslogtreecommitdiff
path: root/klm/lm/vocab.cc
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm/vocab.cc')
-rw-r--r--klm/lm/vocab.cc187
1 files changed, 187 insertions, 0 deletions
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
new file mode 100644
index 00000000..c30428b2
--- /dev/null
+++ b/klm/lm/vocab.cc
@@ -0,0 +1,187 @@
+#include "lm/vocab.hh"
+
+#include "lm/enumerate_vocab.hh"
+#include "lm/lm_exception.hh"
+#include "lm/config.hh"
+#include "lm/weights.hh"
+#include "util/exception.hh"
+#include "util/joint_sort.hh"
+#include "util/murmur_hash.hh"
+#include "util/probing_hash_table.hh"
+
+#include <string>
+
+namespace lm {
+namespace ngram {
+
+namespace detail {
+uint64_t HashForVocab(const char *str, std::size_t len) {
+ // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000
+ // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit.
+ return util::MurmurHash64A(str, len, 0);
+}
+} // namespace detail
+
+namespace {
+// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok.
+const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
+// Sadly some LMs have <UNK>.
+const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
+
+void ReadWords(int fd, EnumerateVocab *enumerate) {
+ if (!enumerate) return;
+ const std::size_t kInitialRead = 16384;
+ std::string buf;
+ buf.reserve(kInitialRead + 100);
+ buf.resize(kInitialRead);
+ WordIndex index = 0;
+ while (true) {
+ ssize_t got = read(fd, &buf[0], kInitialRead);
+ if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
+ if (got == 0) return;
+ buf.resize(got);
+ while (buf[buf.size() - 1]) {
+ char next_char;
+ ssize_t ret = read(fd, &next_char, 1);
+ if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
+ if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word.");
+ buf.push_back(next_char);
+ }
+ // Ok now we have null terminated strings.
+ for (const char *i = buf.data(); i != buf.data() + buf.size();) {
+ std::size_t length = strlen(i);
+ enumerate->Add(index++, StringPiece(i, length));
+ i += length + 1 /* null byte */;
+ }
+ }
+}
+
+void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
+ const uint8_t *data = static_cast<const uint8_t*>(data_void);
+ while (size) {
+ ssize_t ret = write(fd, data, size);
+ if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed");
+ data += ret;
+ size -= ret;
+ }
+}
+
+} // namespace
+
+WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner, int fd) : inner_(inner), fd_(fd) {}
+WriteWordsWrapper::~WriteWordsWrapper() {}
+
+void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
+ if (inner_) inner_->Add(index, str);
+ WriteOrThrow(fd_, str.data(), str.size());
+ char null_byte = 0;
+ // Inefficient because it's unbuffered. Sue me.
+ WriteOrThrow(fd_, &null_byte, 1);
+}
+
+SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
+
+std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) {
+ // Lead with the number of entries.
+ return sizeof(uint64_t) + sizeof(Entry) * entries;
+}
+
+void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) {
+ assert(allocated >= Size(entries, config));
+ // Leave space for number of entries.
+ begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1);
+ end_ = begin_;
+ saw_unk_ = false;
+}
+
+void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
+ enumerate_ = to;
+ if (enumerate_) {
+ enumerate_->Add(0, "<unk>");
+ strings_to_enumerate_.resize(max_entries);
+ }
+}
+
+WordIndex SortedVocabulary::Insert(const StringPiece &str) {
+ uint64_t hashed = detail::HashForVocab(str);
+ if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
+ saw_unk_ = true;
+ return 0;
+ }
+ end_->key = hashed;
+ if (enumerate_) {
+ strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size());
+ }
+ ++end_;
+ // This is 1 + the offset where it was inserted to make room for unk.
+ return end_ - begin_;
+}
+
+void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
+ if (enumerate_) {
+ util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::JointSort(begin_, end_, values);
+ for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
+ // <unk> strikes again: +1 here.
+ enumerate_->Add(i + 1, strings_to_enumerate_[i]);
+ }
+ strings_to_enumerate_.clear();
+ } else {
+ util::JointSort(begin_, end_, reorder_vocab + 1);
+ }
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+ // Save size.
+ *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
+}
+
+void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+ end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
+ ReadWords(fd, to);
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+}
+
+ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
+
+std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
+ return Lookup::Size(entries, config.probing_multiplier);
+}
+
+void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
+ lookup_ = Lookup(start, allocated);
+ available_ = 1;
+ saw_unk_ = false;
+}
+
+void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
+ enumerate_ = to;
+ if (enumerate_) {
+ enumerate_->Add(0, "<unk>");
+ }
+}
+
+WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
+ uint64_t hashed = detail::HashForVocab(str);
+ // Prevent unknown from going into the table.
+ if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
+ saw_unk_ = true;
+ return 0;
+ } else {
+ if (enumerate_) enumerate_->Add(available_, str);
+ lookup_.Insert(Lookup::Packing::Make(hashed, available_));
+ return available_++;
+ }
+}
+
+void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
+ lookup_.FinishedInserting();
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+}
+
+void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+ lookup_.LoadedBinary();
+ ReadWords(fd, to);
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+}
+
+} // namespace ngram
+} // namespace lm