From 7607b0a7873f52d6e3ea387bf88c773cbb55f8ee Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 18 Aug 2011 12:14:01 +0100 Subject: KenLM update: Bhiksha's trick, simple test for lms without unk, auto-detect binary files instead of requiring them to be specified at runtime. --- decoder/cdec_ff.cc | 5 +- decoder/ff_klm.cc | 38 +++++- decoder/ff_klm.h | 7 +- klm/compile.sh | 2 +- klm/lm/Makefile.am | 1 + klm/lm/bhiksha.cc | 93 +++++++++++++++ klm/lm/bhiksha.hh | 108 +++++++++++++++++ klm/lm/binary_format.cc | 13 ++- klm/lm/binary_format.hh | 9 +- klm/lm/build_binary.cc | 54 ++++++--- klm/lm/config.cc | 1 + klm/lm/config.hh | 5 +- klm/lm/model.cc | 67 ++++++----- klm/lm/model.hh | 12 +- klm/lm/model_test.cc | 73 ++++++++++-- klm/lm/ngram_query.cc | 9 ++ klm/lm/quantize.cc | 1 + klm/lm/quantize.hh | 4 +- klm/lm/read_arpa.cc | 6 +- klm/lm/search_hashed.cc | 2 +- klm/lm/search_hashed.hh | 3 +- klm/lm/search_trie.cc | 45 +++---- klm/lm/search_trie.hh | 20 ++-- klm/lm/test_nounk.arpa | 120 +++++++++++++++++++ klm/lm/trie.cc | 57 ++++----- klm/lm/trie.hh | 24 ++-- klm/lm/vocab.cc | 6 +- klm/lm/vocab.hh | 4 + klm/util/bit_packing.hh | 13 ++- klm/util/murmur_hash.cc | 258 ++++++++++++++++++++--------------------- klm/util/probing_hash_table.hh | 2 +- klm/util/sorted_uniform.hh | 23 +++- 32 files changed, 792 insertions(+), 293 deletions(-) create mode 100644 klm/lm/bhiksha.cc create mode 100644 klm/lm/bhiksha.hh create mode 100644 klm/lm/test_nounk.arpa diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 3451c9fb..1ef76a05 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -55,10 +55,7 @@ void register_feature_functions() { ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); - ff_registry.Register("KLanguageModel", new FFFactory >()); - ff_registry.Register("KLanguageModel_Trie", new FFFactory >()); - ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory >()); - ff_registry.Register("KLanguageModel_Probing", new FFFactory >()); + ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); ff_registry.Register("NonLatinCount", new FFFactory); ff_registry.Register("RuleShape", new FFFactory); ff_registry.Register("RelativeSentencePosition", new FFFactory); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 9b7fe2d3..24dcb9c3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -9,6 +9,7 @@ #include "stringlib.h" #include "hg.h" #include "tdict.h" +#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; @@ -434,8 +435,37 @@ void KLanguageModel::FinalTraversalFeatures(const void* ant_state, features->set_value(oov_fid_, oovs); } -// instantiate templates -template class KLanguageModel; -template class KLanguageModel; -template class KLanguageModel; +template boost::shared_ptr CreateModel(const std::string ¶m) { + KLanguageModel *ret = new KLanguageModel(param); + ret->Init(); + return boost::shared_ptr(ret); +} +boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel(param); + case TRIE_SORTED: + return CreateModel(param); + case ARRAY_TRIE_SORTED: + return CreateModel(param); + case QUANT_TRIE_SORTED: + return CreateModel(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel::usage(params, verbose); +} diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 5eafe8be..6efe50f6 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -4,8 +4,8 @@ #include #include +#include "ff_factory.h" #include "ff.h" -#include "lm/model.hh" template struct KLanguageModelImpl; @@ -34,4 +34,9 @@ class KLanguageModel : public FeatureFunction { KLanguageModelImpl* pimpl_; }; +struct KLanguageModelFactory : public FactoryBase { + FP Create(std::string param) const; + std::string usage(bool params,bool verbose) const; +}; + #endif diff --git a/klm/compile.sh b/klm/compile.sh index 6ca85e1f..abe3473a 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -5,7 +5,7 @@ set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o done g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 395494bc..fae6b41a 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -12,6 +12,7 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz noinst_LIBRARIES = libklm.a libklm_a_SOURCES = \ + bhiksha.cc \ binary_format.cc \ config.cc \ lm_exception.cc \ diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc new file mode 100644 index 00000000..bf86fd4b --- /dev/null +++ b/klm/lm/bhiksha.cc @@ -0,0 +1,93 @@ +#include "lm/bhiksha.hh" +#include "lm/config.hh" + +#include + +namespace lm { +namespace ngram { +namespace trie { + +DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) : + next_(util::BitsMask::ByMax(max_next)) {} + +const uint8_t kArrayBhikshaVersion = 0; + +void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { + uint8_t version; + uint8_t configured_bits; + if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) { + UTIL_THROW(util::ErrnoException, "Could not read from binary file"); + } + if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); + config.pointer_bhiksha_bits = configured_bits; +} + +namespace { + +// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset) +uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t best_chop = 0; + int64_t lowest_change = std::numeric_limits::max(); + // There are probably faster ways but I don't care because this is only done once per order at construction time. + for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) { + int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */ + - max_offset * static_cast(chop); /* savings in bits*/ + if (change < lowest_change) { + lowest_change = change; + best_chop = chop; + } + } + return best_chop; +} + +std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t chopping = ChopBits(max_offset, max_next, config); + return (max_next >> (required - chopping)) + 1 /* we store 0 too */; +} +} // namespace + +std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { + return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; +} + +uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config); +} + +namespace { + +void *AlignTo8(void *from) { + uint8_t *val = reinterpret_cast(from); + std::size_t remainder = reinterpret_cast(val) & 7; + if (!remainder) return val; + return val + 8 - remainder; +} + +} // namespace + +ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config) + : next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))), + offset_begin_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */), + offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)), + write_to_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */), + original_base_(base) {} + +void ArrayBhiksha::FinishedLoading(const Config &config) { + // *offset_begin_ = 0 but without a const_cast. + *(write_to_ - (write_to_ - offset_begin_)) = 0; + + if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected."); + + uint8_t *head_write = reinterpret_cast(original_base_); + *(head_write++) = kArrayBhikshaVersion; + *(head_write++) = config.pointer_bhiksha_bits; +} + +void ArrayBhiksha::LoadedBinary() { +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh new file mode 100644 index 00000000..cfb2b053 --- /dev/null +++ b/klm/lm/bhiksha.hh @@ -0,0 +1,108 @@ +/* Simple implementation of + * @inproceedings{bhikshacompression, + * author={Bhiksha Raj and Ed Whittaker}, + * year={2003}, + * title={Lossless Compression of Language Model Structure and Word Identifiers}, + * booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing}, + * pages={388--391}, + * } + * + * Currently only used for next pointers. + */ + +#include + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "util/bit_packing.hh" +#include "util/sorted_uniform.hh" + +namespace lm { +namespace ngram { +class Config; + +namespace trie { + +class DontBhiksha { + public: + static const ModelType kModelTypeAdd = static_cast(0); + + static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + + static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + + static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { + return util::RequiredBits(max_next); + } + + DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const { + out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask); + out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask); + //assert(out.end >= out.begin); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) { + util::WriteInt57(base, bit_offset, next_.bits, value); + } + + void FinishedLoading(const Config &/*config*/) {} + + void LoadedBinary() {} + + uint8_t InlineBits() const { return next_.bits; } + + private: + util::BitsMask next_; +}; + +class ArrayBhiksha { + public: + static const ModelType kModelTypeAdd = kArrayAdd; + + static void UpdateConfigFromBinary(int fd, Config &config); + + static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + + static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); + + ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const { + const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor(), offset_begin_, offset_end_, index); + const uint64_t *end_it; + for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} + --end_it; + out.begin = ((begin_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); + out.end = ((end_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { + uint64_t encode = value >> next_inline_.bits; + for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index; + util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask); + } + + void FinishedLoading(const Config &config); + + void LoadedBinary(); + + uint8_t InlineBits() const { return next_inline_.bits; } + + private: + const util::BitsMask next_inline_; + + const uint64_t *const offset_begin_; + const uint64_t *const offset_end_; + + uint64_t *write_to_; + + void *original_base_; +}; + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 92b1008b..e02e621a 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -40,7 +40,7 @@ struct Sanity { } }; -const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; std::size_t Align8(std::size_t in) { std::size_t off = in % 8; @@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ } } -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) { +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { + std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { // Grow the file to accomodate the search, using zeros. - if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) - UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); + if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) + UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed"); // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. off_t page_size = sysconf(_SC_PAGE_SIZE); - off_t alignment_cruft = backing.vocab.size() % page_size; - backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); + off_t alignment_cruft = adjusted_vocab % page_size; + backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); return reinterpret_cast(backing.search.get()) + alignment_cruft; } else { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 2b32b450..d28cb6c5 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,12 @@ namespace lm { namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in @@ -55,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off); // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing); // Write header to binary file. This is done last to prevent incomplete files // from loading. diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 4552c419..b7aee4de 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,12 +15,12 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" -"-u sets the default log10 probability for if the ARPA file does not have\n" -"one.\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" +"-u sets the log10 probability for if the ARPA file does not have one.\n" +" Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have and .\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is either probing or trie:\n\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n" +"type is either probing or trie. Default is probing.\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" @@ -29,10 +29,11 @@ void Usage(const char *name) { "-t is the temporary directory prefix. Default is the output file name.\n" "-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" "-q turns quantization on and sets the number of bits (e.g. -q 8).\n" -"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n" -"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" -"Passing only an input file will print memory usage of each data structure.\n" -"If the ARPA file does not have , -u sets 's probability; default 0.0.\n"; +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n" +"-a compresses pointers using an array of offsets. The parameter is the\n" +" maximum number of bits encoded by the array. Memory is minimized subject\n" +" to the maximum, so pick 255 to minimize memory.\n\n" +"Get a memory estimate by passing an ARPA file without an output file name.\n"; exit(1); } @@ -63,12 +64,14 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t sizes[3]; + std::size_t sizes[5]; sizes[0] = ProbingModel::Size(counts, config); sizes[1] = TrieModel::Size(counts, config); sizes[2] = QuantTrieModel::Size(counts, config); - std::size_t max_length = *std::max_element(sizes, sizes + 3); - std::size_t min_length = *std::max_element(sizes, sizes + 3); + sizes[3] = ArrayTrieModel::Size(counts, config); + sizes[4] = QuantArrayTrieModel::Size(counts, config); + std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); + std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); std::size_t divide; char prefix; if (min_length < (1 << 10) * 10) { @@ -91,7 +94,9 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::cout << prefix << "B\n" "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" - "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; + "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" + "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" + "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; } void ProbingQuantizationUnsupported() { @@ -106,11 +111,11 @@ void ProbingQuantizationUnsupported() { int main(int argc, char *argv[]) { using namespace lm::ngram; - bool quantize = false, set_backoff_bits = false; try { + bool quantize = false, set_backoff_bits = false, bhiksha = false; lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -121,6 +126,9 @@ int main(int argc, char *argv[]) { config.backoff_bits = ParseBitCount(optarg); set_backoff_bits = true; break; + case 'a': + config.pointer_bhiksha_bits = ParseBitCount(optarg); + bhiksha = true; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -162,9 +170,17 @@ int main(int argc, char *argv[]) { ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { if (quantize) { - QuantTrieModel(from_file, config); + if (bhiksha) { + QuantArrayTrieModel(from_file, config); + } else { + QuantTrieModel(from_file, config); + } } else { - TrieModel(from_file, config); + if (bhiksha) { + ArrayTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } } else { Usage(argv[0]); @@ -173,9 +189,9 @@ int main(int argc, char *argv[]) { Usage(argv[0]); } } - catch (std::exception &e) { + catch (const std::exception &e) { std::cerr << e.what() << std::endl; - abort(); + return 1; } return 0; } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 08e1af5c..297589a4 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -20,6 +20,7 @@ Config::Config() : include_vocab(true), prob_bits(8), backoff_bits(8), + pointer_bhiksha_bits(22), load_method(util::POPULATE_OR_READ) {} } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index dcc7cf35..227b8512 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -73,9 +73,12 @@ struct Config { // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used - // to quantize. + // to quantize (and one of the remaining backoffs will be 0). uint8_t prob_bits, backoff_bits; + // Bhiksha compression (simple form). Only works with trie. + uint8_t pointer_bhiksha_bits; + // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a1d10b3d..27e24b1c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -21,6 +21,8 @@ size_t hash_value(const State &state) { namespace detail { +template const ModelType GenericModel::kModelType = Search::kModelType; + template size_t GenericModel::Size(const std::vector &counts, const Config &config) { return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } @@ -56,35 +58,40 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, const Config &config) { // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. util::FilePiece f(backing_.file.release(), file, config.messages); - std::vector counts; - // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. - ReadARPACounts(f, counts); - - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); - if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); - if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - - std::size_t vocab_size = VocabularyT::Size(counts[0], config); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. - vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); - - if (config.write_mmap) { - WriteWordsWrapper wrap(config.enumerate_vocab); - vocab_.ConfigureEnumerate(&wrap, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get()); - } else { - vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - } + try { + std::vector counts; + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. + ReadARPACounts(f, counts); + + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + + std::size_t vocab_size = VocabularyT::Size(counts[0], config); + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + + if (config.write_mmap) { + WriteWordsWrapper wrap(config.enumerate_vocab); + vocab_.ConfigureEnumerate(&wrap, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + wrap.Write(backing_.file.get()); + } else { + vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + } - if (!vocab_.SawUnk()) { - assert(config.unknown_missing != THROW_UP); - // Default probabilities for unknown. - search_.unigram.Unknown().backoff = 0.0; - search_.unigram.Unknown().prob = config.unknown_missing_logprob; + if (!vocab_.SawUnk()) { + assert(config.unknown_missing != THROW_UP); + // Default probabilities for unknown. + search_.unigram.Unknown().backoff = 0.0; + search_.unigram.Unknown().prob = config.unknown_missing_logprob; + } + FinishFile(config, kModelType, counts, backing_); + } catch (util::Exception &e) { + e << " Byte: " << f.Offset(); + throw; } - FinishFile(config, kModelType, counts, backing_); } template FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { @@ -225,8 +232,10 @@ template FullScoreReturn GenericModel; // HASH_PROBING -template class GenericModel, SortedVocabulary>; // TRIE_SORTED -template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, SortedVocabulary>; +template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; } // namespace detail } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 1f49a382..21595321 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -1,6 +1,7 @@ #ifndef LM_MODEL__ #define LM_MODEL__ +#include "lm/bhiksha.hh" #include "lm/binary_format.hh" #include "lm/config.hh" #include "lm/facade.hh" @@ -71,6 +72,9 @@ template class GenericModel : public base::Mod private: typedef base::ModelFacade, State, VocabularyT> P; public: + // This is the model type returned by RecognizeBinary. + static const ModelType kModelType; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -131,8 +135,6 @@ template class GenericModel : public base::Mod Backing &MutableBacking() { return backing_; } - static const ModelType kModelType = Search::kModelType; - Backing backing_; VocabularyT vocab_; @@ -152,9 +154,11 @@ typedef ProbingModel Model; // Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> ArrayTrieModel; -typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantArrayTrieModel; } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 8bf040ff..57c7291c 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -193,6 +193,14 @@ template void Stateless(const M &model) { BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); } +template void NoUnkCheck(const M &model) { + WordIndex unk_index = 0; + State state; + + FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state); + BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); +} + template void Everything(const M &m) { Starters(m); Continuation(m); @@ -231,25 +239,38 @@ template void LoadingTest() { Config config; config.arpa_complain = Config::NONE; config.messages = NULL; - ExpectEnumerateVocab enumerate; - config.enumerate_vocab = &enumerate; config.probing_multiplier = 2.0; - ModelT m("test.arpa", config); - enumerate.Check(m.GetVocabulary()); - Everything(m); + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test.arpa", config); + enumerate.Check(m.GetVocabulary()); + Everything(m); + } + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test_nounk.arpa", config); + enumerate.Check(m.GetVocabulary()); + NoUnkCheck(m); + } } BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } - BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } - -BOOST_AUTO_TEST_CASE(quant) { +BOOST_AUTO_TEST_CASE(quant_trie) { LoadingTest(); } +BOOST_AUTO_TEST_CASE(bhiksha_trie) { + LoadingTest(); +} +BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { + LoadingTest(); +} template void BinaryTest() { Config config; @@ -267,10 +288,34 @@ template void BinaryTest() { config.write_mmap = NULL; - ModelT binary("test.binary", config); - enumerate.Check(binary.GetVocabulary()); - Everything(binary); + ModelType type; + BOOST_REQUIRE(RecognizeBinary("test.binary", type)); + BOOST_CHECK_EQUAL(ModelT::kModelType, type); + + { + ModelT binary("test.binary", config); + enumerate.Check(binary.GetVocabulary()); + Everything(binary); + } unlink("test.binary"); + + // Now test without . + config.write_mmap = "test_nounk.binary"; + config.messages = NULL; + enumerate.Clear(); + { + ModelT copy_model("test_nounk.arpa", config); + enumerate.Check(copy_model.GetVocabulary()); + enumerate.Clear(); + NoUnkCheck(copy_model); + } + config.write_mmap = NULL; + { + ModelT binary("test_nounk.binary", config); + enumerate.Check(binary.GetVocabulary()); + NoUnkCheck(binary); + } + unlink("test_nounk.binary"); } BOOST_AUTO_TEST_CASE(write_and_read_probing) { @@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { BinaryTest(); } +BOOST_AUTO_TEST_CASE(write_and_read_array_trie) { + BinaryTest(); +} +BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) { + BinaryTest(); +} } // namespace } // namespace ngram diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 9454a6d1..d9db4aa2 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -99,6 +99,15 @@ int main(int argc, char *argv[]) { case lm::ngram::TRIE_SORTED: Query(argv[1], sentence_context); break; + case lm::ngram::QUANT_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::QUANT_ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; case lm::ngram::HASH_SORTED: default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index 4bb6b1b8..fd371cc8 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -43,6 +43,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector(0); static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } @@ -108,7 +108,7 @@ class SeparatelyQuantize { }; public: - static const ModelType kModelType = QUANT_TRIE_SORTED; + static const ModelType kModelTypeAdd = kQuantAdd; static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 060a97ea..455bc4ba 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -31,15 +31,15 @@ const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; void ReadARPACounts(util::FilePiece &in, std::vector &number) { number.clear(); StringPiece line; - if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { + while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + if (line != "\\data\\") { if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast(line.data()[1]) == 0x8b)) { UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); } if (static_cast(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); - UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank"); + UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\."); } - if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); // So strtol doesn't go off the end of line. diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index c56ba7b8..82c53ec8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -98,7 +98,7 @@ template uint8_t *TemplateHashedSearch template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. - SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); + SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config); PositiveProbWarn warn(config.positive_log_probability); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index f3acdefc..c62985e4 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -52,12 +52,11 @@ struct HashedSearch { Unigram unigram; - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { const ProbBackoff &entry = unigram.Lookup(word); prob = entry.prob; backoff = entry.backoff; next = static_cast(word); - return true; } }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 91f87f1c..05059ffb 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,6 +1,7 @@ /* This is where the trie is built. It's on-disk. */ #include "lm/search_trie.hh" +#include "lm/bhiksha.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" @@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector appears. - size_t extra_count = counts[0] + 1; - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); + size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; @@ -610,9 +611,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -template class WriteEntries { +template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -649,7 +650,7 @@ template class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; + BitPackedMiddle *const middle_; BitPackedLongest &longest_; BitPacked &bigram_pack_; }; @@ -821,7 +822,7 @@ template void TrainProbQuantizer(uint8_t order, uint64_t count, So } // namespace -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing) { +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { std::vector inputs(counts.size() - 1); std::vector contexts(counts.size() - 1); @@ -846,7 +847,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); if (Quant::kTrain) { util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); @@ -863,7 +864,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -901,14 +902,14 @@ template void BuildTrie(const std::string &file_prefix, std::vecto /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { - for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { - i->FinishedLoading((i+1)->InsertIndex()); + for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex(), config); } - (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config); } } -template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { +template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { quant_.SetupMemory(start, config); start += Quant::Size(counts.size(), config); unigram.Init(start); @@ -919,22 +920,24 @@ template uint8_t *TrieSearch::SetupMemory(uint8_t *start, c std::vector middle_starts(counts.size() - 2); for (unsigned char i = 2; i < counts.size(); ++i) { middle_starts[i-2] = start; - start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config); } - // Crazy backwards thing so we initialize in the correct order. + // Crazy backwards thing so we initialize using pointers to ones that have already been initialized for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( middle_starts[i-2], quant_.Mid(i), + counts[i-1], counts[0], counts[i], - (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); + (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1]), + config); } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template void TrieSearch::LoadedBinary() { +template void TrieSearch::LoadedBinary() { unigram.LoadedBinary(); for (Middle *i = middle_begin_; i != middle_end_; ++i) { i->LoadedBinary(); @@ -942,7 +945,7 @@ template void TrieSearch::LoadedBinary() { longest.LoadedBinary(); } -template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -966,14 +969,16 @@ template void TrieSearch::InitializeFromARPA(const char *fi // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, quant_, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } -template class TrieSearch; -template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; } // namespace trie } // namespace ngram diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0a52acb5..2f39c09f 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,31 +13,33 @@ struct Backing; class SortedVocabulary; namespace trie { -template class TrieSearch; -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); +template class TrieSearch; +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); -template class TrieSearch { +template class TrieSearch { public: typedef NodeRange Node; typedef ::lm::ngram::trie::Unigram Unigram; Unigram unigram; - typedef trie::BitPackedMiddle Middle; + typedef trie::BitPackedMiddle Middle; typedef trie::BitPackedLongest Longest; Longest longest; - static const ModelType kModelType = Quant::kModelType; + static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); + AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + Bhiksha::UpdateConfigFromBinary(fd, config); } static std::size_t Size(const std::vector &counts, const Config &config) { std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); } return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } @@ -55,8 +57,8 @@ template class TrieSearch { void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + unigram.Find(word, prob, backoff, node); } bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { @@ -83,7 +85,7 @@ template class TrieSearch { } private: - friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); + friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa new file mode 100644 index 00000000..060733d9 --- /dev/null +++ b/klm/lm/test_nounk.arpa @@ -0,0 +1,120 @@ + +\data\ +ngram 1=36 +ngram 2=45 +ngram 3=10 +ngram 4=6 +ngram 5=4 + +\1-grams: +-1.383514 , -0.30103 +-1.139057 . -0.845098 +-1.029493 +-99 -0.4149733 +-1.285941 a -0.69897 +-1.687872 also -0.30103 +-1.687872 beyond -0.30103 +-1.687872 biarritz -0.30103 +-1.687872 call -0.30103 +-1.687872 concerns -0.30103 +-1.687872 consider -0.30103 +-1.687872 considering -0.30103 +-1.687872 for -0.30103 +-1.509559 higher -0.30103 +-1.687872 however -0.30103 +-1.687872 i -0.30103 +-1.687872 immediate -0.30103 +-1.687872 in -0.30103 +-1.687872 is -0.30103 +-1.285941 little -0.69897 +-1.383514 loin -0.30103 +-1.687872 look -0.30103 +-1.285941 looking -0.4771212 +-1.206319 more -0.544068 +-1.509559 on -0.4771212 +-1.509559 screening -0.4771212 +-1.687872 small -0.30103 +-1.687872 the -0.30103 +-1.687872 to -0.30103 +-1.687872 watch -0.30103 +-1.687872 watching -0.30103 +-1.687872 what -0.30103 +-1.687872 would -0.30103 +-3.141592 foo +-2.718281 bar 3.0 +-6.535897 baz -0.0 + +\2-grams: +-0.6925742 , . +-0.7522095 , however +-0.7522095 , is +-0.0602359 . +-0.4846522 looking -0.4771214 +-1.051485 screening +-1.07153 the +-1.07153 watching +-1.07153 what +-0.09132547 a little -0.69897 +-0.2922095 also call +-0.2922095 beyond immediate +-0.2705918 biarritz . +-0.2922095 call for +-0.2922095 concerns in +-0.2922095 consider watch +-0.2922095 considering consider +-0.2834328 for , +-0.5511513 higher more +-0.5845945 higher small +-0.2834328 however , +-0.2922095 i would +-0.2922095 immediate concerns +-0.2922095 in biarritz +-0.2922095 is to +-0.09021038 little more -0.1998621 +-0.7273645 loin , +-0.6925742 loin . +-0.6708385 loin +-0.2922095 look beyond +-0.4638903 looking higher +-0.4638903 looking on -0.4771212 +-0.5136299 more . -0.4771212 +-0.3561665 more loin +-0.1649931 on a -0.4771213 +-0.1649931 screening a -0.4771213 +-0.2705918 small . +-0.287799 the screening +-0.2922095 to look +-0.2622373 watch +-0.2922095 watching considering +-0.2922095 what i +-0.2922095 would also +-2 also would -6 +-6 foo bar + +\3-grams: +-0.01916512 more . +-0.0283603 on a little -0.4771212 +-0.0283603 screening a little -0.4771212 +-0.01660496 a little more -0.09409451 +-0.3488368 looking higher +-0.3488368 looking on -0.4771212 +-0.1892331 little more loin +-0.04835128 looking on a -0.4771212 +-3 also would consider -7 +-7 to look good + +\4-grams: +-0.009249173 looking on a little -0.4771212 +-0.005464747 on a little more -0.4771212 +-0.005464747 screening a little more +-0.1453306 a little more loin +-0.01552657 looking on a -0.4771212 +-4 also would consider higher -8 + +\5-grams: +-0.003061223 looking on a little +-0.001813953 looking on a little more +-0.0432557 on a little more loin +-5 also would consider higher looking + +\end\ diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 63c2a612..8c536e66 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,5 +1,6 @@ #include "lm/trie.hh" +#include "lm/bhiksha.hh" #include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" @@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { - return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { + return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } -template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { - if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); +template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : + BitPacked(), + quant_(quant), + // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. + bhiksha_(base, entries + 1, max_next, config), + next_source_(&next_source) { + if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); + BaseInit(reinterpret_cast(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits()); } -template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; @@ -75,47 +81,42 @@ template void BitPackedMiddle::Insert(WordIndex word, float quant_.Write(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); uint64_t next = next_source_->InsertIndex(); - assert(next <= next_mask_); - util::WriteInt57(base_, at_pointer, next_bits_, next); + bhiksha_.WriteNext(base_, at_pointer, insert_index_, next); ++insert_index_; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } + uint64_t index = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); + return true; } -template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { - uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; - at_pointer *= total_bits_; +template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { + uint64_t index; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; + uint64_t at_pointer = index * total_bits_; at_pointer += word_bits_; quant_.ReadBackoff(base_, at_pointer, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); return true; } -template void BitPackedMiddle::FinishedLoading(uint64_t next_end) { - assert(next_end <= next_mask_); - uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_, last_next_write, next_bits_, next_end); +template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { + uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); + bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); + bhiksha_.FinishedLoading(config); } template void BitPackedLongest::Insert(WordIndex index, float prob) { @@ -135,8 +136,10 @@ template bool BitPackedLongest::Find(WordIndex word, float return true; } -template class BitPackedMiddle; -template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; template class BitPackedLongest; template class BitPackedLongest; diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 8fa21aaf..53612064 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -10,6 +10,7 @@ namespace lm { namespace ngram { +class Config; namespace trie { struct NodeRange { @@ -46,13 +47,12 @@ class Unigram { void LoadedBinary() {} - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { + void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { UnigramValue *val = unigram_ + word; prob = val->weights.prob; backoff = val->weights.backoff; next.begin = val->next; next.end = (val+1)->next; - return true; } private: @@ -67,8 +67,6 @@ class BitPacked { return insert_index_; } - void LoadedBinary() {} - protected: static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); @@ -83,30 +81,30 @@ class BitPacked { uint64_t insert_index_, max_vocab_; }; -template class BitPackedMiddle : public BitPacked { +template class BitPackedMiddle : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. - BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); void Insert(WordIndex word, float prob, float backoff); + void FinishedLoading(uint64_t next_end, const Config &config); + + void LoadedBinary() { bhiksha_.LoadedBinary(); } + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; - void FinishedLoading(uint64_t next_end); - private: Quant quant_; - uint8_t next_bits_; - uint64_t next_mask_; + Bhiksha bhiksha_; const BitPacked *next_source_; }; - template class BitPackedLongest : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { @@ -120,6 +118,8 @@ template class BitPackedLongest : public BitPacked { BaseInit(base, max_vocab, quant_.TotalBits()); } + void LoadedBinary() {} + void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 7defd5c1..04979d51 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -37,14 +37,14 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { WordIndex index = 0; while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); - if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words"); if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; ssize_t ret = read(fd, &next_char, 1); - if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word."); + UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word."); buf.push_back(next_char); } // Ok now we have null terminated strings. diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index c92518e4..9d218fff 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary { } } + // Size for purposes of file writing static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. @@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary { // Reorders reorder_vocab so that the IDs are sorted. void FinishedLoading(ProbBackoff *reorder_vocab); + // Trie stores the correct counts including in the header. If this was previously sized based on a count exluding , padding with 8 bytes will make it the correct size based on a count including . + std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } + bool SawUnk() const { return saw_unk_; } void LoadedBinary(int fd, EnumerateVocab *to); diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index b35d80c8..9f47d559 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -107,9 +107,20 @@ void BitPackingSanity(); uint8_t RequiredBits(uint64_t max_value); struct BitsMask { + static BitsMask ByMax(uint64_t max_value) { + BitsMask ret; + ret.FromMax(max_value); + return ret; + } + static BitsMask ByBits(uint8_t bits) { + BitsMask ret; + ret.bits = bits; + ret.mask = (1ULL << bits) - 1; + return ret; + } void FromMax(uint64_t max_value) { bits = RequiredBits(max_value); - mask = (1 << bits) - 1; + mask = (1ULL << bits) - 1; } uint8_t bits; uint64_t mask; diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index d58a0727..fec47fd9 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All - * code is released to the public domain. For business purposes, Murmurhash is - * under the MIT license." - * This is modified from the original: - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. - * length changed to unsigned int. - * placed in namespace util - * add MurmurHashNative - * default option = 0 for seed - */ - -#include "util/murmur_hash.hh" - -namespace util { - -//----------------------------------------------------------------------------- -// MurmurHash2, 64-bit versions, by Austin Appleby - -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment -// and endian-ness issues if used across multiple platforms. - -// 64-bit hash for 64-bit platforms - -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - - -// 64-bit hash for 32-bit platforms - -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) -{ - const unsigned int m = 0x5bd1e995; - const int r = 24; - - unsigned int h1 = seed ^ len; - unsigned int h2 = 0; - - const unsigned int * data = (const unsigned int *)key; - - while(len >= 8) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - unsigned int k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } -} - -} // namespace util +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 00be0ed7..2ec342a6 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -57,7 +57,7 @@ template class IdentityAccessor { public: typedef T Key; - T operator()(const uint64_t *in) const { return *in; } + T operator()(const T *in) const { return *in; } }; struct Pivot64 { @@ -101,6 +101,27 @@ template bool SortedUniformFind(co return BoundedSortedUniformFind(accessor, begin, below, end, above, key, out); } +// May return begin - 1. +template Iterator BinaryBelow( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {} + return pivot - 1; + } + } + return begin - 1; +} + // To use this template, you need to define a Pivot function to match Key. template class SortedUniformMap { public: -- cgit v1.2.3