diff options
author | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-11-10 02:02:04 +0000 |
---|---|---|
committer | redpony <redpony@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-11-10 02:02:04 +0000 |
commit | 15b03336564d5e57e50693f19dd81b45076af5d4 (patch) | |
tree | c2072893a43f4c75f0ad5ebe3080bfa901faf18f /klm | |
parent | 1336aecfe930546f8836ffe65dd5ff78434084eb (diff) |
new version of klm
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@706 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'klm')
34 files changed, 2097 insertions, 634 deletions
diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 48f9889e..eb71c0f5 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -7,11 +7,17 @@ noinst_LIBRARIES = libklm.a libklm_a_SOURCES = \ - exception.cc \ - ngram.cc \ - ngram_build_binary.cc \ + binary_format.cc \ + config.cc \ + lm_exception.cc \ + model.cc \ ngram_query.cc \ - virtual_interface.cc + read_arpa.cc \ + search_hashed.cc \ + search_trie.cc \ + trie.cc \ + virtual_interface.cc \ + vocab.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc new file mode 100644 index 00000000..2a075b6b --- /dev/null +++ b/klm/lm/binary_format.cc @@ -0,0 +1,191 @@ +#include "lm/binary_format.hh" + +#include "lm/lm_exception.hh" +#include "util/file_piece.hh" + +#include <limits> +#include <string> + +#include <fcntl.h> +#include <errno.h> +#include <stdlib.h> +#include <sys/mman.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <unistd.h> + +namespace lm { +namespace ngram { +namespace { +const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0"; +const long int kMagicVersion = 1; + +// Test values. +struct Sanity { + char magic[sizeof(kMagicBytes)]; + float zero_f, one_f, minus_half_f; + WordIndex one_word_index, max_word_index; + uint64_t one_uint64; + + void SetToReference() { + std::memcpy(magic, kMagicBytes, sizeof(magic)); + zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; + one_word_index = 1; + max_word_index = std::numeric_limits<WordIndex>::max(); + one_uint64 = 1; + } +}; + +const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; + +std::size_t Align8(std::size_t in) { + std::size_t off = in % 8; + if (!off) return in; + return in + 8 - off; +} + +std::size_t TotalHeaderSize(unsigned char order) { + return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); +} + +void ReadLoop(int fd, void *to_void, std::size_t size) { + uint8_t *to = static_cast<uint8_t*>(to_void); + while (size) { + ssize_t ret = read(fd, to, size); + if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file"); + if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short"); + to += ret; + size -= ret; + } +} + +void WriteHeader(void *to, const Parameters ¶ms) { + Sanity header = Sanity(); + header.SetToReference(); + memcpy(to, &header, sizeof(Sanity)); + char *out = reinterpret_cast<char*>(to) + sizeof(Sanity); + + *reinterpret_cast<FixedWidthParameters*>(out) = params.fixed; + out += sizeof(FixedWidthParameters); + + uint64_t *counts = reinterpret_cast<uint64_t*>(out); + for (std::size_t i = 0; i < params.counts.size(); ++i) { + counts[i] = params.counts[i]; + } +} + +} // namespace +namespace detail { + +bool IsBinaryFormat(int fd) { + const off_t size = util::SizeFile(fd); + if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(Sanity)))) return false; + // Try reading the header. + util::scoped_memory memory; + try { + util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory); + } catch (const util::Exception &e) { + return false; + } + Sanity reference_header = Sanity(); + reference_header.SetToReference(); + if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true; + if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) { + char *end_ptr; + const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion); + long int version = strtol(begin_version, &end_ptr, 10); + if ((end_ptr != begin_version) && version != kMagicVersion) { + UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA. Sorry."); + } + UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture."); + } + return false; +} + +void ReadHeader(int fd, Parameters &out) { + if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file"); + ReadLoop(fd, &out.fixed, sizeof(out.fixed)); + if (out.fixed.probing_multiplier < 1.0) + UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); + + out.counts.resize(static_cast<std::size_t>(out.fixed.order)); + ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); +} + +void MatchCheck(ModelType model_type, const Parameters ¶ms) { + if (params.fixed.model_type != model_type) { + if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) + UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code."); + UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]); + } +} + +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { + const off_t file_size = util::SizeFile(backing.file.get()); + // The header is smaller than a page, so we have to map the whole header as well. + std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; + if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) + UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); + + util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory); + + if (config.enumerate_vocab && !params.fixed.has_vocabulary) + UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + + if (config.enumerate_vocab) { + if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) + UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words"); + } + return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(params.counts.size()); +} + +uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) { + if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0"); + if (config.write_mmap) { + std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size; + // Write out an mmap file. + backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED); + + Parameters params; + params.counts = counts; + params.fixed.order = counts.size(); + params.fixed.probing_multiplier = config.probing_multiplier; + params.fixed.model_type = model_type; + params.fixed.has_vocabulary = config.include_vocab; + + WriteHeader(backing.memory.get(), params); + + if (params.fixed.has_vocabulary) { + if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) + UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words"); + } + return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(counts.size()); + } else { + backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); + return reinterpret_cast<uint8_t*>(backing.memory.get()); + } +} + +void ComplainAboutARPA(const Config &config, ModelType model_type) { + if (config.write_mmap || !config.messages) return; + if (config.arpa_complain == Config::ALL) { + *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; + } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { + *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; + } +} + +} // namespace detail + +bool RecognizeBinary(const char *file, ModelType &recognized) { + util::scoped_fd fd(util::OpenReadOrThrow(file)); + if (!detail::IsBinaryFormat(fd.get())) return false; + Parameters params; + detail::ReadHeader(fd.get(), params); + recognized = params.fixed.model_type; + return true; +} + +} // namespace ngram +} // namespace lm diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc new file mode 100644 index 00000000..4db631a2 --- /dev/null +++ b/klm/lm/build_binary.cc @@ -0,0 +1,13 @@ +#include "lm/model.hh" + +#include <iostream> + +int main(int argc, char *argv[]) { + if (argc != 3) { + std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl; + return 1; + } + lm::ngram::Config config; + config.write_mmap = argv[2]; + lm::ngram::Model(argv[1], config); +} diff --git a/klm/lm/config.cc b/klm/lm/config.cc new file mode 100644 index 00000000..2831d578 --- /dev/null +++ b/klm/lm/config.cc @@ -0,0 +1,22 @@ +#include "lm/config.hh" + +#include <iostream> + +namespace lm { +namespace ngram { + +Config::Config() : + messages(&std::cerr), + enumerate_vocab(NULL), + unknown_missing(COMPLAIN), + unknown_missing_prob(0.0), + probing_multiplier(1.5), + building_memory(1073741824ULL), // 1 GB + temporary_directory_prefix(NULL), + arpa_complain(ALL), + write_mmap(NULL), + include_vocab(true), + load_method(util::POPULATE_OR_READ) {} + +} // namespace ngram +} // namespace lm diff --git a/klm/lm/lm_exception.cc b/klm/lm/lm_exception.cc new file mode 100644 index 00000000..ab2ec52f --- /dev/null +++ b/klm/lm/lm_exception.cc @@ -0,0 +1,21 @@ +#include "lm/lm_exception.hh" + +#include<errno.h> +#include<stdio.h> + +namespace lm { + +LoadException::LoadException() throw() {} +LoadException::~LoadException() throw() {} +VocabLoadException::VocabLoadException() throw() {} +VocabLoadException::~VocabLoadException() throw() {} + +FormatLoadException::FormatLoadException() throw() {} +FormatLoadException::~FormatLoadException() throw() {} + +SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() { + *this << "Missing special word " << which; +} +SpecialWordMissingException::~SpecialWordMissingException() throw() {} + +} // namespace lm diff --git a/klm/lm/model.cc b/klm/lm/model.cc new file mode 100644 index 00000000..6921d4d9 --- /dev/null +++ b/klm/lm/model.cc @@ -0,0 +1,239 @@ +#include "lm/model.hh" + +#include "lm/lm_exception.hh" +#include "lm/search_hashed.hh" +#include "lm/search_trie.hh" +#include "lm/read_arpa.hh" +#include "util/murmur_hash.hh" + +#include <algorithm> +#include <functional> +#include <numeric> +#include <cmath> + +namespace lm { +namespace ngram { + +size_t hash_value(const State &state) { + return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); +} + +namespace detail { + +template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) { + uint8_t *start = static_cast<uint8_t*>(base); + size_t allocated = VocabularyT::Size(counts[0], config); + vocab_.SetupMemory(start, allocated, counts[0], config); + start += allocated; + start = search_.SetupMemory(start, counts, config); + if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config)); +} + +template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { + LoadLM(file, config, *this); + + // g++ prints warnings unless these are fully initialized. + State begin_sentence = State(); + begin_sentence.valid_length_ = 1; + begin_sentence.history_[0] = vocab_.BeginSentence(); + begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; + State null_context = State(); + null_context.valid_length_ = 0; + P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { + SetupMemory(start, params.counts, config); + vocab_.LoadedBinary(fd, config.enumerate_vocab); + search_.unigram.LoadedBinary(); + for (typename std::vector<Middle>::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) { + i->LoadedBinary(); + } + search_.longest.LoadedBinary(); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters ¶ms, const Config &config) { + SetupMemory(start, params.counts, config); + + if (config.write_mmap) { + WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get()); + vocab_.ConfigureEnumerate(&wrap, params.counts[0]); + search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + } else { + vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]); + search_.InitializeFromARPA(file, f, params.counts, config, vocab_); + } + // TODO: fail faster? + if (!vocab_.SawUnk()) { + switch(config.unknown_missing) { + case Config::THROW_UP: + { + SpecialWordMissingException e("<unk>"); + e << " and configuration was set to throw if unknown is missing"; + throw e; + } + case Config::COMPLAIN: + if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl; + // There's no break;. This is by design. + case Config::SILENT: + // Default probabilities for unknown. + search_.unigram.Unknown().backoff = 0.0; + search_.unigram.Unknown().prob = config.unknown_missing_prob; + break; + } + } + if (std::fabs(search_.unigram.Unknown().backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff); +} + +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { + unsigned char backoff_start; + FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state); + if (backoff_start - 1 < in_state.valid_length_) { + ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + } + return ret; +} + +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { + unsigned char backoff_start; + context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); + FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state); + ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start); + return ret; +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { + context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); + if (context_rend == context_rbegin || *context_rbegin == 0) { + out_state.valid_length_ = 0; + return; + } + float ignored_prob; + typename Search::Node node; + search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); + float *backoff_out = out_state.backoff_ + 1; + const WordIndex *i = context_rbegin + 1; + for (; i < context_rend; ++i, ++backoff_out) { + if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) { + out_state.valid_length_ = i - context_rbegin; + std::copy(context_rbegin, i, out_state.history_); + return; + } + } + std::copy(context_rbegin, context_rend, out_state.history_); + out_state.valid_length_ = static_cast<unsigned char>(context_rend - context_rbegin); +} + +template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup( + const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const { + // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). + if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0; + float ret = 0.0; + if (start == 1) { + ret += search_.unigram.Lookup(*context_rbegin).backoff; + start = 2; + } + typename Search::Node node; + if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { + return 0.0; + } + float backoff; + // i is the order of the backoff we're looking for. + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { + if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; + ret += backoff; + } + return ret; +} + +/* Ugly optimized function. Produce a score excluding backoff. + * The search goes in increasing order of ngram length. + * Context goes backward, so context_begin is the word immediately preceeding + * new_word. + */ +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff( + const WordIndex *context_rbegin, + const WordIndex *context_rend, + const WordIndex new_word, + unsigned char &backoff_start, + State &out_state) const { + FullScoreReturn ret; + typename Search::Node node; + float *backoff_out(out_state.backoff_); + search_.LookupUnigram(new_word, ret.prob, *backoff_out, node); + if (new_word == 0) { + ret.ngram_length = out_state.valid_length_ = 0; + // All of backoff. + backoff_start = 1; + return ret; + } + out_state.history_[0] = new_word; + if (context_rbegin == context_rend) { + ret.ngram_length = out_state.valid_length_ = 1; + // No backoff because we don't have the history for it. + backoff_start = P::Order(); + return ret; + } + ++backoff_out; + + // Ok now we now that the bigram contains known words. Start by looking it up. + + const WordIndex *hist_iter = context_rbegin; + typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin(); + for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { + if (hist_iter == context_rend) { + // Ran out of history. No backoff. + backoff_start = P::Order(); + std::copy(context_rbegin, context_rend, out_state.history_ + 1); + ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1; + // ret.prob was already set. + return ret; + } + + if (mid_iter == search_.middle.end()) break; + + if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { + // Didn't find an ngram using hist_iter. + // The history used in the found n-gram is [context_rbegin, hist_iter). + std::copy(context_rbegin, hist_iter, out_state.history_ + 1); + // Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word. + ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1; + backoff_start = mid_iter - search_.middle.begin() + 1; + // ret.prob was already set. + return ret; + } + } + + // It passed every lookup in search_.middle. That means it's at least a (P::Order() - 1)-gram. + // All that's left is to check search_.longest. + + if (!search_.LookupLongest(*hist_iter, ret.prob, node)) { + // It's an (P::Order()-1)-gram + std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); + ret.ngram_length = out_state.valid_length_ = P::Order() - 1; + backoff_start = P::Order() - 1; + // ret.prob was already set. + return ret; + } + // It's an P::Order()-gram + // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. + std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); + out_state.valid_length_ = P::Order() - 1; + ret.ngram_length = P::Order(); + backoff_start = P::Order(); + return ret; +} + +template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; +template class GenericModel<SortedHashedSearch, SortedVocabulary>; +template class GenericModel<trie::TrieSearch, SortedVocabulary>; + +} // namespace detail +} // namespace ngram +} // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc new file mode 100644 index 00000000..159628d4 --- /dev/null +++ b/klm/lm/model_test.cc @@ -0,0 +1,200 @@ +#include "lm/model.hh" + +#include <stdlib.h> + +#define BOOST_TEST_MODULE ModelTest +#include <boost/test/unit_test.hpp> + +namespace lm { +namespace ngram { +namespace { + +#define StartTest(word, ngram, score) \ + ret = model.FullScore( \ + state, \ + model.GetVocabulary().Index(word), \ + out);\ + BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ + BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ + BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); + +#define AppendTest(word, ngram, score) \ + StartTest(word, ngram, score) \ + state = out; + +template <class M> void Starters(const M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + StartTest("looking", 2, -0.4846522); + + // , probability plus <s> backoff + StartTest(",", 1, -1.383514 + -0.4149733); + // <unk> probability plus <s> backoff + StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); +} + +template <class M> void Continuation(const M &model) { + FullScoreReturn ret; + Model::State state(model.BeginSentenceState()); + Model::State out; + + AppendTest("looking", 2, -0.484652); + AppendTest("on", 3, -0.348837); + AppendTest("a", 4, -0.0155266); + AppendTest("little", 5, -0.00306122); + State preserve = state; + AppendTest("the", 1, -4.04005); + AppendTest("biarritz", 1, -1.9889); + AppendTest("not_found", 0, -2.29666); + AppendTest("more", 1, -1.20632); + AppendTest(".", 2, -0.51363); + AppendTest("</s>", 3, -0.0191651); + + state = preserve; + AppendTest("more", 5, -0.00181395); + AppendTest("loin", 5, -0.0432557); +} + +#define StatelessTest(word, provide, ngram, score) \ + ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \ + BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ + BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ + model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \ + ret = model.FullScore(before, indices[num_words - word - 1], out); \ + BOOST_CHECK(state == out); \ + BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ + BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); + +template <class M> void Stateless(const M &model) { + const char *words[] = {"<s>", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", "</s>"}; + const size_t num_words = sizeof(words) / sizeof(const char*); + // Silience "array subscript is above array bounds" when extracting end pointer. + WordIndex indices[num_words + 1]; + for (unsigned int i = 0; i < num_words; ++i) { + indices[num_words - 1 - i] = model.GetVocabulary().Index(words[i]); + } + FullScoreReturn ret; + State state, out, before; + + ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state); + BOOST_CHECK_CLOSE(-0.484652, ret.prob, 0.001); + StatelessTest(1, 1, 2, -0.484652); + + // looking + StatelessTest(1, 2, 2, -0.484652); + // on + AppendTest("on", 3, -0.348837); + StatelessTest(2, 3, 3, -0.348837); + StatelessTest(2, 2, 3, -0.348837); + StatelessTest(2, 1, 2, -0.4638903); + // a + StatelessTest(3, 4, 4, -0.0155266); + // little + AppendTest("little", 5, -0.00306122); + StatelessTest(4, 5, 5, -0.00306122); + // the + AppendTest("the", 1, -4.04005); + StatelessTest(5, 5, 1, -4.04005); + // No context of the. + StatelessTest(5, 0, 1, -1.687872); + // biarritz + StatelessTest(6, 1, 1, -1.9889); + // not found + StatelessTest(7, 1, 0, -2.29666); + StatelessTest(7, 0, 0, -1.995635); + + WordIndex unk[1]; + unk[0] = 0; + model.GetState(unk, unk + 1, state); + BOOST_CHECK_EQUAL(0, state.valid_length_); +} + +//const char *kExpectedOrderProbing[] = {"<unk>", ",", ".", "</s>", "<s>", "a", "also", "beyond", "biarritz", "call", "concerns", "consider", "considering", "for", "higher", "however", "i", "immediate", "in", "is", "little", "loin", "look", "looking", "more", "on", "screening", "small", "the", "to", "watch", "watching", "what", "would"}; + +class ExpectEnumerateVocab : public EnumerateVocab { + public: + ExpectEnumerateVocab() {} + + void Add(WordIndex index, const StringPiece &str) { + BOOST_CHECK_EQUAL(seen.size(), index); + seen.push_back(std::string(str.data(), str.length())); + } + + void Check(const base::Vocabulary &vocab) { + BOOST_CHECK_EQUAL(34, seen.size()); + BOOST_REQUIRE(!seen.empty()); + BOOST_CHECK_EQUAL("<unk>", seen[0]); + for (WordIndex i = 0; i < seen.size(); ++i) { + BOOST_CHECK_EQUAL(i, vocab.Index(seen[i])); + } + } + + void Clear() { + seen.clear(); + } + + std::vector<std::string> seen; +}; + +template <class ModelT> void LoadingTest() { + Config config; + config.arpa_complain = Config::NONE; + config.messages = NULL; + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test.arpa", config); + enumerate.Check(m.GetVocabulary()); + Starters(m); + Continuation(m); + Stateless(m); +} + +BOOST_AUTO_TEST_CASE(probing) { + LoadingTest<Model>(); +} + +BOOST_AUTO_TEST_CASE(sorted) { + LoadingTest<SortedModel>(); +} +BOOST_AUTO_TEST_CASE(trie) { + LoadingTest<TrieModel>(); +} + +template <class ModelT> void BinaryTest() { + Config config; + config.write_mmap = "test.binary"; + config.messages = NULL; + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + + { + ModelT copy_model("test.arpa", config); + enumerate.Check(copy_model.GetVocabulary()); + enumerate.Clear(); + } + + config.write_mmap = NULL; + + ModelT binary("test.binary", config); + enumerate.Check(binary.GetVocabulary()); + Starters(binary); + Continuation(binary); + Stateless(binary); + unlink("test.binary"); +} + +BOOST_AUTO_TEST_CASE(write_and_read_probing) { + BinaryTest<Model>(); +} +BOOST_AUTO_TEST_CASE(write_and_read_sorted) { + BinaryTest<SortedModel>(); +} +BOOST_AUTO_TEST_CASE(write_and_read_trie) { + BinaryTest<TrieModel>(); +} + +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/ngram.cc b/klm/lm/ngram.cc deleted file mode 100644 index a87c82aa..00000000 --- a/klm/lm/ngram.cc +++ /dev/null @@ -1,522 +0,0 @@ -#include "lm/ngram.hh" - -#include "lm/exception.hh" -#include "util/file_piece.hh" -#include "util/joint_sort.hh" -#include "util/murmur_hash.hh" -#include "util/probing_hash_table.hh" - -#include <algorithm> -#include <functional> -#include <numeric> -#include <limits> -#include <string> - -#include <cmath> -#include <fcntl.h> -#include <errno.h> -#include <stdlib.h> -#include <sys/mman.h> -#include <sys/types.h> -#include <sys/stat.h> -#include <unistd.h> - -namespace lm { -namespace ngram { - -size_t hash_value(const State &state) { - return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); -} - -namespace detail { -uint64_t HashForVocab(const char *str, std::size_t len) { - // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000 - // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit. - return util::MurmurHash64A(str, len, 0); -} - -void Prob::SetBackoff(float to) { - UTIL_THROW(FormatLoadException, "Attempt to set backoff " << to << " for the highest order n-gram"); -} - -// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok. -const uint64_t kUnknownHash = HashForVocab("<unk>", 5); -// Sadly some LMs have <UNK>. -const uint64_t kUnknownCapHash = HashForVocab("<UNK>", 5); - -} // namespace detail - -SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL) {} - -std::size_t SortedVocabulary::Size(std::size_t entries, float ignored) { - // Lead with the number of entries. - return sizeof(uint64_t) + sizeof(Entry) * entries; -} - -void SortedVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) { - assert(allocated >= Size(entries)); - // Leave space for number of entries. - begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1); - end_ = begin_; - saw_unk_ = false; -} - -WordIndex SortedVocabulary::Insert(const StringPiece &str) { - uint64_t hashed = detail::HashForVocab(str); - if (hashed == detail::kUnknownHash || hashed == detail::kUnknownCapHash) { - saw_unk_ = true; - return 0; - } - end_->key = hashed; - ++end_; - // This is 1 + the offset where it was inserted to make room for unk. - return end_ - begin_; -} - -bool SortedVocabulary::FinishedLoading(detail::ProbBackoff *reorder_vocab) { - util::JointSort(begin_, end_, reorder_vocab + 1); - SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1); - // Save size. - *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_; - return saw_unk_; -} - -void SortedVocabulary::LoadedBinary() { - end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); - SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1); -} - -namespace detail { - -template <class Search> MapVocabulary<Search>::MapVocabulary() {} - -template <class Search> void MapVocabulary<Search>::Init(void *start, std::size_t allocated, std::size_t entries) { - lookup_ = Lookup(start, allocated); - available_ = 1; - // Later if available_ != expected_available_ then we can throw UnknownMissingException. - saw_unk_ = false; -} - -template <class Search> WordIndex MapVocabulary<Search>::Insert(const StringPiece &str) { - uint64_t hashed = HashForVocab(str); - // Prevent unknown from going into the table. - if (hashed == kUnknownHash || hashed == kUnknownCapHash) { - saw_unk_ = true; - return 0; - } else { - lookup_.Insert(Lookup::Packing::Make(hashed, available_)); - return available_++; - } -} - -template <class Search> bool MapVocabulary<Search>::FinishedLoading(ProbBackoff *reorder_vocab) { - lookup_.FinishedInserting(); - SetSpecial(Index("<s>"), Index("</s>"), 0, available_); - return saw_unk_; -} - -template <class Search> void MapVocabulary<Search>::LoadedBinary() { - lookup_.LoadedBinary(); - SetSpecial(Index("<s>"), Index("</s>"), 0, available_); -} - -/* All of the entropy is in low order bits and boost::hash does poorly with - * these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a - * stable point: 0. But 0 is <unk> which won't be queried here anyway. - */ -inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { - uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL); - return ret; -} - -uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) { - if (word == word_end) return 0; - uint64_t current = static_cast<uint64_t>(*word); - for (++word; word != word_end; ++word) { - current = CombineWordHash(current, *word); - } - return current; -} - -bool IsEntirelyWhiteSpace(const StringPiece &line) { - for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) { - if (!isspace(line.data()[i])) return false; - } - return true; -} - -void ReadARPACounts(util::FilePiece &in, std::vector<size_t> &number) { - number.clear(); - StringPiece line; - if (!IsEntirelyWhiteSpace(line = in.ReadLine())) UTIL_THROW(FormatLoadException, "First line was \"" << line << "\" not blank"); - if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); - while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { - if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); - // So strtol doesn't go off the end of line. - std::string remaining(line.data() + 6, line.size() - 6); - char *end_ptr; - unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); - if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line); - if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line); - ++end_ptr; - const char *start = end_ptr; - long int count = std::strtol(start, &end_ptr, 10); - if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); - if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); - number.push_back(count); - } -} - -void ReadNGramHeader(util::FilePiece &in, unsigned int length) { - StringPiece line; - while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} - std::stringstream expected; - expected << '\\' << length << "-grams:"; - if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead."); -} - -// Special unigram reader because unigram's data structure is different and because we're inserting vocab words. -template <class Voc> void Read1Grams(util::FilePiece &f, const size_t count, Voc &vocab, ProbBackoff *unigrams) { - ReadNGramHeader(f, 1); - for (size_t i = 0; i < count; ++i) { - try { - float prob = f.ReadFloat(); - if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); - ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())]; - value.prob = prob; - switch (f.get()) { - case '\t': - value.SetBackoff(f.ReadFloat()); - if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); - break; - case '\n': - value.ZeroBackoff(); - break; - default: - UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); - } - } catch(util::Exception &e) { - e << " in the " << i << "th 1-gram at byte " << f.Offset(); - throw; - } - } - if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after unigrams at byte " << f.Offset()); -} - -template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) { - ReadNGramHeader(f, n); - - // vocab ids of words in reverse order - WordIndex vocab_ids[n]; - typename Store::Packing::Value value; - for (size_t i = 0; i < count; ++i) { - try { - value.prob = f.ReadFloat(); - for (WordIndex *vocab_out = &vocab_ids[n-1]; vocab_out >= vocab_ids; --vocab_out) { - *vocab_out = vocab.Index(f.ReadDelimited()); - } - uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n); - - switch (f.get()) { - case '\t': - value.SetBackoff(f.ReadFloat()); - if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); - break; - case '\n': - value.ZeroBackoff(); - break; - default: - UTIL_THROW(FormatLoadException, "Expected tab or newline after n-gram"); - } - store.Insert(Store::Packing::Make(key, value)); - } catch(util::Exception &e) { - e << " in the " << i << "th " << n << "-gram at byte " << f.Offset(); - throw; - } - } - - if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after " << n << "-grams at byte " << f.Offset()); - store.FinishedInserting(); -} - -template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<size_t> &counts, const Config &config) { - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit ngram.hh's kMaxOrder to at least this value and recompile."); - if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); - size_t memory_size = VocabularyT::Size(counts[0], config.probing_multiplier); - memory_size += sizeof(ProbBackoff) * (counts[0] + 1); // +1 for hallucinate <unk> - for (unsigned char n = 2; n < counts.size(); ++n) { - memory_size += Middle::Size(counts[n - 1], config.probing_multiplier); - } - memory_size += Longest::Size(counts.back(), config.probing_multiplier); - return memory_size; -} - -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(char *base, const std::vector<size_t> &counts, const Config &config) { - char *start = base; - size_t allocated = VocabularyT::Size(counts[0], config.probing_multiplier); - vocab_.Init(start, allocated, counts[0]); - start += allocated; - unigram_ = reinterpret_cast<ProbBackoff*>(start); - start += sizeof(ProbBackoff) * (counts[0] + 1); - for (unsigned int n = 2; n < counts.size(); ++n) { - allocated = Middle::Size(counts[n - 1], config.probing_multiplier); - middle_.push_back(Middle(start, allocated)); - start += allocated; - } - allocated = Longest::Size(counts.back(), config.probing_multiplier); - longest_ = Longest(start, allocated); - start += allocated; - if (static_cast<std::size_t>(start - base) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - base) << " but Size says they should take " << Size(counts, config)); -} - -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 0\n\0"; -struct BinaryFileHeader { - char magic[sizeof(kMagicBytes)]; - float zero_f, one_f, minus_half_f; - WordIndex one_word_index, max_word_index; - uint64_t one_uint64; - - void SetToReference() { - std::memcpy(magic, kMagicBytes, sizeof(magic)); - zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; - one_word_index = 1; - max_word_index = std::numeric_limits<WordIndex>::max(); - one_uint64 = 1; - } -}; - -bool IsBinaryFormat(int fd, off_t size) { - if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(BinaryFileHeader)))) return false; - // Try reading the header. - util::scoped_mmap memory(mmap(NULL, sizeof(BinaryFileHeader), PROT_READ, MAP_FILE | MAP_PRIVATE, fd, 0), sizeof(BinaryFileHeader)); - if (memory.get() == MAP_FAILED) return false; - BinaryFileHeader reference_header = BinaryFileHeader(); - reference_header.SetToReference(); - if (!memcmp(memory.get(), &reference_header, sizeof(BinaryFileHeader))) return true; - if (!memcmp(memory.get(), "mmap lm ", 8)) UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Was it built on a different machine or with a different compiler?"); - return false; -} - -std::size_t Align8(std::size_t in) { - std::size_t off = in % 8; - if (!off) return in; - return in + 8 - off; -} - -std::size_t TotalHeaderSize(unsigned int order) { - return Align8(sizeof(BinaryFileHeader) + 1 /* order */ + sizeof(uint64_t) * order /* counts */ + sizeof(float) /* probing multiplier */ + 1 /* search_tag */); -} - -void ReadBinaryHeader(const void *from, off_t size, std::vector<size_t> &out, float &probing_multiplier, unsigned char &search_tag) { - const char *from_char = reinterpret_cast<const char*>(from); - if (size < static_cast<off_t>(1 + sizeof(BinaryFileHeader))) UTIL_THROW(FormatLoadException, "File too short to have count information."); - // Skip over the BinaryFileHeader which was read by IsBinaryFormat. - from_char += sizeof(BinaryFileHeader); - unsigned char order = *reinterpret_cast<const unsigned char*>(from_char); - if (size < static_cast<off_t>(TotalHeaderSize(order))) UTIL_THROW(FormatLoadException, "File too short to have full header."); - out.resize(static_cast<std::size_t>(order)); - const uint64_t *counts = reinterpret_cast<const uint64_t*>(from_char + 1); - for (std::size_t i = 0; i < out.size(); ++i) { - out[i] = static_cast<std::size_t>(counts[i]); - } - const float *probing_ptr = reinterpret_cast<const float*>(counts + out.size()); - probing_multiplier = *probing_ptr; - search_tag = *reinterpret_cast<const char*>(probing_ptr + 1); -} - -void WriteBinaryHeader(void *to, const std::vector<size_t> &from, float probing_multiplier, char search_tag) { - BinaryFileHeader header = BinaryFileHeader(); - header.SetToReference(); - memcpy(to, &header, sizeof(BinaryFileHeader)); - char *out = reinterpret_cast<char*>(to) + sizeof(BinaryFileHeader); - *reinterpret_cast<unsigned char*>(out) = static_cast<unsigned char>(from.size()); - uint64_t *counts = reinterpret_cast<uint64_t*>(out + 1); - for (std::size_t i = 0; i < from.size(); ++i) { - counts[i] = from[i]; - } - float *probing_ptr = reinterpret_cast<float*>(counts + from.size()); - *probing_ptr = probing_multiplier; - *reinterpret_cast<char*>(probing_ptr + 1) = search_tag; -} - -template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, Config config) : mapped_file_(util::OpenReadOrThrow(file)) { - const off_t file_size = util::SizeFile(mapped_file_.get()); - - std::vector<size_t> counts; - - if (IsBinaryFormat(mapped_file_.get(), file_size)) { - memory_.reset(util::MapForRead(file_size, config.prefault, mapped_file_.get()), file_size); - - unsigned char search_tag; - ReadBinaryHeader(memory_.begin(), file_size, counts, config.probing_multiplier, search_tag); - if (config.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << config.probing_multiplier << " which is < 1.0."); - if (search_tag != Search::kBinaryTag) UTIL_THROW(FormatLoadException, "The binary file has a different search strategy than the one requested."); - size_t memory_size = Size(counts, config); - - char *start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size()); - if (memory_size != static_cast<size_t>(memory_.end() - start)) UTIL_THROW(FormatLoadException, "The mmap file " << file << " has size " << file_size << " but " << (memory_size + TotalHeaderSize(counts.size())) << " was expected based on the number of counts and configuration."); - - SetupMemory(start, counts, config); - vocab_.LoadedBinary(); - for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) { - i->LoadedBinary(); - } - longest_.LoadedBinary(); - - } else { - if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0"); - - util::FilePiece f(file, mapped_file_.release(), config.messages); - ReadARPACounts(f, counts); - size_t memory_size = Size(counts, config); - char *start; - - if (config.write_mmap) { - // Write out an mmap file. - util::MapZeroedWrite(config.write_mmap, TotalHeaderSize(counts.size()) + memory_size, mapped_file_, memory_); - WriteBinaryHeader(memory_.get(), counts, config.probing_multiplier, Search::kBinaryTag); - start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size()); - } else { - memory_.reset(util::MapAnonymous(memory_size), memory_size); - start = reinterpret_cast<char*>(memory_.get()); - } - SetupMemory(start, counts, config); - try { - LoadFromARPA(f, counts, config); - } catch (FormatLoadException &e) { - e << " in file " << file; - throw; - } - } - - // g++ prints warnings unless these are fully initialized. - State begin_sentence = State(); - begin_sentence.valid_length_ = 1; - begin_sentence.history_[0] = vocab_.BeginSentence(); - begin_sentence.backoff_[0] = unigram_[begin_sentence.history_[0]].backoff; - State null_context = State(); - null_context.valid_length_ = 0; - P::Init(begin_sentence, null_context, vocab_, counts.size()); -} - -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config) { - // Read the unigrams. - Read1Grams(f, counts[0], vocab_, unigram_); - bool saw_unk = vocab_.FinishedLoading(unigram_); - if (!saw_unk) { - switch(config.unknown_missing) { - case Config::THROW_UP: - { - SpecialWordMissingException e("<unk>"); - e << " and configuration was set to throw if unknown is missing"; - throw e; - } - case Config::COMPLAIN: - if (config.messages) *config.messages << "Language model is missing <unk>. Substituting probability " << config.unknown_missing_prob << "." << std::endl; - // There's no break;. This is by design. - case Config::SILENT: - // Default probabilities for unknown. - unigram_[0].backoff = 0.0; - unigram_[0].prob = config.unknown_missing_prob; - break; - } - } - - // Read the n-grams. - for (unsigned int n = 2; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab_, middle_[n-2]); - } - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab_, longest_); - if (std::fabs(unigram_[0].backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << unigram_[0].backoff); -} - -/* Ugly optimized function. - * in_state contains the previous ngram's length and backoff probabilites to - * be used here. out_state is populated with the found ngram length and - * backoffs that the next call will find useful. - * - * The search goes in increasing order of ngram length. - */ -template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore( - const State &in_state, - const WordIndex new_word, - State &out_state) const { - - FullScoreReturn ret; - // This is end pointer passed to SumBackoffs. - const ProbBackoff &unigram = unigram_[new_word]; - if (new_word == 0) { - ret.ngram_length = out_state.valid_length_ = 0; - // all of backoff. - ret.prob = std::accumulate( - in_state.backoff_, - in_state.backoff_ + in_state.valid_length_, - unigram.prob); - return ret; - } - float *backoff_out(out_state.backoff_); - *backoff_out = unigram.backoff; - ret.prob = unigram.prob; - out_state.history_[0] = new_word; - if (in_state.valid_length_ == 0) { - ret.ngram_length = out_state.valid_length_ = 1; - // No backoff because NGramLength() == 0 and unknown can't have backoff. - return ret; - } - ++backoff_out; - - // Ok now we now that the bigram contains known words. Start by looking it up. - - uint64_t lookup_hash = static_cast<uint64_t>(new_word); - const WordIndex *hist_iter = in_state.history_; - const WordIndex *const hist_end = hist_iter + in_state.valid_length_; - typename std::vector<Middle>::const_iterator mid_iter = middle_.begin(); - for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { - if (hist_iter == hist_end) { - // Used history [in_state.history_, hist_end) and ran out. No backoff. - std::copy(in_state.history_, hist_end, out_state.history_ + 1); - ret.ngram_length = out_state.valid_length_ = in_state.valid_length_ + 1; - // ret.prob was already set. - return ret; - } - lookup_hash = CombineWordHash(lookup_hash, *hist_iter); - if (mid_iter == middle_.end()) break; - typename Middle::ConstIterator found; - if (!mid_iter->Find(lookup_hash, found)) { - // Didn't find an ngram using hist_iter. - // The history used in the found n-gram is [in_state.history_, hist_iter). - std::copy(in_state.history_, hist_iter, out_state.history_ + 1); - // Therefore, we found a (hist_iter - in_state.history_ + 1)-gram including the last word. - ret.ngram_length = out_state.valid_length_ = (hist_iter - in_state.history_) + 1; - ret.prob = std::accumulate( - in_state.backoff_ + (mid_iter - middle_.begin()), - in_state.backoff_ + in_state.valid_length_, - ret.prob); - return ret; - } - *backoff_out = found->GetValue().backoff; - ret.prob = found->GetValue().prob; - } - - typename Longest::ConstIterator found; - if (!longest_.Find(lookup_hash, found)) { - // It's an (P::Order()-1)-gram - std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1); - ret.ngram_length = out_state.valid_length_ = P::Order() - 1; - ret.prob += in_state.backoff_[P::Order() - 2]; - return ret; - } - // It's an P::Order()-gram - // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. - std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1); - out_state.valid_length_ = P::Order() - 1; - ret.ngram_length = P::Order(); - ret.prob = found->GetValue().prob; - return ret; -} - -template class GenericModel<ProbingSearch, MapVocabulary<ProbingSearch> >; -template class GenericModel<SortedUniformSearch, SortedVocabulary>; -} // namespace detail -} // namespace ngram -} // namespace lm diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index d1970260..74457a74 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -1,4 +1,4 @@ -#include "lm/ngram.hh" +#include "lm/model.hh" #include <cstdlib> #include <fstream> diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc new file mode 100644 index 00000000..8e9a770d --- /dev/null +++ b/klm/lm/read_arpa.cc @@ -0,0 +1,154 @@ +#include "lm/read_arpa.hh" + +#include <cstdlib> +#include <vector> + +#include <ctype.h> +#include <inttypes.h> + +namespace lm { + +namespace { + +bool IsEntirelyWhiteSpace(const StringPiece &line) { + for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) { + if (!isspace(line.data()[i])) return false; + } + return true; +} + +template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &number) { + number.clear(); + StringPiece line; + if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { + if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) { + UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, run\nzcat " << in.FileName() << " |kenlm/build_binary /dev/stdin " << in.FileName() << ".binary\nIf this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); + } + UTIL_THROW(FormatLoadException, "First line was \"" << static_cast<int>(line.data()[1]) << "\" not blank"); + } + if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); + while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { + if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); + // So strtol doesn't go off the end of line. + std::string remaining(line.data() + 6, line.size() - 6); + char *end_ptr; + unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); + if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line); + if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line); + ++end_ptr; + const char *start = end_ptr; + long int count = std::strtol(start, &end_ptr, 10); + if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); + if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); + number.push_back(count); + } +} + +template <class F> void GenericReadNGramHeader(F &in, unsigned int length) { + StringPiece line; + while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + std::stringstream expected; + expected << '\\' << length << "-grams:"; + if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead. "); +} + +template <class F> void GenericReadEnd(F &in) { + StringPiece line; + do { + line = in.ReadLine(); + } while (IsEntirelyWhiteSpace(line)); + if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line); +} + +class FakeFilePiece { + public: + explicit FakeFilePiece(std::istream &in) : in_(in) { + in_.exceptions(std::ios::failbit | std::ios::badbit | std::ios::eofbit); + } + + StringPiece ReadLine() throw(util::EndOfFileException) { + getline(in_, buffer_); + return StringPiece(buffer_); + } + + float ReadFloat() { + float ret; + in_ >> ret; + return ret; + } + + const char *FileName() const { + // This only used for error messages and we don't know the file name. . . + return "$file"; + } + + private: + std::istream &in_; + std::string buffer_; +}; + +} // namespace + +void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { + GenericReadARPACounts(in, number); +} +void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number) { + FakeFilePiece fake(in); + GenericReadARPACounts(fake, number); +} +void ReadNGramHeader(util::FilePiece &in, unsigned int length) { + GenericReadNGramHeader(in, length); +} +void ReadNGramHeader(std::istream &in, unsigned int length) { + FakeFilePiece fake(in); + GenericReadNGramHeader(fake, length); +} + +void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) { + switch (in.get()) { + case '\t': + { + float got = in.ReadFloat(); + if (got != 0.0) + UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff."); + } + break; + case '\n': + break; + default: + UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); + } +} + +void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) { + switch (in.get()) { + case '\t': + weights.backoff = in.ReadFloat(); + if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); + break; + case '\n': + weights.backoff = 0.0; + break; + default: + UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); + } +} + +void ReadEnd(util::FilePiece &in) { + GenericReadEnd(in); + StringPiece line; + try { + while (true) { + line = in.ReadLine(); + if (!IsEntirelyWhiteSpace(line)) UTIL_THROW(FormatLoadException, "Trailing line " << line); + } + } catch (const util::EndOfFileException &e) { + return; + } +} +void ReadEnd(std::istream &in) { + FakeFilePiece fake(in); + GenericReadEnd(fake); +} + +} // namespace lm diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc new file mode 100644 index 00000000..9cb662a6 --- /dev/null +++ b/klm/lm/search_hashed.cc @@ -0,0 +1,66 @@ +#include "lm/search_hashed.hh" + +#include "lm/lm_exception.hh" +#include "lm/read_arpa.hh" +#include "lm/vocab.hh" + +#include "util/file_piece.hh" + +#include <string> + +namespace lm { +namespace ngram { + +namespace { + +/* All of the entropy is in low order bits and boost::hash does poorly with + * these. Odd numbers near 2^64 chosen by mashing on the keyboard. There is a + * stable point: 0. But 0 is <unk> which won't be queried here anyway. + */ +inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { + uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL); + return ret; +} + +uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) { + if (word == word_end) return 0; + uint64_t current = static_cast<uint64_t>(*word); + for (++word; word != word_end; ++word) { + current = CombineWordHash(current, *word); + } + return current; +} + +template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) { + ReadNGramHeader(f, n); + + // vocab ids of words in reverse order + WordIndex vocab_ids[n]; + typename Store::Packing::Value value; + for (size_t i = 0; i < count; ++i) { + ReadNGram(f, n, vocab, vocab_ids, value); + uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n); + store.Insert(Store::Packing::Make(key, value)); + } + + store.FinishedInserting(); +} + +} // namespace +namespace detail { + +template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &/*config*/, Voc &vocab) { + Read1Grams(f, counts[0], vocab, unigram.Raw()); + // Read the n-grams. + for (unsigned int n = 2; n < counts.size(); ++n) { + ReadNGrams(f, n, counts[n-1], vocab, middle[n-2]); + } + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, longest); +} + +template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab); +template void TemplateHashedSearch<SortedHashedSearch::Middle, SortedHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, SortedVocabulary &vocab); + +} // namespace detail +} // namespace ngram +} // namespace lm diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc new file mode 100644 index 00000000..182e27f5 --- /dev/null +++ b/klm/lm/search_trie.cc @@ -0,0 +1,402 @@ +#include "lm/search_trie.hh" + +#include "lm/lm_exception.hh" +#include "lm/read_arpa.hh" +#include "lm/trie.hh" +#include "lm/vocab.hh" +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/ersatz_progress.hh" +#include "util/file_piece.hh" +#include "util/scoped.hh" + +#include <algorithm> +#include <cstring> +#include <cstdio> +#include <deque> +#include <iostream> +#include <limits> +//#include <parallel/algorithm> +#include <vector> + +#include <sys/mman.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <fcntl.h> +#include <stdlib.h> + +namespace lm { +namespace ngram { +namespace trie { +namespace { + +template <unsigned char Order> class FullEntry { + public: + typedef ProbBackoff Weights; + static const unsigned char kOrder = Order; + + // reverse order + WordIndex words[Order]; + Weights weights; + + bool operator<(const FullEntry<Order> &other) const { + for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { + if (*i < *j) return true; + if (*i > *j) return false; + } + return false; + } +}; + +template <unsigned char Order> class ProbEntry { + public: + typedef Prob Weights; + static const unsigned char kOrder = Order; + + // reverse order + WordIndex words[Order]; + Weights weights; + + bool operator<(const ProbEntry<Order> &other) const { + for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { + if (*i < *j) return true; + if (*i > *j) return false; + } + return false; + } +}; + +void WriteOrThrow(FILE *to, const void *data, size_t size) { + if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +} + +void ReadOrThrow(FILE *from, void *data, size_t size) { + if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size); +} + +void CopyOrThrow(FILE *from, FILE *to, size_t size) { + const size_t kBufSize = 512; + char buf[kBufSize]; + for (size_t i = 0; i < size; i += kBufSize) { + std::size_t amount = std::min(size - i, kBufSize); + ReadOrThrow(from, buf, amount); + WriteOrThrow(to, buf, amount); + } +} + +template <class Entry> std::string DiskFlush(const Entry *begin, const Entry *end, const std::string &file_prefix, std::size_t batch) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << '_' << batch; + std::string ret(assembled.str()); + util::scoped_FILE out(fopen(ret.c_str(), "w")); + if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); + for (const Entry *group_begin = begin; group_begin != end;) { + const Entry *group_end = group_begin; + for (++group_end; (group_end != end) && !memcmp(group_begin->words, group_end->words, sizeof(WordIndex) * (Entry::kOrder - 1)); ++group_end) {} + WriteOrThrow(out.get(), group_begin->words, sizeof(WordIndex) * (Entry::kOrder - 1)); + WordIndex group_size = group_end - group_begin; + WriteOrThrow(out.get(), &group_size, sizeof(group_size)); + for (const Entry *i = group_begin; i != group_end; ++i) { + WriteOrThrow(out.get(), &i->words[Entry::kOrder - 1], sizeof(WordIndex)); + WriteOrThrow(out.get(), &i->weights, sizeof(typename Entry::Weights)); + } + group_begin = group_end; + } + return ret; +} + +class SortedFileReader { + public: + SortedFileReader() {} + + void Init(const std::string &name, unsigned char order) { + file_.reset(fopen(name.c_str(), "r")); + if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read"); + header_.resize(order - 1); + NextHeader(); + } + + // Preceding words. + const WordIndex *Header() const { + return &*header_.begin(); + } + const std::vector<WordIndex> &HeaderVector() const { return header_;} + + std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); } + + void NextHeader() { + if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get()) && !Ended()) { + UTIL_THROW(util::ErrnoException, "Short read of counts"); + } + } + + void ReadCount(WordIndex &to) { + ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); + } + + void ReadWord(WordIndex &to) { + ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); + } + + template <class Weights> void ReadWeights(Weights &to) { + ReadOrThrow(file_.get(), &to, sizeof(Weights)); + } + + bool Ended() { + return feof(file_.get()); + } + + FILE *File() { return file_.get(); } + + private: + util::scoped_FILE file_; + + std::vector<WordIndex> header_; +}; + +void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) { + WriteOrThrow(to, from.Header(), from.HeaderBytes()); + WordIndex count; + from.ReadCount(count); + WriteOrThrow(to, &count, sizeof(WordIndex)); + + CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); +} + +void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) { + SortedFileReader first, second; + first.Init(first_name, order); + second.Init(second_name, order); + util::scoped_FILE out_file(fopen(out, "w")); + if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write"); + while (!first.Ended() && !second.Ended()) { + if (first.HeaderVector() < second.HeaderVector()) { + CopyFullRecord(first, out_file.get(), weights_size); + first.NextHeader(); + continue; + } + if (first.HeaderVector() > second.HeaderVector()) { + CopyFullRecord(second, out_file.get(), weights_size); + second.NextHeader(); + continue; + } + // Merge at the entry level. + WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes()); + WordIndex first_count, second_count; + first.ReadCount(first_count); second.ReadCount(second_count); + WordIndex total_count = first_count + second_count; + WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex)); + + WordIndex first_word, second_word; + first.ReadWord(first_word); second.ReadWord(second_word); + WordIndex first_index = 0, second_index = 0; + while (true) { + if (first_word < second_word) { + WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); + CopyOrThrow(first.File(), out_file.get(), weights_size); + if (++first_index == first_count) break; + first.ReadWord(first_word); + } else { + WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); + CopyOrThrow(second.File(), out_file.get(), weights_size); + if (++second_index == second_count) break; + second.ReadWord(second_word); + } + } + if (first_index == first_count) { + WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); + CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); + } else { + WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); + CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); + } + first.NextHeader(); + second.NextHeader(); + } + + for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) { + CopyFullRecord(remaining, out_file.get(), weights_size); + } +} + +template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix) { + ConvertToSorted<FullEntry<Entry::kOrder - 1> >(f, vocab, counts, mem, file_prefix); + + ReadNGramHeader(f, Entry::kOrder); + const size_t count = counts[Entry::kOrder - 1]; + const size_t batch_size = std::min(count, mem.size() / sizeof(Entry)); + Entry *const begin = reinterpret_cast<Entry*>(mem.get()); + std::deque<std::string> files; + for (std::size_t batch = 0, done = 0; done < count; ++batch) { + Entry *out = begin; + Entry *out_end = out + std::min(count - done, batch_size); + for (; out != out_end; ++out) { + ReadNGram(f, Entry::kOrder, vocab, out->words, out->weights); + } + //__gnu_parallel::sort(begin, out_end); + std::sort(begin, out_end); + + files.push_back(DiskFlush(begin, out_end, file_prefix, batch)); + done += out_end - begin; + } + + // All individual files created. Merge them. + + std::size_t merge_count = 0; + while (files.size() > 1) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merge_" << (merge_count++); + files.push_back(assembled.str()); + MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), sizeof(typename Entry::Weights), Entry::kOrder); + if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); + files.pop_front(); + if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); + files.pop_front(); + } + if (!files.empty()) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merged"; + std::string merged_name(assembled.str()); + if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); + } +} + +template <> void ConvertToSorted<FullEntry<1> >(util::FilePiece &/*f*/, const SortedVocabulary &/*vocab*/, const std::vector<uint64_t> &/*counts*/, util::scoped_memory &/*mem*/, const std::string &/*file_prefix*/) {} + +void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + { + std::string unigram_name = file_prefix + "unigrams"; + util::scoped_fd unigram_file; + util::scoped_mmap unigram_mmap; + unigram_mmap.reset(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); + Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); + } + + util::scoped_memory mem; + mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED); + if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + ConvertToSorted<ProbEntry<5> >(f, vocab, counts, mem, file_prefix); + ReadEnd(f); +} + +struct RecursiveInsertParams { + WordIndex *words; + SortedFileReader *files; + unsigned char max_order; + // This is an array of size order - 2. + BitPackedMiddle *middle; + // This has exactly one entry. + BitPackedLongest *longest; +}; + +uint64_t RecursiveInsert(RecursiveInsertParams ¶ms, unsigned char order) { + SortedFileReader &file = params.files[order - 2]; + const uint64_t ret = (order == params.max_order) ? params.longest->InsertIndex() : params.middle[order - 2].InsertIndex(); + if (std::memcmp(params.words, file.Header(), sizeof(WordIndex) * (order - 1))) + return ret; + WordIndex count; + file.ReadCount(count); + WordIndex key; + if (order == params.max_order) { + Prob value; + for (WordIndex i = 0; i < count; ++i) { + file.ReadWord(key); + file.ReadWeights(value); + params.longest->Insert(key, value.prob); + } + file.NextHeader(); + return ret; + } + ProbBackoff value; + for (WordIndex i = 0; i < count; ++i) { + file.ReadWord(params.words[order - 1]); + file.ReadWeights(value); + params.middle[order - 2].Insert( + params.words[order - 1], + value.prob, + value.backoff, + RecursiveInsert(params, order + 1)); + } + file.NextHeader(); + return ret; +} + +void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, std::ostream *messages, TrieSearch &out) { + UnigramValue *unigrams = out.unigram.Raw(); + // Load unigrams. Leave the next pointers uninitialized. + { + std::string name(file_prefix + "unigrams"); + util::scoped_FILE file(fopen(name.c_str(), "r")); + if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); + for (WordIndex i = 0; i < counts[0]; ++i) { + ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); + } + unlink(name.c_str()); + } + + // inputs[0] is bigrams. + SortedFileReader inputs[counts.size() - 1]; + for (unsigned char i = 2; i <= counts.size(); ++i) { + std::stringstream assembled; + assembled << file_prefix << static_cast<unsigned int>(i) << "_merged"; + inputs[i-2].Init(assembled.str(), i); + unlink(assembled.str().c_str()); + } + + // words[0] is unigrams. + WordIndex words[counts.size()]; + RecursiveInsertParams params; + params.words = words; + params.files = inputs; + params.max_order = static_cast<unsigned char>(counts.size()); + params.middle = &*out.middle.begin(); + params.longest = &out.longest; + { + util::ErsatzProgress progress(messages, "Building trie", counts[0]); + for (words[0] = 0; words[0] < counts[0]; ++words[0], ++progress) { + unigrams[words[0]].next = RecursiveInsert(params, 2); + } + } + + /* Set ending offsets so the last entry will be sized properly */ + if (!out.middle.empty()) { + unigrams[counts[0]].next = out.middle.front().InsertIndex(); + for (size_t i = 0; i < out.middle.size() - 1; ++i) { + out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); + } + out.middle.back().FinishedLoading(out.longest.InsertIndex()); + } else { + unigrams[counts[0]].next = out.longest.InsertIndex(); + } +} + +} // namespace + +void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab) { + std::string temporary_directory; + if (config.temporary_directory_prefix) { + temporary_directory = config.temporary_directory_prefix; + } else if (config.write_mmap) { + temporary_directory = config.write_mmap; + } else { + temporary_directory = file; + } + // Null on end is kludge to ensure null termination. + temporary_directory += "-tmp-XXXXXX\0"; + if (!mkdtemp(&temporary_directory[0])) { + UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); + } + // Chop off null kludge. + temporary_directory.resize(strlen(temporary_directory.c_str())); + // Add directory delimiter. Assumes a real operating system. + temporary_directory += '/'; + ARPAToSortedFiles(f, counts, config.building_memory, temporary_directory.c_str(), vocab); + BuildTrie(temporary_directory.c_str(), counts, config.messages, *this); + if (rmdir(temporary_directory.c_str())) { + std::cerr << "Failed to delete " << temporary_directory << std::endl; + } +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc index 7bd23d76..b634d200 100644 --- a/klm/lm/sri.cc +++ b/klm/lm/sri.cc @@ -1,4 +1,4 @@ -#include "lm/exception.hh" +#include "lm/lm_exception.hh" #include "lm/sri.hh" #include <Ngram.h> @@ -31,8 +31,7 @@ void Vocabulary::FinishedLoading() { SetSpecial( sri_->ssIndex(), sri_->seIndex(), - sri_->unkIndex(), - sri_->highIndex() + 1); + sri_->unkIndex()); } namespace { diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc new file mode 100644 index 00000000..8ed7b2a2 --- /dev/null +++ b/klm/lm/trie.cc @@ -0,0 +1,167 @@ +#include "lm/trie.hh" + +#include "util/bit_packing.hh" +#include "util/exception.hh" +#include "util/proxy_iterator.hh" +#include "util/sorted_uniform.hh" + +#include <assert.h> + +namespace lm { +namespace ngram { +namespace trie { +namespace { + +// Assumes key is first. +class JustKeyProxy { + public: + JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {} + + operator uint64_t() const { return GetKey(); } + + uint64_t GetKey() const { + uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_); + return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_); + } + + private: + friend class util::ProxyIterator<JustKeyProxy>; + friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); + + JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits) + : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {} + + // This is a read-only iterator. + JustKeyProxy &operator=(const JustKeyProxy &other); + + typedef uint64_t value_type; + + typedef uint64_t InnerIterator; + uint64_t &Inner() { return inner_; } + const uint64_t &Inner() const { return inner_; } + + // The address in bits is base_ * 8 + inner_ * total_bits_. + uint64_t inner_; + const uint8_t *const base_; + const uint64_t key_mask_; + const uint8_t total_bits_; +}; + +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { + util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits)); + util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits)); + util::ProxyIterator<JustKeyProxy> out; + if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false; + at_index = out.Inner(); + return true; +} +} // namespace + +std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { + uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits; + // Extra entry for next pointer at the end. + // +7 then / 8 to round up bits and convert to bytes + // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault. + // Note that this waste is O(order), not O(number of ngrams). + return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t); +} + +void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) { + util::BitPackingSanity(); + word_bits_ = util::RequiredBits(max_vocab); + word_mask_ = (1ULL << word_bits_) - 1ULL; + if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions."); + prob_bits_ = 31; + total_bits_ = word_bits_ + prob_bits_ + remaining_bits; + + base_ = static_cast<uint8_t*>(base); + insert_index_ = 0; +} + +std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { + return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr)); +} + +void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) { + backoff_bits_ = 32; + next_bits_ = util::RequiredBits(max_next); + if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); + next_mask_ = (1ULL << next_bits_) - 1; + + BaseInit(base, max_vocab, backoff_bits_ + next_bits_); +} + +void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) { + assert(word <= word_mask_); + assert(next <= next_mask_); + uint64_t at_pointer = insert_index_ * total_bits_; + + util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word); + at_pointer += word_bits_; + util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); + at_pointer += prob_bits_; + util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); + at_pointer += backoff_bits_; + util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next); + + ++insert_index_; +} + +bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { + uint64_t at_pointer; + if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + at_pointer *= total_bits_; + at_pointer += word_bits_; + prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); + at_pointer += prob_bits_; + backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); + at_pointer += backoff_bits_; + range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + // Read the next entry's pointer. + at_pointer += total_bits_; + range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + return true; +} + +bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { + uint64_t at_pointer; + if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + at_pointer *= total_bits_; + at_pointer += word_bits_; + at_pointer += prob_bits_; + backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); + at_pointer += backoff_bits_; + range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + // Read the next entry's pointer. + at_pointer += total_bits_; + range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); + return true; +} + +void BitPackedMiddle::FinishedLoading(uint64_t next_end) { + assert(next_end <= next_mask_); + uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; + util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end); +} + + +void BitPackedLongest::Insert(WordIndex index, float prob) { + assert(index <= word_mask_); + uint64_t at_pointer = insert_index_ * total_bits_; + util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index); + at_pointer += word_bits_; + util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); + ++insert_index_; +} + +bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const { + uint64_t at_pointer; + if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false; + at_pointer = at_pointer * total_bits_ + word_bits_; + prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); + return true; +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/virtual_interface.cc b/klm/lm/virtual_interface.cc index 9c7151f9..c5a64972 100644 --- a/klm/lm/virtual_interface.cc +++ b/klm/lm/virtual_interface.cc @@ -1,17 +1,16 @@ #include "lm/virtual_interface.hh" -#include "lm/exception.hh" +#include "lm/lm_exception.hh" namespace lm { namespace base { Vocabulary::~Vocabulary() {} -void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) { +void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) { begin_sentence_ = begin_sentence; end_sentence_ = end_sentence; not_found_ = not_found; - available_ = available; if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>"); if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>"); } diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 621a129e..f15f8789 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -37,8 +37,6 @@ class Vocabulary { WordIndex BeginSentence() const { return begin_sentence_; } WordIndex EndSentence() const { return end_sentence_; } WordIndex NotFound() const { return not_found_; } - // FullScoreReturn start index of unused word assignments. - WordIndex Available() const { return available_; } /* Most implementations allow StringPiece lookups and need only override * Index(StringPiece). SRI requires null termination and overrides all @@ -56,13 +54,13 @@ class Vocabulary { // Call SetSpecial afterward. Vocabulary() {} - Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) { - SetSpecial(begin_sentence, end_sentence, not_found, available); + Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) { + SetSpecial(begin_sentence, end_sentence, not_found); } - void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available); + void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found); - WordIndex begin_sentence_, end_sentence_, not_found_, available_; + WordIndex begin_sentence_, end_sentence_, not_found_; private: // Disable copy constructors. They're private and undefined. @@ -97,7 +95,7 @@ class Vocabulary { * missing these methods, see facade.hh. * * This is the fastest way to use a model and presents a normal State class to - * be included in hypothesis state structure. + * be included in a hypothesis state structure. * * * OPTION 2: Use the virtual interface below. diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc new file mode 100644 index 00000000..c30428b2 --- /dev/null +++ b/klm/lm/vocab.cc @@ -0,0 +1,187 @@ +#include "lm/vocab.hh" + +#include "lm/enumerate_vocab.hh" +#include "lm/lm_exception.hh" +#include "lm/config.hh" +#include "lm/weights.hh" +#include "util/exception.hh" +#include "util/joint_sort.hh" +#include "util/murmur_hash.hh" +#include "util/probing_hash_table.hh" + +#include <string> + +namespace lm { +namespace ngram { + +namespace detail { +uint64_t HashForVocab(const char *str, std::size_t len) { + // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000 + // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit. + return util::MurmurHash64A(str, len, 0); +} +} // namespace detail + +namespace { +// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok. +const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5); +// Sadly some LMs have <UNK>. +const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); + +void ReadWords(int fd, EnumerateVocab *enumerate) { + if (!enumerate) return; + const std::size_t kInitialRead = 16384; + std::string buf; + buf.reserve(kInitialRead + 100); + buf.resize(kInitialRead); + WordIndex index = 0; + while (true) { + ssize_t got = read(fd, &buf[0], kInitialRead); + if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); + if (got == 0) return; + buf.resize(got); + while (buf[buf.size() - 1]) { + char next_char; + ssize_t ret = read(fd, &next_char, 1); + if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); + if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word."); + buf.push_back(next_char); + } + // Ok now we have null terminated strings. + for (const char *i = buf.data(); i != buf.data() + buf.size();) { + std::size_t length = strlen(i); + enumerate->Add(index++, StringPiece(i, length)); + i += length + 1 /* null byte */; + } + } +} + +void WriteOrThrow(int fd, const void *data_void, std::size_t size) { + const uint8_t *data = static_cast<const uint8_t*>(data_void); + while (size) { + ssize_t ret = write(fd, data, size); + if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); + data += ret; + size -= ret; + } +} + +} // namespace + +WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner, int fd) : inner_(inner), fd_(fd) {} +WriteWordsWrapper::~WriteWordsWrapper() {} + +void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { + if (inner_) inner_->Add(index, str); + WriteOrThrow(fd_, str.data(), str.size()); + char null_byte = 0; + // Inefficient because it's unbuffered. Sue me. + WriteOrThrow(fd_, &null_byte, 1); +} + +SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} + +std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { + // Lead with the number of entries. + return sizeof(uint64_t) + sizeof(Entry) * entries; +} + +void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) { + assert(allocated >= Size(entries, config)); + // Leave space for number of entries. + begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1); + end_ = begin_; + saw_unk_ = false; +} + +void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) { + enumerate_ = to; + if (enumerate_) { + enumerate_->Add(0, "<unk>"); + strings_to_enumerate_.resize(max_entries); + } +} + +WordIndex SortedVocabulary::Insert(const StringPiece &str) { + uint64_t hashed = detail::HashForVocab(str); + if (hashed == kUnknownHash || hashed == kUnknownCapHash) { + saw_unk_ = true; + return 0; + } + end_->key = hashed; + if (enumerate_) { + strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); + } + ++end_; + // This is 1 + the offset where it was inserted to make room for unk. + return end_ - begin_; +} + +void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { + if (enumerate_) { + util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); + util::JointSort(begin_, end_, values); + for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) { + // <unk> strikes again: +1 here. + enumerate_->Add(i + 1, strings_to_enumerate_[i]); + } + strings_to_enumerate_.clear(); + } else { + util::JointSort(begin_, end_, reorder_vocab + 1); + } + SetSpecial(Index("<s>"), Index("</s>"), 0); + // Save size. + *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_; +} + +void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { + end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); + ReadWords(fd, to); + SetSpecial(Index("<s>"), Index("</s>"), 0); +} + +ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} + +std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { + return Lookup::Size(entries, config.probing_multiplier); +} + +void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { + lookup_ = Lookup(start, allocated); + available_ = 1; + saw_unk_ = false; +} + +void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) { + enumerate_ = to; + if (enumerate_) { + enumerate_->Add(0, "<unk>"); + } +} + +WordIndex ProbingVocabulary::Insert(const StringPiece &str) { + uint64_t hashed = detail::HashForVocab(str); + // Prevent unknown from going into the table. + if (hashed == kUnknownHash || hashed == kUnknownCapHash) { + saw_unk_ = true; + return 0; + } else { + if (enumerate_) enumerate_->Add(available_, str); + lookup_.Insert(Lookup::Packing::Make(hashed, available_)); + return available_++; + } +} + +void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { + lookup_.FinishedInserting(); + SetSpecial(Index("<s>"), Index("</s>"), 0); +} + +void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { + lookup_.LoadedBinary(); + ReadWords(fd, to); + SetSpecial(Index("<s>"), Index("</s>"), 0); +} + +} // namespace ngram +} // namespace lm diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index be84736c..9e38e0f1 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -20,6 +20,7 @@ noinst_LIBRARIES = libklm_util.a libklm_util_a_SOURCES = \ ersatz_progress.cc \ + bit_packing.cc \ exception.cc \ file_piece.cc \ mmap.cc \ diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc new file mode 100644 index 00000000..dd14ffe1 --- /dev/null +++ b/klm/util/bit_packing.cc @@ -0,0 +1,27 @@ +#include "util/bit_packing.hh" +#include "util/exception.hh" + +namespace util { + +namespace { +template <bool> struct StaticCheck {}; +template <> struct StaticCheck<true> { typedef bool StaticAssertionPassed; }; + +typedef StaticCheck<sizeof(float) == 4>::StaticAssertionPassed FloatSize; + +} // namespace + +uint8_t RequiredBits(uint64_t max_value) { + if (!max_value) return 0; + uint8_t ret = 1; + while (max_value >>= 1) ++ret; + return ret; +} + +void BitPackingSanity() { + const detail::FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 }; + if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000"); + // TODO: more checks. +} + +} // namespace util diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh new file mode 100644 index 00000000..422ed873 --- /dev/null +++ b/klm/util/bit_packing.hh @@ -0,0 +1,88 @@ +#ifndef UTIL_BIT_PACKING__ +#define UTIL_BIT_PACKING__ + +/* Bit-level packing routines */ + +#include <assert.h> +#ifdef __APPLE__ +#include <architecture/byte_order.h> +#else +#include <endian.h> +#endif + +#include <inttypes.h> + +#if __BYTE_ORDER != __LITTLE_ENDIAN +#error The bit aligned storage functions assume little endian architecture +#endif + +namespace util { + +/* WARNING WARNING WARNING: + * The write functions assume that memory is zero initially. This makes them + * faster and is the appropriate case for mmapped language model construction. + * These routines assume that unaligned access to uint64_t is fast and that + * storage is little endian. This is the case on x86_64. It may not be the + * case on 32-bit x86 but my target audience is large language models for which + * 64-bit is necessary. + */ + +/* Pack integers up to 57 bits using their least significant digits. + * The length is specified using mask: + * Assumes mask == (1 << length) - 1 where length <= 57. + */ +inline uint64_t ReadInt57(const void *base, uint8_t bit, uint64_t mask) { + return (*reinterpret_cast<const uint64_t*>(base) >> bit) & mask; +} +/* Assumes value <= mask and mask == (1 << length) - 1 where length <= 57. + * Assumes the memory is zero initially. + */ +inline void WriteInt57(void *base, uint8_t bit, uint64_t value) { + *reinterpret_cast<uint64_t*>(base) |= (value << bit); +} + +namespace detail { typedef union { float f; uint32_t i; } FloatEnc; } +inline float ReadFloat32(const void *base, uint8_t bit) { + detail::FloatEnc encoded; + encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit; + return encoded.f; +} +inline void WriteFloat32(void *base, uint8_t bit, float value) { + detail::FloatEnc encoded; + encoded.f = value; + WriteInt57(base, bit, encoded.i); +} + +inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { + detail::FloatEnc encoded; + encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit; + // Sign bit set means negative. + encoded.i |= 0x80000000; + return encoded.f; +} +inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { + assert(value <= 0.0); + detail::FloatEnc encoded; + encoded.f = value; + encoded.i &= ~0x80000000; + WriteInt57(base, bit, encoded.i); +} + +void BitPackingSanity(); + +// Return bits required to store integers upto max_value. Not the most +// efficient implementation, but this is only called a few times to size tries. +uint8_t RequiredBits(uint64_t max_value); + +struct BitsMask { + void FromMax(uint64_t max_value) { + bits = RequiredBits(max_value); + mask = (1 << bits) - 1; + } + uint8_t bits; + uint64_t mask; +}; + +} // namespace util + +#endif // UTIL_BIT_PACKING__ diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc index 09e3a106..55c182bd 100644 --- a/klm/util/ersatz_progress.cc +++ b/klm/util/ersatz_progress.cc @@ -13,10 +13,7 @@ ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::s ErsatzProgress::~ErsatzProgress() { if (!out_) return; - for (; stones_written_ < kWidth; ++stones_written_) { - (*out_) << '*'; - } - *out_ << '\n'; + Finished(); } ErsatzProgress::ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete) @@ -36,8 +33,8 @@ void ErsatzProgress::Milestone() { for (; stones_written_ < stone; ++stones_written_) { (*out_) << '*'; } - - if (current_ >= complete_) { + if (stone == kWidth) { + (*out_) << std::endl; next_ = std::numeric_limits<std::size_t>::max(); } else { next_ = std::max(next_, (stone * complete_) / kWidth); diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index ea6c3bb9..92c345fe 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -19,7 +19,7 @@ class ErsatzProgress { ~ErsatzProgress(); ErsatzProgress &operator++() { - if (++current_ == next_) Milestone(); + if (++current_ >= next_) Milestone(); return *this; } @@ -33,6 +33,10 @@ class ErsatzProgress { Milestone(); } + void Finished() { + Set(complete_); + } + private: void Milestone(); diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 2b439499..e7bd8659 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -2,19 +2,23 @@ #include "util/exception.hh" -#include <iostream> #include <string> #include <limits> #include <assert.h> -#include <cstdlib> #include <ctype.h> +#include <err.h> #include <fcntl.h> +#include <stdlib.h> #include <sys/mman.h> #include <sys/types.h> #include <sys/stat.h> #include <unistd.h> +#ifdef HAVE_ZLIB +#include <zlib.h> +#endif + namespace util { EndOfFileException::EndOfFileException() throw() { @@ -26,6 +30,13 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() { *this << "Could not parse \"" << value << "\" into a float"; } +GZException::GZException(void *file) { +#ifdef HAVE_ZLIB + int num; + *this << gzerror(file, &num) << " from zlib"; +#endif // HAVE_ZLIB +} + int OpenReadOrThrow(const char *name) { int ret = open(name, O_RDONLY); if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading"); @@ -38,42 +49,73 @@ off_t SizeFile(int fd) { return sb.st_size; } -FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) : +FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) : file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { Initialize(name, show_progress, min_buffer); } -FilePiece::FilePiece(const char *name, int fd, std::ostream *show_progress, off_t min_buffer) : +FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) : file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { Initialize(name, show_progress, min_buffer); } -void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) { - if (total_size_ == kBadSize) { - fallback_to_read_ = true; - if (show_progress) - *show_progress << "File " << name << " isn't normal. Using slower read() instead of mmap(). No progress bar." << std::endl; - } else { - fallback_to_read_ = false; +FilePiece::~FilePiece() { +#ifdef HAVE_ZLIB + if (gz_file_) { + // zlib took ownership + file_.release(); + int ret; + if (Z_OK != (ret = gzclose(gz_file_))) { + errx(1, "could not close file %s using zlib", file_name_.c_str()); + } } +#endif +} + +void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) { +#ifdef HAVE_ZLIB + gz_file_ = NULL; +#endif + file_name_ = name; + default_map_size_ = page_ * std::max<off_t>((min_buffer / page_ + 1), 2); position_ = NULL; position_end_ = NULL; mapped_offset_ = 0; at_end_ = false; + + if (total_size_ == kBadSize) { + // So the assertion passes. + fallback_to_read_ = false; + if (show_progress) + *show_progress << "File " << name << " isn't normal. Using slower read() instead of mmap(). No progress bar." << std::endl; + TransitionToRead(); + } else { + fallback_to_read_ = false; + } Shift(); + // gzip detect. + if ((position_end_ - position_) > 2 && *position_ == 0x1f && static_cast<unsigned char>(*(position_ + 1)) == 0x8b) { +#ifndef HAVE_ZLIB + UTIL_THROW(GZException, "Looks like a gzip file but support was not compiled in."); +#endif + if (!fallback_to_read_) { + at_end_ = false; + TransitionToRead(); + } + } } -float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) { +float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) { SkipSpaces(); while (last_space_ < position_) { if (at_end_) { // Hallucinate a null off the end of the file. std::string buffer(position_, position_end_); char *end; - float ret = std::strtof(buffer.c_str(), &end); + float ret = strtof(buffer.c_str(), &end); if (buffer.c_str() == end) throw ParseNumberException(buffer); position_ += end - buffer.c_str(); return ret; @@ -81,20 +123,20 @@ float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) { Shift(); } char *end; - float ret = std::strtof(position_, &end); + float ret = strtof(position_, &end); if (end == position_) throw ParseNumberException(ReadDelimited()); position_ = end; return ret; } -void FilePiece::SkipSpaces() throw (EndOfFileException) { +void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) { for (; ; ++position_) { if (position_ == position_end_) Shift(); if (!isspace(*position_)) return; } } -const char *FilePiece::FindDelimiterOrEOF() throw (EndOfFileException) { +const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileException) { for (const char *i = position_; i <= last_space_; ++i) { if (isspace(*i)) return i; } @@ -108,7 +150,7 @@ const char *FilePiece::FindDelimiterOrEOF() throw (EndOfFileException) { return position_end_; } -StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) { +StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) { const char *start = position_; do { for (const char *i = start; i < position_end_; ++i) { @@ -124,17 +166,19 @@ StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) { } while (!at_end_); StringPiece ret(position_, position_end_ - position_); position_ = position_end_; - return position_; + return ret; } -void FilePiece::Shift() throw(EndOfFileException) { - if (at_end_) throw EndOfFileException(); +void FilePiece::Shift() throw(GZException, EndOfFileException) { + if (at_end_) { + progress_.Finished(); + throw EndOfFileException(); + } off_t desired_begin = position_ - data_.begin() + mapped_offset_; - progress_.Set(desired_begin); if (!fallback_to_read_) MMapShift(desired_begin); // Notice an mmap failure might set the fallback. - if (fallback_to_read_) ReadShift(desired_begin); + if (fallback_to_read_) ReadShift(); for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) { if (isspace(*last_space_)) break; @@ -163,28 +207,41 @@ void FilePiece::MMapShift(off_t desired_begin) throw() { data_.reset(); data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED); if (data_.get() == MAP_FAILED) { - fallback_to_read_ = true; if (desired_begin) { if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before. lseek failed too, so using read isn't an option either."); } + // The mmap was scheduled to end the file, but now we're going to read it. + at_end_ = false; + TransitionToRead(); return; } mapped_offset_ = mapped_offset; position_ = data_.begin() + ignore; position_end_ = data_.begin() + mapped_size; + + progress_.Set(desired_begin); +} + +void FilePiece::TransitionToRead() throw (GZException) { + assert(!fallback_to_read_); + fallback_to_read_ = true; + data_.reset(); + data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); + if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_); + position_ = data_.begin(); + position_end_ = position_; + +#ifdef HAVE_ZLIB + assert(!gz_file_); + gz_file_ = gzdopen(file_.get(), "r"); + if (!gz_file_) { + UTIL_THROW(GZException, "zlib failed to open " << file_name_); + } +#endif } -void FilePiece::ReadShift(off_t desired_begin) throw() { +void FilePiece::ReadShift() throw(GZException, EndOfFileException) { assert(fallback_to_read_); - if (data_.source() != scoped_memory::MALLOC_ALLOCATED) { - // First call. - data_.reset(); - data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); - if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_); - position_ = data_.begin(); - position_end_ = position_; - } - // Bytes [data_.begin(), position_) have been consumed. // Bytes [position_, position_end_) have been read into the buffer. @@ -215,9 +272,23 @@ void FilePiece::ReadShift(off_t desired_begin) throw() { } } - ssize_t read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); + ssize_t read_return; +#ifdef HAVE_ZLIB + read_return = gzread(gz_file_, static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); + if (read_return == -1) throw GZException(gz_file_); + if (total_size_ != kBadSize) { + // Just get the position, don't actually seek. Apparently this is how you do it. . . + off_t ret = lseek(file_.get(), 0, SEEK_CUR); + if (ret != -1) progress_.Set(ret); + } +#else + read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); if (read_return == -1) UTIL_THROW(ErrnoException, "read failed"); - if (read_return == 0) at_end_ = true; + progress_.Set(mapped_offset_); +#endif + if (read_return == 0) { + at_end_ = true; + } position_end_ += read_return; } diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index 704f0ac6..11d4a751 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -11,6 +11,8 @@ #include <cstddef> +#define HAVE_ZLIB + namespace util { class EndOfFileException : public Exception { @@ -25,6 +27,13 @@ class ParseNumberException : public Exception { ~ParseNumberException() throw() {} }; +class GZException : public Exception { + public: + explicit GZException(void *file); + GZException() throw() {} + ~GZException() throw() {} +}; + int OpenReadOrThrow(const char *name); // Return value for SizeFile when it can't size properly. @@ -34,40 +43,42 @@ off_t SizeFile(int fd); class FilePiece { public: // 32 MB default. - explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); + explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException); // Takes ownership of fd. name is used for messages. - explicit FilePiece(const char *name, int fd, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); + explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException); + + ~FilePiece(); - char get() throw(EndOfFileException) { - if (position_ == position_end_) Shift(); + char get() throw(GZException, EndOfFileException) { + if (position_ == position_end_) { + Shift(); + if (at_end_) throw EndOfFileException(); + } return *(position_++); } // Memory backing the returned StringPiece may vanish on the next call. // Leaves the delimiter, if any, to be returned by get(). - StringPiece ReadDelimited() throw(EndOfFileException) { + StringPiece ReadDelimited() throw(GZException, EndOfFileException) { SkipSpaces(); return Consume(FindDelimiterOrEOF()); } // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. // It is similar to getline in that way. - StringPiece ReadLine(char delim = '\n') throw(EndOfFileException); + StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException); - float ReadFloat() throw(EndOfFileException, ParseNumberException); + float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException); - void SkipSpaces() throw (EndOfFileException); + void SkipSpaces() throw (GZException, EndOfFileException); off_t Offset() const { return position_ - data_.begin() + mapped_offset_; } - // Only for testing. - void ForceFallbackToRead() { - fallback_to_read_ = true; - } + const std::string &FileName() const { return file_name_; } private: - void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer); + void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException); StringPiece Consume(const char *to) { StringPiece ret(position_, to - position_); @@ -75,12 +86,14 @@ class FilePiece { return ret; } - const char *FindDelimiterOrEOF() throw(EndOfFileException); + const char *FindDelimiterOrEOF() throw(EndOfFileException, GZException); - void Shift() throw (EndOfFileException); + void Shift() throw (EndOfFileException, GZException); // Backends to Shift(). void MMapShift(off_t desired_begin) throw (); - void ReadShift(off_t desired_begin) throw (); + + void TransitionToRead() throw (GZException); + void ReadShift() throw (GZException, EndOfFileException); const char *position_, *last_space_, *position_end_; @@ -98,6 +111,12 @@ class FilePiece { bool fallback_to_read_; ErsatzProgress progress_; + + std::string file_name_; + +#ifdef HAVE_ZLIB + void *gz_file_; +#endif // HAVE_ZLIB }; } // namespace util diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index befb7866..23e79fe0 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -1,15 +1,19 @@ #include "util/file_piece.hh" +#include "util/scoped.hh" + #define BOOST_TEST_MODULE FilePieceTest #include <boost/test/unit_test.hpp> #include <fstream> #include <iostream> +#include <stdio.h> + namespace util { namespace { /* mmap implementation */ -BOOST_AUTO_TEST_CASE(MMapLine) { +BOOST_AUTO_TEST_CASE(MMapReadLine) { std::fstream ref("file_piece.cc", std::ios::in); FilePiece test("file_piece.cc", NULL, 1); std::string ref_line; @@ -20,13 +24,17 @@ BOOST_AUTO_TEST_CASE(MMapLine) { BOOST_CHECK_EQUAL(ref_line, test_line); } } + BOOST_CHECK_THROW(test.get(), EndOfFileException); } /* read() implementation */ -BOOST_AUTO_TEST_CASE(ReadLine) { +BOOST_AUTO_TEST_CASE(StreamReadLine) { std::fstream ref("file_piece.cc", std::ios::in); - FilePiece test("file_piece.cc", NULL, 1); - test.ForceFallbackToRead(); + + scoped_FILE catter(popen("cat file_piece.cc", "r")); + BOOST_REQUIRE(catter.get()); + + FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1); std::string ref_line; while (getline(ref, ref_line)) { StringPiece test_line(test.ReadLine()); @@ -35,7 +43,47 @@ BOOST_AUTO_TEST_CASE(ReadLine) { BOOST_CHECK_EQUAL(ref_line, test_line); } } + BOOST_CHECK_THROW(test.get(), EndOfFileException); } +#ifdef HAVE_ZLIB + +// gzip file +BOOST_AUTO_TEST_CASE(PlainZipReadLine) { + std::fstream ref("file_piece.cc", std::ios::in); + + BOOST_REQUIRE_EQUAL(0, system("gzip <file_piece.cc >file_piece.cc.gz")); + FilePiece test("file_piece.cc.gz", NULL, 1); + std::string ref_line; + while (getline(ref, ref_line)) { + StringPiece test_line(test.ReadLine()); + // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924 + if (!test_line.empty() || !ref_line.empty()) { + BOOST_CHECK_EQUAL(ref_line, test_line); + } + } + BOOST_CHECK_THROW(test.get(), EndOfFileException); +} +// gzip stream +BOOST_AUTO_TEST_CASE(StreamZipReadLine) { + std::fstream ref("file_piece.cc", std::ios::in); + + scoped_FILE catter(popen("gzip <file_piece.cc", "r")); + BOOST_REQUIRE(catter.get()); + + FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1); + std::string ref_line; + while (getline(ref, ref_line)) { + StringPiece test_line(test.ReadLine()); + // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924 + if (!test_line.empty() || !ref_line.empty()) { + BOOST_CHECK_EQUAL(ref_line, test_line); + } + } + BOOST_CHECK_THROW(test.get(), EndOfFileException); +} + +#endif + } // namespace } // namespace util diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh index a2f1c01d..cf3d8432 100644 --- a/klm/util/joint_sort.hh +++ b/klm/util/joint_sort.hh @@ -119,6 +119,12 @@ template <class Proxy, class Less> class LessWrapper : public std::binary_functi } // namespace detail +template <class KeyIter, class ValueIter> class PairedIterator : public ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > { + public: + PairedIterator(const KeyIter &key, const ValueIter &value) : + ProxyIterator<detail::JointProxy<KeyIter, ValueIter> >(detail::JointProxy<KeyIter, ValueIter>(key, value)) {} +}; + template <class KeyIter, class ValueIter, class Less> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin, const Less &less) { ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > full_begin(detail::JointProxy<KeyIter, ValueIter>(key_begin, value_begin)); detail::LessWrapper<detail::JointProxy<KeyIter, ValueIter>, Less> less_wrap(less); diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index 648b5d0a..8685170f 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -53,10 +53,8 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int if (prefault) { flags |= MAP_POPULATE; } - int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; -#else - int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; #endif + int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; void *ret = mmap(NULL, size, protect, flags, fd, offset); if (ret == MAP_FAILED) { UTIL_THROW(ErrnoException, "mmap failed for size " << size << " at offset " << offset); @@ -64,8 +62,40 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int return ret; } -void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset) { - return MapOrThrow(size, false, MAP_FILE | MAP_PRIVATE, prefault, fd, offset); +namespace { +void ReadAll(int fd, void *to_void, std::size_t amount) { + uint8_t *to = static_cast<uint8_t*>(to_void); + while (amount) { + ssize_t ret = read(fd, to, amount); + if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); + if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); + amount -= ret; + to += ret; + } +} +} // namespace + +void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out) { + switch (method) { + case LAZY: + out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, false, fd, offset), size, scoped_memory::MMAP_ALLOCATED); + break; + case POPULATE_OR_LAZY: +#ifdef MAP_POPULATE + case POPULATE_OR_READ: +#endif + out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, true, fd, offset), size, scoped_memory::MMAP_ALLOCATED); + break; +#ifndef MAP_POPULATE + case POPULATE_OR_READ: +#endif + case READ: + out.reset(malloc(size), size, scoped_memory::MALLOC_ALLOCATED); + if (!out.get()) UTIL_THROW(util::ErrnoException, "Allocating " << size << " bytes with malloc"); + if (-1 == lseek(fd, offset, SEEK_SET)) UTIL_THROW(ErrnoException, "lseek to " << offset << " in fd " << fd << " failed."); + ReadAll(fd, out.get(), size); + break; + } } void *MapAnonymous(std::size_t size) { @@ -78,14 +108,14 @@ void *MapAnonymous(std::size_t size) { | MAP_PRIVATE, false, -1, 0); } -void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem) { +void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) { file.reset(open(name, O_CREAT | O_RDWR | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)); if (-1 == file.get()) UTIL_THROW(ErrnoException, "Failed to open " << name << " for writing"); if (-1 == ftruncate(file.get(), size)) UTIL_THROW(ErrnoException, "ftruncate on " << name << " to " << size << " failed"); try { - mem.reset(MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0), size); + return MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0); } catch (ErrnoException &e) { e << " in file " << name; throw; diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh index c9068ec9..0a504d89 100644 --- a/klm/util/mmap.hh +++ b/klm/util/mmap.hh @@ -6,6 +6,7 @@ #include <cstddef> +#include <inttypes.h> #include <sys/types.h> namespace util { @@ -19,8 +20,8 @@ class scoped_mmap { void *get() const { return data_; } - const char *begin() const { return reinterpret_cast<char*>(data_); } - const char *end() const { return reinterpret_cast<char*>(data_) + size_; } + const uint8_t *begin() const { return reinterpret_cast<uint8_t*>(data_); } + const uint8_t *end() const { return reinterpret_cast<uint8_t*>(data_) + size_; } std::size_t size() const { return size_; } void reset(void *data, std::size_t size) { @@ -79,23 +80,27 @@ class scoped_memory { scoped_memory &operator=(const scoped_memory &); }; -struct scoped_mapped_file { - scoped_fd fd; - scoped_mmap mem; -}; +typedef enum { + // mmap with no prepopulate + LAZY, + // On linux, pass MAP_POPULATE to mmap. + POPULATE_OR_LAZY, + // Populate on Linux. malloc and read on non-Linux. + POPULATE_OR_READ, + // malloc and read. + READ +} LoadMethod; + // Wrapper around mmap to check it worked and hide some platform macros. void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset = 0); -void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset = 0); +void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out); void *MapAnonymous(std::size_t size); // Open file name with mmap of size bytes, all of which are initially zero. -void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem); -inline void MapZeroedWrite(const char *name, std::size_t size, scoped_mapped_file &out) { - MapZeroedWrite(name, size, out.fd, out.mem); -} - +void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file); + } // namespace util -#endif // UTIL_SCOPED__ +#endif // UTIL_MMAP__ diff --git a/klm/util/proxy_iterator.hh b/klm/util/proxy_iterator.hh index 1c5b7089..121a45fa 100644 --- a/klm/util/proxy_iterator.hh +++ b/klm/util/proxy_iterator.hh @@ -78,6 +78,8 @@ template <class Proxy> class ProxyIterator { const Proxy *operator->() const { return &p_; } Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); } + const InnerIterator &Inner() { return p_.Inner(); } + private: InnerIterator &I() { return p_.Inner(); } const InnerIterator &I() const { return p_.Inner(); } diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc index 61394ffc..2c6d5394 100644 --- a/klm/util/scoped.cc +++ b/klm/util/scoped.cc @@ -9,4 +9,8 @@ scoped_fd::~scoped_fd() { if (fd_ != -1 && close(fd_)) err(1, "Could not close file %i", fd_); } +scoped_FILE::~scoped_FILE() { + if (file_ && fclose(file_)) err(1, "Could not close file"); +} + } // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index ef62a74f..52864481 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -4,6 +4,7 @@ /* Other scoped objects in the style of scoped_ptr. */ #include <cstddef> +#include <cstdio> namespace util { @@ -61,6 +62,24 @@ class scoped_fd { scoped_fd &operator=(const scoped_fd &); }; +class scoped_FILE { + public: + explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} + + ~scoped_FILE(); + + std::FILE *get() { return file_; } + const std::FILE *get() const { return file_; } + + void reset(std::FILE *to = NULL) { + scoped_FILE other(file_); + file_ = to; + } + + private: + std::FILE *file_; +}; + } // namespace util #endif // UTIL_SCOPED__ diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 96ec4866..a8e208fb 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -65,7 +65,7 @@ template <class PackingT> class SortedUniformMap { public: // Offer consistent API with probing hash. - static std::size_t Size(std::size_t entries, float ignore = 0.0) { + static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) { return sizeof(uint64_t) + entries * Packing::kBytes; } @@ -75,7 +75,7 @@ template <class PackingT> class SortedUniformMap { #endif {} - SortedUniformMap(void *start, std::size_t allocated) : + SortedUniformMap(void *start, std::size_t /*allocated*/) : begin_(Packing::FromVoid(reinterpret_cast<uint64_t*>(start) + 1)), end_(begin_), size_ptr_(reinterpret_cast<uint64_t*>(start)) #ifdef DEBUG diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc index 6917a6bc..5b4e98f5 100644 --- a/klm/util/string_piece.cc +++ b/klm/util/string_piece.cc @@ -30,14 +30,14 @@ #include "util/string_piece.hh" -#ifdef USE_BOOST +#ifdef HAVE_BOOST #include <boost/functional/hash/hash.hpp> #endif #include <algorithm> #include <iostream> -#ifdef USE_ICU +#ifdef HAVE_ICU U_NAMESPACE_BEGIN #endif @@ -46,12 +46,12 @@ std::ostream& operator<<(std::ostream& o, const StringPiece& piece) { return o; } -#ifdef USE_BOOST +#ifdef HAVE_BOOST size_t hash_value(const StringPiece &str) { return boost::hash_range(str.data(), str.data() + str.length()); } #endif -#ifdef USE_ICU +#ifdef HAVE_ICU U_NAMESPACE_END #endif diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index 58008d13..3ac2f8a7 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -1,4 +1,4 @@ -/* If you use ICU in your program, then compile with -DUSE_ICU -licui18n. If +/* If you use ICU in your program, then compile with -DHAVE_ICU -licui18n. If * you don't use ICU, then this will use the Google implementation from Chrome. * This has been modified from the original version to let you choose. */ @@ -49,14 +49,14 @@ #define BASE_STRING_PIECE_H__ //Uncomment this line if you use ICU in your code. -//#define USE_ICU +//#define HAVE_ICU //Uncomment this line if you want boost hashing for your StringPieces. -//#define USE_BOOST +//#define HAVE_BOOST #include <cstring> #include <iosfwd> -#ifdef USE_ICU +#ifdef HAVE_ICU #include <unicode/stringpiece.h> U_NAMESPACE_BEGIN #else @@ -230,7 +230,7 @@ inline bool operator>=(const StringPiece& x, const StringPiece& y) { // allow StringPiece to be logged (needed for unit testing). extern std::ostream& operator<<(std::ostream& o, const StringPiece& piece); -#ifdef USE_BOOST +#ifdef HAVE_BOOST size_t hash_value(const StringPiece &str); /* Support for lookup of StringPiece in boost::unordered_map<std::string> */ @@ -253,7 +253,7 @@ template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece } #endif -#ifdef USE_ICU +#ifdef HAVE_ICU U_NAMESPACE_END #endif |