summaryrefslogtreecommitdiff
path: root/klm/lm
diff options
context:
space:
mode:
Diffstat (limited to 'klm/lm')
-rw-r--r--klm/lm/Makefile.am20
-rw-r--r--klm/lm/exception.cc21
-rw-r--r--klm/lm/exception.hh40
-rw-r--r--klm/lm/facade.hh64
-rw-r--r--klm/lm/ngram.cc522
-rw-r--r--klm/lm/ngram.hh226
-rw-r--r--klm/lm/ngram_build_binary.cc13
-rw-r--r--klm/lm/ngram_config.hh58
-rw-r--r--klm/lm/ngram_query.cc72
-rw-r--r--klm/lm/ngram_test.cc91
-rw-r--r--klm/lm/sri.cc115
-rw-r--r--klm/lm/sri.hh102
-rw-r--r--klm/lm/sri_test.cc65
-rw-r--r--klm/lm/test.arpa112
-rw-r--r--klm/lm/test.binarybin0 -> 1660 bytes
-rw-r--r--klm/lm/virtual_interface.cc22
-rw-r--r--klm/lm/virtual_interface.hh156
-rw-r--r--klm/lm/word_index.hh11
18 files changed, 1710 insertions, 0 deletions
diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am
new file mode 100644
index 00000000..a0c49eb4
--- /dev/null
+++ b/klm/lm/Makefile.am
@@ -0,0 +1,20 @@
+if HAVE_GTEST
+noinst_PROGRAMS = \
+ ngram_test
+TESTS = ngram_test
+endif
+
+noinst_LIBRARIES = libklm.a
+
+libklm_a_SOURCES = \
+ exception.cc \
+ ngram.cc \
+ ngram_build_binary.cc \
+ ngram_query.cc \
+ virtual_interface.cc
+
+ngram_test_SOURCES = ngram_test.cc
+ngram_test_LDADD = ../util/libklm_util.a
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I..
+
diff --git a/klm/lm/exception.cc b/klm/lm/exception.cc
new file mode 100644
index 00000000..59a1650d
--- /dev/null
+++ b/klm/lm/exception.cc
@@ -0,0 +1,21 @@
+#include "lm/exception.hh"
+
+#include<errno.h>
+#include<stdio.h>
+
+namespace lm {
+
+LoadException::LoadException() throw() {}
+LoadException::~LoadException() throw() {}
+VocabLoadException::VocabLoadException() throw() {}
+VocabLoadException::~VocabLoadException() throw() {}
+
+FormatLoadException::FormatLoadException() throw() {}
+FormatLoadException::~FormatLoadException() throw() {}
+
+SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() {
+ *this << "Missing special word " << which;
+}
+SpecialWordMissingException::~SpecialWordMissingException() throw() {}
+
+} // namespace lm
diff --git a/klm/lm/exception.hh b/klm/lm/exception.hh
new file mode 100644
index 00000000..95109012
--- /dev/null
+++ b/klm/lm/exception.hh
@@ -0,0 +1,40 @@
+#ifndef LM_EXCEPTION__
+#define LM_EXCEPTION__
+
+#include "util/exception.hh"
+#include "util/string_piece.hh"
+
+#include <exception>
+#include <string>
+
+namespace lm {
+
+class LoadException : public util::Exception {
+ public:
+ virtual ~LoadException() throw();
+
+ protected:
+ LoadException() throw();
+};
+
+class VocabLoadException : public LoadException {
+ public:
+ virtual ~VocabLoadException() throw();
+ VocabLoadException() throw();
+};
+
+class FormatLoadException : public LoadException {
+ public:
+ FormatLoadException() throw();
+ ~FormatLoadException() throw();
+};
+
+class SpecialWordMissingException : public VocabLoadException {
+ public:
+ explicit SpecialWordMissingException(StringPiece which) throw();
+ ~SpecialWordMissingException() throw();
+};
+
+} // namespace lm
+
+#endif
diff --git a/klm/lm/facade.hh b/klm/lm/facade.hh
new file mode 100644
index 00000000..8b186017
--- /dev/null
+++ b/klm/lm/facade.hh
@@ -0,0 +1,64 @@
+#ifndef LM_FACADE__
+#define LM_FACADE__
+
+#include "lm/virtual_interface.hh"
+#include "util/string_piece.hh"
+
+#include <string>
+
+namespace lm {
+namespace base {
+
+// Common model interface that depends on knowing the specific classes.
+// Curiously recurring template pattern.
+template <class Child, class StateT, class VocabularyT> class ModelFacade : public Model {
+ public:
+ typedef StateT State;
+ typedef VocabularyT Vocabulary;
+
+ // Default Score function calls FullScore. Model can override this.
+ float Score(const State &in_state, const WordIndex new_word, State &out_state) const {
+ return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob;
+ }
+
+ /* Translate from void* to State */
+ FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const {
+ return static_cast<const Child*>(this)->FullScore(
+ *reinterpret_cast<const State*>(in_state),
+ new_word,
+ *reinterpret_cast<State*>(out_state));
+ }
+ float Score(const void *in_state, const WordIndex new_word, void *out_state) const {
+ return static_cast<const Child*>(this)->Score(
+ *reinterpret_cast<const State*>(in_state),
+ new_word,
+ *reinterpret_cast<State*>(out_state));
+ }
+
+ const State &BeginSentenceState() const { return begin_sentence_; }
+ const State &NullContextState() const { return null_context_; }
+ const Vocabulary &GetVocabulary() const { return *static_cast<const Vocabulary*>(&BaseVocabulary()); }
+
+ protected:
+ ModelFacade() : Model(sizeof(State)) {}
+
+ virtual ~ModelFacade() {}
+
+ // begin_sentence and null_context can disappear after. vocab should stay.
+ void Init(const State &begin_sentence, const State &null_context, const Vocabulary &vocab, unsigned char order) {
+ begin_sentence_ = begin_sentence;
+ null_context_ = null_context;
+ begin_sentence_memory_ = &begin_sentence_;
+ null_context_memory_ = &null_context_;
+ base_vocab_ = &vocab;
+ order_ = order;
+ }
+
+ private:
+ State begin_sentence_, null_context_;
+};
+
+} // mamespace base
+} // namespace lm
+
+#endif // LM_FACADE__
diff --git a/klm/lm/ngram.cc b/klm/lm/ngram.cc
new file mode 100644
index 00000000..a87c82aa
--- /dev/null
+++ b/klm/lm/ngram.cc
@@ -0,0 +1,522 @@
+#include "lm/ngram.hh"
+
+#include "lm/exception.hh"
+#include "util/file_piece.hh"
+#include "util/joint_sort.hh"
+#include "util/murmur_hash.hh"
+#include "util/probing_hash_table.hh"
+
+#include <algorithm>
+#include <functional>
+#include <numeric>
+#include <limits>
+#include <string>
+
+#include <cmath>
+#include <fcntl.h>
+#include <errno.h>
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+namespace lm {
+namespace ngram {
+
+size_t hash_value(const State &state) {
+ return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_);
+}
+
+namespace detail {
+uint64_t HashForVocab(const char *str, std::size_t len) {
+ // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000
+ // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit.
+ return util::MurmurHash64A(str, len, 0);
+}
+
+void Prob::SetBackoff(float to) {
+ UTIL_THROW(FormatLoadException, "Attempt to set backoff " << to << " for the highest order n-gram");
+}
+
+// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok.
+const uint64_t kUnknownHash = HashForVocab("<unk>", 5);
+// Sadly some LMs have <UNK>.
+const uint64_t kUnknownCapHash = HashForVocab("<UNK>", 5);
+
+} // namespace detail
+
+SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL) {}
+
+std::size_t SortedVocabulary::Size(std::size_t entries, float ignored) {
+ // Lead with the number of entries.
+ return sizeof(uint64_t) + sizeof(Entry) * entries;
+}
+
+void SortedVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) {
+ assert(allocated >= Size(entries));
+ // Leave space for number of entries.
+ begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1);
+ end_ = begin_;
+ saw_unk_ = false;
+}
+
+WordIndex SortedVocabulary::Insert(const StringPiece &str) {
+ uint64_t hashed = detail::HashForVocab(str);
+ if (hashed == detail::kUnknownHash || hashed == detail::kUnknownCapHash) {
+ saw_unk_ = true;
+ return 0;
+ }
+ end_->key = hashed;
+ ++end_;
+ // This is 1 + the offset where it was inserted to make room for unk.
+ return end_ - begin_;
+}
+
+bool SortedVocabulary::FinishedLoading(detail::ProbBackoff *reorder_vocab) {
+ util::JointSort(begin_, end_, reorder_vocab + 1);
+ SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1);
+ // Save size.
+ *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
+ return saw_unk_;
+}
+
+void SortedVocabulary::LoadedBinary() {
+ end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
+ SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1);
+}
+
+namespace detail {
+
+template <class Search> MapVocabulary<Search>::MapVocabulary() {}
+
+template <class Search> void MapVocabulary<Search>::Init(void *start, std::size_t allocated, std::size_t entries) {
+ lookup_ = Lookup(start, allocated);
+ available_ = 1;
+ // Later if available_ != expected_available_ then we can throw UnknownMissingException.
+ saw_unk_ = false;
+}
+
+template <class Search> WordIndex MapVocabulary<Search>::Insert(const StringPiece &str) {
+ uint64_t hashed = HashForVocab(str);
+ // Prevent unknown from going into the table.
+ if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
+ saw_unk_ = true;
+ return 0;
+ } else {
+ lookup_.Insert(Lookup::Packing::Make(hashed, available_));
+ return available_++;
+ }
+}
+
+template <class Search> bool MapVocabulary<Search>::FinishedLoading(ProbBackoff *reorder_vocab) {
+ lookup_.FinishedInserting();
+ SetSpecial(Index("<s>"), Index("</s>"), 0, available_);
+ return saw_unk_;
+}
+
+template <class Search> void MapVocabulary<Search>::LoadedBinary() {
+ lookup_.LoadedBinary();
+ SetSpecial(Index("<s>"), Index("</s>"), 0, available_);
+}
+
+/* All of the entropy is in low order bits and boost::hash does poorly with
+ * these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a
+ * stable point: 0. But 0 is <unk> which won't be queried here anyway.
+ */
+inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
+ uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
+ return ret;
+}
+
+uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) {
+ if (word == word_end) return 0;
+ uint64_t current = static_cast<uint64_t>(*word);
+ for (++word; word != word_end; ++word) {
+ current = CombineWordHash(current, *word);
+ }
+ return current;
+}
+
+bool IsEntirelyWhiteSpace(const StringPiece &line) {
+ for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
+ if (!isspace(line.data()[i])) return false;
+ }
+ return true;
+}
+
+void ReadARPACounts(util::FilePiece &in, std::vector<size_t> &number) {
+ number.clear();
+ StringPiece line;
+ if (!IsEntirelyWhiteSpace(line = in.ReadLine())) UTIL_THROW(FormatLoadException, "First line was \"" << line << "\" not blank");
+ if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");
+ while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
+ if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
+ // So strtol doesn't go off the end of line.
+ std::string remaining(line.data() + 6, line.size() - 6);
+ char *end_ptr;
+ unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10);
+ if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line);
+ if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line);
+ ++end_ptr;
+ const char *start = end_ptr;
+ long int count = std::strtol(start, &end_ptr, 10);
+ if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count);
+ if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line);
+ number.push_back(count);
+ }
+}
+
+void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
+ StringPiece line;
+ while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
+ std::stringstream expected;
+ expected << '\\' << length << "-grams:";
+ if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead.");
+}
+
+// Special unigram reader because unigram's data structure is different and because we're inserting vocab words.
+template <class Voc> void Read1Grams(util::FilePiece &f, const size_t count, Voc &vocab, ProbBackoff *unigrams) {
+ ReadNGramHeader(f, 1);
+ for (size_t i = 0; i < count; ++i) {
+ try {
+ float prob = f.ReadFloat();
+ if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability");
+ ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())];
+ value.prob = prob;
+ switch (f.get()) {
+ case '\t':
+ value.SetBackoff(f.ReadFloat());
+ if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
+ break;
+ case '\n':
+ value.ZeroBackoff();
+ break;
+ default:
+ UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram");
+ }
+ } catch(util::Exception &e) {
+ e << " in the " << i << "th 1-gram at byte " << f.Offset();
+ throw;
+ }
+ }
+ if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after unigrams at byte " << f.Offset());
+}
+
+template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) {
+ ReadNGramHeader(f, n);
+
+ // vocab ids of words in reverse order
+ WordIndex vocab_ids[n];
+ typename Store::Packing::Value value;
+ for (size_t i = 0; i < count; ++i) {
+ try {
+ value.prob = f.ReadFloat();
+ for (WordIndex *vocab_out = &vocab_ids[n-1]; vocab_out >= vocab_ids; --vocab_out) {
+ *vocab_out = vocab.Index(f.ReadDelimited());
+ }
+ uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n);
+
+ switch (f.get()) {
+ case '\t':
+ value.SetBackoff(f.ReadFloat());
+ if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff");
+ break;
+ case '\n':
+ value.ZeroBackoff();
+ break;
+ default:
+ UTIL_THROW(FormatLoadException, "Expected tab or newline after n-gram");
+ }
+ store.Insert(Store::Packing::Make(key, value));
+ } catch(util::Exception &e) {
+ e << " in the " << i << "th " << n << "-gram at byte " << f.Offset();
+ throw;
+ }
+ }
+
+ if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after " << n << "-grams at byte " << f.Offset());
+ store.FinishedInserting();
+}
+
+template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<size_t> &counts, const Config &config) {
+ if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile.");
+ if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
+ size_t memory_size = VocabularyT::Size(counts[0], config.probing_multiplier);
+ memory_size += sizeof(ProbBackoff) * (counts[0] + 1); // +1 for hallucinate <unk>
+ for (unsigned char n = 2; n < counts.size(); ++n) {
+ memory_size += Middle::Size(counts[n - 1], config.probing_multiplier);
+ }
+ memory_size += Longest::Size(counts.back(), config.probing_multiplier);
+ return memory_size;
+}
+
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(char *base, const std::vector<size_t> &counts, const Config &config) {
+ char *start = base;
+ size_t allocated = VocabularyT::Size(counts[0], config.probing_multiplier);
+ vocab_.Init(start, allocated, counts[0]);
+ start += allocated;
+ unigram_ = reinterpret_cast<ProbBackoff*>(start);
+ start += sizeof(ProbBackoff) * (counts[0] + 1);
+ for (unsigned int n = 2; n < counts.size(); ++n) {
+ allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
+ middle_.push_back(Middle(start, allocated));
+ start += allocated;
+ }
+ allocated = Longest::Size(counts.back(), config.probing_multiplier);
+ longest_ = Longest(start, allocated);
+ start += allocated;
+ if (static_cast<std::size_t>(start - base) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - base) << " but Size says they should take " << Size(counts, config));
+}
+
+const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 0\n\0";
+struct BinaryFileHeader {
+ char magic[sizeof(kMagicBytes)];
+ float zero_f, one_f, minus_half_f;
+ WordIndex one_word_index, max_word_index;
+ uint64_t one_uint64;
+
+ void SetToReference() {
+ std::memcpy(magic, kMagicBytes, sizeof(magic));
+ zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
+ one_word_index = 1;
+ max_word_index = std::numeric_limits<WordIndex>::max();
+ one_uint64 = 1;
+ }
+};
+
+bool IsBinaryFormat(int fd, off_t size) {
+ if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(BinaryFileHeader)))) return false;
+ // Try reading the header.
+ util::scoped_mmap memory(mmap(NULL, sizeof(BinaryFileHeader), PROT_READ, MAP_FILE | MAP_PRIVATE, fd, 0), sizeof(BinaryFileHeader));
+ if (memory.get() == MAP_FAILED) return false;
+ BinaryFileHeader reference_header = BinaryFileHeader();
+ reference_header.SetToReference();
+ if (!memcmp(memory.get(), &reference_header, sizeof(BinaryFileHeader))) return true;
+ if (!memcmp(memory.get(), "mmap lm ", 8)) UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Was it built on a different machine or with a different compiler?");
+ return false;
+}
+
+std::size_t Align8(std::size_t in) {
+ std::size_t off = in % 8;
+ if (!off) return in;
+ return in + 8 - off;
+}
+
+std::size_t TotalHeaderSize(unsigned int order) {
+ return Align8(sizeof(BinaryFileHeader) + 1 /* order */ + sizeof(uint64_t) * order /* counts */ + sizeof(float) /* probing multiplier */ + 1 /* search_tag */);
+}
+
+void ReadBinaryHeader(const void *from, off_t size, std::vector<size_t> &out, float &probing_multiplier, unsigned char &search_tag) {
+ const char *from_char = reinterpret_cast<const char*>(from);
+ if (size < static_cast<off_t>(1 + sizeof(BinaryFileHeader))) UTIL_THROW(FormatLoadException, "File too short to have count information.");
+ // Skip over the BinaryFileHeader which was read by IsBinaryFormat.
+ from_char += sizeof(BinaryFileHeader);
+ unsigned char order = *reinterpret_cast<const unsigned char*>(from_char);
+ if (size < static_cast<off_t>(TotalHeaderSize(order))) UTIL_THROW(FormatLoadException, "File too short to have full header.");
+ out.resize(static_cast<std::size_t>(order));
+ const uint64_t *counts = reinterpret_cast<const uint64_t*>(from_char + 1);
+ for (std::size_t i = 0; i < out.size(); ++i) {
+ out[i] = static_cast<std::size_t>(counts[i]);
+ }
+ const float *probing_ptr = reinterpret_cast<const float*>(counts + out.size());
+ probing_multiplier = *probing_ptr;
+ search_tag = *reinterpret_cast<const char*>(probing_ptr + 1);
+}
+
+void WriteBinaryHeader(void *to, const std::vector<size_t> &from, float probing_multiplier, char search_tag) {
+ BinaryFileHeader header = BinaryFileHeader();
+ header.SetToReference();
+ memcpy(to, &header, sizeof(BinaryFileHeader));
+ char *out = reinterpret_cast<char*>(to) + sizeof(BinaryFileHeader);
+ *reinterpret_cast<unsigned char*>(out) = static_cast<unsigned char>(from.size());
+ uint64_t *counts = reinterpret_cast<uint64_t*>(out + 1);
+ for (std::size_t i = 0; i < from.size(); ++i) {
+ counts[i] = from[i];
+ }
+ float *probing_ptr = reinterpret_cast<float*>(counts + from.size());
+ *probing_ptr = probing_multiplier;
+ *reinterpret_cast<char*>(probing_ptr + 1) = search_tag;
+}
+
+template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, Config config) : mapped_file_(util::OpenReadOrThrow(file)) {
+ const off_t file_size = util::SizeFile(mapped_file_.get());
+
+ std::vector<size_t> counts;
+
+ if (IsBinaryFormat(mapped_file_.get(), file_size)) {
+ memory_.reset(util::MapForRead(file_size, config.prefault, mapped_file_.get()), file_size);
+
+ unsigned char search_tag;
+ ReadBinaryHeader(memory_.begin(), file_size, counts, config.probing_multiplier, search_tag);
+ if (config.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << config.probing_multiplier << " which is < 1.0.");
+ if (search_tag != Search::kBinaryTag) UTIL_THROW(FormatLoadException, "The binary file has a different search strategy than the one requested.");
+ size_t memory_size = Size(counts, config);
+
+ char *start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size());
+ if (memory_size != static_cast<size_t>(memory_.end() - start)) UTIL_THROW(FormatLoadException, "The mmap file " << file << " has size " << file_size << " but " << (memory_size + TotalHeaderSize(counts.size())) << " was expected based on the number of counts and configuration.");
+
+ SetupMemory(start, counts, config);
+ vocab_.LoadedBinary();
+ for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) {
+ i->LoadedBinary();
+ }
+ longest_.LoadedBinary();
+
+ } else {
+ if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0");
+
+ util::FilePiece f(file, mapped_file_.release(), config.messages);
+ ReadARPACounts(f, counts);
+ size_t memory_size = Size(counts, config);
+ char *start;
+
+ if (config.write_mmap) {
+ // Write out an mmap file.
+ util::MapZeroedWrite(config.write_mmap, TotalHeaderSize(counts.size()) + memory_size, mapped_file_, memory_);
+ WriteBinaryHeader(memory_.get(), counts, config.probing_multiplier, Search::kBinaryTag);
+ start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size());
+ } else {
+ memory_.reset(util::MapAnonymous(memory_size), memory_size);
+ start = reinterpret_cast<char*>(memory_.get());
+ }
+ SetupMemory(start, counts, config);
+ try {
+ LoadFromARPA(f, counts, config);
+ } catch (FormatLoadException &e) {
+ e << " in file " << file;
+ throw;
+ }
+ }
+
+ // g++ prints warnings unless these are fully initialized.
+ State begin_sentence = State();
+ begin_sentence.valid_length_ = 1;
+ begin_sentence.history_[0] = vocab_.BeginSentence();
+ begin_sentence.backoff_[0] = unigram_[begin_sentence.history_[0]].backoff;
+ State null_context = State();
+ null_context.valid_length_ = 0;
+ P::Init(begin_sentence, null_context, vocab_, counts.size());
+}
+
+template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config) {
+ // Read the unigrams.
+ Read1Grams(f, counts[0], vocab_, unigram_);
+ bool saw_unk = vocab_.FinishedLoading(unigram_);
+ if (!saw_unk) {
+ switch(config.unknown_missing) {
+ case Config::THROW_UP:
+ {
+ SpecialWordMissingException e("<unk>");
+ e << " and configuration was set to throw if unknown is missing";
+ throw e;
+ }
+ case Config::COMPLAIN:
+ if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl;
+ // There's no break;. This is by design.
+ case Config::SILENT:
+ // Default probabilities for unknown.
+ unigram_[0].backoff = 0.0;
+ unigram_[0].prob = config.unknown_missing_prob;
+ break;
+ }
+ }
+
+ // Read the n-grams.
+ for (unsigned int n = 2; n < counts.size(); ++n) {
+ ReadNGrams(f, n, counts[n-1], vocab_, middle_[n-2]);
+ }
+ ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab_, longest_);
+ if (std::fabs(unigram_[0].backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << unigram_[0].backoff);
+}
+
+/* Ugly optimized function.
+ * in_state contains the previous ngram's length and backoff probabilites to
+ * be used here. out_state is populated with the found ngram length and
+ * backoffs that the next call will find useful.
+ *
+ * The search goes in increasing order of ngram length.
+ */
+template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(
+ const State &in_state,
+ const WordIndex new_word,
+ State &out_state) const {
+
+ FullScoreReturn ret;
+ // This is end pointer passed to SumBackoffs.
+ const ProbBackoff &unigram = unigram_[new_word];
+ if (new_word == 0) {
+ ret.ngram_length = out_state.valid_length_ = 0;
+ // all of backoff.
+ ret.prob = std::accumulate(
+ in_state.backoff_,
+ in_state.backoff_ + in_state.valid_length_,
+ unigram.prob);
+ return ret;
+ }
+ float *backoff_out(out_state.backoff_);
+ *backoff_out = unigram.backoff;
+ ret.prob = unigram.prob;
+ out_state.history_[0] = new_word;
+ if (in_state.valid_length_ == 0) {
+ ret.ngram_length = out_state.valid_length_ = 1;
+ // No backoff because NGramLength() == 0 and unknown can't have backoff.
+ return ret;
+ }
+ ++backoff_out;
+
+ // Ok now we now that the bigram contains known words. Start by looking it up.
+
+ uint64_t lookup_hash = static_cast<uint64_t>(new_word);
+ const WordIndex *hist_iter = in_state.history_;
+ const WordIndex *const hist_end = hist_iter + in_state.valid_length_;
+ typename std::vector<Middle>::const_iterator mid_iter = middle_.begin();
+ for (; ; ++mid_iter, ++hist_iter, ++backoff_out) {
+ if (hist_iter == hist_end) {
+ // Used history [in_state.history_, hist_end) and ran out. No backoff.
+ std::copy(in_state.history_, hist_end, out_state.history_ + 1);
+ ret.ngram_length = out_state.valid_length_ = in_state.valid_length_ + 1;
+ // ret.prob was already set.
+ return ret;
+ }
+ lookup_hash = CombineWordHash(lookup_hash, *hist_iter);
+ if (mid_iter == middle_.end()) break;
+ typename Middle::ConstIterator found;
+ if (!mid_iter->Find(lookup_hash, found)) {
+ // Didn't find an ngram using hist_iter.
+ // The history used in the found n-gram is [in_state.history_, hist_iter).
+ std::copy(in_state.history_, hist_iter, out_state.history_ + 1);
+ // Therefore, we found a (hist_iter - in_state.history_ + 1)-gram including the last word.
+ ret.ngram_length = out_state.valid_length_ = (hist_iter - in_state.history_) + 1;
+ ret.prob = std::accumulate(
+ in_state.backoff_ + (mid_iter - middle_.begin()),
+ in_state.backoff_ + in_state.valid_length_,
+ ret.prob);
+ return ret;
+ }
+ *backoff_out = found->GetValue().backoff;
+ ret.prob = found->GetValue().prob;
+ }
+
+ typename Longest::ConstIterator found;
+ if (!longest_.Find(lookup_hash, found)) {
+ // It's an (P::Order()-1)-gram
+ std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
+ ret.ngram_length = out_state.valid_length_ = P::Order() - 1;
+ ret.prob += in_state.backoff_[P::Order() - 2];
+ return ret;
+ }
+ // It's an P::Order()-gram
+ // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much.
+ std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1);
+ out_state.valid_length_ = P::Order() - 1;
+ ret.ngram_length = P::Order();
+ ret.prob = found->GetValue().prob;
+ return ret;
+}
+
+template class GenericModel<ProbingSearch, MapVocabulary<ProbingSearch> >;
+template class GenericModel<SortedUniformSearch, SortedVocabulary>;
+} // namespace detail
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/ngram.hh b/klm/lm/ngram.hh
new file mode 100644
index 00000000..899a80e8
--- /dev/null
+++ b/klm/lm/ngram.hh
@@ -0,0 +1,226 @@
+#ifndef LM_NGRAM__
+#define LM_NGRAM__
+
+#include "lm/facade.hh"
+#include "lm/ngram_config.hh"
+#include "util/key_value_packing.hh"
+#include "util/mmap.hh"
+#include "util/probing_hash_table.hh"
+#include "util/scoped.hh"
+#include "util/sorted_uniform.hh"
+#include "util/string_piece.hh"
+
+#include <algorithm>
+#include <memory>
+#include <vector>
+
+namespace util { class FilePiece; }
+
+namespace lm {
+namespace ngram {
+
+// If you need higher order, change this and recompile.
+// Having this limit means that State can be
+// (kMaxOrder - 1) * sizeof(float) bytes instead of
+// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
+const std::size_t kMaxOrder = 6;
+
+// This is a POD.
+class State {
+ public:
+ bool operator==(const State &other) const {
+ if (valid_length_ != other.valid_length_) return false;
+ const WordIndex *end = history_ + valid_length_;
+ for (const WordIndex *first = history_, *second = other.history_;
+ first != end; ++first, ++second) {
+ if (*first != *second) return false;
+ }
+ // If the histories are equal, so are the backoffs.
+ return true;
+ }
+
+ // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
+ // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
+ WordIndex history_[kMaxOrder - 1];
+ float backoff_[kMaxOrder - 1];
+ unsigned char valid_length_;
+};
+
+size_t hash_value(const State &state);
+
+namespace detail {
+
+uint64_t HashForVocab(const char *str, std::size_t len);
+inline uint64_t HashForVocab(const StringPiece &str) {
+ return HashForVocab(str.data(), str.length());
+}
+
+struct Prob {
+ float prob;
+ void SetBackoff(float to);
+ void ZeroBackoff() {}
+};
+// No inheritance so this will be a POD.
+struct ProbBackoff {
+ float prob;
+ float backoff;
+ void SetBackoff(float to) { backoff = to; }
+ void ZeroBackoff() { backoff = 0.0; }
+};
+
+} // namespace detail
+
+// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
+class SortedVocabulary : public base::Vocabulary {
+ private:
+ // Sorted uniform requires a GetKey function.
+ struct Entry {
+ uint64_t GetKey() const { return key; }
+ uint64_t key;
+ bool operator<(const Entry &other) const {
+ return key < other.key;
+ }
+ };
+
+ public:
+ SortedVocabulary();
+
+ WordIndex Index(const StringPiece &str) const {
+ const Entry *found;
+ if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) {
+ return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
+ } else {
+ return 0;
+ }
+ }
+
+ // Ignores second argument for consistency with probing hash which has a float here.
+ static size_t Size(std::size_t entries, float ignored = 0.0);
+
+ // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
+ void Init(void *start, std::size_t allocated, std::size_t entries);
+
+ WordIndex Insert(const StringPiece &str);
+
+ // Returns true if unknown was seen. Reorders reorder_vocab so that the IDs are sorted.
+ bool FinishedLoading(detail::ProbBackoff *reorder_vocab);
+
+ void LoadedBinary();
+
+ private:
+ Entry *begin_, *end_;
+
+ bool saw_unk_;
+};
+
+namespace detail {
+
+// Vocabulary storing a map from uint64_t to WordIndex.
+template <class Search> class MapVocabulary : public base::Vocabulary {
+ public:
+ MapVocabulary();
+
+ WordIndex Index(const StringPiece &str) const {
+ typename Lookup::ConstIterator i;
+ return lookup_.Find(HashForVocab(str), i) ? i->GetValue() : 0;
+ }
+
+ static size_t Size(std::size_t entries, float probing_multiplier) {
+ return Lookup::Size(entries, probing_multiplier);
+ }
+
+ // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
+ void Init(void *start, std::size_t allocated, std::size_t entries);
+
+ WordIndex Insert(const StringPiece &str);
+
+ // Returns true if unknown was seen. Does nothing with reorder_vocab.
+ bool FinishedLoading(ProbBackoff *reorder_vocab);
+
+ void LoadedBinary();
+
+ private:
+ typedef typename Search::template Table<WordIndex>::T Lookup;
+ Lookup lookup_;
+
+ bool saw_unk_;
+};
+
+// std::identity is an SGI extension :-(
+struct IdentityHash : public std::unary_function<uint64_t, size_t> {
+ size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
+};
+
+// Should return the same results as SRI.
+// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary.
+template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
+ private:
+ typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
+ public:
+ // Get the size of memory that will be mapped given ngram counts. This
+ // does not include small non-mapped control structures, such as this class
+ // itself.
+ static size_t Size(const std::vector<size_t> &counts, const Config &config = Config());
+
+ GenericModel(const char *file, Config config = Config());
+
+ FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
+
+ private:
+ // Appears after Size in the cc.
+ void SetupMemory(char *start, const std::vector<size_t> &counts, const Config &config);
+
+ void LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config);
+
+ util::scoped_fd mapped_file_;
+
+ // memory_ is the raw block of memory backing vocab_, unigram_, [middle.begin(), middle.end()), and longest_.
+ util::scoped_mmap memory_;
+
+ VocabularyT vocab_;
+
+ ProbBackoff *unigram_;
+
+ typedef typename Search::template Table<ProbBackoff>::T Middle;
+ std::vector<Middle> middle_;
+
+ typedef typename Search::template Table<Prob>::T Longest;
+ Longest longest_;
+};
+
+struct ProbingSearch {
+ typedef float Init;
+
+ static const unsigned char kBinaryTag = 1;
+
+ template <class Value> struct Table {
+ typedef util::ByteAlignedPacking<uint64_t, Value> Packing;
+ typedef util::ProbingHashTable<Packing, IdentityHash> T;
+ };
+};
+
+struct SortedUniformSearch {
+ // This is ignored.
+ typedef float Init;
+
+ static const unsigned char kBinaryTag = 2;
+
+ template <class Value> struct Table {
+ typedef util::ByteAlignedPacking<uint64_t, Value> Packing;
+ typedef util::SortedUniformMap<Packing> T;
+ };
+};
+
+} // namespace detail
+
+// These must also be instantiated in the cc file.
+typedef detail::MapVocabulary<detail::ProbingSearch> Vocabulary;
+typedef detail::GenericModel<detail::ProbingSearch, Vocabulary> Model;
+
+// SortedVocabulary was defined above.
+typedef detail::GenericModel<detail::SortedUniformSearch, SortedVocabulary> SortedModel;
+
+} // namespace ngram
+} // namespace lm
+
+#endif // LM_NGRAM__
diff --git a/klm/lm/ngram_build_binary.cc b/klm/lm/ngram_build_binary.cc
new file mode 100644
index 00000000..9dab30a1
--- /dev/null
+++ b/klm/lm/ngram_build_binary.cc
@@ -0,0 +1,13 @@
+#include "lm/ngram.hh"
+
+#include <iostream>
+
+int main(int argc, char *argv[]) {
+ if (argc != 3) {
+ std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl;
+ return 1;
+ }
+ lm::ngram::Config config;
+ config.write_mmap = argv[2];
+ lm::ngram::Model(argv[1], config);
+}
diff --git a/klm/lm/ngram_config.hh b/klm/lm/ngram_config.hh
new file mode 100644
index 00000000..a7b3afae
--- /dev/null
+++ b/klm/lm/ngram_config.hh
@@ -0,0 +1,58 @@
+#ifndef LM_NGRAM_CONFIG__
+#define LM_NGRAM_CONFIG__
+
+/* Configuration for ngram model. Separate header to reduce pollution. */
+
+#include <iostream>
+
+namespace lm { namespace ngram {
+
+struct Config {
+ /* EFFECTIVE FOR BOTH ARPA AND BINARY READS */
+ // Where to log messages including the progress bar. Set to NULL for
+ // silence.
+ std::ostream *messages;
+
+
+
+ /* ONLY EFFECTIVE WHEN READING ARPA */
+
+ // What to do when <unk> isn't in the provided model.
+ typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing;
+ UnknownMissing unknown_missing;
+
+ // The probability to substitute for <unk> if it's missing from the model.
+ // No effect if the model has <unk> or unknown_missing == THROW_UP.
+ float unknown_missing_prob;
+
+ // Size multiplier for probing hash table. Must be > 1. Space is linear in
+ // this. Time is probing_multiplier / (probing_multiplier - 1). No effect
+ // for sorted variant.
+ // If you find yourself setting this to a low number, consider using the
+ // Sorted version instead which has lower memory consumption.
+ float probing_multiplier;
+
+ // While loading an ARPA file, also write out this binary format file. Set
+ // to NULL to disable.
+ const char *write_mmap;
+
+
+
+ /* ONLY EFFECTIVE WHEN READING BINARY */
+ bool prefault;
+
+
+
+ // Defaults.
+ Config() :
+ messages(&std::cerr),
+ unknown_missing(COMPLAIN),
+ unknown_missing_prob(0.0),
+ probing_multiplier(1.5),
+ write_mmap(NULL),
+ prefault(false) {}
+};
+
+} /* namespace ngram */ } /* namespace lm */
+
+#endif // LM_NGRAM_CONFIG__
diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc
new file mode 100644
index 00000000..d1970260
--- /dev/null
+++ b/klm/lm/ngram_query.cc
@@ -0,0 +1,72 @@
+#include "lm/ngram.hh"
+
+#include <cstdlib>
+#include <fstream>
+#include <iostream>
+#include <string>
+
+#include <sys/resource.h>
+#include <sys/time.h>
+
+float FloatSec(const struct timeval &tv) {
+ return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000000.0);
+}
+
+void PrintUsage(const char *message) {
+ struct rusage usage;
+ if (getrusage(RUSAGE_SELF, &usage)) {
+ perror("getrusage");
+ return;
+ }
+ std::cerr << message;
+ std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n';
+
+ // Linux doesn't set memory usage :-(.
+ std::ifstream status("/proc/self/status", std::ios::in);
+ std::string line;
+ while (getline(status, line)) {
+ if (!strncmp(line.c_str(), "VmRSS:\t", 7)) {
+ std::cerr << "rss " << (line.c_str() + 7) << '\n';
+ break;
+ }
+ }
+}
+
+template <class Model> void Query(const Model &model) {
+ PrintUsage("Loading statistics:\n");
+ typename Model::State state, out;
+ lm::FullScoreReturn ret;
+ std::string word;
+
+ while (std::cin) {
+ state = model.BeginSentenceState();
+ float total = 0.0;
+ bool got = false;
+ while (std::cin >> word) {
+ got = true;
+ ret = model.FullScore(state, model.GetVocabulary().Index(word), out);
+ total += ret.prob;
+ std::cout << word << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' ';
+ state = out;
+ if (std::cin.get() == '\n') break;
+ }
+ if (!got && !std::cin) break;
+ ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
+ total += ret.prob;
+ std::cout << "</s> " << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' ';
+ std::cout << "Total: " << total << '\n';
+ }
+ PrintUsage("After queries:\n");
+}
+
+int main(int argc, char *argv[]) {
+ if (argc < 2) {
+ std::cerr << "Pass language model name." << std::endl;
+ return 0;
+ }
+ {
+ lm::ngram::Model ngram(argv[1]);
+ Query(ngram);
+ }
+ PrintUsage("Total time including destruction:\n");
+}
diff --git a/klm/lm/ngram_test.cc b/klm/lm/ngram_test.cc
new file mode 100644
index 00000000..031e0348
--- /dev/null
+++ b/klm/lm/ngram_test.cc
@@ -0,0 +1,91 @@
+#include "lm/ngram.hh"
+
+#include <stdlib.h>
+
+#define BOOST_TEST_MODULE NGramTest
+#include <boost/test/unit_test.hpp>
+
+namespace lm {
+namespace ngram {
+namespace {
+
+#define StartTest(word, ngram, score) \
+ ret = model.FullScore( \
+ state, \
+ model.GetVocabulary().Index(word), \
+ out);\
+ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
+ BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
+ BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
+
+#define AppendTest(word, ngram, score) \
+ StartTest(word, ngram, score) \
+ state = out;
+
+template <class M> void Starters(M &model) {
+ FullScoreReturn ret;
+ Model::State state(model.BeginSentenceState());
+ Model::State out;
+
+ StartTest("looking", 2, -0.4846522);
+
+ // , probability plus <s> backoff
+ StartTest(",", 1, -1.383514 + -0.4149733);
+ // <unk> probability plus <s> backoff
+ StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
+}
+
+template <class M> void Continuation(M &model) {
+ FullScoreReturn ret;
+ Model::State state(model.BeginSentenceState());
+ Model::State out;
+
+ AppendTest("looking", 2, -0.484652);
+ AppendTest("on", 3, -0.348837);
+ AppendTest("a", 4, -0.0155266);
+ AppendTest("little", 5, -0.00306122);
+ State preserve = state;
+ AppendTest("the", 1, -4.04005);
+ AppendTest("biarritz", 1, -1.9889);
+ AppendTest("not_found", 0, -2.29666);
+ AppendTest("more", 1, -1.20632);
+ AppendTest(".", 2, -0.51363);
+ AppendTest("</s>", 3, -0.0191651);
+
+ state = preserve;
+ AppendTest("more", 5, -0.00181395);
+ AppendTest("loin", 5, -0.0432557);
+}
+
+BOOST_AUTO_TEST_CASE(starters_probing) { Model m("test.arpa"); Starters(m); }
+BOOST_AUTO_TEST_CASE(continuation_probing) { Model m("test.arpa"); Continuation(m); }
+BOOST_AUTO_TEST_CASE(starters_sorted) { SortedModel m("test.arpa"); Starters(m); }
+BOOST_AUTO_TEST_CASE(continuation_sorted) { SortedModel m("test.arpa"); Continuation(m); }
+
+BOOST_AUTO_TEST_CASE(write_and_read_probing) {
+ Config config;
+ config.write_mmap = "test.binary";
+ {
+ Model copy_model("test.arpa", config);
+ }
+ Model binary("test.binary");
+ Starters(binary);
+ Continuation(binary);
+}
+
+BOOST_AUTO_TEST_CASE(write_and_read_sorted) {
+ Config config;
+ config.write_mmap = "test.binary";
+ config.prefault = true;
+ {
+ SortedModel copy_model("test.arpa", config);
+ }
+ SortedModel binary("test.binary");
+ Starters(binary);
+ Continuation(binary);
+}
+
+
+} // namespace
+} // namespace ngram
+} // namespace lm
diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc
new file mode 100644
index 00000000..7bd23d76
--- /dev/null
+++ b/klm/lm/sri.cc
@@ -0,0 +1,115 @@
+#include "lm/exception.hh"
+#include "lm/sri.hh"
+
+#include <Ngram.h>
+#include <Vocab.h>
+
+#include <errno.h>
+
+namespace lm {
+namespace sri {
+
+Vocabulary::Vocabulary() : sri_(new Vocab) {}
+
+Vocabulary::~Vocabulary() {}
+
+WordIndex Vocabulary::Index(const char *str) const {
+ WordIndex ret = sri_->getIndex(str);
+ // NGram wants the index of Vocab_Unknown for unknown words, but for some reason SRI returns Vocab_None here :-(.
+ if (ret == Vocab_None) {
+ return not_found_;
+ } else {
+ return ret;
+ }
+}
+
+const char *Vocabulary::Word(WordIndex index) const {
+ return sri_->getWord(index);
+}
+
+void Vocabulary::FinishedLoading() {
+ SetSpecial(
+ sri_->ssIndex(),
+ sri_->seIndex(),
+ sri_->unkIndex(),
+ sri_->highIndex() + 1);
+}
+
+namespace {
+Ngram *MakeSRIModel(const char *file_name, unsigned int ngram_length, Vocab &sri_vocab) {
+ sri_vocab.unkIsWord() = true;
+ std::auto_ptr<Ngram> ret(new Ngram(sri_vocab, ngram_length));
+ File file(file_name, "r");
+ errno = 0;
+ if (!ret->read(file)) {
+ UTIL_THROW(FormatLoadException, "reading file " << file_name << " with SRI failed.");
+ }
+ return ret.release();
+}
+} // namespace
+
+Model::Model(const char *file_name, unsigned int ngram_length) : sri_(MakeSRIModel(file_name, ngram_length, *vocab_.sri_)) {
+ if (!sri_->setorder()) {
+ UTIL_THROW(FormatLoadException, "Can't have an SRI model with order 0.");
+ }
+ vocab_.FinishedLoading();
+ State begin_state = State();
+ begin_state.valid_length_ = 1;
+ if (kMaxOrder > 1) {
+ begin_state.history_[0] = vocab_.BeginSentence();
+ if (kMaxOrder > 2) begin_state.history_[1] = Vocab_None;
+ }
+ State null_state = State();
+ null_state.valid_length_ = 0;
+ if (kMaxOrder > 1) null_state.history_[0] = Vocab_None;
+ Init(begin_state, null_state, vocab_, sri_->setorder());
+ not_found_ = vocab_.NotFound();
+}
+
+Model::~Model() {}
+
+namespace {
+
+/* Argh SRI's wordProb knows the ngram length but doesn't return it. One more
+ * reason you should use my model. */
+// TODO(stolcke): fix SRILM so I don't have to do this.
+unsigned int MatchedLength(Ngram &model, const WordIndex new_word, const SRIVocabIndex *const_history) {
+ unsigned int out_length = 0;
+ // This gets the length of context used, which is ngram_length - 1 unless new_word is OOV in which case it is 0.
+ model.contextID(new_word, const_history, out_length);
+ return out_length + 1;
+}
+
+} // namespace
+
+FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
+ // If you get a compiler in this function, change SRIVocabIndex in sri.hh to match the one found in SRI's Vocab.h.
+ const SRIVocabIndex *const_history;
+ SRIVocabIndex local_history[Order()];
+ if (in_state.valid_length_ < kMaxOrder - 1) {
+ const_history = in_state.history_;
+ } else {
+ std::copy(in_state.history_, in_state.history_ + in_state.valid_length_, local_history);
+ local_history[in_state.valid_length_] = Vocab_None;
+ const_history = local_history;
+ }
+ FullScoreReturn ret;
+ if (new_word != not_found_) {
+ ret.ngram_length = MatchedLength(*sri_, new_word, const_history);
+ out_state.history_[0] = new_word;
+ out_state.valid_length_ = std::min<unsigned char>(ret.ngram_length, Order() - 1);
+ std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1);
+ if (out_state.valid_length_ < kMaxOrder - 1) {
+ out_state.history_[out_state.valid_length_] = Vocab_None;
+ }
+ } else {
+ ret.ngram_length = 0;
+ if (kMaxOrder > 1) out_state.history_[0] = Vocab_None;
+ out_state.valid_length_ = 0;
+ }
+ ret.prob = sri_->wordProb(new_word, const_history);
+ return ret;
+}
+
+} // namespace sri
+} // namespace lm
diff --git a/klm/lm/sri.hh b/klm/lm/sri.hh
new file mode 100644
index 00000000..b57e9b73
--- /dev/null
+++ b/klm/lm/sri.hh
@@ -0,0 +1,102 @@
+#ifndef LM_SRI__
+#define LM_SRI__
+
+#include "lm/facade.hh"
+#include "util/murmur_hash.hh"
+
+#include <cmath>
+#include <exception>
+#include <memory>
+
+class Ngram;
+class Vocab;
+
+/* The ngram length reported uses some random API I found and may be wrong.
+ *
+ * See ngram, which should return equivalent results.
+ */
+
+namespace lm {
+namespace sri {
+
+static const unsigned int kMaxOrder = 6;
+
+/* This should match VocabIndex found in SRI's Vocab.h
+ * The reason I define this here independently is that SRI's headers
+ * pollute and increase compile time.
+ * It's difficult to extract this from their header and anyway would
+ * break packaging.
+ * If these differ there will be a compiler error in ActuallyCall.
+ */
+typedef unsigned int SRIVocabIndex;
+
+class State {
+ public:
+ // You shouldn't need to touch these, but they're public so State will be a POD.
+ // If valid_length_ < kMaxOrder - 1 then history_[valid_length_] == Vocab_None.
+ SRIVocabIndex history_[kMaxOrder - 1];
+ unsigned char valid_length_;
+};
+
+inline bool operator==(const State &left, const State &right) {
+ if (left.valid_length_ != right.valid_length_) {
+ return false;
+ }
+ for (const SRIVocabIndex *l = left.history_, *r = right.history_;
+ l != left.history_ + left.valid_length_;
+ ++l, ++r) {
+ if (*l != *r) return false;
+ }
+ return true;
+}
+
+inline size_t hash_value(const State &state) {
+ return util::MurmurHashNative(&state.history_, sizeof(SRIVocabIndex) * state.valid_length_);
+}
+
+class Vocabulary : public base::Vocabulary {
+ public:
+ Vocabulary();
+
+ ~Vocabulary();
+
+ WordIndex Index(const StringPiece &str) const {
+ std::string temp(str.data(), str.length());
+ return Index(temp.c_str());
+ }
+ WordIndex Index(const std::string &str) const {
+ return Index(str.c_str());
+ }
+ WordIndex Index(const char *str) const;
+
+ const char *Word(WordIndex index) const;
+
+ private:
+ friend class Model;
+ void FinishedLoading();
+
+ // The parent class isn't copyable so auto_ptr is the same as scoped_ptr
+ // but without the boost dependence.
+ mutable std::auto_ptr<Vocab> sri_;
+};
+
+class Model : public base::ModelFacade<Model, State, Vocabulary> {
+ public:
+ Model(const char *file_name, unsigned int ngram_length);
+
+ ~Model();
+
+ FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
+
+ private:
+ Vocabulary vocab_;
+
+ mutable std::auto_ptr<Ngram> sri_;
+
+ WordIndex not_found_;
+};
+
+} // namespace sri
+} // namespace lm
+
+#endif // LM_SRI__
diff --git a/klm/lm/sri_test.cc b/klm/lm/sri_test.cc
new file mode 100644
index 00000000..e697d722
--- /dev/null
+++ b/klm/lm/sri_test.cc
@@ -0,0 +1,65 @@
+#include "lm/sri.hh"
+
+#include <stdlib.h>
+
+#define BOOST_TEST_MODULE SRITest
+#include <boost/test/unit_test.hpp>
+
+namespace lm {
+namespace sri {
+namespace {
+
+#define StartTest(word, ngram, score) \
+ ret = model.FullScore( \
+ state, \
+ model.GetVocabulary().Index(word), \
+ out);\
+ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \
+ BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
+ BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_);
+
+#define AppendTest(word, ngram, score) \
+ StartTest(word, ngram, score) \
+ state = out;
+
+template <class M> void Starters(M &model) {
+ FullScoreReturn ret;
+ Model::State state(model.BeginSentenceState());
+ Model::State out;
+
+ StartTest("looking", 2, -0.4846522);
+
+ // , probability plus <s> backoff
+ StartTest(",", 1, -1.383514 + -0.4149733);
+ // <unk> probability plus <s> backoff
+ StartTest("this_is_not_found", 0, -1.995635 + -0.4149733);
+}
+
+template <class M> void Continuation(M &model) {
+ FullScoreReturn ret;
+ Model::State state(model.BeginSentenceState());
+ Model::State out;
+
+ AppendTest("looking", 2, -0.484652);
+ AppendTest("on", 3, -0.348837);
+ AppendTest("a", 4, -0.0155266);
+ AppendTest("little", 5, -0.00306122);
+ State preserve = state;
+ AppendTest("the", 1, -4.04005);
+ AppendTest("biarritz", 1, -1.9889);
+ AppendTest("not_found", 0, -2.29666);
+ AppendTest("more", 1, -1.20632);
+ AppendTest(".", 2, -0.51363);
+ AppendTest("</s>", 3, -0.0191651);
+
+ state = preserve;
+ AppendTest("more", 5, -0.00181395);
+ AppendTest("loin", 5, -0.0432557);
+}
+
+BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); }
+BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); }
+
+} // namespace
+} // namespace sri
+} // namespace lm
diff --git a/klm/lm/test.arpa b/klm/lm/test.arpa
new file mode 100644
index 00000000..9d674e83
--- /dev/null
+++ b/klm/lm/test.arpa
@@ -0,0 +1,112 @@
+
+\data\
+ngram 1=34
+ngram 2=43
+ngram 3=8
+ngram 4=5
+ngram 5=3
+
+\1-grams:
+-1.383514 , -0.30103
+-1.139057 . -0.845098
+-1.029493 </s>
+-99 <s> -0.4149733
+-1.995635 <unk>
+-1.285941 a -0.69897
+-1.687872 also -0.30103
+-1.687872 beyond -0.30103
+-1.687872 biarritz -0.30103
+-1.687872 call -0.30103
+-1.687872 concerns -0.30103
+-1.687872 consider -0.30103
+-1.687872 considering -0.30103
+-1.687872 for -0.30103
+-1.509559 higher -0.30103
+-1.687872 however -0.30103
+-1.687872 i -0.30103
+-1.687872 immediate -0.30103
+-1.687872 in -0.30103
+-1.687872 is -0.30103
+-1.285941 little -0.69897
+-1.383514 loin -0.30103
+-1.687872 look -0.30103
+-1.285941 looking -0.4771212
+-1.206319 more -0.544068
+-1.509559 on -0.4771212
+-1.509559 screening -0.4771212
+-1.687872 small -0.30103
+-1.687872 the -0.30103
+-1.687872 to -0.30103
+-1.687872 watch -0.30103
+-1.687872 watching -0.30103
+-1.687872 what -0.30103
+-1.687872 would -0.30103
+
+\2-grams:
+-0.6925742 , .
+-0.7522095 , however
+-0.7522095 , is
+-0.0602359 . </s>
+-0.4846522 <s> looking -0.4771214
+-1.051485 <s> screening
+-1.07153 <s> the
+-1.07153 <s> watching
+-1.07153 <s> what
+-0.09132547 a little -0.69897
+-0.2922095 also call
+-0.2922095 beyond immediate
+-0.2705918 biarritz .
+-0.2922095 call for
+-0.2922095 concerns in
+-0.2922095 consider watch
+-0.2922095 considering consider
+-0.2834328 for ,
+-0.5511513 higher more
+-0.5845945 higher small
+-0.2834328 however ,
+-0.2922095 i would
+-0.2922095 immediate concerns
+-0.2922095 in biarritz
+-0.2922095 is to
+-0.09021038 little more -0.1998621
+-0.7273645 loin ,
+-0.6925742 loin .
+-0.6708385 loin </s>
+-0.2922095 look beyond
+-0.4638903 looking higher
+-0.4638903 looking on -0.4771212
+-0.5136299 more . -0.4771212
+-0.3561665 more loin
+-0.1649931 on a -0.4771213
+-0.1649931 screening a -0.4771213
+-0.2705918 small .
+-0.287799 the screening
+-0.2922095 to look
+-0.2622373 watch </s>
+-0.2922095 watching considering
+-0.2922095 what i
+-0.2922095 would also
+
+\3-grams:
+-0.01916512 more . </s>
+-0.0283603 on a little -0.4771212
+-0.0283603 screening a little -0.4771212
+-0.01660496 a little more -0.09409451
+-0.3488368 <s> looking higher
+-0.3488368 <s> looking on -0.4771212
+-0.1892331 little more loin
+-0.04835128 looking on a -0.4771212
+
+\4-grams:
+-0.009249173 looking on a little -0.4771212
+-0.005464747 on a little more -0.4771212
+-0.005464747 screening a little more
+-0.1453306 a little more loin
+-0.01552657 <s> looking on a -0.4771212
+
+\5-grams:
+-0.003061223 <s> looking on a little
+-0.001813953 looking on a little more
+-0.0432557 on a little more loin
+
+\end\
diff --git a/klm/lm/test.binary b/klm/lm/test.binary
new file mode 100644
index 00000000..90bd2b76
--- /dev/null
+++ b/klm/lm/test.binary
Binary files differ
diff --git a/klm/lm/virtual_interface.cc b/klm/lm/virtual_interface.cc
new file mode 100644
index 00000000..9c7151f9
--- /dev/null
+++ b/klm/lm/virtual_interface.cc
@@ -0,0 +1,22 @@
+#include "lm/virtual_interface.hh"
+
+#include "lm/exception.hh"
+
+namespace lm {
+namespace base {
+
+Vocabulary::~Vocabulary() {}
+
+void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) {
+ begin_sentence_ = begin_sentence;
+ end_sentence_ = end_sentence;
+ not_found_ = not_found;
+ available_ = available;
+ if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>");
+ if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>");
+}
+
+Model::~Model() {}
+
+} // namespace base
+} // namespace lm
diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh
new file mode 100644
index 00000000..621a129e
--- /dev/null
+++ b/klm/lm/virtual_interface.hh
@@ -0,0 +1,156 @@
+#ifndef LM_VIRTUAL_INTERFACE__
+#define LM_VIRTUAL_INTERFACE__
+
+#include "lm/word_index.hh"
+#include "util/string_piece.hh"
+
+#include <string>
+
+namespace lm {
+
+struct FullScoreReturn {
+ float prob;
+ unsigned char ngram_length;
+};
+
+namespace base {
+
+template <class T, class U, class V> class ModelFacade;
+
+/* Vocabulary interface. Call Index(string) and get a word index for use in
+ * calling Model. It provides faster convenience functions for <s>, </s>, and
+ * <unk> although you can also find these using Index.
+ *
+ * Some models do not load the mapping from index to string. If you need this,
+ * check if the model Vocabulary class implements such a function and access it
+ * directly.
+ *
+ * The Vocabulary object is always owned by the Model and can be retrieved from
+ * the Model using BaseVocabulary() for this abstract interface or
+ * GetVocabulary() for the actual implementation (in which case you'll need the
+ * actual implementation of the Model too).
+ */
+class Vocabulary {
+ public:
+ virtual ~Vocabulary();
+
+ WordIndex BeginSentence() const { return begin_sentence_; }
+ WordIndex EndSentence() const { return end_sentence_; }
+ WordIndex NotFound() const { return not_found_; }
+ // FullScoreReturn start index of unused word assignments.
+ WordIndex Available() const { return available_; }
+
+ /* Most implementations allow StringPiece lookups and need only override
+ * Index(StringPiece). SRI requires null termination and overrides all
+ * three methods.
+ */
+ virtual WordIndex Index(const StringPiece &str) const = 0;
+ virtual WordIndex Index(const std::string &str) const {
+ return Index(StringPiece(str));
+ }
+ virtual WordIndex Index(const char *str) const {
+ return Index(StringPiece(str));
+ }
+
+ protected:
+ // Call SetSpecial afterward.
+ Vocabulary() {}
+
+ Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) {
+ SetSpecial(begin_sentence, end_sentence, not_found, available);
+ }
+
+ void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available);
+
+ WordIndex begin_sentence_, end_sentence_, not_found_, available_;
+
+ private:
+ // Disable copy constructors. They're private and undefined.
+ // Ersatz boost::noncopyable.
+ Vocabulary(const Vocabulary &);
+ Vocabulary &operator=(const Vocabulary &);
+};
+
+/* There are two ways to access a Model.
+ *
+ *
+ * OPTION 1: Access the Model directly (e.g. lm::ngram::Model in ngram.hh).
+ * Every Model implements the scoring function:
+ * float Score(
+ * const Model::State &in_state,
+ * const WordIndex new_word,
+ * Model::State &out_state) const;
+ *
+ * It can also return the length of n-gram matched by the model:
+ * FullScoreReturn FullScore(
+ * const Model::State &in_state,
+ * const WordIndex new_word,
+ * Model::State &out_state) const;
+ *
+ * There are also accessor functions:
+ * const State &BeginSentenceState() const;
+ * const State &NullContextState() const;
+ * const Vocabulary &GetVocabulary() const;
+ * unsigned int Order() const;
+ *
+ * NB: In case you're wondering why the model implementation looks like it's
+ * missing these methods, see facade.hh.
+ *
+ * This is the fastest way to use a model and presents a normal State class to
+ * be included in hypothesis state structure.
+ *
+ *
+ * OPTION 2: Use the virtual interface below.
+ *
+ * The virtual interface allow you to decide which Model to use at runtime
+ * without templatizing everything on the Model type. However, each Model has
+ * its own State class, so a single State cannot be efficiently provided (it
+ * would require using the maximum memory of any Model's State or memory
+ * allocation with each lookup). This means you become responsible for
+ * allocating memory with size StateSize() and passing it to the Score or
+ * FullScore functions provided here.
+ *
+ * For example, cdec has a std::string containing the entire state of a
+ * hypothesis. It can reserve StateSize bytes in this string for the model
+ * state.
+ *
+ * All the State objects are POD, so it's ok to use raw memory for storing
+ * State.
+ */
+class Model {
+ public:
+ virtual ~Model();
+
+ size_t StateSize() const { return state_size_; }
+ const void *BeginSentenceMemory() const { return begin_sentence_memory_; }
+ const void *NullContextMemory() const { return null_context_memory_; }
+
+ virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+
+ virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
+
+ unsigned char Order() const { return order_; }
+
+ const Vocabulary &BaseVocabulary() const { return *base_vocab_; }
+
+ private:
+ template <class T, class U, class V> friend class ModelFacade;
+ explicit Model(size_t state_size) : state_size_(state_size) {}
+
+ const size_t state_size_;
+ const void *begin_sentence_memory_, *null_context_memory_;
+
+ const Vocabulary *base_vocab_;
+
+ unsigned char order_;
+
+ // Disable copy constructors. They're private and undefined.
+ // Ersatz boost::noncopyable.
+ Model(const Model &);
+ Model &operator=(const Model &);
+};
+
+} // mamespace base
+} // namespace lm
+
+#endif // LM_VIRTUAL_INTERFACE__
diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh
new file mode 100644
index 00000000..67841c30
--- /dev/null
+++ b/klm/lm/word_index.hh
@@ -0,0 +1,11 @@
+// Separate header because this is used often.
+#ifndef LM_WORD_INDEX__
+#define LM_WORD_INDEX__
+
+namespace lm {
+typedef unsigned int WordIndex;
+} // namespace lm
+
+typedef lm::WordIndex LMWordIndex;
+
+#endif