summaryrefslogtreecommitdiff
path: root/klm
diff options
context:
space:
mode:
Diffstat (limited to 'klm')
-rw-r--r--klm/lm/Makefile.am14
-rw-r--r--klm/lm/binary_format.cc191
-rw-r--r--klm/lm/build_binary.cc13
-rw-r--r--klm/lm/config.cc22
-rw-r--r--klm/lm/lm_exception.cc21
-rw-r--r--klm/lm/model.cc239
-rw-r--r--klm/lm/model_test.cc200
-rw-r--r--klm/lm/ngram.cc522
-rw-r--r--klm/lm/ngram_query.cc2
-rw-r--r--klm/lm/read_arpa.cc154
-rw-r--r--klm/lm/search_hashed.cc66
-rw-r--r--klm/lm/search_trie.cc402
-rw-r--r--klm/lm/sri.cc5
-rw-r--r--klm/lm/trie.cc167
-rw-r--r--klm/lm/virtual_interface.cc5
-rw-r--r--klm/lm/virtual_interface.hh12
-rw-r--r--klm/lm/vocab.cc187
-rw-r--r--klm/util/Makefile.am1
-rw-r--r--klm/util/bit_packing.cc27
-rw-r--r--klm/util/bit_packing.hh88
-rw-r--r--klm/util/ersatz_progress.cc9
-rw-r--r--klm/util/ersatz_progress.hh6
-rw-r--r--klm/util/file_piece.cc141
-rw-r--r--klm/util/file_piece.hh51
-rw-r--r--klm/util/file_piece_test.cc56
-rw-r--r--klm/util/joint_sort.hh6
-rw-r--r--klm/util/mmap.cc44
-rw-r--r--klm/util/mmap.hh31
-rw-r--r--klm/util/proxy_iterator.hh2
-rw-r--r--klm/util/scoped.cc4
-rw-r--r--klm/util/scoped.hh19
-rw-r--r--klm/util/sorted_uniform.hh4
-rw-r--r--klm/util/string_piece.cc8
-rw-r--r--klm/util/string_piece.hh12
34 files changed, 2097 insertions, 634 deletions
diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am
index 48f9889e..eb71c0f5 100644
--- a/klm/lm/Makefile.am
+++ b/klm/lm/Makefile.am
@@ -7,11 +7,17 @@
noinst_LIBRARIES = libklm.a
libklm_a_SOURCES = \
- exception.cc \
- ngram.cc \
- ngram_build_binary.cc \
+ binary_format.cc \
+ config.cc \
+ lm_exception.cc \
+ model.cc \
ngram_query.cc \
- virtual_interface.cc
+ read_arpa.cc \
+ search_hashed.cc \
+ search_trie.cc \
+ trie.cc \
+ virtual_interface.cc \
+ vocab.cc
AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc
new file mode 100644
index 00000000..2a075b6b
--- /dev/null
+++ b/klm/lm/binary_format.cc
@@ -0,0 +1,191 @@
+#include "lm/binary_format.hh"
+
+#include "lm/lm_exception.hh"
+#include "util/file_piece.hh"
+
+#include <limits>
+#include <string>
+
+#include <fcntl.h>
+#include <errno.h>
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+namespace lm {
+namespace ngram {
+namespace {
+const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
+const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0";
+const long int kMagicVersion = 1;
+
+// Test values.
+struct Sanity {
+ char magic[sizeof(kMagicBytes)];
+ float zero_f, one_f, minus_half_f;
+ WordIndex one_word_index, max_word_index;
+ uint64_t one_uint64;
+
+ void SetToReference() {
+ std::memcpy(magic, kMagicBytes, sizeof(magic));
+ zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
+ one_word_index = 1;
+ max_word_index = std::numeric_limits<WordIndex>::max();
+ one_uint64 = 1;
+ }
+};
+
+const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"};
+
+std::size_t Align8(std::size_t in) {
+ std::size_t off = in % 8;
+ if (!off) return in;
+ return in + 8 - off;
+}
+
+std::size_t TotalHeaderSize(unsigned char order) {
+ return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
+}
+
+void ReadLoop(int fd, void *to_void, std::size_t size) {
+ uint8_t *to = static_cast<uint8_t*>(to_void);
+ while (size) {
+ ssize_t ret = read(fd, to, size);
+ if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file");
+ if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short");
+ to += ret;
+ size -= ret;
+ }
+}
+
+void WriteHeader(void *to, const Parameters &params) {
+ Sanity header = Sanity();
+ header.SetToReference();
+ memcpy(to, &header, sizeof(Sanity));
+ char *out = reinterpret_cast<char*>(to) + sizeof(Sanity);
+
+ *reinterpret_cast<FixedWidthParameters*>(out) = params.fixed;
+ out += sizeof(FixedWidthParameters);
+
+ uint64_t *counts = reinterpret_cast<uint64_t*>(out);
+ for (std::size_t i = 0; i < params.counts.size(); ++i) {
+ counts[i] = params.counts[i];
+ }
+}
+
+} // namespace
+namespace detail {
+
+bool IsBinaryFormat(int fd) {
+ const off_t size = util::SizeFile(fd);
+ if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(Sanity)))) return false;
+ // Try reading the header.
+ util::scoped_memory memory;
+ try {
+ util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
+ } catch (const util::Exception &e) {
+ return false;
+ }
+ Sanity reference_header = Sanity();
+ reference_header.SetToReference();
+ if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
+ if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
+ char *end_ptr;
+ const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
+ long int version = strtol(begin_version, &end_ptr, 10);
+ if ((end_ptr != begin_version) && version != kMagicVersion) {
+ UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA. Sorry.");
+ }
+ UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture.");
+ }
+ return false;
+}
+
+void ReadHeader(int fd, Parameters &out) {
+ if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file");
+ ReadLoop(fd, &out.fixed, sizeof(out.fixed));
+ if (out.fixed.probing_multiplier < 1.0)
+ UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
+
+ out.counts.resize(static_cast<std::size_t>(out.fixed.order));
+ ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
+}
+
+void MatchCheck(ModelType model_type, const Parameters &params) {
+ if (params.fixed.model_type != model_type) {
+ if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *)))
+ UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code.");
+ UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]);
+ }
+}
+
+uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
+ const off_t file_size = util::SizeFile(backing.file.get());
+ // The header is smaller than a page, so we have to map the whole header as well.
+ std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size;
+ if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map)
+ UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
+
+ util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory);
+
+ if (config.enumerate_vocab && !params.fixed.has_vocabulary)
+ UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
+
+ if (config.enumerate_vocab) {
+ if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
+ UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words");
+ }
+ return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(params.counts.size());
+}
+
+uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) {
+ if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0");
+ if (config.write_mmap) {
+ std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size;
+ // Write out an mmap file.
+ backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED);
+
+ Parameters params;
+ params.counts = counts;
+ params.fixed.order = counts.size();
+ params.fixed.probing_multiplier = config.probing_multiplier;
+ params.fixed.model_type = model_type;
+ params.fixed.has_vocabulary = config.include_vocab;
+
+ WriteHeader(backing.memory.get(), params);
+
+ if (params.fixed.has_vocabulary) {
+ if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
+ UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words");
+ }
+ return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(counts.size());
+ } else {
+ backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED);
+ return reinterpret_cast<uint8_t*>(backing.memory.get());
+ }
+}
+
+void ComplainAboutARPA(const Config &config, ModelType model_type) {
+ if (config.write_mmap || !config.messages) return;
+ if (config.arpa_complain == Config::ALL) {
+ *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
+ } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) {
+ *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
+ }
+}
+
+} // namespace detail
+
+bool RecognizeBinary(const char *file, ModelType &recognized) {
+ util::scoped_fd fd(util::OpenReadOrThrow(file));
+ if (!detail::IsBinaryFormat(fd.get())) return false;
+ Parameters params;
+ detail::ReadHeader(fd.get(), params);
+ recognized = params.fixed.model_type;
+ return true;
+}
+
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc
new file mode 100644
index 00000000..4db631a2
--- /dev/null
+++ b/klm/lm/build_binary.cc
@@ -0,0 +1,13 @@
+#include "lm/model.hh"
+
+#include <iostream>
+
+int main(int argc, char *argv[]) {
+ if (argc != 3) {
+ std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl;
+ return 1;
+ }
+ lm::ngram::Config config;
+ config.write_mmap = argv[2];
+ lm::ngram::Model(argv[1], config);
+}
diff --git a/klm/lm/config.cc b/klm/lm/config.cc
new file mode 100644
index 00000000..2831d578
--- /dev/null
+++ b/klm/lm/config.cc
@@ -0,0 +1,22 @@
+#include "lm/config.hh"
+
+#include <iostream>
+
+namespace lm {
+namespace ngram {
+
+Config::Config() :
+ messages(&std::cerr),
+ enumerate_vocab(NULL),
+ unknown_missing(COMPLAIN),
+ unknown_missing_prob(0.0),
+ probing_multiplier(1.5),
+ building_memory(1073741824ULL), // 1 GB
+ temporary_directory_prefix(NULL),
+ arpa_complain(ALL),
+ write_mmap(NULL),
+ include_vocab(true),
+ load_method(util::POPULATE_OR_READ) {}
+
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/lm_exception.cc b/klm/lm/lm_exception.cc
new file mode 100644
index 00000000..ab2ec52f
--- /dev/null
+++ b/klm/lm/lm_exception.cc
@@ -0,0 +1,21 @@
+#include "lm/lm_exception.hh"
+
+#include<errno.h>
+#include<stdio.h>
+
+namespace lm {
+
+LoadException::LoadException() throw() {}
+LoadException::~LoadException() throw() {}
+VocabLoadException::VocabLoadException() throw() {}
+VocabLoadException::~VocabLoadException() throw() {}
+
+FormatLoadException::FormatLoadException() throw() {}
+FormatLoadException::~FormatLoadException() throw() {}
+
+SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() {
+ *this << "Missing special word " << which;
+}
+SpecialWordMissingException::~SpecialWordMissingException() throw() {}
+
+} // namespace lm
diff --git a/klm/lm/model.cc b/klm/lm/model.cc
new file mode 100644
index 00000000..6921d4d9
--- /dev/null
+++ b/klm/lm/model.cc
@@ -0,0 +1,239 @@
+#include "lm/model.hh"
+
+#include "lm/lm_exception.hh"
+#include "lm/search_hashed.hh"
+#include "lm/search_trie.hh"
+#include "lm/read_arpa.hh"
+#include "util/murmur_hash.hh"
+
+#include <algorithm>
+#include <functional>
+#include <numeric>
+#include <cmath>
+
+namespace lm {
+namespace ngram {
+
+size_t hash_value(const State &state) {
+ return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_);
+}
+
+namespace detail {
+
+template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
+ if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
+ if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
+ return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
+}
+
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) {
+ uint8_t *start = static_cast<uint8_t*>(base);
+ size_t allocated = VocabularyT::Size(counts[0], config);
+ vocab_.SetupMemory(start, allocated, counts[0], config);
+ start += allocated;
+ start = search_.SetupMemory(start, counts, config);
+ if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config));
+}
+
+template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) {
+ LoadLM(file, config, *this);
+
+ // g++ prints warnings unless these are fully initialized.
+ State begin_sentence = State();
+ begin_sentence.valid_length_ = 1;
+ begin_sentence.history_[0] = vocab_.BeginSentence();
+ begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff;
+ State null_context = State();
+ null_context.valid_length_ = 0;
+ P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2);
+}
+
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters &params, const Config &config, int fd) {
+ SetupMemory(start, params.counts, config);
+ vocab_.LoadedBinary(fd, config.enumerate_vocab);
+ search_.unigram.LoadedBinary();
+ for (typename std::vector<Middle>::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) {
+ i->LoadedBinary();
+ }
+ search_.longest.LoadedBinary();
+}
+
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters &params, const Config &config) {
+ SetupMemory(start, params.counts, config);
+
+ if (config.write_mmap) {
+ WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get());
+ vocab_.ConfigureEnumerate(&wrap, params.counts[0]);
+ search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
+ } else {
+ vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]);
+ search_.InitializeFromARPA(file, f, params.counts, config, vocab_);
+ }
+ // TODO: fail faster?
+ if (!vocab_.SawUnk()) {
+ switch(config.unknown_missing) {
+ case Config::THROW_UP:
+ {
+ SpecialWordMissingException e("<unk>");
+ e << " and configuration was set to throw if unknown is missing";
+ throw e;
+ }
+ case Config::COMPLAIN:
+ if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
+ // There's no break;. This is by design.
+ case Config::SILENT:
+ // Default probabilities for unknown.
+ search_.unigram.Unknown().backoff = 0.0;
+ search_.unigram.Unknown().prob = config.unknown_missing_prob;
+ break;
+ }
+ }
+ if (std::fabs(search_.unigram.Unknown().backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff);
+}
+
+template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
+ unsigned char backoff_start;
+ FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state);
+ if (backoff_start - 1 < in_state.valid_length_) {
+ ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob);
+ }
+ return ret;
+}
+
+template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
+ unsigned char backoff_start;
+ context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
+ FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state);
+ ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start);
+ return ret;
+}
+
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
+ context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
+ if (context_rend == context_rbegin || *context_rbegin == 0) {
+ out_state.valid_length_ = 0;
+ return;
+ }
+ float ignored_prob;
+ typename Search::Node node;
+ search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node);
+ float *backoff_out = out_state.backoff_ + 1;
+ const WordIndex *i = context_rbegin + 1;
+ for (; i < context_rend; ++i, ++backoff_out) {
+ if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) {
+ out_state.valid_length_ = i - context_rbegin;
+ std::copy(context_rbegin, i, out_state.history_);
+ return;
+ }
+ }
+ std::copy(context_rbegin, context_rend, out_state.history_);
+ out_state.valid_length_ = static_cast<unsigned char>(context_rend - context_rbegin);
+}
+
+template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup(
+ const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const {
+ // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
+ if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0;
+ float ret = 0.0;
+ if (start == 1) {
+ ret += search_.unigram.Lookup(*context_rbegin).backoff;
+ start = 2;
+ }
+ typename Search::Node node;
+ if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
+ return 0.0;
+ }
+ float backoff;
+ // i is the order of the backoff we're looking for.
+ for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
+ if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
+ ret += backoff;
+ }
+ return ret;
+}
+
+/* Ugly optimized function. Produce a score excluding backoff.
+ * The search goes in increasing order of ngram length.
+ * Context goes backward, so context_begin is the word immediately preceeding
+ * new_word.
+ */
+template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
+ const WordIndex *context_rbegin,
+ const WordIndex *context_rend,
+ const WordIndex new_word,
+ unsigned char &backoff_start,
+ State &out_state) const {
+ FullScoreReturn ret;
+ typename Search::Node node;
+ float *backoff_out(out_state.backoff_);
+ search_.LookupUnigram(new_word, ret.prob, *backoff_out, node);
+ if (new_word == 0) {
+ ret.ngram_length = out_state.valid_length_ = 0;
+ // All of backoff.
+ backoff_start = 1;
+ return ret;
+ }
+ out_state.history_[0] = new_word;
+ if (context_rbegin == context_rend) {
+ ret.ngram_length = out_state.valid_length_ = 1;
+ // No backoff because we don't have the history for it.
+ backoff_start = P::Order();
+ return ret;
+ }
+ ++backoff_out;
+
+ // Ok now we now that the bigram contains known words. Start by looking it up.
+
+ const WordIndex *hist_iter = context_rbegin;
+ typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin();
+ for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
+ if (hist_iter == context_rend) {
+ // Ran out of history. No backoff.
+ backoff_start = P::Order();
+ std::copy(context_rbegin, context_rend, out_state.history_ + 1);
+ ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1;
+ // ret.prob was already set.
+ return ret;
+ }
+
+ if (mid_iter == search_.middle.end()) break;
+
+ if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) {
+ // Didn't find an ngram using hist_iter.
+ // The history used in the found n-gram is [context_rbegin, hist_iter).
+ std::copy(context_rbegin, hist_iter, out_state.history_ + 1);
+ // Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word.
+ ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1;
+ backoff_start = mid_iter - search_.middle.begin() + 1;
+ // ret.prob was already set.
+ return ret;
+ }
+ }
+
+ // It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram.
+ // All that's left is to check search_.longest.
+
+ if (!search_.LookupLongest(*hist_iter, ret.prob, node)) {
+ // It's an (P::Order()-1)-gram
+ std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
+ ret.ngram_length = out_state.valid_length_ = P::Order() - 1;
+ backoff_start = P::Order() - 1;
+ // ret.prob was already set.
+ return ret;
+ }
+ // It's an P::Order()-gram
+ // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
+ std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1);
+ out_state.valid_length_ = P::Order() - 1;
+ ret.ngram_length = P::Order();
+ backoff_start = P::Order();
+ return ret;
+}
+
+template class GenericModel<ProbingHashedSearch, ProbingVocabulary>;
+template class GenericModel<SortedHashedSearch, SortedVocabulary>;
+template class GenericModel<trie::TrieSearch, SortedVocabulary>;
+
+} // namespace detail
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc
new file mode 100644
index 00000000..159628d4
--- /dev/null
+++ b/klm/lm/model_test.cc
@@ -0,0 +1,200 @@
+#include "lm/model.hh"
+
+#include <stdlib.h>
+
+#define BOOST_TEST_MODULE ModelTest
+#include <boost/test/unit_test.hpp>
+
+namespace lm {
+namespace ngram {
+namespace {
+
+#define StartTest(word, ngram, score) \
+ ret = model.FullScore( \
+ state, \
+ model.GetVocabulary().Index(word), \
+ out);\
+ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
+ BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
+ BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
+
+#define AppendTest(word, ngram, score) \
+ StartTest(word, ngram, score) \
+ state = out;
+
+template <class M> void Starters(const M &model) {
+ FullScoreReturn ret;
+ Model::State state(model.BeginSentenceState());
+ Model::State out;
+
+ StartTest("looking", 2, -0.4846522);
+
+ // , probability plus <s> backoff
+ StartTest(",", 1, -1.383514 + -0.4149733);
+ // <unk> probability plus <s> backoff
+ StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
+}
+
+template <class M> void Continuation(const M &model) {
+ FullScoreReturn ret;
+ Model::State state(model.BeginSentenceState());
+ Model::State out;
+
+ AppendTest("looking", 2, -0.484652);
+ AppendTest("on", 3, -0.348837);
+ AppendTest("a", 4, -0.0155266);
+ AppendTest("little", 5, -0.00306122);
+ State preserve = state;
+ AppendTest("the", 1, -4.04005);
+ AppendTest("biarritz", 1, -1.9889);
+ AppendTest("not_found", 0, -2.29666);
+ AppendTest("more", 1, -1.20632);
+ AppendTest(".", 2, -0.51363);
+ AppendTest("</s>", 3, -0.0191651);
+
+ state = preserve;
+ AppendTest("more", 5, -0.00181395);
+ AppendTest("loin", 5, -0.0432557);
+}
+
+#define StatelessTest(word, provide, ngram, score) \
+ ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
+ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
+ BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
+ model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \
+ ret = model.FullScore(before, indices[num_words - word - 1], out); \
+ BOOST_CHECK(state == out); \
+ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
+ BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length);
+
+template <class M> void Stateless(const M &model) {
+ const char *words[] = {"<s>", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", "</s>"};
+ const size_t num_words = sizeof(words) / sizeof(const char*);
+ // Silience "array subscript is above array bounds" when extracting end pointer.
+ WordIndex indices[num_words + 1];
+ for (unsigned int i = 0; i < num_words; ++i) {
+ indices[num_words - 1 - i] = model.GetVocabulary().Index(words[i]);
+ }
+ FullScoreReturn ret;
+ State state, out, before;
+
+ ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state);
+ BOOST_CHECK_CLOSE(-0.484652, ret.prob, 0.001);
+ StatelessTest(1, 1, 2, -0.484652);
+
+ // looking
+ StatelessTest(1, 2, 2, -0.484652);
+ // on
+ AppendTest("on", 3, -0.348837);
+ StatelessTest(2, 3, 3, -0.348837);
+ StatelessTest(2, 2, 3, -0.348837);
+ StatelessTest(2, 1, 2, -0.4638903);
+ // a
+ StatelessTest(3, 4, 4, -0.0155266);
+ // little
+ AppendTest("little", 5, -0.00306122);
+ StatelessTest(4, 5, 5, -0.00306122);
+ // the
+ AppendTest("the", 1, -4.04005);
+ StatelessTest(5, 5, 1, -4.04005);
+ // No context of the.
+ StatelessTest(5, 0, 1, -1.687872);
+ // biarritz
+ StatelessTest(6, 1, 1, -1.9889);
+ // not found
+ StatelessTest(7, 1, 0, -2.29666);
+ StatelessTest(7, 0, 0, -1.995635);
+
+ WordIndex unk[1];
+ unk[0] = 0;
+ model.GetState(unk, unk + 1, state);
+ BOOST_CHECK_EQUAL(0, state.valid_length_);
+}
+
+//const char *kExpectedOrderProbing[] = {"<unk>", ",", ".", "</s>", "<s>", "a", "also", "beyond", "biarritz", "call", "concerns", "consider", "considering", "for", "higher", "however", "i", "immediate", "in", "is", "little", "loin", "look", "looking", "more", "on", "screening", "small", "the", "to", "watch", "watching", "what", "would"};
+
+class ExpectEnumerateVocab : public EnumerateVocab {
+ public:
+ ExpectEnumerateVocab() {}
+
+ void Add(WordIndex index, const StringPiece &str) {
+ BOOST_CHECK_EQUAL(seen.size(), index);
+ seen.push_back(std::string(str.data(), str.length()));
+ }
+
+ void Check(const base::Vocabulary &vocab) {
+ BOOST_CHECK_EQUAL(34, seen.size());
+ BOOST_REQUIRE(!seen.empty());
+ BOOST_CHECK_EQUAL("<unk>", seen[0]);
+ for (WordIndex i = 0; i < seen.size(); ++i) {
+ BOOST_CHECK_EQUAL(i, vocab.Index(seen[i]));
+ }
+ }
+
+ void Clear() {
+ seen.clear();
+ }
+
+ std::vector<std::string> seen;
+};
+
+template <class ModelT> void LoadingTest() {
+ Config config;
+ config.arpa_complain = Config::NONE;
+ config.messages = NULL;
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+ ModelT m("test.arpa", config);
+ enumerate.Check(m.GetVocabulary());
+ Starters(m);
+ Continuation(m);
+ Stateless(m);
+}
+
+BOOST_AUTO_TEST_CASE(probing) {
+ LoadingTest<Model>();
+}
+
+BOOST_AUTO_TEST_CASE(sorted) {
+ LoadingTest<SortedModel>();
+}
+BOOST_AUTO_TEST_CASE(trie) {
+ LoadingTest<TrieModel>();
+}
+
+template <class ModelT> void BinaryTest() {
+ Config config;
+ config.write_mmap = "test.binary";
+ config.messages = NULL;
+ ExpectEnumerateVocab enumerate;
+ config.enumerate_vocab = &enumerate;
+
+ {
+ ModelT copy_model("test.arpa", config);
+ enumerate.Check(copy_model.GetVocabulary());
+ enumerate.Clear();
+ }
+
+ config.write_mmap = NULL;
+
+ ModelT binary("test.binary", config);
+ enumerate.Check(binary.GetVocabulary());
+ Starters(binary);
+ Continuation(binary);
+ Stateless(binary);
+ unlink("test.binary");
+}
+
+BOOST_AUTO_TEST_CASE(write_and_read_probing) {
+ BinaryTest<Model>();
+}
+BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
+ BinaryTest<SortedModel>();
+}
+BOOST_AUTO_TEST_CASE(write_and_read_trie) {
+ BinaryTest<TrieModel>();
+}
+
+} // namespace
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/ngram.cc b/klm/lm/ngram.cc
deleted file mode 100644
index a87c82aa..00000000
--- a/klm/lm/ngram.cc
+++ /dev/null
@@ -1,522 +0,0 @@
-#include "lm/ngram.hh"
-
-#include "lm/exception.hh"
-#include "util/file_piece.hh"
-#include "util/joint_sort.hh"
-#include "util/murmur_hash.hh"
-#include "util/probing_hash_table.hh"
-
-#include <algorithm>
-#include <functional>
-#include <numeric>
-#include <limits>
-#include <string>
-
-#include <cmath>
-#include <fcntl.h>
-#include <errno.h>
-#include <stdlib.h>
-#include <sys/mman.h>
-#include <sys/types.h>
-#include <sys/stat.h>
-#include <unistd.h>
-
-namespace lm {
-namespace ngram {
-
-size_t hash_value(const State &state) {
- return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_);
-}
-
-namespace detail {
-uint64_t HashForVocab(const char *str, std::size_t len) {
- // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000
- // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit.
- return util::MurmurHash64A(str, len, 0);
-}
-
-void Prob::SetBackoff(float to) {
- UTIL_THROW(FormatLoadException, "Attempt to set backoff " << to << " for the highest order n-gram");
-}
-
-// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok.
-const uint64_t kUnknownHash = HashForVocab("<unk>", 5);
-// Sadly some LMs have <UNK>.
-const uint64_t kUnknownCapHash = HashForVocab("<UNK>", 5);
-
-} // namespace detail
-
-SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL) {}
-
-std::size_t SortedVocabulary::Size(std::size_t entries, float ignored) {
- // Lead with the number of entries.
- return sizeof(uint64_t) + sizeof(Entry) * entries;
-}
-
-void SortedVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) {
- assert(allocated >= Size(entries));
- // Leave space for number of entries.
- begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1);
- end_ = begin_;
- saw_unk_ = false;
-}
-
-WordIndex SortedVocabulary::Insert(const StringPiece &str) {
- uint64_t hashed = detail::HashForVocab(str);
- if (hashed == detail::kUnknownHash || hashed == detail::kUnknownCapHash) {
- saw_unk_ = true;
- return 0;
- }
- end_->key = hashed;
- ++end_;
- // This is 1 + the offset where it was inserted to make room for unk.
- return end_ - begin_;
-}
-
-bool SortedVocabulary::FinishedLoading(detail::ProbBackoff *reorder_vocab) {
- util::JointSort(begin_, end_, reorder_vocab + 1);
- SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1);
- // Save size.
- *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
- return saw_unk_;
-}
-
-void SortedVocabulary::LoadedBinary() {
- end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
- SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1);
-}
-
-namespace detail {
-
-template <class Search> MapVocabulary<Search>::MapVocabulary() {}
-
-template <class Search> void MapVocabulary<Search>::Init(void *start, std::size_t allocated, std::size_t entries) {
- lookup_ = Lookup(start, allocated);
- available_ = 1;
- // Later if available_ != expected_available_ then we can throw UnknownMissingException.
- saw_unk_ = false;
-}
-
-template <class Search> WordIndex MapVocabulary<Search>::Insert(const StringPiece &str) {
- uint64_t hashed = HashForVocab(str);
- // Prevent unknown from going into the table.
- if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
- saw_unk_ = true;
- return 0;
- } else {
- lookup_.Insert(Lookup::Packing::Make(hashed, available_));
- return available_++;
- }
-}
-
-template <class Search> bool MapVocabulary<Search>::FinishedLoading(ProbBackoff *reorder_vocab) {
- lookup_.FinishedInserting();
- SetSpecial(Index("<s>"), Index("</s>"), 0, available_);
- return saw_unk_;
-}
-
-template <class Search> void MapVocabulary<Search>::LoadedBinary() {
- lookup_.LoadedBinary();
- SetSpecial(Index("<s>"), Index("</s>"), 0, available_);
-}
-
-/* All of the entropy is in low order bits and boost::hash does poorly with
- * these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a
- * stable point: 0. But 0 is <unk> which won't be queried here anyway.
- */
-inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
- uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
- return ret;
-}
-
-uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) {
- if (word == word_end) return 0;
- uint64_t current = static_cast<uint64_t>(*word);
- for (++word; word != word_end; ++word) {
- current = CombineWordHash(current, *word);
- }
- return current;
-}
-
-bool IsEntirelyWhiteSpace(const StringPiece &line) {
- for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
- if (!isspace(line.data()[i])) return false;
- }
- return true;
-}
-
-void ReadARPACounts(util::FilePiece &in, std::vector<size_t> &number) {
- number.clear();
- StringPiece line;
- if (!IsEntirelyWhiteSpace(line = in.ReadLine())) UTIL_THROW(FormatLoadException, "First line was \"" << line << "\" not blank");
- if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");
- while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
- if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
- // So strtol doesn't go off the end of line.
- std::string remaining(line.data() + 6, line.size() - 6);
- char *end_ptr;
- unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10);
- if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line);
- if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line);
- ++end_ptr;
- const char *start = end_ptr;
- long int count = std::strtol(start, &end_ptr, 10);
- if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count);
- if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line);
- number.push_back(count);
- }
-}
-
-void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
- StringPiece line;
- while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
- std::stringstream expected;
- expected << '\\' << length << "-grams:";
- if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead.");
-}
-
-// Special unigram reader because unigram's data structure is different and because we're inserting vocab words.
-template <class Voc> void Read1Grams(util::FilePiece &f, const size_t count, Voc &vocab, ProbBackoff *unigrams) {
- ReadNGramHeader(f, 1);
- for (size_t i = 0; i < count; ++i) {
- try {
- float prob = f.ReadFloat();
- if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability");
- ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())];
- value.prob = prob;
- switch (f.get()) {
- case '\t':
- value.SetBackoff(f.ReadFloat());
- if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
- break;
- case '\n':
- value.ZeroBackoff();
- break;
- default:
- UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
- }
- } catch(util::Exception &e) {
- e << " in the " << i << "th 1-gram at byte " << f.Offset();
- throw;
- }
- }
- if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after unigrams at byte " << f.Offset());
-}
-
-template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) {
- ReadNGramHeader(f, n);
-
- // vocab ids of words in reverse order
- WordIndex vocab_ids[n];
- typename Store::Packing::Value value;
- for (size_t i = 0; i < count; ++i) {
- try {
- value.prob = f.ReadFloat();
- for (WordIndex *vocab_out = &vocab_ids[n-1]; vocab_out >= vocab_ids; --vocab_out) {
- *vocab_out = vocab.Index(f.ReadDelimited());
- }
- uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n);
-
- switch (f.get()) {
- case '\t':
- value.SetBackoff(f.ReadFloat());
- if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
- break;
- case '\n':
- value.ZeroBackoff();
- break;
- default:
- UTIL_THROW(FormatLoadException, "Expected tab or newline after n-gram");
- }
- store.Insert(Store::Packing::Make(key, value));
- } catch(util::Exception &e) {
- e << " in the " << i << "th " << n << "-gram at byte " << f.Offset();
- throw;
- }
- }
-
- if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after " << n << "-grams at byte " << f.Offset());
- store.FinishedInserting();
-}
-
-template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<size_t> &counts, const Config &config) {
- if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
- if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
- size_t memory_size = VocabularyT::Size(counts[0], config.probing_multiplier);
- memory_size += sizeof(ProbBackoff) * (counts[0] + 1); // +1 for hallucinate <unk>
- for (unsigned char n = 2; n < counts.size(); ++n) {
- memory_size += Middle::Size(counts[n - 1], config.probing_multiplier);
- }
- memory_size += Longest::Size(counts.back(), config.probing_multiplier);
- return memory_size;
-}
-
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(char *base, const std::vector<size_t> &counts, const Config &config) {
- char *start = base;
- size_t allocated = VocabularyT::Size(counts[0], config.probing_multiplier);
- vocab_.Init(start, allocated, counts[0]);
- start += allocated;
- unigram_ = reinterpret_cast<ProbBackoff*>(start);
- start += sizeof(ProbBackoff) * (counts[0] + 1);
- for (unsigned int n = 2; n < counts.size(); ++n) {
- allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
- middle_.push_back(Middle(start, allocated));
- start += allocated;
- }
- allocated = Longest::Size(counts.back(), config.probing_multiplier);
- longest_ = Longest(start, allocated);
- start += allocated;
- if (static_cast<std::size_t>(start - base) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - base) << " but Size says they should take " << Size(counts, config));
-}
-
-const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 0\n\0";
-struct BinaryFileHeader {
- char magic[sizeof(kMagicBytes)];
- float zero_f, one_f, minus_half_f;
- WordIndex one_word_index, max_word_index;
- uint64_t one_uint64;
-
- void SetToReference() {
- std::memcpy(magic, kMagicBytes, sizeof(magic));
- zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
- one_word_index = 1;
- max_word_index = std::numeric_limits<WordIndex>::max();
- one_uint64 = 1;
- }
-};
-
-bool IsBinaryFormat(int fd, off_t size) {
- if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(BinaryFileHeader)))) return false;
- // Try reading the header.
- util::scoped_mmap memory(mmap(NULL, sizeof(BinaryFileHeader), PROT_READ, MAP_FILE | MAP_PRIVATE, fd, 0), sizeof(BinaryFileHeader));
- if (memory.get() == MAP_FAILED) return false;
- BinaryFileHeader reference_header = BinaryFileHeader();
- reference_header.SetToReference();
- if (!memcmp(memory.get(), &reference_header, sizeof(BinaryFileHeader))) return true;
- if (!memcmp(memory.get(), "mmap lm ", 8)) UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Was it built on a different machine or with a different compiler?");
- return false;
-}
-
-std::size_t Align8(std::size_t in) {
- std::size_t off = in % 8;
- if (!off) return in;
- return in + 8 - off;
-}
-
-std::size_t TotalHeaderSize(unsigned int order) {
- return Align8(sizeof(BinaryFileHeader) + 1 /* order */ + sizeof(uint64_t) * order /* counts */ + sizeof(float) /* probing multiplier */ + 1 /* search_tag */);
-}
-
-void ReadBinaryHeader(const void *from, off_t size, std::vector<size_t> &out, float &probing_multiplier, unsigned char &search_tag) {
- const char *from_char = reinterpret_cast<const char*>(from);
- if (size < static_cast<off_t>(1 + sizeof(BinaryFileHeader))) UTIL_THROW(FormatLoadException, "File too short to have count information.");
- // Skip over the BinaryFileHeader which was read by IsBinaryFormat.
- from_char += sizeof(BinaryFileHeader);
- unsigned char order = *reinterpret_cast<const unsigned char*>(from_char);
- if (size < static_cast<off_t>(TotalHeaderSize(order))) UTIL_THROW(FormatLoadException, "File too short to have full header.");
- out.resize(static_cast<std::size_t>(order));
- const uint64_t *counts = reinterpret_cast<const uint64_t*>(from_char + 1);
- for (std::size_t i = 0; i < out.size(); ++i) {
- out[i] = static_cast<std::size_t>(counts[i]);
- }
- const float *probing_ptr = reinterpret_cast<const float*>(counts + out.size());
- probing_multiplier = *probing_ptr;
- search_tag = *reinterpret_cast<const char*>(probing_ptr + 1);
-}
-
-void WriteBinaryHeader(void *to, const std::vector<size_t> &from, float probing_multiplier, char search_tag) {
- BinaryFileHeader header = BinaryFileHeader();
- header.SetToReference();
- memcpy(to, &header, sizeof(BinaryFileHeader));
- char *out = reinterpret_cast<char*>(to) + sizeof(BinaryFileHeader);
- *reinterpret_cast<unsigned char*>(out) = static_cast<unsigned char>(from.size());
- uint64_t *counts = reinterpret_cast<uint64_t*>(out + 1);
- for (std::size_t i = 0; i < from.size(); ++i) {
- counts[i] = from[i];
- }
- float *probing_ptr = reinterpret_cast<float*>(counts + from.size());
- *probing_ptr = probing_multiplier;
- *reinterpret_cast<char*>(probing_ptr + 1) = search_tag;
-}
-
-template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, Config config) : mapped_file_(util::OpenReadOrThrow(file)) {
- const off_t file_size = util::SizeFile(mapped_file_.get());
-
- std::vector<size_t> counts;
-
- if (IsBinaryFormat(mapped_file_.get(), file_size)) {
- memory_.reset(util::MapForRead(file_size, config.prefault, mapped_file_.get()), file_size);
-
- unsigned char search_tag;
- ReadBinaryHeader(memory_.begin(), file_size, counts, config.probing_multiplier, search_tag);
- if (config.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << config.probing_multiplier << " which is < 1.0.");
- if (search_tag != Search::kBinaryTag) UTIL_THROW(FormatLoadException, "The binary file has a different search strategy than the one requested.");
- size_t memory_size = Size(counts, config);
-
- char *start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size());
- if (memory_size != static_cast<size_t>(memory_.end() - start)) UTIL_THROW(FormatLoadException, "The mmap file " << file << " has size " << file_size << " but " << (memory_size + TotalHeaderSize(counts.size())) << " was expected based on the number of counts and configuration.");
-
- SetupMemory(start, counts, config);
- vocab_.LoadedBinary();
- for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) {
- i->LoadedBinary();
- }
- longest_.LoadedBinary();
-
- } else {
- if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0");
-
- util::FilePiece f(file, mapped_file_.release(), config.messages);
- ReadARPACounts(f, counts);
- size_t memory_size = Size(counts, config);
- char *start;
-
- if (config.write_mmap) {
- // Write out an mmap file.
- util::MapZeroedWrite(config.write_mmap, TotalHeaderSize(counts.size()) + memory_size, mapped_file_, memory_);
- WriteBinaryHeader(memory_.get(), counts, config.probing_multiplier, Search::kBinaryTag);
- start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size());
- } else {
- memory_.reset(util::MapAnonymous(memory_size), memory_size);
- start = reinterpret_cast<char*>(memory_.get());
- }
- SetupMemory(start, counts, config);
- try {
- LoadFromARPA(f, counts, config);
- } catch (FormatLoadException &e) {
- e << " in file " << file;
- throw;
- }
- }
-
- // g++ prints warnings unless these are fully initialized.
- State begin_sentence = State();
- begin_sentence.valid_length_ = 1;
- begin_sentence.history_[0] = vocab_.BeginSentence();
- begin_sentence.backoff_[0] = unigram_[begin_sentence.history_[0]].backoff;
- State null_context = State();
- null_context.valid_length_ = 0;
- P::Init(begin_sentence, null_context, vocab_, counts.size());
-}
-
-template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config) {
- // Read the unigrams.
- Read1Grams(f, counts[0], vocab_, unigram_);
- bool saw_unk = vocab_.FinishedLoading(unigram_);
- if (!saw_unk) {
- switch(config.unknown_missing) {
- case Config::THROW_UP:
- {
- SpecialWordMissingException e("<unk>");
- e << " and configuration was set to throw if unknown is missing";
- throw e;
- }
- case Config::COMPLAIN:
- if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
- // There's no break;. This is by design.
- case Config::SILENT:
- // Default probabilities for unknown.
- unigram_[0].backoff = 0.0;
- unigram_[0].prob = config.unknown_missing_prob;
- break;
- }
- }
-
- // Read the n-grams.
- for (unsigned int n = 2; n < counts.size(); ++n) {
- ReadNGrams(f, n, counts[n-1], vocab_, middle_[n-2]);
- }
- ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab_, longest_);
- if (std::fabs(unigram_[0].backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << unigram_[0].backoff);
-}
-
-/* Ugly optimized function.
- * in_state contains the previous ngram's length and backoff probabilites to
- * be used here. out_state is populated with the found ngram length and
- * backoffs that the next call will find useful.
- *
- * The search goes in increasing order of ngram length.
- */
-template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(
- const State &in_state,
- const WordIndex new_word,
- State &out_state) const {
-
- FullScoreReturn ret;
- // This is end pointer passed to SumBackoffs.
- const ProbBackoff &unigram = unigram_[new_word];
- if (new_word == 0) {
- ret.ngram_length = out_state.valid_length_ = 0;
- // all of backoff.
- ret.prob = std::accumulate(
- in_state.backoff_,
- in_state.backoff_ + in_state.valid_length_,
- unigram.prob);
- return ret;
- }
- float *backoff_out(out_state.backoff_);
- *backoff_out = unigram.backoff;
- ret.prob = unigram.prob;
- out_state.history_[0] = new_word;
- if (in_state.valid_length_ == 0) {
- ret.ngram_length = out_state.valid_length_ = 1;
- // No backoff because NGramLength() == 0 and unknown can't have backoff.
- return ret;
- }
- ++backoff_out;
-
- // Ok now we now that the bigram contains known words. Start by looking it up.
-
- uint64_t lookup_hash = static_cast<uint64_t>(new_word);
- const WordIndex *hist_iter = in_state.history_;
- const WordIndex *const hist_end = hist_iter + in_state.valid_length_;
- typename std::vector<Middle>::const_iterator mid_iter = middle_.begin();
- for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
- if (hist_iter == hist_end) {
- // Used history [in_state.history_, hist_end) and ran out. No backoff.
- std::copy(in_state.history_, hist_end, out_state.history_ + 1);
- ret.ngram_length = out_state.valid_length_ = in_state.valid_length_ + 1;
- // ret.prob was already set.
- return ret;
- }
- lookup_hash = CombineWordHash(lookup_hash, *hist_iter);
- if (mid_iter == middle_.end()) break;
- typename Middle::ConstIterator found;
- if (!mid_iter->Find(lookup_hash, found)) {
- // Didn't find an ngram using hist_iter.
- // The history used in the found n-gram is [in_state.history_, hist_iter).
- std::copy(in_state.history_, hist_iter, out_state.history_ + 1);
- // Therefore, we found a (hist_iter - in_state.history_ + 1)-gram including the last word.
- ret.ngram_length = out_state.valid_length_ = (hist_iter - in_state.history_) + 1;
- ret.prob = std::accumulate(
- in_state.backoff_ + (mid_iter - middle_.begin()),
- in_state.backoff_ + in_state.valid_length_,
- ret.prob);
- return ret;
- }
- *backoff_out = found->GetValue().backoff;
- ret.prob = found->GetValue().prob;
- }
-
- typename Longest::ConstIterator found;
- if (!longest_.Find(lookup_hash, found)) {
- // It's an (P::Order()-1)-gram
- std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
- ret.ngram_length = out_state.valid_length_ = P::Order() - 1;
- ret.prob += in_state.backoff_[P::Order() - 2];
- return ret;
- }
- // It's an P::Order()-gram
- // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
- std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
- out_state.valid_length_ = P::Order() - 1;
- ret.ngram_length = P::Order();
- ret.prob = found->GetValue().prob;
- return ret;
-}
-
-template class GenericModel<ProbingSearch, MapVocabulary<ProbingSearch> >;
-template class GenericModel<SortedUniformSearch, SortedVocabulary>;
-} // namespace detail
-} // namespace ngram
-} // namespace lm
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
index d1970260..74457a74 100644
--- a/klm/lm/ngram_query.cc
+++ b/klm/lm/ngram_query.cc
@@ -1,4 +1,4 @@
-#include "lm/ngram.hh"
+#include "lm/model.hh"
#include <cstdlib>
#include <fstream>
diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc
new file mode 100644
index 00000000..8e9a770d
--- /dev/null
+++ b/klm/lm/read_arpa.cc
@@ -0,0 +1,154 @@
+#include "lm/read_arpa.hh"
+
+#include <cstdlib>
+#include <vector>
+
+#include <ctype.h>
+#include <inttypes.h>
+
+namespace lm {
+
+namespace {
+
+bool IsEntirelyWhiteSpace(const StringPiece &line) {
+ for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
+ if (!isspace(line.data()[i])) return false;
+ }
+ return true;
+}
+
+template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &number) {
+ number.clear();
+ StringPiece line;
+ if (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
+ if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
+ UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, run\nzcat " << in.FileName() << " |kenlm/build_binary /dev/stdin " << in.FileName() << ".binary\nIf this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
+ }
+ UTIL_THROW(FormatLoadException, "First line was \"" << static_cast<int>(line.data()[1]) << "\" not blank");
+ }
+ if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");
+ while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
+ if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
+ // So strtol doesn't go off the end of line.
+ std::string remaining(line.data() + 6, line.size() - 6);
+ char *end_ptr;
+ unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10);
+ if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line);
+ if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line);
+ ++end_ptr;
+ const char *start = end_ptr;
+ long int count = std::strtol(start, &end_ptr, 10);
+ if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count);
+ if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line);
+ number.push_back(count);
+ }
+}
+
+template <class F> void GenericReadNGramHeader(F &in, unsigned int length) {
+ StringPiece line;
+ while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
+ std::stringstream expected;
+ expected << '\\' << length << "-grams:";
+ if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead. ");
+}
+
+template <class F> void GenericReadEnd(F &in) {
+ StringPiece line;
+ do {
+ line = in.ReadLine();
+ } while (IsEntirelyWhiteSpace(line));
+ if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line);
+}
+
+class FakeFilePiece {
+ public:
+ explicit FakeFilePiece(std::istream &in) : in_(in) {
+ in_.exceptions(std::ios::failbit | std::ios::badbit | std::ios::eofbit);
+ }
+
+ StringPiece ReadLine() throw(util::EndOfFileException) {
+ getline(in_, buffer_);
+ return StringPiece(buffer_);
+ }
+
+ float ReadFloat() {
+ float ret;
+ in_ >> ret;
+ return ret;
+ }
+
+ const char *FileName() const {
+ // This only used for error messages and we don't know the file name. . .
+ return "$file";
+ }
+
+ private:
+ std::istream &in_;
+ std::string buffer_;
+};
+
+} // namespace
+
+void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
+ GenericReadARPACounts(in, number);
+}
+void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number) {
+ FakeFilePiece fake(in);
+ GenericReadARPACounts(fake, number);
+}
+void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
+ GenericReadNGramHeader(in, length);
+}
+void ReadNGramHeader(std::istream &in, unsigned int length) {
+ FakeFilePiece fake(in);
+ GenericReadNGramHeader(fake, length);
+}
+
+void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
+ switch (in.get()) {
+ case '\t':
+ {
+ float got = in.ReadFloat();
+ if (got != 0.0)
+ UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff.");
+ }
+ break;
+ case '\n':
+ break;
+ default:
+ UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
+ }
+}
+
+void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
+ switch (in.get()) {
+ case '\t':
+ weights.backoff = in.ReadFloat();
+ if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
+ break;
+ case '\n':
+ weights.backoff = 0.0;
+ break;
+ default:
+ UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
+ }
+}
+
+void ReadEnd(util::FilePiece &in) {
+ GenericReadEnd(in);
+ StringPiece line;
+ try {
+ while (true) {
+ line = in.ReadLine();
+ if (!IsEntirelyWhiteSpace(line)) UTIL_THROW(FormatLoadException, "Trailing line " << line);
+ }
+ } catch (const util::EndOfFileException &e) {
+ return;
+ }
+}
+void ReadEnd(std::istream &in) {
+ FakeFilePiece fake(in);
+ GenericReadEnd(fake);
+}
+
+} // namespace lm
diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc
new file mode 100644
index 00000000..9cb662a6
--- /dev/null
+++ b/klm/lm/search_hashed.cc
@@ -0,0 +1,66 @@
+#include "lm/search_hashed.hh"
+
+#include "lm/lm_exception.hh"
+#include "lm/read_arpa.hh"
+#include "lm/vocab.hh"
+
+#include "util/file_piece.hh"
+
+#include <string>
+
+namespace lm {
+namespace ngram {
+
+namespace {
+
+/* All of the entropy is in low order bits and boost::hash does poorly with
+ * these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a
+ * stable point: 0. But 0 is <unk> which won't be queried here anyway.
+ */
+inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
+ uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
+ return ret;
+}
+
+uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) {
+ if (word == word_end) return 0;
+ uint64_t current = static_cast<uint64_t>(*word);
+ for (++word; word != word_end; ++word) {
+ current = CombineWordHash(current, *word);
+ }
+ return current;
+}
+
+template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) {
+ ReadNGramHeader(f, n);
+
+ // vocab ids of words in reverse order
+ WordIndex vocab_ids[n];
+ typename Store::Packing::Value value;
+ for (size_t i = 0; i < count; ++i) {
+ ReadNGram(f, n, vocab, vocab_ids, value);
+ uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n);
+ store.Insert(Store::Packing::Make(key, value));
+ }
+
+ store.FinishedInserting();
+}
+
+} // namespace
+namespace detail {
+
+template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &/*config*/, Voc &vocab) {
+ Read1Grams(f, counts[0], vocab, unigram.Raw());
+ // Read the n-grams.
+ for (unsigned int n = 2; n < counts.size(); ++n) {
+ ReadNGrams(f, n, counts[n-1], vocab, middle[n-2]);
+ }
+ ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, longest);
+}
+
+template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab);
+template void TemplateHashedSearch<SortedHashedSearch::Middle, SortedHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, SortedVocabulary &vocab);
+
+} // namespace detail
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc
new file mode 100644
index 00000000..182e27f5
--- /dev/null
+++ b/klm/lm/search_trie.cc
@@ -0,0 +1,402 @@
+#include "lm/search_trie.hh"
+
+#include "lm/lm_exception.hh"
+#include "lm/read_arpa.hh"
+#include "lm/trie.hh"
+#include "lm/vocab.hh"
+#include "lm/weights.hh"
+#include "lm/word_index.hh"
+#include "util/ersatz_progress.hh"
+#include "util/file_piece.hh"
+#include "util/scoped.hh"
+
+#include <algorithm>
+#include <cstring>
+#include <cstdio>
+#include <deque>
+#include <iostream>
+#include <limits>
+//#include <parallel/algorithm>
+#include <vector>
+
+#include <sys/mman.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <stdlib.h>
+
+namespace lm {
+namespace ngram {
+namespace trie {
+namespace {
+
+template <unsigned char Order> class FullEntry {
+ public:
+ typedef ProbBackoff Weights;
+ static const unsigned char kOrder = Order;
+
+ // reverse order
+ WordIndex words[Order];
+ Weights weights;
+
+ bool operator<(const FullEntry<Order> &other) const {
+ for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) {
+ if (*i < *j) return true;
+ if (*i > *j) return false;
+ }
+ return false;
+ }
+};
+
+template <unsigned char Order> class ProbEntry {
+ public:
+ typedef Prob Weights;
+ static const unsigned char kOrder = Order;
+
+ // reverse order
+ WordIndex words[Order];
+ Weights weights;
+
+ bool operator<(const ProbEntry<Order> &other) const {
+ for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) {
+ if (*i < *j) return true;
+ if (*i > *j) return false;
+ }
+ return false;
+ }
+};
+
+void WriteOrThrow(FILE *to, const void *data, size_t size) {
+ if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
+}
+
+void ReadOrThrow(FILE *from, void *data, size_t size) {
+ if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size);
+}
+
+void CopyOrThrow(FILE *from, FILE *to, size_t size) {
+ const size_t kBufSize = 512;
+ char buf[kBufSize];
+ for (size_t i = 0; i < size; i += kBufSize) {
+ std::size_t amount = std::min(size - i, kBufSize);
+ ReadOrThrow(from, buf, amount);
+ WriteOrThrow(to, buf, amount);
+ }
+}
+
+template <class Entry> std::string DiskFlush(const Entry *begin, const Entry *end, const std::string &file_prefix, std::size_t batch) {
+ std::stringstream assembled;
+ assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << '_' << batch;
+ std::string ret(assembled.str());
+ util::scoped_FILE out(fopen(ret.c_str(), "w"));
+ if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing");
+ for (const Entry *group_begin = begin; group_begin != end;) {
+ const Entry *group_end = group_begin;
+ for (++group_end; (group_end != end) && !memcmp(group_begin->words, group_end->words, sizeof(WordIndex) * (Entry::kOrder - 1)); ++group_end) {}
+ WriteOrThrow(out.get(), group_begin->words, sizeof(WordIndex) * (Entry::kOrder - 1));
+ WordIndex group_size = group_end - group_begin;
+ WriteOrThrow(out.get(), &group_size, sizeof(group_size));
+ for (const Entry *i = group_begin; i != group_end; ++i) {
+ WriteOrThrow(out.get(), &i->words[Entry::kOrder - 1], sizeof(WordIndex));
+ WriteOrThrow(out.get(), &i->weights, sizeof(typename Entry::Weights));
+ }
+ group_begin = group_end;
+ }
+ return ret;
+}
+
+class SortedFileReader {
+ public:
+ SortedFileReader() {}
+
+ void Init(const std::string &name, unsigned char order) {
+ file_.reset(fopen(name.c_str(), "r"));
+ if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read");
+ header_.resize(order - 1);
+ NextHeader();
+ }
+
+ // Preceding words.
+ const WordIndex *Header() const {
+ return &*header_.begin();
+ }
+ const std::vector<WordIndex> &HeaderVector() const { return header_;}
+
+ std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); }
+
+ void NextHeader() {
+ if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get()) && !Ended()) {
+ UTIL_THROW(util::ErrnoException, "Short read of counts");
+ }
+ }
+
+ void ReadCount(WordIndex &to) {
+ ReadOrThrow(file_.get(), &to, sizeof(WordIndex));
+ }
+
+ void ReadWord(WordIndex &to) {
+ ReadOrThrow(file_.get(), &to, sizeof(WordIndex));
+ }
+
+ template <class Weights> void ReadWeights(Weights &to) {
+ ReadOrThrow(file_.get(), &to, sizeof(Weights));
+ }
+
+ bool Ended() {
+ return feof(file_.get());
+ }
+
+ FILE *File() { return file_.get(); }
+
+ private:
+ util::scoped_FILE file_;
+
+ std::vector<WordIndex> header_;
+};
+
+void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) {
+ WriteOrThrow(to, from.Header(), from.HeaderBytes());
+ WordIndex count;
+ from.ReadCount(count);
+ WriteOrThrow(to, &count, sizeof(WordIndex));
+
+ CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count);
+}
+
+void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) {
+ SortedFileReader first, second;
+ first.Init(first_name, order);
+ second.Init(second_name, order);
+ util::scoped_FILE out_file(fopen(out, "w"));
+ if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write");
+ while (!first.Ended() && !second.Ended()) {
+ if (first.HeaderVector() < second.HeaderVector()) {
+ CopyFullRecord(first, out_file.get(), weights_size);
+ first.NextHeader();
+ continue;
+ }
+ if (first.HeaderVector() > second.HeaderVector()) {
+ CopyFullRecord(second, out_file.get(), weights_size);
+ second.NextHeader();
+ continue;
+ }
+ // Merge at the entry level.
+ WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes());
+ WordIndex first_count, second_count;
+ first.ReadCount(first_count); second.ReadCount(second_count);
+ WordIndex total_count = first_count + second_count;
+ WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex));
+
+ WordIndex first_word, second_word;
+ first.ReadWord(first_word); second.ReadWord(second_word);
+ WordIndex first_index = 0, second_index = 0;
+ while (true) {
+ if (first_word < second_word) {
+ WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex));
+ CopyOrThrow(first.File(), out_file.get(), weights_size);
+ if (++first_index == first_count) break;
+ first.ReadWord(first_word);
+ } else {
+ WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex));
+ CopyOrThrow(second.File(), out_file.get(), weights_size);
+ if (++second_index == second_count) break;
+ second.ReadWord(second_word);
+ }
+ }
+ if (first_index == first_count) {
+ WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex));
+ CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex));
+ } else {
+ WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex));
+ CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex));
+ }
+ first.NextHeader();
+ second.NextHeader();
+ }
+
+ for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) {
+ CopyFullRecord(remaining, out_file.get(), weights_size);
+ }
+}
+
+template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix) {
+ ConvertToSorted<FullEntry<Entry::kOrder - 1> >(f, vocab, counts, mem, file_prefix);
+
+ ReadNGramHeader(f, Entry::kOrder);
+ const size_t count = counts[Entry::kOrder - 1];
+ const size_t batch_size = std::min(count, mem.size() / sizeof(Entry));
+ Entry *const begin = reinterpret_cast<Entry*>(mem.get());
+ std::deque<std::string> files;
+ for (std::size_t batch = 0, done = 0; done < count; ++batch) {
+ Entry *out = begin;
+ Entry *out_end = out + std::min(count - done, batch_size);
+ for (; out != out_end; ++out) {
+ ReadNGram(f, Entry::kOrder, vocab, out->words, out->weights);
+ }
+ //__gnu_parallel::sort(begin, out_end);
+ std::sort(begin, out_end);
+
+ files.push_back(DiskFlush(begin, out_end, file_prefix, batch));
+ done += out_end - begin;
+ }
+
+ // All individual files created. Merge them.
+
+ std::size_t merge_count = 0;
+ while (files.size() > 1) {
+ std::stringstream assembled;
+ assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merge_" << (merge_count++);
+ files.push_back(assembled.str());
+ MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), sizeof(typename Entry::Weights), Entry::kOrder);
+ if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
+ files.pop_front();
+ if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]);
+ files.pop_front();
+ }
+ if (!files.empty()) {
+ std::stringstream assembled;
+ assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merged";
+ std::string merged_name(assembled.str());
+ if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str());
+ }
+}
+
+template <> void ConvertToSorted<FullEntry<1> >(util::FilePiece &/*f*/, const SortedVocabulary &/*vocab*/, const std::vector<uint64_t> &/*counts*/, util::scoped_memory &/*mem*/, const std::string &/*file_prefix*/) {}
+
+void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
+ {
+ std::string unigram_name = file_prefix + "unigrams";
+ util::scoped_fd unigram_file;
+ util::scoped_mmap unigram_mmap;
+ unigram_mmap.reset(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff));
+ Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()));
+ }
+
+ util::scoped_memory mem;
+ mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED);
+ if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
+ ConvertToSorted<ProbEntry<5> >(f, vocab, counts, mem, file_prefix);
+ ReadEnd(f);
+}
+
+struct RecursiveInsertParams {
+ WordIndex *words;
+ SortedFileReader *files;
+ unsigned char max_order;
+ // This is an array of size order - 2.
+ BitPackedMiddle *middle;
+ // This has exactly one entry.
+ BitPackedLongest *longest;
+};
+
+uint64_t RecursiveInsert(RecursiveInsertParams &params, unsigned char order) {
+ SortedFileReader &file = params.files[order - 2];
+ const uint64_t ret = (order == params.max_order) ? params.longest->InsertIndex() : params.middle[order - 2].InsertIndex();
+ if (std::memcmp(params.words, file.Header(), sizeof(WordIndex) * (order - 1)))
+ return ret;
+ WordIndex count;
+ file.ReadCount(count);
+ WordIndex key;
+ if (order == params.max_order) {
+ Prob value;
+ for (WordIndex i = 0; i < count; ++i) {
+ file.ReadWord(key);
+ file.ReadWeights(value);
+ params.longest->Insert(key, value.prob);
+ }
+ file.NextHeader();
+ return ret;
+ }
+ ProbBackoff value;
+ for (WordIndex i = 0; i < count; ++i) {
+ file.ReadWord(params.words[order - 1]);
+ file.ReadWeights(value);
+ params.middle[order - 2].Insert(
+ params.words[order - 1],
+ value.prob,
+ value.backoff,
+ RecursiveInsert(params, order + 1));
+ }
+ file.NextHeader();
+ return ret;
+}
+
+void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, std::ostream *messages, TrieSearch &out) {
+ UnigramValue *unigrams = out.unigram.Raw();
+ // Load unigrams. Leave the next pointers uninitialized.
+ {
+ std::string name(file_prefix + "unigrams");
+ util::scoped_FILE file(fopen(name.c_str(), "r"));
+ if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed");
+ for (WordIndex i = 0; i < counts[0]; ++i) {
+ ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff));
+ }
+ unlink(name.c_str());
+ }
+
+ // inputs[0] is bigrams.
+ SortedFileReader inputs[counts.size() - 1];
+ for (unsigned char i = 2; i <= counts.size(); ++i) {
+ std::stringstream assembled;
+ assembled << file_prefix << static_cast<unsigned int>(i) << "_merged";
+ inputs[i-2].Init(assembled.str(), i);
+ unlink(assembled.str().c_str());
+ }
+
+ // words[0] is unigrams.
+ WordIndex words[counts.size()];
+ RecursiveInsertParams params;
+ params.words = words;
+ params.files = inputs;
+ params.max_order = static_cast<unsigned char>(counts.size());
+ params.middle = &*out.middle.begin();
+ params.longest = &out.longest;
+ {
+ util::ErsatzProgress progress(messages, "Building trie", counts[0]);
+ for (words[0] = 0; words[0] < counts[0]; ++words[0], ++progress) {
+ unigrams[words[0]].next = RecursiveInsert(params, 2);
+ }
+ }
+
+ /* Set ending offsets so the last entry will be sized properly */
+ if (!out.middle.empty()) {
+ unigrams[counts[0]].next = out.middle.front().InsertIndex();
+ for (size_t i = 0; i < out.middle.size() - 1; ++i) {
+ out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex());
+ }
+ out.middle.back().FinishedLoading(out.longest.InsertIndex());
+ } else {
+ unigrams[counts[0]].next = out.longest.InsertIndex();
+ }
+}
+
+} // namespace
+
+void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab) {
+ std::string temporary_directory;
+ if (config.temporary_directory_prefix) {
+ temporary_directory = config.temporary_directory_prefix;
+ } else if (config.write_mmap) {
+ temporary_directory = config.write_mmap;
+ } else {
+ temporary_directory = file;
+ }
+ // Null on end is kludge to ensure null termination.
+ temporary_directory += "-tmp-XXXXXX\0";
+ if (!mkdtemp(&temporary_directory[0])) {
+ UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str());
+ }
+ // Chop off null kludge.
+ temporary_directory.resize(strlen(temporary_directory.c_str()));
+ // Add directory delimiter. Assumes a real operating system.
+ temporary_directory += '/';
+ ARPAToSortedFiles(f, counts, config.building_memory, temporary_directory.c_str(), vocab);
+ BuildTrie(temporary_directory.c_str(), counts, config.messages, *this);
+ if (rmdir(temporary_directory.c_str())) {
+ std::cerr << "Failed to delete " << temporary_directory << std::endl;
+ }
+}
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc
index 7bd23d76..b634d200 100644
--- a/klm/lm/sri.cc
+++ b/klm/lm/sri.cc
@@ -1,4 +1,4 @@
-#include "lm/exception.hh"
+#include "lm/lm_exception.hh"
#include "lm/sri.hh"
#include <Ngram.h>
@@ -31,8 +31,7 @@ void Vocabulary::FinishedLoading() {
SetSpecial(
sri_->ssIndex(),
sri_->seIndex(),
- sri_->unkIndex(),
- sri_->highIndex() + 1);
+ sri_->unkIndex());
}
namespace {
diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc
new file mode 100644
index 00000000..8ed7b2a2
--- /dev/null
+++ b/klm/lm/trie.cc
@@ -0,0 +1,167 @@
+#include "lm/trie.hh"
+
+#include "util/bit_packing.hh"
+#include "util/exception.hh"
+#include "util/proxy_iterator.hh"
+#include "util/sorted_uniform.hh"
+
+#include <assert.h>
+
+namespace lm {
+namespace ngram {
+namespace trie {
+namespace {
+
+// Assumes key is first.
+class JustKeyProxy {
+ public:
+ JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {}
+
+ operator uint64_t() const { return GetKey(); }
+
+ uint64_t GetKey() const {
+ uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_);
+ return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_);
+ }
+
+ private:
+ friend class util::ProxyIterator<JustKeyProxy>;
+ friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index);
+
+ JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits)
+ : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {}
+
+ // This is a read-only iterator.
+ JustKeyProxy &operator=(const JustKeyProxy &other);
+
+ typedef uint64_t value_type;
+
+ typedef uint64_t InnerIterator;
+ uint64_t &Inner() { return inner_; }
+ const uint64_t &Inner() const { return inner_; }
+
+ // The address in bits is base_ * 8 + inner_ * total_bits_.
+ uint64_t inner_;
+ const uint8_t *const base_;
+ const uint64_t key_mask_;
+ const uint8_t total_bits_;
+};
+
+bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) {
+ util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits));
+ util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits));
+ util::ProxyIterator<JustKeyProxy> out;
+ if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false;
+ at_index = out.Inner();
+ return true;
+}
+} // namespace
+
+std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
+ uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits;
+ // Extra entry for next pointer at the end.
+ // +7 then / 8 to round up bits and convert to bytes
+ // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.
+ // Note that this waste is O(order), not O(number of ngrams).
+ return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t);
+}
+
+void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) {
+ util::BitPackingSanity();
+ word_bits_ = util::RequiredBits(max_vocab);
+ word_mask_ = (1ULL << word_bits_) - 1ULL;
+ if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions.");
+ prob_bits_ = 31;
+ total_bits_ = word_bits_ + prob_bits_ + remaining_bits;
+
+ base_ = static_cast<uint8_t*>(base);
+ insert_index_ = 0;
+}
+
+std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) {
+ return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr));
+}
+
+void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) {
+ backoff_bits_ = 32;
+ next_bits_ = util::RequiredBits(max_next);
+ if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
+ next_mask_ = (1ULL << next_bits_) - 1;
+
+ BaseInit(base, max_vocab, backoff_bits_ + next_bits_);
+}
+
+void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) {
+ assert(word <= word_mask_);
+ assert(next <= next_mask_);
+ uint64_t at_pointer = insert_index_ * total_bits_;
+
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word);
+ at_pointer += word_bits_;
+ util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
+ at_pointer += prob_bits_;
+ util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff);
+ at_pointer += backoff_bits_;
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next);
+
+ ++insert_index_;
+}
+
+bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {
+ uint64_t at_pointer;
+ if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ at_pointer *= total_bits_;
+ at_pointer += word_bits_;
+ prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
+ at_pointer += prob_bits_;
+ backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
+ at_pointer += backoff_bits_;
+ range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ // Read the next entry's pointer.
+ at_pointer += total_bits_;
+ range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ return true;
+}
+
+bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
+ uint64_t at_pointer;
+ if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
+ at_pointer *= total_bits_;
+ at_pointer += word_bits_;
+ at_pointer += prob_bits_;
+ backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
+ at_pointer += backoff_bits_;
+ range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ // Read the next entry's pointer.
+ at_pointer += total_bits_;
+ range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
+ return true;
+}
+
+void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
+ assert(next_end <= next_mask_);
+ uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
+ util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end);
+}
+
+
+void BitPackedLongest::Insert(WordIndex index, float prob) {
+ assert(index <= word_mask_);
+ uint64_t at_pointer = insert_index_ * total_bits_;
+ util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index);
+ at_pointer += word_bits_;
+ util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob);
+ ++insert_index_;
+}
+
+bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const {
+ uint64_t at_pointer;
+ if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false;
+ at_pointer = at_pointer * total_bits_ + word_bits_;
+ prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7);
+ return true;
+}
+
+} // namespace trie
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/virtual_interface.cc b/klm/lm/virtual_interface.cc
index 9c7151f9..c5a64972 100644
--- a/klm/lm/virtual_interface.cc
+++ b/klm/lm/virtual_interface.cc
@@ -1,17 +1,16 @@
#include "lm/virtual_interface.hh"
-#include "lm/exception.hh"
+#include "lm/lm_exception.hh"
namespace lm {
namespace base {
Vocabulary::~Vocabulary() {}
-void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) {
+void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) {
begin_sentence_ = begin_sentence;
end_sentence_ = end_sentence;
not_found_ = not_found;
- available_ = available;
if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>");
if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>");
}
diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh
index 621a129e..f15f8789 100644
--- a/klm/lm/virtual_interface.hh
+++ b/klm/lm/virtual_interface.hh
@@ -37,8 +37,6 @@ class Vocabulary {
WordIndex BeginSentence() const { return begin_sentence_; }
WordIndex EndSentence() const { return end_sentence_; }
WordIndex NotFound() const { return not_found_; }
- // FullScoreReturn start index of unused word assignments.
- WordIndex Available() const { return available_; }
/* Most implementations allow StringPiece lookups and need only override
* Index(StringPiece). SRI requires null termination and overrides all
@@ -56,13 +54,13 @@ class Vocabulary {
// Call SetSpecial afterward.
Vocabulary() {}
- Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) {
- SetSpecial(begin_sentence, end_sentence, not_found, available);
+ Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) {
+ SetSpecial(begin_sentence, end_sentence, not_found);
}
- void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available);
+ void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found);
- WordIndex begin_sentence_, end_sentence_, not_found_, available_;
+ WordIndex begin_sentence_, end_sentence_, not_found_;
private:
// Disable copy constructors. They're private and undefined.
@@ -97,7 +95,7 @@ class Vocabulary {
* missing these methods, see facade.hh.
*
* This is the fastest way to use a model and presents a normal State class to
- * be included in hypothesis state structure.
+ * be included in a hypothesis state structure.
*
*
* OPTION 2: Use the virtual interface below.
diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc
new file mode 100644
index 00000000..c30428b2
--- /dev/null
+++ b/klm/lm/vocab.cc
@@ -0,0 +1,187 @@
+#include "lm/vocab.hh"
+
+#include "lm/enumerate_vocab.hh"
+#include "lm/lm_exception.hh"
+#include "lm/config.hh"
+#include "lm/weights.hh"
+#include "util/exception.hh"
+#include "util/joint_sort.hh"
+#include "util/murmur_hash.hh"
+#include "util/probing_hash_table.hh"
+
+#include <string>
+
+namespace lm {
+namespace ngram {
+
+namespace detail {
+uint64_t HashForVocab(const char *str, std::size_t len) {
+ // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000
+ // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit.
+ return util::MurmurHash64A(str, len, 0);
+}
+} // namespace detail
+
+namespace {
+// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok.
+const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
+// Sadly some LMs have <UNK>.
+const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
+
+void ReadWords(int fd, EnumerateVocab *enumerate) {
+ if (!enumerate) return;
+ const std::size_t kInitialRead = 16384;
+ std::string buf;
+ buf.reserve(kInitialRead + 100);
+ buf.resize(kInitialRead);
+ WordIndex index = 0;
+ while (true) {
+ ssize_t got = read(fd, &buf[0], kInitialRead);
+ if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
+ if (got == 0) return;
+ buf.resize(got);
+ while (buf[buf.size() - 1]) {
+ char next_char;
+ ssize_t ret = read(fd, &next_char, 1);
+ if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words");
+ if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word.");
+ buf.push_back(next_char);
+ }
+ // Ok now we have null terminated strings.
+ for (const char *i = buf.data(); i != buf.data() + buf.size();) {
+ std::size_t length = strlen(i);
+ enumerate->Add(index++, StringPiece(i, length));
+ i += length + 1 /* null byte */;
+ }
+ }
+}
+
+void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
+ const uint8_t *data = static_cast<const uint8_t*>(data_void);
+ while (size) {
+ ssize_t ret = write(fd, data, size);
+ if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed");
+ data += ret;
+ size -= ret;
+ }
+}
+
+} // namespace
+
+WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner, int fd) : inner_(inner), fd_(fd) {}
+WriteWordsWrapper::~WriteWordsWrapper() {}
+
+void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
+ if (inner_) inner_->Add(index, str);
+ WriteOrThrow(fd_, str.data(), str.size());
+ char null_byte = 0;
+ // Inefficient because it's unbuffered. Sue me.
+ WriteOrThrow(fd_, &null_byte, 1);
+}
+
+SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
+
+std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) {
+ // Lead with the number of entries.
+ return sizeof(uint64_t) + sizeof(Entry) * entries;
+}
+
+void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) {
+ assert(allocated >= Size(entries, config));
+ // Leave space for number of entries.
+ begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1);
+ end_ = begin_;
+ saw_unk_ = false;
+}
+
+void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
+ enumerate_ = to;
+ if (enumerate_) {
+ enumerate_->Add(0, "<unk>");
+ strings_to_enumerate_.resize(max_entries);
+ }
+}
+
+WordIndex SortedVocabulary::Insert(const StringPiece &str) {
+ uint64_t hashed = detail::HashForVocab(str);
+ if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
+ saw_unk_ = true;
+ return 0;
+ }
+ end_->key = hashed;
+ if (enumerate_) {
+ strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size());
+ }
+ ++end_;
+ // This is 1 + the offset where it was inserted to make room for unk.
+ return end_ - begin_;
+}
+
+void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
+ if (enumerate_) {
+ util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
+ util::JointSort(begin_, end_, values);
+ for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
+ // <unk> strikes again: +1 here.
+ enumerate_->Add(i + 1, strings_to_enumerate_[i]);
+ }
+ strings_to_enumerate_.clear();
+ } else {
+ util::JointSort(begin_, end_, reorder_vocab + 1);
+ }
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+ // Save size.
+ *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
+}
+
+void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+ end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
+ ReadWords(fd, to);
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+}
+
+ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
+
+std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) {
+ return Lookup::Size(entries, config.probing_multiplier);
+}
+
+void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
+ lookup_ = Lookup(start, allocated);
+ available_ = 1;
+ saw_unk_ = false;
+}
+
+void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
+ enumerate_ = to;
+ if (enumerate_) {
+ enumerate_->Add(0, "<unk>");
+ }
+}
+
+WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
+ uint64_t hashed = detail::HashForVocab(str);
+ // Prevent unknown from going into the table.
+ if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
+ saw_unk_ = true;
+ return 0;
+ } else {
+ if (enumerate_) enumerate_->Add(available_, str);
+ lookup_.Insert(Lookup::Packing::Make(hashed, available_));
+ return available_++;
+ }
+}
+
+void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) {
+ lookup_.FinishedInserting();
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+}
+
+void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) {
+ lookup_.LoadedBinary();
+ ReadWords(fd, to);
+ SetSpecial(Index("<s>"), Index("</s>"), 0);
+}
+
+} // namespace ngram
+} // namespace lm
diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am
index be84736c..9e38e0f1 100644
--- a/klm/util/Makefile.am
+++ b/klm/util/Makefile.am
@@ -20,6 +20,7 @@ noinst_LIBRARIES = libklm_util.a
libklm_util_a_SOURCES = \
ersatz_progress.cc \
+ bit_packing.cc \
exception.cc \
file_piece.cc \
mmap.cc \
diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc
new file mode 100644
index 00000000..dd14ffe1
--- /dev/null
+++ b/klm/util/bit_packing.cc
@@ -0,0 +1,27 @@
+#include "util/bit_packing.hh"
+#include "util/exception.hh"
+
+namespace util {
+
+namespace {
+template <bool> struct StaticCheck {};
+template <> struct StaticCheck<true> { typedef bool StaticAssertionPassed; };
+
+typedef StaticCheck<sizeof(float) == 4>::StaticAssertionPassed FloatSize;
+
+} // namespace
+
+uint8_t RequiredBits(uint64_t max_value) {
+ if (!max_value) return 0;
+ uint8_t ret = 1;
+ while (max_value >>= 1) ++ret;
+ return ret;
+}
+
+void BitPackingSanity() {
+ const detail::FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 };
+ if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000");
+ // TODO: more checks.
+}
+
+} // namespace util
diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh
new file mode 100644
index 00000000..422ed873
--- /dev/null
+++ b/klm/util/bit_packing.hh
@@ -0,0 +1,88 @@
+#ifndef UTIL_BIT_PACKING__
+#define UTIL_BIT_PACKING__
+
+/* Bit-level packing routines */
+
+#include <assert.h>
+#ifdef __APPLE__
+#include <architecture/byte_order.h>
+#else
+#include <endian.h>
+#endif
+
+#include <inttypes.h>
+
+#if __BYTE_ORDER != __LITTLE_ENDIAN
+#error The bit aligned storage functions assume little endian architecture
+#endif
+
+namespace util {
+
+/* WARNING WARNING WARNING:
+ * The write functions assume that memory is zero initially. This makes them
+ * faster and is the appropriate case for mmapped language model construction.
+ * These routines assume that unaligned access to uint64_t is fast and that
+ * storage is little endian. This is the case on x86_64. It may not be the
+ * case on 32-bit x86 but my target audience is large language models for which
+ * 64-bit is necessary.
+ */
+
+/* Pack integers up to 57 bits using their least significant digits.
+ * The length is specified using mask:
+ * Assumes mask == (1 << length) - 1 where length <= 57.
+ */
+inline uint64_t ReadInt57(const void *base, uint8_t bit, uint64_t mask) {
+ return (*reinterpret_cast<const uint64_t*>(base) >> bit) & mask;
+}
+/* Assumes value <= mask and mask == (1 << length) - 1 where length <= 57.
+ * Assumes the memory is zero initially.
+ */
+inline void WriteInt57(void *base, uint8_t bit, uint64_t value) {
+ *reinterpret_cast<uint64_t*>(base) |= (value << bit);
+}
+
+namespace detail { typedef union { float f; uint32_t i; } FloatEnc; }
+inline float ReadFloat32(const void *base, uint8_t bit) {
+ detail::FloatEnc encoded;
+ encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit;
+ return encoded.f;
+}
+inline void WriteFloat32(void *base, uint8_t bit, float value) {
+ detail::FloatEnc encoded;
+ encoded.f = value;
+ WriteInt57(base, bit, encoded.i);
+}
+
+inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) {
+ detail::FloatEnc encoded;
+ encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit;
+ // Sign bit set means negative.
+ encoded.i |= 0x80000000;
+ return encoded.f;
+}
+inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) {
+ assert(value <= 0.0);
+ detail::FloatEnc encoded;
+ encoded.f = value;
+ encoded.i &= ~0x80000000;
+ WriteInt57(base, bit, encoded.i);
+}
+
+void BitPackingSanity();
+
+// Return bits required to store integers upto max_value. Not the most
+// efficient implementation, but this is only called a few times to size tries.
+uint8_t RequiredBits(uint64_t max_value);
+
+struct BitsMask {
+ void FromMax(uint64_t max_value) {
+ bits = RequiredBits(max_value);
+ mask = (1 << bits) - 1;
+ }
+ uint8_t bits;
+ uint64_t mask;
+};
+
+} // namespace util
+
+#endif // UTIL_BIT_PACKING__
diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc
index 09e3a106..55c182bd 100644
--- a/klm/util/ersatz_progress.cc
+++ b/klm/util/ersatz_progress.cc
@@ -13,10 +13,7 @@ ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::s
ErsatzProgress::~ErsatzProgress() {
if (!out_) return;
- for (; stones_written_ < kWidth; ++stones_written_) {
- (*out_) << '*';
- }
- *out_ << '\n';
+ Finished();
}
ErsatzProgress::ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete)
@@ -36,8 +33,8 @@ void ErsatzProgress::Milestone() {
for (; stones_written_ < stone; ++stones_written_) {
(*out_) << '*';
}
-
- if (current_ >= complete_) {
+ if (stone == kWidth) {
+ (*out_) << std::endl;
next_ = std::numeric_limits<std::size_t>::max();
} else {
next_ = std::max(next_, (stone * complete_) / kWidth);
diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh
index ea6c3bb9..92c345fe 100644
--- a/klm/util/ersatz_progress.hh
+++ b/klm/util/ersatz_progress.hh
@@ -19,7 +19,7 @@ class ErsatzProgress {
~ErsatzProgress();
ErsatzProgress &operator++() {
- if (++current_ == next_) Milestone();
+ if (++current_ >= next_) Milestone();
return *this;
}
@@ -33,6 +33,10 @@ class ErsatzProgress {
Milestone();
}
+ void Finished() {
+ Set(complete_);
+ }
+
private:
void Milestone();
diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc
index 2b439499..e7bd8659 100644
--- a/klm/util/file_piece.cc
+++ b/klm/util/file_piece.cc
@@ -2,19 +2,23 @@
#include "util/exception.hh"
-#include <iostream>
#include <string>
#include <limits>
#include <assert.h>
-#include <cstdlib>
#include <ctype.h>
+#include <err.h>
#include <fcntl.h>
+#include <stdlib.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
+#ifdef HAVE_ZLIB
+#include <zlib.h>
+#endif
+
namespace util {
EndOfFileException::EndOfFileException() throw() {
@@ -26,6 +30,13 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() {
*this << "Could not parse \"" << value << "\" into a float";
}
+GZException::GZException(void *file) {
+#ifdef HAVE_ZLIB
+ int num;
+ *this << gzerror(file, &num) << " from zlib";
+#endif // HAVE_ZLIB
+}
+
int OpenReadOrThrow(const char *name) {
int ret = open(name, O_RDONLY);
if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading");
@@ -38,42 +49,73 @@ off_t SizeFile(int fd) {
return sb.st_size;
}
-FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) :
+FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :
file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),
progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
Initialize(name, show_progress, min_buffer);
}
-FilePiece::FilePiece(const char *name, int fd, std::ostream *show_progress, off_t min_buffer) :
+FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :
file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),
progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {
Initialize(name, show_progress, min_buffer);
}
-void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) {
- if (total_size_ == kBadSize) {
- fallback_to_read_ = true;
- if (show_progress)
- *show_progress << "File " << name << " isn't normal. Using slower read() instead of mmap(). No progress bar." << std::endl;
- } else {
- fallback_to_read_ = false;
+FilePiece::~FilePiece() {
+#ifdef HAVE_ZLIB
+ if (gz_file_) {
+ // zlib took ownership
+ file_.release();
+ int ret;
+ if (Z_OK != (ret = gzclose(gz_file_))) {
+ errx(1, "could not close file %s using zlib", file_name_.c_str());
+ }
}
+#endif
+}
+
+void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) {
+#ifdef HAVE_ZLIB
+ gz_file_ = NULL;
+#endif
+ file_name_ = name;
+
default_map_size_ = page_ * std::max<off_t>((min_buffer / page_ + 1), 2);
position_ = NULL;
position_end_ = NULL;
mapped_offset_ = 0;
at_end_ = false;
+
+ if (total_size_ == kBadSize) {
+ // So the assertion passes.
+ fallback_to_read_ = false;
+ if (show_progress)
+ *show_progress << "File " << name << " isn't normal. Using slower read() instead of mmap(). No progress bar." << std::endl;
+ TransitionToRead();
+ } else {
+ fallback_to_read_ = false;
+ }
Shift();
+ // gzip detect.
+ if ((position_end_ - position_) > 2 && *position_ == 0x1f && static_cast<unsigned char>(*(position_ + 1)) == 0x8b) {
+#ifndef HAVE_ZLIB
+ UTIL_THROW(GZException, "Looks like a gzip file but support was not compiled in.");
+#endif
+ if (!fallback_to_read_) {
+ at_end_ = false;
+ TransitionToRead();
+ }
+ }
}
-float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) {
+float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) {
SkipSpaces();
while (last_space_ < position_) {
if (at_end_) {
// Hallucinate a null off the end of the file.
std::string buffer(position_, position_end_);
char *end;
- float ret = std::strtof(buffer.c_str(), &end);
+ float ret = strtof(buffer.c_str(), &end);
if (buffer.c_str() == end) throw ParseNumberException(buffer);
position_ += end - buffer.c_str();
return ret;
@@ -81,20 +123,20 @@ float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) {
Shift();
}
char *end;
- float ret = std::strtof(position_, &end);
+ float ret = strtof(position_, &end);
if (end == position_) throw ParseNumberException(ReadDelimited());
position_ = end;
return ret;
}
-void FilePiece::SkipSpaces() throw (EndOfFileException) {
+void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) {
for (; ; ++position_) {
if (position_ == position_end_) Shift();
if (!isspace(*position_)) return;
}
}
-const char *FilePiece::FindDelimiterOrEOF() throw (EndOfFileException) {
+const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileException) {
for (const char *i = position_; i <= last_space_; ++i) {
if (isspace(*i)) return i;
}
@@ -108,7 +150,7 @@ const char *FilePiece::FindDelimiterOrEOF() throw (EndOfFileException) {
return position_end_;
}
-StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) {
+StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) {
const char *start = position_;
do {
for (const char *i = start; i < position_end_; ++i) {
@@ -124,17 +166,19 @@ StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) {
} while (!at_end_);
StringPiece ret(position_, position_end_ - position_);
position_ = position_end_;
- return position_;
+ return ret;
}
-void FilePiece::Shift() throw(EndOfFileException) {
- if (at_end_) throw EndOfFileException();
+void FilePiece::Shift() throw(GZException, EndOfFileException) {
+ if (at_end_) {
+ progress_.Finished();
+ throw EndOfFileException();
+ }
off_t desired_begin = position_ - data_.begin() + mapped_offset_;
- progress_.Set(desired_begin);
if (!fallback_to_read_) MMapShift(desired_begin);
// Notice an mmap failure might set the fallback.
- if (fallback_to_read_) ReadShift(desired_begin);
+ if (fallback_to_read_) ReadShift();
for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) {
if (isspace(*last_space_)) break;
@@ -163,28 +207,41 @@ void FilePiece::MMapShift(off_t desired_begin) throw() {
data_.reset();
data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED);
if (data_.get() == MAP_FAILED) {
- fallback_to_read_ = true;
if (desired_begin) {
if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either.");
}
+ // The mmap was scheduled to end the file, but now we're going to read it.
+ at_end_ = false;
+ TransitionToRead();
return;
}
mapped_offset_ = mapped_offset;
position_ = data_.begin() + ignore;
position_end_ = data_.begin() + mapped_size;
+
+ progress_.Set(desired_begin);
+}
+
+void FilePiece::TransitionToRead() throw (GZException) {
+ assert(!fallback_to_read_);
+ fallback_to_read_ = true;
+ data_.reset();
+ data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED);
+ if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_);
+ position_ = data_.begin();
+ position_end_ = position_;
+
+#ifdef HAVE_ZLIB
+ assert(!gz_file_);
+ gz_file_ = gzdopen(file_.get(), "r");
+ if (!gz_file_) {
+ UTIL_THROW(GZException, "zlib failed to open " << file_name_);
+ }
+#endif
}
-void FilePiece::ReadShift(off_t desired_begin) throw() {
+void FilePiece::ReadShift() throw(GZException, EndOfFileException) {
assert(fallback_to_read_);
- if (data_.source() != scoped_memory::MALLOC_ALLOCATED) {
- // First call.
- data_.reset();
- data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED);
- if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_);
- position_ = data_.begin();
- position_end_ = position_;
- }
-
// Bytes [data_.begin(), position_) have been consumed.
// Bytes [position_, position_end_) have been read into the buffer.
@@ -215,9 +272,23 @@ void FilePiece::ReadShift(off_t desired_begin) throw() {
}
}
- ssize_t read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
+ ssize_t read_return;
+#ifdef HAVE_ZLIB
+ read_return = gzread(gz_file_, static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
+ if (read_return == -1) throw GZException(gz_file_);
+ if (total_size_ != kBadSize) {
+ // Just get the position, don't actually seek. Apparently this is how you do it. . .
+ off_t ret = lseek(file_.get(), 0, SEEK_CUR);
+ if (ret != -1) progress_.Set(ret);
+ }
+#else
+ read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);
if (read_return == -1) UTIL_THROW(ErrnoException, "read failed");
- if (read_return == 0) at_end_ = true;
+ progress_.Set(mapped_offset_);
+#endif
+ if (read_return == 0) {
+ at_end_ = true;
+ }
position_end_ += read_return;
}
diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh
index 704f0ac6..11d4a751 100644
--- a/klm/util/file_piece.hh
+++ b/klm/util/file_piece.hh
@@ -11,6 +11,8 @@
#include <cstddef>
+#define HAVE_ZLIB
+
namespace util {
class EndOfFileException : public Exception {
@@ -25,6 +27,13 @@ class ParseNumberException : public Exception {
~ParseNumberException() throw() {}
};
+class GZException : public Exception {
+ public:
+ explicit GZException(void *file);
+ GZException() throw() {}
+ ~GZException() throw() {}
+};
+
int OpenReadOrThrow(const char *name);
// Return value for SizeFile when it can't size properly.
@@ -34,40 +43,42 @@ off_t SizeFile(int fd);
class FilePiece {
public:
// 32 MB default.
- explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);
+ explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException);
// Takes ownership of fd. name is used for messages.
- explicit FilePiece(const char *name, int fd, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);
+ explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException);
+
+ ~FilePiece();
- char get() throw(EndOfFileException) {
- if (position_ == position_end_) Shift();
+ char get() throw(GZException, EndOfFileException) {
+ if (position_ == position_end_) {
+ Shift();
+ if (at_end_) throw EndOfFileException();
+ }
return *(position_++);
}
// Memory backing the returned StringPiece may vanish on the next call.
// Leaves the delimiter, if any, to be returned by get().
- StringPiece ReadDelimited() throw(EndOfFileException) {
+ StringPiece ReadDelimited() throw(GZException, EndOfFileException) {
SkipSpaces();
return Consume(FindDelimiterOrEOF());
}
// Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.
// It is similar to getline in that way.
- StringPiece ReadLine(char delim = '\n') throw(EndOfFileException);
+ StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException);
- float ReadFloat() throw(EndOfFileException, ParseNumberException);
+ float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException);
- void SkipSpaces() throw (EndOfFileException);
+ void SkipSpaces() throw (GZException, EndOfFileException);
off_t Offset() const {
return position_ - data_.begin() + mapped_offset_;
}
- // Only for testing.
- void ForceFallbackToRead() {
- fallback_to_read_ = true;
- }
+ const std::string &FileName() const { return file_name_; }
private:
- void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer);
+ void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException);
StringPiece Consume(const char *to) {
StringPiece ret(position_, to - position_);
@@ -75,12 +86,14 @@ class FilePiece {
return ret;
}
- const char *FindDelimiterOrEOF() throw(EndOfFileException);
+ const char *FindDelimiterOrEOF() throw(EndOfFileException, GZException);
- void Shift() throw (EndOfFileException);
+ void Shift() throw (EndOfFileException, GZException);
// Backends to Shift().
void MMapShift(off_t desired_begin) throw ();
- void ReadShift(off_t desired_begin) throw ();
+
+ void TransitionToRead() throw (GZException);
+ void ReadShift() throw (GZException, EndOfFileException);
const char *position_, *last_space_, *position_end_;
@@ -98,6 +111,12 @@ class FilePiece {
bool fallback_to_read_;
ErsatzProgress progress_;
+
+ std::string file_name_;
+
+#ifdef HAVE_ZLIB
+ void *gz_file_;
+#endif // HAVE_ZLIB
};
} // namespace util
diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc
index befb7866..23e79fe0 100644
--- a/klm/util/file_piece_test.cc
+++ b/klm/util/file_piece_test.cc
@@ -1,15 +1,19 @@
#include "util/file_piece.hh"
+#include "util/scoped.hh"
+
#define BOOST_TEST_MODULE FilePieceTest
#include <boost/test/unit_test.hpp>
#include <fstream>
#include <iostream>
+#include <stdio.h>
+
namespace util {
namespace {
/* mmap implementation */
-BOOST_AUTO_TEST_CASE(MMapLine) {
+BOOST_AUTO_TEST_CASE(MMapReadLine) {
std::fstream ref("file_piece.cc", std::ios::in);
FilePiece test("file_piece.cc", NULL, 1);
std::string ref_line;
@@ -20,13 +24,17 @@ BOOST_AUTO_TEST_CASE(MMapLine) {
BOOST_CHECK_EQUAL(ref_line, test_line);
}
}
+ BOOST_CHECK_THROW(test.get(), EndOfFileException);
}
/* read() implementation */
-BOOST_AUTO_TEST_CASE(ReadLine) {
+BOOST_AUTO_TEST_CASE(StreamReadLine) {
std::fstream ref("file_piece.cc", std::ios::in);
- FilePiece test("file_piece.cc", NULL, 1);
- test.ForceFallbackToRead();
+
+ scoped_FILE catter(popen("cat file_piece.cc", "r"));
+ BOOST_REQUIRE(catter.get());
+
+ FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1);
std::string ref_line;
while (getline(ref, ref_line)) {
StringPiece test_line(test.ReadLine());
@@ -35,7 +43,47 @@ BOOST_AUTO_TEST_CASE(ReadLine) {
BOOST_CHECK_EQUAL(ref_line, test_line);
}
}
+ BOOST_CHECK_THROW(test.get(), EndOfFileException);
}
+#ifdef HAVE_ZLIB
+
+// gzip file
+BOOST_AUTO_TEST_CASE(PlainZipReadLine) {
+ std::fstream ref("file_piece.cc", std::ios::in);
+
+ BOOST_REQUIRE_EQUAL(0, system("gzip <file_piece.cc >file_piece.cc.gz"));
+ FilePiece test("file_piece.cc.gz", NULL, 1);
+ std::string ref_line;
+ while (getline(ref, ref_line)) {
+ StringPiece test_line(test.ReadLine());
+ // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924
+ if (!test_line.empty() || !ref_line.empty()) {
+ BOOST_CHECK_EQUAL(ref_line, test_line);
+ }
+ }
+ BOOST_CHECK_THROW(test.get(), EndOfFileException);
+}
+// gzip stream
+BOOST_AUTO_TEST_CASE(StreamZipReadLine) {
+ std::fstream ref("file_piece.cc", std::ios::in);
+
+ scoped_FILE catter(popen("gzip <file_piece.cc", "r"));
+ BOOST_REQUIRE(catter.get());
+
+ FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1);
+ std::string ref_line;
+ while (getline(ref, ref_line)) {
+ StringPiece test_line(test.ReadLine());
+ // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924
+ if (!test_line.empty() || !ref_line.empty()) {
+ BOOST_CHECK_EQUAL(ref_line, test_line);
+ }
+ }
+ BOOST_CHECK_THROW(test.get(), EndOfFileException);
+}
+
+#endif
+
} // namespace
} // namespace util
diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh
index a2f1c01d..cf3d8432 100644
--- a/klm/util/joint_sort.hh
+++ b/klm/util/joint_sort.hh
@@ -119,6 +119,12 @@ template <class Proxy, class Less> class LessWrapper : public std::binary_functi
} // namespace detail
+template <class KeyIter, class ValueIter> class PairedIterator : public ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > {
+ public:
+ PairedIterator(const KeyIter &key, const ValueIter &value) :
+ ProxyIterator<detail::JointProxy<KeyIter, ValueIter> >(detail::JointProxy<KeyIter, ValueIter>(key, value)) {}
+};
+
template <class KeyIter, class ValueIter, class Less> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin, const Less &less) {
ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > full_begin(detail::JointProxy<KeyIter, ValueIter>(key_begin, value_begin));
detail::LessWrapper<detail::JointProxy<KeyIter, ValueIter>, Less> less_wrap(less);
diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc
index 648b5d0a..8685170f 100644
--- a/klm/util/mmap.cc
+++ b/klm/util/mmap.cc
@@ -53,10 +53,8 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int
if (prefault) {
flags |= MAP_POPULATE;
}
- int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ;
-#else
- int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ;
#endif
+ int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ;
void *ret = mmap(NULL, size, protect, flags, fd, offset);
if (ret == MAP_FAILED) {
UTIL_THROW(ErrnoException, "mmap failed for size " << size << " at offset " << offset);
@@ -64,8 +62,40 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int
return ret;
}
-void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset) {
- return MapOrThrow(size, false, MAP_FILE | MAP_PRIVATE, prefault, fd, offset);
+namespace {
+void ReadAll(int fd, void *to_void, std::size_t amount) {
+ uint8_t *to = static_cast<uint8_t*>(to_void);
+ while (amount) {
+ ssize_t ret = read(fd, to, amount);
+ if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed.");
+ if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read.");
+ amount -= ret;
+ to += ret;
+ }
+}
+} // namespace
+
+void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out) {
+ switch (method) {
+ case LAZY:
+ out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, false, fd, offset), size, scoped_memory::MMAP_ALLOCATED);
+ break;
+ case POPULATE_OR_LAZY:
+#ifdef MAP_POPULATE
+ case POPULATE_OR_READ:
+#endif
+ out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, true, fd, offset), size, scoped_memory::MMAP_ALLOCATED);
+ break;
+#ifndef MAP_POPULATE
+ case POPULATE_OR_READ:
+#endif
+ case READ:
+ out.reset(malloc(size), size, scoped_memory::MALLOC_ALLOCATED);
+ if (!out.get()) UTIL_THROW(util::ErrnoException, "Allocating " << size << " bytes with malloc");
+ if (-1 == lseek(fd, offset, SEEK_SET)) UTIL_THROW(ErrnoException, "lseek to " << offset << " in fd " << fd << " failed.");
+ ReadAll(fd, out.get(), size);
+ break;
+ }
}
void *MapAnonymous(std::size_t size) {
@@ -78,14 +108,14 @@ void *MapAnonymous(std::size_t size) {
| MAP_PRIVATE, false, -1, 0);
}
-void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem) {
+void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) {
file.reset(open(name, O_CREAT | O_RDWR | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH));
if (-1 == file.get())
UTIL_THROW(ErrnoException, "Failed to open " << name << " for writing");
if (-1 == ftruncate(file.get(), size))
UTIL_THROW(ErrnoException, "ftruncate on " << name << " to " << size << " failed");
try {
- mem.reset(MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0), size);
+ return MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0);
} catch (ErrnoException &e) {
e << " in file " << name;
throw;
diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh
index c9068ec9..0a504d89 100644
--- a/klm/util/mmap.hh
+++ b/klm/util/mmap.hh
@@ -6,6 +6,7 @@
#include <cstddef>
+#include <inttypes.h>
#include <sys/types.h>
namespace util {
@@ -19,8 +20,8 @@ class scoped_mmap {
void *get() const { return data_; }
- const char *begin() const { return reinterpret_cast<char*>(data_); }
- const char *end() const { return reinterpret_cast<char*>(data_) + size_; }
+ const uint8_t *begin() const { return reinterpret_cast<uint8_t*>(data_); }
+ const uint8_t *end() const { return reinterpret_cast<uint8_t*>(data_) + size_; }
std::size_t size() const { return size_; }
void reset(void *data, std::size_t size) {
@@ -79,23 +80,27 @@ class scoped_memory {
scoped_memory &operator=(const scoped_memory &);
};
-struct scoped_mapped_file {
- scoped_fd fd;
- scoped_mmap mem;
-};
+typedef enum {
+ // mmap with no prepopulate
+ LAZY,
+ // On linux, pass MAP_POPULATE to mmap.
+ POPULATE_OR_LAZY,
+ // Populate on Linux. malloc and read on non-Linux.
+ POPULATE_OR_READ,
+ // malloc and read.
+ READ
+} LoadMethod;
+
// Wrapper around mmap to check it worked and hide some platform macros.
void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset = 0);
-void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset = 0);
+void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out);
void *MapAnonymous(std::size_t size);
// Open file name with mmap of size bytes, all of which are initially zero.
-void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem);
-inline void MapZeroedWrite(const char *name, std::size_t size, scoped_mapped_file &out) {
- MapZeroedWrite(name, size, out.fd, out.mem);
-}
-
+void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file);
+
} // namespace util
-#endif // UTIL_SCOPED__
+#endif // UTIL_MMAP__
diff --git a/klm/util/proxy_iterator.hh b/klm/util/proxy_iterator.hh
index 1c5b7089..121a45fa 100644
--- a/klm/util/proxy_iterator.hh
+++ b/klm/util/proxy_iterator.hh
@@ -78,6 +78,8 @@ template <class Proxy> class ProxyIterator {
const Proxy *operator->() const { return &p_; }
Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); }
+ const InnerIterator &Inner() { return p_.Inner(); }
+
private:
InnerIterator &I() { return p_.Inner(); }
const InnerIterator &I() const { return p_.Inner(); }
diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc
index 61394ffc..2c6d5394 100644
--- a/klm/util/scoped.cc
+++ b/klm/util/scoped.cc
@@ -9,4 +9,8 @@ scoped_fd::~scoped_fd() {
if (fd_ != -1 && close(fd_)) err(1, "Could not close file %i", fd_);
}
+scoped_FILE::~scoped_FILE() {
+ if (file_ && fclose(file_)) err(1, "Could not close file");
+}
+
} // namespace util
diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh
index ef62a74f..52864481 100644
--- a/klm/util/scoped.hh
+++ b/klm/util/scoped.hh
@@ -4,6 +4,7 @@
/* Other scoped objects in the style of scoped_ptr. */
#include <cstddef>
+#include <cstdio>
namespace util {
@@ -61,6 +62,24 @@ class scoped_fd {
scoped_fd &operator=(const scoped_fd &);
};
+class scoped_FILE {
+ public:
+ explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {}
+
+ ~scoped_FILE();
+
+ std::FILE *get() { return file_; }
+ const std::FILE *get() const { return file_; }
+
+ void reset(std::FILE *to = NULL) {
+ scoped_FILE other(file_);
+ file_ = to;
+ }
+
+ private:
+ std::FILE *file_;
+};
+
} // namespace util
#endif // UTIL_SCOPED__
diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh
index 96ec4866..a8e208fb 100644
--- a/klm/util/sorted_uniform.hh
+++ b/klm/util/sorted_uniform.hh
@@ -65,7 +65,7 @@ template <class PackingT> class SortedUniformMap {
public:
// Offer consistent API with probing hash.
- static std::size_t Size(std::size_t entries, float ignore = 0.0) {
+ static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) {
return sizeof(uint64_t) + entries * Packing::kBytes;
}
@@ -75,7 +75,7 @@ template <class PackingT> class SortedUniformMap {
#endif
{}
- SortedUniformMap(void *start, std::size_t allocated) :
+ SortedUniformMap(void *start, std::size_t /*allocated*/) :
begin_(Packing::FromVoid(reinterpret_cast<uint64_t*>(start) + 1)),
end_(begin_), size_ptr_(reinterpret_cast<uint64_t*>(start))
#ifdef DEBUG
diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc
index 6917a6bc..5b4e98f5 100644
--- a/klm/util/string_piece.cc
+++ b/klm/util/string_piece.cc
@@ -30,14 +30,14 @@
#include "util/string_piece.hh"
-#ifdef USE_BOOST
+#ifdef HAVE_BOOST
#include <boost/functional/hash/hash.hpp>
#endif
#include <algorithm>
#include <iostream>
-#ifdef USE_ICU
+#ifdef HAVE_ICU
U_NAMESPACE_BEGIN
#endif
@@ -46,12 +46,12 @@ std::ostream& operator<<(std::ostream& o, const StringPiece& piece) {
return o;
}
-#ifdef USE_BOOST
+#ifdef HAVE_BOOST
size_t hash_value(const StringPiece &str) {
return boost::hash_range(str.data(), str.data() + str.length());
}
#endif
-#ifdef USE_ICU
+#ifdef HAVE_ICU
U_NAMESPACE_END
#endif
diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh
index 58008d13..3ac2f8a7 100644
--- a/klm/util/string_piece.hh
+++ b/klm/util/string_piece.hh
@@ -1,4 +1,4 @@
-/* If you use ICU in your program, then compile with -DUSE_ICU -licui18n. If
+/* If you use ICU in your program, then compile with -DHAVE_ICU -licui18n. If
* you don't use ICU, then this will use the Google implementation from Chrome.
* This has been modified from the original version to let you choose.
*/
@@ -49,14 +49,14 @@
#define BASE_STRING_PIECE_H__
//Uncomment this line if you use ICU in your code.
-//#define USE_ICU
+//#define HAVE_ICU
//Uncomment this line if you want boost hashing for your StringPieces.
-//#define USE_BOOST
+//#define HAVE_BOOST
#include <cstring>
#include <iosfwd>
-#ifdef USE_ICU
+#ifdef HAVE_ICU
#include <unicode/stringpiece.h>
U_NAMESPACE_BEGIN
#else
@@ -230,7 +230,7 @@ inline bool operator>=(const StringPiece& x, const StringPiece& y) {
// allow StringPiece to be logged (needed for unit testing).
extern std::ostream& operator<<(std::ostream& o, const StringPiece& piece);
-#ifdef USE_BOOST
+#ifdef HAVE_BOOST
size_t hash_value(const StringPiece &str);
/* Support for lookup of StringPiece in boost::unordered_map<std::string> */
@@ -253,7 +253,7 @@ template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece
}
#endif
-#ifdef USE_ICU
+#ifdef HAVE_ICU
U_NAMESPACE_END
#endif