diff options
| author | Patrick Simianer <p@simianer.de> | 2011-09-09 15:33:35 +0200 | 
|---|---|---|
| committer | Patrick Simianer <p@simianer.de> | 2011-09-23 19:13:58 +0200 | 
| commit | edb0cc0cbae1e75e4aeedb6360eab325effe6573 (patch) | |
| tree | a2fed4614b88f177f91e88fef3b269fa75e80188 /klm | |
| parent | 2e6ef7cbec77b22ce3d64416a5ada3a6c081f9e2 (diff) | |
partial merge, ruleid feature
Diffstat (limited to 'klm')
32 files changed, 867 insertions, 326 deletions
| diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 395494bc..fae6b41a 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -12,6 +12,7 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz  noinst_LIBRARIES = libklm.a  libklm_a_SOURCES = \ +  bhiksha.cc \    binary_format.cc \    config.cc \    lm_exception.cc \ diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc new file mode 100644 index 00000000..bf86fd4b --- /dev/null +++ b/klm/lm/bhiksha.cc @@ -0,0 +1,93 @@ +#include "lm/bhiksha.hh" +#include "lm/config.hh" + +#include <limits> + +namespace lm { +namespace ngram { +namespace trie { + +DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) :  +  next_(util::BitsMask::ByMax(max_next)) {} + +const uint8_t kArrayBhikshaVersion = 0; + +void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { +  uint8_t version; +  uint8_t configured_bits; +  if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) { +    UTIL_THROW(util::ErrnoException, "Could not read from binary file"); +  } +  if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); +  config.pointer_bhiksha_bits = configured_bits; +} + +namespace { + +// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset) +uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) { +  uint8_t required = util::RequiredBits(max_next); +  uint8_t best_chop = 0; +  int64_t lowest_change = std::numeric_limits<int64_t>::max(); +  // There are probably faster ways but I don't care because this is only done once per order at construction time.   +  for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) { +    int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */ +      - max_offset * static_cast<int64_t>(chop); /* savings in bits*/ +    if (change < lowest_change) { +      lowest_change = change; +      best_chop = chop; +    } +  } +  return best_chop; +} + +std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) { +  uint8_t required = util::RequiredBits(max_next); +  uint8_t chopping = ChopBits(max_offset, max_next, config); +  return (max_next >> (required - chopping)) + 1 /* we store 0 too */; +} +} // namespace + +std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { +  return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; +} + +uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) { +  return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config); +} + +namespace { + +void *AlignTo8(void *from) { +  uint8_t *val = reinterpret_cast<uint8_t*>(from); +  std::size_t remainder = reinterpret_cast<std::size_t>(val) & 7; +  if (!remainder) return val; +  return val + 8 - remainder; +} + +} // namespace + +ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config) +  : next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))), +    offset_begin_(reinterpret_cast<const uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */), +    offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)), +    write_to_(reinterpret_cast<uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */), +    original_base_(base) {} + +void ArrayBhiksha::FinishedLoading(const Config &config) { +  // *offset_begin_ = 0 but without a const_cast. +  *(write_to_ - (write_to_ - offset_begin_)) = 0; + +  if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected."); + +  uint8_t *head_write = reinterpret_cast<uint8_t*>(original_base_); +  *(head_write++) = kArrayBhikshaVersion; +  *(head_write++) = config.pointer_bhiksha_bits; +} + +void ArrayBhiksha::LoadedBinary() { +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh new file mode 100644 index 00000000..cfb2b053 --- /dev/null +++ b/klm/lm/bhiksha.hh @@ -0,0 +1,108 @@ +/* Simple implementation of + * @inproceedings{bhikshacompression, + *  author={Bhiksha Raj and Ed Whittaker}, + *  year={2003}, + *  title={Lossless Compression of Language Model Structure and Word Identifiers}, + *  booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing}, + *  pages={388--391}, + *  } + * + *  Currently only used for next pointers.   + */ + +#include <inttypes.h> + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "util/bit_packing.hh" +#include "util/sorted_uniform.hh" + +namespace lm { +namespace ngram { +class Config; + +namespace trie { + +class DontBhiksha { +  public: +    static const ModelType kModelTypeAdd = static_cast<ModelType>(0); + +    static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + +    static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + +    static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { +      return util::RequiredBits(max_next); +    } + +    DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config); + +    void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const { +      out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask); +      out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask); +      //assert(out.end >= out.begin); +    } + +    void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) { +      util::WriteInt57(base, bit_offset, next_.bits, value); +    } + +    void FinishedLoading(const Config &/*config*/) {} + +    void LoadedBinary() {} + +    uint8_t InlineBits() const { return next_.bits; } + +  private: +    util::BitsMask next_; +}; + +class ArrayBhiksha { +  public: +    static const ModelType kModelTypeAdd = kArrayAdd; + +    static void UpdateConfigFromBinary(int fd, Config &config); + +    static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + +    static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); + +    ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config); + +    void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const { +      const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor<uint64_t>(), offset_begin_, offset_end_, index); +      const uint64_t *end_it; +      for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} +      --end_it; +      out.begin = ((begin_it - offset_begin_) << next_inline_.bits) |  +        util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); +      out.end = ((end_it - offset_begin_) << next_inline_.bits) |  +        util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); +    } + +    void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { +      uint64_t encode = value >> next_inline_.bits; +      for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index; +      util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask); +    } + +    void FinishedLoading(const Config &config); + +    void LoadedBinary(); + +    uint8_t InlineBits() const { return next_inline_.bits; } + +  private: +    const util::BitsMask next_inline_; + +    const uint64_t *const offset_begin_; +    const uint64_t *const offset_end_; + +    uint64_t *write_to_; + +    void *original_base_; +}; + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 92b1008b..e02e621a 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -40,7 +40,7 @@ struct Sanity {    }  }; -const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};  std::size_t Align8(std::size_t in) {    std::size_t off = in % 8; @@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_    }  } -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) { +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { +  std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad;    if (config.write_mmap) {      // Grow the file to accomodate the search, using zeros.   -    if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) -      UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); +    if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) +      UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed");      // We're skipping over the header and vocab for the search space mmap.  mmap likes page aligned offsets, so some arithmetic to round the offset down.        off_t page_size = sysconf(_SC_PAGE_SIZE); -    off_t alignment_cruft = backing.vocab.size() % page_size; -    backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); +    off_t alignment_cruft = adjusted_vocab % page_size; +    backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);      return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;    } else { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 2b32b450..d28cb6c5 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,12 @@  namespace lm {  namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE_SORTED - TRIE_SORTED);  /*Inspect a file to determine if it is a binary lm.  If not, return false.     * If so, return true and set recognized to the type.  This is the only API in @@ -55,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off);  // Create just enough of a binary file to write vocabulary to it.    uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);  // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.   -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing);  // Write header to binary file.  This is done last to prevent incomplete files  // from loading.    diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 4552c419..b7aee4de 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,12 +15,12 @@ namespace ngram {  namespace {  void Usage(const char *name) { -  std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" -"-u sets the default log10 probability for <unk> if the ARPA file does not have\n" -"one.\n" +  std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" +"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n" +"   Default is -100.  The ARPA file will always take precedence.\n"  "-s allows models to be built even if they do not have <s> and </s>.\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is either probing or trie:\n\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n" +"type is either probing or trie.  Default is probing.\n\n"  "probing uses a probing hash table.  It is the fastest but uses the most memory.\n"  "-p sets the space multiplier and must be >1.0.  The default is 1.5.\n\n"  "trie is a straightforward trie with bit-level packing.  It uses the least\n" @@ -29,10 +29,11 @@ void Usage(const char *name) {  "-t is the temporary directory prefix.  Default is the output file name.\n"  "-m limits memory use for sorting.  Measured in MB.  Default is 1024MB.\n"  "-q turns quantization on and sets the number of bits (e.g. -q 8).\n" -"-b sets backoff quantization bits.  Requires -q and defaults to that value.\n\n" -"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" -"Passing only an input file will print memory usage of each data structure.\n" -"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n"; +"-b sets backoff quantization bits.  Requires -q and defaults to that value.\n" +"-a compresses pointers using an array of offsets.  The parameter is the\n" +"   maximum number of bits encoded by the array.  Memory is minimized subject\n" +"   to the maximum, so pick 255 to minimize memory.\n\n" +"Get a memory estimate by passing an ARPA file without an output file name.\n";    exit(1);  } @@ -63,12 +64,14 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {    std::vector<uint64_t> counts;    util::FilePiece f(file);    lm::ReadARPACounts(f, counts); -  std::size_t sizes[3]; +  std::size_t sizes[5];    sizes[0] = ProbingModel::Size(counts, config);    sizes[1] = TrieModel::Size(counts, config);    sizes[2] = QuantTrieModel::Size(counts, config); -  std::size_t max_length = *std::max_element(sizes, sizes + 3); -  std::size_t min_length = *std::max_element(sizes, sizes + 3); +  sizes[3] = ArrayTrieModel::Size(counts, config); +  sizes[4] = QuantArrayTrieModel::Size(counts, config); +  std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); +  std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t));    std::size_t divide;    char prefix;    if (min_length < (1 << 10) * 10) { @@ -91,7 +94,9 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {    std::cout << prefix << "B\n"      "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"      "trie    " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" -    "trie    " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; +    "trie    " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" +    "trie    " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" +    "trie    " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";  }  void ProbingQuantizationUnsupported() { @@ -106,11 +111,11 @@ void ProbingQuantizationUnsupported() {  int main(int argc, char *argv[]) {    using namespace lm::ngram; -  bool quantize = false, set_backoff_bits = false;    try { +    bool quantize = false, set_backoff_bits = false, bhiksha = false;      lm::ngram::Config config;      int opt; -    while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { +    while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) {        switch(opt) {          case 'q':            config.prob_bits = ParseBitCount(optarg); @@ -121,6 +126,9 @@ int main(int argc, char *argv[]) {            config.backoff_bits = ParseBitCount(optarg);            set_backoff_bits = true;            break; +        case 'a': +          config.pointer_bhiksha_bits = ParseBitCount(optarg); +          bhiksha = true;          case 'u':            config.unknown_missing_logprob = ParseFloat(optarg);            break; @@ -162,9 +170,17 @@ int main(int argc, char *argv[]) {          ProbingModel(from_file, config);        } else if (!strcmp(model_type, "trie")) {          if (quantize) { -          QuantTrieModel(from_file, config); +          if (bhiksha) { +            QuantArrayTrieModel(from_file, config); +          } else { +            QuantTrieModel(from_file, config); +          }          } else { -          TrieModel(from_file, config); +          if (bhiksha) { +            ArrayTrieModel(from_file, config); +          } else { +            TrieModel(from_file, config); +          }          }        } else {          Usage(argv[0]); @@ -173,9 +189,9 @@ int main(int argc, char *argv[]) {        Usage(argv[0]);      }    } -  catch (std::exception &e) { +  catch (const std::exception &e) {      std::cerr << e.what() << std::endl; -    abort(); +    return 1;    }    return 0;  } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 08e1af5c..297589a4 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -20,6 +20,7 @@ Config::Config() :    include_vocab(true),    prob_bits(8),    backoff_bits(8), +  pointer_bhiksha_bits(22),    load_method(util::POPULATE_OR_READ) {}  } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index dcc7cf35..227b8512 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -73,9 +73,12 @@ struct Config {    // Quantization options.  Only effective for QuantTrieModel.  One value is    // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used -  // to quantize.   +  // to quantize (and one of the remaining backoffs will be 0).      uint8_t prob_bits, backoff_bits; +  // Bhiksha compression (simple form).  Only works with trie. +  uint8_t pointer_bhiksha_bits; +    // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a1d10b3d..27e24b1c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -21,6 +21,8 @@ size_t hash_value(const State &state) {  namespace detail { +template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType; +  template <class Search, class VocabularyT> size_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {    return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);  } @@ -56,35 +58,40 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT  template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) {    // Backing file is the ARPA.  Steal it so we can make the backing file the mmap output if any.      util::FilePiece f(backing_.file.release(), file, config.messages); -  std::vector<uint64_t> counts; -  // File counts do not include pruned trigrams that extend to quadgrams etc.   These will be fixed by search_. -  ReadARPACounts(f, counts); - -  if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ".  Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); -  if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); -  if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - -  std::size_t vocab_size = VocabularyT::Size(counts[0], config); -  // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.   -  vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); - -  if (config.write_mmap) { -    WriteWordsWrapper wrap(config.enumerate_vocab); -    vocab_.ConfigureEnumerate(&wrap, counts[0]); -    search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); -    wrap.Write(backing_.file.get()); -  } else { -    vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); -    search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); -  } +  try { +    std::vector<uint64_t> counts; +    // File counts do not include pruned trigrams that extend to quadgrams etc.   These will be fixed by search_. +    ReadARPACounts(f, counts); + +    if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ".  Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); +    if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); +    if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + +    std::size_t vocab_size = VocabularyT::Size(counts[0], config); +    // Setup the binary file for writing the vocab lookup table.  The search_ is responsible for growing the binary file to its needs.   +    vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + +    if (config.write_mmap) { +      WriteWordsWrapper wrap(config.enumerate_vocab); +      vocab_.ConfigureEnumerate(&wrap, counts[0]); +      search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); +      wrap.Write(backing_.file.get()); +    } else { +      vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); +      search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); +    } -  if (!vocab_.SawUnk()) { -    assert(config.unknown_missing != THROW_UP); -    // Default probabilities for unknown.   -    search_.unigram.Unknown().backoff = 0.0; -    search_.unigram.Unknown().prob = config.unknown_missing_logprob; +    if (!vocab_.SawUnk()) { +      assert(config.unknown_missing != THROW_UP); +      // Default probabilities for unknown.   +      search_.unigram.Unknown().backoff = 0.0; +      search_.unigram.Unknown().prob = config.unknown_missing_logprob; +    } +    FinishFile(config, kModelType, counts, backing_); +  } catch (util::Exception &e) { +    e << " Byte: " << f.Offset(); +    throw;    } -  FinishFile(config, kModelType, counts, backing_);  }  template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { @@ -225,8 +232,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,  }  template class GenericModel<ProbingHashedSearch, ProbingVocabulary>;  // HASH_PROBING -template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>; // TRIE_SORTED -template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED +template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>; +template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;  } // namespace detail  } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 1f49a382..21595321 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -1,6 +1,7 @@  #ifndef LM_MODEL__  #define LM_MODEL__ +#include "lm/bhiksha.hh"  #include "lm/binary_format.hh"  #include "lm/config.hh"  #include "lm/facade.hh" @@ -71,6 +72,9 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod    private:      typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;    public: +    // This is the model type returned by RecognizeBinary. +    static const ModelType kModelType; +      /* Get the size of memory that will be mapped given ngram counts.  This       * does not include small non-mapped control structures, such as this class       * itself.   @@ -131,8 +135,6 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod      Backing &MutableBacking() { return backing_; } -    static const ModelType kModelType = Search::kModelType; -      Backing backing_;      VocabularyT vocab_; @@ -152,9 +154,11 @@ typedef ProbingModel Model;  // Smaller implementation.  typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary> ArrayTrieModel; -typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary> QuantArrayTrieModel;  } // namespace ngram  } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 8bf040ff..57c7291c 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -193,6 +193,14 @@ template <class M> void Stateless(const M &model) {    BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.history_[0]);  } +template <class M> void NoUnkCheck(const M &model) { +  WordIndex unk_index = 0; +  State state; + +  FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state); +  BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); +} +  template <class M> void Everything(const M &m) {    Starters(m);    Continuation(m); @@ -231,25 +239,38 @@ template <class ModelT> void LoadingTest() {    Config config;    config.arpa_complain = Config::NONE;    config.messages = NULL; -  ExpectEnumerateVocab enumerate; -  config.enumerate_vocab = &enumerate;    config.probing_multiplier = 2.0; -  ModelT m("test.arpa", config); -  enumerate.Check(m.GetVocabulary()); -  Everything(m); +  { +    ExpectEnumerateVocab enumerate; +    config.enumerate_vocab = &enumerate; +    ModelT m("test.arpa", config); +    enumerate.Check(m.GetVocabulary()); +    Everything(m); +  } +  { +    ExpectEnumerateVocab enumerate; +    config.enumerate_vocab = &enumerate; +    ModelT m("test_nounk.arpa", config); +    enumerate.Check(m.GetVocabulary()); +    NoUnkCheck(m); +  }  }  BOOST_AUTO_TEST_CASE(probing) {    LoadingTest<Model>();  } -  BOOST_AUTO_TEST_CASE(trie) {    LoadingTest<TrieModel>();  } - -BOOST_AUTO_TEST_CASE(quant) { +BOOST_AUTO_TEST_CASE(quant_trie) {    LoadingTest<QuantTrieModel>();  } +BOOST_AUTO_TEST_CASE(bhiksha_trie) { +  LoadingTest<ArrayTrieModel>(); +} +BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { +  LoadingTest<QuantArrayTrieModel>(); +}  template <class ModelT> void BinaryTest() {    Config config; @@ -267,10 +288,34 @@ template <class ModelT> void BinaryTest() {    config.write_mmap = NULL; -  ModelT binary("test.binary", config); -  enumerate.Check(binary.GetVocabulary()); -  Everything(binary); +  ModelType type; +  BOOST_REQUIRE(RecognizeBinary("test.binary", type)); +  BOOST_CHECK_EQUAL(ModelT::kModelType, type); + +  { +    ModelT binary("test.binary", config); +    enumerate.Check(binary.GetVocabulary()); +    Everything(binary); +  }    unlink("test.binary"); + +  // Now test without <unk>. +  config.write_mmap = "test_nounk.binary"; +  config.messages = NULL; +  enumerate.Clear(); +  { +    ModelT copy_model("test_nounk.arpa", config); +    enumerate.Check(copy_model.GetVocabulary()); +    enumerate.Clear(); +    NoUnkCheck(copy_model); +  } +  config.write_mmap = NULL; +  { +    ModelT binary("test_nounk.binary", config); +    enumerate.Check(binary.GetVocabulary()); +    NoUnkCheck(binary); +  } +  unlink("test_nounk.binary");  }  BOOST_AUTO_TEST_CASE(write_and_read_probing) { @@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) {  BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {    BinaryTest<QuantTrieModel>();  } +BOOST_AUTO_TEST_CASE(write_and_read_array_trie) { +  BinaryTest<ArrayTrieModel>(); +} +BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) { +  BinaryTest<QuantArrayTrieModel>(); +}  } // namespace  } // namespace ngram diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 9454a6d1..d9db4aa2 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -99,6 +99,15 @@ int main(int argc, char *argv[]) {        case lm::ngram::TRIE_SORTED:          Query<lm::ngram::TrieModel>(argv[1], sentence_context);          break; +      case lm::ngram::QUANT_TRIE_SORTED: +        Query<lm::ngram::QuantTrieModel>(argv[1], sentence_context); +        break; +      case lm::ngram::ARRAY_TRIE_SORTED: +        Query<lm::ngram::ArrayTrieModel>(argv[1], sentence_context); +        break; +      case lm::ngram::QUANT_ARRAY_TRIE_SORTED: +        Query<lm::ngram::QuantArrayTrieModel>(argv[1], sentence_context); +        break;        case lm::ngram::HASH_SORTED:        default:          std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index 4bb6b1b8..fd371cc8 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -43,6 +43,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64    if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1)       UTIL_THROW(util::ErrnoException, "Failed to read header for quantization.");    if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); +  AdvanceOrThrow(fd, -3);  }  void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index aae72b34..0b71d14a 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -21,7 +21,7 @@ class Config;  /* Store values directly and don't quantize. */  class DontQuantize {    public: -    static const ModelType kModelType = TRIE_SORTED; +    static const ModelType kModelTypeAdd = static_cast<ModelType>(0);      static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {}      static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }      static uint8_t MiddleBits(const Config &/*config*/) { return 63; } @@ -108,7 +108,7 @@ class SeparatelyQuantize {      };    public: -    static const ModelType kModelType = QUANT_TRIE_SORTED; +    static const ModelType kModelTypeAdd = kQuantAdd;      static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 060a97ea..455bc4ba 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -31,15 +31,15 @@ const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";  void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {    number.clear();    StringPiece line; -  if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { +  while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} +  if (line != "\\data\\") {      if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {        UTIL_THROW(FormatLoadException, "Looks like a gzip file.  If this is an ARPA file, pipe " << in.FileName() << " through zcat.  If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");      }      if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)         UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser.  Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); -    UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank"); +    UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\.");    } -  if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\.");    while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {      if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");      // So strtol doesn't go off the end of line.   diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index c56ba7b8..82c53ec8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -98,7 +98,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,  template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {    // TODO: fix sorted. -  SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); +  SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config);    PositiveProbWarn warn(config.positive_log_probability); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index f3acdefc..c62985e4 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -52,12 +52,11 @@ struct HashedSearch {    Unigram unigram; -  bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { +  void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {      const ProbBackoff &entry = unigram.Lookup(word);      prob = entry.prob;      backoff = entry.backoff;      next = static_cast<Node>(word); -    return true;    }  }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 91f87f1c..05059ffb 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,6 +1,7 @@  /* This is where the trie is built.  It's on-disk.  */  #include "lm/search_trie.hh" +#include "lm/bhiksha.hh"  #include "lm/blank.hh"  #include "lm/lm_exception.hh"  #include "lm/max_order.hh" @@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uin      std::string unigram_name = file_prefix + "unigrams";      util::scoped_fd unigram_file;      // In case <unk> appears.   -    size_t extra_count = counts[0] + 1; -    util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); +    size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); +    util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out);      Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);      CheckSpecials(config, vocab);      if (!vocab.SawUnk()) ++counts[0]; @@ -610,9 +611,9 @@ class JustCount {  };  // Phase to actually write n-grams to the trie.   -template <class Quant> class WriteEntries { +template <class Quant, class Bhiksha> class WriteEntries {    public: -    WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :  +    WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) :         contexts_(contexts),        unigrams_(unigrams),        middle_(middle), @@ -649,7 +650,7 @@ template <class Quant> class WriteEntries {    private:      ContextReader *contexts_;      UnigramValue *const unigrams_; -    BitPackedMiddle<typename Quant::Middle> *const middle_; +    BitPackedMiddle<typename Quant::Middle, Bhiksha> *const middle_;      BitPackedLongest<typename Quant::Longest> &longest_;      BitPacked &bigram_pack_;  }; @@ -821,7 +822,7 @@ template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, So  } // namespace -template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing) { +template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) {    std::vector<SortedFileReader> inputs(counts.size() - 1);    std::vector<ContextReader> contexts(counts.size() - 1); @@ -846,7 +847,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto    SanityCheckCounts(counts, fixed_counts);    counts = fixed_counts; -  out.SetupMemory(GrowForSearch(config, TrieSearch<Quant>::Size(fixed_counts, config), backing), fixed_counts, config); +  out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), backing), fixed_counts, config);    if (Quant::kTrain) {      util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); @@ -863,7 +864,7 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto    UnigramValue *unigrams = out.unigram.Raw();    // Fill entries except unigram probabilities.      { -    RecursiveInsert<WriteEntries<Quant> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); +    RecursiveInsert<WriteEntries<Quant, Bhiksha> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size());      inserter.Apply(config.messages, "Building trie", fixed_counts[0]);    } @@ -901,14 +902,14 @@ template <class Quant> void BuildTrie(const std::string &file_prefix, std::vecto    /* Set ending offsets so the last entry will be sized properly */    // Last entry for unigrams was already set.      if (out.middle_begin_ != out.middle_end_) { -    for (typename TrieSearch<Quant>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { -      i->FinishedLoading((i+1)->InsertIndex()); +    for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { +      i->FinishedLoading((i+1)->InsertIndex(), config);      } -    (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); +    (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config);    }    } -template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { +template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {    quant_.SetupMemory(start, config);    start += Quant::Size(counts.size(), config);    unigram.Init(start); @@ -919,22 +920,24 @@ template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, c    std::vector<uint8_t*> middle_starts(counts.size() - 2);    for (unsigned char i = 2; i < counts.size(); ++i) {      middle_starts[i-2] = start; -    start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); +    start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config);    } -  // Crazy backwards thing so we initialize in the correct order.   +  // Crazy backwards thing so we initialize using pointers to ones that have already been initialized    for (unsigned char i = counts.size() - 1; i >= 2; --i) {      new (middle_begin_ + i - 2) Middle(          middle_starts[i-2],          quant_.Mid(i), +        counts[i-1],          counts[0],          counts[i], -        (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1])); +        (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1]), +        config);    }    longest.Init(start, quant_.Long(counts.size()), counts[0]);    return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);  } -template <class Quant> void TrieSearch<Quant>::LoadedBinary() { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::LoadedBinary() {    unigram.LoadedBinary();    for (Middle *i = middle_begin_; i != middle_end_; ++i) {      i->LoadedBinary(); @@ -942,7 +945,7 @@ template <class Quant> void TrieSearch<Quant>::LoadedBinary() {    longest.LoadedBinary();  } -template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) {    std::string temporary_directory;    if (config.temporary_directory_prefix) {      temporary_directory = config.temporary_directory_prefix; @@ -966,14 +969,16 @@ template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *fi    // At least 1MB sorting memory.      ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); -  BuildTrie(temporary_directory, counts, config, *this, quant_, backing); +  BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing);    if (rmdir(temporary_directory.c_str()) && config.messages) {      *config.messages << "Failed to delete " << temporary_directory << std::endl;    }  } -template class TrieSearch<DontQuantize>; -template class TrieSearch<SeparatelyQuantize>; +template class TrieSearch<DontQuantize, DontBhiksha>; +template class TrieSearch<DontQuantize, ArrayBhiksha>; +template class TrieSearch<SeparatelyQuantize, DontBhiksha>; +template class TrieSearch<SeparatelyQuantize, ArrayBhiksha>;  } // namespace trie  } // namespace ngram diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0a52acb5..2f39c09f 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,31 +13,33 @@ struct Backing;  class SortedVocabulary;  namespace trie { -template <class Quant> class TrieSearch; -template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); +template <class Quant, class Bhiksha> class TrieSearch; +template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); -template <class Quant> class TrieSearch { +template <class Quant, class Bhiksha> class TrieSearch {    public:      typedef NodeRange Node;      typedef ::lm::ngram::trie::Unigram Unigram;      Unigram unigram; -    typedef trie::BitPackedMiddle<typename Quant::Middle> Middle; +    typedef trie::BitPackedMiddle<typename Quant::Middle, Bhiksha> Middle;      typedef trie::BitPackedLongest<typename Quant::Longest> Longest;      Longest longest; -    static const ModelType kModelType = Quant::kModelType; +    static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);      static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {        Quant::UpdateConfigFromBinary(fd, counts, config); +      AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); +      Bhiksha::UpdateConfigFromBinary(fd, config);      }      static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {        std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);        for (unsigned char i = 1; i < counts.size() - 1; ++i) { -        ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); +        ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);        }        return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);      } @@ -55,8 +57,8 @@ template <class Quant> class TrieSearch {      void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); -    bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { -      return unigram.Find(word, prob, backoff, node); +    void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { +      unigram.Find(word, prob, backoff, node);      }      bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { @@ -83,7 +85,7 @@ template <class Quant> class TrieSearch {      }    private: -    friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); +    friend void BuildTrie<Quant, Bhiksha>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing);      // Middles are managed manually so we can delay construction and they don't have to be copyable.        void FreeMiddles() { diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa new file mode 100644 index 00000000..060733d9 --- /dev/null +++ b/klm/lm/test_nounk.arpa @@ -0,0 +1,120 @@ + +\data\ +ngram 1=36 +ngram 2=45 +ngram 3=10 +ngram 4=6 +ngram 5=4 + +\1-grams: +-1.383514	,	-0.30103 +-1.139057	.	-0.845098 +-1.029493	</s> +-99	<s>	-0.4149733 +-1.285941	a	-0.69897 +-1.687872	also	-0.30103 +-1.687872	beyond	-0.30103 +-1.687872	biarritz	-0.30103 +-1.687872	call	-0.30103 +-1.687872	concerns	-0.30103 +-1.687872	consider	-0.30103 +-1.687872	considering	-0.30103 +-1.687872	for	-0.30103 +-1.509559	higher	-0.30103 +-1.687872	however	-0.30103 +-1.687872	i	-0.30103 +-1.687872	immediate	-0.30103 +-1.687872	in	-0.30103 +-1.687872	is	-0.30103 +-1.285941	little	-0.69897 +-1.383514	loin	-0.30103 +-1.687872	look	-0.30103 +-1.285941	looking	-0.4771212 +-1.206319	more	-0.544068 +-1.509559	on	-0.4771212 +-1.509559	screening	-0.4771212 +-1.687872	small	-0.30103 +-1.687872	the	-0.30103 +-1.687872	to	-0.30103 +-1.687872	watch	-0.30103 +-1.687872	watching	-0.30103 +-1.687872	what	-0.30103 +-1.687872	would	-0.30103 +-3.141592	foo +-2.718281	bar	3.0 +-6.535897	baz	-0.0 + +\2-grams: +-0.6925742	, . +-0.7522095	, however +-0.7522095	, is +-0.0602359	. </s> +-0.4846522	<s> looking	-0.4771214 +-1.051485	<s> screening +-1.07153	<s> the +-1.07153	<s> watching +-1.07153	<s> what +-0.09132547	a little	-0.69897 +-0.2922095	also call +-0.2922095	beyond immediate +-0.2705918	biarritz . +-0.2922095	call for +-0.2922095	concerns in +-0.2922095	consider watch +-0.2922095	considering consider +-0.2834328	for , +-0.5511513	higher more +-0.5845945	higher small +-0.2834328	however , +-0.2922095	i would +-0.2922095	immediate concerns +-0.2922095	in biarritz +-0.2922095	is to +-0.09021038	little more	-0.1998621 +-0.7273645	loin , +-0.6925742	loin . +-0.6708385	loin </s> +-0.2922095	look beyond +-0.4638903	looking higher +-0.4638903	looking on	-0.4771212 +-0.5136299	more .	-0.4771212 +-0.3561665	more loin +-0.1649931	on a	-0.4771213 +-0.1649931	screening a	-0.4771213 +-0.2705918	small . +-0.287799	the screening +-0.2922095	to look +-0.2622373	watch </s> +-0.2922095	watching considering +-0.2922095	what i +-0.2922095	would also +-2	also would	-6 +-6	foo bar + +\3-grams: +-0.01916512	more . </s> +-0.0283603	on a little	-0.4771212 +-0.0283603	screening a little	-0.4771212 +-0.01660496	a little more	-0.09409451 +-0.3488368	<s> looking higher +-0.3488368	<s> looking on	-0.4771212 +-0.1892331	little more loin +-0.04835128	looking on a	-0.4771212 +-3	also would consider	-7 +-7	to look good + +\4-grams: +-0.009249173	looking on a little	-0.4771212 +-0.005464747	on a little more	-0.4771212 +-0.005464747	screening a little more +-0.1453306	a little more loin +-0.01552657	<s> looking on a	-0.4771212 +-4	also would consider higher	-8 + +\5-grams: +-0.003061223	<s> looking on a little +-0.001813953	looking on a little more +-0.0432557	on a little more loin +-5	also would consider higher looking + +\end\ diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 63c2a612..8c536e66 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,5 +1,6 @@  #include "lm/trie.hh" +#include "lm/bhiksha.hh"  #include "lm/quantize.hh"  #include "util/bit_packing.hh"  #include "util/exception.hh" @@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits)    max_vocab_ = max_vocab;  } -template <class Quant> std::size_t BitPackedMiddle<Quant>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { -  return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); +template <class Quant, class Bhiksha> std::size_t BitPackedMiddle<Quant, Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { +  return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));  } -template <class Quant> BitPackedMiddle<Quant>::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { -  if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order.  Edit util/bit_packing.hh and fix the bit packing functions."); -  BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); +template <class Quant, class Bhiksha> BitPackedMiddle<Quant, Bhiksha>::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : +  BitPacked(), +  quant_(quant), +  // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. +  bhiksha_(base, entries + 1, max_next, config), +  next_source_(&next_source) { +  if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57)))  UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order.  Edit util/bit_packing.hh and fix the bit packing functions."); +  BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits());  } -template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float prob, float backoff) { +template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::Insert(WordIndex word, float prob, float backoff) {    assert(word <= word_mask_);    uint64_t at_pointer = insert_index_ * total_bits_; @@ -75,47 +81,42 @@ template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float    quant_.Write(base_, at_pointer, prob, backoff);    at_pointer += quant_.TotalBits();    uint64_t next = next_source_->InsertIndex(); -  assert(next <= next_mask_); -  util::WriteInt57(base_, at_pointer, next_bits_, next); +  bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);    ++insert_index_;  } -template <class Quant> bool BitPackedMiddle<Quant>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const {    uint64_t at_pointer;    if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {      return false;    } +  uint64_t index = at_pointer;    at_pointer *= total_bits_;    at_pointer += word_bits_;    quant_.Read(base_, at_pointer, prob, backoff);    at_pointer += quant_.TotalBits(); -  range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); -  // Read the next entry's pointer.   -  at_pointer += total_bits_; -  range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); +  bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); +    return true;  } -template <class Quant> bool BitPackedMiddle<Quant>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { -  uint64_t at_pointer; -  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; -  at_pointer *= total_bits_; +template <class Quant, class Bhiksha> bool BitPackedMiddle<Quant, Bhiksha>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { +  uint64_t index; +  if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; +  uint64_t at_pointer = index * total_bits_;    at_pointer += word_bits_;    quant_.ReadBackoff(base_, at_pointer, backoff);    at_pointer += quant_.TotalBits(); -  range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); -  // Read the next entry's pointer.   -  at_pointer += total_bits_; -  range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); +  bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range);    return true;  } -template <class Quant> void BitPackedMiddle<Quant>::FinishedLoading(uint64_t next_end) { -  assert(next_end <= next_mask_); -  uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; -  util::WriteInt57(base_, last_next_write, next_bits_, next_end); +template <class Quant, class Bhiksha> void BitPackedMiddle<Quant, Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) { +  uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); +  bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); +  bhiksha_.FinishedLoading(config);  }  template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) { @@ -135,8 +136,10 @@ template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float    return true;  } -template class BitPackedMiddle<DontQuantize::Middle>; -template class BitPackedMiddle<SeparatelyQuantize::Middle>; +template class BitPackedMiddle<DontQuantize::Middle, DontBhiksha>; +template class BitPackedMiddle<DontQuantize::Middle, ArrayBhiksha>; +template class BitPackedMiddle<SeparatelyQuantize::Middle, DontBhiksha>; +template class BitPackedMiddle<SeparatelyQuantize::Middle, ArrayBhiksha>;  template class BitPackedLongest<DontQuantize::Longest>;  template class BitPackedLongest<SeparatelyQuantize::Longest>; diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 8fa21aaf..53612064 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -10,6 +10,7 @@  namespace lm {  namespace ngram { +class Config;  namespace trie {  struct NodeRange { @@ -46,13 +47,12 @@ class Unigram {      void LoadedBinary() {} -    bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { +    void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const {        UnigramValue *val = unigram_ + word;        prob = val->weights.prob;        backoff = val->weights.backoff;        next.begin = val->next;        next.end = (val+1)->next; -      return true;      }    private: @@ -67,8 +67,6 @@ class BitPacked {        return insert_index_;      } -    void LoadedBinary() {} -    protected:      static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); @@ -83,30 +81,30 @@ class BitPacked {      uint64_t insert_index_, max_vocab_;  }; -template <class Quant> class BitPackedMiddle : public BitPacked { +template <class Quant, class Bhiksha> class BitPackedMiddle : public BitPacked {    public: -    static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); +    static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);      // next_source need not be initialized.   -    BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); +    BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);      void Insert(WordIndex word, float prob, float backoff); +    void FinishedLoading(uint64_t next_end, const Config &config); + +    void LoadedBinary() { bhiksha_.LoadedBinary(); } +      bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;      bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; -    void FinishedLoading(uint64_t next_end); -    private:      Quant quant_; -    uint8_t next_bits_; -    uint64_t next_mask_; +    Bhiksha bhiksha_;      const BitPacked *next_source_;  }; -  template <class Quant> class BitPackedLongest : public BitPacked {    public:      static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { @@ -120,6 +118,8 @@ template <class Quant> class BitPackedLongest : public BitPacked {        BaseInit(base, max_vocab, quant_.TotalBits());      } +    void LoadedBinary() {} +      void Insert(WordIndex word, float prob);      bool Find(WordIndex word, float &prob, const NodeRange &node) const; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 7defd5c1..04979d51 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -37,14 +37,14 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) {    WordIndex index = 0;    while (true) {      ssize_t got = read(fd, &buf[0], kInitialRead); -    if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); +    UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words");      if (got == 0) return index;      buf.resize(got);      while (buf[buf.size() - 1]) {        char next_char;        ssize_t ret = read(fd, &next_char, 1); -      if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); -      if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word."); +      UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); +      UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word.");        buf.push_back(next_char);      }      // Ok now we have null terminated strings.   diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index c92518e4..9d218fff 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary {        }      } +    // Size for purposes of file writing      static size_t Size(std::size_t entries, const Config &config);      // Vocab words are [0, Bound())  Only valid after FinishedLoading/LoadedBinary.   @@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary {      // Reorders reorder_vocab so that the IDs are sorted.        void FinishedLoading(ProbBackoff *reorder_vocab); +    // Trie stores the correct counts including <unk> in the header.  If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>. +    std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } +      bool SawUnk() const { return saw_unk_; }      void LoadedBinary(int fd, EnumerateVocab *to); diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index b35d80c8..9f47d559 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -107,9 +107,20 @@ void BitPackingSanity();  uint8_t RequiredBits(uint64_t max_value);  struct BitsMask { +  static BitsMask ByMax(uint64_t max_value) { +    BitsMask ret; +    ret.FromMax(max_value); +    return ret; +  } +  static BitsMask ByBits(uint8_t bits) { +    BitsMask ret; +    ret.bits = bits; +    ret.mask = (1ULL << bits) - 1; +    return ret; +  }    void FromMax(uint64_t max_value) {      bits = RequiredBits(max_value); -    mask = (1 << bits) - 1; +    mask = (1ULL << bits) - 1;    }    uint8_t bits;    uint64_t mask; diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 84f9fe7c..62280970 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -1,5 +1,9 @@  #include "util/exception.hh" +#ifdef __GXX_RTTI +#include <typeinfo> +#endif +  #include <errno.h>  #include <string.h> @@ -22,6 +26,30 @@ const char *Exception::what() const throw() {    return text_.c_str();  } +void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) { +  /* The child class might have set some text, but we want this to come first. +   * Another option would be passing this information to the constructor, but +   * then child classes would have to accept constructor arguments and pass +   * them down.   +   */ +  text_ = stream_.str(); +  stream_.str(""); +  stream_ << file << ':' << line; +  if (func) stream_ << " in " << func << " threw "; +  if (child_name) { +    stream_ << child_name; +  } else { +#ifdef __GXX_RTTI +    stream_ << typeid(this).name(); +#else +    stream_ << "an exception"; +#endif +  } +  if (condition) stream_ << " because `" << condition; +  stream_ << "'.\n"; +  stream_ << text_; +} +  namespace {  // The XOPEN version.  const char *HandleStrerror(int ret, const char *buf) { diff --git a/klm/util/exception.hh b/klm/util/exception.hh index c6936914..81675a57 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -1,8 +1,6 @@  #ifndef UTIL_EXCEPTION__  #define UTIL_EXCEPTION__ -#include "util/string_piece.hh" -  #include <exception>  #include <sstream>  #include <string> @@ -22,6 +20,14 @@ class Exception : public std::exception {      // Not threadsafe, but probably doesn't matter.  FWIW, Boost's exception guidance implies that what() isn't threadsafe.        const char *what() const throw(); +    // For use by the UTIL_THROW macros.   +    void SetLocation( +        const char *file, +        unsigned int line, +        const char *func, +        const char *child_name, +        const char *condition); +    private:      template <class Except, class Data> friend typename Except::template ExceptionTag<Except&>::Identity operator<<(Except &e, const Data &data); @@ -43,7 +49,49 @@ template <class Except, class Data> typename Except::template ExceptionTag<Excep    return e;  } -#define UTIL_THROW(Exception, Modify) { Exception UTIL_e; {UTIL_e << Modify;} throw UTIL_e; } +#ifdef __GNUC__ +#define UTIL_FUNC_NAME __PRETTY_FUNCTION__ +#else +#ifdef _WIN32 +#define UTIL_FUNC_NAME __FUNCTION__ +#else +#define UTIL_FUNC_NAME NULL +#endif +#endif + +#define UTIL_SET_LOCATION(UTIL_e, child, condition) do { \ +  (UTIL_e).SetLocation(__FILE__, __LINE__, UTIL_FUNC_NAME, (child), (condition)); \ +} while (0) + +/* Create an instance of Exception, add the message Modify, and throw it. + * Modify is appended to the what() message and can contain << for ostream + * operations.   + * + * do .. while kludge to swallow trailing ; character + * http://gcc.gnu.org/onlinedocs/cpp/Swallowing-the-Semicolon.html .   + */ +#define UTIL_THROW(Exception, Modify) do { \ +  Exception UTIL_e; \ +  UTIL_SET_LOCATION(UTIL_e, #Exception, NULL); \ +  UTIL_e << Modify; \ +  throw UTIL_e; \ +} while (0) + +#define UTIL_THROW_VAR(Var, Modify) do { \ +  Exception &UTIL_e = (Var); \ +  UTIL_SET_LOCATION(UTIL_e, NULL, NULL); \ +  UTIL_e << Modify; \ +  throw UTIL_e; \ +} while (0) + +#define UTIL_THROW_IF(Condition, Exception, Modify) do { \ +  if (Condition) { \ +    Exception UTIL_e; \ +    UTIL_SET_LOCATION(UTIL_e, #Exception, #Condition); \ +    UTIL_e << Modify; \ +    throw UTIL_e; \ +  } \ +} while (0)  class ErrnoException : public Exception {    public: @@ -51,7 +99,7 @@ class ErrnoException : public Exception {      virtual ~ErrnoException() throw(); -    int Error() { return errno_; } +    int Error() const throw() { return errno_; }    private:      int errno_; diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index f447a70c..cbe4234f 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -41,8 +41,8 @@ GZException::GZException(void *file) {  const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};  int OpenReadOrThrow(const char *name) { -  int ret = open(name, O_RDONLY); -  if (ret == -1) UTIL_THROW(ErrnoException, "in open (" << name << ") for reading"); +  int ret; +  UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name);    return ret;  } @@ -52,13 +52,13 @@ off_t SizeFile(int fd) {    return sb.st_size;  } -FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :  +FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) :     file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),    progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {    Initialize(name, show_progress, min_buffer);  } -FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) :  +FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, off_t min_buffer)  :     file_(fd), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)),    progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) {    Initialize(name, show_progress, min_buffer); @@ -78,7 +78,7 @@ FilePiece::~FilePiece() {  #endif  } -StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileException) { +StringPiece FilePiece::ReadLine(char delim) {    size_t skip = 0;    while (true) {      for (const char *i = position_ + skip; i < position_end_; ++i) { @@ -97,20 +97,20 @@ StringPiece FilePiece::ReadLine(char delim) throw (GZException, EndOfFileExcepti    }  } -float FilePiece::ReadFloat() throw(GZException, EndOfFileException, ParseNumberException) { +float FilePiece::ReadFloat() {    return ReadNumber<float>();  } -double FilePiece::ReadDouble() throw(GZException, EndOfFileException, ParseNumberException) { +double FilePiece::ReadDouble() {    return ReadNumber<double>();  } -long int FilePiece::ReadLong() throw(GZException, EndOfFileException, ParseNumberException) { +long int FilePiece::ReadLong() {    return ReadNumber<long int>();  } -unsigned long int FilePiece::ReadULong() throw(GZException, EndOfFileException, ParseNumberException) { +unsigned long int FilePiece::ReadULong() {    return ReadNumber<unsigned long int>();  } -void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw (GZException) { +void FilePiece::Initialize(const char *name, std::ostream *show_progress, off_t min_buffer)  {  #ifdef HAVE_ZLIB    gz_file_ = NULL;  #endif @@ -163,7 +163,7 @@ void ParseNumber(const char *begin, char *&end, unsigned long int &out) {  }  } // namespace -template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileException, ParseNumberException) { +template <class T> T FilePiece::ReadNumber() {    SkipSpaces();    while (last_space_ < position_) {      if (at_end_) { @@ -186,7 +186,7 @@ template <class T> T FilePiece::ReadNumber() throw(GZException, EndOfFileExcepti    return ret;  } -const char *FilePiece::FindDelimiterOrEOF(const bool *delim) throw (GZException, EndOfFileException) { +const char *FilePiece::FindDelimiterOrEOF(const bool *delim)  {    size_t skip = 0;    while (true) {      for (const char *i = position_ + skip; i < position_end_; ++i) { @@ -201,7 +201,7 @@ const char *FilePiece::FindDelimiterOrEOF(const bool *delim) throw (GZException,    }  } -void FilePiece::Shift() throw(GZException, EndOfFileException) { +void FilePiece::Shift() {    if (at_end_) {      progress_.Finished();      throw EndOfFileException(); @@ -217,7 +217,7 @@ void FilePiece::Shift() throw(GZException, EndOfFileException) {    }  } -void FilePiece::MMapShift(off_t desired_begin) throw() { +void FilePiece::MMapShift(off_t desired_begin) {    // Use mmap.      off_t ignore = desired_begin % page_;    // Duplicate request for Shift means give more data.   @@ -259,25 +259,23 @@ void FilePiece::MMapShift(off_t desired_begin) throw() {    progress_.Set(desired_begin);  } -void FilePiece::TransitionToRead() throw (GZException) { +void FilePiece::TransitionToRead() {    assert(!fallback_to_read_);    fallback_to_read_ = true;    data_.reset();    data_.reset(malloc(default_map_size_), default_map_size_, scoped_memory::MALLOC_ALLOCATED); -  if (!data_.get()) UTIL_THROW(ErrnoException, "malloc failed for " << default_map_size_); +  UTIL_THROW_IF(!data_.get(), ErrnoException, "malloc failed for " << default_map_size_);    position_ = data_.begin();    position_end_ = position_;  #ifdef HAVE_ZLIB    assert(!gz_file_);    gz_file_ = gzdopen(file_.get(), "r"); -  if (!gz_file_) { -    UTIL_THROW(GZException, "zlib failed to open " << file_name_); -  } +  UTIL_THROW_IF(!gz_file_, GZException, "zlib failed to open " << file_name_);  #endif  } -void FilePiece::ReadShift() throw(GZException, EndOfFileException) { +void FilePiece::ReadShift() {    assert(fallback_to_read_);    // Bytes [data_.begin(), position_) have been consumed.      // Bytes [position_, position_end_) have been read into the buffer.   @@ -297,7 +295,7 @@ void FilePiece::ReadShift() throw(GZException, EndOfFileException) {        std::size_t valid_length = position_end_ - position_;        default_map_size_ *= 2;        data_.call_realloc(default_map_size_); -      if (!data_.get()) UTIL_THROW(ErrnoException, "realloc failed for " << default_map_size_); +      UTIL_THROW_IF(!data_.get(), ErrnoException, "realloc failed for " << default_map_size_);        position_ = data_.begin();        position_end_ = position_ + valid_length;      } else { @@ -320,7 +318,7 @@ void FilePiece::ReadShift() throw(GZException, EndOfFileException) {    }  #else    read_return = read(file_.get(), static_cast<char*>(data_.get()) + already_read, default_map_size_ - already_read); -  if (read_return == -1) UTIL_THROW(ErrnoException, "read failed"); +  UTIL_THROW_IF(read_return == -1, ErrnoException, "read failed");    progress_.Set(mapped_offset_);  #endif    if (read_return == 0) { diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index 870ae5a3..a5c00910 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -45,13 +45,13 @@ off_t SizeFile(int fd);  class FilePiece {    public:      // 32 MB default. -    explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException); +    explicit FilePiece(const char *file, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);      // Takes ownership of fd.  name is used for messages.   -    explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432) throw(GZException); +    explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, off_t min_buffer = 33554432);      ~FilePiece(); -    char get() throw(GZException, EndOfFileException) {  +    char get() {         if (position_ == position_end_) {          Shift();          if (at_end_) throw EndOfFileException(); @@ -60,22 +60,22 @@ class FilePiece {      }      // Leaves the delimiter, if any, to be returned by get().  Delimiters defined by isspace().   -    StringPiece ReadDelimited(const bool *delim = kSpaces) throw(GZException, EndOfFileException) { +    StringPiece ReadDelimited(const bool *delim = kSpaces) {        SkipSpaces(delim);        return Consume(FindDelimiterOrEOF(delim));      }      // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter.      // It is similar to getline in that way.   -    StringPiece ReadLine(char delim = '\n') throw(GZException, EndOfFileException); +    StringPiece ReadLine(char delim = '\n'); -    float ReadFloat() throw(GZException, EndOfFileException, ParseNumberException); -    double ReadDouble() throw(GZException, EndOfFileException, ParseNumberException); -    long int ReadLong() throw(GZException, EndOfFileException, ParseNumberException); -    unsigned long int ReadULong() throw(GZException, EndOfFileException, ParseNumberException); +    float ReadFloat(); +    double ReadDouble(); +    long int ReadLong(); +    unsigned long int ReadULong();      // Skip spaces defined by isspace.   -    void SkipSpaces(const bool *delim = kSpaces) throw (GZException, EndOfFileException) { +    void SkipSpaces(const bool *delim = kSpaces) {        for (; ; ++position_) {          if (position_ == position_end_) Shift();          if (!delim[static_cast<unsigned char>(*position_)]) return; @@ -89,9 +89,9 @@ class FilePiece {      const std::string &FileName() const { return file_name_; }    private: -    void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer) throw(GZException); +    void Initialize(const char *name, std::ostream *show_progress, off_t min_buffer); -    template <class T> T ReadNumber() throw(GZException, EndOfFileException, ParseNumberException); +    template <class T> T ReadNumber();      StringPiece Consume(const char *to) {        StringPiece ret(position_, to - position_); @@ -99,14 +99,14 @@ class FilePiece {        return ret;      } -    const char *FindDelimiterOrEOF(const bool *delim = kSpaces) throw (GZException, EndOfFileException); +    const char *FindDelimiterOrEOF(const bool *delim = kSpaces); -    void Shift() throw (EndOfFileException, GZException); +    void Shift();      // Backends to Shift(). -    void MMapShift(off_t desired_begin) throw (); +    void MMapShift(off_t desired_begin); -    void TransitionToRead() throw (GZException); -    void ReadShift() throw (GZException, EndOfFileException); +    void TransitionToRead(); +    void ReadShift();      const char *position_, *last_space_, *position_end_; diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index d58a0727..fec47fd9 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All
 - * code is released to the public domain. For business purposes, Murmurhash is
 - * under the MIT license."
 - * This is modified from the original:
 - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.  
 - * length changed to unsigned int.  
 - * placed in namespace util
 - * add MurmurHashNative
 - * default option = 0 for seed
 - */
 -
 -#include "util/murmur_hash.hh"
 -
 -namespace util {
 -
 -//-----------------------------------------------------------------------------
 -// MurmurHash2, 64-bit versions, by Austin Appleby
 -
 -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment 
 -// and endian-ness issues if used across multiple platforms.
 -
 -// 64-bit hash for 64-bit platforms
 -
 -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed )
 -{
 -  const uint64_t m = 0xc6a4a7935bd1e995ULL;
 -  const int r = 47;
 -
 -  uint64_t h = seed ^ (len * m);
 -
 -  const uint64_t * data = (const uint64_t *)key;
 -  const uint64_t * end = data + (len/8);
 -
 -  while(data != end)
 -  {
 -    uint64_t k = *data++;
 -
 -    k *= m; 
 -    k ^= k >> r; 
 -    k *= m; 
 -    
 -    h ^= k;
 -    h *= m; 
 -  }
 -
 -  const unsigned char * data2 = (const unsigned char*)data;
 -
 -  switch(len & 7)
 -  {
 -  case 7: h ^= uint64_t(data2[6]) << 48;
 -  case 6: h ^= uint64_t(data2[5]) << 40;
 -  case 5: h ^= uint64_t(data2[4]) << 32;
 -  case 4: h ^= uint64_t(data2[3]) << 24;
 -  case 3: h ^= uint64_t(data2[2]) << 16;
 -  case 2: h ^= uint64_t(data2[1]) << 8;
 -  case 1: h ^= uint64_t(data2[0]);
 -          h *= m;
 -  };
 - 
 -  h ^= h >> r;
 -  h *= m;
 -  h ^= h >> r;
 -
 -  return h;
 -} 
 -
 -
 -// 64-bit hash for 32-bit platforms
 -
 -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed )
 -{
 -  const unsigned int m = 0x5bd1e995;
 -  const int r = 24;
 -
 -  unsigned int h1 = seed ^ len;
 -  unsigned int h2 = 0;
 -
 -  const unsigned int * data = (const unsigned int *)key;
 -
 -  while(len >= 8)
 -  {
 -    unsigned int k1 = *data++;
 -    k1 *= m; k1 ^= k1 >> r; k1 *= m;
 -    h1 *= m; h1 ^= k1;
 -    len -= 4;
 -
 -    unsigned int k2 = *data++;
 -    k2 *= m; k2 ^= k2 >> r; k2 *= m;
 -    h2 *= m; h2 ^= k2;
 -    len -= 4;
 -  }
 -
 -  if(len >= 4)
 -  {
 -    unsigned int k1 = *data++;
 -    k1 *= m; k1 ^= k1 >> r; k1 *= m;
 -    h1 *= m; h1 ^= k1;
 -    len -= 4;
 -  }
 -
 -  switch(len)
 -  {
 -  case 3: h2 ^= ((unsigned char*)data)[2] << 16;
 -  case 2: h2 ^= ((unsigned char*)data)[1] << 8;
 -  case 1: h2 ^= ((unsigned char*)data)[0];
 -      h2 *= m;
 -  };
 -
 -  h1 ^= h2 >> 18; h1 *= m;
 -  h2 ^= h1 >> 22; h2 *= m;
 -  h1 ^= h2 >> 17; h1 *= m;
 -  h2 ^= h1 >> 19; h2 *= m;
 -
 -  uint64_t h = h1;
 -
 -  h = (h << 32) | h2;
 -
 -  return h;
 -}
 -
 -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) {
 -  if (sizeof(int) == 4) {
 -    return MurmurHash64B(key, len, seed);
 -  } else {
 -    return MurmurHash64A(key, len, seed);
 -  }
 -}
 -
 -} // namespace util
 +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit.   + * length changed to unsigned int.   + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment  +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ +  const uint64_t m = 0xc6a4a7935bd1e995ULL; +  const int r = 47; + +  uint64_t h = seed ^ (len * m); + +  const uint64_t * data = (const uint64_t *)key; +  const uint64_t * end = data + (len/8); + +  while(data != end) +  { +    uint64_t k = *data++; + +    k *= m;  +    k ^= k >> r;  +    k *= m;  +     +    h ^= k; +    h *= m;  +  } + +  const unsigned char * data2 = (const unsigned char*)data; + +  switch(len & 7) +  { +  case 7: h ^= uint64_t(data2[6]) << 48; +  case 6: h ^= uint64_t(data2[5]) << 40; +  case 5: h ^= uint64_t(data2[4]) << 32; +  case 4: h ^= uint64_t(data2[3]) << 24; +  case 3: h ^= uint64_t(data2[2]) << 16; +  case 2: h ^= uint64_t(data2[1]) << 8; +  case 1: h ^= uint64_t(data2[0]); +          h *= m; +  }; +  +  h ^= h >> r; +  h *= m; +  h ^= h >> r; + +  return h; +}  + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ +  const unsigned int m = 0x5bd1e995; +  const int r = 24; + +  unsigned int h1 = seed ^ len; +  unsigned int h2 = 0; + +  const unsigned int * data = (const unsigned int *)key; + +  while(len >= 8) +  { +    unsigned int k1 = *data++; +    k1 *= m; k1 ^= k1 >> r; k1 *= m; +    h1 *= m; h1 ^= k1; +    len -= 4; + +    unsigned int k2 = *data++; +    k2 *= m; k2 ^= k2 >> r; k2 *= m; +    h2 *= m; h2 ^= k2; +    len -= 4; +  } + +  if(len >= 4) +  { +    unsigned int k1 = *data++; +    k1 *= m; k1 ^= k1 >> r; k1 *= m; +    h1 *= m; h1 ^= k1; +    len -= 4; +  } + +  switch(len) +  { +  case 3: h2 ^= ((unsigned char*)data)[2] << 16; +  case 2: h2 ^= ((unsigned char*)data)[1] << 8; +  case 1: h2 ^= ((unsigned char*)data)[0]; +      h2 *= m; +  }; + +  h1 ^= h2 >> 18; h1 *= m; +  h2 ^= h1 >> 22; h2 *= m; +  h1 ^= h2 >> 17; h1 *= m; +  h2 ^= h1 >> 19; h2 *= m; + +  uint64_t h = h1; + +  h = (h << 32) | h2; + +  return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { +  if (sizeof(int) == 4) { +    return MurmurHash64B(key, len, seed); +  } else { +    return MurmurHash64A(key, len, seed); +  } +} + +} // namespace util diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 00be0ed7..2ec342a6 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -57,7 +57,7 @@ template <class PackingT, class HashT, class EqualT = std::equal_to<typename Pac          equal_(equal_func),          entries_(0)  #ifdef DEBUG -        , initialized_(true), +        , initialized_(true)  #endif      {} diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 84d7aa02..0d6ecbbd 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -12,7 +12,7 @@ namespace util {  template <class T> class IdentityAccessor {    public:      typedef T Key; -    T operator()(const uint64_t *in) const { return *in; } +    T operator()(const T *in) const { return *in; }  };  struct Pivot64 { @@ -101,6 +101,27 @@ template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(co    return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out);  } +// May return begin - 1. +template <class Iterator, class Accessor> Iterator BinaryBelow( +    const Accessor &accessor, +    Iterator begin, +    Iterator end, +    const typename Accessor::Key key) { +  while (end > begin) { +    Iterator pivot(begin + (end - begin) / 2); +    typename Accessor::Key mid(accessor(pivot)); +    if (mid < key) { +      begin = pivot + 1; +    } else if (mid > key) { +      end = pivot; +    } else { +      for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {} +      return pivot - 1; +    } +  } +  return begin - 1; +} +  // To use this template, you need to define a Pivot function to match Key.    template <class PackingT> class SortedUniformMap {    public: | 
