diff options
Diffstat (limited to 'klm/lm')
37 files changed, 800 insertions, 487 deletions
| diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc index 088ea98d..c8a18dfd 100644 --- a/klm/lm/bhiksha.cc +++ b/klm/lm/bhiksha.cc @@ -1,4 +1,6 @@  #include "lm/bhiksha.hh" + +#include "lm/binary_format.hh"  #include "lm/config.hh"  #include "util/file.hh"  #include "util/exception.hh" @@ -15,11 +17,11 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_  const uint8_t kArrayBhikshaVersion = 0;  // TODO: put this in binary file header instead when I change the binary file format again.   -void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { -  uint8_t version; -  uint8_t configured_bits; -  util::ReadOrThrow(fd, &version, 1); -  util::ReadOrThrow(fd, &configured_bits, 1); +void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { +  uint8_t buffer[2]; +  file.ReadForConfig(buffer, 2, offset); +  uint8_t version = buffer[0]; +  uint8_t configured_bits = buffer[1];    if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion);    config.pointer_bhiksha_bits = configured_bits;  } @@ -87,9 +89,6 @@ void ArrayBhiksha::FinishedLoading(const Config &config) {    *(head_write++) = config.pointer_bhiksha_bits;  } -void ArrayBhiksha::LoadedBinary() { -} -  } // namespace trie  } // namespace ngram  } // namespace lm diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index 8ff88654..350571a6 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -24,6 +24,7 @@  namespace lm {  namespace ngram {  struct Config; +class BinaryFormat;  namespace trie { @@ -31,7 +32,7 @@ class DontBhiksha {    public:      static const ModelType kModelTypeAdd = static_cast<ModelType>(0); -    static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} +    static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {}      static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } @@ -53,8 +54,6 @@ class DontBhiksha {      void FinishedLoading(const Config &/*config*/) {} -    void LoadedBinary() {} -      uint8_t InlineBits() const { return next_.bits; }    private: @@ -65,7 +64,7 @@ class ArrayBhiksha {    public:      static const ModelType kModelTypeAdd = kArrayAdd; -    static void UpdateConfigFromBinary(int fd, Config &config); +    static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);      static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); @@ -93,8 +92,6 @@ class ArrayBhiksha {      void FinishedLoading(const Config &config); -    void LoadedBinary(); -      uint8_t InlineBits() const { return next_inline_.bits; }    private: diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 39c4a9b6..9c744b13 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -8,11 +8,15 @@  #include <cstring>  #include <limits>  #include <string> +#include <cstdlib>  #include <stdint.h>  namespace lm {  namespace ngram { + +const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; +  namespace {  const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";  const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; @@ -57,8 +61,6 @@ struct Sanity {    }  }; -const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; -  std::size_t TotalHeaderSize(unsigned char order) {    return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);  } @@ -80,83 +82,6 @@ void WriteHeader(void *to, const Parameters ¶ms) {  } // namespace -uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { -  if (config.write_mmap) { -    std::size_t total = TotalHeaderSize(order) + memory_size; -    backing.file.reset(util::CreateOrThrow(config.write_mmap)); -    if (config.write_method == Config::WRITE_MMAP) { -      backing.vocab.reset(util::MapZeroedWrite(backing.file.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); -    } else { -      util::ResizeOrThrow(backing.file.get(), 0); -      util::MapAnonymous(total, backing.vocab); -    } -    strncpy(reinterpret_cast<char*>(backing.vocab.get()), kMagicIncomplete, TotalHeaderSize(order)); -    return reinterpret_cast<uint8_t*>(backing.vocab.get()) + TotalHeaderSize(order); -  } else { -    util::MapAnonymous(memory_size, backing.vocab); -    return reinterpret_cast<uint8_t*>(backing.vocab.get()); -  } -} - -uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { -  std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; -  if (config.write_mmap) { -    // Grow the file to accomodate the search, using zeros. -    try { -      util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); -    } catch (util::ErrnoException &e) { -      e << " for file " << config.write_mmap; -      throw e; -    } - -    if (config.write_method == Config::WRITE_AFTER) { -      util::MapAnonymous(memory_size, backing.search); -      return reinterpret_cast<uint8_t*>(backing.search.get()); -    } -    // mmap it now. -    // We're skipping over the header and vocab for the search space mmap.  mmap likes page aligned offsets, so some arithmetic to round the offset down. -    std::size_t page_size = util::SizePage(); -    std::size_t alignment_cruft = adjusted_vocab % page_size; -    backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); -    return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft; -  } else { -    util::MapAnonymous(memory_size, backing.search); -    return reinterpret_cast<uint8_t*>(backing.search.get()); -  } -} - -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) { -  if (!config.write_mmap) return; -  switch (config.write_method) { -    case Config::WRITE_MMAP: -      util::SyncOrThrow(backing.vocab.get(), backing.vocab.size()); -      util::SyncOrThrow(backing.search.get(), backing.search.size()); -      break; -    case Config::WRITE_AFTER: -      util::SeekOrThrow(backing.file.get(), 0); -      util::WriteOrThrow(backing.file.get(), backing.vocab.get(), backing.vocab.size()); -      util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad); -      util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size()); -      util::FSyncOrThrow(backing.file.get()); -      break; -  } -  // header and vocab share the same mmap.  The header is written here because we know the counts. -  Parameters params = Parameters(); -  params.counts = counts; -  params.fixed.order = counts.size(); -  params.fixed.probing_multiplier = config.probing_multiplier; -  params.fixed.model_type = model_type; -  params.fixed.has_vocabulary = config.include_vocab; -  params.fixed.search_version = search_version; -  WriteHeader(backing.vocab.get(), params); -  if (config.write_method == Config::WRITE_AFTER) { -    util::SeekOrThrow(backing.file.get(), 0); -    util::WriteOrThrow(backing.file.get(), backing.vocab.get(), TotalHeaderSize(counts.size())); -  } -} - -namespace detail { -  bool IsBinaryFormat(int fd) {    const uint64_t size = util::SizeFile(fd);    if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false; @@ -169,21 +94,21 @@ bool IsBinaryFormat(int fd) {    }    Sanity reference_header = Sanity();    reference_header.SetToReference(); -  if (!memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true; -  if (!memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) { +  if (!std::memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true; +  if (!std::memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {      UTIL_THROW(FormatLoadException, "This binary file did not finish building");    } -  if (!memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) { +  if (!std::memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {      char *end_ptr;      const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion); -    long int version = strtol(begin_version, &end_ptr, 10); +    long int version = std::strtol(begin_version, &end_ptr, 10);      if ((end_ptr != begin_version) && version != kMagicVersion) {        UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");      }      OldSanity old_sanity = OldSanity();      old_sanity.SetToReference(); -    UTIL_THROW_IF(!memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format.  The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable."); +    UTIL_THROW_IF(!std::memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format.  The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");      UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match.  Try rebuilding the binary format LM using the same code revision, compiler, and architecture");    }    return false; @@ -208,44 +133,164 @@ void MatchCheck(ModelType model_type, unsigned int search_version, const Paramet    UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);  } -void SeekPastHeader(int fd, const Parameters ¶ms) { -  util::SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); +const std::size_t kInvalidSize = static_cast<std::size_t>(-1); + +BinaryFormat::BinaryFormat(const Config &config)  +  : write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method), +    header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {} + +void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms) { +  file_.reset(fd); +  write_mmap_ = NULL; // Ignore write requests; this is already in binary format. +  ReadHeader(fd, params); +  MatchCheck(model_type, search_version, params); +  header_size_ = TotalHeaderSize(params.counts.size()); +} + +void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const { +  assert(header_size_ != kInvalidSize); +  util::PReadOrThrow(file_.get(), to, amount, offset_excluding_header + header_size_);  } -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { -  const uint64_t file_size = util::SizeFile(backing.file.get()); +void *BinaryFormat::LoadBinary(std::size_t size) { +  assert(header_size_ != kInvalidSize); +  const uint64_t file_size = util::SizeFile(file_.get());    // The header is smaller than a page, so we have to map the whole header as well. -  std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); -  if (file_size != util::kBadSize && static_cast<uint64_t>(file_size) < total_map) -    UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); +  uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size); +  UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); -  util::MapRead(config.load_method, backing.file.get(), 0, total_map, backing.search); +  util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_); -  if (config.enumerate_vocab && !params.fixed.has_vocabulary) -    UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them.  You may need to rebuild the binary file with an updated version of build_binary."); +  vocab_string_offset_ = total_map; +  return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_; +} + +void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) { +  vocab_size_ = memory_size; +  if (!write_mmap_) { +    header_size_ = 0; +    util::MapAnonymous(memory_size, memory_vocab_); +    return reinterpret_cast<uint8_t*>(memory_vocab_.get()); +  } +  header_size_ = TotalHeaderSize(order); +  std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size)); +  file_.reset(util::CreateOrThrow(write_mmap_)); +  // some gccs complain about uninitialized variables even though all enum values are covered. +  void *vocab_base = NULL; +  switch (write_method_) { +    case Config::WRITE_MMAP: +      mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED); +      vocab_base = mapping_.get(); +      break; +    case Config::WRITE_AFTER: +      util::ResizeOrThrow(file_.get(), 0); +      util::MapAnonymous(total, memory_vocab_); +      vocab_base = memory_vocab_.get(); +      break; +  } +  strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_); +  return reinterpret_cast<uint8_t*>(vocab_base) + header_size_; +} -  // Seek to vocabulary words -  util::SeekOrThrow(backing.file.get(), total_map); -  return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size()); +void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) { +  assert(vocab_size_ != kInvalidSize); +  vocab_pad_ = vocab_pad; +  std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size; +  vocab_string_offset_ = new_size; +  if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) { +    util::MapAnonymous(memory_size, memory_search_); +    assert(header_size_ == 0 || write_mmap_); +    vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_; +    return reinterpret_cast<uint8_t*>(memory_search_.get()); +  } + +  assert(write_method_ == Config::WRITE_MMAP); +  // Also known as total size without vocab words. +  // Grow the file to accomodate the search, using zeros. +  // According to man mmap, behavior is undefined when the file is resized +  // underneath a mmap that is not a multiple of the page size.  So to be +  // safe, we'll unmap it and map it again. +  mapping_.reset(); +  util::ResizeOrThrow(file_.get(), new_size); +  void *ret; +  MapFile(vocab_base, ret); +  return ret;  } -void ComplainAboutARPA(const Config &config, ModelType model_type) { -  if (config.write_mmap || !config.messages) return; -  if (config.arpa_complain == Config::ALL) { -    *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; -  } else if (config.arpa_complain == Config::EXPENSIVE && -             (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { -    *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive.  Save time by building a binary format." << std::endl; +void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) { +  // Checking Config's include_vocab is the responsibility of the caller. +  assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize); +  if (!write_mmap_) { +    // Unchanged base. +    vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()); +    search_base = reinterpret_cast<uint8_t*>(memory_search_.get()); +    return; +  } +  if (write_method_ == Config::WRITE_MMAP) { +    mapping_.reset(); +  } +  util::SeekOrThrow(file_.get(), VocabStringReadingOffset()); +  util::WriteOrThrow(file_.get(), &buffer[0], buffer.size()); +  if (write_method_ == Config::WRITE_MMAP) { +    MapFile(vocab_base, search_base); +  } else { +    vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_; +    search_base = reinterpret_cast<uint8_t*>(memory_search_.get()); +  } +} + +void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) { +  if (!write_mmap_) return; +  switch (write_method_) { +    case Config::WRITE_MMAP: +      util::SyncOrThrow(mapping_.get(), mapping_.size()); +      break; +    case Config::WRITE_AFTER: +      util::SeekOrThrow(file_.get(), 0); +      util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size()); +      util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_); +      util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size()); +      util::FSyncOrThrow(file_.get()); +      break; +  } +  // header and vocab share the same mmap. +  Parameters params = Parameters(); +  memset(¶ms, 0, sizeof(Parameters)); +  params.counts = counts; +  params.fixed.order = counts.size(); +  params.fixed.probing_multiplier = config.probing_multiplier; +  params.fixed.model_type = model_type; +  params.fixed.has_vocabulary = config.include_vocab; +  params.fixed.search_version = search_version; +  switch (write_method_) { +    case Config::WRITE_MMAP: +      WriteHeader(mapping_.get(), params); +      util::SyncOrThrow(mapping_.get(), mapping_.size()); +      break; +    case Config::WRITE_AFTER: +      { +        std::vector<uint8_t> buffer(TotalHeaderSize(counts.size())); +        WriteHeader(&buffer[0], params); +        util::SeekOrThrow(file_.get(), 0); +        util::WriteOrThrow(file_.get(), &buffer[0], buffer.size()); +      } +      break;    }  } -} // namespace detail +void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) { +  mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED); +  vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_; +  search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_; +}  bool RecognizeBinary(const char *file, ModelType &recognized) {    util::scoped_fd fd(util::OpenReadOrThrow(file)); -  if (!detail::IsBinaryFormat(fd.get())) return false; +  if (!IsBinaryFormat(fd.get())) { +    return false; +  }    Parameters params; -  detail::ReadHeader(fd.get(), params); +  ReadHeader(fd.get(), params);    recognized = params.fixed.model_type;    return true;  } diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index bf699d5f..f33f88d7 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -17,6 +17,8 @@  namespace lm {  namespace ngram { +extern const char *kModelNames[6]; +  /*Inspect a file to determine if it is a binary lm.  If not, return false.     * If so, return true and set recognized to the type.  This is the only API in   * this header designed for use by decoder authors.   @@ -42,67 +44,63 @@ struct Parameters {    std::vector<uint64_t> counts;  }; -struct Backing { -  // File behind memory, if any.   -  util::scoped_fd file; -  // Vocabulary lookup table.  Not to be confused with the vocab words themselves.   -  util::scoped_memory vocab; -  // Raw block of memory backing the language model data structures -  util::scoped_memory search; -}; - -// Create just enough of a binary file to write vocabulary to it.   -uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); -// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.   -uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing); - -// Write header to binary file.  This is done last to prevent incomplete files -// from loading.    -void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts,  std::size_t vocab_pad, Backing &backing); +class BinaryFormat { +  public: +    explicit BinaryFormat(const Config &config); + +    // Reading a binary file: +    // Takes ownership of fd +    void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms); +    // Used to read parts of the file to update the config object before figuring out full size. +    void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const; +    // Actually load the binary file and return a pointer to the beginning of the search area. +    void *LoadBinary(std::size_t size); + +    uint64_t VocabStringReadingOffset() const { +      assert(vocab_string_offset_ != kInvalidOffset); +      return vocab_string_offset_; +    } -namespace detail { +    // Writing a binary file or initializing in RAM from ARPA: +    // Size for vocabulary. +    void *SetupJustVocab(std::size_t memory_size, uint8_t order); +    // Warning: can change the vocaulary base pointer. +    void *GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base); +    // Warning: can change vocabulary and search base addresses. +    void WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base); +    // Write the header at the beginning of the file. +    void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts); + +  private: +    void MapFile(void *&vocab_base, void *&search_base); + +    // Copied from configuration. +    const Config::WriteMethod write_method_; +    const char *write_mmap_; +    util::LoadMethod load_method_; + +    // File behind memory, if any.   +    util::scoped_fd file_; + +    // If there is a file involved, a single mapping. +    util::scoped_memory mapping_; + +    // If the data is only in memory, separately allocate each because the trie +    // knows vocab's size before it knows search's size (because SRILM might +    // have pruned). +    util::scoped_memory memory_vocab_, memory_search_; + +    // Memory ranges.  Note that these may not be contiguous and may not all +    // exist. +    std::size_t header_size_, vocab_size_, vocab_pad_; +    // aka end of search. +    uint64_t vocab_string_offset_; + +    static const uint64_t kInvalidOffset = (uint64_t)-1; +};  bool IsBinaryFormat(int fd); -void ReadHeader(int fd, Parameters ¶ms); - -void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms); - -void SeekPastHeader(int fd, const Parameters ¶ms); - -uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing); - -void ComplainAboutARPA(const Config &config, ModelType model_type); - -} // namespace detail - -template <class To> void LoadLM(const char *file, const Config &config, To &to) { -  Backing &backing = to.MutableBacking(); -  backing.file.reset(util::OpenReadOrThrow(file)); - -  try { -    if (detail::IsBinaryFormat(backing.file.get())) { -      Parameters params; -      detail::ReadHeader(backing.file.get(), params); -      detail::MatchCheck(To::kModelType, To::kVersion, params); -      // Replace the run-time configured probing_multiplier with the one in the file.   -      Config new_config(config); -      new_config.probing_multiplier = params.fixed.probing_multiplier; -      detail::SeekPastHeader(backing.file.get(), params); -      To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); -      uint64_t memory_size = To::Size(params.counts, new_config); -      uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); -      to.InitializeFromBinary(start, params, new_config, backing.file.get()); -    } else { -      detail::ComplainAboutARPA(config, To::kModelType); -      to.InitializeFromARPA(file, config); -    } -  } catch (util::Exception &e) { -    e << " File: " << file; -    throw; -  } -} -  } // namespace ngram  } // namespace lm  #endif // LM_BINARY_FORMAT__ diff --git a/klm/lm/build_binary_main.cc b/klm/lm/build_binary_main.cc index ab2c0c32..15b421e9 100644 --- a/klm/lm/build_binary_main.cc +++ b/klm/lm/build_binary_main.cc @@ -52,6 +52,7 @@ void Usage(const char *name, const char *default_mem) {  "-a compresses pointers using an array of offsets.  The parameter is the\n"  "   maximum number of bits encoded by the array.  Memory is minimized subject\n"  "   to the maximum, so pick 255 to minimize memory.\n\n" +"-h print this help message.\n\n"  "Get a memory estimate by passing an ARPA file without an output file name.\n";    exit(1);  } @@ -104,12 +105,15 @@ int main(int argc, char *argv[]) {    const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G"; +  if (argc == 2 && !strcmp(argv[1], "--help")) +    Usage(argv[0], default_mem); +    try {      bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;      lm::ngram::Config config;      config.building_memory = util::ParseSize(default_mem);      int opt; -    while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) { +    while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:h")) != -1) {        switch(opt) {          case 'q':            config.prob_bits = ParseBitCount(optarg); @@ -161,6 +165,7 @@ int main(int argc, char *argv[]) {            ParseFileList(optarg, config.rest_lower_files);            config.rest_function = Config::REST_LOWER;            break; +        case 'h': // help          default:            Usage(argv[0], default_mem);        } @@ -186,6 +191,7 @@ int main(int argc, char *argv[]) {        config.write_mmap = argv[optind + 2];      } else {        Usage(argv[0], default_mem); +      return 1;      }      if (!strcmp(model_type, "probing")) {        if (!set_write_method) config.write_method = Config::WRITE_AFTER; diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc index aea93ad1..ccc06efc 100644 --- a/klm/lm/builder/corpus_count.cc +++ b/klm/lm/builder/corpus_count.cc @@ -238,12 +238,17 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {    const WordIndex end_sentence = vocab.Lookup("</s>");    Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);    uint64_t count = 0; -  StringPiece delimiters("\0\t\r ", 4); +  bool delimiters[256]; +  memset(delimiters, 0, sizeof(delimiters)); +  const char kDelimiterSet[] = "\0\t\n\r "; +  for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) { +    delimiters[static_cast<unsigned char>(*i)] = true; +  }    try {      while(true) {        StringPiece line(from_.ReadLine());        writer.StartSentence(); -      for (util::TokenIter<util::AnyCharacter, true> w(line, delimiters); w; ++w) { +      for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {          WordIndex word = vocab.Lookup(*w);          UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus.  I plan to support models containing <unk> in the future.");          writer.Append(word); diff --git a/klm/lm/builder/lmplz_main.cc b/klm/lm/builder/lmplz_main.cc index c87abdb8..2563deed 100644 --- a/klm/lm/builder/lmplz_main.cc +++ b/klm/lm/builder/lmplz_main.cc @@ -33,7 +33,10 @@ int main(int argc, char *argv[]) {      po::options_description options("Language model building options");      lm::builder::PipelineConfig pipeline; +    std::string text, arpa; +      options.add_options() +      ("help", po::bool_switch(), "Show this help message")        ("order,o", po::value<std::size_t>(&pipeline.order)  #if BOOST_VERSION >= 104200           ->required() @@ -47,8 +50,13 @@ int main(int argc, char *argv[]) {        ("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")        ("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")        ("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") -      ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); -    if (argc == 1) { +      ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.") +      ("text", po::value<std::string>(&text), "Read text from a file instead of stdin") +      ("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout"); +    po::variables_map vm; +    po::store(po::parse_command_line(argc, argv, options), vm); + +    if (argc == 1 || vm["help"].as<bool>()) {        std::cerr <<           "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n"          "Please cite:\n" @@ -66,12 +74,17 @@ int main(int argc, char *argv[]) {          "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n"          "Memory sizes are specified like GNU sort: a number followed by a unit character.\n"          "Valid units are \% for percentage of memory (supported platforms only) and (in\n" -        "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y.  Default is K (*1024).\n\n"; +        "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y.  Default is K (*1024).\n"; +      uint64_t mem = util::GuessPhysicalMemory(); +      if (mem) { +        std::cerr << "This machine has " << mem << " bytes of memory.\n\n"; +      } else { +        std::cerr << "Unable to determine the amount of memory on this machine.\n\n"; +      }         std::cerr << options << std::endl;        return 1;      } -    po::variables_map vm; -    po::store(po::parse_command_line(argc, argv, options), vm); +      po::notify(vm);      // required() appeared in Boost 1.42.0. @@ -92,9 +105,17 @@ int main(int argc, char *argv[]) {      initial.adder_out.block_count = 2;      pipeline.read_backoffs = initial.adder_out; +    util::scoped_fd in(0), out(1); +    if (vm.count("text")) { +      in.reset(util::OpenReadOrThrow(text.c_str())); +    } +    if (vm.count("arpa")) { +      out.reset(util::CreateOrThrow(arpa.c_str())); +    } +      // Read from stdin      try { -      lm::builder::Pipeline(pipeline, 0, 1); +      lm::builder::Pipeline(pipeline, in.release(), out.release());      } catch (const util::MallocException &e) {        std::cerr << e.what() << std::endl;        std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as<std::string>() << std::endl; diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc index b89ea6ba..44a2313c 100644 --- a/klm/lm/builder/pipeline.cc +++ b/klm/lm/builder/pipeline.cc @@ -226,6 +226,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m    util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());    chain.Wait(true); +  std::cerr << "Unigram tokens " << token_count << " types " << type_count << std::endl;    std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl;    master.InitForAdjust(sorter, type_count);  } diff --git a/klm/lm/facade.hh b/klm/lm/facade.hh index 8b186017..de1551f1 100644 --- a/klm/lm/facade.hh +++ b/klm/lm/facade.hh @@ -16,19 +16,28 @@ template <class Child, class StateT, class VocabularyT> class ModelFacade : publ      typedef StateT State;      typedef VocabularyT Vocabulary; -    // Default Score function calls FullScore.  Model can override this.   -    float Score(const State &in_state, const WordIndex new_word, State &out_state) const { -      return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob; -    } -      /* Translate from void* to State */ -    FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const { +    FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const {        return static_cast<const Child*>(this)->FullScore(            *reinterpret_cast<const State*>(in_state),            new_word,            *reinterpret_cast<State*>(out_state));      } -    float Score(const void *in_state, const WordIndex new_word, void *out_state) const { + +    FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const { +      return static_cast<const Child*>(this)->FullScoreForgotState( +          context_rbegin, +          context_rend, +          new_word, +          *reinterpret_cast<State*>(out_state)); +    } + +    // Default Score function calls FullScore.  Model can override this.   +    float Score(const State &in_state, const WordIndex new_word, State &out_state) const { +      return static_cast<const Child*>(this)->FullScore(in_state, new_word, out_state).prob; +    } + +    float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const {        return static_cast<const Child*>(this)->Score(            *reinterpret_cast<const State*>(in_state),            new_word, diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh index 5b31620b..602b5b31 100644 --- a/klm/lm/filter/arpa_io.hh +++ b/klm/lm/filter/arpa_io.hh @@ -14,7 +14,6 @@  #include <string>  #include <vector> -#include <err.h>  #include <string.h>  #include <stdint.h> diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh index 97c0fa25..d992026f 100644 --- a/klm/lm/filter/count_io.hh +++ b/klm/lm/filter/count_io.hh @@ -5,20 +5,18 @@  #include <iostream>  #include <string> -#include <err.h> - +#include "util/fake_ofstream.hh" +#include "util/file.hh"  #include "util/file_piece.hh"  namespace lm {  class CountOutput : boost::noncopyable {    public: -    explicit CountOutput(const char *name) : file_(name, std::ios::out) {} +    explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {}      void AddNGram(const StringPiece &line) { -      if (!(file_ << line << '\n')) { -        err(3, "Writing counts file failed"); -      } +      file_ << line << '\n';      }      template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { @@ -30,7 +28,7 @@ class CountOutput : boost::noncopyable {      }    private: -    std::fstream file_; +    util::FakeOFStream file_;  };  class CountBatch { diff --git a/klm/lm/filter/filter_main.cc b/klm/lm/filter/filter_main.cc index 1736bc40..82fdc1ef 100644 --- a/klm/lm/filter/filter_main.cc +++ b/klm/lm/filter/filter_main.cc @@ -6,6 +6,7 @@  #endif  #include "lm/filter/vocab.hh"  #include "lm/filter/wrapper.hh" +#include "util/exception.hh"  #include "util/file_piece.hh"  #include <boost/ptr_container/ptr_vector.hpp> @@ -157,92 +158,96 @@ template <class Format> void DispatchFilterModes(const Config &config, std::istr  } // namespace lm  int main(int argc, char *argv[]) { -  if (argc < 4) { -    lm::DisplayHelp(argv[0]); -    return 1; -  } +  try { +    if (argc < 4) { +      lm::DisplayHelp(argv[0]); +      return 1; +    } -  // I used to have boost::program_options, but some users didn't want to compile boost. -  lm::Config config; -  config.mode = lm::MODE_UNSET; -  for (int i = 1; i < argc - 2; ++i) { -    const char *str = argv[i]; -    if (!std::strcmp(str, "copy")) { -      config.mode = lm::MODE_COPY; -    } else if (!std::strcmp(str, "single")) { -      config.mode = lm::MODE_SINGLE; -    } else if (!std::strcmp(str, "multiple")) { -      config.mode = lm::MODE_MULTIPLE; -    } else if (!std::strcmp(str, "union")) { -      config.mode = lm::MODE_UNION; -    } else if (!std::strcmp(str, "phrase")) { -      config.phrase = true; -    } else if (!std::strcmp(str, "context")) { -      config.context = true; -    } else if (!std::strcmp(str, "arpa")) { -      config.format = lm::FORMAT_ARPA; -    } else if (!std::strcmp(str, "raw")) { -      config.format = lm::FORMAT_COUNT; +    // I used to have boost::program_options, but some users didn't want to compile boost. +    lm::Config config; +    config.mode = lm::MODE_UNSET; +    for (int i = 1; i < argc - 2; ++i) { +      const char *str = argv[i]; +      if (!std::strcmp(str, "copy")) { +        config.mode = lm::MODE_COPY; +      } else if (!std::strcmp(str, "single")) { +        config.mode = lm::MODE_SINGLE; +      } else if (!std::strcmp(str, "multiple")) { +        config.mode = lm::MODE_MULTIPLE; +      } else if (!std::strcmp(str, "union")) { +        config.mode = lm::MODE_UNION; +      } else if (!std::strcmp(str, "phrase")) { +        config.phrase = true; +      } else if (!std::strcmp(str, "context")) { +        config.context = true; +      } else if (!std::strcmp(str, "arpa")) { +        config.format = lm::FORMAT_ARPA; +      } else if (!std::strcmp(str, "raw")) { +        config.format = lm::FORMAT_COUNT;  #ifndef NTHREAD -    } else if (!std::strncmp(str, "threads:", 8)) { -      config.threads = boost::lexical_cast<size_t>(str + 8); -      if (!config.threads) { -        std::cerr << "Specify at least one thread." << std::endl; +      } else if (!std::strncmp(str, "threads:", 8)) { +        config.threads = boost::lexical_cast<size_t>(str + 8); +        if (!config.threads) { +          std::cerr << "Specify at least one thread." << std::endl; +          return 1; +        } +      } else if (!std::strncmp(str, "batch_size:", 11)) { +        config.batch_size = boost::lexical_cast<size_t>(str + 11); +        if (config.batch_size < 5000) { +          std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; +          if (!config.batch_size) return 1; +        } +#endif +      } else { +        lm::DisplayHelp(argv[0]);          return 1;        } -    } else if (!std::strncmp(str, "batch_size:", 11)) { -      config.batch_size = boost::lexical_cast<size_t>(str + 11); -      if (config.batch_size < 5000) { -        std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; -        if (!config.batch_size) return 1; -      } -#endif -    } else { +    } + +    if (config.mode == lm::MODE_UNSET) {        lm::DisplayHelp(argv[0]);        return 1;      } -  } -   -  if (config.mode == lm::MODE_UNSET) { -    lm::DisplayHelp(argv[0]); -    return 1; -  } -  if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { -    std::cerr << "Phrase constraint currently only works in multiple or union mode.  If you really need it for single, put everything on one line and use union." << std::endl; -    return 1; -  } +    if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) { +      std::cerr << "Phrase constraint currently only works in multiple or union mode.  If you really need it for single, put everything on one line and use union." << std::endl; +      return 1; +    } -  bool cmd_is_model = true; -  const char *cmd_input = argv[argc - 2]; -  if (!strncmp(cmd_input, "vocab:", 6)) { -    cmd_is_model = false; -    cmd_input += 6; -  } else if (!strncmp(cmd_input, "model:", 6)) { -    cmd_input += 6; -  } else if (strchr(cmd_input, ':')) { -    errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); -  } else { -    std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; -  } -  std::ifstream cmd_file; -  std::istream *vocab; -  if (cmd_is_model) { -    vocab = &std::cin; -  } else { -    cmd_file.open(cmd_input, std::ios::in); -    if (!cmd_file) { -      err(2, "Could not open input file %s", cmd_input); +    bool cmd_is_model = true; +    const char *cmd_input = argv[argc - 2]; +    if (!strncmp(cmd_input, "vocab:", 6)) { +      cmd_is_model = false; +      cmd_input += 6; +    } else if (!strncmp(cmd_input, "model:", 6)) { +      cmd_input += 6; +    } else if (strchr(cmd_input, ':')) { +      std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl; +      return 1; +    } else { +      std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; +    } +    std::ifstream cmd_file; +    std::istream *vocab; +    if (cmd_is_model) { +      vocab = &std::cin; +    } else { +      cmd_file.open(cmd_input, std::ios::in); +      UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input); +      vocab = &cmd_file;      } -    vocab = &cmd_file; -  } -  util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); +    util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); -  if (config.format == lm::FORMAT_ARPA) { -    lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); -  } else if (config.format == lm::FORMAT_COUNT) { -    lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); +    if (config.format == lm::FORMAT_ARPA) { +      lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]); +    } else if (config.format == lm::FORMAT_COUNT) { +      lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]); +    } +    return 0; +  } catch (const std::exception &e) { +    std::cerr << e.what() << std::endl; +    return 1;    } -  return 0;  } diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh index 7f945b0d..7d8c28db 100644 --- a/klm/lm/filter/format.hh +++ b/klm/lm/filter/format.hh @@ -1,5 +1,5 @@  #ifndef LM_FILTER_FORMAT_H__ -#define LM_FITLER_FORMAT_H__ +#define LM_FILTER_FORMAT_H__  #include "lm/filter/arpa_io.hh"  #include "lm/filter/count_io.hh" diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc index 1bef2a3f..e2946b14 100644 --- a/klm/lm/filter/phrase.cc +++ b/klm/lm/filter/phrase.cc @@ -48,21 +48,21 @@ unsigned int ReadMultiple(std::istream &in, Substrings &out) {    return sentence_id + sentence_content;  } -namespace detail { const StringPiece kEndSentence("</s>"); } -  namespace { -  typedef unsigned int Sentence;  typedef std::vector<Sentence> Sentences; +} // namespace -class Vertex; +namespace detail {  + +const StringPiece kEndSentence("</s>");  class Arc {    public:      Arc() {}      // For arcs from one vertex to another.   -    void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) { +    void SetPhrase(detail::Vertex &from, detail::Vertex &to, const Sentences &intersect) {        Set(to, intersect);        from_ = &from;      } @@ -71,7 +71,7 @@ class Arc {       * aligned).  These have no from_ vertex; it implictly matches every       * sentence.  This also handles when the n-gram is a substring of a phrase.        */ -    void SetRight(Vertex &to, const Sentences &complete) { +    void SetRight(detail::Vertex &to, const Sentences &complete) {        Set(to, complete);        from_ = NULL;      } @@ -97,11 +97,11 @@ class Arc {      void LowerBound(const Sentence to);    private: -    void Set(Vertex &to, const Sentences &sentences); +    void Set(detail::Vertex &to, const Sentences &sentences);      const Sentence *current_;      const Sentence *last_; -    Vertex *from_; +    detail::Vertex *from_;  };  struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> { @@ -183,7 +183,13 @@ void Vertex::LowerBound(const Sentence to) {    }  } -void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Vertex *const vertices, Arc *free_arc) { +} // namespace detail + +namespace { + +void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, detail::Vertex *const vertices, detail::Arc *free_arc) { +  using detail::Vertex; +  using detail::Arc;    assert(!hashes.empty());    const Hash *const first_word = &*hashes.begin(); @@ -231,17 +237,29 @@ void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, Verte  namespace detail { -} // namespace detail +// Here instead of header due to forward declaration. +ConditionCommon::ConditionCommon(const Substrings &substrings) : substrings_(substrings) {} -bool Union::Evaluate() { +// Rest of the variables are temporaries anyway +ConditionCommon::ConditionCommon(const ConditionCommon &from) : substrings_(from.substrings_) {} + +ConditionCommon::~ConditionCommon() {} + +detail::Vertex &ConditionCommon::MakeGraph() {    assert(!hashes_.empty()); -  // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.   -  Vertex vertices[hashes_.size()]; +  vertices_.clear(); +  vertices_.resize(hashes_.size()); +  arcs_.clear();    // One for every substring.   -  Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; -  BuildGraph(substrings_, hashes_, vertices, arcs); -  Vertex &last_vertex = vertices[hashes_.size() - 1]; +  arcs_.resize(((hashes_.size() + 1) * hashes_.size()) / 2); +  BuildGraph(substrings_, hashes_, &*vertices_.begin(), &*arcs_.begin()); +  return vertices_[hashes_.size() - 1]; +} + +} // namespace detail +bool Union::Evaluate() { +  detail::Vertex &last_vertex = MakeGraph();    unsigned int lower = 0;    while (true) {      last_vertex.LowerBound(lower); @@ -252,14 +270,7 @@ bool Union::Evaluate() {  }  template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) { -  assert(!hashes_.empty()); -  // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable.   -  Vertex vertices[hashes_.size()]; -  // One for every substring.   -  Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; -  BuildGraph(substrings_, hashes_, vertices, arcs); -  Vertex &last_vertex = vertices[hashes_.size() - 1]; - +  detail::Vertex &last_vertex = MakeGraph();    unsigned int lower = 0;    while (true) {      last_vertex.LowerBound(lower); diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh index b4edff41..e8e85835 100644 --- a/klm/lm/filter/phrase.hh +++ b/klm/lm/filter/phrase.hh @@ -103,11 +103,33 @@ template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::    }  } +class Vertex; +class Arc; + +class ConditionCommon { +  protected: +    ConditionCommon(const Substrings &substrings); +    ConditionCommon(const ConditionCommon &from); + +    ~ConditionCommon(); + +    detail::Vertex &MakeGraph(); + +    // Temporaries in PassNGram and Evaluate to avoid reallocation. +    std::vector<Hash> hashes_; + +  private: +    std::vector<detail::Vertex> vertices_; +    std::vector<detail::Arc> arcs_; + +    const Substrings &substrings_; +}; +  } // namespace detail -class Union { +class Union : public detail::ConditionCommon {    public: -    explicit Union(const Substrings &substrings) : substrings_(substrings) {} +    explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {}      template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {        detail::MakeHashes(begin, end, hashes_); @@ -116,23 +138,19 @@ class Union {    private:      bool Evaluate(); - -    std::vector<Hash> hashes_; - -    const Substrings &substrings_;  }; -class Multiple { +class Multiple : public detail::ConditionCommon {    public: -    explicit Multiple(const Substrings &substrings) : substrings_(substrings) {} +    explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {}      template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {        detail::MakeHashes(begin, end, hashes_);        if (hashes_.empty()) {          output.AddNGram(line); -        return; +      } else { +        Evaluate(line, output);        } -      Evaluate(line, output);      }      template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { @@ -143,10 +161,6 @@ class Multiple {    private:      template <class Output> void Evaluate(const StringPiece &line, Output &output); - -    std::vector<Hash> hashes_; - -    const Substrings &substrings_;  };  } // namespace phrase diff --git a/klm/lm/filter/phrase_table_vocab_main.cc b/klm/lm/filter/phrase_table_vocab_main.cc new file mode 100644 index 00000000..e0f47d89 --- /dev/null +++ b/klm/lm/filter/phrase_table_vocab_main.cc @@ -0,0 +1,165 @@ +#include "util/fake_ofstream.hh" +#include "util/file_piece.hh" +#include "util/murmur_hash.hh" +#include "util/pool.hh" +#include "util/string_piece.hh" +#include "util/string_piece_hash.hh" +#include "util/tokenize_piece.hh" + +#include <boost/unordered_map.hpp> +#include <boost/unordered_set.hpp> + +#include <cstddef> +#include <vector> + +namespace { + +struct MutablePiece { +  mutable StringPiece behind; +  bool operator==(const MutablePiece &other) const { +    return behind == other.behind; +  } +}; + +std::size_t hash_value(const MutablePiece &m) { +  return hash_value(m.behind); +} + +class InternString { +  public: +    const char *Add(StringPiece str) { +      MutablePiece mut; +      mut.behind = str; +      std::pair<boost::unordered_set<MutablePiece>::iterator, bool> res(strs_.insert(mut)); +      if (res.second) { +        void *mem = backing_.Allocate(str.size() + 1); +        memcpy(mem, str.data(), str.size()); +        static_cast<char*>(mem)[str.size()] = 0; +        res.first->behind = StringPiece(static_cast<char*>(mem), str.size()); +      } +      return res.first->behind.data(); +    } + +  private: +    util::Pool backing_; +    boost::unordered_set<MutablePiece> strs_; +}; + +class TargetWords { +  public: +    void Introduce(StringPiece source) { +      vocab_.resize(vocab_.size() + 1); +      std::vector<unsigned int> temp(1, vocab_.size() - 1); +      Add(temp, source); +    } + +    void Add(const std::vector<unsigned int> &sentences, StringPiece target) { +      if (sentences.empty()) return; +      interns_.clear(); +      for (util::TokenIter<util::SingleCharacter, true> i(target, ' '); i; ++i) { +        interns_.push_back(intern_.Add(*i)); +      } +      for (std::vector<unsigned int>::const_iterator i(sentences.begin()); i != sentences.end(); ++i) { +        boost::unordered_set<const char *> &vocab = vocab_[*i]; +        for (std::vector<const char *>::const_iterator j = interns_.begin(); j != interns_.end(); ++j) { +          vocab.insert(*j); +        } +      } +    } + +    void Print() const { +      util::FakeOFStream out(1); +      for (std::vector<boost::unordered_set<const char *> >::const_iterator i = vocab_.begin(); i != vocab_.end(); ++i) { +        for (boost::unordered_set<const char *>::const_iterator j = i->begin(); j != i->end(); ++j) { +          out << *j << ' '; +        } +        out << '\n'; +      } +    } + +  private: +    InternString intern_; + +    std::vector<boost::unordered_set<const char *> > vocab_; + +    // Temporary in Add. +    std::vector<const char *> interns_; +}; + +class Input { +  public: +    explicit Input(std::size_t max_length)  +      : max_length_(max_length), sentence_id_(0), empty_() {} + +    void AddSentence(StringPiece sentence, TargetWords &targets) { +      canonical_.clear(); +      starts_.clear(); +      starts_.push_back(0); +      for (util::TokenIter<util::AnyCharacter, true> i(sentence, StringPiece("\0 \t", 3)); i; ++i) { +        canonical_.append(i->data(), i->size()); +        canonical_ += ' '; +        starts_.push_back(canonical_.size()); +      } +      targets.Introduce(canonical_); +      for (std::size_t i = 0; i < starts_.size() - 1; ++i) { +        std::size_t subtract = starts_[i]; +        const char *start = &canonical_[subtract]; +        for (std::size_t j = i + 1; j < std::min(starts_.size(), i + max_length_ + 1); ++j) { +          map_[util::MurmurHash64A(start, &canonical_[starts_[j]] - start - 1)].push_back(sentence_id_); +        } +      } +      ++sentence_id_; +    } + +    // Assumes single space-delimited phrase with no space at the beginning or end. +    const std::vector<unsigned int> &Matches(StringPiece phrase) const { +      Map::const_iterator i = map_.find(util::MurmurHash64A(phrase.data(), phrase.size())); +      return i == map_.end() ? empty_ : i->second; +    } + +  private: +    const std::size_t max_length_; + +    // hash of phrase is the key, array of sentences is the value. +    typedef boost::unordered_map<uint64_t, std::vector<unsigned int> > Map; +    Map map_; + +    std::size_t sentence_id_; +     +    // Temporaries in AddSentence. +    std::string canonical_; +    std::vector<std::size_t> starts_; + +    const std::vector<unsigned int> empty_; +}; + +} // namespace + +int main(int argc, char *argv[]) { +  if (argc != 2) { +    std::cerr << "Expected source text on the command line" << std::endl; +    return 1; +  } +  Input input(7); +  TargetWords targets; +  try { +    util::FilePiece inputs(argv[1], &std::cerr); +    while (true) +      input.AddSentence(inputs.ReadLine(), targets); +  } catch (const util::EndOfFileException &e) {} + +  util::FilePiece table(0, NULL, &std::cerr); +  StringPiece line; +  const StringPiece pipes("|||"); +  while (true) { +    try { +      line = table.ReadLine(); +    } catch (const util::EndOfFileException &e) { break; } +    util::TokenIter<util::MultiCharacter> it(line, pipes); +    StringPiece source(*it); +    if (!source.empty() && source[source.size() - 1] == ' ') +      source.remove_suffix(1); +    targets.Add(input.Matches(source), *++it); +  } +  targets.Print(); +} diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc index 7ee4e84b..011ab599 100644 --- a/klm/lm/filter/vocab.cc +++ b/klm/lm/filter/vocab.cc @@ -4,7 +4,6 @@  #include <iostream>  #include <ctype.h> -#include <err.h>  namespace lm {  namespace vocab { diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh index 90b07a08..eb657501 100644 --- a/klm/lm/filter/wrapper.hh +++ b/klm/lm/filter/wrapper.hh @@ -39,17 +39,15 @@ template <class FilterT> class ContextFilter {      explicit ContextFilter(Filter &backend) : backend_(backend) {}      template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { -      pieces_.clear(); -      // TODO: this copy could be avoided by a lookahead iterator. -      std::copy(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), std::back_insert_iterator<std::vector<StringPiece> >(pieces_)); -      backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output); +      // Find beginning of string or last space. +      const char *last_space; +      for (last_space = ngram.data() + ngram.size() - 1; last_space > ngram.data() && *last_space != ' '; --last_space) {} +      backend_.AddNGram(StringPiece(ngram.data(), last_space - ngram.data()), line, output);      }      void Flush() const {}    private: -    std::vector<StringPiece> pieces_; -      Filter backend_;  }; diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a26654a6..a5a16bf8 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -34,23 +34,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT    if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);  } -template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &config) { -  LoadLM(file, config, *this); - -  // g++ prints warnings unless these are fully initialized. -  State begin_sentence = State(); -  begin_sentence.length = 1; -  begin_sentence.words[0] = vocab_.BeginSentence(); -  typename Search::Node ignored_node; -  bool ignored_independent_left; -  uint64_t ignored_extend_left; -  begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); -  State null_context = State(); -  null_context.length = 0; -  P::Init(begin_sentence, null_context, vocab_, search_.Order()); +namespace { +void ComplainAboutARPA(const Config &config, ModelType model_type) { +  if (config.write_mmap || !config.messages) return; +  if (config.arpa_complain == Config::ALL) { +    *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; +  } else if (config.arpa_complain == Config::EXPENSIVE && +             (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { +    *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive.  Save time by building a binary format." << std::endl; +  }  } -namespace {  void CheckCounts(const std::vector<uint64_t> &counts) {    UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ".  " << KENLM_ORDER_MESSAGE);    if (sizeof(uint64_t) > sizeof(std::size_t)) { @@ -59,18 +53,45 @@ void CheckCounts(const std::vector<uint64_t> &counts) {      }    }  } +  } // namespace -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { -  CheckCounts(params.counts); -  SetupMemory(start, params.counts, config); -  vocab_.LoadedBinary(params.fixed.has_vocabulary, fd, config.enumerate_vocab); -  search_.LoadedBinary(); +template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) { +  util::scoped_fd fd(util::OpenReadOrThrow(file)); +  if (IsBinaryFormat(fd.get())) { +    Parameters parameters; +    int fd_shallow = fd.release(); +    backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters); +    CheckCounts(parameters.counts); + +    Config new_config(init_config); +    new_config.probing_multiplier = parameters.fixed.probing_multiplier; +    Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config); +    UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them.  You may need to rebuild the binary file with an updated version of build_binary."); + +    SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config); +    vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset()); +  } else { +    ComplainAboutARPA(init_config, kModelType); +    InitializeFromARPA(fd.release(), file, init_config); +  } + +  // g++ prints warnings unless these are fully initialized. +  State begin_sentence = State(); +  begin_sentence.length = 1; +  begin_sentence.words[0] = vocab_.BeginSentence(); +  typename Search::Node ignored_node; +  bool ignored_independent_left; +  uint64_t ignored_extend_left; +  begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); +  State null_context = State(); +  null_context.length = 0; +  P::Init(begin_sentence, null_context, vocab_, search_.Order());  } -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { -  // Backing file is the ARPA.  Steal it so we can make the backing file the mmap output if any. -  util::FilePiece f(backing_.file.release(), file, config.ProgressMessages()); +template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) { +  // Backing file is the ARPA. +  util::FilePiece f(fd, file, config.ProgressMessages());    try {      std::vector<uint64_t> counts;      // File counts do not include pruned trigrams that extend to quadgrams etc.   These will be fixed by search_. @@ -81,13 +102,17 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT      std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));      // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs. -    vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); +    vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config); -    if (config.write_mmap) { +    if (config.write_mmap && config.include_vocab) {        WriteWordsWrapper wrap(config.enumerate_vocab);        vocab_.ConfigureEnumerate(&wrap, counts[0]);        search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); -      wrap.Write(backing_.file.get(), backing_.vocab.size() + vocab_.UnkCountChangePadding() + Search::Size(counts, config)); +      void *vocab_rebase, *search_rebase; +      backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase); +      // Due to writing at the end of file, mmap may have relocated data.  So remap. +      vocab_.Relocate(vocab_rebase); +      search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config);      } else {        vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);        search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); @@ -99,18 +124,13 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT        search_.UnknownUnigram().backoff = 0.0;        search_.UnknownUnigram().prob = config.unknown_missing_logprob;      } -    FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_); +    backing_.FinishFile(config, kModelType, kVersion, counts);    } catch (util::Exception &e) {      e << " Byte: " << f.Offset();      throw;    }  } -template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { -  util::AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); -  Search::UpdateConfigFromBinary(fd, counts, config); -} -  template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {    FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);    for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) { diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 60f55110..e75da93b 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -67,7 +67,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod      FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;      /* Get the state for a context.  Don't use this if you can avoid it.  Use -     * BeginSentenceState or EmptyContextState and extend from those.  If +     * BeginSentenceState or NullContextState and extend from those.  If       * you're only going to use this state to call FullScore once, use       * FullScoreForgotState.        * To use this function, make an array of WordIndex containing the context @@ -104,10 +104,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod      }    private: -    friend void lm::ngram::LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); - -    static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); -      FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;      // Score bigrams and above.  Do not include backoff.    @@ -116,15 +112,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod      // Appears after Size in the cc file.      void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); -    void InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd); - -    void InitializeFromARPA(const char *file, const Config &config); +    void InitializeFromARPA(int fd, const char *file, const Config &config);      float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const; -    Backing &MutableBacking() { return backing_; } - -    Backing backing_; +    BinaryFormat backing_;      VocabularyT vocab_; diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index eb159094..7005b05e 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -360,10 +360,11 @@ BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {    LoadingTest<QuantArrayTrieModel>();  } -template <class ModelT> void BinaryTest() { +template <class ModelT> void BinaryTest(Config::WriteMethod write_method) {    Config config;    config.write_mmap = "test.binary";    config.messages = NULL; +  config.write_method = write_method;    ExpectEnumerateVocab enumerate;    config.enumerate_vocab = &enumerate; @@ -406,6 +407,11 @@ template <class ModelT> void BinaryTest() {    unlink("test_nounk.binary");  } +template <class ModelT> void BinaryTest() { +  BinaryTest<ModelT>(Config::WRITE_MMAP); +  BinaryTest<ModelT>(Config::WRITE_AFTER); +} +  BOOST_AUTO_TEST_CASE(write_and_read_probing) {    BinaryTest<ProbingModel>();  } diff --git a/klm/lm/ngram_query.hh b/klm/lm/ngram_query.hh index dfcda170..ec2590f4 100644 --- a/klm/lm/ngram_query.hh +++ b/klm/lm/ngram_query.hh @@ -11,21 +11,25 @@  #include <istream>  #include <string> +#include <math.h> +  namespace lm {  namespace ngram {  template <class Model> void Query(const Model &model, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { -  std::cerr << "Loading statistics:\n"; -  util::PrintUsage(std::cerr);    typename Model::State state, out;    lm::FullScoreReturn ret;    std::string word; +  double corpus_total = 0.0; +  uint64_t corpus_oov = 0; +  uint64_t corpus_tokens = 0; +    while (in_stream) {      state = sentence_context ? model.BeginSentenceState() : model.NullContextState();      float total = 0.0;      bool got = false; -    unsigned int oov = 0; +    uint64_t oov = 0;      while (in_stream >> word) {        got = true;        lm::WordIndex vocab = model.GetVocabulary().Index(word); @@ -33,6 +37,7 @@ template <class Model> void Query(const Model &model, bool sentence_context, std        ret = model.FullScore(state, vocab, out);        total += ret.prob;        out_stream << word << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length)  << ' ' << ret.prob << '\t'; +      ++corpus_tokens;        state = out;        char c;        while (true) { @@ -50,12 +55,14 @@ template <class Model> void Query(const Model &model, bool sentence_context, std      if (sentence_context) {        ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);        total += ret.prob; +      ++corpus_tokens;        out_stream << "</s>=" << model.GetVocabulary().EndSentence() << ' ' << static_cast<unsigned int>(ret.ngram_length)  << ' ' << ret.prob << '\t';      }      out_stream << "Total: " << total << " OOV: " << oov << '\n'; +    corpus_total += total; +    corpus_oov += oov;    } -  std::cerr << "After queries:\n"; -  util::PrintUsage(std::cerr); +  out_stream << "Perplexity " << pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))) << std::endl;  }  template <class M> void Query(const char *file, bool sentence_context, std::istream &in_stream, std::ostream &out_stream) { diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index b58c3f3f..273ea398 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -38,13 +38,13 @@ const char kSeparatelyQuantizeVersion = 2;  } // namespace -void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) { -  char version; -  util::ReadOrThrow(fd, &version, 1); -  util::ReadOrThrow(fd, &config.prob_bits, 1); -  util::ReadOrThrow(fd, &config.backoff_bits, 1); +void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { +  unsigned char buffer[3]; +  file.ReadForConfig(buffer, 3, offset); +  char version = buffer[0]; +  config.prob_bits = buffer[1]; +  config.backoff_bits = buffer[2];    if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); -  util::AdvanceOrThrow(fd, -3);  }  void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 8ce2378a..9d3a2f43 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -18,12 +18,13 @@ namespace lm {  namespace ngram {  struct Config; +class BinaryFormat;  /* Store values directly and don't quantize. */  class DontQuantize {    public:      static const ModelType kModelTypeAdd = static_cast<ModelType>(0); -    static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} +    static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {}      static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }      static uint8_t MiddleBits(const Config &/*config*/) { return 63; }      static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -136,7 +137,7 @@ class SeparatelyQuantize {    public:      static const ModelType kModelTypeAdd = kQuantAdd; -    static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); +    static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);      static uint64_t Size(uint8_t order, const Config &config) {        uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float); diff --git a/klm/lm/query_main.cc b/klm/lm/query_main.cc index 27d3a1a5..bd4fde62 100644 --- a/klm/lm/query_main.cc +++ b/klm/lm/query_main.cc @@ -1,42 +1,65 @@  #include "lm/ngram_query.hh" +#ifdef WITH_NPLM +#include "lm/wrappers/nplm.hh" +#endif + +#include <stdlib.h> + +void Usage(const char *name) { +  std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; +  std::cerr << "Usage: " << name << " [-n] lm_file" << std::endl; +  std::cerr << "Input is wrapped in <s> and </s> unless -n is passed." << std::endl; +  exit(1); +} +  int main(int argc, char *argv[]) { -  if (!(argc == 2 || (argc == 3 && !strcmp(argv[2], "null")))) { -    std::cerr << "KenLM was compiled with maximum order " << KENLM_MAX_ORDER << "." << std::endl; -    std::cerr << "Usage: " << argv[0] << " lm_file [null]" << std::endl; -    std::cerr << "Input is wrapped in <s> and </s> unless null is passed." << std::endl; -    return 1; +  bool sentence_context = true; +  const char *file = NULL; +  for (char **arg = argv + 1; arg != argv + argc; ++arg) { +    if (!strcmp(*arg, "-n")) { +      sentence_context = false; +    } else if (!strcmp(*arg, "-h") || !strcmp(*arg, "--help") || file) { +      Usage(argv[0]); +    } else { +      file = *arg; +    }    } +  if (!file) Usage(argv[0]);    try { -    bool sentence_context = (argc == 2);      using namespace lm::ngram;      ModelType model_type; -    if (RecognizeBinary(argv[1], model_type)) { +    if (RecognizeBinary(file, model_type)) {        switch(model_type) {          case PROBING: -          Query<lm::ngram::ProbingModel>(argv[1], sentence_context, std::cin, std::cout); +          Query<lm::ngram::ProbingModel>(file, sentence_context, std::cin, std::cout);            break;          case REST_PROBING: -          Query<lm::ngram::RestProbingModel>(argv[1], sentence_context, std::cin, std::cout); +          Query<lm::ngram::RestProbingModel>(file, sentence_context, std::cin, std::cout);            break;          case TRIE: -          Query<TrieModel>(argv[1], sentence_context, std::cin, std::cout); +          Query<TrieModel>(file, sentence_context, std::cin, std::cout);            break;          case QUANT_TRIE: -          Query<QuantTrieModel>(argv[1], sentence_context, std::cin, std::cout); +          Query<QuantTrieModel>(file, sentence_context, std::cin, std::cout);            break;          case ARRAY_TRIE: -          Query<ArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout); +          Query<ArrayTrieModel>(file, sentence_context, std::cin, std::cout);            break;          case QUANT_ARRAY_TRIE: -          Query<QuantArrayTrieModel>(argv[1], sentence_context, std::cin, std::cout); +          Query<QuantArrayTrieModel>(file, sentence_context, std::cin, std::cout);            break;          default:            std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;            abort();        } +#ifdef WITH_NPLM +    } else if (lm::np::Model::Recognize(file)) { +      lm::np::Model model(file); +      Query(model, sentence_context, std::cin, std::cout); +#endif      } else { -      Query<ProbingModel>(argv[1], sentence_context, std::cin, std::cout); +      Query<ProbingModel>(file, sentence_context, std::cin, std::cout);      }      std::cerr << "Total time including destruction:\n";      util::PrintUsage(std::cerr); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 9ea08798..fb8bbfa2 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -19,7 +19,7 @@  namespace lm { -// 1 for '\t', '\n', and ' '.  This is stricter than isspace.   +// 1 for '\t', '\n', and ' '.  This is stricter than isspace.  const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};  namespace { @@ -50,7 +50,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {    // In general, ARPA files can have arbitrary text before "\data\"    // But in KenLM, we require such lines to start with "#", so that    // we can do stricter error checking -  while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) { +  while (IsEntirelyWhiteSpace(line) || starts_with(line, "#")) {      line = in.ReadLine();    } @@ -58,7 +58,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {      if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {        UTIL_THROW(FormatLoadException, "Looks like a gzip file.  If this is an ARPA file, pipe " << in.FileName() << " through zcat.  If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");      } -    if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)  +    if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)        UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser.  Did you compress the binary file or pass a binary file where only ARPA files are accepted?");      UTIL_THROW_IF(line.size() >= 4 && StringPiece(line.data(), 4) == "blmt", FormatLoadException, "This looks like an IRSTLM binary file.  Did you forget to pass --text yes to compile-lm?");      UTIL_THROW_IF(line == "iARPA", FormatLoadException, "This looks like an IRSTLM iARPA file.  You need an ARPA file.  Run\n  compile-lm --text yes " << in.FileName() << " " << in.FileName() << ".arpa\nfirst."); @@ -66,7 +66,7 @@ void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {    }    while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {      if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); -    // So strtol doesn't go off the end of line.   +    // So strtol doesn't go off the end of line.      std::string remaining(line.data() + 6, line.size() - 6);      char *end_ptr;      unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10); @@ -102,8 +102,8 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {  }  void ReadBackoff(util::FilePiece &in, float &backoff) { -  // Always make zero negative.   -  // Negative zero means that no (n+1)-gram has this n-gram as context.   +  // Always make zero negative. +  // Negative zero means that no (n+1)-gram has this n-gram as context.    // Therefore the hypothesis state can be shorter.  Of course, many n-grams    // are context for (n+1)-grams.  An algorithm in the data structure will go    // back and set the backoff to positive zero in these cases. @@ -150,7 +150,7 @@ void PositiveProbWarn::Warn(float prob) {      case THROW_UP:        UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model.  This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability.  Error");      case COMPLAIN: -      std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM.  This and subsequent entires will be mapepd to 0 log probability." << std::endl; +      std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM.  This and subsequent entires will be mapped to 0 log probability." << std::endl;        action_ = SILENT;        break;      case SILENT: diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 62275d27..354a56b4 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -204,9 +204,10 @@ template <class Build, class Activate, class Store> void ReadNGrams(  namespace detail {  template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { -  std::size_t allocated = Unigram::Size(counts[0]); -  unigram_ = Unigram(start, counts[0], allocated); -  start += allocated; +  unigram_ = Unigram(start, counts[0]); +  start += Unigram::Size(counts[0]); +  std::size_t allocated; +  middle_.clear();    for (unsigned int n = 2; n < counts.size(); ++n) {      allocated = Middle::Size(counts[n - 1], config.probing_multiplier);      middle_.push_back(Middle(start, allocated)); @@ -218,9 +219,21 @@ template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start,    return start;  } -template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing) { -  // TODO: fix sorted. -  SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config); +/*template <class Value> void HashedSearch<Value>::Relocate(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { +  unigram_ = Unigram(start, counts[0]); +  start += Unigram::Size(counts[0]); +  for (unsigned int n = 2; n < counts.size(); ++n) { +    middle[n-2].Relocate(start); +    start += Middle::Size(counts[n - 1], config.probing_multiplier) +  } +  longest_.Relocate(start); +}*/ + +template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) { +  void *vocab_rebase; +  void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase); +  vocab.Relocate(vocab_rebase); +  SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config);    PositiveProbWarn warn(config.positive_log_probability);    Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn); @@ -277,14 +290,6 @@ template <class Value> template <class Build> void HashedSearch<Value>::ApplyBui    ReadEnd(f);  } -template <class Value> void HashedSearch<Value>::LoadedBinary() { -  unigram_.LoadedBinary(); -  for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) { -    i->LoadedBinary(); -  } -  longest_.LoadedBinary(); -} -  template class HashedSearch<BackoffValue>;  template class HashedSearch<RestValue>; diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 9d067bc2..8193262b 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -18,7 +18,7 @@ namespace util { class FilePiece; }  namespace lm {  namespace ngram { -struct Backing; +class BinaryFormat;  class ProbingVocabulary;  namespace detail { @@ -72,7 +72,7 @@ template <class Value> class HashedSearch {      static const unsigned int kVersion = 0;      // TODO: move probing_multiplier here with next binary file format update. -    static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} +    static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {}      static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {        uint64_t ret = Unigram::Size(counts[0]); @@ -84,9 +84,7 @@ template <class Value> class HashedSearch {      uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); -    void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, Backing &backing); - -    void LoadedBinary(); +    void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing);      unsigned char Order() const {        return middle_.size() + 2; @@ -148,7 +146,7 @@ template <class Value> class HashedSearch {        public:          Unigram() {} -        Unigram(void *start, uint64_t count, std::size_t /*allocated*/) : +        Unigram(void *start, uint64_t count) :            unigram_(static_cast<typename Value::Weights*>(start))  #ifdef DEBUG           ,  count_(count) @@ -168,8 +166,6 @@ template <class Value> class HashedSearch {          typename Value::Weights &Unknown() { return unigram_[0]; } -        void LoadedBinary() {} -          // For building.          typename Value::Weights *Raw() { return unigram_; } diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1b0d9b26..4a88194e 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -253,11 +253,6 @@ class FindBlanks {        ++counts_.back();      } -    // Unigrams wrote one past. -    void Cleanup() { -      --counts_[0]; -    } -      const std::vector<uint64_t> &Counts() const {        return counts_;      } @@ -310,8 +305,6 @@ template <class Quant, class Bhiksha> class WriteEntries {        typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);      } -    void Cleanup() {} -    private:      RecordReader *contexts_;      const Quant &quant_; @@ -385,14 +378,14 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con    util::ErsatzProgress progress(unigram_count + 1, progress_out, message);    WordIndex unigram = 0;    std::priority_queue<Gram> grams; -  grams.push(Gram(&unigram, 1)); +  if (unigram_count) grams.push(Gram(&unigram, 1));    for (unsigned char i = 2; i <= total_order; ++i) {      if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i));    }    BlankManager<Doing> blank(total_order, doing); -  while (true) { +  while (!grams.empty()) {      Gram top = grams.top();      grams.pop();      unsigned char order = top.end - top.begin; @@ -400,8 +393,7 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con        blank.Visit(&unigram, 1, doing.UnigramProb(unigram));        doing.Unigram(unigram);        progress.Set(unigram); -      if (++unigram == unigram_count + 1) break; -      grams.push(top); +      if (++unigram < unigram_count) grams.push(top);      } else {        if (order == total_order) {          blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob); @@ -414,8 +406,6 @@ template <class Doing> void RecursiveInsert(const unsigned char total_order, con        if (++reader) grams.push(top);      }    } -  assert(grams.empty()); -  doing.Cleanup();  }  void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) { @@ -469,7 +459,7 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c  } // namespace -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) {    RecordReader inputs[KENLM_MAX_ORDER - 1];    RecordReader contexts[KENLM_MAX_ORDER - 1]; @@ -498,7 +488,10 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve    sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); -  out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config); +  void *vocab_relocate; +  void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate); +  vocab.Relocate(vocab_relocate); +  out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config);    for (unsigned char i = 2; i <= counts.size(); ++i) {      inputs[i-2].Rewind(); @@ -524,6 +517,8 @@ template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::ve    {      WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);      RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); +    // Write the last unigram entry, which is the end pointer for the bigrams.   +    writer.Unigram(counts[0]);    }    // Do not disable this error message or else too little state will be returned.  Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -579,15 +574,7 @@ template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::Setup    return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);  } -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() { -  unigram_.LoadedBinary(); -  for (Middle *i = middle_begin_; i != middle_end_; ++i) { -    i->LoadedBinary(); -  } -  longest_.LoadedBinary(); -} - -template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) {    std::string temporary_prefix;    if (config.temporary_directory_prefix) {      temporary_prefix = config.temporary_directory_prefix; diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 60be416b..299262a5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -11,18 +11,19 @@  #include "util/file_piece.hh"  #include <vector> +#include <cstdlib>  #include <assert.h>  namespace lm {  namespace ngram { -struct Backing; +class BinaryFormat;  class SortedVocabulary;  namespace trie {  template <class Quant, class Bhiksha> class TrieSearch;  class SortedFiles; -template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); +template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);  template <class Quant, class Bhiksha> class TrieSearch {    public: @@ -38,11 +39,11 @@ template <class Quant, class Bhiksha> class TrieSearch {      static const unsigned int kVersion = 1; -    static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { -      Quant::UpdateConfigFromBinary(fd, counts, config); -      util::AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); +    static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) { +      Quant::UpdateConfigFromBinary(file, offset, config);        // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. -      if (counts.size() > 2) Bhiksha::UpdateConfigFromBinary(fd, config); +      if (counts.size() > 2) +        Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config);      }      static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { @@ -59,9 +60,7 @@ template <class Quant, class Bhiksha> class TrieSearch {      uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); -    void LoadedBinary(); - -    void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); +    void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing);      unsigned char Order() const {        return middle_end_ - middle_begin_ + 2; @@ -102,14 +101,14 @@ template <class Quant, class Bhiksha> class TrieSearch {      }    private: -    friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); +    friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing); -    // Middles are managed manually so we can delay construction and they don't have to be copyable.   +    // Middles are managed manually so we can delay construction and they don't have to be copyable.      void FreeMiddles() {        for (const Middle *i = middle_begin_; i != middle_end_; ++i) {          i->~Middle();        } -      free(middle_begin_); +      std::free(middle_begin_);      }      typedef trie::BitPackedMiddle<Bhiksha> Middle; diff --git a/klm/lm/state.hh b/klm/lm/state.hh index a6b9accb..543df37c 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -102,7 +102,7 @@ struct ChartState {    }    bool operator<(const ChartState &other) const { -    return Compare(other) == -1; +    return Compare(other) < 0;    }    void ZeroRemaining() { diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 9ea3c546..d858ab5e 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -62,8 +62,6 @@ class Unigram {        return unigram_;      } -    void LoadedBinary() {} -      UnigramPointer Find(WordIndex word, NodeRange &next) const {        UnigramValue *val = unigram_ + word;        next.begin = val->next; @@ -108,8 +106,6 @@ template <class Bhiksha> class BitPackedMiddle : public BitPacked {      void FinishedLoading(uint64_t next_end, const Config &config); -    void LoadedBinary() { bhiksha_.LoadedBinary(); } -      util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;      util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) { @@ -138,14 +134,9 @@ class BitPackedLongest : public BitPacked {        BaseInit(base, max_vocab, quant_bits);      } -    void LoadedBinary() {} -      util::BitAddress Insert(WordIndex word);      util::BitAddress Find(WordIndex word, const NodeRange &node) const; - -  private: -    uint8_t quant_bits_;  };  } // namespace trie diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index dc542bb3..126d43ab 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -50,6 +50,10 @@ class PartialViewProxy {      const void *Data() const { return inner_.Data(); }      void *Data() { return inner_.Data(); } +    friend void swap(PartialViewProxy first, PartialViewProxy second) { +      std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data())); +    } +    private:      friend class util::ProxyIterator<PartialViewProxy>; diff --git a/klm/lm/value_build.cc b/klm/lm/value_build.cc index 6124f8da..3ec3dce2 100644 --- a/klm/lm/value_build.cc +++ b/klm/lm/value_build.cc @@ -9,6 +9,7 @@ namespace ngram {  template <class Model> LowerRestBuild<Model>::LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab) {    UTIL_THROW_IF(config.rest_lower_files.size() != order - 1, ConfigException, "This model has order " << order << " so there should be " << (order - 1) << " lower-order models for rest cost purposes.");    Config for_lower = config; +  for_lower.write_mmap = NULL;    for_lower.rest_lower_files.clear();    // Unigram models aren't supported, so this is a custom loader.   diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 17f064b2..7a3e2379 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -125,10 +125,13 @@ class Model {      void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }      // Requires in_state != out_state -    virtual float Score(const void *in_state, const WordIndex new_word, void *out_state) const = 0; +    virtual float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;      // Requires in_state != out_state -    virtual FullScoreReturn FullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; +    virtual FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0; + +    // Prefer to use FullScore.  The context words should be provided in reverse order. +    virtual FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0;      unsigned char Order() const { return order_; } diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index fd7f96dc..7f0878f4 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -32,7 +32,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);  // Sadly some LMs have <UNK>.    const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); -void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count) { +void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) { +  util::SeekOrThrow(fd, offset);    // Check that we're at the right place by reading <unk> which is always first.    char check_unk[6];    util::ReadOrThrow(fd, check_unk, 6); @@ -80,11 +81,6 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {    buffer_.push_back(0);  } -void WriteWordsWrapper::Write(int fd, uint64_t start) { -  util::SeekOrThrow(fd, start); -  util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); -} -  SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}  uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) { @@ -100,6 +96,12 @@ void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size    saw_unk_ = false;  } +void SortedVocabulary::Relocate(void *new_start) { +  std::size_t delta = end_ - begin_; +  begin_ = reinterpret_cast<uint64_t*>(new_start) + 1; +  end_ = begin_ + delta; +} +  void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {    enumerate_ = to;    if (enumerate_) { @@ -147,11 +149,11 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {    bound_ = end_ - begin_ + 1;  } -void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {    end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);    SetSpecial(Index("<s>"), Index("</s>"), 0);    bound_ = end_ - begin_ + 1; -  if (have_words) ReadWords(fd, to, bound_); +  if (have_words) ReadWords(fd, to, bound_, offset);  }  namespace { @@ -179,6 +181,11 @@ void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::siz    saw_unk_ = false;  } +void ProbingVocabulary::Relocate(void *new_start) { +  header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start); +  lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader))); +} +  void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {    enumerate_ = to;    if (enumerate_) { @@ -206,12 +213,11 @@ void ProbingVocabulary::InternalFinishedLoading() {    SetSpecial(Index("<s>"), Index("</s>"), 0);  } -void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to) { +void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {    UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ".  Please rerun build_binary using the same version of the code."); -  lookup_.LoadedBinary();    bound_ = header_->bound;    SetSpecial(Index("<s>"), Index("</s>"), 0); -  if (have_words) ReadWords(fd, to, bound_); +  if (have_words) ReadWords(fd, to, bound_, offset);  }  void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 226ae438..074b74d8 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -36,7 +36,7 @@ class WriteWordsWrapper : public EnumerateVocab {      void Add(WordIndex index, const StringPiece &str); -    void Write(int fd, uint64_t start); +    const std::string &Buffer() const { return buffer_; }    private:      EnumerateVocab *inner_; @@ -71,6 +71,8 @@ class SortedVocabulary : public base::Vocabulary {      // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway.      void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); +    void Relocate(void *new_start); +      void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);      WordIndex Insert(const StringPiece &str); @@ -83,15 +85,13 @@ class SortedVocabulary : public base::Vocabulary {      bool SawUnk() const { return saw_unk_; } -    void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); +    void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);    private:      uint64_t *begin_, *end_;      WordIndex bound_; -    WordIndex highest_value_; -      bool saw_unk_;      EnumerateVocab *enumerate_; @@ -140,6 +140,8 @@ class ProbingVocabulary : public base::Vocabulary {      // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway.      void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); +    void Relocate(void *new_start); +      void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);      WordIndex Insert(const StringPiece &str); @@ -152,7 +154,7 @@ class ProbingVocabulary : public base::Vocabulary {      bool SawUnk() const { return saw_unk_; } -    void LoadedBinary(bool have_words, int fd, EnumerateVocab *to); +    void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);    private:      void InternalFinishedLoading(); | 
