diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-10-18 23:24:01 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-10-18 23:24:01 +0000 |
commit | de379496ee411993dff94e52f393f6e19437a204 (patch) | |
tree | a3fdb3b299100384e0a82dd2bc424fd52177d411 /klm/lm | |
parent | 08ff0e0332b562dd9c1f36fce24439db81287c68 (diff) |
kenneth's LM preliminary integration
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@681 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'klm/lm')
-rw-r--r-- | klm/lm/Makefile.am | 20 | ||||
-rw-r--r-- | klm/lm/exception.cc | 21 | ||||
-rw-r--r-- | klm/lm/exception.hh | 40 | ||||
-rw-r--r-- | klm/lm/facade.hh | 64 | ||||
-rw-r--r-- | klm/lm/ngram.cc | 522 | ||||
-rw-r--r-- | klm/lm/ngram.hh | 226 | ||||
-rw-r--r-- | klm/lm/ngram_build_binary.cc | 13 | ||||
-rw-r--r-- | klm/lm/ngram_config.hh | 58 | ||||
-rw-r--r-- | klm/lm/ngram_query.cc | 72 | ||||
-rw-r--r-- | klm/lm/ngram_test.cc | 91 | ||||
-rw-r--r-- | klm/lm/sri.cc | 115 | ||||
-rw-r--r-- | klm/lm/sri.hh | 102 | ||||
-rw-r--r-- | klm/lm/sri_test.cc | 65 | ||||
-rw-r--r-- | klm/lm/test.arpa | 112 | ||||
-rw-r--r-- | klm/lm/test.binary | bin | 0 -> 1660 bytes | |||
-rw-r--r-- | klm/lm/virtual_interface.cc | 22 | ||||
-rw-r--r-- | klm/lm/virtual_interface.hh | 156 | ||||
-rw-r--r-- | klm/lm/word_index.hh | 11 |
18 files changed, 1710 insertions, 0 deletions
diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am new file mode 100644 index 00000000..a0c49eb4 --- /dev/null +++ b/klm/lm/Makefile.am @@ -0,0 +1,20 @@ +if HAVE_GTEST +noinst_PROGRAMS = \ + ngram_test +TESTS = ngram_test +endif + +noinst_LIBRARIES = libklm.a + +libklm_a_SOURCES = \ + exception.cc \ + ngram.cc \ + ngram_build_binary.cc \ + ngram_query.cc \ + virtual_interface.cc + +ngram_test_SOURCES = ngram_test.cc +ngram_test_LDADD = ../util/libklm_util.a + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. + diff --git a/klm/lm/exception.cc b/klm/lm/exception.cc new file mode 100644 index 00000000..59a1650d --- /dev/null +++ b/klm/lm/exception.cc @@ -0,0 +1,21 @@ +#include "lm/exception.hh" + +#include<errno.h> +#include<stdio.h> + +namespace lm { + +LoadException::LoadException() throw() {} +LoadException::~LoadException() throw() {} +VocabLoadException::VocabLoadException() throw() {} +VocabLoadException::~VocabLoadException() throw() {} + +FormatLoadException::FormatLoadException() throw() {} +FormatLoadException::~FormatLoadException() throw() {} + +SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() { + *this << "Missing special word " << which; +} +SpecialWordMissingException::~SpecialWordMissingException() throw() {} + +} // namespace lm diff --git a/klm/lm/exception.hh b/klm/lm/exception.hh new file mode 100644 index 00000000..95109012 --- /dev/null +++ b/klm/lm/exception.hh @@ -0,0 +1,40 @@ +#ifndef LM_EXCEPTION__ +#define LM_EXCEPTION__ + +#include "util/exception.hh" +#include "util/string_piece.hh" + +#include <exception> +#include <string> + +namespace lm { + +class LoadException : public util::Exception { + public: + virtual ~LoadException() throw(); + + protected: + LoadException() throw(); +}; + +class VocabLoadException : public LoadException { + public: + virtual ~VocabLoadException() throw(); + VocabLoadException() throw(); +}; + +class FormatLoadException : public LoadException { + public: + FormatLoadException() throw(); + ~FormatLoadException() throw(); +}; + +class SpecialWordMissingException : public VocabLoadException { + public: + explicit SpecialWordMissingException(StringPiece which) throw(); + ~SpecialWordMissingException() throw(); +}; + +} // namespace lm + +#endif diff --git a/klm/lm/facade.hh b/klm/lm/facade.hh new file mode 100644 index 00000000..8b186017 --- /dev/null +++ b/klm/lm/facade.hh @@ -0,0 +1,64 @@ +#ifndef LM_FACADE__ +#define LM_FACADE__ + +#include "lm/virtual_interface.hh" +#include "util/string_piece.hh" + +#include <string> + +namespace lm { +namespace base { + +// Common model interface that depends on knowing the specific classes. +// Curiously recurring template pattern. +template <class Child, class StateT, class VocabularyT> class ModelFacade : public Model { + public: + typedef StateT State; + typedef VocabularyT Vocabulary; + + // Default Score function calls FullScore. Model can override this. + float Score(const State &in_state, const WordIndex new_word, State &out_state) const { + return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob; + } + + /* Translate from void* to State */ + FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const { + return static_cast<const Child*>(this)->FullScore( + *reinterpret_cast<const State*>(in_state), + new_word, + *reinterpret_cast<State*>(out_state)); + } + float Score(const void *in_state, const WordIndex new_word, void *out_state) const { + return static_cast<const Child*>(this)->Score( + *reinterpret_cast<const State*>(in_state), + new_word, + *reinterpret_cast<State*>(out_state)); + } + + const State &BeginSentenceState() const { return begin_sentence_; } + const State &NullContextState() const { return null_context_; } + const Vocabulary &GetVocabulary() const { return *static_cast<const Vocabulary*>(&BaseVocabulary()); } + + protected: + ModelFacade() : Model(sizeof(State)) {} + + virtual ~ModelFacade() {} + + // begin_sentence and null_context can disappear after. vocab should stay. + void Init(const State &begin_sentence, const State &null_context, const Vocabulary &vocab, unsigned char order) { + begin_sentence_ = begin_sentence; + null_context_ = null_context; + begin_sentence_memory_ = &begin_sentence_; + null_context_memory_ = &null_context_; + base_vocab_ = &vocab; + order_ = order; + } + + private: + State begin_sentence_, null_context_; +}; + +} // mamespace base +} // namespace lm + +#endif // LM_FACADE__ diff --git a/klm/lm/ngram.cc b/klm/lm/ngram.cc new file mode 100644 index 00000000..a87c82aa --- /dev/null +++ b/klm/lm/ngram.cc @@ -0,0 +1,522 @@ +#include "lm/ngram.hh" + +#include "lm/exception.hh" +#include "util/file_piece.hh" +#include "util/joint_sort.hh" +#include "util/murmur_hash.hh" +#include "util/probing_hash_table.hh" + +#include <algorithm> +#include <functional> +#include <numeric> +#include <limits> +#include <string> + +#include <cmath> +#include <fcntl.h> +#include <errno.h> +#include <stdlib.h> +#include <sys/mman.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <unistd.h> + +namespace lm { +namespace ngram { + +size_t hash_value(const State &state) { + return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); +} + +namespace detail { +uint64_t HashForVocab(const char *str, std::size_t len) { + // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000 + // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit. + return util::MurmurHash64A(str, len, 0); +} + +void Prob::SetBackoff(float to) { + UTIL_THROW(FormatLoadException, "Attempt to set backoff " << to << " for the highest order n-gram"); +} + +// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok. +const uint64_t kUnknownHash = HashForVocab("<unk>", 5); +// Sadly some LMs have <UNK>. +const uint64_t kUnknownCapHash = HashForVocab("<UNK>", 5); + +} // namespace detail + +SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL) {} + +std::size_t SortedVocabulary::Size(std::size_t entries, float ignored) { + // Lead with the number of entries. + return sizeof(uint64_t) + sizeof(Entry) * entries; +} + +void SortedVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) { + assert(allocated >= Size(entries)); + // Leave space for number of entries. + begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1); + end_ = begin_; + saw_unk_ = false; +} + +WordIndex SortedVocabulary::Insert(const StringPiece &str) { + uint64_t hashed = detail::HashForVocab(str); + if (hashed == detail::kUnknownHash || hashed == detail::kUnknownCapHash) { + saw_unk_ = true; + return 0; + } + end_->key = hashed; + ++end_; + // This is 1 + the offset where it was inserted to make room for unk. + return end_ - begin_; +} + +bool SortedVocabulary::FinishedLoading(detail::ProbBackoff *reorder_vocab) { + util::JointSort(begin_, end_, reorder_vocab + 1); + SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1); + // Save size. + *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_; + return saw_unk_; +} + +void SortedVocabulary::LoadedBinary() { + end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); + SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1); +} + +namespace detail { + +template <class Search> MapVocabulary<Search>::MapVocabulary() {} + +template <class Search> void MapVocabulary<Search>::Init(void *start, std::size_t allocated, std::size_t entries) { + lookup_ = Lookup(start, allocated); + available_ = 1; + // Later if available_ != expected_available_ then we can throw UnknownMissingException. + saw_unk_ = false; +} + +template <class Search> WordIndex MapVocabulary<Search>::Insert(const StringPiece &str) { + uint64_t hashed = HashForVocab(str); + // Prevent unknown from going into the table. + if (hashed == kUnknownHash || hashed == kUnknownCapHash) { + saw_unk_ = true; + return 0; + } else { + lookup_.Insert(Lookup::Packing::Make(hashed, available_)); + return available_++; + } +} + +template <class Search> bool MapVocabulary<Search>::FinishedLoading(ProbBackoff *reorder_vocab) { + lookup_.FinishedInserting(); + SetSpecial(Index("<s>"), Index("</s>"), 0, available_); + return saw_unk_; +} + +template <class Search> void MapVocabulary<Search>::LoadedBinary() { + lookup_.LoadedBinary(); + SetSpecial(Index("<s>"), Index("</s>"), 0, available_); +} + +/* All of the entropy is in low order bits and boost::hash does poorly with + * these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a + * stable point: 0. But 0 is <unk> which won't be queried here anyway. + */ +inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { + uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL); + return ret; +} + +uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) { + if (word == word_end) return 0; + uint64_t current = static_cast<uint64_t>(*word); + for (++word; word != word_end; ++word) { + current = CombineWordHash(current, *word); + } + return current; +} + +bool IsEntirelyWhiteSpace(const StringPiece &line) { + for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) { + if (!isspace(line.data()[i])) return false; + } + return true; +} + +void ReadARPACounts(util::FilePiece &in, std::vector<size_t> &number) { + number.clear(); + StringPiece line; + if (!IsEntirelyWhiteSpace(line = in.ReadLine())) UTIL_THROW(FormatLoadException, "First line was \"" << line << "\" not blank"); + if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); + while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { + if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); + // So strtol doesn't go off the end of line. + std::string remaining(line.data() + 6, line.size() - 6); + char *end_ptr; + unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); + if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line); + if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line); + ++end_ptr; + const char *start = end_ptr; + long int count = std::strtol(start, &end_ptr, 10); + if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); + if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); + number.push_back(count); + } +} + +void ReadNGramHeader(util::FilePiece &in, unsigned int length) { + StringPiece line; + while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + std::stringstream expected; + expected << '\\' << length << "-grams:"; + if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead."); +} + +// Special unigram reader because unigram's data structure is different and because we're inserting vocab words. +template <class Voc> void Read1Grams(util::FilePiece &f, const size_t count, Voc &vocab, ProbBackoff *unigrams) { + ReadNGramHeader(f, 1); + for (size_t i = 0; i < count; ++i) { + try { + float prob = f.ReadFloat(); + if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); + ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())]; + value.prob = prob; + switch (f.get()) { + case '\t': + value.SetBackoff(f.ReadFloat()); + if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); + break; + case '\n': + value.ZeroBackoff(); + break; + default: + UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); + } + } catch(util::Exception &e) { + e << " in the " << i << "th 1-gram at byte " << f.Offset(); + throw; + } + } + if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after unigrams at byte " << f.Offset()); +} + +template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) { + ReadNGramHeader(f, n); + + // vocab ids of words in reverse order + WordIndex vocab_ids[n]; + typename Store::Packing::Value value; + for (size_t i = 0; i < count; ++i) { + try { + value.prob = f.ReadFloat(); + for (WordIndex *vocab_out = &vocab_ids[n-1]; vocab_out >= vocab_ids; --vocab_out) { + *vocab_out = vocab.Index(f.ReadDelimited()); + } + uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n); + + switch (f.get()) { + case '\t': + value.SetBackoff(f.ReadFloat()); + if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); + break; + case '\n': + value.ZeroBackoff(); + break; + default: + UTIL_THROW(FormatLoadException, "Expected tab or newline after n-gram"); + } + store.Insert(Store::Packing::Make(key, value)); + } catch(util::Exception &e) { + e << " in the " << i << "th " << n << "-gram at byte " << f.Offset(); + throw; + } + } + + if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after " << n << "-grams at byte " << f.Offset()); + store.FinishedInserting(); +} + +template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<size_t> &counts, const Config &config) { + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + size_t memory_size = VocabularyT::Size(counts[0], config.probing_multiplier); + memory_size += sizeof(ProbBackoff) * (counts[0] + 1); // +1 for hallucinate <unk> + for (unsigned char n = 2; n < counts.size(); ++n) { + memory_size += Middle::Size(counts[n - 1], config.probing_multiplier); + } + memory_size += Longest::Size(counts.back(), config.probing_multiplier); + return memory_size; +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(char *base, const std::vector<size_t> &counts, const Config &config) { + char *start = base; + size_t allocated = VocabularyT::Size(counts[0], config.probing_multiplier); + vocab_.Init(start, allocated, counts[0]); + start += allocated; + unigram_ = reinterpret_cast<ProbBackoff*>(start); + start += sizeof(ProbBackoff) * (counts[0] + 1); + for (unsigned int n = 2; n < counts.size(); ++n) { + allocated = Middle::Size(counts[n - 1], config.probing_multiplier); + middle_.push_back(Middle(start, allocated)); + start += allocated; + } + allocated = Longest::Size(counts.back(), config.probing_multiplier); + longest_ = Longest(start, allocated); + start += allocated; + if (static_cast<std::size_t>(start - base) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - base) << " but Size says they should take " << Size(counts, config)); +} + +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 0\n\0"; +struct BinaryFileHeader { + char magic[sizeof(kMagicBytes)]; + float zero_f, one_f, minus_half_f; + WordIndex one_word_index, max_word_index; + uint64_t one_uint64; + + void SetToReference() { + std::memcpy(magic, kMagicBytes, sizeof(magic)); + zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; + one_word_index = 1; + max_word_index = std::numeric_limits<WordIndex>::max(); + one_uint64 = 1; + } +}; + +bool IsBinaryFormat(int fd, off_t size) { + if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(BinaryFileHeader)))) return false; + // Try reading the header. + util::scoped_mmap memory(mmap(NULL, sizeof(BinaryFileHeader), PROT_READ, MAP_FILE | MAP_PRIVATE, fd, 0), sizeof(BinaryFileHeader)); + if (memory.get() == MAP_FAILED) return false; + BinaryFileHeader reference_header = BinaryFileHeader(); + reference_header.SetToReference(); + if (!memcmp(memory.get(), &reference_header, sizeof(BinaryFileHeader))) return true; + if (!memcmp(memory.get(), "mmap lm ", 8)) UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Was it built on a different machine or with a different compiler?"); + return false; +} + +std::size_t Align8(std::size_t in) { + std::size_t off = in % 8; + if (!off) return in; + return in + 8 - off; +} + +std::size_t TotalHeaderSize(unsigned int order) { + return Align8(sizeof(BinaryFileHeader) + 1 /* order */ + sizeof(uint64_t) * order /* counts */ + sizeof(float) /* probing multiplier */ + 1 /* search_tag */); +} + +void ReadBinaryHeader(const void *from, off_t size, std::vector<size_t> &out, float &probing_multiplier, unsigned char &search_tag) { + const char *from_char = reinterpret_cast<const char*>(from); + if (size < static_cast<off_t>(1 + sizeof(BinaryFileHeader))) UTIL_THROW(FormatLoadException, "File too short to have count information."); + // Skip over the BinaryFileHeader which was read by IsBinaryFormat. + from_char += sizeof(BinaryFileHeader); + unsigned char order = *reinterpret_cast<const unsigned char*>(from_char); + if (size < static_cast<off_t>(TotalHeaderSize(order))) UTIL_THROW(FormatLoadException, "File too short to have full header."); + out.resize(static_cast<std::size_t>(order)); + const uint64_t *counts = reinterpret_cast<const uint64_t*>(from_char + 1); + for (std::size_t i = 0; i < out.size(); ++i) { + out[i] = static_cast<std::size_t>(counts[i]); + } + const float *probing_ptr = reinterpret_cast<const float*>(counts + out.size()); + probing_multiplier = *probing_ptr; + search_tag = *reinterpret_cast<const char*>(probing_ptr + 1); +} + +void WriteBinaryHeader(void *to, const std::vector<size_t> &from, float probing_multiplier, char search_tag) { + BinaryFileHeader header = BinaryFileHeader(); + header.SetToReference(); + memcpy(to, &header, sizeof(BinaryFileHeader)); + char *out = reinterpret_cast<char*>(to) + sizeof(BinaryFileHeader); + *reinterpret_cast<unsigned char*>(out) = static_cast<unsigned char>(from.size()); + uint64_t *counts = reinterpret_cast<uint64_t*>(out + 1); + for (std::size_t i = 0; i < from.size(); ++i) { + counts[i] = from[i]; + } + float *probing_ptr = reinterpret_cast<float*>(counts + from.size()); + *probing_ptr = probing_multiplier; + *reinterpret_cast<char*>(probing_ptr + 1) = search_tag; +} + +template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, Config config) : mapped_file_(util::OpenReadOrThrow(file)) { + const off_t file_size = util::SizeFile(mapped_file_.get()); + + std::vector<size_t> counts; + + if (IsBinaryFormat(mapped_file_.get(), file_size)) { + memory_.reset(util::MapForRead(file_size, config.prefault, mapped_file_.get()), file_size); + + unsigned char search_tag; + ReadBinaryHeader(memory_.begin(), file_size, counts, config.probing_multiplier, search_tag); + if (config.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << config.probing_multiplier << " which is < 1.0."); + if (search_tag != Search::kBinaryTag) UTIL_THROW(FormatLoadException, "The binary file has a different search strategy than the one requested."); + size_t memory_size = Size(counts, config); + + char *start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size()); + if (memory_size != static_cast<size_t>(memory_.end() - start)) UTIL_THROW(FormatLoadException, "The mmap file " << file << " has size " << file_size << " but " << (memory_size + TotalHeaderSize(counts.size())) << " was expected based on the number of counts and configuration."); + + SetupMemory(start, counts, config); + vocab_.LoadedBinary(); + for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) { + i->LoadedBinary(); + } + longest_.LoadedBinary(); + + } else { + if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0"); + + util::FilePiece f(file, mapped_file_.release(), config.messages); + ReadARPACounts(f, counts); + size_t memory_size = Size(counts, config); + char *start; + + if (config.write_mmap) { + // Write out an mmap file. + util::MapZeroedWrite(config.write_mmap, TotalHeaderSize(counts.size()) + memory_size, mapped_file_, memory_); + WriteBinaryHeader(memory_.get(), counts, config.probing_multiplier, Search::kBinaryTag); + start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size()); + } else { + memory_.reset(util::MapAnonymous(memory_size), memory_size); + start = reinterpret_cast<char*>(memory_.get()); + } + SetupMemory(start, counts, config); + try { + LoadFromARPA(f, counts, config); + } catch (FormatLoadException &e) { + e << " in file " << file; + throw; + } + } + + // g++ prints warnings unless these are fully initialized. + State begin_sentence = State(); + begin_sentence.valid_length_ = 1; + begin_sentence.history_[0] = vocab_.BeginSentence(); + begin_sentence.backoff_[0] = unigram_[begin_sentence.history_[0]].backoff; + State null_context = State(); + null_context.valid_length_ = 0; + P::Init(begin_sentence, null_context, vocab_, counts.size()); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config) { + // Read the unigrams. + Read1Grams(f, counts[0], vocab_, unigram_); + bool saw_unk = vocab_.FinishedLoading(unigram_); + if (!saw_unk) { + switch(config.unknown_missing) { + case Config::THROW_UP: + { + SpecialWordMissingException e("<unk>"); + e << " and configuration was set to throw if unknown is missing"; + throw e; + } + case Config::COMPLAIN: + if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl; + // There's no break;. This is by design. + case Config::SILENT: + // Default probabilities for unknown. + unigram_[0].backoff = 0.0; + unigram_[0].prob = config.unknown_missing_prob; + break; + } + } + + // Read the n-grams. + for (unsigned int n = 2; n < counts.size(); ++n) { + ReadNGrams(f, n, counts[n-1], vocab_, middle_[n-2]); + } + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab_, longest_); + if (std::fabs(unigram_[0].backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << unigram_[0].backoff); +} + +/* Ugly optimized function. + * in_state contains the previous ngram's length and backoff probabilites to + * be used here. out_state is populated with the found ngram length and + * backoffs that the next call will find useful. + * + * The search goes in increasing order of ngram length. + */ +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore( + const State &in_state, + const WordIndex new_word, + State &out_state) const { + + FullScoreReturn ret; + // This is end pointer passed to SumBackoffs. + const ProbBackoff &unigram = unigram_[new_word]; + if (new_word == 0) { + ret.ngram_length = out_state.valid_length_ = 0; + // all of backoff. + ret.prob = std::accumulate( + in_state.backoff_, + in_state.backoff_ + in_state.valid_length_, + unigram.prob); + return ret; + } + float *backoff_out(out_state.backoff_); + *backoff_out = unigram.backoff; + ret.prob = unigram.prob; + out_state.history_[0] = new_word; + if (in_state.valid_length_ == 0) { + ret.ngram_length = out_state.valid_length_ = 1; + // No backoff because NGramLength() == 0 and unknown can't have backoff. + return ret; + } + ++backoff_out; + + // Ok now we now that the bigram contains known words. Start by looking it up. + + uint64_t lookup_hash = static_cast<uint64_t>(new_word); + const WordIndex *hist_iter = in_state.history_; + const WordIndex *const hist_end = hist_iter + in_state.valid_length_; + typename std::vector<Middle>::const_iterator mid_iter = middle_.begin(); + for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { + if (hist_iter == hist_end) { + // Used history [in_state.history_, hist_end) and ran out. No backoff. + std::copy(in_state.history_, hist_end, out_state.history_ + 1); + ret.ngram_length = out_state.valid_length_ = in_state.valid_length_ + 1; + // ret.prob was already set. + return ret; + } + lookup_hash = CombineWordHash(lookup_hash, *hist_iter); + if (mid_iter == middle_.end()) break; + typename Middle::ConstIterator found; + if (!mid_iter->Find(lookup_hash, found)) { + // Didn't find an ngram using hist_iter. + // The history used in the found n-gram is [in_state.history_, hist_iter). + std::copy(in_state.history_, hist_iter, out_state.history_ + 1); + // Therefore, we found a (hist_iter - in_state.history_ + 1)-gram including the last word. + ret.ngram_length = out_state.valid_length_ = (hist_iter - in_state.history_) + 1; + ret.prob = std::accumulate( + in_state.backoff_ + (mid_iter - middle_.begin()), + in_state.backoff_ + in_state.valid_length_, + ret.prob); + return ret; + } + *backoff_out = found->GetValue().backoff; + ret.prob = found->GetValue().prob; + } + + typename Longest::ConstIterator found; + if (!longest_.Find(lookup_hash, found)) { + // It's an (P::Order()-1)-gram + std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1); + ret.ngram_length = out_state.valid_length_ = P::Order() - 1; + ret.prob += in_state.backoff_[P::Order() - 2]; + return ret; + } + // It's an P::Order()-gram + // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. + std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1); + out_state.valid_length_ = P::Order() - 1; + ret.ngram_length = P::Order(); + ret.prob = found->GetValue().prob; + return ret; +} + +template class GenericModel<ProbingSearch, MapVocabulary<ProbingSearch> >; +template class GenericModel<SortedUniformSearch, SortedVocabulary>; +} // namespace detail +} // namespace ngram +} // namespace lm diff --git a/klm/lm/ngram.hh b/klm/lm/ngram.hh new file mode 100644 index 00000000..899a80e8 --- /dev/null +++ b/klm/lm/ngram.hh @@ -0,0 +1,226 @@ +#ifndef LM_NGRAM__ +#define LM_NGRAM__ + +#include "lm/facade.hh" +#include "lm/ngram_config.hh" +#include "util/key_value_packing.hh" +#include "util/mmap.hh" +#include "util/probing_hash_table.hh" +#include "util/scoped.hh" +#include "util/sorted_uniform.hh" +#include "util/string_piece.hh" + +#include <algorithm> +#include <memory> +#include <vector> + +namespace util { class FilePiece; } + +namespace lm { +namespace ngram { + +// If you need higher order, change this and recompile. +// Having this limit means that State can be +// (kMaxOrder - 1) * sizeof(float) bytes instead of +// sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead +const std::size_t kMaxOrder = 6; + +// This is a POD. +class State { + public: + bool operator==(const State &other) const { + if (valid_length_ != other.valid_length_) return false; + const WordIndex *end = history_ + valid_length_; + for (const WordIndex *first = history_, *second = other.history_; + first != end; ++first, ++second) { + if (*first != *second) return false; + } + // If the histories are equal, so are the backoffs. + return true; + } + + // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. + // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. + WordIndex history_[kMaxOrder - 1]; + float backoff_[kMaxOrder - 1]; + unsigned char valid_length_; +}; + +size_t hash_value(const State &state); + +namespace detail { + +uint64_t HashForVocab(const char *str, std::size_t len); +inline uint64_t HashForVocab(const StringPiece &str) { + return HashForVocab(str.data(), str.length()); +} + +struct Prob { + float prob; + void SetBackoff(float to); + void ZeroBackoff() {} +}; +// No inheritance so this will be a POD. +struct ProbBackoff { + float prob; + float backoff; + void SetBackoff(float to) { backoff = to; } + void ZeroBackoff() { backoff = 0.0; } +}; + +} // namespace detail + +// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. +class SortedVocabulary : public base::Vocabulary { + private: + // Sorted uniform requires a GetKey function. + struct Entry { + uint64_t GetKey() const { return key; } + uint64_t key; + bool operator<(const Entry &other) const { + return key < other.key; + } + }; + + public: + SortedVocabulary(); + + WordIndex Index(const StringPiece &str) const { + const Entry *found; + if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) { + return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table. + } else { + return 0; + } + } + + // Ignores second argument for consistency with probing hash which has a float here. + static size_t Size(std::size_t entries, float ignored = 0.0); + + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. + void Init(void *start, std::size_t allocated, std::size_t entries); + + WordIndex Insert(const StringPiece &str); + + // Returns true if unknown was seen. Reorders reorder_vocab so that the IDs are sorted. + bool FinishedLoading(detail::ProbBackoff *reorder_vocab); + + void LoadedBinary(); + + private: + Entry *begin_, *end_; + + bool saw_unk_; +}; + +namespace detail { + +// Vocabulary storing a map from uint64_t to WordIndex. +template <class Search> class MapVocabulary : public base::Vocabulary { + public: + MapVocabulary(); + + WordIndex Index(const StringPiece &str) const { + typename Lookup::ConstIterator i; + return lookup_.Find(HashForVocab(str), i) ? i->GetValue() : 0; + } + + static size_t Size(std::size_t entries, float probing_multiplier) { + return Lookup::Size(entries, probing_multiplier); + } + + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. + void Init(void *start, std::size_t allocated, std::size_t entries); + + WordIndex Insert(const StringPiece &str); + + // Returns true if unknown was seen. Does nothing with reorder_vocab. + bool FinishedLoading(ProbBackoff *reorder_vocab); + + void LoadedBinary(); + + private: + typedef typename Search::template Table<WordIndex>::T Lookup; + Lookup lookup_; + + bool saw_unk_; +}; + +// std::identity is an SGI extension :-( +struct IdentityHash : public std::unary_function<uint64_t, size_t> { + size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); } +}; + +// Should return the same results as SRI. +// Why VocabularyT instead of just Vocabulary? ModelFacade defines Vocabulary. +template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> { + private: + typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P; + public: + // Get the size of memory that will be mapped given ngram counts. This + // does not include small non-mapped control structures, such as this class + // itself. + static size_t Size(const std::vector<size_t> &counts, const Config &config = Config()); + + GenericModel(const char *file, Config config = Config()); + + FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; + + private: + // Appears after Size in the cc. + void SetupMemory(char *start, const std::vector<size_t> &counts, const Config &config); + + void LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config); + + util::scoped_fd mapped_file_; + + // memory_ is the raw block of memory backing vocab_, unigram_, [middle.begin(), middle.end()), and longest_. + util::scoped_mmap memory_; + + VocabularyT vocab_; + + ProbBackoff *unigram_; + + typedef typename Search::template Table<ProbBackoff>::T Middle; + std::vector<Middle> middle_; + + typedef typename Search::template Table<Prob>::T Longest; + Longest longest_; +}; + +struct ProbingSearch { + typedef float Init; + + static const unsigned char kBinaryTag = 1; + + template <class Value> struct Table { + typedef util::ByteAlignedPacking<uint64_t, Value> Packing; + typedef util::ProbingHashTable<Packing, IdentityHash> T; + }; +}; + +struct SortedUniformSearch { + // This is ignored. + typedef float Init; + + static const unsigned char kBinaryTag = 2; + + template <class Value> struct Table { + typedef util::ByteAlignedPacking<uint64_t, Value> Packing; + typedef util::SortedUniformMap<Packing> T; + }; +}; + +} // namespace detail + +// These must also be instantiated in the cc file. +typedef detail::MapVocabulary<detail::ProbingSearch> Vocabulary; +typedef detail::GenericModel<detail::ProbingSearch, Vocabulary> Model; + +// SortedVocabulary was defined above. +typedef detail::GenericModel<detail::SortedUniformSearch, SortedVocabulary> SortedModel; + +} // namespace ngram +} // namespace lm + +#endif // LM_NGRAM__ diff --git a/klm/lm/ngram_build_binary.cc b/klm/lm/ngram_build_binary.cc new file mode 100644 index 00000000..9dab30a1 --- /dev/null +++ b/klm/lm/ngram_build_binary.cc @@ -0,0 +1,13 @@ +#include "lm/ngram.hh" + +#include <iostream> + +int main(int argc, char *argv[]) { + if (argc != 3) { + std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl; + return 1; + } + lm::ngram::Config config; + config.write_mmap = argv[2]; + lm::ngram::Model(argv[1], config); +} diff --git a/klm/lm/ngram_config.hh b/klm/lm/ngram_config.hh new file mode 100644 index 00000000..a7b3afae --- /dev/null +++ b/klm/lm/ngram_config.hh @@ -0,0 +1,58 @@ +#ifndef LM_NGRAM_CONFIG__ +#define LM_NGRAM_CONFIG__ + +/* Configuration for ngram model. Separate header to reduce pollution. */ + +#include <iostream> + +namespace lm { namespace ngram { + +struct Config { + /* EFFECTIVE FOR BOTH ARPA AND BINARY READS */ + // Where to log messages including the progress bar. Set to NULL for + // silence. + std::ostream *messages; + + + + /* ONLY EFFECTIVE WHEN READING ARPA */ + + // What to do when <unk> isn't in the provided model. + typedef enum {THROW_UP, COMPLAIN, SILENT} UnknownMissing; + UnknownMissing unknown_missing; + + // The probability to substitute for <unk> if it's missing from the model. + // No effect if the model has <unk> or unknown_missing == THROW_UP. + float unknown_missing_prob; + + // Size multiplier for probing hash table. Must be > 1. Space is linear in + // this. Time is probing_multiplier / (probing_multiplier - 1). No effect + // for sorted variant. + // If you find yourself setting this to a low number, consider using the + // Sorted version instead which has lower memory consumption. + float probing_multiplier; + + // While loading an ARPA file, also write out this binary format file. Set + // to NULL to disable. + const char *write_mmap; + + + + /* ONLY EFFECTIVE WHEN READING BINARY */ + bool prefault; + + + + // Defaults. + Config() : + messages(&std::cerr), + unknown_missing(COMPLAIN), + unknown_missing_prob(0.0), + probing_multiplier(1.5), + write_mmap(NULL), + prefault(false) {} +}; + +} /* namespace ngram */ } /* namespace lm */ + +#endif // LM_NGRAM_CONFIG__ diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc new file mode 100644 index 00000000..d1970260 --- /dev/null +++ b/klm/lm/ngram_query.cc @@ -0,0 +1,72 @@ +#include "lm/ngram.hh" + +#include <cstdlib> +#include <fstream> +#include <iostream> +#include <string> + +#include <sys/resource.h> +#include <sys/time.h> + +float FloatSec(const struct timeval &tv) { + return static_cast<float>(tv.tv_sec) + (static_cast<float>(tv.tv_usec) / 1000000000.0); +} + +void PrintUsage(const char *message) { + struct rusage usage; + if (getrusage(RUSAGE_SELF, &usage)) { + perror("getrusage"); + return; + } + std::cerr << message; + std::cerr << "user\t" << FloatSec(usage.ru_utime) << "\nsys\t" << FloatSec(usage.ru_stime) << '\n'; + + // Linux doesn't set memory usage :-(. + std::ifstream status("/proc/self/status", std::ios::in); + std::string line; + while (getline(status, line)) { + if (!strncmp(line.c_str(), "VmRSS:\t", 7)) { + std::cerr << "rss " << (line.c_str() + 7) << '\n'; + break; + } + } +} + +template <class Model> void Query(const Model &model) { + PrintUsage("Loading statistics:\n"); + typename Model::State state, out; + lm::FullScoreReturn ret; + std::string word; + + while (std::cin) { + state = model.BeginSentenceState(); + float total = 0.0; + bool got = false; + while (std::cin >> word) { + got = true; + ret = model.FullScore(state, model.GetVocabulary().Index(word), out); + total += ret.prob; + std::cout << word << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' '; + state = out; + if (std::cin.get() == '\n') break; + } + if (!got && !std::cin) break; + ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); + total += ret.prob; + std::cout << "</s> " << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << ' '; + std::cout << "Total: " << total << '\n'; + } + PrintUsage("After queries:\n"); +} + +int main(int argc, char *argv[]) { + if (argc < 2) { + std::cerr << "Pass language model name." << std::endl; + return 0; + } + { + lm::ngram::Model ngram(argv[1]); + Query(ngram); + } + PrintUsage("Total time including destruction:\n"); +} diff --git a/klm/lm/ngram_test.cc b/klm/lm/ngram_test.cc new file mode 100644 index 00000000..031e0348 --- /dev/null +++ b/klm/lm/ngram_test.cc @@ -0,0 +1,91 @@ +#include "lm/ngram.hh" + +#include <stdlib.h> + +#define BOOST_TEST_MODULE NGramTest +#include <boost/test/unit_test.hpp> + +namespace lm { +namespace ngram { +namespace { + +#define StartTest(word, ngram, score) \ + ret = model.FullScore( \ + state, \ + model.GetVocabulary().Index(word), \ + out);\ + BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ + BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ + BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); + +#define AppendTest(word, ngram, score) \ + StartTest(word, ngram, score) \ + state = out; + +template <class M> void Starters(M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + StartTest("looking", 2, -0.4846522); + + // , probability plus <s> backoff + StartTest(",", 1, -1.383514 + -0.4149733); + // <unk> probability plus <s> backoff + StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); +} + +template <class M> void Continuation(M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + AppendTest("looking", 2, -0.484652); + AppendTest("on", 3, -0.348837); + AppendTest("a", 4, -0.0155266); + AppendTest("little", 5, -0.00306122); + State preserve = state; + AppendTest("the", 1, -4.04005); + AppendTest("biarritz", 1, -1.9889); + AppendTest("not_found", 0, -2.29666); + AppendTest("more", 1, -1.20632); + AppendTest(".", 2, -0.51363); + AppendTest("</s>", 3, -0.0191651); + + state = preserve; + AppendTest("more", 5, -0.00181395); + AppendTest("loin", 5, -0.0432557); +} + +BOOST_AUTO_TEST_CASE(starters_probing) { Model m("test.arpa"); Starters(m); } +BOOST_AUTO_TEST_CASE(continuation_probing) { Model m("test.arpa"); Continuation(m); } +BOOST_AUTO_TEST_CASE(starters_sorted) { SortedModel m("test.arpa"); Starters(m); } +BOOST_AUTO_TEST_CASE(continuation_sorted) { SortedModel m("test.arpa"); Continuation(m); } + +BOOST_AUTO_TEST_CASE(write_and_read_probing) { + Config config; + config.write_mmap = "test.binary"; + { + Model copy_model("test.arpa", config); + } + Model binary("test.binary"); + Starters(binary); + Continuation(binary); +} + +BOOST_AUTO_TEST_CASE(write_and_read_sorted) { + Config config; + config.write_mmap = "test.binary"; + config.prefault = true; + { + SortedModel copy_model("test.arpa", config); + } + SortedModel binary("test.binary"); + Starters(binary); + Continuation(binary); +} + + +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc new file mode 100644 index 00000000..7bd23d76 --- /dev/null +++ b/klm/lm/sri.cc @@ -0,0 +1,115 @@ +#include "lm/exception.hh" +#include "lm/sri.hh" + +#include <Ngram.h> +#include <Vocab.h> + +#include <errno.h> + +namespace lm { +namespace sri { + +Vocabulary::Vocabulary() : sri_(new Vocab) {} + +Vocabulary::~Vocabulary() {} + +WordIndex Vocabulary::Index(const char *str) const { + WordIndex ret = sri_->getIndex(str); + // NGram wants the index of Vocab_Unknown for unknown words, but for some reason SRI returns Vocab_None here :-(. + if (ret == Vocab_None) { + return not_found_; + } else { + return ret; + } +} + +const char *Vocabulary::Word(WordIndex index) const { + return sri_->getWord(index); +} + +void Vocabulary::FinishedLoading() { + SetSpecial( + sri_->ssIndex(), + sri_->seIndex(), + sri_->unkIndex(), + sri_->highIndex() + 1); +} + +namespace { +Ngram *MakeSRIModel(const char *file_name, unsigned int ngram_length, Vocab &sri_vocab) { + sri_vocab.unkIsWord() = true; + std::auto_ptr<Ngram> ret(new Ngram(sri_vocab, ngram_length)); + File file(file_name, "r"); + errno = 0; + if (!ret->read(file)) { + UTIL_THROW(FormatLoadException, "reading file " << file_name << " with SRI failed."); + } + return ret.release(); +} +} // namespace + +Model::Model(const char *file_name, unsigned int ngram_length) : sri_(MakeSRIModel(file_name, ngram_length, *vocab_.sri_)) { + if (!sri_->setorder()) { + UTIL_THROW(FormatLoadException, "Can't have an SRI model with order 0."); + } + vocab_.FinishedLoading(); + State begin_state = State(); + begin_state.valid_length_ = 1; + if (kMaxOrder > 1) { + begin_state.history_[0] = vocab_.BeginSentence(); + if (kMaxOrder > 2) begin_state.history_[1] = Vocab_None; + } + State null_state = State(); + null_state.valid_length_ = 0; + if (kMaxOrder > 1) null_state.history_[0] = Vocab_None; + Init(begin_state, null_state, vocab_, sri_->setorder()); + not_found_ = vocab_.NotFound(); +} + +Model::~Model() {} + +namespace { + +/* Argh SRI's wordProb knows the ngram length but doesn't return it. One more + * reason you should use my model. */ +// TODO(stolcke): fix SRILM so I don't have to do this. +unsigned int MatchedLength(Ngram &model, const WordIndex new_word, const SRIVocabIndex *const_history) { + unsigned int out_length = 0; + // This gets the length of context used, which is ngram_length - 1 unless new_word is OOV in which case it is 0. + model.contextID(new_word, const_history, out_length); + return out_length + 1; +} + +} // namespace + +FullScoreReturn Model::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { + // If you get a compiler in this function, change SRIVocabIndex in sri.hh to match the one found in SRI's Vocab.h. + const SRIVocabIndex *const_history; + SRIVocabIndex local_history[Order()]; + if (in_state.valid_length_ < kMaxOrder - 1) { + const_history = in_state.history_; + } else { + std::copy(in_state.history_, in_state.history_ + in_state.valid_length_, local_history); + local_history[in_state.valid_length_] = Vocab_None; + const_history = local_history; + } + FullScoreReturn ret; + if (new_word != not_found_) { + ret.ngram_length = MatchedLength(*sri_, new_word, const_history); + out_state.history_[0] = new_word; + out_state.valid_length_ = std::min<unsigned char>(ret.ngram_length, Order() - 1); + std::copy(const_history, const_history + out_state.valid_length_ - 1, out_state.history_ + 1); + if (out_state.valid_length_ < kMaxOrder - 1) { + out_state.history_[out_state.valid_length_] = Vocab_None; + } + } else { + ret.ngram_length = 0; + if (kMaxOrder > 1) out_state.history_[0] = Vocab_None; + out_state.valid_length_ = 0; + } + ret.prob = sri_->wordProb(new_word, const_history); + return ret; +} + +} // namespace sri +} // namespace lm diff --git a/klm/lm/sri.hh b/klm/lm/sri.hh new file mode 100644 index 00000000..b57e9b73 --- /dev/null +++ b/klm/lm/sri.hh @@ -0,0 +1,102 @@ +#ifndef LM_SRI__ +#define LM_SRI__ + +#include "lm/facade.hh" +#include "util/murmur_hash.hh" + +#include <cmath> +#include <exception> +#include <memory> + +class Ngram; +class Vocab; + +/* The ngram length reported uses some random API I found and may be wrong. + * + * See ngram, which should return equivalent results. + */ + +namespace lm { +namespace sri { + +static const unsigned int kMaxOrder = 6; + +/* This should match VocabIndex found in SRI's Vocab.h + * The reason I define this here independently is that SRI's headers + * pollute and increase compile time. + * It's difficult to extract this from their header and anyway would + * break packaging. + * If these differ there will be a compiler error in ActuallyCall. + */ +typedef unsigned int SRIVocabIndex; + +class State { + public: + // You shouldn't need to touch these, but they're public so State will be a POD. + // If valid_length_ < kMaxOrder - 1 then history_[valid_length_] == Vocab_None. + SRIVocabIndex history_[kMaxOrder - 1]; + unsigned char valid_length_; +}; + +inline bool operator==(const State &left, const State &right) { + if (left.valid_length_ != right.valid_length_) { + return false; + } + for (const SRIVocabIndex *l = left.history_, *r = right.history_; + l != left.history_ + left.valid_length_; + ++l, ++r) { + if (*l != *r) return false; + } + return true; +} + +inline size_t hash_value(const State &state) { + return util::MurmurHashNative(&state.history_, sizeof(SRIVocabIndex) * state.valid_length_); +} + +class Vocabulary : public base::Vocabulary { + public: + Vocabulary(); + + ~Vocabulary(); + + WordIndex Index(const StringPiece &str) const { + std::string temp(str.data(), str.length()); + return Index(temp.c_str()); + } + WordIndex Index(const std::string &str) const { + return Index(str.c_str()); + } + WordIndex Index(const char *str) const; + + const char *Word(WordIndex index) const; + + private: + friend class Model; + void FinishedLoading(); + + // The parent class isn't copyable so auto_ptr is the same as scoped_ptr + // but without the boost dependence. + mutable std::auto_ptr<Vocab> sri_; +}; + +class Model : public base::ModelFacade<Model, State, Vocabulary> { + public: + Model(const char *file_name, unsigned int ngram_length); + + ~Model(); + + FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; + + private: + Vocabulary vocab_; + + mutable std::auto_ptr<Ngram> sri_; + + WordIndex not_found_; +}; + +} // namespace sri +} // namespace lm + +#endif // LM_SRI__ diff --git a/klm/lm/sri_test.cc b/klm/lm/sri_test.cc new file mode 100644 index 00000000..e697d722 --- /dev/null +++ b/klm/lm/sri_test.cc @@ -0,0 +1,65 @@ +#include "lm/sri.hh" + +#include <stdlib.h> + +#define BOOST_TEST_MODULE SRITest +#include <boost/test/unit_test.hpp> + +namespace lm { +namespace sri { +namespace { + +#define StartTest(word, ngram, score) \ + ret = model.FullScore( \ + state, \ + model.GetVocabulary().Index(word), \ + out);\ + BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ + BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ + BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); + +#define AppendTest(word, ngram, score) \ + StartTest(word, ngram, score) \ + state = out; + +template <class M> void Starters(M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + StartTest("looking", 2, -0.4846522); + + // , probability plus <s> backoff + StartTest(",", 1, -1.383514 + -0.4149733); + // <unk> probability plus <s> backoff + StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); +} + +template <class M> void Continuation(M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + AppendTest("looking", 2, -0.484652); + AppendTest("on", 3, -0.348837); + AppendTest("a", 4, -0.0155266); + AppendTest("little", 5, -0.00306122); + State preserve = state; + AppendTest("the", 1, -4.04005); + AppendTest("biarritz", 1, -1.9889); + AppendTest("not_found", 0, -2.29666); + AppendTest("more", 1, -1.20632); + AppendTest(".", 2, -0.51363); + AppendTest("</s>", 3, -0.0191651); + + state = preserve; + AppendTest("more", 5, -0.00181395); + AppendTest("loin", 5, -0.0432557); +} + +BOOST_AUTO_TEST_CASE(starters) { Model m("test.arpa", 5); Starters(m); } +BOOST_AUTO_TEST_CASE(continuation) { Model m("test.arpa", 5); Continuation(m); } + +} // namespace +} // namespace sri +} // namespace lm diff --git a/klm/lm/test.arpa b/klm/lm/test.arpa new file mode 100644 index 00000000..9d674e83 --- /dev/null +++ b/klm/lm/test.arpa @@ -0,0 +1,112 @@ + +\data\ +ngram 1=34 +ngram 2=43 +ngram 3=8 +ngram 4=5 +ngram 5=3 + +\1-grams: +-1.383514 , -0.30103 +-1.139057 . -0.845098 +-1.029493 </s> +-99 <s> -0.4149733 +-1.995635 <unk> +-1.285941 a -0.69897 +-1.687872 also -0.30103 +-1.687872 beyond -0.30103 +-1.687872 biarritz -0.30103 +-1.687872 call -0.30103 +-1.687872 concerns -0.30103 +-1.687872 consider -0.30103 +-1.687872 considering -0.30103 +-1.687872 for -0.30103 +-1.509559 higher -0.30103 +-1.687872 however -0.30103 +-1.687872 i -0.30103 +-1.687872 immediate -0.30103 +-1.687872 in -0.30103 +-1.687872 is -0.30103 +-1.285941 little -0.69897 +-1.383514 loin -0.30103 +-1.687872 look -0.30103 +-1.285941 looking -0.4771212 +-1.206319 more -0.544068 +-1.509559 on -0.4771212 +-1.509559 screening -0.4771212 +-1.687872 small -0.30103 +-1.687872 the -0.30103 +-1.687872 to -0.30103 +-1.687872 watch -0.30103 +-1.687872 watching -0.30103 +-1.687872 what -0.30103 +-1.687872 would -0.30103 + +\2-grams: +-0.6925742 , . +-0.7522095 , however +-0.7522095 , is +-0.0602359 . </s> +-0.4846522 <s> looking -0.4771214 +-1.051485 <s> screening +-1.07153 <s> the +-1.07153 <s> watching +-1.07153 <s> what +-0.09132547 a little -0.69897 +-0.2922095 also call +-0.2922095 beyond immediate +-0.2705918 biarritz . +-0.2922095 call for +-0.2922095 concerns in +-0.2922095 consider watch +-0.2922095 considering consider +-0.2834328 for , +-0.5511513 higher more +-0.5845945 higher small +-0.2834328 however , +-0.2922095 i would +-0.2922095 immediate concerns +-0.2922095 in biarritz +-0.2922095 is to +-0.09021038 little more -0.1998621 +-0.7273645 loin , +-0.6925742 loin . +-0.6708385 loin </s> +-0.2922095 look beyond +-0.4638903 looking higher +-0.4638903 looking on -0.4771212 +-0.5136299 more . -0.4771212 +-0.3561665 more loin +-0.1649931 on a -0.4771213 +-0.1649931 screening a -0.4771213 +-0.2705918 small . +-0.287799 the screening +-0.2922095 to look +-0.2622373 watch </s> +-0.2922095 watching considering +-0.2922095 what i +-0.2922095 would also + +\3-grams: +-0.01916512 more . </s> +-0.0283603 on a little -0.4771212 +-0.0283603 screening a little -0.4771212 +-0.01660496 a little more -0.09409451 +-0.3488368 <s> looking higher +-0.3488368 <s> looking on -0.4771212 +-0.1892331 little more loin +-0.04835128 looking on a -0.4771212 + +\4-grams: +-0.009249173 looking on a little -0.4771212 +-0.005464747 on a little more -0.4771212 +-0.005464747 screening a little more +-0.1453306 a little more loin +-0.01552657 <s> looking on a -0.4771212 + +\5-grams: +-0.003061223 <s> looking on a little +-0.001813953 looking on a little more +-0.0432557 on a little more loin + +\end\ diff --git a/klm/lm/test.binary b/klm/lm/test.binary Binary files differnew file mode 100644 index 00000000..90bd2b76 --- /dev/null +++ b/klm/lm/test.binary diff --git a/klm/lm/virtual_interface.cc b/klm/lm/virtual_interface.cc new file mode 100644 index 00000000..9c7151f9 --- /dev/null +++ b/klm/lm/virtual_interface.cc @@ -0,0 +1,22 @@ +#include "lm/virtual_interface.hh" + +#include "lm/exception.hh" + +namespace lm { +namespace base { + +Vocabulary::~Vocabulary() {} + +void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) { + begin_sentence_ = begin_sentence; + end_sentence_ = end_sentence; + not_found_ = not_found; + available_ = available; + if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>"); + if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>"); +} + +Model::~Model() {} + +} // namespace base +} // namespace lm diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh new file mode 100644 index 00000000..621a129e --- /dev/null +++ b/klm/lm/virtual_interface.hh @@ -0,0 +1,156 @@ +#ifndef LM_VIRTUAL_INTERFACE__ +#define LM_VIRTUAL_INTERFACE__ + +#include "lm/word_index.hh" +#include "util/string_piece.hh" + +#include <string> + +namespace lm { + +struct FullScoreReturn { + float prob; + unsigned char ngram_length; +}; + +namespace base { + +template <class T, class U, class V> class ModelFacade; + +/* Vocabulary interface. Call Index(string) and get a word index for use in + * calling Model. It provides faster convenience functions for <s>, </s>, and + * <unk> although you can also find these using Index. + * + * Some models do not load the mapping from index to string. If you need this, + * check if the model Vocabulary class implements such a function and access it + * directly. + * + * The Vocabulary object is always owned by the Model and can be retrieved from + * the Model using BaseVocabulary() for this abstract interface or + * GetVocabulary() for the actual implementation (in which case you'll need the + * actual implementation of the Model too). + */ +class Vocabulary { + public: + virtual ~Vocabulary(); + + WordIndex BeginSentence() const { return begin_sentence_; } + WordIndex EndSentence() const { return end_sentence_; } + WordIndex NotFound() const { return not_found_; } + // FullScoreReturn start index of unused word assignments. + WordIndex Available() const { return available_; } + + /* Most implementations allow StringPiece lookups and need only override + * Index(StringPiece). SRI requires null termination and overrides all + * three methods. + */ + virtual WordIndex Index(const StringPiece &str) const = 0; + virtual WordIndex Index(const std::string &str) const { + return Index(StringPiece(str)); + } + virtual WordIndex Index(const char *str) const { + return Index(StringPiece(str)); + } + + protected: + // Call SetSpecial afterward. + Vocabulary() {} + + Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) { + SetSpecial(begin_sentence, end_sentence, not_found, available); + } + + void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available); + + WordIndex begin_sentence_, end_sentence_, not_found_, available_; + + private: + // Disable copy constructors. They're private and undefined. + // Ersatz boost::noncopyable. + Vocabulary(const Vocabulary &); + Vocabulary &operator=(const Vocabulary &); +}; + +/* There are two ways to access a Model. + * + * + * OPTION 1: Access the Model directly (e.g. lm::ngram::Model in ngram.hh). + * Every Model implements the scoring function: + * float Score( + * const Model::State &in_state, + * const WordIndex new_word, + * Model::State &out_state) const; + * + * It can also return the length of n-gram matched by the model: + * FullScoreReturn FullScore( + * const Model::State &in_state, + * const WordIndex new_word, + * Model::State &out_state) const; + * + * There are also accessor functions: + * const State &BeginSentenceState() const; + * const State &NullContextState() const; + * const Vocabulary &GetVocabulary() const; + * unsigned int Order() const; + * + * NB: In case you're wondering why the model implementation looks like it's + * missing these methods, see facade.hh. + * + * This is the fastest way to use a model and presents a normal State class to + * be included in hypothesis state structure. + * + * + * OPTION 2: Use the virtual interface below. + * + * The virtual interface allow you to decide which Model to use at runtime + * without templatizing everything on the Model type. However, each Model has + * its own State class, so a single State cannot be efficiently provided (it + * would require using the maximum memory of any Model's State or memory + * allocation with each lookup). This means you become responsible for + * allocating memory with size StateSize() and passing it to the Score or + * FullScore functions provided here. + * + * For example, cdec has a std::string containing the entire state of a + * hypothesis. It can reserve StateSize bytes in this string for the model + * state. + * + * All the State objects are POD, so it's ok to use raw memory for storing + * State. + */ +class Model { + public: + virtual ~Model(); + + size_t StateSize() const { return state_size_; } + const void *BeginSentenceMemory() const { return begin_sentence_memory_; } + const void *NullContextMemory() const { return null_context_memory_; } + + virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + + virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + + unsigned char Order() const { return order_; } + + const Vocabulary &BaseVocabulary() const { return *base_vocab_; } + + private: + template <class T, class U, class V> friend class ModelFacade; + explicit Model(size_t state_size) : state_size_(state_size) {} + + const size_t state_size_; + const void *begin_sentence_memory_, *null_context_memory_; + + const Vocabulary *base_vocab_; + + unsigned char order_; + + // Disable copy constructors. They're private and undefined. + // Ersatz boost::noncopyable. + Model(const Model &); + Model &operator=(const Model &); +}; + +} // mamespace base +} // namespace lm + +#endif // LM_VIRTUAL_INTERFACE__ diff --git a/klm/lm/word_index.hh b/klm/lm/word_index.hh new file mode 100644 index 00000000..67841c30 --- /dev/null +++ b/klm/lm/word_index.hh @@ -0,0 +1,11 @@ +// Separate header because this is used often. +#ifndef LM_WORD_INDEX__ +#define LM_WORD_INDEX__ + +namespace lm { +typedef unsigned int WordIndex; +} // namespace lm + +typedef lm::WordIndex LMWordIndex; + +#endif |