From b9d7da0413403805f035479a0a426c27102032f6 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 3 Jun 2011 20:49:52 -0400 Subject: Add exception catcher around constructor --- decoder/ff_klm.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 35b35d36..71ba9f30 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -385,7 +385,12 @@ KLanguageModel::KLanguageModel(const string& param) { if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { abort(); } - pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + try { + pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } fid_ = FD::Convert(featname); oov_fid_ = FD::Convert(featname+"_OOV"); cerr << "FID: " << oov_fid_ << endl; -- cgit v1.2.3 From dcf4447590277887d65b0bdec7e6818081869a9a Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 3 Jun 2011 20:56:37 -0400 Subject: Code cleanup for vocabulary mapping --- decoder/ff_klm.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 71ba9f30..a3bd0c5f 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -282,11 +282,10 @@ class KLanguageModelImpl { KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : kCDEC_UNK(TD::Convert("")) , add_sos_eos_(!explicit_markers) { - if (true) { - boost::scoped_ptr vm; - vm.reset(new VMapper(&cdec2klm_map_)); + { + VMapper vm(&cdec2klm_map_); lm::ngram::Config conf; - conf.enumerate_vocab = vm.get(); + conf.enumerate_vocab = &vm; ngram_ = new Model(filename.c_str(), conf); } order_ = ngram_->Order(); -- cgit v1.2.3 From 205893513c8343fdc55789e427fab4c8b536dc12 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 26 Jun 2011 18:40:15 -0400 Subject: Quantization --- BUILDING | 5 +- decoder/cdec_ff.cc | 1 + decoder/ff_klm.cc | 1 + klm/compile.sh | 4 +- klm/lm/Makefile.am | 1 + klm/lm/binary_format.cc | 17 +++- klm/lm/binary_format.hh | 10 ++- klm/lm/blank.hh | 4 + klm/lm/build_binary.cc | 83 +++++++++++++---- klm/lm/config.cc | 2 + klm/lm/config.hh | 5 ++ klm/lm/model.cc | 25 +++--- klm/lm/model.hh | 21 +++-- klm/lm/model_test.cc | 13 +-- klm/lm/quantize.cc | 84 ++++++++++++++++++ klm/lm/quantize.hh | 207 +++++++++++++++++++++++++++++++++++++++++++ klm/lm/search_hashed.cc | 38 ++++++-- klm/lm/search_hashed.hh | 122 ++++++++++++------------- klm/lm/search_trie.cc | 121 ++++++++++++++++++++----- klm/lm/search_trie.hh | 132 +++++++++++++++------------ klm/lm/trie.cc | 123 ++++++++++--------------- klm/lm/trie.hh | 33 +++---- klm/lm/vocab.cc | 18 ++-- klm/lm/vocab.hh | 35 +++++--- klm/util/bit_packing.cc | 4 +- klm/util/bit_packing.hh | 39 +++++--- klm/util/bit_packing_test.cc | 25 ++++-- klm/util/sorted_uniform.hh | 120 +++++++++++++++++-------- 28 files changed, 921 insertions(+), 372 deletions(-) create mode 100644 klm/lm/quantize.cc create mode 100644 klm/lm/quantize.hh (limited to 'decoder/ff_klm.cc') diff --git a/BUILDING b/BUILDING index dcb3d45b..b7535d70 100644 --- a/BUILDING +++ b/BUILDING @@ -1,6 +1,5 @@ To build cdec, you'll need: - * SRILM (register and download from http://www.speech.sri.com/projects/srilm/) * Google c++ testing framework (http://code.google.com/p/googletest/) * boost headers & boost program_options (you may need to install a package like boost-devel) @@ -9,7 +8,7 @@ To build cdec, you'll need: Instructions for building ----------------------------------- - 1) Download and build SRILM + 1) Optional: Download and build SRILM 2) Download, build, and install Google Test (optional, this is necessary to build unit tests that may be useful in development; system tests @@ -22,7 +21,7 @@ Instructions for building 4) Configure and build. Your command will look something like this. - ./configure --with-srilm=/home/me/software/srilm-1.5.9 --disable-gtest + ./configure --disable-gtest make If you get errors during configure about missing BOOST macros, then step 3 diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 37aa655b..31f88a4f 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -55,6 +55,7 @@ void register_feature_functions() { ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); ff_registry.Register("KLanguageModel", new FFFactory >()); ff_registry.Register("KLanguageModel_Trie", new FFFactory >()); + ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory >()); ff_registry.Register("KLanguageModel_Probing", new FFFactory >()); ff_registry.Register("NonLatinCount", new FFFactory); ff_registry.Register("RuleShape", new FFFactory); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index a3bd0c5f..9b7fe2d3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -437,4 +437,5 @@ void KLanguageModel::FinalTraversalFeatures(const void* ant_state, // instantiate templates template class KLanguageModel; template class KLanguageModel; +template class KLanguageModel; diff --git a/klm/compile.sh b/klm/compile.sh index 49e04db8..6ca85e1f 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -3,11 +3,9 @@ #If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU #I use zlib by default. If you don't want to depend on zlib, remove #define USE_ZLIB from util/file_piece.hh -#don't need to use if compiling with moses Makefiles already - set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o done g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 61d98d97..395494bc 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -15,6 +15,7 @@ libklm_a_SOURCES = \ binary_format.cc \ config.cc \ lm_exception.cc \ + quantize.cc \ model.cc \ ngram_query.cc \ read_arpa.cc \ diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 34d9ffca..92b1008b 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -80,6 +80,14 @@ void WriteHeader(void *to, const Parameters ¶ms) { } // namespace +void SeekOrThrow(int fd, off_t off) { + if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed"); +} + +void AdvanceOrThrow(int fd, off_t off) { + if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed"); +} + uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) { if (config.write_mmap) { std::size_t total = TotalHeaderSize(order) + memory_size; @@ -156,7 +164,7 @@ bool IsBinaryFormat(int fd) { } void ReadHeader(int fd, Parameters &out) { - if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file"); + SeekOrThrow(fd, sizeof(Sanity)); ReadLoop(fd, &out.fixed, sizeof(out.fixed)); if (out.fixed.probing_multiplier < 1.0) UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); @@ -173,6 +181,10 @@ void MatchCheck(ModelType model_type, const Parameters ¶ms) { } } +void SeekPastHeader(int fd, const Parameters ¶ms) { + SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); +} + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { const off_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. @@ -186,8 +198,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); if (config.enumerate_vocab) { - if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET)) - UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words"); + SeekOrThrow(backing.file.get(), total_map); } return reinterpret_cast(backing.search.get()) + TotalHeaderSize(params.counts.size()); } diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 1fc71be4..2b32b450 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,7 @@ namespace lm { namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2} ModelType; +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in @@ -48,6 +48,10 @@ struct Backing { util::scoped_memory search; }; +void SeekOrThrow(int fd, off_t off); +// Seek forward +void AdvanceOrThrow(int fd, off_t off); + // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. @@ -65,6 +69,8 @@ void ReadHeader(int fd, Parameters ¶ms); void MatchCheck(ModelType model_type, const Parameters ¶ms); +void SeekPastHeader(int fd, const Parameters ¶ms); + uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing); void ComplainAboutARPA(const Config &config, ModelType model_type); @@ -83,6 +89,8 @@ template void LoadLM(const char *file, const Config &config, To &to) // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; + detail::SeekPastHeader(backing.file.get(), params); + To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config); std::size_t memory_size = To::Size(params.counts, new_config); uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing); to.InitializeFromBinary(start, params, new_config, backing.file.get()); diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 4615a09e..162411a9 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -22,6 +22,8 @@ namespace ngram { */ const float kNoExtensionBackoff = -0.0; const float kExtensionBackoff = 0.0; +const uint64_t kNoExtensionQuant = 0; +const uint64_t kExtensionQuant = 1; inline void SetExtension(float &backoff) { if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff; @@ -47,6 +49,8 @@ inline bool HasExtension(const float &backoff) { */ const float kBlankProb = -std::numeric_limits::infinity(); const float kBlankBackoff = kNoExtensionBackoff; +const uint32_t kBlankProbQuant = 0; +const uint32_t kBlankBackoffQuant = 0; } // namespace ngram } // namespace lm diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 91ad2fb9..4552c419 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,22 +15,21 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [type] input.arpa output.mmap\n\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" "-u sets the default log10 probability for if the ARPA file does not have\n" "one.\n" "-s allows models to be built even if they do not have and .\n" "-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is one of probing, trie, or sorted:\n\n" +"type is either probing or trie:\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" "memory and is still faster than SRI or IRST. Building the trie format uses an\n" "on-disk sort to save memory.\n" "-t is the temporary directory prefix. Default is the output file name.\n" -"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n\n" -/*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n" -"It uses more memory than trie and is also slower, so there's no real reason to\n" -"use it.\n\n"*/ +"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" +"-q turns quantization on and sets the number of bits (e.g. -q 8).\n" +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n" "See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" "Passing only an input file will print memory usage of each data structure.\n" "If the ARPA file does not have , -u sets 's probability; default 0.0.\n"; @@ -51,19 +50,53 @@ unsigned long int ParseUInt(const char *from) { return ret; } +uint8_t ParseBitCount(const char *from) { + unsigned long val = ParseUInt(from); + if (val > 25) { + util::ParseNumberException e(from); + e << " bit counts are limited to 256."; + } + return val; +} + void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t probing_size = ProbingModel::Size(counts, config); - // probing is always largest so use it to determine number of columns. - long int length = std::max(5, lrint(ceil(log10(probing_size)))); + std::size_t sizes[3]; + sizes[0] = ProbingModel::Size(counts, config); + sizes[1] = TrieModel::Size(counts, config); + sizes[2] = QuantTrieModel::Size(counts, config); + std::size_t max_length = *std::max_element(sizes, sizes + 3); + std::size_t min_length = *std::max_element(sizes, sizes + 3); + std::size_t divide; + char prefix; + if (min_length < (1 << 10) * 10) { + prefix = ' '; + divide = 1; + } else if (min_length < (1 << 20) * 10) { + prefix = 'k'; + divide = 1 << 10; + } else if (min_length < (1ULL << 30) * 10) { + prefix = 'M'; + divide = 1 << 20; + } else { + prefix = 'G'; + divide = 1 << 30; + } + long int length = std::max(2, lrint(ceil(log10(max_length / divide)))); std::cout << "Memory estimate:\ntype "; // right align bytes. - for (long int i = 0; i < length - 5; ++i) std::cout << ' '; - std::cout << "bytes\n" - "probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n" - "trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n"; + for (long int i = 0; i < length - 2; ++i) std::cout << ' '; + std::cout << prefix << "B\n" + "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" + "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" + "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; +} + +void ProbingQuantizationUnsupported() { + std::cerr << "Quantization is only implemented in the trie data structure." << std::endl; + exit(1); } } // namespace ngram @@ -73,11 +106,21 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { int main(int argc, char *argv[]) { using namespace lm::ngram; + bool quantize = false, set_backoff_bits = false; try { lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { switch(opt) { + case 'q': + config.prob_bits = ParseBitCount(optarg); + if (!set_backoff_bits) config.backoff_bits = config.prob_bits; + quantize = true; + break; + case 'b': + config.backoff_bits = ParseBitCount(optarg); + set_backoff_bits = true; + break; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -100,19 +143,29 @@ int main(int argc, char *argv[]) { Usage(argv[0]); } } + if (!quantize && set_backoff_bits) { + std::cerr << "You specified backoff quantization (-b) but not probability quantization (-q)" << std::endl; + abort(); + } if (optind + 1 == argc) { ShowSizes(argv[optind], config); } else if (optind + 2 == argc) { config.write_mmap = argv[optind + 1]; + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(argv[optind], config); } else if (optind + 3 == argc) { const char *model_type = argv[optind]; const char *from_file = argv[optind + 1]; config.write_mmap = argv[optind + 2]; if (!strcmp(model_type, "probing")) { + if (quantize || set_backoff_bits) ProbingQuantizationUnsupported(); ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { - TrieModel(from_file, config); + if (quantize) { + QuantTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } else { Usage(argv[0]); } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index cee8fce2..08e1af5c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -18,6 +18,8 @@ Config::Config() : arpa_complain(ALL), write_mmap(NULL), include_vocab(true), + prob_bits(8), + backoff_bits(8), load_method(util::POPULATE_OR_READ) {} } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 6c7fe39b..dcc7cf35 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -71,6 +71,11 @@ struct Config { // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; + // Quantization options. Only effective for QuantTrieModel. One value is + // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used + // to quantize. + uint8_t prob_bits, backoff_bits; + // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index f0579c0c..a1d10b3d 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -44,17 +44,13 @@ template GenericModel::Ge begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; State null_context = State(); null_context.valid_length_ = 0; - P::Init(begin_sentence, null_context, vocab_, search_.middle.size() + 2); + P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } template void GenericModel::InitializeFromBinary(void *start, const Parameters ¶ms, const Config &config, int fd) { SetupMemory(start, params.counts, config); vocab_.LoadedBinary(fd, config.enumerate_vocab); - search_.unigram.LoadedBinary(); - for (typename std::vector::iterator i = search_.middle.begin(); i != search_.middle.end(); ++i) { - i->LoadedBinary(); - } - search_.longest.LoadedBinary(); + search_.LoadedBinary(); } template void GenericModel::InitializeFromARPA(const char *file, const Config &config) { @@ -116,8 +112,9 @@ template FullScoreReturn GenericModel void GenericModel FullScoreReturn GenericModel::const_iterator mid_iter = search_.middle.begin(); + const typename Search::Middle *mid_iter = search_.MiddleBegin(); for (; ; ++mid_iter, ++hist_iter, ++backoff_out) { if (hist_iter == context_rend) { // Ran out of history. Typically no backoff, but this could be a blank. @@ -192,7 +189,7 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel; -template class GenericModel; -template class GenericModel; +template class GenericModel; // HASH_PROBING +template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT } // namespace detail } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index b85ccdcc..1f49a382 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -5,6 +5,7 @@ #include "lm/config.hh" #include "lm/facade.hh" #include "lm/max_order.hh" +#include "lm/quantize.hh" #include "lm/search_hashed.hh" #include "lm/search_trie.hh" #include "lm/vocab.hh" @@ -70,9 +71,10 @@ template class GenericModel : public base::Mod private: typedef base::ModelFacade, State, VocabularyT> P; public: - // Get the size of memory that will be mapped given ngram counts. This - // does not include small non-mapped control structures, such as this class - // itself. + /* Get the size of memory that will be mapped given ngram counts. This + * does not include small non-mapped control structures, such as this class + * itself. + */ static size_t Size(const std::vector &counts, const Config &config = Config()); /* Load the model from a file. It may be an ARPA or binary file. Binary @@ -111,6 +113,11 @@ template class GenericModel : public base::Mod private: friend void LoadLM<>(const char *file, const Config &config, GenericModel &to); + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { + AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config)); + Search::UpdateConfigFromBinary(fd, counts, config); + } + float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const; FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; @@ -130,9 +137,7 @@ template class GenericModel : public base::Mod VocabularyT vocab_; - typedef typename Search::Unigram Unigram; typedef typename Search::Middle Middle; - typedef typename Search::Longest Longest; Search search_; }; @@ -141,13 +146,15 @@ template class GenericModel : public base::Mod // These must also be instantiated in the cc file. typedef ::lm::ngram::ProbingVocabulary Vocabulary; -typedef detail::GenericModel ProbingModel; +typedef detail::GenericModel ProbingModel; // HASH_PROBING // Default implementation. No real reason for it to be the default. typedef ProbingModel Model; // Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel TrieModel; +typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED + +typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 548c098d..8bf040ff 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -243,13 +243,14 @@ BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } -/*BOOST_AUTO_TEST_CASE(sorted) { - LoadingTest(); -}*/ BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } +BOOST_AUTO_TEST_CASE(quant) { + LoadingTest(); +} + template void BinaryTest() { Config config; config.write_mmap = "test.binary"; @@ -275,12 +276,12 @@ template void BinaryTest() { BOOST_AUTO_TEST_CASE(write_and_read_probing) { BinaryTest(); } -/*BOOST_AUTO_TEST_CASE(write_and_read_sorted) { - BinaryTest(); -}*/ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BinaryTest(); } +BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { + BinaryTest(); +} } // namespace } // namespace ngram diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc new file mode 100644 index 00000000..b4d76893 --- /dev/null +++ b/klm/lm/quantize.cc @@ -0,0 +1,84 @@ +#include "lm/quantize.hh" + +#include "lm/lm_exception.hh" + +#include +#include + +#include + +namespace lm { +namespace ngram { + +/* Quantize into bins of equal size as described in + * M. Federico and N. Bertoldi. 2006. How many bits are needed + * to store probabilities for phrase-based translation? In Proc. + * of the Workshop on Statistical Machine Translation, pages + * 94–101, New York City, June. Association for Computa- + * tional Linguistics. + */ + +namespace { + +void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) { + std::sort(values, values_end); + const float *start = values, *finish; + for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) { + finish = values + (((values_end - values) * static_cast(i + 1)) / bins); + if (finish == start) { + // zero length bucket. + *centers = i ? *(centers - 1) : -std::numeric_limits::infinity(); + } else { + *centers = std::accumulate(start, finish, 0.0) / static_cast(finish - start); + } + } +} + +const char kSeparatelyQuantizeVersion = 1; + +} // namespace + +void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector &/*counts*/, Config &config) { + char version; + if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1) + UTIL_THROW(util::ErrnoException, "Failed to read header for quantization."); + if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion); +} + +void SeparatelyQuantize::SetupMemory(void *start, const Config &config) { + // Reserve 8 byte header for bit counts. + start_ = reinterpret_cast(static_cast(start) + 8); + prob_bits_ = config.prob_bits; + backoff_bits_ = config.backoff_bits; + // We need the reserved values. + if (config.prob_bits == 0) UTIL_THROW(ConfigException, "You can't quantize probability to zero"); + if (config.backoff_bits == 0) UTIL_THROW(ConfigException, "You can't quantize backoff to zero"); + if (config.prob_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing probability supports at most 25 bits. Currently you have requested " << static_cast(config.prob_bits) << " bits."); + if (config.backoff_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing backoff supports at most 25 bits. Currently you have requested " << static_cast(config.backoff_bits) << " bits."); +} + +void SeparatelyQuantize::Train(uint8_t order, std::vector &prob, std::vector &backoff) { + TrainProb(order, prob); + + // Backoff + float *centers = start_ + TableStart(order) + ProbTableLength(); + *(centers++) = kNoExtensionBackoff; + *(centers++) = kExtensionBackoff; + MakeBins(&*backoff.begin(), &*backoff.end(), centers, (1ULL << backoff_bits_) - 2); +} + +void SeparatelyQuantize::TrainProb(uint8_t order, std::vector &prob) { + float *centers = start_ + TableStart(order); + *(centers++) = kBlankProb; + MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_) - 1); +} + +void SeparatelyQuantize::FinishedLoading(const Config &config) { + uint8_t *actual_base = reinterpret_cast(start_) - 8; + *(actual_base++) = kSeparatelyQuantizeVersion; // version + *(actual_base++) = config.prob_bits; + *(actual_base++) = config.backoff_bits; +} + +} // namespace ngram +} // namespace lm diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh new file mode 100644 index 00000000..aae72b34 --- /dev/null +++ b/klm/lm/quantize.hh @@ -0,0 +1,207 @@ +#ifndef LM_QUANTIZE_H__ +#define LM_QUANTIZE_H__ + +#include "lm/binary_format.hh" // for ModelType +#include "lm/blank.hh" +#include "lm/config.hh" +#include "util/bit_packing.hh" + +#include +#include + +#include + +#include + +namespace lm { +namespace ngram { + +class Config; + +/* Store values directly and don't quantize. */ +class DontQuantize { + public: + static const ModelType kModelType = TRIE_SORTED; + static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} + static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } + static uint8_t MiddleBits(const Config &/*config*/) { return 63; } + static uint8_t LongestBits(const Config &/*config*/) { return 31; } + + struct Middle { + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + util::WriteFloat32(base, bit_offset + 31, backoff); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + backoff = util::ReadFloat32(base, bit_offset + 31); + } + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = util::ReadFloat32(base, bit_offset + 31); + } + uint8_t TotalBits() const { return 63; } + }; + + struct Longest { + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + } + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } + uint8_t TotalBits() const { return 31; } + }; + + DontQuantize() {} + + void SetupMemory(void * /*start*/, const Config & /*config*/) {} + + static const bool kTrain = false; + // These should never be called because kTrain is false. + void Train(uint8_t /*order*/, std::vector &/*prob*/, std::vector &/*backoff*/) {} + void TrainProb(uint8_t, std::vector &/*prob*/) {} + + void FinishedLoading(const Config &) {} + + Middle Mid(uint8_t /*order*/) const { return Middle(); } + Longest Long(uint8_t /*order*/) const { return Longest(); } +}; + +class SeparatelyQuantize { + private: + class Bins { + public: + // Sigh C++ default constructor + Bins() {} + + Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + + uint64_t EncodeProb(float value) const { + return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + } + + uint64_t EncodeBackoff(float value) const { + if (value == 0.0) { + return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant; + } + return Encode(value, 2); + } + + float Decode(std::size_t off) const { return begin_[off]; } + + uint8_t Bits() const { return bits_; } + + uint64_t Mask() const { return mask_; } + + private: + uint64_t Encode(float value, size_t reserved) const { + const float *above = std::lower_bound(begin_ + reserved, end_, value); + if (above == begin_ + reserved) return reserved; + if (above == end_) return end_ - begin_ - 1; + return above - begin_ - (value - *(above - 1) < *above - value); + } + + const float *begin_; + const float *end_; + uint8_t bits_; + uint64_t mask_; + }; + + public: + static const ModelType kModelType = QUANT_TRIE_SORTED; + + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); + + static std::size_t Size(uint8_t order, const Config &config) { + size_t longest_table = (static_cast(1) << static_cast(config.prob_bits)) * sizeof(float); + size_t middle_table = (static_cast(1) << static_cast(config.backoff_bits)) * sizeof(float) + longest_table; + // unigrams are currently not quantized so no need for a table. + return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; + } + + static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } + static uint8_t LongestBits(const Config &config) { return config.prob_bits; } + + class Middle { + public: + Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) : + total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteInt57(base, bit_offset, total_bits_, + (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); + prob = prob_.Decode(both >> backoff_.Bits()); + backoff = backoff_.Decode(both & backoff_.Mask()); + } + + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask())); + } + + uint8_t TotalBits() const { + return total_bits_; + } + + private: + const uint8_t total_bits_; + const uint64_t total_mask_; + const Bins prob_; + const Bins backoff_; + }; + + class Longest { + public: + // Sigh C++ default constructor + Longest() {} + + Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask())); + } + + uint8_t TotalBits() const { return prob_.Bits(); } + + private: + Bins prob_; + }; + + SeparatelyQuantize() {} + + void SetupMemory(void *start, const Config &config); + + static const bool kTrain = true; + // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + void Train(uint8_t order, std::vector &prob, std::vector &backoff); + // Train just probabilities (for longest order). + void TrainProb(uint8_t order, std::vector &prob); + + void FinishedLoading(const Config &config); + + Middle Mid(uint8_t order) const { + const float *table = start_ + TableStart(order); + return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength()); + } + + Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); } + + private: + size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast(order - 2); } + size_t ProbTableLength() const { return (1ULL << prob_bits_); } + + float *start_; + uint8_t prob_bits_, backoff_bits_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_QUANTIZE_H__ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index eaad59ab..c56ba7b8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -80,6 +80,21 @@ template void ReadNGrams( } // namespace namespace detail { + +template uint8_t *TemplateHashedSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { + std::size_t allocated = Unigram::Size(counts[0]); + unigram = Unigram(start, allocated); + start += allocated; + for (unsigned int n = 2; n < counts.size(); ++n) { + allocated = Middle::Size(counts[n - 1], config.probing_multiplier); + middle_.push_back(Middle(start, allocated)); + start += allocated; + } + allocated = Longest::Size(counts.back(), config.probing_multiplier); + longest = Longest(start, allocated); + start += allocated; + return start; +} template template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. @@ -92,15 +107,15 @@ template template void TemplateHashe try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle, ActivateUnigram(unigram.Raw()), middle[0], warn); + ReadNGrams(f, 2, counts[1], vocab, middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle, ActivateLowerMiddle(middle[n-3]), middle[n-2], warn); + ReadNGrams(f, n, counts[n-1], vocab, middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateLowerMiddle(middle.back()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateLowerMiddle(middle_.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle, ActivateUnigram(unigram.Raw()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); @@ -108,13 +123,18 @@ template template void TemplateHashe ReadEnd(f); } -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, SortedVocabulary &vocab, Backing &backing); - -SortedHashedSearch::SortedHashedSearch() { - UTIL_THROW(util::Exception, "Sorted is broken at the moment, sorry"); +template void TemplateHashedSearch::LoadedBinary() { + unigram.LoadedBinary(); + for (typename std::vector::iterator i = middle_.begin(); i != middle_.end(); ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); } +template class TemplateHashedSearch; + +template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); + } // namespace detail } // namespace ngram } // namespace lm diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index 6dc11fb3..f3acdefc 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -8,7 +8,6 @@ #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" -#include "util/sorted_uniform.hh" #include #include @@ -62,73 +61,71 @@ struct HashedSearch { } }; -template struct TemplateHashedSearch : public HashedSearch { - typedef MiddleT Middle; - std::vector middle; +template class TemplateHashedSearch : public HashedSearch { + public: + typedef MiddleT Middle; - typedef LongestT Longest; - Longest longest; + typedef LongestT Longest; + Longest longest; - static std::size_t Size(const std::vector &counts, const Config &config) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char n = 1; n < counts.size() - 1; ++n) { - ret += Middle::Size(counts[n], config.probing_multiplier); - } - return ret + Longest::Size(counts.back(), config.probing_multiplier); - } + // TODO: move probing_multiplier here with next binary file format update. + static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} - uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { - std::size_t allocated = Unigram::Size(counts[0]); - unigram = Unigram(start, allocated); - start += allocated; - for (unsigned int n = 2; n < counts.size(); ++n) { - allocated = Middle::Size(counts[n - 1], config.probing_multiplier); - middle.push_back(Middle(start, allocated)); - start += allocated; + static std::size_t Size(const std::vector &counts, const Config &config) { + std::size_t ret = Unigram::Size(counts[0]); + for (unsigned char n = 1; n < counts.size() - 1; ++n) { + ret += Middle::Size(counts[n], config.probing_multiplier); + } + return ret + Longest::Size(counts.back(), config.probing_multiplier); } - allocated = Longest::Size(counts.back(), config.probing_multiplier); - longest = Longest(start, allocated); - start += allocated; - return start; - } - template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing); + uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config); - bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { - node = CombineWordHash(node, word); - typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - prob = found->GetValue().prob; - backoff = found->GetValue().backoff; - return true; - } + template void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing); - bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { - node = CombineWordHash(node, word); - typename Middle::ConstIterator found; - if (!middle.Find(node, found)) return false; - backoff = found->GetValue().backoff; - return true; - } + const Middle *MiddleBegin() const { return &*middle_.begin(); } + const Middle *MiddleEnd() const { return &*middle_.end(); } - bool LookupLongest(WordIndex word, float &prob, Node &node) const { - node = CombineWordHash(node, word); - typename Longest::ConstIterator found; - if (!longest.Find(node, found)) return false; - prob = found->GetValue().prob; - return true; - } + bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + prob = found->GetValue().prob; + backoff = found->GetValue().backoff; + return true; + } + + void LoadedBinary(); - // Geenrate a node without necessarily checking that it actually exists. - // Optionally return false if it's know to not exist. - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - assert(begin != end); - node = static_cast(*begin); - for (const WordIndex *i = begin + 1; i < end; ++i) { - node = CombineWordHash(node, *i); + bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const { + node = CombineWordHash(node, word); + typename Middle::ConstIterator found; + if (!middle.Find(node, found)) return false; + backoff = found->GetValue().backoff; + return true; } - return true; - } + + bool LookupLongest(WordIndex word, float &prob, Node &node) const { + node = CombineWordHash(node, word); + typename Longest::ConstIterator found; + if (!longest.Find(node, found)) return false; + prob = found->GetValue().prob; + return true; + } + + // Geenrate a node without necessarily checking that it actually exists. + // Optionally return false if it's know to not exist. + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + assert(begin != end); + node = static_cast(*begin); + for (const WordIndex *i = begin + 1; i < end; ++i) { + node = CombineWordHash(node, *i); + } + return true; + } + + private: + std::vector middle_; }; // std::identity is an SGI extension :-( @@ -143,15 +140,6 @@ struct ProbingHashedSearch : public TemplateHashedSearch< static const ModelType kModelType = HASH_PROBING; }; -struct SortedHashedSearch : public TemplateHashedSearch< - util::SortedUniformMap >, - util::SortedUniformMap > > { - - SortedHashedSearch(); - - static const ModelType kModelType = HASH_SORTED; -}; - } // namespace detail } // namespace ngram } // namespace lm diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 7c57072b..1ce4d278 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -4,6 +4,7 @@ #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" +#include "lm/quantize.hh" #include "lm/read_arpa.hh" #include "lm/trie.hh" #include "lm/vocab.hh" @@ -21,6 +22,7 @@ #include #include #include +#include #include #include @@ -579,7 +581,7 @@ bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const W // Phase to count n-grams, including blanks inserted because they were pruned but have extensions class JustCount { public: - JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, BitPackedMiddle * /*middle*/, BitPackedLongest &/*longest*/, uint64_t *counts, unsigned char order) + template JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) : counts_(counts), longest_counts_(counts + order - 1) {} void Unigrams(WordIndex begin, WordIndex end) { @@ -608,9 +610,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -class WriteEntries { +template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -647,14 +649,14 @@ class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; - BitPackedLongest &longest_; + BitPackedMiddle *const middle_; + BitPackedLongest &longest_; BitPacked &bigram_pack_; }; template class RecursiveInsert { public: - RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, uint64_t *counts, unsigned char order) : + template RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { } @@ -775,7 +777,51 @@ void SanityCheckCounts(const std::vector &initial, const std::vector &counts, const Config &config, TrieSearch &out, Backing &backing) { +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} + +template void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + ProbBackoff weights; + std::vector probs, backoffs; + probs.reserve(count); + backoffs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; + } + } + quant.Train(order, probs, backoffs); +} + +template void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { + Prob weights; + std::vector probs, backoffs; + probs.reserve(count); + for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { + uint64_t entries = reader.ReadCount(); + for (uint64_t c = 0; c < entries; ++c) { + reader.ReadWord(); + reader.ReadWeights(weights); + // kBlankProb isn't added yet. + probs.push_back(weights.prob); + ++progress; + } + } + quant.TrainProb(order, probs); +} + +} // namespace + +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing) { std::vector inputs(counts.size() - 1); std::vector contexts(counts.size() - 1); @@ -791,7 +837,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); } for (std::vector::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { @@ -800,7 +846,16 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + + if (Quant::kTrain) { + util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); + for (unsigned char i = 2; i < counts.size(); ++i) { + TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + } + TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); + quant.FinishedLoading(config); + } for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -808,7 +863,7 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert inserter(&*inputs.begin(), &*contexts.begin(), unigrams, &*out.middle.begin(), out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -845,23 +900,44 @@ void BuildTrie(const std::string &file_prefix, std::vector &counts, co /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. - if (!out.middle.empty()) { - for (size_t i = 0; i < out.middle.size() - 1; ++i) { - out.middle[i].FinishedLoading(out.middle[i+1].InsertIndex()); + if (out.middle_begin_ != out.middle_end_) { + for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex()); } - out.middle.back().FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); } } -bool IsDirectory(const char *path) { - struct stat info; - if (0 != stat(path, &info)) return false; - return S_ISDIR(info.st_mode); +template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { + quant_.SetupMemory(start, config); + start += Quant::Size(counts.size(), config); + unigram.Init(start); + start += Unigram::Size(counts[0]); + FreeMiddles(); + middle_begin_ = static_cast(malloc(sizeof(Middle) * (counts.size() - 2))); + middle_end_ = middle_begin_ + (counts.size() - 2); + for (unsigned char i = counts.size() - 1; i >= 2; --i) { + new (middle_begin_ + i - 2) Middle( + start, + quant_.Mid(i), + counts[0], + counts[i], + (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + } + longest.Init(start, quant_.Long(counts.size()), counts[0]); + return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -} // namespace +template void TrieSearch::LoadedBinary() { + unigram.LoadedBinary(); + for (Middle *i = middle_begin_; i != middle_end_; ++i) { + i->LoadedBinary(); + } + longest.LoadedBinary(); +} -void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -885,12 +961,15 @@ void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::v // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } +template class TrieSearch; +template class TrieSearch; + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0f720217..0a52acb5 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,72 +13,88 @@ struct Backing; class SortedVocabulary; namespace trie { -struct TrieSearch { - typedef NodeRange Node; +template class TrieSearch; +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); - typedef ::lm::ngram::trie::Unigram Unigram; - Unigram unigram; +template class TrieSearch { + public: + typedef NodeRange Node; - typedef trie::BitPackedMiddle Middle; - std::vector middle; + typedef ::lm::ngram::trie::Unigram Unigram; + Unigram unigram; - typedef trie::BitPackedLongest Longest; - Longest longest; + typedef trie::BitPackedMiddle Middle; - static const ModelType kModelType = TRIE_SORTED; + typedef trie::BitPackedLongest Longest; + Longest longest; - static std::size_t Size(const std::vector &counts, const Config &/*config*/) { - std::size_t ret = Unigram::Size(counts[0]); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(counts[i], counts[0], counts[i+1]); + static const ModelType kModelType = Quant::kModelType; + + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { + Quant::UpdateConfigFromBinary(fd, counts, config); } - return ret + Longest::Size(counts.back(), counts[0]); - } - - uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &/*config*/) { - unigram.Init(start); - start += Unigram::Size(counts[0]); - middle.resize(counts.size() - 2); - for (unsigned char i = 1; i < counts.size() - 1; ++i) { - middle[i-1].Init( - start, - counts[0], - counts[i+1], - (i == counts.size() - 2) ? static_cast(longest) : static_cast(middle[i])); - start += Middle::Size(counts[i], counts[0], counts[i+1]); + + static std::size_t Size(const std::vector &counts, const Config &config) { + std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); + for (unsigned char i = 1; i < counts.size() - 1; ++i) { + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + } + return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } - longest.Init(start, counts[0]); - return start + Longest::Size(counts.back(), counts[0]); - } - - void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); - } - - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); - } - - bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { - return mid.FindNoProb(word, backoff, node); - } - - bool LookupLongest(WordIndex word, float &prob, const Node &node) const { - return longest.Find(word, prob, node); - } - - bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { - // TODO: don't decode backoff. - assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); - for (const WordIndex *i = begin + 1; i < end; ++i) { - if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false; + + TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {} + + ~TrieSearch() { FreeMiddles(); } + + uint8_t *SetupMemory(uint8_t *start, const std::vector &counts, const Config &config); + + void LoadedBinary(); + + const Middle *MiddleBegin() const { return middle_begin_; } + const Middle *MiddleEnd() const { return middle_end_; } + + void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); + + bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + return unigram.Find(word, prob, backoff, node); + } + + bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { + return mid.Find(word, prob, backoff, node); } - return true; - } + + bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { + return mid.FindNoProb(word, backoff, node); + } + + bool LookupLongest(WordIndex word, float &prob, const Node &node) const { + return longest.Find(word, prob, node); + } + + bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { + // TODO: don't decode backoff. + assert(begin != end); + float ignored_prob, ignored_backoff; + LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + for (const WordIndex *i = begin + 1; i < end; ++i) { + if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; + } + return true; + } + + private: + friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); + + // Middles are managed manually so we can delay construction and they don't have to be copyable. + void FreeMiddles() { + for (const Middle *i = middle_begin_; i != middle_end_; ++i) { + i->~Middle(); + } + free(middle_begin_); + } + + Middle *middle_begin_, *middle_end_; + Quant quant_; }; } // namespace trie diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 2c633613..63c2a612 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,8 +1,8 @@ #include "lm/trie.hh" +#include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" -#include "util/proxy_iterator.hh" #include "util/sorted_uniform.hh" #include @@ -12,53 +12,32 @@ namespace ngram { namespace trie { namespace { -// Assumes key is first. -class JustKeyProxy { +class KeyAccessor { public: - JustKeyProxy() : inner_(), base_(), key_mask_(), key_bits_(), total_bits_() {} + KeyAccessor(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) + : base_(reinterpret_cast(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} - operator uint64_t() const { return GetKey(); } + typedef uint64_t Key; - uint64_t GetKey() const { - uint64_t bit_off = inner_ * static_cast(total_bits_); - return util::ReadInt57(base_ + bit_off / 8, bit_off & 7, key_bits_, key_mask_); + Key operator()(uint64_t index) const { + return util::ReadInt57(base_, index * static_cast(total_bits_), key_bits_, key_mask_); } private: - friend class util::ProxyIterator; - friend bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index); - - JustKeyProxy(const void *base, uint64_t index, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits) - : inner_(index), base_(static_cast(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {} - - // This is a read-only iterator. - JustKeyProxy &operator=(const JustKeyProxy &other); - - typedef uint64_t value_type; - - typedef uint64_t InnerIterator; - uint64_t &Inner() { return inner_; } - const uint64_t &Inner() const { return inner_; } - - // The address in bits is base_ * 8 + inner_ * total_bits_. - uint64_t inner_; const uint8_t *const base_; - const uint64_t key_mask_; + const WordIndex key_mask_; const uint8_t key_bits_, total_bits_; }; -bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, WordIndex key, uint64_t &at_index) { - util::ProxyIterator begin_it(JustKeyProxy(base, begin_index, key_mask, key_bits, total_bits)); - util::ProxyIterator end_it(JustKeyProxy(base, end_index, key_mask, key_bits, total_bits)); - util::ProxyIterator out; - if (!util::SortedUniformFind(begin_it, end_it, key, out)) return false; - at_index = out.Inner(); +bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const uint64_t key, uint64_t &at_index) { + KeyAccessor accessor(base, key_mask, key_bits, total_bits); + if (!util::BoundedSortedUniformFind::T>(accessor, begin_index - 1, (uint64_t)0, end_index, max_vocab, key, at_index)) return false; return true; } } // namespace std::size_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) { - uint8_t total_bits = util::RequiredBits(max_vocab) + 31 + remaining_bits; + uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits; // Extra entry for next pointer at the end. // +7 then / 8 to round up bits and convert to bytes // +sizeof(uint64_t) so that ReadInt57 etc don't go segfault. @@ -71,100 +50,96 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) word_bits_ = util::RequiredBits(max_vocab); word_mask_ = (1ULL << word_bits_) - 1ULL; if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions."); - prob_bits_ = 31; - total_bits_ = word_bits_ + prob_bits_ + remaining_bits; + total_bits_ = word_bits_ + remaining_bits; base_ = static_cast(base); insert_index_ = 0; + max_vocab_ = max_vocab; } -std::size_t BitPackedMiddle::Size(uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { - return BaseSize(entries, max_vocab, 32 + util::RequiredBits(max_ptr)); +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { + return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); } -void BitPackedMiddle::Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) { - next_source_ = &next_source; - backoff_bits_ = 32; - next_bits_ = util::RequiredBits(max_next); +template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - next_mask_ = (1ULL << next_bits_) - 1; - - BaseInit(base, max_vocab, backoff_bits_ + next_bits_); + BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); } -void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, word); + util::WriteInt57(base_, at_pointer, word_bits_, word); at_pointer += word_bits_; - util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); - at_pointer += prob_bits_; - util::WriteFloat32(base_ + (at_pointer >> 3), at_pointer & 7, backoff); - at_pointer += backoff_bits_; + quant_.Write(base_, at_pointer, prob, backoff); + at_pointer += quant_.TotalBits(); uint64_t next = next_source_->InsertIndex(); assert(next <= next_mask_); - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next); + util::WriteInt57(base_, at_pointer, next_bits_, next); ++insert_index_; } -bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) { + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } at_pointer *= total_bits_; at_pointer += word_bits_; - prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += prob_bits_; - backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + quant_.Read(base_, at_pointer, prob, backoff); + at_pointer += quant_.TotalBits(); + + range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); return true; } -bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer *= total_bits_; at_pointer += word_bits_; - at_pointer += prob_bits_; - backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7); - at_pointer += backoff_bits_; - range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + quant_.ReadBackoff(base_, at_pointer, backoff); + at_pointer += quant_.TotalBits(); + range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); // Read the next entry's pointer. at_pointer += total_bits_; - range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_bits_, next_mask_); + range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); return true; } -void BitPackedMiddle::FinishedLoading(uint64_t next_end) { +template void BitPackedMiddle::FinishedLoading(uint64_t next_end) { assert(next_end <= next_mask_); uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_ + (last_next_write >> 3), last_next_write & 7, next_bits_, next_end); + util::WriteInt57(base_, last_next_write, next_bits_, next_end); } -void BitPackedLongest::Insert(WordIndex index, float prob) { +template void BitPackedLongest::Insert(WordIndex index, float prob) { assert(index <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; - util::WriteInt57(base_ + (at_pointer >> 3), at_pointer & 7, word_bits_, index); + util::WriteInt57(base_, at_pointer, word_bits_, index); at_pointer += word_bits_; - util::WriteNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7, prob); + quant_.Write(base_, at_pointer, prob); ++insert_index_; } -bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { +template bool BitPackedLongest::Find(WordIndex word, float &prob, const NodeRange &range) const { uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, word, at_pointer)) return false; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; at_pointer = at_pointer * total_bits_ + word_bits_; - prob = util::ReadNonPositiveFloat31(base_ + (at_pointer >> 3), at_pointer & 7); + quant_.Read(base_, at_pointer, prob); return true; } +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedLongest; +template class BitPackedLongest; + } // namespace trie } // namespace ngram } // namespace lm diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 6aef050c..8fa21aaf 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -74,23 +74,21 @@ class BitPacked { void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); - uint8_t word_bits_, prob_bits_; + uint8_t word_bits_; uint8_t total_bits_; uint64_t word_mask_; uint8_t *base_; - uint64_t insert_index_; + uint64_t insert_index_, max_vocab_; }; -class BitPackedMiddle : public BitPacked { +template class BitPackedMiddle : public BitPacked { public: - BitPackedMiddle() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); // next_source need not be initialized. - void Init(void *base, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); void Insert(WordIndex word, float prob, float backoff); @@ -101,28 +99,33 @@ class BitPackedMiddle : public BitPacked { void FinishedLoading(uint64_t next_end); private: - uint8_t backoff_bits_, next_bits_; + Quant quant_; + uint8_t next_bits_; uint64_t next_mask_; const BitPacked *next_source_; }; -class BitPackedLongest : public BitPacked { +template class BitPackedLongest : public BitPacked { public: - BitPackedLongest() {} - - static std::size_t Size(uint64_t entries, uint64_t max_vocab) { - return BaseSize(entries, max_vocab, 0); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { + return BaseSize(entries, max_vocab, quant_bits); } - void Init(void *base, uint64_t max_vocab) { - return BaseInit(base, max_vocab, 0); + BitPackedLongest() {} + + void Init(void *base, const Quant &quant, uint64_t max_vocab) { + quant_ = quant; + BaseInit(base, max_vocab, quant_.TotalBits()); } void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; + + private: + Quant quant_; }; } // namespace trie diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 515af5db..7defd5c1 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -28,8 +28,8 @@ const uint64_t kUnknownHash = detail::HashForVocab("", 5); // Sadly some LMs have . const uint64_t kUnknownCapHash = detail::HashForVocab("", 5); -void ReadWords(int fd, EnumerateVocab *enumerate) { - if (!enumerate) return; +WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { + if (!enumerate) return std::numeric_limits::max(); const std::size_t kInitialRead = 16384; std::string buf; buf.reserve(kInitialRead + 100); @@ -38,7 +38,7 @@ void ReadWords(int fd, EnumerateVocab *enumerate) { while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (got == 0) return; + if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; @@ -87,13 +87,13 @@ SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL std::size_t SortedVocabulary::Size(std::size_t entries, const Config &/*config*/) { // Lead with the number of entries. - return sizeof(uint64_t) + sizeof(Entry) * entries; + return sizeof(uint64_t) + sizeof(uint64_t) * entries; } void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) { assert(allocated >= Size(entries, config)); // Leave space for number of entries. - begin_ = reinterpret_cast(reinterpret_cast(start) + 1); + begin_ = reinterpret_cast(start) + 1; end_ = begin_; saw_unk_ = false; } @@ -112,7 +112,7 @@ WordIndex SortedVocabulary::Insert(const StringPiece &str) { saw_unk_ = true; return 0; } - end_->key = hashed; + *end_ = hashed; if (enumerate_) { strings_to_enumerate_[end_ - begin_].assign(str.data(), str.size()); } @@ -134,8 +134,10 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) { util::JointSort(begin_, end_, reorder_vocab + 1); } SetSpecial(Index(""), Index(""), 0); - // Save size. + // Save size. Excludes UNK. *(reinterpret_cast(begin_) - 1) = end_ - begin_; + // Includes UNK. + bound_ = end_ - begin_ + 1; } void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { @@ -183,7 +185,7 @@ void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { lookup_.LoadedBinary(); - ReadWords(fd, to); + available_ = ReadWords(fd, to); SetSpecial(Index(""), Index(""), 0); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 546c1649..c92518e4 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -9,6 +9,7 @@ #include "util/sorted_uniform.hh" #include "util/string_piece.hh" +#include #include #include @@ -44,22 +45,16 @@ class WriteWordsWrapper : public EnumerateVocab { // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. class SortedVocabulary : public base::Vocabulary { - private: - // Sorted uniform requires a GetKey function. - struct Entry { - uint64_t GetKey() const { return key; } - uint64_t key; - bool operator<(const Entry &other) const { - return key < other.key; - } - }; - public: SortedVocabulary(); WordIndex Index(const StringPiece &str) const { - const Entry *found; - if (util::SortedUniformFind(begin_, end_, detail::HashForVocab(str), found)) { + const uint64_t *found; + if (util::BoundedSortedUniformFind, util::Pivot64>( + util::IdentityAccessor(), + begin_ - 1, 0, + end_, std::numeric_limits::max(), + detail::HashForVocab(str), found)) { return found - begin_ + 1; // +1 because is 0 and does not appear in the lookup table. } else { return 0; @@ -68,6 +63,10 @@ class SortedVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. + // While this number is correct, ProbingVocabulary::Bound might not be correct in some cases. + WordIndex Bound() const { return bound_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); @@ -83,7 +82,11 @@ class SortedVocabulary : public base::Vocabulary { void LoadedBinary(int fd, EnumerateVocab *to); private: - Entry *begin_, *end_; + uint64_t *begin_, *end_; + + WordIndex bound_; + + WordIndex highest_value_; bool saw_unk_; @@ -105,6 +108,12 @@ class ProbingVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); + // Vocab words are [0, Bound()). + // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. + // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. + // Specifically, the binary file format does not currently indicate whether is in count or not. + WordIndex Bound() const { return available_; } + // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); diff --git a/klm/util/bit_packing.cc b/klm/util/bit_packing.cc index 681da5f2..41999b72 100644 --- a/klm/util/bit_packing.cc +++ b/klm/util/bit_packing.cc @@ -28,10 +28,10 @@ void BitPackingSanity() { memset(mem, 0, sizeof(mem)); const uint64_t test57 = 0x123456789abcdefULL; for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + b / 8, b % 8, 57, test57); + WriteInt57(mem, b, 57, test57); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - if (test57 != ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)) + if (test57 != ReadInt57(mem, b, 57, (1ULL << 57) - 1)) UTIL_THROW(Exception, "The bit packing routines are failing for your architecture. Please send a bug report with your architecture, operating system, and compiler."); } // TODO: more checks. diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 5c71c792..b35d80c8 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -42,47 +42,62 @@ inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { #error "Bit packing code isn't written for your byte order." #endif +inline uint64_t ReadOff(const void *base, uint64_t bit_off) { + return *reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)); +} + /* Pack integers up to 57 bits using their least significant digits. * The length is specified using mask: * Assumes mask == (1 << length) - 1 where length <= 57. */ -inline uint64_t ReadInt57(const void *base, uint8_t bit, uint8_t length, uint64_t mask) { - return (*reinterpret_cast(base) >> BitPackShift(bit, length)) & mask; +inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, uint64_t mask) { + return (ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, length)) & mask; } /* Assumes value < (1 << length) and length <= 57. * Assumes the memory is zero initially. */ -inline void WriteInt57(void *base, uint8_t bit, uint8_t length, uint64_t value) { - *reinterpret_cast(base) |= (value << BitPackShift(bit, length)); +inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) { + *reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); +} + +/* Same caveats as above, but for a 25 bit limit. */ +inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) { + return (*reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask; +} + +inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) { + *reinterpret_cast(reinterpret_cast(base) + (bit_off >> 3)) |= + (value << BitPackShift(bit_off & 7, length)); } typedef union { float f; uint32_t i; } FloatEnc; -inline float ReadFloat32(const void *base, uint8_t bit) { +inline float ReadFloat32(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast(base) >> BitPackShift(bit, 32); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32); return encoded.f; } -inline void WriteFloat32(void *base, uint8_t bit, float value) { +inline void WriteFloat32(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; - WriteInt57(base, bit, 32, encoded.i); + WriteInt57(base, bit_off, 32, encoded.i); } const uint32_t kSignBit = 0x80000000; -inline float ReadNonPositiveFloat31(const void *base, uint8_t bit) { +inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { FloatEnc encoded; - encoded.i = *reinterpret_cast(base) >> BitPackShift(bit, 31); + encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); // Sign bit set means negative. encoded.i |= kSignBit; return encoded.f; } -inline void WriteNonPositiveFloat31(void *base, uint8_t bit, float value) { +inline void WriteNonPositiveFloat31(void *base, uint64_t bit_off, float value) { FloatEnc encoded; encoded.f = value; encoded.i &= ~kSignBit; - WriteInt57(base, bit, 31, encoded.i); + WriteInt57(base, bit_off, 31, encoded.i); } void BitPackingSanity(); diff --git a/klm/util/bit_packing_test.cc b/klm/util/bit_packing_test.cc index c578ddd1..4edc2004 100644 --- a/klm/util/bit_packing_test.cc +++ b/klm/util/bit_packing_test.cc @@ -9,15 +9,16 @@ namespace util { namespace { const uint64_t test57 = 0x123456789abcdefULL; +const uint32_t test25 = 0x1234567; -BOOST_AUTO_TEST_CASE(ZeroBit) { +BOOST_AUTO_TEST_CASE(ZeroBit57) { char mem[16]; memset(mem, 0, sizeof(mem)); WriteInt57(mem, 0, 57, test57); BOOST_CHECK_EQUAL(test57, ReadInt57(mem, 0, 57, (1ULL << 57) - 1)); } -BOOST_AUTO_TEST_CASE(EachBit) { +BOOST_AUTO_TEST_CASE(EachBit57) { char mem[16]; for (uint8_t b = 0; b < 8; ++b) { memset(mem, 0, sizeof(mem)); @@ -26,15 +27,27 @@ BOOST_AUTO_TEST_CASE(EachBit) { } } -BOOST_AUTO_TEST_CASE(Consecutive) { +BOOST_AUTO_TEST_CASE(Consecutive57) { char mem[57+8]; memset(mem, 0, sizeof(mem)); for (uint64_t b = 0; b < 57 * 8; b += 57) { - WriteInt57(mem + (b / 8), b % 8, 57, test57); - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + WriteInt57(mem, b, 57, test57); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); } for (uint64_t b = 0; b < 57 * 8; b += 57) { - BOOST_CHECK_EQUAL(test57, ReadInt57(mem + b / 8, b % 8, 57, (1ULL << 57) - 1)); + BOOST_CHECK_EQUAL(test57, ReadInt57(mem, b, 57, (1ULL << 57) - 1)); + } +} + +BOOST_AUTO_TEST_CASE(Consecutive25) { + char mem[25+8]; + memset(mem, 0, sizeof(mem)); + for (uint64_t b = 0; b < 25 * 8; b += 25) { + WriteInt25(mem, b, 25, test25); + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); + } + for (uint64_t b = 0; b < 25 * 8; b += 25) { + BOOST_CHECK_EQUAL(test25, ReadInt25(mem, b, 25, (1ULL << 25) - 1)); } } diff --git a/klm/util/sorted_uniform.hh b/klm/util/sorted_uniform.hh index 05826b51..84d7aa02 100644 --- a/klm/util/sorted_uniform.hh +++ b/klm/util/sorted_uniform.hh @@ -9,52 +9,96 @@ namespace util { -inline std::size_t Pivot(uint64_t off, uint64_t range, std::size_t width) { - std::size_t ret = static_cast(static_cast(off) / static_cast(range) * static_cast(width)); - // Cap for floating point rounding - return (ret < width) ? ret : width - 1; -} -/*inline std::size_t Pivot(uint32_t off, uint32_t range, std::size_t width) { - return static_cast(static_cast(off) * static_cast(width) / static_cast(range)); +template class IdentityAccessor { + public: + typedef T Key; + T operator()(const uint64_t *in) const { return *in; } +}; + +struct Pivot64 { + static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) { + std::size_t ret = static_cast(static_cast(off) / static_cast(range) * static_cast(width)); + // Cap for floating point rounding + return (ret < width) ? ret : width - 1; + } +}; + +// Use when off * width is <2^64. This is guaranteed when each of them is actually a 32-bit value. +struct Pivot32 { + static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) { + return static_cast((off * width) / (range + 1)); + } +}; + +// Usage: PivotSelect::T +template struct PivotSelect; +template <> struct PivotSelect<8> { typedef Pivot64 T; }; +template <> struct PivotSelect<4> { typedef Pivot32 T; }; +template <> struct PivotSelect<2> { typedef Pivot32 T; }; + +/* Binary search. */ +template bool BinaryFind( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key, Iterator &out) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(uint16_t off, uint16_t range, std::size_t width) { - return static_cast(static_cast(off) * width / static_cast(range)); + +// Search the range [before_it + 1, after_it - 1] for key. +// Preconditions: +// before_v <= key <= after_v +// before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v +// range is sorted. +template bool BoundedSortedUniformFind( + const Accessor &accessor, + Iterator before_it, typename Accessor::Key before_v, + Iterator after_it, typename Accessor::Key after_v, + const typename Accessor::Key key, Iterator &out) { + while (after_it - before_it > 1) { + Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1))); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + before_it = pivot; + before_v = mid; + } else if (mid > key) { + after_it = pivot; + after_v = mid; + } else { + out = pivot; + return true; + } + } + return false; } -inline std::size_t Pivot(unsigned char off, unsigned char range, std::size_t width) { - return static_cast(static_cast(off) * width / static_cast(range)); -}*/ -template bool SortedUniformFind(Iterator begin, Iterator end, const Key key, Iterator &out) { +template bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) { if (begin == end) return false; - Key below(begin->GetKey()); + typename Accessor::Key below(accessor(begin)); if (key <= below) { if (key == below) { out = begin; return true; } return false; } // Make the range [begin, end]. --end; - Key above(end->GetKey()); + typename Accessor::Key above(accessor(end)); if (key >= above) { if (key == above) { out = end; return true; } return false; } - - // Search the range [begin + 1, end - 1] knowing that *begin == below, *end == above. - while (end - begin > 1) { - Iterator pivot(begin + (1 + Pivot(key - below, above - below, static_cast(end - begin - 1)))); - Key mid(pivot->GetKey()); - if (mid < key) { - begin = pivot; - below = mid; - } else if (mid > key) { - end = pivot; - above = mid; - } else { - out = pivot; - return true; - } - } - return false; + return BoundedSortedUniformFind(accessor, begin, below, end, above, key, out); } // To use this template, you need to define a Pivot function to match Key. @@ -64,7 +108,13 @@ template class SortedUniformMap { typedef typename Packing::ConstIterator ConstIterator; typedef typename Packing::MutableIterator MutableIterator; - public: + struct Accessor { + public: + typedef typename Packing::Key Key; + const Key &operator()(const ConstIterator &i) const { return i->GetKey(); } + Key &operator()(const MutableIterator &i) const { return i->GetKey(); } + }; + // Offer consistent API with probing hash. static std::size_t Size(std::size_t entries, float /*ignore*/ = 0.0) { return sizeof(uint64_t) + entries * Packing::kBytes; @@ -120,7 +170,7 @@ template class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind(begin_, end_, key, out); + return SortedUniformFind(begin_, end_, key, out); } // Do not call before FinishedInserting. @@ -129,7 +179,7 @@ template class SortedUniformMap { assert(initialized_); assert(loaded_); #endif - return SortedUniformFind(ConstIterator(begin_), ConstIterator(end_), key, out); + return SortedUniformFind(Accessor(), ConstIterator(begin_), ConstIterator(end_), key, out); } ConstIterator begin() const { return begin_; } -- cgit v1.2.3 From 2c14cf2218031c29a9884bccf17e9273c71a33b2 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 18 Aug 2011 12:14:01 +0100 Subject: KenLM update: Bhiksha's trick, simple test for lms without unk, auto-detect binary files instead of requiring them to be specified at runtime. --- decoder/cdec_ff.cc | 5 +- decoder/ff_klm.cc | 38 +++++- decoder/ff_klm.h | 7 +- klm/compile.sh | 2 +- klm/lm/Makefile.am | 1 + klm/lm/bhiksha.cc | 93 +++++++++++++++ klm/lm/bhiksha.hh | 108 +++++++++++++++++ klm/lm/binary_format.cc | 13 ++- klm/lm/binary_format.hh | 9 +- klm/lm/build_binary.cc | 54 ++++++--- klm/lm/config.cc | 1 + klm/lm/config.hh | 5 +- klm/lm/model.cc | 67 ++++++----- klm/lm/model.hh | 12 +- klm/lm/model_test.cc | 73 ++++++++++-- klm/lm/ngram_query.cc | 9 ++ klm/lm/quantize.cc | 1 + klm/lm/quantize.hh | 4 +- klm/lm/read_arpa.cc | 6 +- klm/lm/search_hashed.cc | 2 +- klm/lm/search_hashed.hh | 3 +- klm/lm/search_trie.cc | 45 +++---- klm/lm/search_trie.hh | 20 ++-- klm/lm/test_nounk.arpa | 120 +++++++++++++++++++ klm/lm/trie.cc | 57 ++++----- klm/lm/trie.hh | 24 ++-- klm/lm/vocab.cc | 6 +- klm/lm/vocab.hh | 4 + klm/util/bit_packing.hh | 13 ++- klm/util/murmur_hash.cc | 258 ++++++++++++++++++++--------------------- klm/util/probing_hash_table.hh | 2 +- klm/util/sorted_uniform.hh | 23 +++- 32 files changed, 792 insertions(+), 293 deletions(-) create mode 100644 klm/lm/bhiksha.cc create mode 100644 klm/lm/bhiksha.hh create mode 100644 klm/lm/test_nounk.arpa (limited to 'decoder/ff_klm.cc') diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 3451c9fb..1ef76a05 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -55,10 +55,7 @@ void register_feature_functions() { ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); - ff_registry.Register("KLanguageModel", new FFFactory >()); - ff_registry.Register("KLanguageModel_Trie", new FFFactory >()); - ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory >()); - ff_registry.Register("KLanguageModel_Probing", new FFFactory >()); + ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); ff_registry.Register("NonLatinCount", new FFFactory); ff_registry.Register("RuleShape", new FFFactory); ff_registry.Register("RelativeSentencePosition", new FFFactory); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 9b7fe2d3..24dcb9c3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -9,6 +9,7 @@ #include "stringlib.h" #include "hg.h" #include "tdict.h" +#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; @@ -434,8 +435,37 @@ void KLanguageModel::FinalTraversalFeatures(const void* ant_state, features->set_value(oov_fid_, oovs); } -// instantiate templates -template class KLanguageModel; -template class KLanguageModel; -template class KLanguageModel; +template boost::shared_ptr CreateModel(const std::string ¶m) { + KLanguageModel *ret = new KLanguageModel(param); + ret->Init(); + return boost::shared_ptr(ret); +} +boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel(param); + case TRIE_SORTED: + return CreateModel(param); + case ARRAY_TRIE_SORTED: + return CreateModel(param); + case QUANT_TRIE_SORTED: + return CreateModel(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel::usage(params, verbose); +} diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 5eafe8be..6efe50f6 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -4,8 +4,8 @@ #include #include +#include "ff_factory.h" #include "ff.h" -#include "lm/model.hh" template struct KLanguageModelImpl; @@ -34,4 +34,9 @@ class KLanguageModel : public FeatureFunction { KLanguageModelImpl* pimpl_; }; +struct KLanguageModelFactory : public FactoryBase { + FP Create(std::string param) const; + std::string usage(bool params,bool verbose) const; +}; + #endif diff --git a/klm/compile.sh b/klm/compile.sh index 6ca85e1f..abe3473a 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -5,7 +5,7 @@ set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o done g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 395494bc..fae6b41a 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -12,6 +12,7 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz noinst_LIBRARIES = libklm.a libklm_a_SOURCES = \ + bhiksha.cc \ binary_format.cc \ config.cc \ lm_exception.cc \ diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc new file mode 100644 index 00000000..bf86fd4b --- /dev/null +++ b/klm/lm/bhiksha.cc @@ -0,0 +1,93 @@ +#include "lm/bhiksha.hh" +#include "lm/config.hh" + +#include + +namespace lm { +namespace ngram { +namespace trie { + +DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) : + next_(util::BitsMask::ByMax(max_next)) {} + +const uint8_t kArrayBhikshaVersion = 0; + +void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { + uint8_t version; + uint8_t configured_bits; + if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) { + UTIL_THROW(util::ErrnoException, "Could not read from binary file"); + } + if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); + config.pointer_bhiksha_bits = configured_bits; +} + +namespace { + +// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset) +uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t best_chop = 0; + int64_t lowest_change = std::numeric_limits::max(); + // There are probably faster ways but I don't care because this is only done once per order at construction time. + for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) { + int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */ + - max_offset * static_cast(chop); /* savings in bits*/ + if (change < lowest_change) { + lowest_change = change; + best_chop = chop; + } + } + return best_chop; +} + +std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t chopping = ChopBits(max_offset, max_next, config); + return (max_next >> (required - chopping)) + 1 /* we store 0 too */; +} +} // namespace + +std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { + return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; +} + +uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config); +} + +namespace { + +void *AlignTo8(void *from) { + uint8_t *val = reinterpret_cast(from); + std::size_t remainder = reinterpret_cast(val) & 7; + if (!remainder) return val; + return val + 8 - remainder; +} + +} // namespace + +ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config) + : next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))), + offset_begin_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */), + offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)), + write_to_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */), + original_base_(base) {} + +void ArrayBhiksha::FinishedLoading(const Config &config) { + // *offset_begin_ = 0 but without a const_cast. + *(write_to_ - (write_to_ - offset_begin_)) = 0; + + if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected."); + + uint8_t *head_write = reinterpret_cast(original_base_); + *(head_write++) = kArrayBhikshaVersion; + *(head_write++) = config.pointer_bhiksha_bits; +} + +void ArrayBhiksha::LoadedBinary() { +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh new file mode 100644 index 00000000..cfb2b053 --- /dev/null +++ b/klm/lm/bhiksha.hh @@ -0,0 +1,108 @@ +/* Simple implementation of + * @inproceedings{bhikshacompression, + * author={Bhiksha Raj and Ed Whittaker}, + * year={2003}, + * title={Lossless Compression of Language Model Structure and Word Identifiers}, + * booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing}, + * pages={388--391}, + * } + * + * Currently only used for next pointers. + */ + +#include + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "util/bit_packing.hh" +#include "util/sorted_uniform.hh" + +namespace lm { +namespace ngram { +class Config; + +namespace trie { + +class DontBhiksha { + public: + static const ModelType kModelTypeAdd = static_cast(0); + + static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + + static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + + static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { + return util::RequiredBits(max_next); + } + + DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const { + out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask); + out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask); + //assert(out.end >= out.begin); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) { + util::WriteInt57(base, bit_offset, next_.bits, value); + } + + void FinishedLoading(const Config &/*config*/) {} + + void LoadedBinary() {} + + uint8_t InlineBits() const { return next_.bits; } + + private: + util::BitsMask next_; +}; + +class ArrayBhiksha { + public: + static const ModelType kModelTypeAdd = kArrayAdd; + + static void UpdateConfigFromBinary(int fd, Config &config); + + static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + + static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); + + ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const { + const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor(), offset_begin_, offset_end_, index); + const uint64_t *end_it; + for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} + --end_it; + out.begin = ((begin_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); + out.end = ((end_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { + uint64_t encode = value >> next_inline_.bits; + for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index; + util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask); + } + + void FinishedLoading(const Config &config); + + void LoadedBinary(); + + uint8_t InlineBits() const { return next_inline_.bits; } + + private: + const util::BitsMask next_inline_; + + const uint64_t *const offset_begin_; + const uint64_t *const offset_end_; + + uint64_t *write_to_; + + void *original_base_; +}; + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 92b1008b..e02e621a 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -40,7 +40,7 @@ struct Sanity { } }; -const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; std::size_t Align8(std::size_t in) { std::size_t off = in % 8; @@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ } } -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) { +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { + std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { // Grow the file to accomodate the search, using zeros. - if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) - UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); + if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) + UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed"); // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. off_t page_size = sysconf(_SC_PAGE_SIZE); - off_t alignment_cruft = backing.vocab.size() % page_size; - backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); + off_t alignment_cruft = adjusted_vocab % page_size; + backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); return reinterpret_cast(backing.search.get()) + alignment_cruft; } else { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 2b32b450..d28cb6c5 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,12 @@ namespace lm { namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in @@ -55,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off); // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing); // Write header to binary file. This is done last to prevent incomplete files // from loading. diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 4552c419..b7aee4de 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,12 +15,12 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" -"-u sets the default log10 probability for if the ARPA file does not have\n" -"one.\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" +"-u sets the log10 probability for if the ARPA file does not have one.\n" +" Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have and .\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is either probing or trie:\n\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n" +"type is either probing or trie. Default is probing.\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" @@ -29,10 +29,11 @@ void Usage(const char *name) { "-t is the temporary directory prefix. Default is the output file name.\n" "-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" "-q turns quantization on and sets the number of bits (e.g. -q 8).\n" -"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n" -"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" -"Passing only an input file will print memory usage of each data structure.\n" -"If the ARPA file does not have , -u sets 's probability; default 0.0.\n"; +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n" +"-a compresses pointers using an array of offsets. The parameter is the\n" +" maximum number of bits encoded by the array. Memory is minimized subject\n" +" to the maximum, so pick 255 to minimize memory.\n\n" +"Get a memory estimate by passing an ARPA file without an output file name.\n"; exit(1); } @@ -63,12 +64,14 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t sizes[3]; + std::size_t sizes[5]; sizes[0] = ProbingModel::Size(counts, config); sizes[1] = TrieModel::Size(counts, config); sizes[2] = QuantTrieModel::Size(counts, config); - std::size_t max_length = *std::max_element(sizes, sizes + 3); - std::size_t min_length = *std::max_element(sizes, sizes + 3); + sizes[3] = ArrayTrieModel::Size(counts, config); + sizes[4] = QuantArrayTrieModel::Size(counts, config); + std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); + std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); std::size_t divide; char prefix; if (min_length < (1 << 10) * 10) { @@ -91,7 +94,9 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::cout << prefix << "B\n" "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" - "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; + "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" + "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" + "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; } void ProbingQuantizationUnsupported() { @@ -106,11 +111,11 @@ void ProbingQuantizationUnsupported() { int main(int argc, char *argv[]) { using namespace lm::ngram; - bool quantize = false, set_backoff_bits = false; try { + bool quantize = false, set_backoff_bits = false, bhiksha = false; lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -121,6 +126,9 @@ int main(int argc, char *argv[]) { config.backoff_bits = ParseBitCount(optarg); set_backoff_bits = true; break; + case 'a': + config.pointer_bhiksha_bits = ParseBitCount(optarg); + bhiksha = true; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -162,9 +170,17 @@ int main(int argc, char *argv[]) { ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { if (quantize) { - QuantTrieModel(from_file, config); + if (bhiksha) { + QuantArrayTrieModel(from_file, config); + } else { + QuantTrieModel(from_file, config); + } } else { - TrieModel(from_file, config); + if (bhiksha) { + ArrayTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } } else { Usage(argv[0]); @@ -173,9 +189,9 @@ int main(int argc, char *argv[]) { Usage(argv[0]); } } - catch (std::exception &e) { + catch (const std::exception &e) { std::cerr << e.what() << std::endl; - abort(); + return 1; } return 0; } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 08e1af5c..297589a4 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -20,6 +20,7 @@ Config::Config() : include_vocab(true), prob_bits(8), backoff_bits(8), + pointer_bhiksha_bits(22), load_method(util::POPULATE_OR_READ) {} } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index dcc7cf35..227b8512 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -73,9 +73,12 @@ struct Config { // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used - // to quantize. + // to quantize (and one of the remaining backoffs will be 0). uint8_t prob_bits, backoff_bits; + // Bhiksha compression (simple form). Only works with trie. + uint8_t pointer_bhiksha_bits; + // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a1d10b3d..27e24b1c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -21,6 +21,8 @@ size_t hash_value(const State &state) { namespace detail { +template const ModelType GenericModel::kModelType = Search::kModelType; + template size_t GenericModel::Size(const std::vector &counts, const Config &config) { return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } @@ -56,35 +58,40 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, const Config &config) { // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. util::FilePiece f(backing_.file.release(), file, config.messages); - std::vector counts; - // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. - ReadARPACounts(f, counts); - - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); - if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); - if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - - std::size_t vocab_size = VocabularyT::Size(counts[0], config); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. - vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); - - if (config.write_mmap) { - WriteWordsWrapper wrap(config.enumerate_vocab); - vocab_.ConfigureEnumerate(&wrap, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get()); - } else { - vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - } + try { + std::vector counts; + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. + ReadARPACounts(f, counts); + + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + + std::size_t vocab_size = VocabularyT::Size(counts[0], config); + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + + if (config.write_mmap) { + WriteWordsWrapper wrap(config.enumerate_vocab); + vocab_.ConfigureEnumerate(&wrap, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + wrap.Write(backing_.file.get()); + } else { + vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + } - if (!vocab_.SawUnk()) { - assert(config.unknown_missing != THROW_UP); - // Default probabilities for unknown. - search_.unigram.Unknown().backoff = 0.0; - search_.unigram.Unknown().prob = config.unknown_missing_logprob; + if (!vocab_.SawUnk()) { + assert(config.unknown_missing != THROW_UP); + // Default probabilities for unknown. + search_.unigram.Unknown().backoff = 0.0; + search_.unigram.Unknown().prob = config.unknown_missing_logprob; + } + FinishFile(config, kModelType, counts, backing_); + } catch (util::Exception &e) { + e << " Byte: " << f.Offset(); + throw; } - FinishFile(config, kModelType, counts, backing_); } template FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { @@ -225,8 +232,10 @@ template FullScoreReturn GenericModel; // HASH_PROBING -template class GenericModel, SortedVocabulary>; // TRIE_SORTED -template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, SortedVocabulary>; +template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; } // namespace detail } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 1f49a382..21595321 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -1,6 +1,7 @@ #ifndef LM_MODEL__ #define LM_MODEL__ +#include "lm/bhiksha.hh" #include "lm/binary_format.hh" #include "lm/config.hh" #include "lm/facade.hh" @@ -71,6 +72,9 @@ template class GenericModel : public base::Mod private: typedef base::ModelFacade, State, VocabularyT> P; public: + // This is the model type returned by RecognizeBinary. + static const ModelType kModelType; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -131,8 +135,6 @@ template class GenericModel : public base::Mod Backing &MutableBacking() { return backing_; } - static const ModelType kModelType = Search::kModelType; - Backing backing_; VocabularyT vocab_; @@ -152,9 +154,11 @@ typedef ProbingModel Model; // Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> ArrayTrieModel; -typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantArrayTrieModel; } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 8bf040ff..57c7291c 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -193,6 +193,14 @@ template void Stateless(const M &model) { BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); } +template void NoUnkCheck(const M &model) { + WordIndex unk_index = 0; + State state; + + FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state); + BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); +} + template void Everything(const M &m) { Starters(m); Continuation(m); @@ -231,25 +239,38 @@ template void LoadingTest() { Config config; config.arpa_complain = Config::NONE; config.messages = NULL; - ExpectEnumerateVocab enumerate; - config.enumerate_vocab = &enumerate; config.probing_multiplier = 2.0; - ModelT m("test.arpa", config); - enumerate.Check(m.GetVocabulary()); - Everything(m); + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test.arpa", config); + enumerate.Check(m.GetVocabulary()); + Everything(m); + } + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test_nounk.arpa", config); + enumerate.Check(m.GetVocabulary()); + NoUnkCheck(m); + } } BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } - BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } - -BOOST_AUTO_TEST_CASE(quant) { +BOOST_AUTO_TEST_CASE(quant_trie) { LoadingTest(); } +BOOST_AUTO_TEST_CASE(bhiksha_trie) { + LoadingTest(); +} +BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { + LoadingTest(); +} template void BinaryTest() { Config config; @@ -267,10 +288,34 @@ template void BinaryTest() { config.write_mmap = NULL; - ModelT binary("test.binary", config); - enumerate.Check(binary.GetVocabulary()); - Everything(binary); + ModelType type; + BOOST_REQUIRE(RecognizeBinary("test.binary", type)); + BOOST_CHECK_EQUAL(ModelT::kModelType, type); + + { + ModelT binary("test.binary", config); + enumerate.Check(binary.GetVocabulary()); + Everything(binary); + } unlink("test.binary"); + + // Now test without . + config.write_mmap = "test_nounk.binary"; + config.messages = NULL; + enumerate.Clear(); + { + ModelT copy_model("test_nounk.arpa", config); + enumerate.Check(copy_model.GetVocabulary()); + enumerate.Clear(); + NoUnkCheck(copy_model); + } + config.write_mmap = NULL; + { + ModelT binary("test_nounk.binary", config); + enumerate.Check(binary.GetVocabulary()); + NoUnkCheck(binary); + } + unlink("test_nounk.binary"); } BOOST_AUTO_TEST_CASE(write_and_read_probing) { @@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { BinaryTest(); } +BOOST_AUTO_TEST_CASE(write_and_read_array_trie) { + BinaryTest(); +} +BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) { + BinaryTest(); +} } // namespace } // namespace ngram diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 9454a6d1..d9db4aa2 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -99,6 +99,15 @@ int main(int argc, char *argv[]) { case lm::ngram::TRIE_SORTED: Query(argv[1], sentence_context); break; + case lm::ngram::QUANT_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::QUANT_ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; case lm::ngram::HASH_SORTED: default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index 4bb6b1b8..fd371cc8 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -43,6 +43,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector(0); static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } @@ -108,7 +108,7 @@ class SeparatelyQuantize { }; public: - static const ModelType kModelType = QUANT_TRIE_SORTED; + static const ModelType kModelTypeAdd = kQuantAdd; static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 060a97ea..455bc4ba 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -31,15 +31,15 @@ const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; void ReadARPACounts(util::FilePiece &in, std::vector &number) { number.clear(); StringPiece line; - if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { + while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + if (line != "\\data\\") { if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast(line.data()[1]) == 0x8b)) { UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); } if (static_cast(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); - UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank"); + UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\."); } - if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); // So strtol doesn't go off the end of line. diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index c56ba7b8..82c53ec8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -98,7 +98,7 @@ template uint8_t *TemplateHashedSearch template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. - SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); + SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config); PositiveProbWarn warn(config.positive_log_probability); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index f3acdefc..c62985e4 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -52,12 +52,11 @@ struct HashedSearch { Unigram unigram; - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { const ProbBackoff &entry = unigram.Lookup(word); prob = entry.prob; backoff = entry.backoff; next = static_cast(word); - return true; } }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 91f87f1c..05059ffb 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,6 +1,7 @@ /* This is where the trie is built. It's on-disk. */ #include "lm/search_trie.hh" +#include "lm/bhiksha.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" @@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector appears. - size_t extra_count = counts[0] + 1; - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); + size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; @@ -610,9 +611,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -template class WriteEntries { +template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -649,7 +650,7 @@ template class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; + BitPackedMiddle *const middle_; BitPackedLongest &longest_; BitPacked &bigram_pack_; }; @@ -821,7 +822,7 @@ template void TrainProbQuantizer(uint8_t order, uint64_t count, So } // namespace -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing) { +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { std::vector inputs(counts.size() - 1); std::vector contexts(counts.size() - 1); @@ -846,7 +847,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); if (Quant::kTrain) { util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); @@ -863,7 +864,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -901,14 +902,14 @@ template void BuildTrie(const std::string &file_prefix, std::vecto /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { - for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { - i->FinishedLoading((i+1)->InsertIndex()); + for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex(), config); } - (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config); } } -template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { +template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { quant_.SetupMemory(start, config); start += Quant::Size(counts.size(), config); unigram.Init(start); @@ -919,22 +920,24 @@ template uint8_t *TrieSearch::SetupMemory(uint8_t *start, c std::vector middle_starts(counts.size() - 2); for (unsigned char i = 2; i < counts.size(); ++i) { middle_starts[i-2] = start; - start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config); } - // Crazy backwards thing so we initialize in the correct order. + // Crazy backwards thing so we initialize using pointers to ones that have already been initialized for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( middle_starts[i-2], quant_.Mid(i), + counts[i-1], counts[0], counts[i], - (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); + (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1]), + config); } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template void TrieSearch::LoadedBinary() { +template void TrieSearch::LoadedBinary() { unigram.LoadedBinary(); for (Middle *i = middle_begin_; i != middle_end_; ++i) { i->LoadedBinary(); @@ -942,7 +945,7 @@ template void TrieSearch::LoadedBinary() { longest.LoadedBinary(); } -template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -966,14 +969,16 @@ template void TrieSearch::InitializeFromARPA(const char *fi // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, quant_, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } -template class TrieSearch; -template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; } // namespace trie } // namespace ngram diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0a52acb5..2f39c09f 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,31 +13,33 @@ struct Backing; class SortedVocabulary; namespace trie { -template class TrieSearch; -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); +template class TrieSearch; +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); -template class TrieSearch { +template class TrieSearch { public: typedef NodeRange Node; typedef ::lm::ngram::trie::Unigram Unigram; Unigram unigram; - typedef trie::BitPackedMiddle Middle; + typedef trie::BitPackedMiddle Middle; typedef trie::BitPackedLongest Longest; Longest longest; - static const ModelType kModelType = Quant::kModelType; + static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); + AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + Bhiksha::UpdateConfigFromBinary(fd, config); } static std::size_t Size(const std::vector &counts, const Config &config) { std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); } return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } @@ -55,8 +57,8 @@ template class TrieSearch { void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + unigram.Find(word, prob, backoff, node); } bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { @@ -83,7 +85,7 @@ template class TrieSearch { } private: - friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); + friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa new file mode 100644 index 00000000..060733d9 --- /dev/null +++ b/klm/lm/test_nounk.arpa @@ -0,0 +1,120 @@ + +\data\ +ngram 1=36 +ngram 2=45 +ngram 3=10 +ngram 4=6 +ngram 5=4 + +\1-grams: +-1.383514 , -0.30103 +-1.139057 . -0.845098 +-1.029493 +-99 -0.4149733 +-1.285941 a -0.69897 +-1.687872 also -0.30103 +-1.687872 beyond -0.30103 +-1.687872 biarritz -0.30103 +-1.687872 call -0.30103 +-1.687872 concerns -0.30103 +-1.687872 consider -0.30103 +-1.687872 considering -0.30103 +-1.687872 for -0.30103 +-1.509559 higher -0.30103 +-1.687872 however -0.30103 +-1.687872 i -0.30103 +-1.687872 immediate -0.30103 +-1.687872 in -0.30103 +-1.687872 is -0.30103 +-1.285941 little -0.69897 +-1.383514 loin -0.30103 +-1.687872 look -0.30103 +-1.285941 looking -0.4771212 +-1.206319 more -0.544068 +-1.509559 on -0.4771212 +-1.509559 screening -0.4771212 +-1.687872 small -0.30103 +-1.687872 the -0.30103 +-1.687872 to -0.30103 +-1.687872 watch -0.30103 +-1.687872 watching -0.30103 +-1.687872 what -0.30103 +-1.687872 would -0.30103 +-3.141592 foo +-2.718281 bar 3.0 +-6.535897 baz -0.0 + +\2-grams: +-0.6925742 , . +-0.7522095 , however +-0.7522095 , is +-0.0602359 . +-0.4846522 looking -0.4771214 +-1.051485 screening +-1.07153 the +-1.07153 watching +-1.07153 what +-0.09132547 a little -0.69897 +-0.2922095 also call +-0.2922095 beyond immediate +-0.2705918 biarritz . +-0.2922095 call for +-0.2922095 concerns in +-0.2922095 consider watch +-0.2922095 considering consider +-0.2834328 for , +-0.5511513 higher more +-0.5845945 higher small +-0.2834328 however , +-0.2922095 i would +-0.2922095 immediate concerns +-0.2922095 in biarritz +-0.2922095 is to +-0.09021038 little more -0.1998621 +-0.7273645 loin , +-0.6925742 loin . +-0.6708385 loin +-0.2922095 look beyond +-0.4638903 looking higher +-0.4638903 looking on -0.4771212 +-0.5136299 more . -0.4771212 +-0.3561665 more loin +-0.1649931 on a -0.4771213 +-0.1649931 screening a -0.4771213 +-0.2705918 small . +-0.287799 the screening +-0.2922095 to look +-0.2622373 watch +-0.2922095 watching considering +-0.2922095 what i +-0.2922095 would also +-2 also would -6 +-6 foo bar + +\3-grams: +-0.01916512 more . +-0.0283603 on a little -0.4771212 +-0.0283603 screening a little -0.4771212 +-0.01660496 a little more -0.09409451 +-0.3488368 looking higher +-0.3488368 looking on -0.4771212 +-0.1892331 little more loin +-0.04835128 looking on a -0.4771212 +-3 also would consider -7 +-7 to look good + +\4-grams: +-0.009249173 looking on a little -0.4771212 +-0.005464747 on a little more -0.4771212 +-0.005464747 screening a little more +-0.1453306 a little more loin +-0.01552657 looking on a -0.4771212 +-4 also would consider higher -8 + +\5-grams: +-0.003061223 looking on a little +-0.001813953 looking on a little more +-0.0432557 on a little more loin +-5 also would consider higher looking + +\end\ diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 63c2a612..8c536e66 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,5 +1,6 @@ #include "lm/trie.hh" +#include "lm/bhiksha.hh" #include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" @@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { - return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { + return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } -template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { - if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); +template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : + BitPacked(), + quant_(quant), + // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. + bhiksha_(base, entries + 1, max_next, config), + next_source_(&next_source) { + if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); + BaseInit(reinterpret_cast(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits()); } -template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; @@ -75,47 +81,42 @@ template void BitPackedMiddle::Insert(WordIndex word, float quant_.Write(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); uint64_t next = next_source_->InsertIndex(); - assert(next <= next_mask_); - util::WriteInt57(base_, at_pointer, next_bits_, next); + bhiksha_.WriteNext(base_, at_pointer, insert_index_, next); ++insert_index_; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } + uint64_t index = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); + return true; } -template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { - uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; - at_pointer *= total_bits_; +template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { + uint64_t index; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; + uint64_t at_pointer = index * total_bits_; at_pointer += word_bits_; quant_.ReadBackoff(base_, at_pointer, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); return true; } -template void BitPackedMiddle::FinishedLoading(uint64_t next_end) { - assert(next_end <= next_mask_); - uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_, last_next_write, next_bits_, next_end); +template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { + uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); + bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); + bhiksha_.FinishedLoading(config); } template void BitPackedLongest::Insert(WordIndex index, float prob) { @@ -135,8 +136,10 @@ template bool BitPackedLongest::Find(WordIndex word, float return true; } -template class BitPackedMiddle; -template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; template class BitPackedLongest; template class BitPackedLongest; diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 8fa21aaf..53612064 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -10,6 +10,7 @@ namespace lm { namespace ngram { +class Config; namespace trie { struct NodeRange { @@ -46,13 +47,12 @@ class Unigram { void LoadedBinary() {} - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { + void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { UnigramValue *val = unigram_ + word; prob = val->weights.prob; backoff = val->weights.backoff; next.begin = val->next; next.end = (val+1)->next; - return true; } private: @@ -67,8 +67,6 @@ class BitPacked { return insert_index_; } - void LoadedBinary() {} - protected: static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); @@ -83,30 +81,30 @@ class BitPacked { uint64_t insert_index_, max_vocab_; }; -template class BitPackedMiddle : public BitPacked { +template class BitPackedMiddle : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. - BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); void Insert(WordIndex word, float prob, float backoff); + void FinishedLoading(uint64_t next_end, const Config &config); + + void LoadedBinary() { bhiksha_.LoadedBinary(); } + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; - void FinishedLoading(uint64_t next_end); - private: Quant quant_; - uint8_t next_bits_; - uint64_t next_mask_; + Bhiksha bhiksha_; const BitPacked *next_source_; }; - template class BitPackedLongest : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { @@ -120,6 +118,8 @@ template class BitPackedLongest : public BitPacked { BaseInit(base, max_vocab, quant_.TotalBits()); } + void LoadedBinary() {} + void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 7defd5c1..04979d51 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -37,14 +37,14 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { WordIndex index = 0; while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); - if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words"); if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; ssize_t ret = read(fd, &next_char, 1); - if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word."); + UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word."); buf.push_back(next_char); } // Ok now we have null terminated strings. diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index c92518e4..9d218fff 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary { } } + // Size for purposes of file writing static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. @@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary { // Reorders reorder_vocab so that the IDs are sorted. void FinishedLoading(ProbBackoff *reorder_vocab); + // Trie stores the correct counts including in the header. If this was previously sized based on a count exluding , padding with 8 bytes will make it the correct size based on a count including . + std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } + bool SawUnk() const { return saw_unk_; } void LoadedBinary(int fd, EnumerateVocab *to); diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index b35d80c8..9f47d559 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -107,9 +107,20 @@ void BitPackingSanity(); uint8_t RequiredBits(uint64_t max_value); struct BitsMask { + static BitsMask ByMax(uint64_t max_value) { + BitsMask ret; + ret.FromMax(max_value); + return ret; + } + static BitsMask ByBits(uint8_t bits) { + BitsMask ret; + ret.bits = bits; + ret.mask = (1ULL << bits) - 1; + return ret; + } void FromMax(uint64_t max_value) { bits = RequiredBits(max_value); - mask = (1 << bits) - 1; + mask = (1ULL << bits) - 1; } uint8_t bits; uint64_t mask; diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index d58a0727..fec47fd9 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All - * code is released to the public domain. For business purposes, Murmurhash is - * under the MIT license." - * This is modified from the original: - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. - * length changed to unsigned int. - * placed in namespace util - * add MurmurHashNative - * default option = 0 for seed - */ - -#include "util/murmur_hash.hh" - -namespace util { - -//----------------------------------------------------------------------------- -// MurmurHash2, 64-bit versions, by Austin Appleby - -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment -// and endian-ness issues if used across multiple platforms. - -// 64-bit hash for 64-bit platforms - -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - - -// 64-bit hash for 32-bit platforms - -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) -{ - const unsigned int m = 0x5bd1e995; - const int r = 24; - - unsigned int h1 = seed ^ len; - unsigned int h2 = 0; - - const unsigned int * data = (const unsigned int *)key; - - while(len >= 8) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - unsigned int k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } -} - -} // namespace util +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 00be0ed7..2ec342a6 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -57,7 +57,7 @@ template class IdentityAccessor { public: typedef T Key; - T operator()(const uint64_t *in) const { return *in; } + T operator()(const T *in) const { return *in; } }; struct Pivot64 { @@ -101,6 +101,27 @@ template bool SortedUniformFind(co return BoundedSortedUniformFind(accessor, begin, below, end, above, key, out); } +// May return begin - 1. +template Iterator BinaryBelow( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {} + return pivot - 1; + } + } + return begin - 1; +} + // To use this template, you need to define a Pivot function to match Key. template class SortedUniformMap { public: -- cgit v1.2.3 From 8ecf63852d730f99e7c1bbacfbffdf518d5a0c3f Mon Sep 17 00:00:00 2001 From: Guest_account Guest_account prguest11 Date: Fri, 23 Sep 2011 20:49:43 +0100 Subject: stub work to talk to new kenlm --- decoder/ff_klm.cc | 349 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 349 insertions(+) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 24dcb9c3..016aad26 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,6 +12,353 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#undef NEW_KENLM +#ifdef NEW_KENLM + +#include "lm/left.hh" + +using namespace std; + +// -x : rules include and +// -n NAME : feature id is NAME +bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { + vector const& argv=SplitOnWhitespace(in); + *explicit_markers = false; + *featname="LanguageModel"; + *mapfile = ""; +#define LMSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'x': + *explicit_markers = true; + break; + case 'm': + LMSPEC_NEXTARG; *mapfile=*i; + break; + case 'n': + LMSPEC_NEXTARG; *featname=*i; + break; +#undef LMSPEC_NEXTARG + default: + fail: + cerr<<"Unknown KLanguageModel option "<empty()) + *filename=s; + else { + cerr<<"More than one filename provided. "; + goto usage; + } + } + } + if (!filename->empty()) + return true; +usage: + cerr << "KLanguageModel is incorrect!\n"; + return false; +} + +template +string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { + return "KLanguageModel"; +} + +struct VMapper : public lm::ngram::EnumerateVocab { + VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_->size()) + out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); + (*out_)[cdec_id] = index; + } + vector* out_; + const lm::WordIndex kLM_UNKNOWN_TOKEN; +}; + +template +class KLanguageModelImpl { + + static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { + return *static_cast(state); + } + + inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const { + // if we were clever, we could use the memory pointed to by state to do all + // the work, avoiding this copy + memcpy(state, &lmstate, ngram_->StateSize()); + } + + public: + double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { + double sum = 0.0; + if (oovs) *oovs = 0; + const vector& e = rule.e(); + lm::ngram::ChartState state; + lm::ngram::RuleScore ruleScore(*ngram_, state); + unsigned i = 0; + if (e.size()) { + if (e[i] == kCDEC_SOS) { + ++i; + ruleScore.BeginSentence(); + } else if (e[i] <= 0) { // special case for left-edge NT + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + ++i; + } + } + for (; i < e.size(); ++i) { + if (e[i] <= 0) { + const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); + ruleScore.NonTerminal(prevState, 0.0f); // TODO + } else { + const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, + // maybe handle emission + const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id + const bool is_oov = (cur_word == 0); + if (is_oov) (*oovs) += 1.0; + ruleScore.Terminal(cur_word); + } + } + if (remnant) SetRemnantLMState(state, remnant); + return ruleScore.Finish(); + } + + // this assumes no target words on final unary -> goal rule. is that ok? + // for (n-1 left words) and (n-1 right words) + double FinalTraversalCost(const void* state, double* oovs) { + if (add_sos_eos_) { // rules do not produce , so do it here + lm::ngram::ChartState cstate; + lm::ngram::RuleScore ruleScore(*ngram_, cstate); + ruleScore.BeginSentence(); + SetRemnantLMState(cstate, dummy_state_); + dummy_ants_[1] = state; + *oovs = 0; + return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL); + } else { // rules DO produce ... + double p = 0; + cerr << "not implemented"; abort(); // TODO + //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } + //if (UnscoredSize(state) > 0) { // are there unscored words + // if (kSOS_ != IthUnscoredWord(0, state)) { + // p -= 100 * UnscoredSize(state); + // } + //} + return p; + } + } + + // if this is not a class-based LM, returns w untransformed, + // otherwise returns a word class mapping of w, + // returns TD::Convert("") if there is no mapping for w + WordID ClassifyWordIfNecessary(WordID w) const { + if (word2class_map_.empty()) return w; + if (w >= word2class_map_.size()) + return kCDEC_UNK; + else + return word2class_map_[w]; + } + + // converts to cdec word id's to KenLM's id space, OOVs and end up at 0 + lm::WordIndex MapWord(WordID w) const { + if (w >= cdec2klm_map_.size()) + return 0; + else + return cdec2klm_map_[w]; + } + + public: + KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : + kCDEC_UNK(TD::Convert("")) , + kCDEC_SOS(TD::Convert("")) , + add_sos_eos_(!explicit_markers) { + { + VMapper vm(&cdec2klm_map_); + lm::ngram::Config conf; + conf.enumerate_vocab = &vm; + ngram_ = new Model(filename.c_str(), conf); + } + order_ = ngram_->Order(); + cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; + state_size_ = sizeof(lm::ngram::ChartState); + + // special handling of beginning / ending sentence markers + dummy_state_ = new char[state_size_]; + memset(dummy_state_, 0, state_size_); + dummy_ants_.push_back(dummy_state_); + dummy_ants_.push_back(NULL); + dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); + kSOS_ = MapWord(kCDEC_SOS); + assert(kSOS_ > 0); + kEOS_ = MapWord(TD::Convert("")); + assert(kEOS_ > 0); + assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant + + // handle class-based LMs (unambiguous word->class mapping reqd.) + if (mapfile.size()) + LoadWordClasses(mapfile); + } + + void LoadWordClasses(const string& file) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + vector dummy; + int lc = 0; + cerr << " Loading word classes from " << file << " ...\n"; + AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); + AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); + while(in) { + getline(in, line); + if (!in) continue; + dummy.clear(); + TD::ConvertSentence(line, &dummy); + ++lc; + if (dummy.size() != 2) { + cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; + abort(); + } + AddWordToClassMapping_(dummy[0], dummy[1]); + } + } + + void AddWordToClassMapping_(WordID word, WordID cls) { + if (word2class_map_.size() <= word) { + word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); + assert(word2class_map_.size() > word); + } + if(word2class_map_[word] != kCDEC_UNK) { + cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; + abort(); + } + word2class_map_[word] = cls; + } + + ~KLanguageModelImpl() { + delete ngram_; + delete[] dummy_state_; + } + + int ReserveStateSize() const { return state_size_; } + + private: + const WordID kCDEC_UNK; + const WordID kCDEC_SOS; + lm::WordIndex kSOS_; // - requires special handling. + lm::WordIndex kEOS_; // + Model* ngram_; + const bool add_sos_eos_; // flag indicating whether the hypergraph produces and + // if this is true, FinalTransitionFeatures will "add" and + // if false, FinalTransitionFeatures will score anything with the + // markers in the right place (i.e., the beginning and end of + // the sentence) with 0, and anything else with -100 + + int order_; + int state_size_; + char* dummy_state_; + vector dummy_ants_; + vector cdec2klm_map_; + vector word2class_map_; // if this is a class-based LM, this is the word->class mapping + TRulePtr dummy_rule_; +}; + +template +KLanguageModel::KLanguageModel(const string& param) { + string filename, mapfile, featname; + bool explicit_markers; + if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { + abort(); + } + try { + pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); + } + fid_ = FD::Convert(featname); + oov_fid_ = FD::Convert(featname+"_OOV"); + // cerr << "FID: " << oov_fid_ << endl; + SetStateSize(pimpl_->ReserveStateSize()); +} + +template +Features KLanguageModel::features() const { + return single_feature(fid_); +} + +template +KLanguageModel::~KLanguageModel() { + delete pimpl_; +} + +template +void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + double est = 0; + double oovs = 0; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &oovs, state)); + if (oovs && oov_fid_) + features->set_value(oov_fid_, oovs); +} + +template +void KLanguageModel::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + double oovs = 0; + double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); + features->set_value(fid_, lm); + if (oov_fid_ && oovs) + features->set_value(oov_fid_, oovs); +} + +template boost::shared_ptr CreateModel(const std::string ¶m) { + KLanguageModel *ret = new KLanguageModel(param); + ret->Init(); + return boost::shared_ptr(ret); +} + +boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel(param); + case TRIE_SORTED: + return CreateModel(param); + case ARRAY_TRIE_SORTED: + return CreateModel(param); + case QUANT_TRIE_SORTED: + return CreateModel(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel::usage(params, verbose); +} + +#else + using namespace std; static const unsigned char HAS_FULL_CONTEXT = 1; @@ -469,3 +816,5 @@ boost::shared_ptr KLanguageModelFactory::Create(std::string par std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel::usage(params, verbose); } + +#endif -- cgit v1.2.3 From d71c74f3924e6c207f3ebfab470b9a30e2551dde Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 23 Sep 2011 16:38:01 -0400 Subject: Go through ff_klm and try to fix it for the new version. --- decoder/ff_klm.cc | 36 +++++++++--------------------------- 1 file changed, 9 insertions(+), 27 deletions(-) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 016aad26..3b2113ad 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -90,19 +90,12 @@ class KLanguageModelImpl { return *static_cast(state); } - inline void SetRemnantLMState(const lm::ngram::ChartState& lmstate, void* state) const { - // if we were clever, we could use the memory pointed to by state to do all - // the work, avoiding this copy - memcpy(state, &lmstate, ngram_->StateSize()); - } - public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { - double sum = 0.0; if (oovs) *oovs = 0; const vector& e = rule.e(); lm::ngram::ChartState state; - lm::ngram::RuleScore ruleScore(*ngram_, state); + lm::ngram::RuleScore ruleScore(*ngram_, remnant ? *static_cast(remnant) : state); unsigned i = 0; if (e.size()) { if (e[i] == kCDEC_SOS) { @@ -123,12 +116,13 @@ class KLanguageModelImpl { // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id const bool is_oov = (cur_word == 0); - if (is_oov) (*oovs) += 1.0; + if (is_oov && oovs) (*oovs) += 1.0; ruleScore.Terminal(cur_word); } } - if (remnant) SetRemnantLMState(state, remnant); - return ruleScore.Finish(); + double ret = ruleScore.Finish(); + state.ZeroRemaining(); + return ret; } // this assumes no target words on final unary -> goal rule. is that ok? @@ -138,10 +132,9 @@ class KLanguageModelImpl { lm::ngram::ChartState cstate; lm::ngram::RuleScore ruleScore(*ngram_, cstate); ruleScore.BeginSentence(); - SetRemnantLMState(cstate, dummy_state_); - dummy_ants_[1] = state; - *oovs = 0; - return LookupWords(*dummy_rule_, dummy_ants_, oovs, NULL); + ruleScore.NonTerminal(RemnantLMState(state), 0.0f); + ruleScore.Terminal(kEOS_); + return ruleScore.Finish(); } else { // rules DO produce ... double p = 0; cerr << "not implemented"; abort(); // TODO @@ -187,14 +180,8 @@ class KLanguageModelImpl { } order_ = ngram_->Order(); cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; - state_size_ = sizeof(lm::ngram::ChartState); // special handling of beginning / ending sentence markers - dummy_state_ = new char[state_size_]; - memset(dummy_state_, 0, state_size_); - dummy_ants_.push_back(dummy_state_); - dummy_ants_.push_back(NULL); - dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); kSOS_ = MapWord(kCDEC_SOS); assert(kSOS_ > 0); kEOS_ = MapWord(TD::Convert("")); @@ -243,10 +230,9 @@ class KLanguageModelImpl { ~KLanguageModelImpl() { delete ngram_; - delete[] dummy_state_; } - int ReserveStateSize() const { return state_size_; } + int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); } private: const WordID kCDEC_UNK; @@ -261,12 +247,8 @@ class KLanguageModelImpl { // the sentence) with 0, and anything else with -100 int order_; - int state_size_; - char* dummy_state_; - vector dummy_ants_; vector cdec2klm_map_; vector word2class_map_; // if this is a class-based LM, this is the word->class mapping - TRulePtr dummy_rule_; }; template -- cgit v1.2.3 From 747309fcb0e0b1c6d060a68286ba1cf5ed1fbfa4 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sat, 24 Sep 2011 11:33:22 -0400 Subject: Chris says remnant and oovs should not be null, so stop checking. Also, we were not properly doing ZeroRemaining, sorry. --- decoder/ff_klm.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 3b2113ad..6d9aca54 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -92,10 +92,9 @@ class KLanguageModelImpl { public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { - if (oovs) *oovs = 0; + *oovs = 0; const vector& e = rule.e(); - lm::ngram::ChartState state; - lm::ngram::RuleScore ruleScore(*ngram_, remnant ? *static_cast(remnant) : state); + lm::ngram::RuleScore ruleScore(*ngram_, *static_cast(remnant)); unsigned i = 0; if (e.size()) { if (e[i] == kCDEC_SOS) { @@ -115,13 +114,12 @@ class KLanguageModelImpl { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id - const bool is_oov = (cur_word == 0); - if (is_oov && oovs) (*oovs) += 1.0; + if (cur_word == 0) (*oovs) += 1.0; ruleScore.Terminal(cur_word); } } double ret = ruleScore.Finish(); - state.ZeroRemaining(); + static_cast(remnant)->ZeroRemaining(); return ret; } -- cgit v1.2.3 From 957d90991b4ec80b9877126c736bd60768b094aa Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 17 Oct 2011 16:58:26 +0100 Subject: Chris, I'd like you to review this for use with your rules that contain and . --- decoder/ff_klm.cc | 72 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 23 deletions(-) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 6d9aca54..658aef80 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -71,6 +71,8 @@ string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { return "KLanguageModel"; } +namespace { + struct VMapper : public lm::ngram::EnumerateVocab { VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } void Add(lm::WordIndex index, const StringPiece &str) { @@ -83,66 +85,90 @@ struct VMapper : public lm::ngram::EnumerateVocab { const lm::WordIndex kLM_UNKNOWN_TOKEN; }; -template -class KLanguageModelImpl { +#pragma pack(push) +#pragma pack(1) - static inline const lm::ngram::ChartState& RemnantLMState(const void* state) { - return *static_cast(state); +struct BoundaryAnnotatedState { + lm::ngram::ChartState state; + bool seen_bos, seen_eos; +}; + +#pragma pack(pop) + +void BoundaryCheck(bool &annotated, bool sub, double &ret) { + if (!sub) return; + if (annotated) { + ret -= 100.0; + } else { + annotated = true; } +} +} // namespace + +template +class KLanguageModelImpl { public: double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector& e = rule.e(); - lm::ngram::RuleScore ruleScore(*ngram_, *static_cast(remnant)); + BoundaryAnnotatedState &annotated = *static_cast(remnant); + lm::ngram::RuleScore ruleScore(*ngram_, annotated.state); + annotated.seen_bos = false; + annotated.seen_eos = false; unsigned i = 0; + double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); + annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[0]]); + ruleScore.BeginNonTerminal(sub.state, 0.0f); + annotated.seen_bos = sub.seen_bos; + annotated.seen_eos = sub.seen_eos; ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const lm::ngram::ChartState& prevState = RemnantLMState(ant_states[-e[i]]); - ruleScore.NonTerminal(prevState, 0.0f); // TODO + const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[i]]); + ruleScore.NonTerminal(sub.state, 0.0f); + BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); + BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; + BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - double ret = ruleScore.Finish(); - static_cast(remnant)->ZeroRemaining(); + ret += ruleScore.Finish(); + annotated.state.ZeroRemaining(); return ret; } // this assumes no target words on final unary -> goal rule. is that ok? // for (n-1 left words) and (n-1 right words) - double FinalTraversalCost(const void* state, double* oovs) { + double FinalTraversalCost(const void* state_void, double* oovs) { + const BoundaryAnnotatedState &annotated = *static_cast(state_void); if (add_sos_eos_) { // rules do not produce , so do it here + assert(!annotated.seen_bos); + assert(!annotated.seen_eos); lm::ngram::ChartState cstate; lm::ngram::RuleScore ruleScore(*ngram_, cstate); ruleScore.BeginSentence(); - ruleScore.NonTerminal(RemnantLMState(state), 0.0f); + ruleScore.NonTerminal(annotated.state, 0.0f); ruleScore.Terminal(kEOS_); return ruleScore.Finish(); } else { // rules DO produce ... - double p = 0; - cerr << "not implemented"; abort(); // TODO - //if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - //if (UnscoredSize(state) > 0) { // are there unscored words - // if (kSOS_ != IthUnscoredWord(0, state)) { - // p -= 100 * UnscoredSize(state); - // } - //} - return p; + double ret = 0.0; + if (!annotated.seen_bos) ret -= 100.0; + if (!annotated.seen_eos) ret -= 100.0; + return ret; } } @@ -230,7 +256,7 @@ class KLanguageModelImpl { delete ngram_; } - int ReserveStateSize() const { return sizeof(lm::ngram::ChartState); } + int ReserveStateSize() const { return sizeof(BoundaryAnnotatedState); } private: const WordID kCDEC_UNK; -- cgit v1.2.3 From 3d1ed02a4e5d81aace80b0e004e96351d116630f Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 18 Oct 2011 10:25:56 +0100 Subject: Revised and handling --- decoder/ff_klm.cc | 84 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 26 deletions(-) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 658aef80..3c941fbf 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,8 +12,8 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" +#define NEW_KENLM #undef NEW_KENLM -#ifdef NEW_KENLM #include "lm/left.hh" @@ -95,14 +95,58 @@ struct BoundaryAnnotatedState { #pragma pack(pop) -void BoundaryCheck(bool &annotated, bool sub, double &ret) { - if (!sub) return; - if (annotated) { - ret -= 100.0; - } else { - annotated = true; - } -} +template class BoundaryRuleScore { + public: + BoundaryRuleScore(const Model &m, BoundaryAnnotatedState &state) : + back_(m, state.state), + bos_(state.seen_bos), + eos_(state.seen_eos), + penalty_(0.0), + end_sentence_(m.GetVocabulary().EndSentence()) { + bos_ = false; + eos_ = false; + } + + void BeginSentence() { + back_.BeginSentence(); + bos_ = true; + } + + void BeginNonTerminal(const BoundaryAnnotatedState &sub) { + back_.BeginNonTerminal(sub.state, 0.0f); + bos_ = sub.seen_bos; + eos_ = sub.seen_eos; + } + + void NonTerminal(const BoundaryAnnotatedState &sub) { + back_.NonTerminal(sub.state, 0.0f); + // cdec only calls this if there's content. + if (sub.seen_bos) { + bos_ = true; + penalty_ -= 100.0f; + } + if (eos_) penalty_ -= 100.0f; + eos_ |= sub.seen_eos; + } + + void Terminal(lm::WordIndex word) { + back_.Terminal(word); + if (eos_) penalty_ -= 100.0f; + if (word == end_sentence_) eos_ = true; + } + + float Finish() { + return penalty_ + back_.Finish(); + } + + private: + lm::ngram::RuleScore back_; + bool &bos_, &eos_; + + float penalty_; + + lm::WordIndex end_sentence_; +}; } // namespace @@ -112,42 +156,30 @@ class KLanguageModelImpl { double LookupWords(const TRule& rule, const vector& ant_states, double* oovs, void* remnant) { *oovs = 0; const vector& e = rule.e(); - BoundaryAnnotatedState &annotated = *static_cast(remnant); - lm::ngram::RuleScore ruleScore(*ngram_, annotated.state); - annotated.seen_bos = false; - annotated.seen_eos = false; + BoundaryRuleScore ruleScore(*ngram_, *static_cast(remnant)); unsigned i = 0; - double ret = 0.0; if (e.size()) { if (e[i] == kCDEC_SOS) { ++i; ruleScore.BeginSentence(); - annotated.seen_bos = true; } else if (e[i] <= 0) { // special case for left-edge NT - const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[0]]); - ruleScore.BeginNonTerminal(sub.state, 0.0f); - annotated.seen_bos = sub.seen_bos; - annotated.seen_eos = sub.seen_eos; + ruleScore.BeginNonTerminal(*static_cast(ant_states[-e[0]])); ++i; } } for (; i < e.size(); ++i) { if (e[i] <= 0) { - const BoundaryAnnotatedState &sub = *static_cast(ant_states[-e[i]]); - ruleScore.NonTerminal(sub.state, 0.0f); - BoundaryCheck(annotated.seen_bos, sub.seen_bos, ret); - BoundaryCheck(annotated.seen_eos, sub.seen_eos, ret); + ruleScore.NonTerminal(*static_cast(ant_states[-e[i]])); } else { const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[i]); // in future, // maybe handle emission const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id if (cur_word == 0) (*oovs) += 1.0; - BoundaryCheck(annotated.seen_eos, cur_word == kEOS_, ret); ruleScore.Terminal(cur_word); } } - ret += ruleScore.Finish(); - annotated.state.ZeroRemaining(); + double ret = ruleScore.Finish(); + static_cast(remnant)->state.ZeroRemaining(); return ret; } -- cgit v1.2.3 From 04e38a57b19ea012895ac2efb39382c2e77833a9 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 18 Oct 2011 14:19:09 +0100 Subject: incorporate kenneth's fixes --- decoder/ff_klm.cc | 464 ------------------------------------------------------ 1 file changed, 464 deletions(-) (limited to 'decoder/ff_klm.cc') diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 3c941fbf..ed6f731e 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -12,9 +12,6 @@ #include "lm/model.hh" #include "lm/enumerate_vocab.hh" -#define NEW_KENLM -#undef NEW_KENLM - #include "lm/left.hh" using namespace std; @@ -395,464 +392,3 @@ std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel::usage(params, verbose); } -#else - -using namespace std; - -static const unsigned char HAS_FULL_CONTEXT = 1; -static const unsigned char HAS_EOS_ON_RIGHT = 2; -static const unsigned char MASK = 7; - -// -x : rules include and -// -n NAME : feature id is NAME -bool ParseLMArgs(string const& in, string* filename, string* mapfile, bool* explicit_markers, string* featname) { - vector const& argv=SplitOnWhitespace(in); - *explicit_markers = false; - *featname="LanguageModel"; - *mapfile = ""; -#define LMSPEC_NEXTARG if (i==argv.end()) { \ - cerr << "Missing argument for "<<*last<<". "; goto usage; \ - } else { ++i; } - - for (vector::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { - string const& s=*i; - if (s[0]=='-') { - if (s.size()>2) goto fail; - switch (s[1]) { - case 'x': - *explicit_markers = true; - break; - case 'm': - LMSPEC_NEXTARG; *mapfile=*i; - break; - case 'n': - LMSPEC_NEXTARG; *featname=*i; - break; -#undef LMSPEC_NEXTARG - default: - fail: - cerr<<"Unknown KLanguageModel option "<empty()) - *filename=s; - else { - cerr<<"More than one filename provided. "; - goto usage; - } - } - } - if (!filename->empty()) - return true; -usage: - cerr << "KLanguageModel is incorrect!\n"; - return false; -} - -template -string KLanguageModel::usage(bool /*param*/,bool /*verbose*/) { - return "KLanguageModel"; -} - -struct VMapper : public lm::ngram::EnumerateVocab { - VMapper(vector* out) : out_(out), kLM_UNKNOWN_TOKEN(0) { out_->clear(); } - void Add(lm::WordIndex index, const StringPiece &str) { - const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_->size()) - out_->resize(cdec_id + 1, kLM_UNKNOWN_TOKEN); - (*out_)[cdec_id] = index; - } - vector* out_; - const lm::WordIndex kLM_UNKNOWN_TOKEN; -}; - -template -class KLanguageModelImpl { - - // returns the number of unscored words at the left edge of a span - inline int UnscoredSize(const void* state) const { - return *(static_cast(state) + unscored_size_offset_); - } - - inline void SetUnscoredSize(int size, void* state) const { - *(static_cast(state) + unscored_size_offset_) = size; - } - - static inline const lm::ngram::State& RemnantLMState(const void* state) { - return *static_cast(state); - } - - inline void SetRemnantLMState(const lm::ngram::State& lmstate, void* state) const { - // if we were clever, we could use the memory pointed to by state to do all - // the work, avoiding this copy - memcpy(state, &lmstate, ngram_->StateSize()); - } - - lm::WordIndex IthUnscoredWord(int i, const void* state) const { - const lm::WordIndex* const mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); - return mem[i]; - } - - void SetIthUnscoredWord(int i, lm::WordIndex index, void *state) const { - lm::WordIndex* mem = reinterpret_cast(static_cast(state) + unscored_words_offset_); - mem[i] = index; - } - - inline bool GetFlag(const void *state, unsigned char flag) const { - return (*(static_cast(state) + is_complete_offset_) & flag); - } - - inline void SetFlag(bool on, unsigned char flag, void *state) const { - if (on) { - *(static_cast(state) + is_complete_offset_) |= flag; - } else { - *(static_cast(state) + is_complete_offset_) &= (MASK ^ flag); - } - } - - inline bool HasFullContext(const void *state) const { - return GetFlag(state, HAS_FULL_CONTEXT); - } - - inline void SetHasFullContext(bool flag, void *state) const { - SetFlag(flag, HAS_FULL_CONTEXT, state); - } - - public: - double LookupWords(const TRule& rule, const vector& ant_states, double* pest_sum, double* oovs, double* est_oovs, void* remnant) { - double sum = 0.0; - double est_sum = 0.0; - int num_scored = 0; - int num_estimated = 0; - if (oovs) *oovs = 0; - if (est_oovs) *est_oovs = 0; - bool saw_eos = false; - bool has_some_history = false; - lm::ngram::State state = ngram_->NullContextState(); - const vector& e = rule.e(); - bool context_complete = false; - for (int j = 0; j < e.size(); ++j) { - if (e[j] < 1) { // handle non-terminal substitution - const void* astate = (ant_states[-e[j]]); - int unscored_ant_len = UnscoredSize(astate); - for (int k = 0; k < unscored_ant_len; ++k) { - const lm::WordIndex cur_word = IthUnscoredWord(k, astate); - const bool is_oov = (cur_word == 0); - double p = 0; - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } - } - saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); - if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 - state = RemnantLMState(astate); - context_complete = true; - } - } else { // handle terminal - const WordID cdec_word_or_class = ClassifyWordIfNecessary(e[j]); // in future, - // maybe handle emission - const lm::WordIndex cur_word = MapWord(cdec_word_or_class); // map to LM's id - double p = 0; - const bool is_oov = (cur_word == 0); - if (cur_word == kSOS_) { - state = ngram_->BeginSentenceState(); - if (has_some_history) { // this is immediately fully scored, and bad - p = -100; - context_complete = true; - } else { // this might be a real - num_scored = max(0, order_ - 2); - } - } else { - const lm::ngram::State scopy(state); - p = ngram_->Score(scopy, cur_word, state); - if (saw_eos) { p = -100; } - saw_eos = (cur_word == kEOS_); - } - has_some_history = true; - ++num_scored; - if (!context_complete) { - if (num_scored >= order_) context_complete = true; - } - if (context_complete) { - sum += p; - if (oovs && is_oov) (*oovs)++; - } else { - if (remnant) - SetIthUnscoredWord(num_estimated, cur_word, remnant); - ++num_estimated; - est_sum += p; - if (est_oovs && is_oov) (*est_oovs)++; - } - } - } - if (pest_sum) *pest_sum = est_sum; - if (remnant) { - state.ZeroRemaining(); - SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); - SetRemnantLMState(state, remnant); - SetUnscoredSize(num_estimated, remnant); - SetHasFullContext(context_complete || (num_scored >= order_), remnant); - } - return sum; - } - - // this assumes no target words on final unary -> goal rule. is that ok? - // for (n-1 left words) and (n-1 right words) - double FinalTraversalCost(const void* state, double* oovs) { - if (add_sos_eos_) { // rules do not produce , so do it here - SetRemnantLMState(ngram_->BeginSentenceState(), dummy_state_); - SetHasFullContext(1, dummy_state_); - SetUnscoredSize(0, dummy_state_); - dummy_ants_[1] = state; - *oovs = 0; - return LookupWords(*dummy_rule_, dummy_ants_, NULL, oovs, NULL, NULL); - } else { // rules DO produce ... - double p = 0; - if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } - if (UnscoredSize(state) > 0) { // are there unscored words - if (kSOS_ != IthUnscoredWord(0, state)) { - p -= 100 * UnscoredSize(state); - } - } - return p; - } - } - - // if this is not a class-based LM, returns w untransformed, - // otherwise returns a word class mapping of w, - // returns TD::Convert("") if there is no mapping for w - WordID ClassifyWordIfNecessary(WordID w) const { - if (word2class_map_.empty()) return w; - if (w >= word2class_map_.size()) - return kCDEC_UNK; - else - return word2class_map_[w]; - } - - // converts to cdec word id's to KenLM's id space, OOVs and end up at 0 - lm::WordIndex MapWord(WordID w) const { - if (w >= cdec2klm_map_.size()) - return 0; - else - return cdec2klm_map_[w]; - } - - public: - KLanguageModelImpl(const string& filename, const string& mapfile, bool explicit_markers) : - kCDEC_UNK(TD::Convert("")) , - add_sos_eos_(!explicit_markers) { - { - VMapper vm(&cdec2klm_map_); - lm::ngram::Config conf; - conf.enumerate_vocab = &vm; - ngram_ = new Model(filename.c_str(), conf); - } - order_ = ngram_->Order(); - cerr << "Loaded " << order_ << "-gram KLM from " << filename << " (MapSize=" << cdec2klm_map_.size() << ")\n"; - state_size_ = ngram_->StateSize() + 2 + (order_ - 1) * sizeof(lm::WordIndex); - unscored_size_offset_ = ngram_->StateSize(); - is_complete_offset_ = unscored_size_offset_ + 1; - unscored_words_offset_ = is_complete_offset_ + 1; - - // special handling of beginning / ending sentence markers - dummy_state_ = new char[state_size_]; - memset(dummy_state_, 0, state_size_); - dummy_ants_.push_back(dummy_state_); - dummy_ants_.push_back(NULL); - dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] ||| X=0")); - kSOS_ = MapWord(TD::Convert("")); - assert(kSOS_ > 0); - kEOS_ = MapWord(TD::Convert("")); - assert(kEOS_ > 0); - assert(MapWord(kCDEC_UNK) == 0); // KenLM invariant - - // handle class-based LMs (unambiguous word->class mapping reqd.) - if (mapfile.size()) - LoadWordClasses(mapfile); - } - - void LoadWordClasses(const string& file) { - ReadFile rf(file); - istream& in = *rf.stream(); - string line; - vector dummy; - int lc = 0; - cerr << " Loading word classes from " << file << " ...\n"; - AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); - AddWordToClassMapping_(TD::Convert(""), TD::Convert("")); - while(in) { - getline(in, line); - if (!in) continue; - dummy.clear(); - TD::ConvertSentence(line, &dummy); - ++lc; - if (dummy.size() != 2) { - cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; - abort(); - } - AddWordToClassMapping_(dummy[0], dummy[1]); - } - } - - void AddWordToClassMapping_(WordID word, WordID cls) { - if (word2class_map_.size() <= word) { - word2class_map_.resize((word + 10) * 1.1, kCDEC_UNK); - assert(word2class_map_.size() > word); - } - if(word2class_map_[word] != kCDEC_UNK) { - cerr << "Multiple classes for symbol " << TD::Convert(word) << endl; - abort(); - } - word2class_map_[word] = cls; - } - - ~KLanguageModelImpl() { - delete ngram_; - delete[] dummy_state_; - } - - int ReserveStateSize() const { return state_size_; } - - private: - const WordID kCDEC_UNK; - lm::WordIndex kSOS_; // - requires special handling. - lm::WordIndex kEOS_; // - Model* ngram_; - const bool add_sos_eos_; // flag indicating whether the hypergraph produces and - // if this is true, FinalTransitionFeatures will "add" and - // if false, FinalTransitionFeatures will score anything with the - // markers in the right place (i.e., the beginning and end of - // the sentence) with 0, and anything else with -100 - - int order_; - int state_size_; - int unscored_size_offset_; - int is_complete_offset_; - int unscored_words_offset_; - char* dummy_state_; - vector dummy_ants_; - vector cdec2klm_map_; - vector word2class_map_; // if this is a class-based LM, this is the word->class mapping - TRulePtr dummy_rule_; -}; - -template -KLanguageModel::KLanguageModel(const string& param) { - string filename, mapfile, featname; - bool explicit_markers; - if (!ParseLMArgs(param, &filename, &mapfile, &explicit_markers, &featname)) { - abort(); - } - try { - pimpl_ = new KLanguageModelImpl(filename, mapfile, explicit_markers); - } catch (std::exception &e) { - std::cerr << e.what() << std::endl; - abort(); - } - fid_ = FD::Convert(featname); - oov_fid_ = FD::Convert(featname+"_OOV"); - cerr << "FID: " << oov_fid_ << endl; - SetStateSize(pimpl_->ReserveStateSize()); -} - -template -Features KLanguageModel::features() const { - return single_feature(fid_); -} - -template -KLanguageModel::~KLanguageModel() { - delete pimpl_; -} - -template -void KLanguageModel::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, - const Hypergraph::Edge& edge, - const vector& ant_states, - SparseVector* features, - SparseVector* estimated_features, - void* state) const { - double est = 0; - double oovs = 0; - double est_oovs = 0; - features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, &est, &oovs, &est_oovs, state)); - estimated_features->set_value(fid_, est); - if (oov_fid_) { - if (oovs) features->set_value(oov_fid_, oovs); - if (est_oovs) estimated_features->set_value(oov_fid_, est_oovs); - } -} - -template -void KLanguageModel::FinalTraversalFeatures(const void* ant_state, - SparseVector* features) const { - double oovs = 0; - double lm = pimpl_->FinalTraversalCost(ant_state, &oovs); - features->set_value(fid_, lm); - if (oov_fid_ && oovs) - features->set_value(oov_fid_, oovs); -} - -template boost::shared_ptr CreateModel(const std::string ¶m) { - KLanguageModel *ret = new KLanguageModel(param); - ret->Init(); - return boost::shared_ptr(ret); -} - -boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { - using namespace lm::ngram; - std::string filename, ignored_map; - bool ignored_markers; - std::string ignored_featname; - ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); - ModelType m; - if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; - - switch (m) { - case HASH_PROBING: - return CreateModel(param); - case TRIE_SORTED: - return CreateModel(param); - case ARRAY_TRIE_SORTED: - return CreateModel(param); - case QUANT_TRIE_SORTED: - return CreateModel(param); - case QUANT_ARRAY_TRIE_SORTED: - return CreateModel(param); - default: - UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); - } -} - -std::string KLanguageModelFactory::usage(bool params,bool verbose) const { - return KLanguageModel::usage(params, verbose); -} - -#endif -- cgit v1.2.3 From ef2df950520a47ca7011736648334eedeae5297a Mon Sep 17 00:00:00 2001 From: Patrick Simianer Date: Wed, 19 Oct 2011 20:56:22 +0200 Subject: merged, compiles but not working --- .gitignore | 3 + decoder/ff_klm.cc | 19 ------- dtrain/dtrain.cc | 75 ++++++++++++++++--------- dtrain/dtrain.h | 2 + dtrain/kbestget.h | 6 +- dtrain/test/example/dtrain.ini | 8 +-- klm/lm/binary_format.cc | 4 -- klm/lm/search_trie.cc | 123 ----------------------------------------- klm/lm/trie.cc | 10 ---- utils/fdict.h | 1 - 10 files changed, 63 insertions(+), 188 deletions(-) (limited to 'decoder/ff_klm.cc') diff --git a/.gitignore b/.gitignore index 7e63c4ef..43b48a97 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,6 @@ training/compute_cllh dtrain/dtrain weights.gz dtrain/test/eval/ +phrasinator/gibbs_train_plm_notables +training/mpi_flex_optimize +utils/phmt diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 28bcb6b9..ed6f731e 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -392,22 +392,3 @@ std::string KLanguageModelFactory::usage(bool params,bool verbose) const { return KLanguageModel::usage(params, verbose); } - switch (m) { - case HASH_PROBING: - return CreateModel(param); - case TRIE_SORTED: - return CreateModel(param); - case ARRAY_TRIE_SORTED: - return CreateModel(param); - case QUANT_TRIE_SORTED: - return CreateModel(param); - case QUANT_ARRAY_TRIE_SORTED: - return CreateModel(param); - default: - UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); - } -} - -std::string KLanguageModelFactory::usage(bool params,bool verbose) const { - return KLanguageModel::usage(params, verbose); -} diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc index 0a94f7aa..e96b65aa 100644 --- a/dtrain/dtrain.cc +++ b/dtrain/dtrain.cc @@ -20,8 +20,8 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("stop_after", po::value()->default_value(0), "stop after X input sentences") ("print_weights", po::value(), "weights to print on each iteration") ("hstreaming", po::value()->zero_tokens(), "run in hadoop streaming mode") - ("learning_rate", po::value()->default_value(0.0005), "learning rate") - ("gamma", po::value()->default_value(0), "gamma for SVM (0 for perceptron)") + ("learning_rate", po::value()->default_value(0.0005), "learning rate") + ("gamma", po::value()->default_value(0), "gamma for SVM (0 for perceptron)") ("tmp", po::value()->default_value("/tmp"), "temp dir to use") ("select_weights", po::value()->default_value("last"), "output 'best' or 'last' weights ('VOID' to throw away)") ("noup", po::value()->zero_tokens(), "do not update weights"); @@ -134,15 +134,14 @@ main(int argc, char** argv) observer->SetScorer(scorer); // init weights - Weights weights; - if (cfg.count("input_weights")) weights.InitFromFile(cfg["input_weights"].as()); - SparseVector lambdas; - weights.InitSparseVector(&lambdas); - vector dense_weights; + vector& dense_weights = decoder.CurrentWeightVector(); + SparseVector lambdas; + if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as(), &dense_weights); + Weights::InitSparseVector(dense_weights, &lambdas); // meta params for perceptron, SVM - double eta = cfg["learning_rate"].as(); - double gamma = cfg["gamma"].as(); + weight_t eta = cfg["learning_rate"].as(); + weight_t gamma = cfg["gamma"].as(); WordID __bias = FD::Convert("__bias"); lambdas.add_value(__bias, 0); @@ -160,7 +159,7 @@ main(int argc, char** argv) grammar_buf_out.open(grammar_buf_fn.c_str()); unsigned in_sz = 999999999; // input index, input size - vector > all_scores; + vector > all_scores; score_t max_score = 0.; unsigned best_it = 0; float overall_time = 0.; @@ -189,6 +188,15 @@ main(int argc, char** argv) } + //LogVal a(2.2); + //LogVal b(2.1); + //cout << a << endl; + //cout << log(a) << endl; + //LogVal c = a - b; + //cout << log(c) << endl; + //exit(0); + + for (unsigned t = 0; t < T; t++) // T epochs { @@ -196,7 +204,8 @@ main(int argc, char** argv) time(&start); igzstream grammar_buf_in; if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str()); - score_t score_sum = 0., model_sum = 0.; + score_t score_sum = 0.; + score_t model_sum(0); unsigned ii = 0, nup = 0, npairs = 0; if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; @@ -238,10 +247,7 @@ main(int argc, char** argv) if (next || stop) break; // weights - dense_weights.clear(); - weights.InitFromVector(lambdas); - weights.InitVector(&dense_weights); - decoder.SetWeights(dense_weights); + lambdas.init_vector(&dense_weights); // getting input vector in_split; // input: sid\tsrc\tref\tpsg @@ -289,7 +295,8 @@ main(int argc, char** argv) // get (scored) samples vector* samples = observer->GetSamples(); - if (verbose) { + // FIXME + /*if (verbose) { cout << "[ref: '"; if (t > 0) cout << ref_ids_buf[ii]; else cout << ref_ids; @@ -297,7 +304,15 @@ main(int argc, char** argv) cout << _p5 << _np << "1best: " << "'" << (*samples)[0].w << "'" << endl; cout << "SCORE=" << (*samples)[0].score << ",model="<< (*samples)[0].model << endl; cout << "F{" << (*samples)[0].f << "} ]" << endl << endl; - } + }*/ + /*cout << lambdas.get(FD::Convert("PhraseModel_0")) << endl; + cout << (*samples)[0].model << endl; + cout << "1best: "; + for (unsigned u = 0; u < (*samples)[0].w.size(); u++) cout << TD::Convert((*samples)[0].w[u]) << " "; + cout << endl; + cout << (*samples)[0].f << endl; + cout << "___" << endl;*/ + score_sum += (*samples)[0].score; model_sum += (*samples)[0].model; @@ -317,21 +332,21 @@ main(int argc, char** argv) if (!gamma) { // perceptron if (it->first.score - it->second.score < 0) { // rank error - SparseVector dv = it->second.f - it->first.f; + SparseVector dv = it->second.f - it->first.f; dv.add_value(__bias, -1); lambdas.plus_eq_v_times_s(dv, eta); nup++; } } else { // SVM - double rank_error = it->second.score - it->first.score; + score_t rank_error = it->second.score - it->first.score; if (rank_error > 0) { - SparseVector dv = it->second.f - it->first.f; + SparseVector dv = it->second.f - it->first.f; dv.add_value(__bias, -1); lambdas.plus_eq_v_times_s(dv, eta); } // regularization - double margin = it->first.model - it->second.model; + score_t margin = it->first.model - it->second.model; if (rank_error || margin < 1) { lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta); // reg /= #EXAMPLES or #UPDATES ? nup++; @@ -339,6 +354,15 @@ main(int argc, char** argv) } } } + + + vector x; + lambdas.init_vector(&x); + for (int q = 0; q < x.size(); q++) { + if (x[q] < -10 && x[q] != 0) + cout << FD::Convert(q) << " " << x[q] << endl; + } + cout << " --- " << endl; ++ii; @@ -358,7 +382,8 @@ main(int argc, char** argv) // print some stats score_t score_avg = score_sum/(score_t)in_sz; score_t model_avg = model_sum/(score_t)in_sz; - score_t score_diff, model_diff; + score_t score_diff; + score_t model_diff; if (t > 0) { score_diff = score_avg - all_scores[t-1].first; model_diff = model_avg - all_scores[t-1].second; @@ -402,10 +427,10 @@ main(int argc, char** argv) // write weights to file if (select_weights == "best") { - weights.InitFromVector(lambdas); string infix = "dtrain-weights-" + boost::lexical_cast(t); + lambdas.init_vector(&dense_weights); string w_fn = gettmpf(tmp_path, infix, "gz"); - weights.WriteToFile(w_fn, true); + Weights::WriteToFile(w_fn, dense_weights, true); weights_files.push_back(w_fn); } @@ -420,7 +445,7 @@ main(int argc, char** argv) ostream& o = *of.stream(); o.precision(17); o << _np; - for (SparseVector::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { + for (SparseVector::const_iterator it = lambdas.begin(); it != lambdas.end(); ++it) { if (it->second == 0) continue; o << FD::Convert(it->first) << '\t' << it->second << endl; } diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h index e98ef470..7c1509e4 100644 --- a/dtrain/dtrain.h +++ b/dtrain/dtrain.h @@ -11,6 +11,8 @@ #include "ksampler.h" #include "pairsampling.h" +#include "filelib.h" + #define DTRAIN_DOTS 100 // when to display a '.' #define DTRAIN_GRAMMAR_DELIM "########EOS########" diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h index d141da60..4aadee7a 100644 --- a/dtrain/kbestget.h +++ b/dtrain/kbestget.h @@ -7,6 +7,7 @@ #include "ff_register.h" #include "decoder.h" #include "weights.h" +#include "logval.h" using namespace std; @@ -106,7 +107,8 @@ struct KBestGetter : public HypSampler ScoredHyp h; h.w = d->yield; h.f = d->feature_values; - h.model = log(d->score); + h.model = d->score; + cout << i << ". "<< h.model << endl; h.rank = i; h.score = scorer_->Score(h.w, *ref_, i); s_.push_back(h); @@ -125,7 +127,7 @@ struct KBestGetter : public HypSampler ScoredHyp h; h.w = d->yield; h.f = d->feature_values; - h.model = log(d->score); + h.model = -1*log(d->score); h.rank = i; h.score = scorer_->Score(h.w, *ref_, i); s_.push_back(h); diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini index 9b83193a..96bdbf8e 100644 --- a/dtrain/test/example/dtrain.ini +++ b/dtrain/test/example/dtrain.ini @@ -1,14 +1,14 @@ decoder_config=test/example/cdec.ini k=100 N=3 -gamma=0.00001 +gamma=0 #.00001 epochs=2 input=test/example/nc-1k-tabs.gz scorer=stupid_bleu output=- -stop_after=10 +stop_after=5 sample_from=kbest -pair_sampling=108010 -select_weights=best +pair_sampling=all #108010 +select_weights=VOID print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PassThrough tmp=/tmp diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index eac8aa85..27cada13 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -182,10 +182,6 @@ void SeekPastHeader(int fd, const Parameters ¶ms) { SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); } -void SeekPastHeader(int fd, const Parameters ¶ms) { - SeekOrThrow(fd, TotalHeaderSize(params.counts.size())); -} - uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, std::size_t memory_size, Backing &backing) { const off_t file_size = util::SizeFile(backing.file.get()); // The header is smaller than a page, so we have to map the whole header as well. diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 1bcfe27d..5d8c70db 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -234,19 +234,8 @@ class FindBlanks { return unigrams_[index].prob; } -<<<<<<< HEAD -// Phase to count n-grams, including blanks inserted because they were pruned but have extensions -class JustCount { - public: - template JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) - : counts_(counts), longest_counts_(counts + order - 1) {} - - void Unigrams(WordIndex begin, WordIndex end) { - counts_[0] += end - begin; -======= void Unigram(WordIndex /*index*/) { ++counts_[0]; ->>>>>>> upstream/master } void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) { @@ -278,11 +267,7 @@ class JustCount { // Phase to actually write n-grams to the trie. template class WriteEntries { public: -<<<<<<< HEAD - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : -======= WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : ->>>>>>> upstream/master contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -330,16 +315,8 @@ template class WriteEntries { SRISucks &sri_; }; -<<<<<<< HEAD -template class RecursiveInsert { - public: - template RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : - doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { - } -======= struct Gram { Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {} ->>>>>>> upstream/master const WordIndex *begin, *end; @@ -440,29 +417,6 @@ void SanityCheckCounts(const std::vector &initial, const std::vector void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - ProbBackoff weights; - std::vector probs, backoffs; - probs.reserve(count); - backoffs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); - ++progress; - } -======= template void TrainQuantizer(uint8_t order, uint64_t count, const std::vector &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector probs(additional), backoffs; probs.reserve(count + additional.size()); @@ -472,26 +426,10 @@ template void TrainQuantizer(uint8_t order, uint64_t count, const probs.push_back(weights.prob); if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); ++progress; ->>>>>>> upstream/master } quant.Train(order, probs, backoffs); } -<<<<<<< HEAD -template void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - Prob weights; - std::vector probs, backoffs; - probs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - ++progress; - } -======= template void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector probs, backoffs; probs.reserve(count); @@ -499,18 +437,10 @@ template void TrainProbQuantizer(uint8_t order, uint64_t count, Re const Prob &weights = *reinterpret_cast(reinterpret_cast(reader.Data()) + sizeof(WordIndex) * order); probs.push_back(weights.prob); ++progress; ->>>>>>> upstream/master } quant.TrainProb(order, probs); } -<<<<<<< HEAD -} // namespace - -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { - std::vector inputs(counts.size() - 1); - std::vector contexts(counts.size() - 1); -======= void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { // Fill unigram probabilities. try { @@ -533,7 +463,6 @@ void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &c template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { RecordReader inputs[kMaxOrder - 1]; RecordReader contexts[kMaxOrder - 1]; ->>>>>>> upstream/master for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; @@ -548,17 +477,12 @@ template void BuildTrie(const std::string &file_pre SRISucks sri; std::vector fixed_counts(counts.size()); { -<<<<<<< HEAD - RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); -======= std::string temp(file_prefix); temp += "unigrams"; util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); util::scoped_memory unigrams; MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast(unigrams.get()), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); ->>>>>>> upstream/master } for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); @@ -566,18 +490,6 @@ template void BuildTrie(const std::string &file_pre SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; -<<<<<<< HEAD - out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); - - if (Quant::kTrain) { - util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); - for (unsigned char i = 2; i < counts.size(); ++i) { - TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); - } - TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); - quant.FinishedLoading(config); - } -======= util::scoped_FILE unigram_file; { std::string name(file_prefix + "unigrams"); @@ -587,7 +499,6 @@ template void BuildTrie(const std::string &file_pre sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); ->>>>>>> upstream/master for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); @@ -610,30 +521,8 @@ template void BuildTrie(const std::string &file_pre } // Fill entries except unigram probabilities. { -<<<<<<< HEAD - RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - inserter.Apply(config.messages, "Building trie", fixed_counts[0]); - } - - // Fill unigram probabilities. - try { - std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); - for (WordIndex i = 0; i < counts[0]; ++i) { - ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); - if (contexts[0] && **contexts[0] == i) { - SetExtension(unigrams[i].weights.backoff); - ++contexts[0]; - } - } - RemoveOrThrow(name.c_str()); - } catch (util::Exception &e) { - e << " while re-reading unigram probabilities"; - throw; -======= WriteEntries writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); ->>>>>>> upstream/master } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. @@ -687,17 +576,6 @@ template uint8_t *TrieSearch::Setup } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); -<<<<<<< HEAD -} - -template void TrieSearch::LoadedBinary() { - unigram.LoadedBinary(); - for (Middle *i = middle_begin_; i != middle_end_; ++i) { - i->LoadedBinary(); - } - longest.LoadedBinary(); -} -======= } template void TrieSearch::LoadedBinary() { @@ -715,7 +593,6 @@ bool IsDirectory(const char *path) { return S_ISDIR(info.st_mode); } } // namespace ->>>>>>> upstream/master template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index a1136b6f..20075bb8 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -91,15 +91,6 @@ template bool BitPackedMiddle::Find if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } -<<<<<<< HEAD - uint64_t index = at_pointer; - at_pointer *= total_bits_; - at_pointer += word_bits_; - quant_.Read(base_, at_pointer, prob, backoff); - at_pointer += quant_.TotalBits(); - - bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); -======= pointer = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; @@ -108,7 +99,6 @@ template bool BitPackedMiddle::Find at_pointer += quant_.TotalBits(); bhiksha_.ReadNext(base_, at_pointer, pointer, total_bits_, range); ->>>>>>> upstream/master return true; } diff --git a/utils/fdict.h b/utils/fdict.h index 9c8d7cde..f0871b9a 100644 --- a/utils/fdict.h +++ b/utils/fdict.h @@ -33,7 +33,6 @@ struct FD { hash_ = new PerfectHashFunction(cmph_file); #endif } ->>>>>>> upstream/master static inline int NumFeats() { #ifdef HAVE_CMPH if (hash_) return hash_->number_of_keys(); -- cgit v1.2.3