diff options
34 files changed, 2097 insertions, 634 deletions
| diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 48f9889e..eb71c0f5 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -7,11 +7,17 @@  noinst_LIBRARIES = libklm.a  libklm_a_SOURCES = \ -  exception.cc \ -  ngram.cc \ -  ngram_build_binary.cc \ +  binary_format.cc \ +  config.cc \ +  lm_exception.cc \ +  model.cc \    ngram_query.cc \ -  virtual_interface.cc +  read_arpa.cc \ +  search_hashed.cc \ +  search_trie.cc \ +  trie.cc \ +  virtual_interface.cc \ +  vocab.cc  AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc new file mode 100644 index 00000000..2a075b6b --- /dev/null +++ b/klm/lm/binary_format.cc @@ -0,0 +1,191 @@ +#include "lm/binary_format.hh" + +#include "lm/lm_exception.hh" +#include "util/file_piece.hh" + +#include <limits> +#include <string> + +#include <fcntl.h> +#include <errno.h> +#include <stdlib.h> +#include <sys/mman.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <unistd.h> + +namespace lm { +namespace ngram { +namespace { +const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 1\n\0"; +const long int kMagicVersion = 1; + +// Test values.   +struct Sanity { +  char magic[sizeof(kMagicBytes)]; +  float zero_f, one_f, minus_half_f; +  WordIndex one_word_index, max_word_index; +  uint64_t one_uint64; + +  void SetToReference() { +    std::memcpy(magic, kMagicBytes, sizeof(magic)); +    zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; +    one_word_index = 1; +    max_word_index = std::numeric_limits<WordIndex>::max(); +    one_uint64 = 1; +  } +}; + +const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; + +std::size_t Align8(std::size_t in) { +  std::size_t off = in % 8; +  if (!off) return in; +  return in + 8 - off; +} + +std::size_t TotalHeaderSize(unsigned char order) { +  return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); +} + +void ReadLoop(int fd, void *to_void, std::size_t size) { +  uint8_t *to = static_cast<uint8_t*>(to_void); +  while (size) { +    ssize_t ret = read(fd, to, size); +    if (ret == -1) UTIL_THROW(util::ErrnoException, "Failed to read from binary file"); +    if (ret == 0) UTIL_THROW(util::ErrnoException, "Binary file too short"); +    to += ret; +    size -= ret; +  } +} + +void WriteHeader(void *to, const Parameters ¶ms) { +  Sanity header = Sanity(); +  header.SetToReference(); +  memcpy(to, &header, sizeof(Sanity)); +  char *out = reinterpret_cast<char*>(to) + sizeof(Sanity); + +  *reinterpret_cast<FixedWidthParameters*>(out) = params.fixed; +  out += sizeof(FixedWidthParameters); + +  uint64_t *counts = reinterpret_cast<uint64_t*>(out); +  for (std::size_t i = 0; i < params.counts.size(); ++i) { +    counts[i] = params.counts[i]; +  } +} + +} // namespace +namespace detail { + +bool IsBinaryFormat(int fd) { +  const off_t size = util::SizeFile(fd); +  if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(Sanity)))) return false; +  // Try reading the header.   +  util::scoped_memory memory; +  try { +    util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory); +  } catch (const util::Exception &e) { +    return false; +  } +  Sanity reference_header = Sanity(); +  reference_header.SetToReference(); +  if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true; +  if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) { +    char *end_ptr; +    const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion); +    long int version = strtol(begin_version, &end_ptr, 10); +    if ((end_ptr != begin_version) && version != kMagicVersion) { +      UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to rebuild your binary LM from the ARPA.  Sorry."); +    } +    UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match.  Try rebuilding the binary format LM using the same code revision, compiler, and architecture."); +  } +  return false; +} + +void ReadHeader(int fd, Parameters &out) { +  if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file"); +  ReadLoop(fd, &out.fixed, sizeof(out.fixed)); +  if (out.fixed.probing_multiplier < 1.0) +    UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); + +  out.counts.resize(static_cast<std::size_t>(out.fixed.order)); +  ReadLoop(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); +} + +void MatchCheck(ModelType model_type, const Parameters ¶ms) { +  if (params.fixed.model_type != model_type) { +    if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) +      UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code."); +    UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]); +  } +} + +uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { +  const off_t file_size = util::SizeFile(backing.file.get()); +  // The header is smaller than a page, so we have to map the whole header as well.   +  std::size_t total_map = TotalHeaderSize(params.counts.size()) + memory_size; +  if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) +    UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); + +  util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.memory); + +  if (config.enumerate_vocab && !params.fixed.has_vocabulary) +    UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them.  You may need to rebuild the binary file with an updated version of build_binary."); + +  if (config.enumerate_vocab) { +    if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) +      UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words"); +  } +  return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(params.counts.size()); +} + +uint8_t *SetupZeroed(const Config &config, ModelType model_type, const std::vector<uint64_t> &counts, std::size_t memory_size, Backing &backing) { +  if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0"); +  if (config.write_mmap) { +    std::size_t total_map = TotalHeaderSize(counts.size()) + memory_size; +    // Write out an mmap file.   +    backing.memory.reset(util::MapZeroedWrite(config.write_mmap, total_map, backing.file), total_map, util::scoped_memory::MMAP_ALLOCATED); + +    Parameters params; +    params.counts = counts; +    params.fixed.order = counts.size(); +    params.fixed.probing_multiplier = config.probing_multiplier; +    params.fixed.model_type = model_type; +    params.fixed.has_vocabulary = config.include_vocab; + +    WriteHeader(backing.memory.get(), params); + +    if (params.fixed.has_vocabulary) { +      if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) +        UTIL_THROW(util::ErrnoException, "Failed to seek in binary file " << config.write_mmap << " to vocab words"); +    } +    return reinterpret_cast<uint8_t*>(backing.memory.get()) + TotalHeaderSize(counts.size()); +  } else { +    backing.memory.reset(util::MapAnonymous(memory_size), memory_size, util::scoped_memory::MMAP_ALLOCATED); +    return reinterpret_cast<uint8_t*>(backing.memory.get()); +  }  +} + +void ComplainAboutARPA(const Config &config, ModelType model_type) { +  if (config.write_mmap || !config.messages) return; +  if (config.arpa_complain == Config::ALL) { +    *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; +  } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { +    *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive.  Save time by building a binary format." << std::endl; +  } +} + +} // namespace detail + +bool RecognizeBinary(const char *file, ModelType &recognized) { +  util::scoped_fd fd(util::OpenReadOrThrow(file)); +  if (!detail::IsBinaryFormat(fd.get())) return false; +  Parameters params; +  detail::ReadHeader(fd.get(), params); +  recognized = params.fixed.model_type; +  return true; +} + +} // namespace ngram +} // namespace lm diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc new file mode 100644 index 00000000..4db631a2 --- /dev/null +++ b/klm/lm/build_binary.cc @@ -0,0 +1,13 @@ +#include "lm/model.hh" + +#include <iostream> + +int main(int argc, char *argv[]) { +  if (argc != 3) { +    std::cerr << "Usage: " << argv[0] << " input.arpa output.mmap" << std::endl; +    return 1; +  } +  lm::ngram::Config config; +  config.write_mmap = argv[2]; +  lm::ngram::Model(argv[1], config); +} diff --git a/klm/lm/config.cc b/klm/lm/config.cc new file mode 100644 index 00000000..2831d578 --- /dev/null +++ b/klm/lm/config.cc @@ -0,0 +1,22 @@ +#include "lm/config.hh" + +#include <iostream> + +namespace lm { +namespace ngram { + +Config::Config() : +  messages(&std::cerr), +  enumerate_vocab(NULL), +  unknown_missing(COMPLAIN), +  unknown_missing_prob(0.0), +  probing_multiplier(1.5), +  building_memory(1073741824ULL), // 1 GB +  temporary_directory_prefix(NULL), +  arpa_complain(ALL), +  write_mmap(NULL), +  include_vocab(true), +  load_method(util::POPULATE_OR_READ) {} + +} // namespace ngram +} // namespace lm diff --git a/klm/lm/lm_exception.cc b/klm/lm/lm_exception.cc new file mode 100644 index 00000000..ab2ec52f --- /dev/null +++ b/klm/lm/lm_exception.cc @@ -0,0 +1,21 @@ +#include "lm/lm_exception.hh" + +#include<errno.h> +#include<stdio.h> + +namespace lm { + +LoadException::LoadException() throw() {} +LoadException::~LoadException() throw() {} +VocabLoadException::VocabLoadException() throw() {} +VocabLoadException::~VocabLoadException() throw() {} + +FormatLoadException::FormatLoadException() throw() {} +FormatLoadException::~FormatLoadException() throw() {} + +SpecialWordMissingException::SpecialWordMissingException(StringPiece which) throw() { +  *this << "Missing special word " << which; +} +SpecialWordMissingException::~SpecialWordMissingException() throw() {} + +} // namespace lm diff --git a/klm/lm/model.cc b/klm/lm/model.cc new file mode 100644 index 00000000..6921d4d9 --- /dev/null +++ b/klm/lm/model.cc @@ -0,0 +1,239 @@ +#include "lm/model.hh" + +#include "lm/lm_exception.hh" +#include "lm/search_hashed.hh" +#include "lm/search_trie.hh" +#include "lm/read_arpa.hh" +#include "util/murmur_hash.hh" + +#include <algorithm> +#include <functional> +#include <numeric> +#include <cmath> + +namespace lm { +namespace ngram { + +size_t hash_value(const State &state) { +  return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); +} + +namespace detail { + +template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) { +  if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ".  Edit ngram.hh's kMaxOrder to at least this value and recompile."); +  if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); +  return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) { +  uint8_t *start = static_cast<uint8_t*>(base); +  size_t allocated = VocabularyT::Size(counts[0], config); +  vocab_.SetupMemory(start, allocated, counts[0], config); +  start += allocated; +  start = search_.SetupMemory(start, counts, config); +  if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << Size(counts, config)); +} + +template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { +  LoadLM(file, config, *this); + +  // g++ prints warnings unless these are fully initialized.   +  State begin_sentence = State(); +  begin_sentence.valid_length_ = 1; +  begin_sentence.history_[0] = vocab_.BeginSentence(); +  begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; +  State null_context = State(); +  null_context.valid_length_ = 0; +  P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { +  SetupMemory(start, params.counts, config); +  vocab_.LoadedBinary(fd, config.enumerate_vocab); +  search_.unigram.LoadedBinary(); +  for (typename std::vector<Middle>::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) { +    i->LoadedBinary(); +  } +  search_.longest.LoadedBinary(); +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, util::FilePiece &f, void *start, const Parameters ¶ms, const Config &config) { +  SetupMemory(start, params.counts, config); + +  if (config.write_mmap) { +    WriteWordsWrapper wrap(config.enumerate_vocab, backing_.file.get()); +    vocab_.ConfigureEnumerate(&wrap, params.counts[0]); +    search_.InitializeFromARPA(file, f, params.counts, config, vocab_); +  } else { +    vocab_.ConfigureEnumerate(config.enumerate_vocab, params.counts[0]); +    search_.InitializeFromARPA(file, f, params.counts, config, vocab_); +  } +  // TODO: fail faster?   +  if (!vocab_.SawUnk()) { +    switch(config.unknown_missing) { +      case Config::THROW_UP: +        { +          SpecialWordMissingException e("<unk>"); +          e << " and configuration was set to throw if unknown is missing"; +          throw e; +        } +      case Config::COMPLAIN: +        if (config.messages) *config.messages << "Language model is missing <unk>.  Substituting probability " << config.unknown_missing_prob << "." << std::endl;  +        // There's no break;.  This is by design.   +      case Config::SILENT: +        // Default probabilities for unknown.   +        search_.unigram.Unknown().backoff = 0.0; +        search_.unigram.Unknown().prob = config.unknown_missing_prob; +        break; +    } +  } +  if (std::fabs(search_.unigram.Unknown().backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << search_.unigram.Unknown().backoff);   +} + +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { +  unsigned char backoff_start; +  FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, backoff_start, out_state); +  if (backoff_start - 1 < in_state.valid_length_) { +    ret.prob = std::accumulate(in_state.backoff_ + backoff_start - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); +  } +  return ret; +} + +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const { +  unsigned char backoff_start; +  context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); +  FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, backoff_start, out_state); +  ret.prob += SlowBackoffLookup(context_rbegin, context_rend, backoff_start); +  return ret; +} + +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { +  context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); +  if (context_rend == context_rbegin || *context_rbegin == 0) { +    out_state.valid_length_ = 0; +    return; +  } +  float ignored_prob; +  typename Search::Node node; +  search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); +  float *backoff_out = out_state.backoff_ + 1; +  const WordIndex *i = context_rbegin + 1; +  for (; i < context_rend; ++i, ++backoff_out) { +    if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) { +      out_state.valid_length_ = i - context_rbegin; +      std::copy(context_rbegin, i, out_state.history_); +      return; +    } +  } +  std::copy(context_rbegin, context_rend, out_state.history_); +  out_state.valid_length_ = static_cast<unsigned char>(context_rend - context_rbegin); +} + +template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::SlowBackoffLookup( +    const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const { +  // Add the backoff weights for n-grams of order start to (context_rend - context_rbegin). +  if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return 0.0; +  float ret = 0.0; +  if (start == 1) { +    ret += search_.unigram.Lookup(*context_rbegin).backoff; +    start = 2; +  } +  typename Search::Node node; +  if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) { +    return 0.0; +  } +  float backoff; +  // i is the order of the backoff we're looking for. +  for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { +    if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; +    ret += backoff; +  } +  return ret; +} + +/* Ugly optimized function.  Produce a score excluding backoff.   + * The search goes in increasing order of ngram length.   + * Context goes backward, so context_begin is the word immediately preceeding + * new_word.   + */ +template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff( +    const WordIndex *context_rbegin, +    const WordIndex *context_rend, +    const WordIndex new_word, +    unsigned char &backoff_start, +    State &out_state) const { +  FullScoreReturn ret; +  typename Search::Node node; +  float *backoff_out(out_state.backoff_); +  search_.LookupUnigram(new_word, ret.prob, *backoff_out, node); +  if (new_word == 0) { +    ret.ngram_length = out_state.valid_length_ = 0; +    // All of backoff.   +    backoff_start = 1; +    return ret; +  } +  out_state.history_[0] = new_word; +  if (context_rbegin == context_rend) { +    ret.ngram_length = out_state.valid_length_ = 1; +    // No backoff because we don't have the history for it.   +    backoff_start = P::Order(); +    return ret; +  } +  ++backoff_out; + +  // Ok now we now that the bigram contains known words.  Start by looking it up. + +  const WordIndex *hist_iter = context_rbegin; +  typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin(); +  for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { +    if (hist_iter == context_rend) { +      // Ran out of history.  No backoff.   +      backoff_start = P::Order(); +      std::copy(context_rbegin, context_rend, out_state.history_ + 1); +      ret.ngram_length = out_state.valid_length_ = (context_rend - context_rbegin) + 1; +      // ret.prob was already set. +      return ret; +    } + +    if (mid_iter == search_.middle.end()) break; + +    if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { +      // Didn't find an ngram using hist_iter.   +      // The history used in the found n-gram is [context_rbegin, hist_iter).   +      std::copy(context_rbegin, hist_iter, out_state.history_ + 1); +      // Therefore, we found a (hist_iter - context_rbegin + 1)-gram including the last word.   +      ret.ngram_length = out_state.valid_length_ = (hist_iter - context_rbegin) + 1; +      backoff_start = mid_iter - search_.middle.begin() + 1; +      // ret.prob was already set.   +      return ret; +    } +  } + +  // It passed every lookup in search_.middle.  That means it's at least a (P::Order() - 1)-gram.  +  // All that's left is to check search_.longest.   +   +  if (!search_.LookupLongest(*hist_iter, ret.prob, node)) { +    // It's an (P::Order()-1)-gram +    std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); +    ret.ngram_length = out_state.valid_length_ = P::Order() - 1; +    backoff_start = P::Order() - 1; +    // ret.prob was already set.   +    return ret; +  } +  // It's an P::Order()-gram +  // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. +  std::copy(context_rbegin, context_rbegin + P::Order() - 2, out_state.history_ + 1); +  out_state.valid_length_ = P::Order() - 1; +  ret.ngram_length = P::Order(); +  backoff_start = P::Order(); +  return ret; +} + +template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; +template class GenericModel<SortedHashedSearch, SortedVocabulary>; +template class GenericModel<trie::TrieSearch, SortedVocabulary>; + +} // namespace detail +} // namespace ngram +} // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc new file mode 100644 index 00000000..159628d4 --- /dev/null +++ b/klm/lm/model_test.cc @@ -0,0 +1,200 @@ +#include "lm/model.hh" + +#include <stdlib.h> + +#define BOOST_TEST_MODULE ModelTest +#include <boost/test/unit_test.hpp> + +namespace lm { +namespace ngram { +namespace { + +#define StartTest(word, ngram, score) \ +  ret = model.FullScore( \ +      state, \ +      model.GetVocabulary().Index(word), \ +      out);\ +  BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ +  BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ +  BOOST_CHECK_EQUAL(std::min<unsigned char>(ngram, 5 - 1), out.valid_length_); + +#define AppendTest(word, ngram, score) \ +  StartTest(word, ngram, score) \ +  state = out; + +template <class M> void Starters(const M &model) { +  FullScoreReturn ret; +  Model::State state(model.BeginSentenceState()); +  Model::State out; + +  StartTest("looking", 2, -0.4846522); + +  // , probability plus <s> backoff +  StartTest(",", 1, -1.383514 + -0.4149733); +  // <unk> probability plus <s> backoff +  StartTest("this_is_not_found", 0, -1.995635 + -0.4149733); +} + +template <class M> void Continuation(const M &model) { +  FullScoreReturn ret; +  Model::State state(model.BeginSentenceState()); +  Model::State out; + +  AppendTest("looking", 2, -0.484652); +  AppendTest("on", 3, -0.348837); +  AppendTest("a", 4, -0.0155266); +  AppendTest("little", 5, -0.00306122); +  State preserve = state; +  AppendTest("the", 1, -4.04005); +  AppendTest("biarritz", 1, -1.9889); +  AppendTest("not_found", 0, -2.29666); +  AppendTest("more", 1, -1.20632); +  AppendTest(".", 2, -0.51363); +  AppendTest("</s>", 3, -0.0191651); + +  state = preserve; +  AppendTest("more", 5, -0.00181395); +  AppendTest("loin", 5, -0.0432557); +} + +#define StatelessTest(word, provide, ngram, score) \ +  ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \ +  BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ +  BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \ +  model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \ +  ret = model.FullScore(before, indices[num_words - word - 1], out); \ +  BOOST_CHECK(state == out); \ +  BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ +  BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); + +template <class M> void Stateless(const M &model) { +  const char *words[] = {"<s>", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", "</s>"}; +  const size_t num_words = sizeof(words) / sizeof(const char*); +  // Silience "array subscript is above array bounds" when extracting end pointer. +  WordIndex indices[num_words + 1]; +  for (unsigned int i = 0; i < num_words; ++i) { +    indices[num_words - 1 - i] = model.GetVocabulary().Index(words[i]); +  } +  FullScoreReturn ret; +  State state, out, before; + +  ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state); +  BOOST_CHECK_CLOSE(-0.484652, ret.prob, 0.001); +  StatelessTest(1, 1, 2, -0.484652); + +  // looking +  StatelessTest(1, 2, 2, -0.484652); +  // on +  AppendTest("on", 3, -0.348837); +  StatelessTest(2, 3, 3, -0.348837); +  StatelessTest(2, 2, 3, -0.348837); +  StatelessTest(2, 1, 2, -0.4638903); +  // a +  StatelessTest(3, 4, 4, -0.0155266); +  // little +  AppendTest("little", 5, -0.00306122); +  StatelessTest(4, 5, 5, -0.00306122); +  // the +  AppendTest("the", 1, -4.04005); +  StatelessTest(5, 5, 1, -4.04005); +  // No context of the.   +  StatelessTest(5, 0, 1, -1.687872); +  // biarritz +  StatelessTest(6, 1, 1, -1.9889); +  // not found +  StatelessTest(7, 1, 0, -2.29666); +  StatelessTest(7, 0, 0, -1.995635); + +  WordIndex unk[1]; +  unk[0] = 0; +  model.GetState(unk, unk + 1, state); +  BOOST_CHECK_EQUAL(0, state.valid_length_); +} + +//const char *kExpectedOrderProbing[] = {"<unk>", ",", ".", "</s>", "<s>", "a", "also", "beyond", "biarritz", "call", "concerns", "consider", "considering", "for", "higher", "however", "i", "immediate", "in", "is", "little", "loin", "look", "looking", "more", "on", "screening", "small", "the", "to", "watch", "watching", "what", "would"}; + +class ExpectEnumerateVocab : public EnumerateVocab { +  public: +    ExpectEnumerateVocab() {} + +    void Add(WordIndex index, const StringPiece &str) { +      BOOST_CHECK_EQUAL(seen.size(), index); +      seen.push_back(std::string(str.data(), str.length())); +    } + +    void Check(const base::Vocabulary &vocab) { +      BOOST_CHECK_EQUAL(34, seen.size()); +      BOOST_REQUIRE(!seen.empty()); +      BOOST_CHECK_EQUAL("<unk>", seen[0]); +      for (WordIndex i = 0; i < seen.size(); ++i) { +        BOOST_CHECK_EQUAL(i, vocab.Index(seen[i])); +      } +    } + +    void Clear() { +      seen.clear(); +    } + +    std::vector<std::string> seen; +}; + +template <class ModelT> void LoadingTest() { +  Config config; +  config.arpa_complain = Config::NONE; +  config.messages = NULL; +  ExpectEnumerateVocab enumerate; +  config.enumerate_vocab = &enumerate; +  ModelT m("test.arpa", config); +  enumerate.Check(m.GetVocabulary()); +  Starters(m); +  Continuation(m); +  Stateless(m); +} + +BOOST_AUTO_TEST_CASE(probing) { +  LoadingTest<Model>(); +} + +BOOST_AUTO_TEST_CASE(sorted) { +  LoadingTest<SortedModel>(); +} +BOOST_AUTO_TEST_CASE(trie) { +  LoadingTest<TrieModel>(); +} + +template <class ModelT> void BinaryTest() { +  Config config; +  config.write_mmap = "test.binary"; +  config.messages = NULL; +  ExpectEnumerateVocab enumerate; +  config.enumerate_vocab = &enumerate; + +  { +    ModelT copy_model("test.arpa", config); +    enumerate.Check(copy_model.GetVocabulary()); +    enumerate.Clear(); +  } + +  config.write_mmap = NULL; + +  ModelT binary("test.binary", config); +  enumerate.Check(binary.GetVocabulary()); +  Starters(binary); +  Continuation(binary); +  Stateless(binary); +  unlink("test.binary"); +} + +BOOST_AUTO_TEST_CASE(write_and_read_probing) { +  BinaryTest<Model>(); +} +BOOST_AUTO_TEST_CASE(write_and_read_sorted) { +  BinaryTest<SortedModel>(); +} +BOOST_AUTO_TEST_CASE(write_and_read_trie) { +  BinaryTest<TrieModel>(); +} + +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/ngram.cc b/klm/lm/ngram.cc deleted file mode 100644 index a87c82aa..00000000 --- a/klm/lm/ngram.cc +++ /dev/null @@ -1,522 +0,0 @@ -#include "lm/ngram.hh" - -#include "lm/exception.hh" -#include "util/file_piece.hh" -#include "util/joint_sort.hh" -#include "util/murmur_hash.hh" -#include "util/probing_hash_table.hh" - -#include <algorithm> -#include <functional> -#include <numeric> -#include <limits> -#include <string> - -#include <cmath> -#include <fcntl.h> -#include <errno.h> -#include <stdlib.h> -#include <sys/mman.h> -#include <sys/types.h> -#include <sys/stat.h> -#include <unistd.h> - -namespace lm { -namespace ngram { - -size_t hash_value(const State &state) { -  return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); -} - -namespace detail { -uint64_t HashForVocab(const char *str, std::size_t len) { -  // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000 -  // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit.   -  return util::MurmurHash64A(str, len, 0); -} -  -void Prob::SetBackoff(float to) { -  UTIL_THROW(FormatLoadException, "Attempt to set backoff " << to << " for the highest order n-gram"); -} - -// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok.   -const uint64_t kUnknownHash = HashForVocab("<unk>", 5); -// Sadly some LMs have <UNK>.   -const uint64_t kUnknownCapHash = HashForVocab("<UNK>", 5); - -} // namespace detail - -SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL) {} - -std::size_t SortedVocabulary::Size(std::size_t entries, float ignored) { -  // Lead with the number of entries.   -  return sizeof(uint64_t) + sizeof(Entry) * entries; -} - -void SortedVocabulary::Init(void *start, std::size_t allocated, std::size_t entries) { -  assert(allocated >= Size(entries)); -  // Leave space for number of entries.   -  begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1); -  end_ = begin_; -  saw_unk_ = false; -} - -WordIndex SortedVocabulary::Insert(const StringPiece &str) { -  uint64_t hashed = detail::HashForVocab(str); -  if (hashed == detail::kUnknownHash || hashed == detail::kUnknownCapHash) { -    saw_unk_ = true; -    return 0; -  } -  end_->key = hashed; -  ++end_; -  // This is 1 + the offset where it was inserted to make room for unk.   -  return end_ - begin_; -} - -bool SortedVocabulary::FinishedLoading(detail::ProbBackoff *reorder_vocab) { -  util::JointSort(begin_, end_, reorder_vocab + 1); -  SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1); -  // Save size.   -  *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_; -  return saw_unk_; -} - -void SortedVocabulary::LoadedBinary() { -  end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); -  SetSpecial(Index("<s>"), Index("</s>"), 0, end_ - begin_ + 1); -} - -namespace detail { - -template <class Search> MapVocabulary<Search>::MapVocabulary() {} - -template <class Search> void MapVocabulary<Search>::Init(void *start, std::size_t allocated, std::size_t entries) { -  lookup_ = Lookup(start, allocated); -  available_ = 1; -  // Later if available_ != expected_available_ then we can throw UnknownMissingException. -  saw_unk_ = false; -} - -template <class Search> WordIndex MapVocabulary<Search>::Insert(const StringPiece &str) { -  uint64_t hashed = HashForVocab(str); -  // Prevent unknown from going into the table.   -  if (hashed == kUnknownHash || hashed == kUnknownCapHash) { -    saw_unk_ = true; -    return 0; -  } else { -    lookup_.Insert(Lookup::Packing::Make(hashed, available_)); -    return available_++; -  } -} - -template <class Search> bool MapVocabulary<Search>::FinishedLoading(ProbBackoff *reorder_vocab) { -  lookup_.FinishedInserting(); -  SetSpecial(Index("<s>"), Index("</s>"), 0, available_); -  return saw_unk_; -} - -template <class Search> void MapVocabulary<Search>::LoadedBinary() { -  lookup_.LoadedBinary(); -  SetSpecial(Index("<s>"), Index("</s>"), 0, available_); -} - -/* All of the entropy is in low order bits and boost::hash does poorly with - * these.  Odd numbers near 2^64 chosen by mashing on the keyboard.  There is a - * stable point: 0.  But 0 is <unk> which won't be queried here anyway.   - */ -inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { -  uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL); -  return ret; -} - -uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) { -  if (word == word_end) return 0; -  uint64_t current = static_cast<uint64_t>(*word); -  for (++word; word != word_end; ++word) { -    current = CombineWordHash(current, *word); -  } -  return current; -} - -bool IsEntirelyWhiteSpace(const StringPiece &line) { -  for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) { -    if (!isspace(line.data()[i])) return false; -  } -  return true; -} - -void ReadARPACounts(util::FilePiece &in, std::vector<size_t> &number) { -  number.clear(); -  StringPiece line; -  if (!IsEntirelyWhiteSpace(line = in.ReadLine())) UTIL_THROW(FormatLoadException, "First line was \"" << line << "\" not blank"); -  if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); -  while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { -    if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); -    // So strtol doesn't go off the end of line.   -    std::string remaining(line.data() + 6, line.size() - 6); -    char *end_ptr; -    unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); -    if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line); -    if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line); -    ++end_ptr; -    const char *start = end_ptr; -    long int count = std::strtol(start, &end_ptr, 10); -    if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); -    if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); -    number.push_back(count); -  } -} - -void ReadNGramHeader(util::FilePiece &in, unsigned int length) { -  StringPiece line; -  while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} -  std::stringstream expected; -  expected << '\\' << length << "-grams:"; -  if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead."); -} - -// Special unigram reader because unigram's data structure is different and because we're inserting vocab words. -template <class Voc> void Read1Grams(util::FilePiece &f, const size_t count, Voc &vocab, ProbBackoff *unigrams) { -  ReadNGramHeader(f, 1); -  for (size_t i = 0; i < count; ++i) { -    try { -      float prob = f.ReadFloat(); -      if (f.get() != '\t') UTIL_THROW(FormatLoadException, "Expected tab after probability"); -      ProbBackoff &value = unigrams[vocab.Insert(f.ReadDelimited())]; -      value.prob = prob; -      switch (f.get()) { -        case '\t': -          value.SetBackoff(f.ReadFloat()); -          if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); -          break; -        case '\n': -          value.ZeroBackoff(); -          break; -        default: -          UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); -      } -     } catch(util::Exception &e) { -      e << " in the " << i << "th 1-gram at byte " << f.Offset(); -      throw; -    } -  } -  if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after unigrams at byte " << f.Offset()); -} - -template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) { -  ReadNGramHeader(f, n); - -  // vocab ids of words in reverse order -  WordIndex vocab_ids[n]; -  typename Store::Packing::Value value; -  for (size_t i = 0; i < count; ++i) { -    try { -      value.prob = f.ReadFloat(); -      for (WordIndex *vocab_out = &vocab_ids[n-1]; vocab_out >= vocab_ids; --vocab_out) { -        *vocab_out = vocab.Index(f.ReadDelimited()); -      } -      uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n); - -      switch (f.get()) { -        case '\t': -          value.SetBackoff(f.ReadFloat()); -          if ((f.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); -          break; -        case '\n': -          value.ZeroBackoff(); -          break; -        default: -          UTIL_THROW(FormatLoadException, "Expected tab or newline after n-gram"); -      } -      store.Insert(Store::Packing::Make(key, value)); -    } catch(util::Exception &e) { -      e << " in the " << i << "th " << n << "-gram at byte " << f.Offset(); -      throw; -    } -  } - -  if (f.ReadLine().size()) UTIL_THROW(FormatLoadException, "Expected blank line after " << n << "-grams at byte " << f.Offset()); -  store.FinishedInserting(); -} - -template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<size_t> &counts, const Config &config) { -  if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ".  Edit ngram.hh's kMaxOrder to at least this value and recompile."); -  if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); -  size_t memory_size = VocabularyT::Size(counts[0], config.probing_multiplier); -  memory_size += sizeof(ProbBackoff) * (counts[0] + 1); // +1 for hallucinate <unk> -  for (unsigned char n = 2; n < counts.size(); ++n) { -    memory_size += Middle::Size(counts[n - 1], config.probing_multiplier); -  } -  memory_size += Longest::Size(counts.back(), config.probing_multiplier); -  return memory_size; -} - -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(char *base, const std::vector<size_t> &counts, const Config &config) { -  char *start = base; -  size_t allocated = VocabularyT::Size(counts[0], config.probing_multiplier); -  vocab_.Init(start, allocated, counts[0]); -  start += allocated; -  unigram_ = reinterpret_cast<ProbBackoff*>(start); -  start += sizeof(ProbBackoff) * (counts[0] + 1); -  for (unsigned int n = 2; n < counts.size(); ++n) { -    allocated = Middle::Size(counts[n - 1], config.probing_multiplier); -    middle_.push_back(Middle(start, allocated)); -    start += allocated; -  } -  allocated = Longest::Size(counts.back(), config.probing_multiplier); -  longest_ = Longest(start, allocated); -  start += allocated; -  if (static_cast<std::size_t>(start - base) != Size(counts, config)) UTIL_THROW(FormatLoadException, "The data structures took " << (start - base) << " but Size says they should take " << Size(counts, config)); -} - -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 0\n\0"; -struct BinaryFileHeader { -  char magic[sizeof(kMagicBytes)]; -  float zero_f, one_f, minus_half_f; -  WordIndex one_word_index, max_word_index; -  uint64_t one_uint64; - -  void SetToReference() { -    std::memcpy(magic, kMagicBytes, sizeof(magic)); -    zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5; -    one_word_index = 1; -    max_word_index = std::numeric_limits<WordIndex>::max(); -    one_uint64 = 1; -  } -}; - -bool IsBinaryFormat(int fd, off_t size) { -  if (size == util::kBadSize || (size <= static_cast<off_t>(sizeof(BinaryFileHeader)))) return false; -  // Try reading the header.   -  util::scoped_mmap memory(mmap(NULL, sizeof(BinaryFileHeader), PROT_READ, MAP_FILE | MAP_PRIVATE, fd, 0), sizeof(BinaryFileHeader)); -  if (memory.get() == MAP_FAILED) return false; -  BinaryFileHeader reference_header = BinaryFileHeader(); -  reference_header.SetToReference(); -  if (!memcmp(memory.get(), &reference_header, sizeof(BinaryFileHeader))) return true; -  if (!memcmp(memory.get(), "mmap lm ", 8)) UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match.  Was it built on a different machine or with a different compiler?"); -  return false; -} - -std::size_t Align8(std::size_t in) { -  std::size_t off = in % 8; -  if (!off) return in; -  return in + 8 - off; -} - -std::size_t TotalHeaderSize(unsigned int order) { -  return Align8(sizeof(BinaryFileHeader) + 1 /* order */ + sizeof(uint64_t) * order /* counts */ + sizeof(float) /* probing multiplier */ + 1 /* search_tag */); -} - -void ReadBinaryHeader(const void *from, off_t size, std::vector<size_t> &out, float &probing_multiplier, unsigned char &search_tag) { -  const char *from_char = reinterpret_cast<const char*>(from); -  if (size < static_cast<off_t>(1 + sizeof(BinaryFileHeader))) UTIL_THROW(FormatLoadException, "File too short to have count information."); -  // Skip over the BinaryFileHeader which was read by IsBinaryFormat.   -  from_char += sizeof(BinaryFileHeader); -  unsigned char order = *reinterpret_cast<const unsigned char*>(from_char); -  if (size < static_cast<off_t>(TotalHeaderSize(order))) UTIL_THROW(FormatLoadException, "File too short to have full header."); -  out.resize(static_cast<std::size_t>(order)); -  const uint64_t *counts = reinterpret_cast<const uint64_t*>(from_char + 1); -  for (std::size_t i = 0; i < out.size(); ++i) { -    out[i] = static_cast<std::size_t>(counts[i]); -  } -  const float *probing_ptr = reinterpret_cast<const float*>(counts + out.size()); -  probing_multiplier = *probing_ptr; -  search_tag = *reinterpret_cast<const char*>(probing_ptr + 1); -} - -void WriteBinaryHeader(void *to, const std::vector<size_t> &from, float probing_multiplier, char search_tag) { -  BinaryFileHeader header = BinaryFileHeader(); -  header.SetToReference(); -  memcpy(to, &header, sizeof(BinaryFileHeader)); -  char *out = reinterpret_cast<char*>(to) + sizeof(BinaryFileHeader); -  *reinterpret_cast<unsigned char*>(out) = static_cast<unsigned char>(from.size()); -  uint64_t *counts = reinterpret_cast<uint64_t*>(out + 1); -  for (std::size_t i = 0; i < from.size(); ++i) { -    counts[i] = from[i]; -  } -  float *probing_ptr = reinterpret_cast<float*>(counts + from.size()); -  *probing_ptr = probing_multiplier; -  *reinterpret_cast<char*>(probing_ptr + 1) = search_tag; -} - -template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, Config config) : mapped_file_(util::OpenReadOrThrow(file)) { -  const off_t file_size = util::SizeFile(mapped_file_.get()); - -  std::vector<size_t> counts; - -  if (IsBinaryFormat(mapped_file_.get(), file_size)) { -    memory_.reset(util::MapForRead(file_size, config.prefault, mapped_file_.get()), file_size); - -    unsigned char search_tag; -    ReadBinaryHeader(memory_.begin(), file_size, counts, config.probing_multiplier, search_tag); -    if (config.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << config.probing_multiplier << " which is < 1.0."); -    if (search_tag != Search::kBinaryTag) UTIL_THROW(FormatLoadException, "The binary file has a different search strategy than the one requested."); -    size_t memory_size = Size(counts, config); - -    char *start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size()); -    if (memory_size != static_cast<size_t>(memory_.end() - start)) UTIL_THROW(FormatLoadException, "The mmap file " << file << " has size " << file_size << " but " << (memory_size + TotalHeaderSize(counts.size())) << " was expected based on the number of counts and configuration."); - -    SetupMemory(start, counts, config); -    vocab_.LoadedBinary(); -    for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) { -      i->LoadedBinary(); -    } -    longest_.LoadedBinary(); - -  } else { -    if (config.probing_multiplier <= 1.0) UTIL_THROW(FormatLoadException, "probing multiplier must be > 1.0"); - -    util::FilePiece f(file, mapped_file_.release(), config.messages); -    ReadARPACounts(f, counts); -    size_t memory_size = Size(counts, config); -    char *start; - -    if (config.write_mmap) { -      // Write out an mmap file.   -      util::MapZeroedWrite(config.write_mmap, TotalHeaderSize(counts.size()) + memory_size, mapped_file_, memory_); -      WriteBinaryHeader(memory_.get(), counts, config.probing_multiplier, Search::kBinaryTag); -      start = reinterpret_cast<char*>(memory_.get()) + TotalHeaderSize(counts.size()); -    } else { -      memory_.reset(util::MapAnonymous(memory_size), memory_size); -      start = reinterpret_cast<char*>(memory_.get()); -    } -    SetupMemory(start, counts, config); -    try { -      LoadFromARPA(f, counts, config); -    } catch (FormatLoadException &e) { -      e << " in file " << file; -      throw; -    } -  } - -  // g++ prints warnings unless these are fully initialized.   -  State begin_sentence = State(); -  begin_sentence.valid_length_ = 1; -  begin_sentence.history_[0] = vocab_.BeginSentence(); -  begin_sentence.backoff_[0] = unigram_[begin_sentence.history_[0]].backoff; -  State null_context = State(); -  null_context.valid_length_ = 0; -  P::Init(begin_sentence, null_context, vocab_, counts.size()); -} - -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::LoadFromARPA(util::FilePiece &f, const std::vector<size_t> &counts, const Config &config) { -  // Read the unigrams. -  Read1Grams(f, counts[0], vocab_, unigram_); -  bool saw_unk = vocab_.FinishedLoading(unigram_); -  if (!saw_unk) { -    switch(config.unknown_missing) { -      case Config::THROW_UP: -        { -          SpecialWordMissingException e("<unk>"); -          e << " and configuration was set to throw if unknown is missing"; -          throw e; -        } -      case Config::COMPLAIN: -        if (config.messages) *config.messages << "Language model is missing <unk>.  Substituting probability " << config.unknown_missing_prob << "." << std::endl;  -        // There's no break;.  This is by design.   -      case Config::SILENT: -        // Default probabilities for unknown.   -        unigram_[0].backoff = 0.0; -        unigram_[0].prob = config.unknown_missing_prob; -        break; -    } -  } -   -  // Read the n-grams. -  for (unsigned int n = 2; n < counts.size(); ++n) { -    ReadNGrams(f, n, counts[n-1], vocab_, middle_[n-2]); -  } -  ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab_, longest_); -  if (std::fabs(unigram_[0].backoff) > 0.0000001) UTIL_THROW(FormatLoadException, "Backoff for unknown word should be zero, but was given as " << unigram_[0].backoff); -} - -/* Ugly optimized function. - * in_state contains the previous ngram's length and backoff probabilites to - * be used here.  out_state is populated with the found ngram length and - * backoffs that the next call will find useful.   - * - * The search goes in increasing order of ngram length.   - */ -template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore( -    const State &in_state, -    const WordIndex new_word, -    State &out_state) const { - -  FullScoreReturn ret; -  // This is end pointer passed to SumBackoffs. -  const ProbBackoff &unigram = unigram_[new_word]; -  if (new_word == 0) { -    ret.ngram_length = out_state.valid_length_ = 0; -    // all of backoff. -    ret.prob = std::accumulate( -        in_state.backoff_, -        in_state.backoff_ + in_state.valid_length_, -        unigram.prob); -    return ret; -  } -  float *backoff_out(out_state.backoff_); -  *backoff_out = unigram.backoff; -  ret.prob = unigram.prob; -  out_state.history_[0] = new_word; -  if (in_state.valid_length_ == 0) { -    ret.ngram_length = out_state.valid_length_ = 1; -    // No backoff because NGramLength() == 0 and unknown can't have backoff. -    return ret; -  } -  ++backoff_out; - -  // Ok now we now that the bigram contains known words.  Start by looking it up. - -  uint64_t lookup_hash = static_cast<uint64_t>(new_word); -  const WordIndex *hist_iter = in_state.history_; -  const WordIndex *const hist_end = hist_iter + in_state.valid_length_; -  typename std::vector<Middle>::const_iterator mid_iter = middle_.begin(); -  for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { -    if (hist_iter == hist_end) { -      // Used history [in_state.history_, hist_end) and ran out.  No backoff.   -      std::copy(in_state.history_, hist_end, out_state.history_ + 1); -      ret.ngram_length = out_state.valid_length_ = in_state.valid_length_ + 1; -      // ret.prob was already set. -      return ret; -    } -    lookup_hash = CombineWordHash(lookup_hash, *hist_iter); -    if (mid_iter == middle_.end()) break; -    typename Middle::ConstIterator found; -    if (!mid_iter->Find(lookup_hash, found)) { -      // Didn't find an ngram using hist_iter.   -      // The history used in the found n-gram is [in_state.history_, hist_iter).   -      std::copy(in_state.history_, hist_iter, out_state.history_ + 1); -      // Therefore, we found a (hist_iter - in_state.history_ + 1)-gram including the last word.   -      ret.ngram_length = out_state.valid_length_ = (hist_iter - in_state.history_) + 1; -      ret.prob = std::accumulate( -          in_state.backoff_ + (mid_iter - middle_.begin()), -          in_state.backoff_ + in_state.valid_length_, -          ret.prob); -      return ret; -    } -    *backoff_out = found->GetValue().backoff; -    ret.prob = found->GetValue().prob; -  } -   -  typename Longest::ConstIterator found; -  if (!longest_.Find(lookup_hash, found)) { -    // It's an (P::Order()-1)-gram -    std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1); -    ret.ngram_length = out_state.valid_length_ = P::Order() - 1; -    ret.prob += in_state.backoff_[P::Order() - 2]; -    return ret; -  } -  // It's an P::Order()-gram -  // out_state.valid_length_ is still P::Order() - 1 because the next lookup will only need that much. -  std::copy(in_state.history_, in_state.history_ + P::Order() - 2, out_state.history_ + 1); -  out_state.valid_length_ = P::Order() - 1; -  ret.ngram_length = P::Order(); -  ret.prob = found->GetValue().prob; -  return ret; -} - -template class GenericModel<ProbingSearch, MapVocabulary<ProbingSearch> >; -template class GenericModel<SortedUniformSearch, SortedVocabulary>; -} // namespace detail -} // namespace ngram -} // namespace lm diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index d1970260..74457a74 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -1,4 +1,4 @@ -#include "lm/ngram.hh" +#include "lm/model.hh"  #include <cstdlib>  #include <fstream> diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc new file mode 100644 index 00000000..8e9a770d --- /dev/null +++ b/klm/lm/read_arpa.cc @@ -0,0 +1,154 @@ +#include "lm/read_arpa.hh" + +#include <cstdlib> +#include <vector> + +#include <ctype.h> +#include <inttypes.h> + +namespace lm { + +namespace { + +bool IsEntirelyWhiteSpace(const StringPiece &line) { +  for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) { +    if (!isspace(line.data()[i])) return false; +  } +  return true; +} + +template <class F> void GenericReadARPACounts(F &in, std::vector<uint64_t> &number) { +  number.clear(); +  StringPiece line; +  if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { +    if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) { +      UTIL_THROW(FormatLoadException, "Looks like a gzip file.  If this is an ARPA file, run\nzcat " << in.FileName() << " |kenlm/build_binary /dev/stdin " << in.FileName() << ".binary\nIf this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); +    } +    UTIL_THROW(FormatLoadException, "First line was \"" << static_cast<int>(line.data()[1]) << "\" not blank"); +  } +  if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); +  while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { +    if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); +    // So strtol doesn't go off the end of line.   +    std::string remaining(line.data() + 6, line.size() - 6); +    char *end_ptr; +    unsigned long int length = std::strtol(remaining.c_str(), &end_ptr, 10); +    if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line); +    if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line); +    ++end_ptr; +    const char *start = end_ptr; +    long int count = std::strtol(start, &end_ptr, 10); +    if (count < 0) UTIL_THROW(FormatLoadException, "Negative n-gram count " << count); +    if (start == end_ptr) UTIL_THROW(FormatLoadException, "Couldn't parse n-gram count from " << line); +    number.push_back(count); +  } +} + +template <class F> void GenericReadNGramHeader(F &in, unsigned int length) { +  StringPiece line; +  while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} +  std::stringstream expected; +  expected << '\\' << length << "-grams:"; +  if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead.  "); +} + +template <class F> void GenericReadEnd(F &in) { +  StringPiece line; +  do { +    line = in.ReadLine(); +  } while (IsEntirelyWhiteSpace(line)); +  if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line); +} + +class FakeFilePiece { +  public: +    explicit FakeFilePiece(std::istream &in) : in_(in) { +      in_.exceptions(std::ios::failbit | std::ios::badbit | std::ios::eofbit); +    } + +    StringPiece ReadLine() throw(util::EndOfFileException) { +      getline(in_, buffer_); +      return StringPiece(buffer_); +    } + +    float ReadFloat() { +      float ret; +      in_ >> ret; +      return ret; +    } + +    const char *FileName() const { +      // This only used for error messages and we don't know the file name. . .  +      return "$file"; +    } + +  private: +    std::istream &in_; +    std::string buffer_; +}; + +} // namespace + +void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) { +  GenericReadARPACounts(in, number); +} +void ReadARPACounts(std::istream &in, std::vector<uint64_t> &number) { +  FakeFilePiece fake(in); +  GenericReadARPACounts(fake, number); +} +void ReadNGramHeader(util::FilePiece &in, unsigned int length) { +  GenericReadNGramHeader(in, length); +} +void ReadNGramHeader(std::istream &in, unsigned int length) { +  FakeFilePiece fake(in); +  GenericReadNGramHeader(fake, length); +} + +void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) { +  switch (in.get()) { +    case '\t': +      { +        float got = in.ReadFloat(); +        if (got != 0.0) +          UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff."); +      } +      break; +    case '\n': +      break; +    default: +      UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); +  } +} + +void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) { +  switch (in.get()) { +    case '\t': +      weights.backoff = in.ReadFloat(); +      if ((in.get() != '\n')) UTIL_THROW(FormatLoadException, "Expected newline after backoff"); +      break; +    case '\n': +      weights.backoff = 0.0; +      break; +    default: +      UTIL_THROW(FormatLoadException, "Expected tab or newline after unigram"); +  } +} + +void ReadEnd(util::FilePiece &in) { +  GenericReadEnd(in); +  StringPiece line; +  try { +    while (true) { +      line = in.ReadLine(); +      if (!IsEntirelyWhiteSpace(line)) UTIL_THROW(FormatLoadException, "Trailing line " << line); +    } +  } catch (const util::EndOfFileException &e) { +    return; +  } +} +void ReadEnd(std::istream &in) { +  FakeFilePiece fake(in); +  GenericReadEnd(fake); +} + +} // namespace lm diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc new file mode 100644 index 00000000..9cb662a6 --- /dev/null +++ b/klm/lm/search_hashed.cc @@ -0,0 +1,66 @@ +#include "lm/search_hashed.hh" + +#include "lm/lm_exception.hh" +#include "lm/read_arpa.hh" +#include "lm/vocab.hh" + +#include "util/file_piece.hh" + +#include <string> + +namespace lm { +namespace ngram { + +namespace { + +/* All of the entropy is in low order bits and boost::hash does poorly with + * these.  Odd numbers near 2^64 chosen by mashing on the keyboard.  There is a + * stable point: 0.  But 0 is <unk> which won't be queried here anyway.   + */ +inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { +  uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL); +  return ret; +} + +uint64_t ChainedWordHash(const WordIndex *word, const WordIndex *word_end) { +  if (word == word_end) return 0; +  uint64_t current = static_cast<uint64_t>(*word); +  for (++word; word != word_end; ++word) { +    current = CombineWordHash(current, *word); +  } +  return current; +} + +template <class Voc, class Store> void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, Store &store) { +  ReadNGramHeader(f, n); + +  // vocab ids of words in reverse order +  WordIndex vocab_ids[n]; +  typename Store::Packing::Value value; +  for (size_t i = 0; i < count; ++i) { +    ReadNGram(f, n, vocab, vocab_ids, value); +    uint64_t key = ChainedWordHash(vocab_ids, vocab_ids + n); +    store.Insert(Store::Packing::Make(key, value)); +  } + +  store.FinishedInserting(); +} + +} // namespace +namespace detail { + +template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &/*config*/, Voc &vocab) { +  Read1Grams(f, counts[0], vocab, unigram.Raw());   +  // Read the n-grams. +  for (unsigned int n = 2; n < counts.size(); ++n) { +    ReadNGrams(f, n, counts[n-1], vocab, middle[n-2]); +  } +  ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, longest); +} + +template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab); +template void TemplateHashedSearch<SortedHashedSearch::Middle, SortedHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, SortedVocabulary &vocab); + +} // namespace detail +} // namespace ngram +} // namespace lm diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc new file mode 100644 index 00000000..182e27f5 --- /dev/null +++ b/klm/lm/search_trie.cc @@ -0,0 +1,402 @@ +#include "lm/search_trie.hh" + +#include "lm/lm_exception.hh" +#include "lm/read_arpa.hh" +#include "lm/trie.hh" +#include "lm/vocab.hh" +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/ersatz_progress.hh" +#include "util/file_piece.hh" +#include "util/scoped.hh" + +#include <algorithm> +#include <cstring> +#include <cstdio> +#include <deque> +#include <iostream> +#include <limits> +//#include <parallel/algorithm> +#include <vector> + +#include <sys/mman.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <fcntl.h> +#include <stdlib.h> + +namespace lm { +namespace ngram { +namespace trie { +namespace { + +template <unsigned char Order> class FullEntry { +  public: +    typedef ProbBackoff Weights; +    static const unsigned char kOrder = Order; + +    // reverse order +    WordIndex words[Order]; +    Weights weights; + +    bool operator<(const FullEntry<Order> &other) const { +      for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { +        if (*i < *j) return true; +        if (*i > *j) return false; +      } +      return false; +    } +}; + +template <unsigned char Order> class ProbEntry { +  public: +    typedef Prob Weights; +    static const unsigned char kOrder = Order; + +    // reverse order +    WordIndex words[Order]; +    Weights weights; + +    bool operator<(const ProbEntry<Order> &other) const { +      for (const WordIndex *i = words, *j = other.words; i != words + Order; ++i, ++j) { +        if (*i < *j) return true; +        if (*i > *j) return false; +      } +      return false; +    } +}; + +void WriteOrThrow(FILE *to, const void *data, size_t size) { +  if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +} + +void ReadOrThrow(FILE *from, void *data, size_t size) { +  if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size); +} + +void CopyOrThrow(FILE *from, FILE *to, size_t size) { +  const size_t kBufSize = 512; +  char buf[kBufSize]; +  for (size_t i = 0; i < size; i += kBufSize) { +    std::size_t amount = std::min(size - i, kBufSize); +    ReadOrThrow(from, buf, amount); +    WriteOrThrow(to, buf, amount); +  } +} + +template <class Entry> std::string DiskFlush(const Entry *begin, const Entry *end, const std::string &file_prefix, std::size_t batch) { +  std::stringstream assembled; +  assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << '_' << batch; +  std::string ret(assembled.str()); +  util::scoped_FILE out(fopen(ret.c_str(), "w")); +  if (!out.get()) UTIL_THROW(util::ErrnoException, "Couldn't open " << assembled.str().c_str() << " for writing"); +  for (const Entry *group_begin = begin; group_begin != end;) { +    const Entry *group_end = group_begin; +    for (++group_end; (group_end != end) && !memcmp(group_begin->words, group_end->words, sizeof(WordIndex) * (Entry::kOrder - 1)); ++group_end) {} +    WriteOrThrow(out.get(), group_begin->words, sizeof(WordIndex) * (Entry::kOrder - 1)); +    WordIndex group_size = group_end - group_begin; +    WriteOrThrow(out.get(), &group_size, sizeof(group_size)); +    for (const Entry *i = group_begin; i != group_end; ++i) { +      WriteOrThrow(out.get(), &i->words[Entry::kOrder - 1], sizeof(WordIndex)); +      WriteOrThrow(out.get(), &i->weights, sizeof(typename Entry::Weights)); +    } +    group_begin = group_end; +  } +  return ret; +} + +class SortedFileReader { +  public: +    SortedFileReader() {} + +    void Init(const std::string &name, unsigned char order) { +      file_.reset(fopen(name.c_str(), "r")); +      if (!file_.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " for read"); +      header_.resize(order - 1); +      NextHeader(); +    } + +    // Preceding words. +    const WordIndex *Header() const { +      return &*header_.begin(); +    } +    const std::vector<WordIndex> &HeaderVector() const { return header_;}  + +    std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); } + +    void NextHeader() { +      if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get()) && !Ended()) { +        UTIL_THROW(util::ErrnoException, "Short read of counts"); +      } +    } + +    void ReadCount(WordIndex &to) { +      ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); +    } + +    void ReadWord(WordIndex &to) { +      ReadOrThrow(file_.get(), &to, sizeof(WordIndex)); +    } + +    template <class Weights> void ReadWeights(Weights &to) { +      ReadOrThrow(file_.get(), &to, sizeof(Weights)); +    } + +    bool Ended() { +      return feof(file_.get()); +    } + +    FILE *File() { return file_.get(); } + +  private: +    util::scoped_FILE file_; + +    std::vector<WordIndex> header_; +}; + +void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) { +  WriteOrThrow(to, from.Header(), from.HeaderBytes()); +  WordIndex count; +  from.ReadCount(count); +  WriteOrThrow(to, &count, sizeof(WordIndex)); + +  CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); +} + +void MergeSortedFiles(const char *first_name, const char *second_name, const char *out, std::size_t weights_size, unsigned char order) { +  SortedFileReader first, second; +  first.Init(first_name, order); +  second.Init(second_name, order); +  util::scoped_FILE out_file(fopen(out, "w")); +  if (!out_file.get()) UTIL_THROW(util::ErrnoException, "Could not open " << out << " for write"); +  while (!first.Ended() && !second.Ended()) { +    if (first.HeaderVector() < second.HeaderVector()) { +      CopyFullRecord(first, out_file.get(), weights_size); +      first.NextHeader(); +      continue; +    }  +    if (first.HeaderVector() > second.HeaderVector()) { +      CopyFullRecord(second, out_file.get(), weights_size); +      second.NextHeader(); +      continue; +    } +    // Merge at the entry level. +    WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes()); +    WordIndex first_count, second_count; +    first.ReadCount(first_count); second.ReadCount(second_count); +    WordIndex total_count = first_count + second_count; +    WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex)); + +    WordIndex first_word, second_word; +    first.ReadWord(first_word); second.ReadWord(second_word); +    WordIndex first_index = 0, second_index = 0; +    while (true) { +      if (first_word < second_word) { +        WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); +        CopyOrThrow(first.File(), out_file.get(), weights_size); +        if (++first_index == first_count) break; +        first.ReadWord(first_word); +      } else { +        WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); +        CopyOrThrow(second.File(), out_file.get(), weights_size); +        if (++second_index == second_count) break; +        second.ReadWord(second_word); +      } +    } +    if (first_index == first_count) { +      WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); +      CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); +    } else { +      WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); +      CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); +    } +    first.NextHeader(); +    second.NextHeader(); +  } + +  for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) { +    CopyFullRecord(remaining, out_file.get(), weights_size); +  } +} + +template <class Entry> void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, util::scoped_memory &mem, const std::string &file_prefix) { +  ConvertToSorted<FullEntry<Entry::kOrder - 1> >(f, vocab, counts, mem, file_prefix); + +  ReadNGramHeader(f, Entry::kOrder); +  const size_t count = counts[Entry::kOrder - 1]; +  const size_t batch_size = std::min(count, mem.size() / sizeof(Entry)); +  Entry *const begin = reinterpret_cast<Entry*>(mem.get()); +  std::deque<std::string> files; +  for (std::size_t batch = 0, done = 0; done < count; ++batch) { +    Entry *out = begin; +    Entry *out_end = out + std::min(count - done, batch_size); +    for (; out != out_end; ++out) { +      ReadNGram(f, Entry::kOrder, vocab, out->words, out->weights); +    } +    //__gnu_parallel::sort(begin, out_end); +    std::sort(begin, out_end); +     +    files.push_back(DiskFlush(begin, out_end, file_prefix, batch)); +    done += out_end - begin; +  } + +  // All individual files created.  Merge them.   + +  std::size_t merge_count = 0; +  while (files.size() > 1) { +    std::stringstream assembled; +    assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merge_" << (merge_count++); +    files.push_back(assembled.str()); +    MergeSortedFiles(files[0].c_str(), files[1].c_str(), files.back().c_str(), sizeof(typename Entry::Weights), Entry::kOrder); +    if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); +    files.pop_front(); +    if (std::remove(files[0].c_str())) UTIL_THROW(util::ErrnoException, "Could not remove " << files[0]); +    files.pop_front(); +  } +  if (!files.empty()) { +    std::stringstream assembled; +    assembled << file_prefix << static_cast<unsigned int>(Entry::kOrder) << "_merged"; +    std::string merged_name(assembled.str()); +    if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); +  } +} + +template <> void ConvertToSorted<FullEntry<1> >(util::FilePiece &/*f*/, const SortedVocabulary &/*vocab*/, const std::vector<uint64_t> &/*counts*/, util::scoped_memory &/*mem*/, const std::string &/*file_prefix*/) {} + +void ARPAToSortedFiles(util::FilePiece &f, const std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { +  { +    std::string unigram_name = file_prefix + "unigrams"; +    util::scoped_fd unigram_file; +    util::scoped_mmap unigram_mmap; +    unigram_mmap.reset(util::MapZeroedWrite(unigram_name.c_str(), counts[0] * sizeof(ProbBackoff), unigram_file), counts[0] * sizeof(ProbBackoff)); +    Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get())); +  } + +  util::scoped_memory mem; +  mem.reset(malloc(buffer), buffer, util::scoped_memory::ARRAY_ALLOCATED); +  if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); +  ConvertToSorted<ProbEntry<5> >(f, vocab, counts, mem, file_prefix); +  ReadEnd(f); +} + +struct RecursiveInsertParams { +  WordIndex *words; +  SortedFileReader *files; +  unsigned char max_order; +  // This is an array of size order - 2. +  BitPackedMiddle *middle; +  // This has exactly one entry. +  BitPackedLongest *longest; +}; + +uint64_t RecursiveInsert(RecursiveInsertParams ¶ms, unsigned char order) { +  SortedFileReader &file = params.files[order - 2]; +  const uint64_t ret = (order == params.max_order) ? params.longest->InsertIndex() : params.middle[order - 2].InsertIndex(); +  if (std::memcmp(params.words, file.Header(), sizeof(WordIndex) * (order - 1))) +    return ret; +  WordIndex count; +  file.ReadCount(count); +  WordIndex key; +  if (order == params.max_order) { +    Prob value; +    for (WordIndex i = 0; i < count; ++i) { +      file.ReadWord(key); +      file.ReadWeights(value); +      params.longest->Insert(key, value.prob); +    } +    file.NextHeader(); +    return ret; +  } +  ProbBackoff value; +  for (WordIndex i = 0; i < count; ++i) { +    file.ReadWord(params.words[order - 1]); +    file.ReadWeights(value); +    params.middle[order - 2].Insert( +        params.words[order - 1], +        value.prob, +        value.backoff, +        RecursiveInsert(params, order + 1)); +  } +  file.NextHeader(); +  return ret; +} + +void BuildTrie(const std::string &file_prefix, const std::vector<uint64_t> &counts, std::ostream *messages, TrieSearch &out) { +  UnigramValue *unigrams = out.unigram.Raw(); +  // Load unigrams.  Leave the next pointers uninitialized.    +  { +    std::string name(file_prefix + "unigrams"); +    util::scoped_FILE file(fopen(name.c_str(), "r")); +    if (!file.get()) UTIL_THROW(util::ErrnoException, "Opening " << name << " failed"); +    for (WordIndex i = 0; i < counts[0]; ++i) { +      ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); +    } +    unlink(name.c_str()); +  } + +  // inputs[0] is bigrams. +  SortedFileReader inputs[counts.size() - 1]; +  for (unsigned char i = 2; i <= counts.size(); ++i) { +    std::stringstream assembled; +    assembled << file_prefix << static_cast<unsigned int>(i) << "_merged"; +    inputs[i-2].Init(assembled.str(), i); +    unlink(assembled.str().c_str()); +  } + +  // words[0] is unigrams.   +  WordIndex words[counts.size()]; +  RecursiveInsertParams params; +  params.words = words; +  params.files = inputs; +  params.max_order = static_cast<unsigned char>(counts.size()); +  params.middle = &*out.middle.begin(); +  params.longest = &out.longest; +  { +    util::ErsatzProgress progress(messages, "Building trie", counts[0]); +    for (words[0] = 0; words[0] < counts[0]; ++words[0], ++progress) { +      unigrams[words[0]].next = RecursiveInsert(params, 2); +    } +  } + +  /* Set ending offsets so the last entry will be sized properly */ +  if (!out.middle.empty()) { +    unigrams[counts[0]].next = out.middle.front().InsertIndex(); +    for (size_t i = 0; i < out.middle.size() - 1; ++i) { +      out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); +    } +    out.middle.back().FinishedLoading(out.longest.InsertIndex()); +  } else { +    unigrams[counts[0]].next = out.longest.InsertIndex(); +  } +} + +} // namespace + +void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab) { +  std::string temporary_directory; +  if (config.temporary_directory_prefix) { +    temporary_directory = config.temporary_directory_prefix; +  } else if (config.write_mmap) { +    temporary_directory = config.write_mmap; +  } else { +    temporary_directory = file; +  } +  // Null on end is kludge to ensure null termination. +  temporary_directory += "-tmp-XXXXXX\0"; +  if (!mkdtemp(&temporary_directory[0])) { +    UTIL_THROW(util::ErrnoException, "Failed to make a temporary directory based on the name " << temporary_directory.c_str()); +  } +  // Chop off null kludge.   +  temporary_directory.resize(strlen(temporary_directory.c_str())); +  // Add directory delimiter.  Assumes a real operating system.   +  temporary_directory += '/'; +  ARPAToSortedFiles(f, counts, config.building_memory, temporary_directory.c_str(), vocab); +  BuildTrie(temporary_directory.c_str(), counts, config.messages, *this); +  if (rmdir(temporary_directory.c_str())) { +    std::cerr << "Failed to delete " << temporary_directory << std::endl; +  } +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/sri.cc b/klm/lm/sri.cc index 7bd23d76..b634d200 100644 --- a/klm/lm/sri.cc +++ b/klm/lm/sri.cc @@ -1,4 +1,4 @@ -#include "lm/exception.hh" +#include "lm/lm_exception.hh"  #include "lm/sri.hh"  #include <Ngram.h> @@ -31,8 +31,7 @@ void Vocabulary::FinishedLoading() {    SetSpecial(      sri_->ssIndex(),      sri_->seIndex(), -    sri_->unkIndex(), -    sri_->highIndex() + 1); +    sri_->unkIndex());  }  namespace { diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc new file mode 100644 index 00000000..8ed7b2a2 --- /dev/null +++ b/klm/lm/trie.cc @@ -0,0 +1,167 @@ +#include "lm/trie.hh" + +#include "util/bit_packing.hh" +#include "util/exception.hh" +#include "util/proxy_iterator.hh" +#include "util/sorted_uniform.hh" + +#include <assert.h> + +namespace lm { +namespace ngram { +namespace trie { +namespace { + +// Assumes key is first.   +class JustKeyProxy { +  public: +    JustKeyProxy() : inner_(), base_(), key_mask_(), total_bits_() {} + +    operator uint64_t() const { return GetKey(); } + +    uint64_t GetKey() const { +      uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_); +      return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_mask_); +    } + +  private: +    friend class util::ProxyIterator<JustKeyProxy>; +    friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); + +    JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t total_bits) +      : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), total_bits_(total_bits) {} + +    // This is a read-only iterator.   +    JustKeyProxy &operator=(const JustKeyProxy &other); + +    typedef uint64_t value_type; + +    typedef uint64_t InnerIterator; +    uint64_t &Inner() { return inner_; } +    const uint64_t &Inner() const { return inner_; } + +    // The address in bits is base_ * 8 + inner_ * total_bits_.   +    uint64_t inner_; +    const uint8_t *const base_; +    const uint64_t key_mask_; +    const uint8_t total_bits_; +}; + +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { +  util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, total_bits)); +  util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, total_bits)); +  util::ProxyIterator<JustKeyProxy> out; +  if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false; +  at_index = out.Inner(); +  return true; +} +} // namespace + +std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { +  uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits; +  // Extra entry for next pointer at the end.   +  // +7 then / 8 to round up bits and convert to bytes +  // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.   +  // Note that this waste is O(order), not O(number of ngrams). +  return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t); +} + +void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) { +  util::BitPackingSanity(); +  word_bits_ = util::RequiredBits(max_vocab); +  word_mask_ = (1ULL << word_bits_) - 1ULL; +  if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented.  Edit util/bit_packing.hh and fix the bit packing functions."); +  prob_bits_ = 31; +  total_bits_ = word_bits_ + prob_bits_ + remaining_bits; + +  base_ = static_cast<uint8_t*>(base); +  insert_index_ = 0; +} + +std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { +  return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr)); +} + +void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next) { +  backoff_bits_ = 32; +  next_bits_ = util::RequiredBits(max_next); +  if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order.  Edit util/bit_packing.hh and fix the bit packing functions."); +  next_mask_ = (1ULL << next_bits_)  - 1; + +  BaseInit(base, max_vocab, backoff_bits_ + next_bits_); +} + +void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff, uint64_t next) { +  assert(word <= word_mask_); +  assert(next <= next_mask_); +  uint64_t at_pointer = insert_index_ * total_bits_; + +  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word); +  at_pointer += word_bits_; +  util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); +  at_pointer += prob_bits_; +  util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); +  at_pointer += backoff_bits_; +  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next); + +  ++insert_index_; +} + +bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +  uint64_t at_pointer; +  if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; +  at_pointer *= total_bits_; +  at_pointer += word_bits_; +  prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); +  at_pointer += prob_bits_; +  backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); +  at_pointer += backoff_bits_; +  range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); +  // Read the next entry's pointer.   +  at_pointer += total_bits_; +  range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); +  return true; +} + +bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { +  uint64_t at_pointer; +  if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false; +  at_pointer *= total_bits_; +  at_pointer += word_bits_; +  at_pointer += prob_bits_; +  backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); +  at_pointer += backoff_bits_; +  range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); +  // Read the next entry's pointer.   +  at_pointer += total_bits_; +  range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_); +  return true; +} + +void BitPackedMiddle::FinishedLoading(uint64_t next_end) { +  assert(next_end <= next_mask_); +  uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; +  util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_end); +} + + +void BitPackedLongest::Insert(WordIndex index, float prob) { +  assert(index <= word_mask_); +  uint64_t at_pointer = insert_index_ * total_bits_; +  util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, index); +  at_pointer += word_bits_; +  util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); +  ++insert_index_; +} + +bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &node) const { +  uint64_t at_pointer; +  if (!FindBitPacked(base_, word_mask_, total_bits_, node.begin, node.end, word, at_pointer)) return false; +  at_pointer = at_pointer * total_bits_ + word_bits_; +  prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); +  return true; +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/virtual_interface.cc b/klm/lm/virtual_interface.cc index 9c7151f9..c5a64972 100644 --- a/klm/lm/virtual_interface.cc +++ b/klm/lm/virtual_interface.cc @@ -1,17 +1,16 @@  #include "lm/virtual_interface.hh" -#include "lm/exception.hh" +#include "lm/lm_exception.hh"  namespace lm {  namespace base {  Vocabulary::~Vocabulary() {} -void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) { +void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) {    begin_sentence_ = begin_sentence;    end_sentence_ = end_sentence;    not_found_ = not_found; -  available_ = available;    if (begin_sentence_ == not_found_) throw SpecialWordMissingException("<s>");    if (end_sentence_ == not_found_) throw SpecialWordMissingException("</s>");  } diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 621a129e..f15f8789 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -37,8 +37,6 @@ class Vocabulary {      WordIndex BeginSentence() const { return begin_sentence_; }      WordIndex EndSentence() const { return end_sentence_; }      WordIndex NotFound() const { return not_found_; } -    // FullScoreReturn start index of unused word assignments. -    WordIndex Available() const { return available_; }      /* Most implementations allow StringPiece lookups and need only override       * Index(StringPiece).  SRI requires null termination and overrides all @@ -56,13 +54,13 @@ class Vocabulary {      // Call SetSpecial afterward.        Vocabulary() {} -    Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available) { -      SetSpecial(begin_sentence, end_sentence, not_found, available); +    Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) { +      SetSpecial(begin_sentence, end_sentence, not_found);      } -    void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found, WordIndex available); +    void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found); -    WordIndex begin_sentence_, end_sentence_, not_found_, available_; +    WordIndex begin_sentence_, end_sentence_, not_found_;    private:      // Disable copy constructors.  They're private and undefined.  @@ -97,7 +95,7 @@ class Vocabulary {   * missing these methods, see facade.hh.     *   * This is the fastest way to use a model and presents a normal State class to - * be included in hypothesis state structure.   + * be included in a hypothesis state structure.     *   *   * OPTION 2: Use the virtual interface below.   diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc new file mode 100644 index 00000000..c30428b2 --- /dev/null +++ b/klm/lm/vocab.cc @@ -0,0 +1,187 @@ +#include "lm/vocab.hh" + +#include "lm/enumerate_vocab.hh" +#include "lm/lm_exception.hh" +#include "lm/config.hh" +#include "lm/weights.hh" +#include "util/exception.hh" +#include "util/joint_sort.hh" +#include "util/murmur_hash.hh" +#include "util/probing_hash_table.hh" + +#include <string> + +namespace lm { +namespace ngram { + +namespace detail { +uint64_t HashForVocab(const char *str, std::size_t len) { +  // This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000 +  // Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit.   +  return util::MurmurHash64A(str, len, 0); +} +} // namespace detail + +namespace { +// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok.   +const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5); +// Sadly some LMs have <UNK>.   +const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); + +void ReadWords(int fd, EnumerateVocab *enumerate) { +  if (!enumerate) return; +  const std::size_t kInitialRead = 16384; +  std::string buf; +  buf.reserve(kInitialRead + 100); +  buf.resize(kInitialRead); +  WordIndex index = 0; +  while (true) { +    ssize_t got = read(fd, &buf[0], kInitialRead); +    if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); +    if (got == 0) return; +    buf.resize(got); +    while (buf[buf.size() - 1]) { +      char next_char; +      ssize_t ret = read(fd, &next_char, 1); +      if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); +      if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word."); +      buf.push_back(next_char); +    } +    // Ok now we have null terminated strings.   +    for (const char *i = buf.data(); i != buf.data() + buf.size();) { +      std::size_t length = strlen(i); +      enumerate->Add(index++, StringPiece(i, length)); +      i += length + 1 /* null byte */; +    } +  } +} + +void WriteOrThrow(int fd, const void *data_void, std::size_t size) { +  const uint8_t *data = static_cast<const uint8_t*>(data_void); +  while (size) { +    ssize_t ret = write(fd, data, size); +    if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); +    data += ret; +    size -= ret; +  } +} + +} // namespace + +WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner, int fd) : inner_(inner), fd_(fd) {} +WriteWordsWrapper::~WriteWordsWrapper() {} + +void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { +  if (inner_) inner_->Add(index, str); +  WriteOrThrow(fd_, str.data(), str.size()); +  char null_byte = 0; +  // Inefficient because it's unbuffered.  Sue me.   +  WriteOrThrow(fd_, &null_byte, 1); +} + +SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} + +std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { +  // Lead with the number of entries.   +  return sizeof(uint64_t) + sizeof(Entry) * entries; +} + +void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) { +  assert(allocated >= Size(entries, config)); +  // Leave space for number of entries.   +  begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1); +  end_ = begin_; +  saw_unk_ = false; +} + +void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) { +  enumerate_ = to; +  if (enumerate_) { +    enumerate_->Add(0, "<unk>"); +    strings_to_enumerate_.resize(max_entries); +  } +} + +WordIndex SortedVocabulary::Insert(const StringPiece &str) { +  uint64_t hashed = detail::HashForVocab(str); +  if (hashed == kUnknownHash || hashed == kUnknownCapHash) { +    saw_unk_ = true; +    return 0; +  } +  end_->key = hashed; +  if (enumerate_) { +    strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); +  } +  ++end_; +  // This is 1 + the offset where it was inserted to make room for unk.   +  return end_ - begin_; +} + +void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { +  if (enumerate_) { +    util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin()); +    util::JointSort(begin_, end_, values); +    for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) { +      // <unk> strikes again: +1 here.   +      enumerate_->Add(i + 1, strings_to_enumerate_[i]); +    } +    strings_to_enumerate_.clear(); +  } else { +    util::JointSort(begin_, end_, reorder_vocab + 1); +  } +  SetSpecial(Index("<s>"), Index("</s>"), 0); +  // Save size.   +  *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_; +} + +void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { +  end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1); +  ReadWords(fd, to); +  SetSpecial(Index("<s>"), Index("</s>"), 0); +} + +ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} + +std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { +  return Lookup::Size(entries, config.probing_multiplier); +} + +void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { +  lookup_ = Lookup(start, allocated); +  available_ = 1; +  saw_unk_ = false; +} + +void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) { +  enumerate_ = to; +  if (enumerate_) { +    enumerate_->Add(0, "<unk>"); +  } +} + +WordIndex ProbingVocabulary::Insert(const StringPiece &str) { +  uint64_t hashed = detail::HashForVocab(str); +  // Prevent unknown from going into the table.   +  if (hashed == kUnknownHash || hashed == kUnknownCapHash) { +    saw_unk_ = true; +    return 0; +  } else { +    if (enumerate_) enumerate_->Add(available_, str); +    lookup_.Insert(Lookup::Packing::Make(hashed, available_)); +    return available_++; +  } +} + +void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { +  lookup_.FinishedInserting(); +  SetSpecial(Index("<s>"), Index("</s>"), 0); +} + +void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { +  lookup_.LoadedBinary(); +  ReadWords(fd, to); +  SetSpecial(Index("<s>"), Index("</s>"), 0); +} + +} // namespace ngram +} // namespace lm diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index be84736c..9e38e0f1 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -20,6 +20,7 @@ noinst_LIBRARIES = libklm_util.a  libklm_util_a_SOURCES = \    ersatz_progress.cc \ +  bit_packing.cc \    exception.cc \    file_piece.cc \    mmap.cc \ diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc new file mode 100644 index 00000000..dd14ffe1 --- /dev/null +++ b/klm/util/bit_packing.cc @@ -0,0 +1,27 @@ +#include "util/bit_packing.hh" +#include "util/exception.hh" + +namespace util { + +namespace { +template <bool> struct StaticCheck {}; +template <> struct StaticCheck<true> { typedef bool StaticAssertionPassed; }; + +typedef StaticCheck<sizeof(float) == 4>::StaticAssertionPassed FloatSize; + +} // namespace + +uint8_t RequiredBits(uint64_t max_value) { +  if (!max_value) return 0; +  uint8_t ret = 1; +  while (max_value >>= 1) ++ret; +  return ret; +} + +void BitPackingSanity() { +  const detail::FloatEnc neg1 = { -1.0 }, pos1 = { 1.0 }; +  if ((neg1.i ^ pos1.i) != 0x80000000) UTIL_THROW(Exception, "Sign bit is not 0x80000000"); +  // TODO: more checks.   +} + +} // namespace util diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh new file mode 100644 index 00000000..422ed873 --- /dev/null +++ b/klm/util/bit_packing.hh @@ -0,0 +1,88 @@ +#ifndef UTIL_BIT_PACKING__ +#define UTIL_BIT_PACKING__ + +/* Bit-level packing routines */ + +#include <assert.h> +#ifdef __APPLE__ +#include <architecture/byte_order.h> +#else +#include <endian.h> +#endif + +#include <inttypes.h> + +#if __BYTE_ORDER != __LITTLE_ENDIAN +#error The bit aligned storage functions assume little endian architecture +#endif + +namespace util { + +/* WARNING WARNING WARNING: + * The write functions assume that memory is zero initially.  This makes them + * faster and is the appropriate case for mmapped language model construction. + * These routines assume that unaligned access to uint64_t is fast and that + * storage is little endian.  This is the case on x86_64.  It may not be the + * case on 32-bit x86 but my target audience is large language models for which + * 64-bit is necessary.   + */ + +/* Pack integers up to 57 bits using their least significant digits.  + * The length is specified using mask: + * Assumes mask == (1 << length) - 1 where length <= 57.    + */ +inline uint64_t ReadInt57(const void *base, uint8_t bit, uint64_t mask) { +  return (*reinterpret_cast<const uint64_t*>(base) >> bit) & mask; +} +/* Assumes value <= mask and mask == (1 << length) - 1 where length <= 57. + * Assumes the memory is zero initially.  + */ +inline void WriteInt57(void *base, uint8_t bit, uint64_t value) { +  *reinterpret_cast<uint64_t*>(base) |= (value << bit); +} + +namespace detail { typedef union { float f; uint32_t i; } FloatEnc; } +inline float ReadFloat32(const void *base, uint8_t bit) { +  detail::FloatEnc encoded; +  encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit; +  return encoded.f; +} +inline void WriteFloat32(void *base, uint8_t bit, float value) { +  detail::FloatEnc encoded; +  encoded.f = value; +  WriteInt57(base, bit, encoded.i); +} + +inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { +  detail::FloatEnc encoded; +  encoded.i = *reinterpret_cast<const uint64_t*>(base) >> bit; +  // Sign bit set means negative.   +  encoded.i |= 0x80000000; +  return encoded.f; +} +inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { +  assert(value <= 0.0); +  detail::FloatEnc encoded; +  encoded.f = value; +  encoded.i &= ~0x80000000; +  WriteInt57(base, bit, encoded.i); +} + +void BitPackingSanity(); + +// Return bits required to store integers upto max_value.  Not the most +// efficient implementation, but this is only called a few times to size tries.  +uint8_t RequiredBits(uint64_t max_value); + +struct BitsMask { +  void FromMax(uint64_t max_value) { +    bits = RequiredBits(max_value); +    mask = (1 << bits) - 1; +  } +  uint8_t bits; +  uint64_t mask; +}; + +} // namespace util + +#endif // UTIL_BIT_PACKING__ diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc index 09e3a106..55c182bd 100644 --- a/klm/util/ersatz_progress.cc +++ b/klm/util/ersatz_progress.cc @@ -13,10 +13,7 @@ ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits<std::s  ErsatzProgress::~ErsatzProgress() {    if (!out_) return; -  for (; stones_written_ < kWidth; ++stones_written_) { -    (*out_) << '*'; -  } -  *out_ << '\n'; +  Finished();  }  ErsatzProgress::ErsatzProgress(std::ostream *to, const std::string &message, std::size_t complete)  @@ -36,8 +33,8 @@ void ErsatzProgress::Milestone() {    for (; stones_written_ < stone; ++stones_written_) {      (*out_) << '*';    } - -  if (current_ >= complete_) { +  if (stone == kWidth) { +    (*out_) << std::endl;      next_ = std::numeric_limits<std::size_t>::max();    } else {      next_ = std::max(next_, (stone * complete_) / kWidth); diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index ea6c3bb9..92c345fe 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -19,7 +19,7 @@ class ErsatzProgress {      ~ErsatzProgress();      ErsatzProgress &operator++() { -      if (++current_ == next_) Milestone(); +      if (++current_ >= next_) Milestone();        return *this;      } @@ -33,6 +33,10 @@ class ErsatzProgress {        Milestone();      } +    void Finished() { +      Set(complete_); +    } +    private:      void Milestone(); diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 2b439499..e7bd8659 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -2,19 +2,23 @@  #include "util/exception.hh" -#include <iostream>  #include <string>  #include <limits>  #include <assert.h> -#include <cstdlib>  #include <ctype.h> +#include <err.h>  #include <fcntl.h> +#include <stdlib.h>  #include <sys/mman.h>  #include <sys/types.h>  #include <sys/stat.h>  #include <unistd.h> +#ifdef HAVE_ZLIB +#include <zlib.h> +#endif +  namespace util {  EndOfFileException::EndOfFileException() throw() { @@ -26,6 +30,13 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() {    *this << "Could not parse \"" << value << "\" into a float";  } +GZException::GZException(void *file) { +#ifdef HAVE_ZLIB +  int num; +  *this << gzerror(file, &num) << " from zlib"; +#endif // HAVE_ZLIB +} +  int OpenReadOrThrow(const char *name) {    int ret = open(name, O_RDONLY);    if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading"); @@ -38,42 +49,73 @@ off_t SizeFile(int fd) {    return sb.st_size;  } -FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) :  +FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :     file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),    progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {    Initialize(name, show_progress, min_buffer);  } -FilePiece::FilePiece(const char *name, int fd, std::ostream *show_progress, off_t min_buffer) :  +FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :     file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),    progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {    Initialize(name, show_progress, min_buffer);  } -void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) { -  if (total_size_ == kBadSize) { -    fallback_to_read_ = true; -    if (show_progress)  -      *show_progress << "File " << name << " isn't normal.  Using slower read() instead of mmap().  No progress bar." << std::endl; -  } else { -    fallback_to_read_ = false; +FilePiece::~FilePiece() { +#ifdef HAVE_ZLIB +  if (gz_file_) { +    // zlib took ownership +    file_.release(); +    int ret; +    if (Z_OK != (ret = gzclose(gz_file_))) { +      errx(1, "could not close file %s using zlib", file_name_.c_str()); +    }    } +#endif +} + +void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) { +#ifdef HAVE_ZLIB +  gz_file_ = NULL; +#endif +  file_name_ = name; +    default_map_size_ = page_ * std::max<off_t>((min_buffer / page_ + 1), 2);    position_ = NULL;    position_end_ = NULL;    mapped_offset_ = 0;    at_end_ = false; + +  if (total_size_ == kBadSize) { +    // So the assertion passes.   +    fallback_to_read_ = false; +    if (show_progress)  +      *show_progress << "File " << name << " isn't normal.  Using slower read() instead of mmap().  No progress bar." << std::endl; +    TransitionToRead(); +  } else { +    fallback_to_read_ = false; +  }    Shift(); +  // gzip detect. +  if ((position_end_ - position_) > 2 && *position_ == 0x1f && static_cast<unsigned char>(*(position_ + 1)) == 0x8b) { +#ifndef HAVE_ZLIB +    UTIL_THROW(GZException, "Looks like a gzip file but support was not compiled in."); +#endif +    if (!fallback_to_read_) { +      at_end_ = false; +      TransitionToRead(); +    } +  }  } -float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) { +float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) {    SkipSpaces();    while (last_space_ < position_) {      if (at_end_) {        // Hallucinate a null off the end of the file.        std::string buffer(position_, position_end_);        char *end; -      float ret = std::strtof(buffer.c_str(), &end); +      float ret = strtof(buffer.c_str(), &end);        if (buffer.c_str() == end) throw ParseNumberException(buffer);        position_ += end - buffer.c_str();        return ret; @@ -81,20 +123,20 @@ float FilePiece::ReadFloat() throw(EndOfFileException, ParseNumberException) {      Shift();    }    char *end; -  float ret = std::strtof(position_, &end); +  float ret = strtof(position_, &end);    if (end == position_) throw ParseNumberException(ReadDelimited());    position_ = end;    return ret;  } -void FilePiece::SkipSpaces() throw (EndOfFileException) { +void FilePiece::SkipSpaces() throw (GZException, EndOfFileException) {    for (; ; ++position_) {      if (position_ == position_end_) Shift();      if (!isspace(*position_)) return;    }  } -const char *FilePiece::FindDelimiterOrEOF() throw (EndOfFileException) { +const char *FilePiece::FindDelimiterOrEOF() throw (GZException, EndOfFileException) {    for (const char *i = position_; i <= last_space_; ++i) {      if (isspace(*i)) return i;    } @@ -108,7 +150,7 @@ const char *FilePiece::FindDelimiterOrEOF() throw (EndOfFileException) {    return position_end_;  } -StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) { +StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) {    const char *start = position_;    do {      for (const char *i = start; i < position_end_; ++i) { @@ -124,17 +166,19 @@ StringPiece FilePiece::ReadLine(char delim) throw (EndOfFileException) {    } while (!at_end_);    StringPiece ret(position_, position_end_ - position_);    position_ = position_end_; -  return position_; +  return ret;  } -void FilePiece::Shift() throw(EndOfFileException) { -  if (at_end_) throw EndOfFileException(); +void FilePiece::Shift() throw(GZException, EndOfFileException) { +  if (at_end_) { +    progress_.Finished(); +    throw EndOfFileException(); +  }    off_t desired_begin = position_ - data_.begin() + mapped_offset_; -  progress_.Set(desired_begin);    if (!fallback_to_read_) MMapShift(desired_begin);    // Notice an mmap failure might set the fallback.   -  if (fallback_to_read_) ReadShift(desired_begin); +  if (fallback_to_read_) ReadShift();    for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) {      if (isspace(*last_space_))  break; @@ -163,28 +207,41 @@ void FilePiece::MMapShift(off_t desired_begin) throw() {    data_.reset();    data_.reset(mmap(NULL, mapped_size, PROT_READ, MAP_PRIVATE, *file_, mapped_offset), mapped_size, scoped_memory::MMAP_ALLOCATED);    if (data_.get() == MAP_FAILED) { -    fallback_to_read_ = true;      if (desired_begin) {        if (((off_t)-1) == lseek(*file_, desired_begin, SEEK_SET)) UTIL_THROW(ErrnoException, "mmap failed even though it worked before.  lseek failed too, so using read isn't an option either.");      } +    // The mmap was scheduled to end the file, but now we're going to read it.   +    at_end_ = false; +    TransitionToRead();      return;    }    mapped_offset_ = mapped_offset;    position_ = data_.begin() + ignore;    position_end_ = data_.begin() + mapped_size; + +  progress_.Set(desired_begin); +} + +void FilePiece::TransitionToRead() throw (GZException) { +  assert(!fallback_to_read_); +  fallback_to_read_ = true; +  data_.reset(); +  data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); +  if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_); +  position_ = data_.begin(); +  position_end_ = position_; + +#ifdef HAVE_ZLIB +  assert(!gz_file_); +  gz_file_ = gzdopen(file_.get(), "r"); +  if (!gz_file_) { +    UTIL_THROW(GZException, "zlib failed to open " << file_name_); +  } +#endif  } -void FilePiece::ReadShift(off_t desired_begin) throw() { +void FilePiece::ReadShift() throw(GZException, EndOfFileException) {    assert(fallback_to_read_); -  if (data_.source() != scoped_memory::MALLOC_ALLOCATED) { -    // First call. -    data_.reset(); -    data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); -    if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_); -    position_ = data_.begin(); -    position_end_ = position_; -  }  -      // Bytes [data_.begin(), position_) have been consumed.      // Bytes [position_, position_end_) have been read into the buffer.   @@ -215,9 +272,23 @@ void FilePiece::ReadShift(off_t desired_begin) throw() {      }    } -  ssize_t read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); +  ssize_t read_return; +#ifdef HAVE_ZLIB +  read_return = gzread(gz_file_, static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); +  if (read_return == -1) throw GZException(gz_file_); +  if (total_size_ != kBadSize) { +    // Just get the position, don't actually seek.  Apparently this is how you do it. . .  +    off_t ret = lseek(file_.get(), 0, SEEK_CUR); +    if (ret != -1) progress_.Set(ret); +  } +#else +  read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read);    if (read_return == -1) UTIL_THROW(ErrnoException, "read failed"); -  if (read_return == 0) at_end_ = true; +  progress_.Set(mapped_offset_); +#endif +  if (read_return == 0) { +    at_end_ = true; +  }    position_end_ += read_return;  } diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index 704f0ac6..11d4a751 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -11,6 +11,8 @@  #include <cstddef> +#define HAVE_ZLIB +  namespace util {  class EndOfFileException : public Exception { @@ -25,6 +27,13 @@ class ParseNumberException : public Exception {      ~ParseNumberException() throw() {}  }; +class GZException : public Exception { +  public: +    explicit GZException(void *file); +    GZException() throw() {} +    ~GZException() throw() {} +}; +  int OpenReadOrThrow(const char *name);  // Return value for SizeFile when it can't size properly.   @@ -34,40 +43,42 @@ off_t SizeFile(int fd);  class FilePiece {    public:      // 32 MB default. -    explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); +    explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException);      // Takes ownership of fd.  name is used for messages.   -    explicit FilePiece(const char *name, int fd, std::ostream *show_progress = NULL, off_t min_buffer = 33554432); +    explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException); + +    ~FilePiece(); -    char get() throw(EndOfFileException) {  -      if (position_ == position_end_) Shift(); +    char get() throw(GZException, EndOfFileException) {  +      if (position_ == position_end_) { +        Shift(); +        if (at_end_) throw EndOfFileException(); +      }        return *(position_++);      }      // Memory backing the returned StringPiece may vanish on the next call.        // Leaves the delimiter, if any, to be returned by get(). -    StringPiece ReadDelimited() throw(EndOfFileException) { +    StringPiece ReadDelimited() throw(GZException, EndOfFileException) {        SkipSpaces();        return Consume(FindDelimiterOrEOF());      }      // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.      // It is similar to getline in that way.   -    StringPiece ReadLine(char delim = '\n') throw(EndOfFileException); +    StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException); -    float ReadFloat() throw(EndOfFileException, ParseNumberException); +    float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException); -    void SkipSpaces() throw (EndOfFileException); +    void SkipSpaces() throw (GZException, EndOfFileException);      off_t Offset() const {        return position_ - data_.begin() + mapped_offset_;      } -    // Only for testing.   -    void ForceFallbackToRead() { -      fallback_to_read_ = true; -    } +    const std::string &FileName() const { return file_name_; }    private: -    void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer); +    void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException);      StringPiece Consume(const char *to) {        StringPiece ret(position_, to - position_); @@ -75,12 +86,14 @@ class FilePiece {        return ret;      } -    const char *FindDelimiterOrEOF() throw(EndOfFileException); +    const char *FindDelimiterOrEOF() throw(EndOfFileException, GZException); -    void Shift() throw (EndOfFileException); +    void Shift() throw (EndOfFileException, GZException);      // Backends to Shift().      void MMapShift(off_t desired_begin) throw (); -    void ReadShift(off_t desired_begin) throw (); + +    void TransitionToRead() throw (GZException); +    void ReadShift() throw (GZException, EndOfFileException);      const char *position_, *last_space_, *position_end_; @@ -98,6 +111,12 @@ class FilePiece {      bool fallback_to_read_;      ErsatzProgress progress_; + +    std::string file_name_; + +#ifdef HAVE_ZLIB +    void *gz_file_; +#endif // HAVE_ZLIB  };  } // namespace util diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index befb7866..23e79fe0 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -1,15 +1,19 @@  #include "util/file_piece.hh" +#include "util/scoped.hh" +  #define BOOST_TEST_MODULE FilePieceTest  #include <boost/test/unit_test.hpp>  #include <fstream>  #include <iostream> +#include <stdio.h> +  namespace util {  namespace {  /* mmap implementation */ -BOOST_AUTO_TEST_CASE(MMapLine) { +BOOST_AUTO_TEST_CASE(MMapReadLine) {    std::fstream ref("file_piece.cc", std::ios::in);    FilePiece test("file_piece.cc", NULL, 1);    std::string ref_line; @@ -20,13 +24,17 @@ BOOST_AUTO_TEST_CASE(MMapLine) {        BOOST_CHECK_EQUAL(ref_line, test_line);      }    } +  BOOST_CHECK_THROW(test.get(), EndOfFileException);  }  /* read() implementation */ -BOOST_AUTO_TEST_CASE(ReadLine) { +BOOST_AUTO_TEST_CASE(StreamReadLine) {    std::fstream ref("file_piece.cc", std::ios::in); -  FilePiece test("file_piece.cc", NULL, 1); -  test.ForceFallbackToRead(); + +  scoped_FILE catter(popen("cat file_piece.cc", "r")); +  BOOST_REQUIRE(catter.get()); +   +  FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1);    std::string ref_line;    while (getline(ref, ref_line)) {      StringPiece test_line(test.ReadLine()); @@ -35,7 +43,47 @@ BOOST_AUTO_TEST_CASE(ReadLine) {        BOOST_CHECK_EQUAL(ref_line, test_line);      }    } +  BOOST_CHECK_THROW(test.get(), EndOfFileException);  } +#ifdef HAVE_ZLIB + +// gzip file +BOOST_AUTO_TEST_CASE(PlainZipReadLine) { +  std::fstream ref("file_piece.cc", std::ios::in); + +  BOOST_REQUIRE_EQUAL(0, system("gzip <file_piece.cc >file_piece.cc.gz")); +  FilePiece test("file_piece.cc.gz", NULL, 1); +  std::string ref_line; +  while (getline(ref, ref_line)) { +    StringPiece test_line(test.ReadLine()); +    // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924 +    if (!test_line.empty() || !ref_line.empty()) { +      BOOST_CHECK_EQUAL(ref_line, test_line); +    } +  } +  BOOST_CHECK_THROW(test.get(), EndOfFileException); +} +// gzip stream +BOOST_AUTO_TEST_CASE(StreamZipReadLine) { +  std::fstream ref("file_piece.cc", std::ios::in); + +  scoped_FILE catter(popen("gzip <file_piece.cc", "r")); +  BOOST_REQUIRE(catter.get()); +   +  FilePiece test(dup(fileno(catter.get())), "file_piece.cc", NULL, 1); +  std::string ref_line; +  while (getline(ref, ref_line)) { +    StringPiece test_line(test.ReadLine()); +    // I submitted a bug report to ICU: http://bugs.icu-project.org/trac/ticket/7924 +    if (!test_line.empty() || !ref_line.empty()) { +      BOOST_CHECK_EQUAL(ref_line, test_line); +    } +  } +  BOOST_CHECK_THROW(test.get(), EndOfFileException); +} + +#endif +  } // namespace  } // namespace util diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh index a2f1c01d..cf3d8432 100644 --- a/klm/util/joint_sort.hh +++ b/klm/util/joint_sort.hh @@ -119,6 +119,12 @@ template <class Proxy, class Less> class LessWrapper : public std::binary_functi  } // namespace detail +template <class KeyIter, class ValueIter> class PairedIterator : public ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > { +  public: +    PairedIterator(const KeyIter &key, const ValueIter &value) :  +      ProxyIterator<detail::JointProxy<KeyIter, ValueIter> >(detail::JointProxy<KeyIter, ValueIter>(key, value)) {} +}; +  template <class KeyIter, class ValueIter, class Less> void JointSort(const KeyIter &key_begin, const KeyIter &key_end, const ValueIter &value_begin, const Less &less) {    ProxyIterator<detail::JointProxy<KeyIter, ValueIter> > full_begin(detail::JointProxy<KeyIter, ValueIter>(key_begin, value_begin));    detail::LessWrapper<detail::JointProxy<KeyIter, ValueIter>, Less> less_wrap(less); diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index 648b5d0a..8685170f 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -53,10 +53,8 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int    if (prefault) {      flags |= MAP_POPULATE;    } -  int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; -#else -  int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ;  #endif +  int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ;    void *ret = mmap(NULL, size, protect, flags, fd, offset);    if (ret == MAP_FAILED) {      UTIL_THROW(ErrnoException, "mmap failed for size " << size << " at offset " << offset); @@ -64,8 +62,40 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int    return ret;  } -void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset) { -  return MapOrThrow(size, false, MAP_FILE | MAP_PRIVATE, prefault, fd, offset); +namespace { +void ReadAll(int fd, void *to_void, std::size_t amount) { +  uint8_t *to = static_cast<uint8_t*>(to_void); +  while (amount) { +    ssize_t ret = read(fd, to, amount); +    if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); +    if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); +    amount -= ret; +    to += ret; +  } +} +} // namespace + +void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out) { +  switch (method) { +    case LAZY: +      out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, false, fd, offset), size, scoped_memory::MMAP_ALLOCATED); +      break; +    case POPULATE_OR_LAZY: +#ifdef MAP_POPULATE +    case POPULATE_OR_READ: +#endif +      out.reset(MapOrThrow(size, false, MAP_FILE | MAP_SHARED, true, fd, offset), size, scoped_memory::MMAP_ALLOCATED); +      break; +#ifndef MAP_POPULATE +    case POPULATE_OR_READ: +#endif +    case READ: +      out.reset(malloc(size), size, scoped_memory::MALLOC_ALLOCATED); +      if (!out.get()) UTIL_THROW(util::ErrnoException, "Allocating " << size << " bytes with malloc"); +      if (-1 == lseek(fd, offset, SEEK_SET)) UTIL_THROW(ErrnoException, "lseek to " << offset << " in fd " << fd << " failed."); +      ReadAll(fd, out.get(), size); +      break; +  }  }  void *MapAnonymous(std::size_t size) { @@ -78,14 +108,14 @@ void *MapAnonymous(std::size_t size) {        | MAP_PRIVATE, false, -1, 0);  } -void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem) { +void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file) {    file.reset(open(name, O_CREAT | O_RDWR | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH));    if (-1 == file.get())      UTIL_THROW(ErrnoException, "Failed to open " << name << " for writing");    if (-1 == ftruncate(file.get(), size))      UTIL_THROW(ErrnoException, "ftruncate on " << name << " to " << size << " failed");    try { -    mem.reset(MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0), size); +    return MapOrThrow(size, true, MAP_FILE | MAP_SHARED, false, file.get(), 0);    } catch (ErrnoException &e) {      e << " in file " << name;      throw; diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh index c9068ec9..0a504d89 100644 --- a/klm/util/mmap.hh +++ b/klm/util/mmap.hh @@ -6,6 +6,7 @@  #include <cstddef> +#include <inttypes.h>  #include <sys/types.h>  namespace util { @@ -19,8 +20,8 @@ class scoped_mmap {      void *get() const { return data_; } -    const char *begin() const { return reinterpret_cast<char*>(data_); } -    const char *end() const { return reinterpret_cast<char*>(data_) + size_; } +    const uint8_t *begin() const { return reinterpret_cast<uint8_t*>(data_); } +    const uint8_t *end() const { return reinterpret_cast<uint8_t*>(data_) + size_; }      std::size_t size() const { return size_; }      void reset(void *data, std::size_t size) { @@ -79,23 +80,27 @@ class scoped_memory {      scoped_memory &operator=(const scoped_memory &);  }; -struct scoped_mapped_file { -  scoped_fd fd; -  scoped_mmap mem; -}; +typedef enum { +  // mmap with no prepopulate +  LAZY, +  // On linux, pass MAP_POPULATE to mmap. +  POPULATE_OR_LAZY, +  // Populate on Linux.  malloc and read on non-Linux.   +  POPULATE_OR_READ, +  // malloc and read.   +  READ +} LoadMethod; +  // Wrapper around mmap to check it worked and hide some platform macros.    void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, off_t offset = 0); -void *MapForRead(std::size_t size, bool prefault, int fd, off_t offset = 0); +void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_memory &out);  void *MapAnonymous(std::size_t size);  // Open file name with mmap of size bytes, all of which are initially zero.   -void MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file, scoped_mmap &mem); -inline void MapZeroedWrite(const char *name, std::size_t size, scoped_mapped_file &out) { -  MapZeroedWrite(name, size, out.fd, out.mem); -} -  +void *MapZeroedWrite(const char *name, std::size_t size, scoped_fd &file); +  } // namespace util -#endif // UTIL_SCOPED__ +#endif // UTIL_MMAP__ diff --git a/klm/util/proxy_iterator.hh b/klm/util/proxy_iterator.hh index 1c5b7089..121a45fa 100644 --- a/klm/util/proxy_iterator.hh +++ b/klm/util/proxy_iterator.hh @@ -78,6 +78,8 @@ template <class Proxy> class ProxyIterator {      const Proxy *operator->() const { return &p_; }      Proxy operator[](std::ptrdiff_t amount) const { return *(*this + amount); } +    const InnerIterator &Inner() { return p_.Inner(); } +    private:      InnerIterator &I() { return p_.Inner(); }      const InnerIterator &I() const { return p_.Inner(); } diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc index 61394ffc..2c6d5394 100644 --- a/klm/util/scoped.cc +++ b/klm/util/scoped.cc @@ -9,4 +9,8 @@ scoped_fd::~scoped_fd() {    if (fd_ != -1 && close(fd_)) err(1, "Could not close file %i", fd_);  } +scoped_FILE::~scoped_FILE() { +  if (file_ && fclose(file_)) err(1, "Could not close file"); +} +  } // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index ef62a74f..52864481 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -4,6 +4,7 @@  /* Other scoped objects in the style of scoped_ptr. */  #include <cstddef> +#include <cstdio>  namespace util { @@ -61,6 +62,24 @@ class scoped_fd {      scoped_fd &operator=(const scoped_fd &);  }; +class scoped_FILE { +  public: +    explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} + +    ~scoped_FILE(); + +    std::FILE *get() { return file_; } +    const std::FILE *get() const { return file_; } + +    void reset(std::FILE *to = NULL) { +      scoped_FILE other(file_); +      file_ = to; +    } + +  private: +    std::FILE *file_; +}; +  } // namespace util  #endif // UTIL_SCOPED__ diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 96ec4866..a8e208fb 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -65,7 +65,7 @@ template <class PackingT> class SortedUniformMap {    public:      // Offer consistent API with probing hash. -    static std::size_t Size(std::size_t entries, float ignore = 0.0) { +    static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) {        return sizeof(uint64_t) + entries * Packing::kBytes;      } @@ -75,7 +75,7 @@ template <class PackingT> class SortedUniformMap {  #endif      {} -    SortedUniformMap(void *start, std::size_t allocated) :  +    SortedUniformMap(void *start, std::size_t /*allocated*/) :         begin_(Packing::FromVoid(reinterpret_cast<uint64_t*>(start) + 1)),        end_(begin_), size_ptr_(reinterpret_cast<uint64_t*>(start))   #ifdef DEBUG diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc index 6917a6bc..5b4e98f5 100644 --- a/klm/util/string_piece.cc +++ b/klm/util/string_piece.cc @@ -30,14 +30,14 @@  #include "util/string_piece.hh" -#ifdef USE_BOOST +#ifdef HAVE_BOOST  #include <boost/functional/hash/hash.hpp>  #endif  #include <algorithm>  #include <iostream> -#ifdef USE_ICU +#ifdef HAVE_ICU  U_NAMESPACE_BEGIN  #endif @@ -46,12 +46,12 @@ std::ostream& operator<<(std::ostream& o, const StringPiece& piece) {    return o;  } -#ifdef USE_BOOST +#ifdef HAVE_BOOST  size_t hash_value(const StringPiece &str) {    return boost::hash_range(str.data(), str.data() + str.length());  }  #endif -#ifdef USE_ICU +#ifdef HAVE_ICU  U_NAMESPACE_END  #endif diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index 58008d13..3ac2f8a7 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -1,4 +1,4 @@ -/* If you use ICU in your program, then compile with -DUSE_ICU -licui18n.  If +/* If you use ICU in your program, then compile with -DHAVE_ICU -licui18n.  If   * you don't use ICU, then this will use the Google implementation from Chrome.   * This has been modified from the original version to let you choose.     */ @@ -49,14 +49,14 @@  #define BASE_STRING_PIECE_H__  //Uncomment this line if you use ICU in your code.   -//#define USE_ICU +//#define HAVE_ICU  //Uncomment this line if you want boost hashing for your StringPieces. -//#define USE_BOOST +//#define HAVE_BOOST  #include <cstring>  #include <iosfwd> -#ifdef USE_ICU +#ifdef HAVE_ICU  #include <unicode/stringpiece.h>  U_NAMESPACE_BEGIN  #else @@ -230,7 +230,7 @@ inline bool operator>=(const StringPiece& x, const StringPiece& y) {  // allow StringPiece to be logged (needed for unit testing).  extern std::ostream& operator<<(std::ostream& o, const StringPiece& piece); -#ifdef USE_BOOST +#ifdef HAVE_BOOST  size_t hash_value(const StringPiece &str);  /* Support for lookup of StringPiece in boost::unordered_map<std::string> */ @@ -253,7 +253,7 @@ template <class T> typename T::iterator FindStringPiece(T &t, const StringPiece  }  #endif -#ifdef USE_ICU +#ifdef HAVE_ICU  U_NAMESPACE_END  #endif | 
