diff options
-rw-r--r-- | BUILDING | 5 | ||||
-rw-r--r-- | decoder/cdec_ff.cc | 1 | ||||
-rw-r--r-- | decoder/ff_klm.cc | 1 | ||||
-rwxr-xr-x | klm/compile.sh | 4 | ||||
-rw-r--r-- | klm/lm/Makefile.am | 1 | ||||
-rw-r--r-- | klm/lm/binary_format.cc | 17 | ||||
-rw-r--r-- | klm/lm/binary_format.hh | 10 | ||||
-rw-r--r-- | klm/lm/blank.hh | 4 | ||||
-rw-r--r-- | klm/lm/build_binary.cc | 83 | ||||
-rw-r--r-- | klm/lm/config.cc | 2 | ||||
-rw-r--r-- | klm/lm/config.hh | 5 | ||||
-rw-r--r-- | klm/lm/model.cc | 25 | ||||
-rw-r--r-- | klm/lm/model.hh | 21 | ||||
-rw-r--r-- | klm/lm/model_test.cc | 13 | ||||
-rw-r--r-- | klm/lm/quantize.cc | 84 | ||||
-rw-r--r-- | klm/lm/quantize.hh | 207 | ||||
-rw-r--r-- | klm/lm/search_hashed.cc | 38 | ||||
-rw-r--r-- | klm/lm/search_hashed.hh | 122 | ||||
-rw-r--r-- | klm/lm/search_trie.cc | 126 | ||||
-rw-r--r-- | klm/lm/search_trie.hh | 132 | ||||
-rw-r--r-- | klm/lm/trie.cc | 123 | ||||
-rw-r--r-- | klm/lm/trie.hh | 33 | ||||
-rw-r--r-- | klm/lm/vocab.cc | 18 | ||||
-rw-r--r-- | klm/lm/vocab.hh | 35 | ||||
-rw-r--r-- | klm/util/bit_packing.cc | 4 | ||||
-rw-r--r-- | klm/util/bit_packing.hh | 39 | ||||
-rw-r--r-- | klm/util/bit_packing_test.cc | 25 | ||||
-rw-r--r-- | klm/util/sorted_uniform.hh | 120 |
28 files changed, 926 insertions, 372 deletions
@@ -1,6 +1,5 @@ To build cdec, you'll need: - * SRILM (register and download from http://www.speech.sri.com/projects/srilm/) * Google c++ testing framework (http://code.google.com/p/googletest/) * boost headers & boost program_options (you may need to install a package like boost-devel) @@ -9,7 +8,7 @@ To build cdec, you'll need: Instructions for building ----------------------------------- - 1) Download and build SRILM + 1) Optional: Download and build SRILM 2) Download, build, and install Google Test (optional, this is necessary to build unit tests that may be useful in development; system tests @@ -22,7 +21,7 @@ Instructions for building 4) Configure and build. Your command will look something like this. - ./configure --with-srilm=/home/me/software/srilm-1.5.9 --disable-gtest + ./configure --disable-gtest make If you get errors during configure about missing BOOST macros, then step 3 diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 37aa655b..31f88a4f 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -55,6 +55,7 @@ void register_feature_functions() { ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory<CMR2008ReorderingFeatures>()); ff_registry.Register("KLanguageModel", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); ff_registry.Register("KLanguageModel_Trie", new FFFactory<KLanguageModel<lm::ngram::TrieModel> >()); + ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory<KLanguageModel<lm::ngram::QuantTrieModel> >()); ff_registry.Register("KLanguageModel_Probing", new FFFactory<KLanguageModel<lm::ngram::ProbingModel> >()); ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>); ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index a3bd0c5f..9b7fe2d3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -437,4 +437,5 @@ void KLanguageModel<Model>::FinalTraversalFeatures(const void* ant_state, // instantiate templates template class KLanguageModel<lm::ngram::ProbingModel>; template class KLanguageModel<lm::ngram::TrieModel>; +template class KLanguageModel<lm::ngram::QuantTrieModel>; diff --git a/klm/compile.sh b/klm/compile.sh index 49e04db8..6ca85e1f 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -3,11 +3,9 @@ #If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU #I use zlib by default. If you don't want to depend on zlib, remove #define USE_ZLIB from util/file_piece.hh -#don't need to use if compiling with moses Makefiles already - set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o done g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 61d98d97..395494bc 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -15,6 +15,7 @@ libklm_a_SOURCES = \ binary_format.cc \ config.cc \ lm_exception.cc \ + quantize.cc \ model.cc \ ngram_query.cc \ read_arpa.cc \ diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 34d9ffca..92b1008b 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -80,6 +80,14 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace +void SeekOrThrow(int fd, off_t off) { + if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed"); +} + +void AdvanceOrThrow(int fd, off_t off) { + if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed"); +} + uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; @@ -156,7 +164,7 @@ bool IsBinaryFormat(int fd) { } void ReadHeader(int fd, Parameters &out) { - if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file"); + SeekOrThrow(fd, sizeof(Sanity)); ReadLoop(fd, &out.fixed, sizeof(out.fixed)); if (out.fixed.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); @@ -173,6 +181,10 @@ void MatchCheck(ModelType model_type, const Parameters ¶ms) { } } +void SeekPastHeader(int fd, const Parameters ¶ms) { + SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); +} + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { const off_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. @@ -186,8 +198,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); if (config.enumerate_vocab) { - if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) - UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words"); + SeekOrThrow(backing.file.get(), total_map); } return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size()); } diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 1fc71be4..2b32b450 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,7 @@ namespace lm { namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2} ModelType; +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in @@ -48,6 +48,10 @@ struct Backing { util::scoped_memory search; }; +void SeekOrThrow(int fd, off_t off); +// Seek forward +void AdvanceOrThrow(int fd, off_t off); + // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. @@ -65,6 +69,8 @@ void ReadHeader(int fd, Parameters ¶ms); void MatchCheck(ModelType model_type, const Parameters ¶ms); +void SeekPastHeader(int fd, const Parameters ¶ms); + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); void ComplainAboutARPA(const Config &config, ModelType model_type); @@ -83,6 +89,8 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to) // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; + detail::SeekPastHeader(backing.file.get(), params); + To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); std::size_t memory_size = To::Size(params.counts, new_config); uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); to.InitializeFromBinary(start, params, new_config, backing.file.get()); diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 4615a09e..162411a9 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -22,6 +22,8 @@ namespace ngram { */ const float kNoExtensionBackoff = -0.0; const float kExtensionBackoff = 0.0; +const uint64_t kNoExtensionQuant = 0; +const uint64_t kExtensionQuant = 1; inline void SetExtension(float &backoff) { if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff; @@ -47,6 +49,8 @@ inline bool HasExtension(const float &backoff) { */ const float kBlankProb = -std::numeric_limits<float>::infinity(); const float kBlankBackoff = kNoExtensionBackoff; +const uint32_t kBlankProbQuant = 0; +const uint32_t kBlankBackoffQuant = 0; } // namespace ngram } // namespace lm diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 91ad2fb9..4552c419 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,22 +15,21 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" "-u sets the default log10 probability for <unk> if the ARPA file does not have\n" "one.\n" "-s allows models to be built even if they do not have <s> and </s>.\n" "-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is one of probing, trie, or sorted:\n\n" +"type is either probing or trie:\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" "memory and is still faster than SRI or IRST. Building the trie format uses an\n" "on-disk sort to save memory.\n" "-t is the temporary directory prefix. Default is the output file name.\n" -"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n\n" -/*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" -"It uses more memory than trie and is also slower, so there's no real reason to\n" -"use it.\n\n"*/ +"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" +"-q turns quantization on and sets the number of bits (e.g. -q 8).\n" +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n" "See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" "Passing only an input file will print memory usage of each data structure.\n" "If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n"; @@ -51,19 +50,53 @@ unsigned long int ParseUInt(const char *from) { return ret; } +uint8_t ParseBitCount(const char *from) { + unsigned long val = ParseUInt(from); + if (val > 25) { + util::ParseNumberException e(from); + e << " bit counts are limited to 256."; + } + return val; +} + void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector<uint64_t> counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t probing_size = ProbingModel::Size(counts, config); - // probing is always largest so use it to determine number of columns. - long int length = std::max<long int>(5, lrint(ceil(log10(probing_size)))); + std::size_t sizes[3]; + sizes[0] = ProbingModel::Size(counts, config); + sizes[1] = TrieModel::Size(counts, config); + sizes[2] = QuantTrieModel::Size(counts, config); + std::size_t max_length = *std::max_element(sizes, sizes + 3); + std::size_t min_length = *std::max_element(sizes, sizes + 3); + std::size_t divide; + char prefix; + if (min_length < (1 << 10) * 10) { + prefix = ' '; + divide = 1; + } else if (min_length < (1 << 20) * 10) { + prefix = 'k'; + divide = 1 << 10; + } else if (min_length < (1ULL << 30) * 10) { + prefix = 'M'; + divide = 1 << 20; + } else { + prefix = 'G'; + divide = 1 << 30; + } + long int length = std::max<long int>(2, lrint(ceil(log10(max_length / divide)))); std::cout << "Memory estimate:\ntype "; // right align bytes. - for (long int i = 0; i < length - 5; ++i) std::cout << ' '; - std::cout << "bytes\n" - "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n" - "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"; + for (long int i = 0; i < length - 2; ++i) std::cout << ' '; + std::cout << prefix << "B\n" + "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" + "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" + "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; +} + +void ProbingQuantizationUnsupported() { + std::cerr << "Quantization is only implemented in the trie data structure." << std::endl; + exit(1); } } // namespace ngram @@ -73,11 +106,21 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { int main(int argc, char *argv[]) { using namespace lm::ngram; + bool quantize = false, set_backoff_bits = false; try { lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { switch(opt) { + case 'q': + config.prob_bits = ParseBitCount(optarg); + if (!set_backoff_bits) config.backoff_bits = config.prob_bits; + quantize = true; + break; + case 'b': + config.backoff_bits = ParseBitCount(optarg); + set_backoff_bits = true; + break; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -100,19 +143,29 @@ int main(int argc, char *argv[]) { Usage(argv[0]); } } + if (!quantize && set_backoff_bits) { + std::cerr << "You specified backoff quantization (-b) but not probability quantization (-q)" << std::endl; + abort(); + } if (optind + 1 == argc) { ShowSizes(argv[optind], config); } else if (optind + 2 == argc) { config.write_mmap = argv[optind + 1]; + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(argv[optind], config); } else if (optind + 3 == argc) { const char *model_type = argv[optind]; const char *from_file = argv[optind + 1]; config.write_mmap = argv[optind + 2]; if (!strcmp(model_type, "probing")) { + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { - TrieModel(from_file, config); + if (quantize) { + QuantTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } else { Usage(argv[0]); } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index cee8fce2..08e1af5c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -18,6 +18,8 @@ Config::Config() : arpa_complain(ALL), write_mmap(NULL), include_vocab(true), + prob_bits(8), + backoff_bits(8), load_method(util::POPULATE_OR_READ) {} } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 6c7fe39b..dcc7cf35 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -71,6 +71,11 @@ struct Config { // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; + // Quantization options. Only effective for QuantTrieModel. One value is + // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used + // to quantize. + uint8_t prob_bits, backoff_bits; + // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index f0579c0c..a1d10b3d 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -44,17 +44,13 @@ template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::Ge begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; State null_context = State(); null_context.valid_length_ = 0; - P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2); + P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { SetupMemory(start, params.counts, config); vocab_.LoadedBinary(fd, config.enumerate_vocab); - search_.unigram.LoadedBinary(); - for (typename std::vector<Middle>::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) { - i->LoadedBinary(); - } - search_.longest.LoadedBinary(); + search_.LoadedBinary(); } template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(const char *file, const Config &config) { @@ -116,8 +112,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, } float backoff; // i is the order of the backoff we're looking for. - for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) { - if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break; + const Middle *mid_iter = search_.MiddleBegin() + start - 2; + for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++mid_iter) { + if (!search_.LookupMiddleNoProb(*mid_iter, *i, backoff, node)) break; ret.prob += backoff; } return ret; @@ -135,7 +132,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT search_.LookupUnigram(*context_rbegin, ignored_prob, out_state.backoff_[0], node); out_state.valid_length_ = HasExtension(out_state.backoff_[0]) ? 1 : 0; float *backoff_out = out_state.backoff_ + 1; - const typename Search::Middle *mid = &*search_.middle.begin(); + const typename Search::Middle *mid = search_.MiddleBegin(); for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++mid) { if (!search_.LookupMiddleNoProb(*mid, *i, *backoff_out, node)) { std::copy(context_rbegin, context_rbegin + out_state.valid_length_, out_state.history_); @@ -183,7 +180,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, // Ok now we now that the bigram contains known words. Start by looking it up. const WordIndex *hist_iter = context_rbegin; - typename std::vector<Middle>::const_iterator mid_iter = search_.middle.begin(); + const typename Search::Middle *mid_iter = search_.MiddleBegin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. @@ -192,7 +189,7 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, return ret; } - if (mid_iter == search_.middle.end()) break; + if (mid_iter == search_.MiddleEnd()) break; float revert = ret.prob; if (!search_.LookupMiddle(*mid_iter, *hist_iter, ret.prob, *backoff_out, node)) { @@ -227,9 +224,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, return ret; } -template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; -template class GenericModel<SortedHashedSearch, SortedVocabulary>; -template class GenericModel<trie::TrieSearch, SortedVocabulary>; +template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING +template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>; // TRIE_SORTED +template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>; // TRIE_SORTED_QUANT } // namespace detail } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index b85ccdcc..1f49a382 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -5,6 +5,7 @@ #include "lm/config.hh" #include "lm/facade.hh" #include "lm/max_order.hh" +#include "lm/quantize.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" #include "lm/vocab.hh" @@ -70,9 +71,10 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod private: typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P; public: - // Get the size of memory that will be mapped given ngram counts. This - // does not include small non-mapped control structures, such as this class - // itself. + /* Get the size of memory that will be mapped given ngram counts. This + * does not include small non-mapped control structures, such as this class + * itself. + */ static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); /* Load the model from a file. It may be an ARPA or binary file. Binary @@ -111,6 +113,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod private: friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to); + static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { + AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); + Search::UpdateConfigFromBinary(fd, counts, config); + } + float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; @@ -130,9 +137,7 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod VocabularyT vocab_; - typedef typename Search::Unigram Unigram; typedef typename Search::Middle Middle; - typedef typename Search::Longest Longest; Search search_; }; @@ -141,13 +146,15 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod // These must also be instantiated in the cc file. typedef ::lm::ngram::ProbingVocabulary Vocabulary; -typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; +typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; // HASH_PROBING // Default implementation. No real reason for it to be the default. typedef ProbingModel Model; // Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel<trie::TrieSearch, SortedVocabulary> TrieModel; +typedef detail::GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary> TrieModel; // TRIE_SORTED + +typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 548c098d..8bf040ff 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -243,13 +243,14 @@ BOOST_AUTO_TEST_CASE(probing) { LoadingTest<Model>(); } -/*BOOST_AUTO_TEST_CASE(sorted) { - LoadingTest<SortedModel>(); -}*/ BOOST_AUTO_TEST_CASE(trie) { LoadingTest<TrieModel>(); } +BOOST_AUTO_TEST_CASE(quant) { + LoadingTest<QuantTrieModel>(); +} + template <class ModelT> void BinaryTest() { Config config; config.write_mmap = "test.binary"; @@ -275,12 +276,12 @@ template <class ModelT> void BinaryTest() { BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest<Model>(); } -/*BOOST_AUTO_TEST_CASE(write_and_read_sorted) { - BinaryTest<SortedModel>(); -}*/ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BinaryTest<TrieModel>(); } +BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { + BinaryTest<QuantTrieModel>(); +} } // namespace } // namespace ngram diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc new file mode 100644 index 00000000..4bb6b1b8 --- /dev/null +++ b/klm/lm/quantize.cc @@ -0,0 +1,84 @@ +#include "lm/quantize.hh" + +#include "lm/lm_exception.hh" + +#include <algorithm> +#include <numeric> + +#include <unistd.h> + +namespace lm { +namespace ngram { + +/* Quantize into bins of equal size as described in + * M. Federico and N. Bertoldi. 2006. How many bits are needed + * to store probabilities for phrase-based translation? In Proc. + * of the Workshop on Statistical Machine Translation, pages + * 94–101, New York City, June. Association for Computa- + * tional Linguistics. + */ + +namespace { + +void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) { + std::sort(values, values_end); + const float *start = values, *finish; + for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) { + finish = values + (((values_end - values) * static_cast<uint64_t>(i + 1)) / bins); + if (finish == start) { + // zero length bucket. + *centers = i ? *(centers - 1) : -std::numeric_limits<float>::infinity(); + } else { + *centers = std::accumulate(start, finish, 0.0) / static_cast<float>(finish - start); + } + } +} + +const char kSeparatelyQuantizeVersion = 2; + +} // namespace + +void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &/*counts*/, Config &config) { + char version; + if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1) + UTIL_THROW(util::ErrnoException, "Failed to read header for quantization."); + if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); +} + +void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { + // Reserve 8 byte header for bit counts. + start_ = reinterpret_cast<float*>(static_cast<uint8_t*>(start) + 8); + prob_bits_ = config.prob_bits; + backoff_bits_ = config.backoff_bits; + // We need the reserved values. + if (config.prob_bits == 0) UTIL_THROW(ConfigException, "You can't quantize probability to zero"); + if (config.backoff_bits == 0) UTIL_THROW(ConfigException, "You can't quantize backoff to zero"); + if (config.prob_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing probability supports at most 25 bits. Currently you have requested " << static_cast<unsigned>(config.prob_bits) << " bits."); + if (config.backoff_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing backoff supports at most 25 bits. Currently you have requested " << static_cast<unsigned>(config.backoff_bits) << " bits."); +} + +void SeparatelyQuantize::Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff) { + TrainProb(order, prob); + + // Backoff + float *centers = start_ + TableStart(order) + ProbTableLength(); + *(centers++) = kNoExtensionBackoff; + *(centers++) = kExtensionBackoff; + MakeBins(&*backoff.begin(), &*backoff.end(), centers, (1ULL << backoff_bits_) - 2); +} + +void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) { + float *centers = start_ + TableStart(order); + *(centers++) = kBlankProb; + MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_) - 1); +} + +void SeparatelyQuantize::FinishedLoading(const Config &config) { + uint8_t *actual_base = reinterpret_cast<uint8_t*>(start_) - 8; + *(actual_base++) = kSeparatelyQuantizeVersion; // version + *(actual_base++) = config.prob_bits; + *(actual_base++) = config.backoff_bits; +} + +} // namespace ngram +} // namespace lm diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh new file mode 100644 index 00000000..aae72b34 --- /dev/null +++ b/klm/lm/quantize.hh @@ -0,0 +1,207 @@ +#ifndef LM_QUANTIZE_H__ +#define LM_QUANTIZE_H__ + +#include "lm/binary_format.hh" // for ModelType +#include "lm/blank.hh" +#include "lm/config.hh" +#include "util/bit_packing.hh" + +#include <algorithm> +#include <vector> + +#include <inttypes.h> + +#include <iostream> + +namespace lm { +namespace ngram { + +class Config; + +/* Store values directly and don't quantize. */ +class DontQuantize { + public: + static const ModelType kModelType = TRIE_SORTED; + static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} + static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } + static uint8_t MiddleBits(const Config &/*config*/) { return 63; } + static uint8_t LongestBits(const Config &/*config*/) { return 31; } + + struct Middle { + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + util::WriteFloat32(base, bit_offset + 31, backoff); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + backoff = util::ReadFloat32(base, bit_offset + 31); + } + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = util::ReadFloat32(base, bit_offset + 31); + } + uint8_t TotalBits() const { return 63; } + }; + + struct Longest { + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + } + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } + uint8_t TotalBits() const { return 31; } + }; + + DontQuantize() {} + + void SetupMemory(void * /*start*/, const Config & /*config*/) {} + + static const bool kTrain = false; + // These should never be called because kTrain is false. + void Train(uint8_t /*order*/, std::vector<float> &/*prob*/, std::vector<float> &/*backoff*/) {} + void TrainProb(uint8_t, std::vector<float> &/*prob*/) {} + + void FinishedLoading(const Config &) {} + + Middle Mid(uint8_t /*order*/) const { return Middle(); } + Longest Long(uint8_t /*order*/) const { return Longest(); } +}; + +class SeparatelyQuantize { + private: + class Bins { + public: + // Sigh C++ default constructor + Bins() {} + + Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + + uint64_t EncodeProb(float value) const { + return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + } + + uint64_t EncodeBackoff(float value) const { + if (value == 0.0) { + return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant; + } + return Encode(value, 2); + } + + float Decode(std::size_t off) const { return begin_[off]; } + + uint8_t Bits() const { return bits_; } + + uint64_t Mask() const { return mask_; } + + private: + uint64_t Encode(float value, size_t reserved) const { + const float *above = std::lower_bound(begin_ + reserved, end_, value); + if (above == begin_ + reserved) return reserved; + if (above == end_) return end_ - begin_ - 1; + return above - begin_ - (value - *(above - 1) < *above - value); + } + + const float *begin_; + const float *end_; + uint8_t bits_; + uint64_t mask_; + }; + + public: + static const ModelType kModelType = QUANT_TRIE_SORTED; + + static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config); + + static std::size_t Size(uint8_t order, const Config &config) { + size_t longest_table = (static_cast<size_t>(1) << static_cast<size_t>(config.prob_bits)) * sizeof(float); + size_t middle_table = (static_cast<size_t>(1) << static_cast<size_t>(config.backoff_bits)) * sizeof(float) + longest_table; + // unigrams are currently not quantized so no need for a table. + return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; + } + + static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } + static uint8_t LongestBits(const Config &config) { return config.prob_bits; } + + class Middle { + public: + Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) : + total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteInt57(base, bit_offset, total_bits_, + (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); + prob = prob_.Decode(both >> backoff_.Bits()); + backoff = backoff_.Decode(both & backoff_.Mask()); + } + + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask())); + } + + uint8_t TotalBits() const { + return total_bits_; + } + + private: + const uint8_t total_bits_; + const uint64_t total_mask_; + const Bins prob_; + const Bins backoff_; + }; + + class Longest { + public: + // Sigh C++ default constructor + Longest() {} + + Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask())); + } + + uint8_t TotalBits() const { return prob_.Bits(); } + + private: + Bins prob_; + }; + + SeparatelyQuantize() {} + + void SetupMemory(void *start, const Config &config); + + static const bool kTrain = true; + // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff); + // Train just probabilities (for longest order). + void TrainProb(uint8_t order, std::vector<float> &prob); + + void FinishedLoading(const Config &config); + + Middle Mid(uint8_t order) const { + const float *table = start_ + TableStart(order); + return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength()); + } + + Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); } + + private: + size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast<uint64_t>(order - 2); } + size_t ProbTableLength() const { return (1ULL << prob_bits_); } + + float *start_; + uint8_t prob_bits_, backoff_bits_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_QUANTIZE_H__ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index eaad59ab..c56ba7b8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -80,6 +80,21 @@ template <class Voc, class Store, class Middle, class Activate> void ReadNGrams( } // namespace namespace detail { + +template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT, LongestT>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { + std::size_t allocated = Unigram::Size(counts[0]); + unigram = Unigram(start, allocated); + start += allocated; + for (unsigned int n = 2; n < counts.size(); ++n) { + allocated = Middle::Size(counts[n - 1], config.probing_multiplier); + middle_.push_back(Middle(start, allocated)); + start += allocated; + } + allocated = Longest::Size(counts.back(), config.probing_multiplier); + longest = Longest(start, allocated); + start += allocated; + return start; +} template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. @@ -92,15 +107,15 @@ template <class MiddleT, class LongestT> template <class Voc> void TemplateHashe try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0], warn); + ReadNGrams(f, 2, counts[1], vocab, middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle<Middle>(middle[n-3]), middle[n-2], warn); + ReadNGrams(f, n, counts[n-1], vocab, middle_, ActivateLowerMiddle<Middle>(middle_[n-3]), middle_[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle<Middle>(middle.back()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateLowerMiddle<Middle>(middle_.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); @@ -108,13 +123,18 @@ template <class MiddleT, class LongestT> template <class Voc> void TemplateHashe ReadEnd(f); } -template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); -template void TemplateHashedSearch<SortedHashedSearch::Middle, SortedHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, SortedVocabulary &vocab, Backing &backing); - -SortedHashedSearch::SortedHashedSearch() { - UTIL_THROW(util::Exception, "Sorted is broken at the moment, sorry"); +template <class MiddleT, class LongestT> void TemplateHashedSearch<MiddleT, LongestT>::LoadedBinary() { + unigram.LoadedBinary(); + for (typename std::vector<Middle>::iterator i = middle_.begin(); i != middle_.end(); ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); } +template class TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>; + +template void TemplateHashedSearch<ProbingHashedSearch::Middle, ProbingHashedSearch::Longest>::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); + } // namespace detail } // namespace ngram } // namespace lm diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 6dc11fb3..f3acdefc 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -8,7 +8,6 @@ #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" -#include "util/sorted_uniform.hh" #include <algorithm> #include <vector> @@ -62,73 +61,71 @@ struct HashedSearch { } }; -template <class MiddleT, class LongestT> struct TemplateHashedSearch : public HashedSearch { - typedef MiddleT Middle; - std::vector<Middle> middle; +template <class MiddleT, class LongestT> class TemplateHashedSearch : public HashedSearch { + public: + typedef MiddleT Middle; - typedef LongestT Longest; - Longest longest; + typedef LongestT Longest; + Longest longest; - static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char n = 1; n < counts.size() - 1; ++n) { - ret += Middle::Size(counts[n], config.probing_multiplier); - } - return ret + Longest::Size(counts.back(), config.probing_multiplier); - } + // TODO: move probing_multiplier here with next binary file format update. + static void UpdateConfigFromBinary(int, const std::vector<uint64_t> &, Config &) {} - uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { - std::size_t allocated = Unigram::Size(counts[0]); - unigram = Unigram(start, allocated); - start += allocated; - for (unsigned int n = 2; n < counts.size(); ++n) { - allocated = Middle::Size(counts[n - 1], config.probing_multiplier); - middle.push_back(Middle(start, allocated)); - start += allocated; + static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { + std::size_t ret = Unigram::Size(counts[0]); + for (unsigned char n = 1; n < counts.size() - 1; ++n) { + ret += Middle::Size(counts[n], config.probing_multiplier); + } + return ret + Longest::Size(counts.back(), config.probing_multiplier); } - allocated = Longest::Size(counts.back(), config.probing_multiplier); - longest = Longest(start, allocated); - start += allocated; - return start; - } - template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing); + uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); - bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { - node = CombineWordHash(node, word); - typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - prob = found->GetValue().prob; - backoff = found->GetValue().backoff; - return true; - } + template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing); - bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { - node = CombineWordHash(node, word); - typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - backoff = found->GetValue().backoff; - return true; - } + const Middle *MiddleBegin() const { return &*middle_.begin(); } + const Middle *MiddleEnd() const { return &*middle_.end(); } - bool LookupLongest(WordIndex word, float &prob, Node &node) const { - node = CombineWordHash(node, word); - typename Longest::ConstIterator found; - if (!longest.Find(node, found)) return false; - prob = found->GetValue().prob; - return true; - } + bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + prob = found->GetValue().prob; + backoff = found->GetValue().backoff; + return true; + } + + void LoadedBinary(); - // Geenrate a node without necessarily checking that it actually exists. - // Optionally return false if it's know to not exist. - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - assert(begin != end); - node = static_cast<Node>(*begin); - for (const WordIndex *i = begin + 1; i < end; ++i) { - node = CombineWordHash(node, *i); + bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + backoff = found->GetValue().backoff; + return true; } - return true; - } + + bool LookupLongest(WordIndex word, float &prob, Node &node) const { + node = CombineWordHash(node, word); + typename Longest::ConstIterator found; + if (!longest.Find(node, found)) return false; + prob = found->GetValue().prob; + return true; + } + + // Geenrate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + assert(begin != end); + node = static_cast<Node>(*begin); + for (const WordIndex *i = begin + 1; i < end; ++i) { + node = CombineWordHash(node, *i); + } + return true; + } + + private: + std::vector<Middle> middle_; }; // std::identity is an SGI extension :-( @@ -143,15 +140,6 @@ struct ProbingHashedSearch : public TemplateHashedSearch< static const ModelType kModelType = HASH_PROBING; }; -struct SortedHashedSearch : public TemplateHashedSearch< - util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, ProbBackoff> >, - util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, Prob> > > { - - SortedHashedSearch(); - - static const ModelType kModelType = HASH_SORTED; -}; - } // namespace detail } // namespace ngram } // namespace lm diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 7c57072b..91f87f1c 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -4,6 +4,7 @@ #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" +#include "lm/quantize.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" #include "lm/vocab.hh" @@ -21,6 +22,7 @@ #include <cstdio> #include <deque> #include <limits> +#include <numeric> #include <vector> #include <sys/mman.h> @@ -579,7 +581,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W // Phase to count n-grams, including blanks inserted because they were pruned but have extensions class JustCount { public: - JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + template <class Middle, class Longest> JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) : counts_(counts), longest_counts_(counts + order - 1) {} void Unigrams(WordIndex begin, WordIndex end) { @@ -608,9 +610,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -class WriteEntries { +template <class Quant> class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle> *middle, BitPackedLongest<typename Quant::Longest> &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -647,14 +649,14 @@ class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; - BitPackedLongest &longest_; + BitPackedMiddle<typename Quant::Middle> *const middle_; + BitPackedLongest<typename Quant::Longest> &longest_; BitPacked &bigram_pack_; }; template <class Doing> class RecursiveInsert { public: - RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + template <class MiddleT, class LongestT> RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { } @@ -775,7 +777,51 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u } } -void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch &out, Backing &backing) { +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + +template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + ProbBackoff weights; + std::vector<float> probs, backoffs; + probs.reserve(count); + backoffs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; + } + } + quant.Train(order, probs, backoffs); +} + +template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + Prob weights; + std::vector<float> probs, backoffs; + probs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + ++progress; + } + } + quant.TrainProb(order, probs); +} + +} // namespace + +template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing) { std::vector<SortedFileReader> inputs(counts.size() - 1); std::vector<ContextReader> contexts(counts.size() - 1); @@ -791,7 +837,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co std::vector<uint64_t> fixed_counts(counts.size()); { - RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert<JustCount> counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } for (std::vector<SortedFileReader>::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { @@ -800,7 +846,16 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, TrieSearch<Quant>::Size(fixed_counts, config), backing), fixed_counts, config); + + if (Quant::kTrain) { + util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + for (unsigned char i = 2; i < counts.size(); ++i) { + TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + } + TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); + quant.FinishedLoading(config); + } for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -808,7 +863,7 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert<WriteEntries> inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert<WriteEntries<Quant> > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -845,23 +900,49 @@ void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, co /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. - if (!out.middle.empty()) { - for (size_t i = 0; i < out.middle.size() - 1; ++i) { - out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); + if (out.middle_begin_ != out.middle_end_) { + for (typename TrieSearch<Quant>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex()); } - out.middle.back().FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); } } -bool IsDirectory(const char *path) { - struct stat info; - if (0 != stat(path, &info)) return false; - return S_ISDIR(info.st_mode); +template <class Quant> uint8_t *TrieSearch<Quant>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) { + quant_.SetupMemory(start, config); + start += Quant::Size(counts.size(), config); + unigram.Init(start); + start += Unigram::Size(counts[0]); + FreeMiddles(); + middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2))); + middle_end_ = middle_begin_ + (counts.size() - 2); + std::vector<uint8_t*> middle_starts(counts.size() - 2); + for (unsigned char i = 2; i < counts.size(); ++i) { + middle_starts[i-2] = start; + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + } + // Crazy backwards thing so we initialize in the correct order. + for (unsigned char i = counts.size() - 1; i >= 2; --i) { + new (middle_begin_ + i - 2) Middle( + middle_starts[i-2], + quant_.Mid(i), + counts[0], + counts[i], + (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle_begin_[i-1])); + } + longest.Init(start, quant_.Long(counts.size()), counts[0]); + return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -} // namespace +template <class Quant> void TrieSearch<Quant>::LoadedBinary() { + unigram.LoadedBinary(); + for (Middle *i = middle_begin_; i != middle_end_; ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); +} -void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template <class Quant> void TrieSearch<Quant>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -885,12 +966,15 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } +template class TrieSearch<DontQuantize>; +template class TrieSearch<SeparatelyQuantize>; + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0f720217..0a52acb5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,72 +13,88 @@ struct Backing; class SortedVocabulary; namespace trie { -struct TrieSearch { - typedef NodeRange Node; +template <class Quant> class TrieSearch; +template <class Quant> void BuildTrie(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); - typedef ::lm::ngram::trie::Unigram Unigram; - Unigram unigram; +template <class Quant> class TrieSearch { + public: + typedef NodeRange Node; - typedef trie::BitPackedMiddle Middle; - std::vector<Middle> middle; + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram; - typedef trie::BitPackedLongest Longest; - Longest longest; + typedef trie::BitPackedMiddle<typename Quant::Middle> Middle; - static const ModelType kModelType = TRIE_SORTED; + typedef trie::BitPackedLongest<typename Quant::Longest> Longest; + Longest longest; - static std::size_t Size(const std::vector<uint64_t> &counts, const Config &/*config*/) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(counts[i], counts[0], counts[i+1]); + static const ModelType kModelType = Quant::kModelType; + + static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) { + Quant::UpdateConfigFromBinary(fd, counts, config); } - return ret + Longest::Size(counts.back(), counts[0]); - } - - uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &/*config*/) { - unigram.Init(start); - start += Unigram::Size(counts[0]); - middle.resize(counts.size() - 2); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - middle[i-1].Init( - start, - counts[0], - counts[i+1], - (i == counts.size() - 2) ? static_cast<const BitPacked&>(longest) : static_cast<const BitPacked &>(middle[i])); - start += Middle::Size(counts[i], counts[0], counts[i+1]); + + static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) { + std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + } + return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } - longest.Init(start, counts[0]); - return start + Longest::Size(counts.back(), counts[0]); - } - - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); - } - - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); - } - - bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { - return mid.FindNoProb(word, backoff, node); - } - - bool LookupLongest(WordIndex word, float &prob, const Node &node) const { - return longest.Find(word, prob, node); - } - - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - // TODO: don't decode backoff. - assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); - for (const WordIndex *i = begin + 1; i < end; ++i) { - if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; + + TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {} + + ~TrieSearch() { FreeMiddles(); } + + uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); + + void LoadedBinary(); + + const Middle *MiddleBegin() const { return middle_begin_; } + const Middle *MiddleEnd() const { return middle_end_; } + + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + return unigram.Find(word, prob, backoff, node); + } + + bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { + return mid.Find(word, prob, backoff, node); } - return true; - } + + bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { + return mid.FindNoProb(word, backoff, node); + } + + bool LookupLongest(WordIndex word, float &prob, const Node &node) const { + return longest.Find(word, prob, node); + } + + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + // TODO: don't decode backoff. + assert(begin != end); + float ignored_prob, ignored_backoff; + LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + for (const WordIndex *i = begin + 1; i < end; ++i) { + if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; + } + return true; + } + + private: + friend void BuildTrie<Quant>(const std::string &file_prefix, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant> &out, Quant &quant, Backing &backing); + + // Middles are managed manually so we can delay construction and they don't have to be copyable. + void FreeMiddles() { + for (const Middle *i = middle_begin_; i != middle_end_; ++i) { + i->~Middle(); + } + free(middle_begin_); + } + + Middle *middle_begin_, *middle_end_; + Quant quant_; }; } // namespace trie diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 2c633613..63c2a612 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,8 +1,8 @@ #include "lm/trie.hh" +#include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" -#include "util/proxy_iterator.hh" #include "util/sorted_uniform.hh" #include <assert.h> @@ -12,53 +12,32 @@ namespace ngram { namespace trie { namespace { -// Assumes key is first. -class JustKeyProxy { +class KeyAccessor { public: - JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {} + KeyAccessor(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) + : base_(reinterpret_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} - operator uint64_t() const { return GetKey(); } + typedef uint64_t Key; - uint64_t GetKey() const { - uint64_t bit_off = inner_ * static_cast<uint64_t>(total_bits_); - return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_); + Key operator()(uint64_t index) const { + return util::ReadInt57(base_, index * static_cast<uint64_t>(total_bits_), key_bits_, key_mask_); } private: - friend class util::ProxyIterator<JustKeyProxy>; - friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); - - JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) - : inner_(index), base_(static_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} - - // This is a read-only iterator. - JustKeyProxy &operator=(const JustKeyProxy &other); - - typedef uint64_t value_type; - - typedef uint64_t InnerIterator; - uint64_t &Inner() { return inner_; } - const uint64_t &Inner() const { return inner_; } - - // The address in bits is base_ * 8 + inner_ * total_bits_. - uint64_t inner_; const uint8_t *const base_; - const uint64_t key_mask_; + const WordIndex key_mask_; const uint8_t key_bits_, total_bits_; }; -bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { - util::ProxyIterator<JustKeyProxy> begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits)); - util::ProxyIterator<JustKeyProxy> end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits)); - util::ProxyIterator<JustKeyProxy> out; - if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false; - at_index = out.Inner(); +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const uint64_t key, uint64_t &at_index) { + KeyAccessor accessor(base, key_mask, key_bits, total_bits); + if (!util::BoundedSortedUniformFind<uint64_t, KeyAccessor, util::PivotSelect<sizeof(WordIndex)>::T>(accessor, begin_index - 1, (uint64_t)0, end_index, max_vocab, key, at_index)) return false; return true; } } // namespace std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { - uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits; + uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits; // Extra entry for next pointer at the end. // +7 then / 8 to round up bits and convert to bytes // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault. @@ -71,100 +50,96 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) word_bits_ = util::RequiredBits(max_vocab); word_mask_ = (1ULL << word_bits_) - 1ULL; if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions."); - prob_bits_ = 31; - total_bits_ = word_bits_ + prob_bits_ + remaining_bits; + total_bits_ = word_bits_ + remaining_bits; base_ = static_cast<uint8_t*>(base); insert_index_ = 0; + max_vocab_ = max_vocab; } -std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { - return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr)); +template <class Quant> std::size_t BitPackedMiddle<Quant>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { + return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); } -void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) { - next_source_ = &next_source; - backoff_bits_ = 32; - next_bits_ = util::RequiredBits(max_next); +template <class Quant> BitPackedMiddle<Quant>::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - next_mask_ = (1ULL << next_bits_) - 1; - - BaseInit(base, max_vocab, backoff_bits_ + next_bits_); + BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); } -void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template <class Quant> void BitPackedMiddle<Quant>::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word); + util::WriteInt57(base_, at_pointer, word_bits_, word); at_pointer += word_bits_; - util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); - at_pointer += prob_bits_; - util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); - at_pointer += backoff_bits_; + quant_.Write(base_, at_pointer, prob, backoff); + at_pointer += quant_.TotalBits(); uint64_t next = next_source_->InsertIndex(); assert(next <= next_mask_); - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next); + util::WriteInt57(base_, at_pointer, next_bits_, next); ++insert_index_; } -bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template <class Quant> bool BitPackedMiddle<Quant>::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) { + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } at_pointer *= total_bits_; at_pointer += word_bits_; - prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += prob_bits_; - backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + quant_.Read(base_, at_pointer, prob, backoff); + at_pointer += quant_.TotalBits(); + + range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); return true; } -bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { +template <class Quant> bool BitPackedMiddle<Quant>::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer *= total_bits_; at_pointer += word_bits_; - at_pointer += prob_bits_; - backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + quant_.ReadBackoff(base_, at_pointer, backoff); + at_pointer += quant_.TotalBits(); + range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); return true; } -void BitPackedMiddle::FinishedLoading(uint64_t next_end) { +template <class Quant> void BitPackedMiddle<Quant>::FinishedLoading(uint64_t next_end) { assert(next_end <= next_mask_); uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end); + util::WriteInt57(base_, last_next_write, next_bits_, next_end); } -void BitPackedLongest::Insert(WordIndex index, float prob) { +template <class Quant> void BitPackedLongest<Quant>::Insert(WordIndex index, float prob) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index); + util::WriteInt57(base_, at_pointer, word_bits_, index); at_pointer += word_bits_; - util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); + quant_.Write(base_, at_pointer, prob); ++insert_index_; } -bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { +template <class Quant> bool BitPackedLongest<Quant>::Find(WordIndex word, float &prob, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer = at_pointer * total_bits_ + word_bits_; - prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); + quant_.Read(base_, at_pointer, prob); return true; } +template class BitPackedMiddle<DontQuantize::Middle>; +template class BitPackedMiddle<SeparatelyQuantize::Middle>; +template class BitPackedLongest<DontQuantize::Longest>; +template class BitPackedLongest<SeparatelyQuantize::Longest>; + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 6aef050c..8fa21aaf 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -74,23 +74,21 @@ class BitPacked { void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); - uint8_t word_bits_, prob_bits_; + uint8_t word_bits_; uint8_t total_bits_; uint64_t word_mask_; uint8_t *base_; - uint64_t insert_index_; + uint64_t insert_index_, max_vocab_; }; -class BitPackedMiddle : public BitPacked { +template <class Quant> class BitPackedMiddle : public BitPacked { public: - BitPackedMiddle() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); // next_source need not be initialized. - void Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); void Insert(WordIndex word, float prob, float backoff); @@ -101,28 +99,33 @@ class BitPackedMiddle : public BitPacked { void FinishedLoading(uint64_t next_end); private: - uint8_t backoff_bits_, next_bits_; + Quant quant_; + uint8_t next_bits_; uint64_t next_mask_; const BitPacked *next_source_; }; -class BitPackedLongest : public BitPacked { +template <class Quant> class BitPackedLongest : public BitPacked { public: - BitPackedLongest() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab) { - return BaseSize(entries, max_vocab, 0); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { + return BaseSize(entries, max_vocab, quant_bits); } - void Init(void *base, uint64_t max_vocab) { - return BaseInit(base, max_vocab, 0); + BitPackedLongest() {} + + void Init(void *base, const Quant &quant, uint64_t max_vocab) { + quant_ = quant; + BaseInit(base, max_vocab, quant_.TotalBits()); } void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; + + private: + Quant quant_; }; } // namespace trie diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 515af5db..7defd5c1 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -28,8 +28,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5); // Sadly some LMs have <UNK>. const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5); -void ReadWords(int fd, EnumerateVocab *enumerate) { - if (!enumerate) return; +WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { + if (!enumerate) return std::numeric_limits<WordIndex>::max(); const std::size_t kInitialRead = 16384; std::string buf; buf.reserve(kInitialRead + 100); @@ -38,7 +38,7 @@ void ReadWords(int fd, EnumerateVocab *enumerate) { while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (got == 0) return; + if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; @@ -87,13 +87,13 @@ SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { // Lead with the number of entries. - return sizeof(uint64_t) + sizeof(Entry) * entries; + return sizeof(uint64_t) + sizeof(uint64_t) * entries; } void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) { assert(allocated >= Size(entries, config)); // Leave space for number of entries. - begin_ = reinterpret_cast<Entry*>(reinterpret_cast<uint64_t*>(start) + 1); + begin_ = reinterpret_cast<uint64_t*>(start) + 1; end_ = begin_; saw_unk_ = false; } @@ -112,7 +112,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { saw_unk_ = true; return 0; } - end_->key = hashed; + *end_ = hashed; if (enumerate_) { strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); } @@ -134,8 +134,10 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { util::JointSort(begin_, end_, reorder_vocab + 1); } SetSpecial(Index("<s>"), Index("</s>"), 0); - // Save size. + // Save size. Excludes UNK. *(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_; + // Includes UNK. + bound_ = end_ - begin_ + 1; } void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { @@ -183,7 +185,7 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { lookup_.LoadedBinary(); - ReadWords(fd, to); + available_ = ReadWords(fd, to); SetSpecial(Index("<s>"), Index("</s>"), 0); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 546c1649..c92518e4 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -9,6 +9,7 @@ #include "util/sorted_uniform.hh" #include "util/string_piece.hh" +#include <limits> #include <string> #include <vector> @@ -44,22 +45,16 @@ class WriteWordsWrapper : public EnumerateVocab { // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. class SortedVocabulary : public base::Vocabulary { - private: - // Sorted uniform requires a GetKey function. - struct Entry { - uint64_t GetKey() const { return key; } - uint64_t key; - bool operator<(const Entry &other) const { - return key < other.key; - } - }; - public: SortedVocabulary(); WordIndex Index(const StringPiece &str) const { - const Entry *found; - if (util::SortedUniformFind<const Entry *, uint64_t>(begin_, end_, detail::HashForVocab(str), found)) { + const uint64_t *found; + if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>( + util::IdentityAccessor<uint64_t>(), + begin_ - 1, 0, + end_, std::numeric_limits<uint64_t>::max(), + detail::HashForVocab(str), found)) { return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table. } else { return 0; @@ -68,6 +63,10 @@ class SortedVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. + // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases. + WordIndex Bound() const { return bound_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); @@ -83,7 +82,11 @@ class SortedVocabulary : public base::Vocabulary { void LoadedBinary(int fd, EnumerateVocab *to); private: - Entry *begin_, *end_; + uint64_t *begin_, *end_; + + WordIndex bound_; + + WordIndex highest_value_; bool saw_unk_; @@ -105,6 +108,12 @@ class ProbingVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()). + // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. + // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. + // Specifically, the binary file format does not currently indicate whether <unk> is in count or not. + WordIndex Bound() const { return available_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc index 681da5f2..41999b72 100644 --- a/klm/util/bit_packing.cc +++ b/klm/util/bit_packing.cc @@ -28,10 +28,10 @@ void BitPackingSanity() { memset(mem, 0, sizeof(mem)); const uint64_t test57 = 0x123456789abcdefULL; for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + b / 8, b % 8, 57, test57); + WriteInt57(mem, b, 57, test57); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)) + if (test57 != ReadInt57(mem, b, 57, (1ULL << 57) - 1)) UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler."); } // TODO: more checks. diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 5c71c792..b35d80c8 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -42,47 +42,62 @@ inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { #error "Bit packing code isn't written for your byte order." #endif +inline uint64_t ReadOff(const void *base, uint64_t bit_off) { + return *reinterpret_cast<const uint64_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)); +} + /* Pack integers up to 57 bits using their least significant digits. * The length is specified using mask: * Assumes mask == (1 << length) - 1 where length <= 57. */ -inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) { - return (*reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, length)) & mask; +inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, uint64_t mask) { + return (ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, length)) & mask; } /* Assumes value < (1 << length) and length <= 57. * Assumes the memory is zero initially. */ -inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) { - *reinterpret_cast<uint64_t*>(base) |= (value << BitPackShift(bit, length)); +inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) { + *reinterpret_cast<uint64_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); +} + +/* Same caveats as above, but for a 25 bit limit. */ +inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) { + return (*reinterpret_cast<const uint32_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask; +} + +inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) { + *reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); } typedef union { float f; uint32_t i; } FloatEnc; -inline float ReadFloat32(const void *base, uint8_t bit) { +inline float ReadFloat32(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 32); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32); return encoded.f; } -inline void WriteFloat32(void *base, uint8_t bit, float value) { +inline void WriteFloat32(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; - WriteInt57(base, bit, 32, encoded.i); + WriteInt57(base, bit_off, 32, encoded.i); } const uint32_t kSignBit = 0x80000000; -inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { +inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast<const uint64_t*>(base) >> BitPackShift(bit, 31); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); // Sign bit set means negative. encoded.i |= kSignBit; return encoded.f; } -inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { +inline void WriteNonPositiveFloat31(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; encoded.i &= ~kSignBit; - WriteInt57(base, bit, 31, encoded.i); + WriteInt57(base, bit_off, 31, encoded.i); } void BitPackingSanity(); diff --git a/klm/util/bit_packing_test.cc b/klm/util/bit_packing_test.cc index c578ddd1..4edc2004 100644 --- a/klm/util/bit_packing_test.cc +++ b/klm/util/bit_packing_test.cc @@ -9,15 +9,16 @@ namespace util { namespace { const uint64_t test57 = 0x123456789abcdefULL; +const uint32_t test25 = 0x1234567; -BOOST_AUTO_TEST_CASE(ZeroBit) { +BOOST_AUTO_TEST_CASE(ZeroBit57) { char mem[16]; memset(mem, 0, sizeof(mem)); WriteInt57(mem, 0, 57, test57); BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1)); } -BOOST_AUTO_TEST_CASE(EachBit) { +BOOST_AUTO_TEST_CASE(EachBit57) { char mem[16]; for (uint8_t b = 0; b < 8; ++b) { memset(mem, 0, sizeof(mem)); @@ -26,15 +27,27 @@ BOOST_AUTO_TEST_CASE(EachBit) { } } -BOOST_AUTO_TEST_CASE(Consecutive) { +BOOST_AUTO_TEST_CASE(Consecutive57) { char mem[57+8]; memset(mem, 0, sizeof(mem)); for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + (b / 8), b % 8, 57, test57); - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + WriteInt57(mem, b, 57, test57); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); + } +} + +BOOST_AUTO_TEST_CASE(Consecutive25) { + char mem[25+8]; + memset(mem, 0, sizeof(mem)); + for (uint64_t b = 0; b < 25 * 8; b += 25) { + WriteInt25(mem, b, 25, test25); + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); + } + for (uint64_t b = 0; b < 25 * 8; b += 25) { + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); } } diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 05826b51..84d7aa02 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -9,52 +9,96 @@ namespace util { -inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) { - std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width)); - // Cap for floating point rounding - return (ret < width) ? ret : width - 1; -} -/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) { - return static_cast<std::size_t>(static_cast<uint64_t>(off) * static_cast<uint64_t>(width) / static_cast<uint64_t>(range)); +template <class T> class IdentityAccessor { + public: + typedef T Key; + T operator()(const uint64_t *in) const { return *in; } +}; + +struct Pivot64 { + static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) { + std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width)); + // Cap for floating point rounding + return (ret < width) ? ret : width - 1; + } +}; + +// Use when off * width is <2^64. This is guaranteed when each of them is actually a 32-bit value. +struct Pivot32 { + static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) { + return static_cast<std::size_t>((off * width) / (range + 1)); + } +}; + +// Usage: PivotSelect<sizeof(DataType)>::T +template <unsigned> struct PivotSelect; +template <> struct PivotSelect<8> { typedef Pivot64 T; }; +template <> struct PivotSelect<4> { typedef Pivot32 T; }; +template <> struct PivotSelect<2> { typedef Pivot32 T; }; + +/* Binary search. */ +template <class Iterator, class Accessor> bool BinaryFind( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key, Iterator &out) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) { - return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range)); + +// Search the range [before_it + 1, after_it - 1] for key. +// Preconditions: +// before_v <= key <= after_v +// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v +// range is sorted. +template <class Iterator, class Accessor, class Pivot> bool BoundedSortedUniformFind( + const Accessor &accessor, + Iterator before_it, typename Accessor::Key before_v, + Iterator after_it, typename Accessor::Key after_v, + const typename Accessor::Key key, Iterator &out) { + while (after_it - before_it > 1) { + Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1))); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + before_it = pivot; + before_v = mid; + } else if (mid > key) { + after_it = pivot; + after_v = mid; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) { - return static_cast<std::size_t>(static_cast<std::size_t>(off) * width / static_cast<std::size_t>(range)); -}*/ -template <class Iterator, class Key> bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) { +template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) { if (begin == end) return false; - Key below(begin->GetKey()); + typename Accessor::Key below(accessor(begin)); if (key <= below) { if (key == below) { out = begin; return true; } return false; } // Make the range [begin, end]. --end; - Key above(end->GetKey()); + typename Accessor::Key above(accessor(end)); if (key >= above) { if (key == above) { out = end; return true; } return false; } - - // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above. - while (end - begin > 1) { - Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast<std::size_t>(end - begin - 1)))); - Key mid(pivot->GetKey()); - if (mid < key) { - begin = pivot; - below = mid; - } else if (mid > key) { - end = pivot; - above = mid; - } else { - out = pivot; - return true; - } - } - return false; + return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out); } // To use this template, you need to define a Pivot function to match Key. @@ -64,7 +108,13 @@ template <class PackingT> class SortedUniformMap { typedef typename Packing::ConstIterator ConstIterator; typedef typename Packing::MutableIterator MutableIterator; - public: + struct Accessor { + public: + typedef typename Packing::Key Key; + const Key &operator()(const ConstIterator &i) const { return i->GetKey(); } + Key &operator()(const MutableIterator &i) const { return i->GetKey(); } + }; + // Offer consistent API with probing hash. static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) { return sizeof(uint64_t) + entries * Packing::kBytes; @@ -120,7 +170,7 @@ template <class PackingT> class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind<MutableIterator, Key>(begin_, end_, key, out); + return SortedUniformFind<MutableIterator, Accessor, Pivot64>(begin_, end_, key, out); } // Do not call before FinishedInserting. @@ -129,7 +179,7 @@ template <class PackingT> class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind<ConstIterator, Key>(ConstIterator(begin_), ConstIterator(end_), key, out); + return SortedUniformFind<ConstIterator, Accessor, Pivot64>(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out); } ConstIterator begin() const { return begin_; } |