From 205893513c8343fdc55789e427fab4c8b536dc12 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Sun, 26 Jun 2011 18:40:15 -0400 Subject: Quantization --- klm/lm/quantize.hh | 207 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 klm/lm/quantize.hh (limited to 'klm/lm/quantize.hh') diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh new file mode 100644 index 00000000..aae72b34 --- /dev/null +++ b/klm/lm/quantize.hh @@ -0,0 +1,207 @@ +#ifndef LM_QUANTIZE_H__ +#define LM_QUANTIZE_H__ + +#include "lm/binary_format.hh" // for ModelType +#include "lm/blank.hh" +#include "lm/config.hh" +#include "util/bit_packing.hh" + +#include +#include + +#include + +#include + +namespace lm { +namespace ngram { + +class Config; + +/* Store values directly and don't quantize. */ +class DontQuantize { + public: + static const ModelType kModelType = TRIE_SORTED; + static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} + static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } + static uint8_t MiddleBits(const Config &/*config*/) { return 63; } + static uint8_t LongestBits(const Config &/*config*/) { return 31; } + + struct Middle { + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + util::WriteFloat32(base, bit_offset + 31, backoff); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + backoff = util::ReadFloat32(base, bit_offset + 31); + } + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = util::ReadFloat32(base, bit_offset + 31); + } + uint8_t TotalBits() const { return 63; } + }; + + struct Longest { + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteNonPositiveFloat31(base, bit_offset, prob); + } + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } + uint8_t TotalBits() const { return 31; } + }; + + DontQuantize() {} + + void SetupMemory(void * /*start*/, const Config & /*config*/) {} + + static const bool kTrain = false; + // These should never be called because kTrain is false. + void Train(uint8_t /*order*/, std::vector &/*prob*/, std::vector &/*backoff*/) {} + void TrainProb(uint8_t, std::vector &/*prob*/) {} + + void FinishedLoading(const Config &) {} + + Middle Mid(uint8_t /*order*/) const { return Middle(); } + Longest Long(uint8_t /*order*/) const { return Longest(); } +}; + +class SeparatelyQuantize { + private: + class Bins { + public: + // Sigh C++ default constructor + Bins() {} + + Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} + + uint64_t EncodeProb(float value) const { + return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + } + + uint64_t EncodeBackoff(float value) const { + if (value == 0.0) { + return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant; + } + return Encode(value, 2); + } + + float Decode(std::size_t off) const { return begin_[off]; } + + uint8_t Bits() const { return bits_; } + + uint64_t Mask() const { return mask_; } + + private: + uint64_t Encode(float value, size_t reserved) const { + const float *above = std::lower_bound(begin_ + reserved, end_, value); + if (above == begin_ + reserved) return reserved; + if (above == end_) return end_ - begin_ - 1; + return above - begin_ - (value - *(above - 1) < *above - value); + } + + const float *begin_; + const float *end_; + uint8_t bits_; + uint64_t mask_; + }; + + public: + static const ModelType kModelType = QUANT_TRIE_SORTED; + + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); + + static std::size_t Size(uint8_t order, const Config &config) { + size_t longest_table = (static_cast(1) << static_cast(config.prob_bits)) * sizeof(float); + size_t middle_table = (static_cast(1) << static_cast(config.backoff_bits)) * sizeof(float) + longest_table; + // unigrams are currently not quantized so no need for a table. + return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; + } + + static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } + static uint8_t LongestBits(const Config &config) { return config.prob_bits; } + + class Middle { + public: + Middle(uint8_t prob_bits, const float *prob_begin, uint8_t backoff_bits, const float *backoff_begin) : + total_bits_(prob_bits + backoff_bits), total_mask_((1ULL << total_bits_) - 1), prob_(prob_bits, prob_begin), backoff_(backoff_bits, backoff_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob, float backoff) const { + util::WriteInt57(base, bit_offset, total_bits_, + (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { + uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); + prob = prob_.Decode(both >> backoff_.Bits()); + backoff = backoff_.Decode(both & backoff_.Mask()); + } + + void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { + backoff = backoff_.Decode(util::ReadInt25(base, bit_offset, backoff_.Bits(), backoff_.Mask())); + } + + uint8_t TotalBits() const { + return total_bits_; + } + + private: + const uint8_t total_bits_; + const uint64_t total_mask_; + const Bins prob_; + const Bins backoff_; + }; + + class Longest { + public: + // Sigh C++ default constructor + Longest() {} + + Longest(uint8_t prob_bits, const float *prob_begin) : prob_(prob_bits, prob_begin) {} + + void Write(void *base, uint64_t bit_offset, float prob) const { + util::WriteInt25(base, bit_offset, prob_.Bits(), prob_.EncodeProb(prob)); + } + + void Read(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset, prob_.Bits(), prob_.Mask())); + } + + uint8_t TotalBits() const { return prob_.Bits(); } + + private: + Bins prob_; + }; + + SeparatelyQuantize() {} + + void SetupMemory(void *start, const Config &config); + + static const bool kTrain = true; + // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + void Train(uint8_t order, std::vector &prob, std::vector &backoff); + // Train just probabilities (for longest order). + void TrainProb(uint8_t order, std::vector &prob); + + void FinishedLoading(const Config &config); + + Middle Mid(uint8_t order) const { + const float *table = start_ + TableStart(order); + return Middle(prob_bits_, table, backoff_bits_, table + ProbTableLength()); + } + + Longest Long(uint8_t order) const { return Longest(prob_bits_, start_ + TableStart(order)); } + + private: + size_t TableStart(uint8_t order) const { return ((1ULL << prob_bits_) + (1ULL << backoff_bits_)) * static_cast(order - 2); } + size_t ProbTableLength() const { return (1ULL << prob_bits_); } + + float *start_; + uint8_t prob_bits_, backoff_bits_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_QUANTIZE_H__ -- cgit v1.2.3 From 2c14cf2218031c29a9884bccf17e9273c71a33b2 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Thu, 18 Aug 2011 12:14:01 +0100 Subject: KenLM update: Bhiksha's trick, simple test for lms without unk, auto-detect binary files instead of requiring them to be specified at runtime. --- decoder/cdec_ff.cc | 5 +- decoder/ff_klm.cc | 38 +++++- decoder/ff_klm.h | 7 +- klm/compile.sh | 2 +- klm/lm/Makefile.am | 1 + klm/lm/bhiksha.cc | 93 +++++++++++++++ klm/lm/bhiksha.hh | 108 +++++++++++++++++ klm/lm/binary_format.cc | 13 ++- klm/lm/binary_format.hh | 9 +- klm/lm/build_binary.cc | 54 ++++++--- klm/lm/config.cc | 1 + klm/lm/config.hh | 5 +- klm/lm/model.cc | 67 ++++++----- klm/lm/model.hh | 12 +- klm/lm/model_test.cc | 73 ++++++++++-- klm/lm/ngram_query.cc | 9 ++ klm/lm/quantize.cc | 1 + klm/lm/quantize.hh | 4 +- klm/lm/read_arpa.cc | 6 +- klm/lm/search_hashed.cc | 2 +- klm/lm/search_hashed.hh | 3 +- klm/lm/search_trie.cc | 45 +++---- klm/lm/search_trie.hh | 20 ++-- klm/lm/test_nounk.arpa | 120 +++++++++++++++++++ klm/lm/trie.cc | 57 ++++----- klm/lm/trie.hh | 24 ++-- klm/lm/vocab.cc | 6 +- klm/lm/vocab.hh | 4 + klm/util/bit_packing.hh | 13 ++- klm/util/murmur_hash.cc | 258 ++++++++++++++++++++--------------------- klm/util/probing_hash_table.hh | 2 +- klm/util/sorted_uniform.hh | 23 +++- 32 files changed, 792 insertions(+), 293 deletions(-) create mode 100644 klm/lm/bhiksha.cc create mode 100644 klm/lm/bhiksha.hh create mode 100644 klm/lm/test_nounk.arpa (limited to 'klm/lm/quantize.hh') diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 3451c9fb..1ef76a05 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -55,10 +55,7 @@ void register_feature_functions() { ff_registry.Register("NgramFeatures", new FFFactory()); ff_registry.Register("RuleNgramFeatures", new FFFactory()); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); - ff_registry.Register("KLanguageModel", new FFFactory >()); - ff_registry.Register("KLanguageModel_Trie", new FFFactory >()); - ff_registry.Register("KLanguageModel_QuantTrie", new FFFactory >()); - ff_registry.Register("KLanguageModel_Probing", new FFFactory >()); + ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); ff_registry.Register("NonLatinCount", new FFFactory); ff_registry.Register("RuleShape", new FFFactory); ff_registry.Register("RelativeSentencePosition", new FFFactory); diff --git a/decoder/ff_klm.cc b/decoder/ff_klm.cc index 9b7fe2d3..24dcb9c3 100644 --- a/decoder/ff_klm.cc +++ b/decoder/ff_klm.cc @@ -9,6 +9,7 @@ #include "stringlib.h" #include "hg.h" #include "tdict.h" +#include "lm/model.hh" #include "lm/enumerate_vocab.hh" using namespace std; @@ -434,8 +435,37 @@ void KLanguageModel::FinalTraversalFeatures(const void* ant_state, features->set_value(oov_fid_, oovs); } -// instantiate templates -template class KLanguageModel; -template class KLanguageModel; -template class KLanguageModel; +template boost::shared_ptr CreateModel(const std::string ¶m) { + KLanguageModel *ret = new KLanguageModel(param); + ret->Init(); + return boost::shared_ptr(ret); +} +boost::shared_ptr KLanguageModelFactory::Create(std::string param) const { + using namespace lm::ngram; + std::string filename, ignored_map; + bool ignored_markers; + std::string ignored_featname; + ParseLMArgs(param, &filename, &ignored_map, &ignored_markers, &ignored_featname); + ModelType m; + if (!RecognizeBinary(filename.c_str(), m)) m = HASH_PROBING; + + switch (m) { + case HASH_PROBING: + return CreateModel(param); + case TRIE_SORTED: + return CreateModel(param); + case ARRAY_TRIE_SORTED: + return CreateModel(param); + case QUANT_TRIE_SORTED: + return CreateModel(param); + case QUANT_ARRAY_TRIE_SORTED: + return CreateModel(param); + default: + UTIL_THROW(util::Exception, "Unrecognized kenlm binary file type " << (unsigned)m); + } +} + +std::string KLanguageModelFactory::usage(bool params,bool verbose) const { + return KLanguageModel::usage(params, verbose); +} diff --git a/decoder/ff_klm.h b/decoder/ff_klm.h index 5eafe8be..6efe50f6 100644 --- a/decoder/ff_klm.h +++ b/decoder/ff_klm.h @@ -4,8 +4,8 @@ #include #include +#include "ff_factory.h" #include "ff.h" -#include "lm/model.hh" template struct KLanguageModelImpl; @@ -34,4 +34,9 @@ class KLanguageModel : public FeatureFunction { KLanguageModelImpl* pimpl_; }; +struct KLanguageModelFactory : public FactoryBase { + FP Create(std::string param) const; + std::string usage(bool params,bool verbose) const; +}; + #endif diff --git a/klm/compile.sh b/klm/compile.sh index 6ca85e1f..abe3473a 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -5,7 +5,7 @@ set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o done g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 395494bc..fae6b41a 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -12,6 +12,7 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz noinst_LIBRARIES = libklm.a libklm_a_SOURCES = \ + bhiksha.cc \ binary_format.cc \ config.cc \ lm_exception.cc \ diff --git a/klm/lm/bhiksha.cc b/klm/lm/bhiksha.cc new file mode 100644 index 00000000..bf86fd4b --- /dev/null +++ b/klm/lm/bhiksha.cc @@ -0,0 +1,93 @@ +#include "lm/bhiksha.hh" +#include "lm/config.hh" + +#include + +namespace lm { +namespace ngram { +namespace trie { + +DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) : + next_(util::BitsMask::ByMax(max_next)) {} + +const uint8_t kArrayBhikshaVersion = 0; + +void ArrayBhiksha::UpdateConfigFromBinary(int fd, Config &config) { + uint8_t version; + uint8_t configured_bits; + if (read(fd, &version, 1) != 1 || read(fd, &configured_bits, 1) != 1) { + UTIL_THROW(util::ErrnoException, "Could not read from binary file"); + } + if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); + config.pointer_bhiksha_bits = configured_bits; +} + +namespace { + +// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset) +uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t best_chop = 0; + int64_t lowest_change = std::numeric_limits::max(); + // There are probably faster ways but I don't care because this is only done once per order at construction time. + for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) { + int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */ + - max_offset * static_cast(chop); /* savings in bits*/ + if (change < lowest_change) { + lowest_change = change; + best_chop = chop; + } + } + return best_chop; +} + +std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) { + uint8_t required = util::RequiredBits(max_next); + uint8_t chopping = ChopBits(max_offset, max_next, config); + return (max_next >> (required - chopping)) + 1 /* we store 0 too */; +} +} // namespace + +std::size_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) { + return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */; +} + +uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) { + return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config); +} + +namespace { + +void *AlignTo8(void *from) { + uint8_t *val = reinterpret_cast(from); + std::size_t remainder = reinterpret_cast(val) & 7; + if (!remainder) return val; + return val + 8 - remainder; +} + +} // namespace + +ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config) + : next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))), + offset_begin_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */), + offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)), + write_to_(reinterpret_cast(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */), + original_base_(base) {} + +void ArrayBhiksha::FinishedLoading(const Config &config) { + // *offset_begin_ = 0 but without a const_cast. + *(write_to_ - (write_to_ - offset_begin_)) = 0; + + if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected."); + + uint8_t *head_write = reinterpret_cast(original_base_); + *(head_write++) = kArrayBhikshaVersion; + *(head_write++) = config.pointer_bhiksha_bits; +} + +void ArrayBhiksha::LoadedBinary() { +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh new file mode 100644 index 00000000..cfb2b053 --- /dev/null +++ b/klm/lm/bhiksha.hh @@ -0,0 +1,108 @@ +/* Simple implementation of + * @inproceedings{bhikshacompression, + * author={Bhiksha Raj and Ed Whittaker}, + * year={2003}, + * title={Lossless Compression of Language Model Structure and Word Identifiers}, + * booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing}, + * pages={388--391}, + * } + * + * Currently only used for next pointers. + */ + +#include + +#include "lm/binary_format.hh" +#include "lm/trie.hh" +#include "util/bit_packing.hh" +#include "util/sorted_uniform.hh" + +namespace lm { +namespace ngram { +class Config; + +namespace trie { + +class DontBhiksha { + public: + static const ModelType kModelTypeAdd = static_cast(0); + + static void UpdateConfigFromBinary(int /*fd*/, Config &/*config*/) {} + + static std::size_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } + + static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) { + return util::RequiredBits(max_next); + } + + DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const { + out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask); + out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask); + //assert(out.end >= out.begin); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) { + util::WriteInt57(base, bit_offset, next_.bits, value); + } + + void FinishedLoading(const Config &/*config*/) {} + + void LoadedBinary() {} + + uint8_t InlineBits() const { return next_.bits; } + + private: + util::BitsMask next_; +}; + +class ArrayBhiksha { + public: + static const ModelType kModelTypeAdd = kArrayAdd; + + static void UpdateConfigFromBinary(int fd, Config &config); + + static std::size_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); + + static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config); + + ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config); + + void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const { + const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor(), offset_begin_, offset_end_, index); + const uint64_t *end_it; + for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {} + --end_it; + out.begin = ((begin_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask); + out.end = ((end_it - offset_begin_) << next_inline_.bits) | + util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask); + } + + void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) { + uint64_t encode = value >> next_inline_.bits; + for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index; + util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask); + } + + void FinishedLoading(const Config &config); + + void LoadedBinary(); + + uint8_t InlineBits() const { return next_inline_.bits; } + + private: + const util::BitsMask next_inline_; + + const uint64_t *const offset_begin_; + const uint64_t *const offset_end_; + + uint64_t *write_to_; + + void *original_base_; +}; + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index 92b1008b..e02e621a 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -40,7 +40,7 @@ struct Sanity { } }; -const char *kModelNames[3] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "bit packed trie"}; +const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; std::size_t Align8(std::size_t in) { std::size_t off = in % 8; @@ -100,16 +100,17 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ } } -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing) { +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { + std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { // Grow the file to accomodate the search, using zeros. - if (-1 == ftruncate(backing.file.get(), backing.vocab.size() + memory_size)) - UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (backing.vocab.size() + memory_size) << " failed"); + if (-1 == ftruncate(backing.file.get(), adjusted_vocab + memory_size)) + UTIL_THROW(util::ErrnoException, "ftruncate on " << config.write_mmap << " to " << (adjusted_vocab + memory_size) << " failed"); // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. off_t page_size = sysconf(_SC_PAGE_SIZE); - off_t alignment_cruft = backing.vocab.size() % page_size; - backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), backing.vocab.size() - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); + off_t alignment_cruft = adjusted_vocab % page_size; + backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); return reinterpret_cast(backing.search.get()) + alignment_cruft; } else { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index 2b32b450..d28cb6c5 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -16,7 +16,12 @@ namespace lm { namespace ngram { -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType; +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in @@ -55,7 +60,7 @@ void AdvanceOrThrow(int fd, off_t off); // Create just enough of a binary file to write vocabulary to it. uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing); // Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin. -uint8_t *GrowForSearch(const Config &config, std::size_t memory_size, Backing &backing); +uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing); // Write header to binary file. This is done last to prevent incomplete files // from loading. diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 4552c419..b7aee4de 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -15,12 +15,12 @@ namespace ngram { namespace { void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-n] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [type] input.arpa output.mmap\n\n" -"-u sets the default log10 probability for if the ARPA file does not have\n" -"one.\n" + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-c bits] [type] input.arpa [output.mmap]\n\n" +"-u sets the log10 probability for if the ARPA file does not have one.\n" +" Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have and .\n" -"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" -"type is either probing or trie:\n\n" +"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n\n" +"type is either probing or trie. Default is probing.\n\n" "probing uses a probing hash table. It is the fastest but uses the most memory.\n" "-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n" "trie is a straightforward trie with bit-level packing. It uses the least\n" @@ -29,10 +29,11 @@ void Usage(const char *name) { "-t is the temporary directory prefix. Default is the output file name.\n" "-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" "-q turns quantization on and sets the number of bits (e.g. -q 8).\n" -"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n" -"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n" -"Passing only an input file will print memory usage of each data structure.\n" -"If the ARPA file does not have , -u sets 's probability; default 0.0.\n"; +"-b sets backoff quantization bits. Requires -q and defaults to that value.\n" +"-a compresses pointers using an array of offsets. The parameter is the\n" +" maximum number of bits encoded by the array. Memory is minimized subject\n" +" to the maximum, so pick 255 to minimize memory.\n\n" +"Get a memory estimate by passing an ARPA file without an output file name.\n"; exit(1); } @@ -63,12 +64,14 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::vector counts; util::FilePiece f(file); lm::ReadARPACounts(f, counts); - std::size_t sizes[3]; + std::size_t sizes[5]; sizes[0] = ProbingModel::Size(counts, config); sizes[1] = TrieModel::Size(counts, config); sizes[2] = QuantTrieModel::Size(counts, config); - std::size_t max_length = *std::max_element(sizes, sizes + 3); - std::size_t min_length = *std::max_element(sizes, sizes + 3); + sizes[3] = ArrayTrieModel::Size(counts, config); + sizes[4] = QuantArrayTrieModel::Size(counts, config); + std::size_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); + std::size_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(size_t)); std::size_t divide; char prefix; if (min_length < (1 << 10) * 10) { @@ -91,7 +94,9 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) { std::cout << prefix << "B\n" "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" "trie " << std::setw(length) << (sizes[1] / divide) << " without quantization\n" - "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"; + "trie " << std::setw(length) << (sizes[2] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" + "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" + "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; } void ProbingQuantizationUnsupported() { @@ -106,11 +111,11 @@ void ProbingQuantizationUnsupported() { int main(int argc, char *argv[]) { using namespace lm::ngram; - bool quantize = false, set_backoff_bits = false; try { + bool quantize = false, set_backoff_bits = false, bhiksha = false; lm::ngram::Config config; int opt; - while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) { + while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:a:")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -121,6 +126,9 @@ int main(int argc, char *argv[]) { config.backoff_bits = ParseBitCount(optarg); set_backoff_bits = true; break; + case 'a': + config.pointer_bhiksha_bits = ParseBitCount(optarg); + bhiksha = true; case 'u': config.unknown_missing_logprob = ParseFloat(optarg); break; @@ -162,9 +170,17 @@ int main(int argc, char *argv[]) { ProbingModel(from_file, config); } else if (!strcmp(model_type, "trie")) { if (quantize) { - QuantTrieModel(from_file, config); + if (bhiksha) { + QuantArrayTrieModel(from_file, config); + } else { + QuantTrieModel(from_file, config); + } } else { - TrieModel(from_file, config); + if (bhiksha) { + ArrayTrieModel(from_file, config); + } else { + TrieModel(from_file, config); + } } } else { Usage(argv[0]); @@ -173,9 +189,9 @@ int main(int argc, char *argv[]) { Usage(argv[0]); } } - catch (std::exception &e) { + catch (const std::exception &e) { std::cerr << e.what() << std::endl; - abort(); + return 1; } return 0; } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index 08e1af5c..297589a4 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -20,6 +20,7 @@ Config::Config() : include_vocab(true), prob_bits(8), backoff_bits(8), + pointer_bhiksha_bits(22), load_method(util::POPULATE_OR_READ) {} } // namespace ngram diff --git a/klm/lm/config.hh b/klm/lm/config.hh index dcc7cf35..227b8512 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -73,9 +73,12 @@ struct Config { // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used - // to quantize. + // to quantize (and one of the remaining backoffs will be 0). uint8_t prob_bits, backoff_bits; + // Bhiksha compression (simple form). Only works with trie. + uint8_t pointer_bhiksha_bits; + // ONLY EFFECTIVE WHEN READING BINARY diff --git a/klm/lm/model.cc b/klm/lm/model.cc index a1d10b3d..27e24b1c 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -21,6 +21,8 @@ size_t hash_value(const State &state) { namespace detail { +template const ModelType GenericModel::kModelType = Search::kModelType; + template size_t GenericModel::Size(const std::vector &counts, const Config &config) { return VocabularyT::Size(counts[0], config) + Search::Size(counts, config); } @@ -56,35 +58,40 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, const Config &config) { // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. util::FilePiece f(backing_.file.release(), file, config.messages); - std::vector counts; - // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. - ReadARPACounts(f, counts); - - if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); - if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); - if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); - - std::size_t vocab_size = VocabularyT::Size(counts[0], config); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. - vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); - - if (config.write_mmap) { - WriteWordsWrapper wrap(config.enumerate_vocab); - vocab_.ConfigureEnumerate(&wrap, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - wrap.Write(backing_.file.get()); - } else { - vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); - search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); - } + try { + std::vector counts; + // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. + ReadARPACounts(f, counts); + + if (counts.size() > kMaxOrder) UTIL_THROW(FormatLoadException, "This model has order " << counts.size() << ". Edit lm/max_order.hh, set kMaxOrder to at least this value, and recompile."); + if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model."); + if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0"); + + std::size_t vocab_size = VocabularyT::Size(counts[0], config); + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); + + if (config.write_mmap) { + WriteWordsWrapper wrap(config.enumerate_vocab); + vocab_.ConfigureEnumerate(&wrap, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + wrap.Write(backing_.file.get()); + } else { + vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]); + search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_); + } - if (!vocab_.SawUnk()) { - assert(config.unknown_missing != THROW_UP); - // Default probabilities for unknown. - search_.unigram.Unknown().backoff = 0.0; - search_.unigram.Unknown().prob = config.unknown_missing_logprob; + if (!vocab_.SawUnk()) { + assert(config.unknown_missing != THROW_UP); + // Default probabilities for unknown. + search_.unigram.Unknown().backoff = 0.0; + search_.unigram.Unknown().prob = config.unknown_missing_logprob; + } + FinishFile(config, kModelType, counts, backing_); + } catch (util::Exception &e) { + e << " Byte: " << f.Offset(); + throw; } - FinishFile(config, kModelType, counts, backing_); } template FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { @@ -225,8 +232,10 @@ template FullScoreReturn GenericModel; // HASH_PROBING -template class GenericModel, SortedVocabulary>; // TRIE_SORTED -template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; // TRIE_SORTED +template class GenericModel, SortedVocabulary>; +template class GenericModel, SortedVocabulary>; // TRIE_SORTED_QUANT +template class GenericModel, SortedVocabulary>; } // namespace detail } // namespace ngram diff --git a/klm/lm/model.hh b/klm/lm/model.hh index 1f49a382..21595321 100644 --- a/klm/lm/model.hh +++ b/klm/lm/model.hh @@ -1,6 +1,7 @@ #ifndef LM_MODEL__ #define LM_MODEL__ +#include "lm/bhiksha.hh" #include "lm/binary_format.hh" #include "lm/config.hh" #include "lm/facade.hh" @@ -71,6 +72,9 @@ template class GenericModel : public base::Mod private: typedef base::ModelFacade, State, VocabularyT> P; public: + // This is the model type returned by RecognizeBinary. + static const ModelType kModelType; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -131,8 +135,6 @@ template class GenericModel : public base::Mod Backing &MutableBacking() { return backing_; } - static const ModelType kModelType = Search::kModelType; - Backing backing_; VocabularyT vocab_; @@ -152,9 +154,11 @@ typedef ProbingModel Model; // Smaller implementation. typedef ::lm::ngram::SortedVocabulary SortedVocabulary; -typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> TrieModel; // TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> ArrayTrieModel; -typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED +typedef detail::GenericModel, SortedVocabulary> QuantArrayTrieModel; } // namespace ngram } // namespace lm diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 8bf040ff..57c7291c 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -193,6 +193,14 @@ template void Stateless(const M &model) { BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); } +template void NoUnkCheck(const M &model) { + WordIndex unk_index = 0; + State state; + + FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state); + BOOST_CHECK_CLOSE(-100.0, ret.prob, 0.001); +} + template void Everything(const M &m) { Starters(m); Continuation(m); @@ -231,25 +239,38 @@ template void LoadingTest() { Config config; config.arpa_complain = Config::NONE; config.messages = NULL; - ExpectEnumerateVocab enumerate; - config.enumerate_vocab = &enumerate; config.probing_multiplier = 2.0; - ModelT m("test.arpa", config); - enumerate.Check(m.GetVocabulary()); - Everything(m); + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test.arpa", config); + enumerate.Check(m.GetVocabulary()); + Everything(m); + } + { + ExpectEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + ModelT m("test_nounk.arpa", config); + enumerate.Check(m.GetVocabulary()); + NoUnkCheck(m); + } } BOOST_AUTO_TEST_CASE(probing) { LoadingTest(); } - BOOST_AUTO_TEST_CASE(trie) { LoadingTest(); } - -BOOST_AUTO_TEST_CASE(quant) { +BOOST_AUTO_TEST_CASE(quant_trie) { LoadingTest(); } +BOOST_AUTO_TEST_CASE(bhiksha_trie) { + LoadingTest(); +} +BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) { + LoadingTest(); +} template void BinaryTest() { Config config; @@ -267,10 +288,34 @@ template void BinaryTest() { config.write_mmap = NULL; - ModelT binary("test.binary", config); - enumerate.Check(binary.GetVocabulary()); - Everything(binary); + ModelType type; + BOOST_REQUIRE(RecognizeBinary("test.binary", type)); + BOOST_CHECK_EQUAL(ModelT::kModelType, type); + + { + ModelT binary("test.binary", config); + enumerate.Check(binary.GetVocabulary()); + Everything(binary); + } unlink("test.binary"); + + // Now test without . + config.write_mmap = "test_nounk.binary"; + config.messages = NULL; + enumerate.Clear(); + { + ModelT copy_model("test_nounk.arpa", config); + enumerate.Check(copy_model.GetVocabulary()); + enumerate.Clear(); + NoUnkCheck(copy_model); + } + config.write_mmap = NULL; + { + ModelT binary("test_nounk.binary", config); + enumerate.Check(binary.GetVocabulary()); + NoUnkCheck(binary); + } + unlink("test_nounk.binary"); } BOOST_AUTO_TEST_CASE(write_and_read_probing) { @@ -282,6 +327,12 @@ BOOST_AUTO_TEST_CASE(write_and_read_trie) { BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) { BinaryTest(); } +BOOST_AUTO_TEST_CASE(write_and_read_array_trie) { + BinaryTest(); +} +BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) { + BinaryTest(); +} } // namespace } // namespace ngram diff --git a/klm/lm/ngram_query.cc b/klm/lm/ngram_query.cc index 9454a6d1..d9db4aa2 100644 --- a/klm/lm/ngram_query.cc +++ b/klm/lm/ngram_query.cc @@ -99,6 +99,15 @@ int main(int argc, char *argv[]) { case lm::ngram::TRIE_SORTED: Query(argv[1], sentence_context); break; + case lm::ngram::QUANT_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; + case lm::ngram::QUANT_ARRAY_TRIE_SORTED: + Query(argv[1], sentence_context); + break; case lm::ngram::HASH_SORTED: default: std::cerr << "Unrecognized kenlm model type " << model_type << std::endl; diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index 4bb6b1b8..fd371cc8 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -43,6 +43,7 @@ void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector(0); static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} static std::size_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } @@ -108,7 +108,7 @@ class SeparatelyQuantize { }; public: - static const ModelType kModelType = QUANT_TRIE_SORTED; + static const ModelType kModelTypeAdd = kQuantAdd; static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config); diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index 060a97ea..455bc4ba 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -31,15 +31,15 @@ const char kBinaryMagic[] = "mmap lm http://kheafield.com/code"; void ReadARPACounts(util::FilePiece &in, std::vector &number) { number.clear(); StringPiece line; - if (!IsEntirelyWhiteSpace(line = in.ReadLine())) { + while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + if (line != "\\data\\") { if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast(line.data()[1]) == 0x8b)) { UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); } if (static_cast(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic) UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?"); - UTIL_THROW(FormatLoadException, "First line was \"" << line.data() << "\" not blank"); + UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\."); } - if ((line = in.ReadLine()) != "\\data\\") UTIL_THROW(FormatLoadException, "second line was \"" << line << "\" not \\data\\."); while (!IsEntirelyWhiteSpace(line = in.ReadLine())) { if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \""); // So strtol doesn't go off the end of line. diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index c56ba7b8..82c53ec8 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -98,7 +98,7 @@ template uint8_t *TemplateHashedSearch template void TemplateHashedSearch::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector &counts, const Config &config, Voc &vocab, Backing &backing) { // TODO: fix sorted. - SetupMemory(GrowForSearch(config, Size(counts, config), backing), counts, config); + SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config); PositiveProbWarn warn(config.positive_log_probability); diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index f3acdefc..c62985e4 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -52,12 +52,11 @@ struct HashedSearch { Unigram unigram; - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { const ProbBackoff &entry = unigram.Lookup(word); prob = entry.prob; backoff = entry.backoff; next = static_cast(word); - return true; } }; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 91f87f1c..05059ffb 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -1,6 +1,7 @@ /* This is where the trie is built. It's on-disk. */ #include "lm/search_trie.hh" +#include "lm/bhiksha.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" @@ -543,8 +544,8 @@ void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector appears. - size_t extra_count = counts[0] + 1; - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), extra_count * sizeof(ProbBackoff), unigram_file), extra_count * sizeof(ProbBackoff)); + size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); CheckSpecials(config, vocab); if (!vocab.SawUnk()) ++counts[0]; @@ -610,9 +611,9 @@ class JustCount { }; // Phase to actually write n-grams to the trie. -template class WriteEntries { +template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : contexts_(contexts), unigrams_(unigrams), middle_(middle), @@ -649,7 +650,7 @@ template class WriteEntries { private: ContextReader *contexts_; UnigramValue *const unigrams_; - BitPackedMiddle *const middle_; + BitPackedMiddle *const middle_; BitPackedLongest &longest_; BitPacked &bigram_pack_; }; @@ -821,7 +822,7 @@ template void TrainProbQuantizer(uint8_t order, uint64_t count, So } // namespace -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing) { +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { std::vector inputs(counts.size() - 1); std::vector contexts(counts.size() - 1); @@ -846,7 +847,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; - out.SetupMemory(GrowForSearch(config, TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); if (Quant::kTrain) { util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); @@ -863,7 +864,7 @@ template void BuildTrie(const std::string &file_prefix, std::vecto UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); + RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); inserter.Apply(config.messages, "Building trie", fixed_counts[0]); } @@ -901,14 +902,14 @@ template void BuildTrie(const std::string &file_prefix, std::vecto /* Set ending offsets so the last entry will be sized properly */ // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { - for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { - i->FinishedLoading((i+1)->InsertIndex()); + for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { + i->FinishedLoading((i+1)->InsertIndex(), config); } - (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex()); + (out.middle_end_ - 1)->FinishedLoading(out.longest.InsertIndex(), config); } } -template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { +template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { quant_.SetupMemory(start, config); start += Quant::Size(counts.size(), config); unigram.Init(start); @@ -919,22 +920,24 @@ template uint8_t *TrieSearch::SetupMemory(uint8_t *start, c std::vector middle_starts(counts.size() - 2); for (unsigned char i = 2; i < counts.size(); ++i) { middle_starts[i-2] = start; - start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i]); + start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config); } - // Crazy backwards thing so we initialize in the correct order. + // Crazy backwards thing so we initialize using pointers to ones that have already been initialized for (unsigned char i = counts.size() - 1; i >= 2; --i) { new (middle_begin_ + i - 2) Middle( middle_starts[i-2], quant_.Mid(i), + counts[i-1], counts[0], counts[i], - (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1])); + (i == counts.size() - 1) ? static_cast(longest) : static_cast(middle_begin_[i-1]), + config); } longest.Init(start, quant_.Long(counts.size()), counts[0]); return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } -template void TrieSearch::LoadedBinary() { +template void TrieSearch::LoadedBinary() { unigram.LoadedBinary(); for (Middle *i = middle_begin_; i != middle_end_; ++i) { i->LoadedBinary(); @@ -942,7 +945,7 @@ template void TrieSearch::LoadedBinary() { longest.LoadedBinary(); } -template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { +template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { temporary_directory = config.temporary_directory_prefix; @@ -966,14 +969,16 @@ template void TrieSearch::InitializeFromARPA(const char *fi // At least 1MB sorting memory. ARPAToSortedFiles(config, f, counts, std::max(config.building_memory, 1048576), temporary_directory.c_str(), vocab); - BuildTrie(temporary_directory, counts, config, *this, quant_, backing); + BuildTrie(temporary_directory, counts, config, *this, quant_, vocab, backing); if (rmdir(temporary_directory.c_str()) && config.messages) { *config.messages << "Failed to delete " << temporary_directory << std::endl; } } -template class TrieSearch; -template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; +template class TrieSearch; } // namespace trie } // namespace ngram diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 0a52acb5..2f39c09f 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -13,31 +13,33 @@ struct Backing; class SortedVocabulary; namespace trie { -template class TrieSearch; -template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); +template class TrieSearch; +template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); -template class TrieSearch { +template class TrieSearch { public: typedef NodeRange Node; typedef ::lm::ngram::trie::Unigram Unigram; Unigram unigram; - typedef trie::BitPackedMiddle Middle; + typedef trie::BitPackedMiddle Middle; typedef trie::BitPackedLongest Longest; Longest longest; - static const ModelType kModelType = Quant::kModelType; + static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); + AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); + Bhiksha::UpdateConfigFromBinary(fd, config); } static std::size_t Size(const std::vector &counts, const Config &config) { std::size_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); for (unsigned char i = 1; i < counts.size() - 1; ++i) { - ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1]); + ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); } return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); } @@ -55,8 +57,8 @@ template class TrieSearch { void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - return unigram.Find(word, prob, backoff, node); + void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { + unigram.Find(word, prob, backoff, node); } bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { @@ -83,7 +85,7 @@ template class TrieSearch { } private: - friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, Backing &backing); + friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); // Middles are managed manually so we can delay construction and they don't have to be copyable. void FreeMiddles() { diff --git a/klm/lm/test_nounk.arpa b/klm/lm/test_nounk.arpa new file mode 100644 index 00000000..060733d9 --- /dev/null +++ b/klm/lm/test_nounk.arpa @@ -0,0 +1,120 @@ + +\data\ +ngram 1=36 +ngram 2=45 +ngram 3=10 +ngram 4=6 +ngram 5=4 + +\1-grams: +-1.383514 , -0.30103 +-1.139057 . -0.845098 +-1.029493 +-99 -0.4149733 +-1.285941 a -0.69897 +-1.687872 also -0.30103 +-1.687872 beyond -0.30103 +-1.687872 biarritz -0.30103 +-1.687872 call -0.30103 +-1.687872 concerns -0.30103 +-1.687872 consider -0.30103 +-1.687872 considering -0.30103 +-1.687872 for -0.30103 +-1.509559 higher -0.30103 +-1.687872 however -0.30103 +-1.687872 i -0.30103 +-1.687872 immediate -0.30103 +-1.687872 in -0.30103 +-1.687872 is -0.30103 +-1.285941 little -0.69897 +-1.383514 loin -0.30103 +-1.687872 look -0.30103 +-1.285941 looking -0.4771212 +-1.206319 more -0.544068 +-1.509559 on -0.4771212 +-1.509559 screening -0.4771212 +-1.687872 small -0.30103 +-1.687872 the -0.30103 +-1.687872 to -0.30103 +-1.687872 watch -0.30103 +-1.687872 watching -0.30103 +-1.687872 what -0.30103 +-1.687872 would -0.30103 +-3.141592 foo +-2.718281 bar 3.0 +-6.535897 baz -0.0 + +\2-grams: +-0.6925742 , . +-0.7522095 , however +-0.7522095 , is +-0.0602359 . +-0.4846522 looking -0.4771214 +-1.051485 screening +-1.07153 the +-1.07153 watching +-1.07153 what +-0.09132547 a little -0.69897 +-0.2922095 also call +-0.2922095 beyond immediate +-0.2705918 biarritz . +-0.2922095 call for +-0.2922095 concerns in +-0.2922095 consider watch +-0.2922095 considering consider +-0.2834328 for , +-0.5511513 higher more +-0.5845945 higher small +-0.2834328 however , +-0.2922095 i would +-0.2922095 immediate concerns +-0.2922095 in biarritz +-0.2922095 is to +-0.09021038 little more -0.1998621 +-0.7273645 loin , +-0.6925742 loin . +-0.6708385 loin +-0.2922095 look beyond +-0.4638903 looking higher +-0.4638903 looking on -0.4771212 +-0.5136299 more . -0.4771212 +-0.3561665 more loin +-0.1649931 on a -0.4771213 +-0.1649931 screening a -0.4771213 +-0.2705918 small . +-0.287799 the screening +-0.2922095 to look +-0.2622373 watch +-0.2922095 watching considering +-0.2922095 what i +-0.2922095 would also +-2 also would -6 +-6 foo bar + +\3-grams: +-0.01916512 more . +-0.0283603 on a little -0.4771212 +-0.0283603 screening a little -0.4771212 +-0.01660496 a little more -0.09409451 +-0.3488368 looking higher +-0.3488368 looking on -0.4771212 +-0.1892331 little more loin +-0.04835128 looking on a -0.4771212 +-3 also would consider -7 +-7 to look good + +\4-grams: +-0.009249173 looking on a little -0.4771212 +-0.005464747 on a little more -0.4771212 +-0.005464747 screening a little more +-0.1453306 a little more loin +-0.01552657 looking on a -0.4771212 +-4 also would consider higher -8 + +\5-grams: +-0.003061223 looking on a little +-0.001813953 looking on a little more +-0.0432557 on a little more loin +-5 also would consider higher looking + +\end\ diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 63c2a612..8c536e66 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -1,5 +1,6 @@ #include "lm/trie.hh" +#include "lm/bhiksha.hh" #include "lm/quantize.hh" #include "util/bit_packing.hh" #include "util/exception.hh" @@ -57,16 +58,21 @@ void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) max_vocab_ = max_vocab; } -template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr) { - return BaseSize(entries, max_vocab, quant_bits + util::RequiredBits(max_ptr)); +template std::size_t BitPackedMiddle::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) { + return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config)); } -template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source) : BitPacked(), quant_(quant), next_bits_(util::RequiredBits(max_next)), next_mask_((1ULL << next_bits_) - 1), next_source_(&next_source) { - if (next_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); - BaseInit(base, max_vocab, quant.TotalBits() + next_bits_); +template BitPackedMiddle::BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) : + BitPacked(), + quant_(quant), + // If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary. + bhiksha_(base, entries + 1, max_next, config), + next_source_(&next_source) { + if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions."); + BaseInit(reinterpret_cast(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant.TotalBits() + bhiksha_.InlineBits()); } -template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { +template void BitPackedMiddle::Insert(WordIndex word, float prob, float backoff) { assert(word <= word_mask_); uint64_t at_pointer = insert_index_ * total_bits_; @@ -75,47 +81,42 @@ template void BitPackedMiddle::Insert(WordIndex word, float quant_.Write(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); uint64_t next = next_source_->InsertIndex(); - assert(next <= next_mask_); - util::WriteInt57(base_, at_pointer, next_bits_, next); + bhiksha_.WriteNext(base_, at_pointer, insert_index_, next); ++insert_index_; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; } + uint64_t index = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); + return true; } -template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { - uint64_t at_pointer; - if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return false; - at_pointer *= total_bits_; +template bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const { + uint64_t index; + if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, index)) return false; + uint64_t at_pointer = index * total_bits_; at_pointer += word_bits_; quant_.ReadBackoff(base_, at_pointer, backoff); at_pointer += quant_.TotalBits(); - range.begin = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); - // Read the next entry's pointer. - at_pointer += total_bits_; - range.end = util::ReadInt57(base_, at_pointer, next_bits_, next_mask_); + bhiksha_.ReadNext(base_, at_pointer, index, total_bits_, range); return true; } -template void BitPackedMiddle::FinishedLoading(uint64_t next_end) { - assert(next_end <= next_mask_); - uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_; - util::WriteInt57(base_, last_next_write, next_bits_, next_end); +template void BitPackedMiddle::FinishedLoading(uint64_t next_end, const Config &config) { + uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits(); + bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end); + bhiksha_.FinishedLoading(config); } template void BitPackedLongest::Insert(WordIndex index, float prob) { @@ -135,8 +136,10 @@ template bool BitPackedLongest::Find(WordIndex word, float return true; } -template class BitPackedMiddle; -template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; +template class BitPackedMiddle; template class BitPackedLongest; template class BitPackedLongest; diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 8fa21aaf..53612064 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -10,6 +10,7 @@ namespace lm { namespace ngram { +class Config; namespace trie { struct NodeRange { @@ -46,13 +47,12 @@ class Unigram { void LoadedBinary() {} - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { + void Find(WordIndex word, float &prob, float &backoff, NodeRange &next) const { UnigramValue *val = unigram_ + word; prob = val->weights.prob; backoff = val->weights.backoff; next.begin = val->next; next.end = (val+1)->next; - return true; } private: @@ -67,8 +67,6 @@ class BitPacked { return insert_index_; } - void LoadedBinary() {} - protected: static std::size_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); @@ -83,30 +81,30 @@ class BitPacked { uint64_t insert_index_, max_vocab_; }; -template class BitPackedMiddle : public BitPacked { +template class BitPackedMiddle : public BitPacked { public: - static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next); + static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); // next_source need not be initialized. - BitPackedMiddle(void *base, const Quant &quant, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source); + BitPackedMiddle(void *base, const Quant &quant, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); void Insert(WordIndex word, float prob, float backoff); + void FinishedLoading(uint64_t next_end, const Config &config); + + void LoadedBinary() { bhiksha_.LoadedBinary(); } + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; - void FinishedLoading(uint64_t next_end); - private: Quant quant_; - uint8_t next_bits_; - uint64_t next_mask_; + Bhiksha bhiksha_; const BitPacked *next_source_; }; - template class BitPackedLongest : public BitPacked { public: static std::size_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { @@ -120,6 +118,8 @@ template class BitPackedLongest : public BitPacked { BaseInit(base, max_vocab, quant_.TotalBits()); } + void LoadedBinary() {} + void Insert(WordIndex word, float prob); bool Find(WordIndex word, float &prob, const NodeRange &node) const; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 7defd5c1..04979d51 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -37,14 +37,14 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { WordIndex index = 0; while (true) { ssize_t got = read(fd, &buf[0], kInitialRead); - if (got == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(got == -1, util::ErrnoException, "Reading vocabulary words"); if (got == 0) return index; buf.resize(got); while (buf[buf.size() - 1]) { char next_char; ssize_t ret = read(fd, &next_char, 1); - if (ret == -1) UTIL_THROW(util::ErrnoException, "Reading vocabulary words"); - if (ret == 0) UTIL_THROW(FormatLoadException, "Missing null terminator on a vocab word."); + UTIL_THROW_IF(ret == -1, util::ErrnoException, "Reading vocabulary words"); + UTIL_THROW_IF(ret == 0, FormatLoadException, "Missing null terminator on a vocab word."); buf.push_back(next_char); } // Ok now we have null terminated strings. diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index c92518e4..9d218fff 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -61,6 +61,7 @@ class SortedVocabulary : public base::Vocabulary { } } + // Size for purposes of file writing static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. @@ -77,6 +78,9 @@ class SortedVocabulary : public base::Vocabulary { // Reorders reorder_vocab so that the IDs are sorted. void FinishedLoading(ProbBackoff *reorder_vocab); + // Trie stores the correct counts including in the header. If this was previously sized based on a count exluding , padding with 8 bytes will make it the correct size based on a count including . + std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } + bool SawUnk() const { return saw_unk_; } void LoadedBinary(int fd, EnumerateVocab *to); diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index b35d80c8..9f47d559 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -107,9 +107,20 @@ void BitPackingSanity(); uint8_t RequiredBits(uint64_t max_value); struct BitsMask { + static BitsMask ByMax(uint64_t max_value) { + BitsMask ret; + ret.FromMax(max_value); + return ret; + } + static BitsMask ByBits(uint8_t bits) { + BitsMask ret; + ret.bits = bits; + ret.mask = (1ULL << bits) - 1; + return ret; + } void FromMax(uint64_t max_value) { bits = RequiredBits(max_value); - mask = (1 << bits) - 1; + mask = (1ULL << bits) - 1; } uint8_t bits; uint64_t mask; diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index d58a0727..fec47fd9 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All - * code is released to the public domain. For business purposes, Murmurhash is - * under the MIT license." - * This is modified from the original: - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. - * length changed to unsigned int. - * placed in namespace util - * add MurmurHashNative - * default option = 0 for seed - */ - -#include "util/murmur_hash.hh" - -namespace util { - -//----------------------------------------------------------------------------- -// MurmurHash2, 64-bit versions, by Austin Appleby - -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment -// and endian-ness issues if used across multiple platforms. - -// 64-bit hash for 64-bit platforms - -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - - -// 64-bit hash for 32-bit platforms - -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) -{ - const unsigned int m = 0x5bd1e995; - const int r = 24; - - unsigned int h1 = seed ^ len; - unsigned int h2 = 0; - - const unsigned int * data = (const unsigned int *)key; - - while(len >= 8) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - unsigned int k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } -} - -} // namespace util +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 00be0ed7..2ec342a6 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -57,7 +57,7 @@ template class IdentityAccessor { public: typedef T Key; - T operator()(const uint64_t *in) const { return *in; } + T operator()(const T *in) const { return *in; } }; struct Pivot64 { @@ -101,6 +101,27 @@ template bool SortedUniformFind(co return BoundedSortedUniformFind(accessor, begin, below, end, above, key, out); } +// May return begin - 1. +template Iterator BinaryBelow( + const Accessor &accessor, + Iterator begin, + Iterator end, + const typename Accessor::Key key) { + while (end > begin) { + Iterator pivot(begin + (end - begin) / 2); + typename Accessor::Key mid(accessor(pivot)); + if (mid < key) { + begin = pivot + 1; + } else if (mid > key) { + end = pivot; + } else { + for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {} + return pivot - 1; + } + } + return begin - 1; +} + // To use this template, you need to define a Pivot function to match Key. template class SortedUniformMap { public: -- cgit v1.2.3 From f111672dd611f78656fceb3df3729a290453ef56 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Wed, 21 Sep 2011 18:23:50 -0400 Subject: Updated kenlm. Includes left state support but not the cdec-side use of it. Updated binary format. --- compound-split/de/charlm.rev.5gm.de.klm | Bin 17376695 -> 17376711 bytes klm/compile.sh | 10 +- klm/lm/bhiksha.hh | 2 +- klm/lm/binary_format.cc | 16 +- klm/lm/binary_format.hh | 20 +- klm/lm/blank.hh | 14 - klm/lm/left.hh | 181 ++++++ klm/lm/left_test.cc | 360 +++++++++++ klm/lm/model.cc | 132 ++-- klm/lm/model.hh | 47 +- klm/lm/model_test.cc | 184 ++++-- klm/lm/model_type.hh | 16 + klm/lm/quantize.cc | 4 +- klm/lm/quantize.hh | 13 +- klm/lm/return.hh | 39 ++ klm/lm/search_hashed.cc | 79 ++- klm/lm/search_hashed.hh | 43 +- klm/lm/search_trie.cc | 1052 ++++++++++--------------------- klm/lm/search_trie.hh | 37 +- klm/lm/trie.cc | 5 +- klm/lm/trie.hh | 10 +- klm/lm/trie_sort.cc | 261 ++++++++ klm/lm/trie_sort.hh | 94 +++ klm/lm/virtual_interface.hh | 26 +- klm/lm/vocab.cc | 44 +- klm/lm/vocab.hh | 10 +- klm/test.sh | 2 +- klm/util/bit_packing.hh | 14 + klm/util/exception.cc | 5 + klm/util/exception.hh | 6 + klm/util/file.cc | 74 +++ klm/util/file.hh | 74 +++ klm/util/file_piece.cc | 18 +- klm/util/file_piece.hh | 14 +- klm/util/mmap.cc | 18 +- klm/util/mmap.hh | 4 +- klm/util/murmur_hash.cc | 258 ++++---- klm/util/scoped.cc | 24 - klm/util/scoped.hh | 58 +- klm/util/sized_iterator.hh | 107 ++++ klm/util/tokenize_piece.hh | 69 ++ 41 files changed, 2261 insertions(+), 1183 deletions(-) create mode 100644 klm/lm/left.hh create mode 100644 klm/lm/left_test.cc create mode 100644 klm/lm/model_type.hh create mode 100644 klm/lm/return.hh create mode 100644 klm/lm/trie_sort.cc create mode 100644 klm/lm/trie_sort.hh create mode 100644 klm/util/file.cc create mode 100644 klm/util/file.hh delete mode 100644 klm/util/scoped.cc create mode 100644 klm/util/sized_iterator.hh create mode 100644 klm/util/tokenize_piece.hh (limited to 'klm/lm/quantize.hh') diff --git a/compound-split/de/charlm.rev.5gm.de.klm b/compound-split/de/charlm.rev.5gm.de.klm index e8d114bd..28d09b54 100644 Binary files a/compound-split/de/charlm.rev.5gm.de.klm and b/compound-split/de/charlm.rev.5gm.de.klm differ diff --git a/klm/compile.sh b/klm/compile.sh index abe3473a..56f2e9b2 100755 --- a/klm/compile.sh +++ b/klm/compile.sh @@ -3,10 +3,12 @@ #If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU #I use zlib by default. If you don't want to depend on zlib, remove #define USE_ZLIB from util/file_piece.hh +#don't need to use if compiling with moses Makefiles already + set -e -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,scoped,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,virtual_interface,vocab}; do - g++ -I. -O3 $CXXFLAGS -c $i.cc -o $i.o +for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,file,mmap} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,trie_sort,virtual_interface,vocab}; do + g++ -I. -O3 -DNDEBUG $CXXFLAGS -c $i.cc -o $i.o done -g++ -I. -O3 $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary -g++ -I. -O3 $CXXFLAGS lm/ngram_query.cc {lm,util}/*.o -lz -o query +g++ -I. -O3 -DNDEBUG $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary +g++ -I. -O3 -DNDEBUG $CXXFLAGS lm/ngram_query.cc {lm,util}/*.o -lz -o query diff --git a/klm/lm/bhiksha.hh b/klm/lm/bhiksha.hh index cfb2b053..ff7fe452 100644 --- a/klm/lm/bhiksha.hh +++ b/klm/lm/bhiksha.hh @@ -12,7 +12,7 @@ #include -#include "lm/binary_format.hh" +#include "lm/model_type.hh" #include "lm/trie.hh" #include "util/bit_packing.hh" #include "util/sorted_uniform.hh" diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index e02e621a..27cada13 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -19,10 +19,10 @@ namespace lm { namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; -const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 4\n\0"; +const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; // This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; -const long int kMagicVersion = 4; +const long int kMagicVersion = 5; // Test values. struct Sanity { @@ -42,12 +42,6 @@ struct Sanity { const char *kModelNames[6] = {"hashed n-grams with probing", "hashed n-grams with sorted uniform find", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"}; -std::size_t Align8(std::size_t in) { - std::size_t off = in % 8; - if (!off) return in; - return in + 8 - off; -} - std::size_t TotalHeaderSize(unsigned char order) { return Align8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order); } @@ -119,7 +113,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t } } -void FinishFile(const Config &config, ModelType model_type, const std::vector &counts, Backing &backing) { +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, Backing &backing) { if (config.write_mmap) { if (msync(backing.search.get(), backing.search.size(), MS_SYNC) || msync(backing.vocab.get(), backing.vocab.size(), MS_SYNC)) UTIL_THROW(util::ErrnoException, "msync failed for " << config.write_mmap); @@ -130,6 +124,7 @@ void FinishFile(const Config &config, ModelType model_type, const std::vector(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast(params.fixed.model_type) << " but this is not implemented for in this inference code."); UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]); } + UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version); } void SeekPastHeader(int fd, const Parameters ¶ms) { diff --git a/klm/lm/binary_format.hh b/klm/lm/binary_format.hh index d28cb6c5..e9df0892 100644 --- a/klm/lm/binary_format.hh +++ b/klm/lm/binary_format.hh @@ -2,6 +2,7 @@ #define LM_BINARY_FORMAT__ #include "lm/config.hh" +#include "lm/model_type.hh" #include "lm/read_arpa.hh" #include "util/file_piece.hh" @@ -16,13 +17,6 @@ namespace lm { namespace ngram { -/* Not the best numbering system, but it grew this way for historical reasons - * and I want to preserve existing binary files. */ -typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; - -const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); -const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); - /*Inspect a file to determine if it is a binary lm. If not, return false. * If so, return true and set recognized to the type. This is the only API in * this header designed for use by decoder authors. @@ -36,8 +30,14 @@ struct FixedWidthParameters { ModelType model_type; // Does the end of the file have the actual strings in the vocabulary? bool has_vocabulary; + unsigned int search_version; }; +inline std::size_t Align8(std::size_t in) { + std::size_t off = in % 8; + return off ? (in + 8 - off) : in; +} + // Parameters stored in the header of a binary file. struct Parameters { FixedWidthParameters fixed; @@ -64,7 +64,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t // Write header to binary file. This is done last to prevent incomplete files // from loading. -void FinishFile(const Config &config, ModelType model_type, const std::vector &counts, Backing &backing); +void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, Backing &backing); namespace detail { @@ -72,7 +72,7 @@ bool IsBinaryFormat(int fd); void ReadHeader(int fd, Parameters ¶ms); -void MatchCheck(ModelType model_type, const Parameters ¶ms); +void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms); void SeekPastHeader(int fd, const Parameters ¶ms); @@ -90,7 +90,7 @@ template void LoadLM(const char *file, const Config &config, To &to) if (detail::IsBinaryFormat(backing.file.get())) { Parameters params; detail::ReadHeader(backing.file.get(), params); - detail::MatchCheck(To::kModelType, params); + detail::MatchCheck(To::kModelType, To::kVersion, params); // Replace the run-time configured probing_multiplier with the one in the file. Config new_config(config); new_config.probing_multiplier = params.fixed.probing_multiplier; diff --git a/klm/lm/blank.hh b/klm/lm/blank.hh index 162411a9..2fb64cd0 100644 --- a/klm/lm/blank.hh +++ b/klm/lm/blank.hh @@ -38,20 +38,6 @@ inline bool HasExtension(const float &backoff) { return compare.i != interpret.i; } -/* Suppose "foo bar baz quux" appears in the ARPA but not "bar baz quux" or - * "baz quux" (because they were pruned). 1.2% of n-grams generated by SRI - * with default settings on the benchmark data set are like this. Since search - * proceeds by finding "quux", "baz quux", "bar baz quux", and finally - * "foo bar baz quux" and the trie needs pointer nodes anyway, blanks are - * inserted. The blanks have probability kBlankProb and backoff kBlankBackoff. - * A blank is recognized by kBlankProb in the probability field; kBlankBackoff - * must be 0 so that inference asseses zero backoff from these blanks. - */ -const float kBlankProb = -std::numeric_limits::infinity(); -const float kBlankBackoff = kNoExtensionBackoff; -const uint32_t kBlankProbQuant = 0; -const uint32_t kBlankBackoffQuant = 0; - } // namespace ngram } // namespace lm #endif // LM_BLANK__ diff --git a/klm/lm/left.hh b/klm/lm/left.hh new file mode 100644 index 00000000..df69e97a --- /dev/null +++ b/klm/lm/left.hh @@ -0,0 +1,181 @@ +#ifndef LM_LEFT__ +#define LM_LEFT__ + +#include "lm/max_order.hh" +#include "lm/model.hh" +#include "lm/return.hh" + +#include + +namespace lm { +namespace ngram { + +struct Left { + bool operator==(const Left &other) const { + return + (length == other.length) && + pointers[length - 1] == other.pointers[length - 1]; + } + + int Compare(const Left &other) const { + if (length != other.length) { + return (int)length - (int)other.length; + } + if (pointers[length - 1] > other.pointers[length - 1]) return 1; + if (pointers[length - 1] < other.pointers[length - 1]) return -1; + return 0; + } + + uint64_t pointers[kMaxOrder - 1]; + unsigned char length; +}; + +struct ChartState { + bool operator==(const ChartState &other) { + return (left == other.left) && (right == other.right) && (full == other.full); + } + + int Compare(const ChartState &other) const { + int lres = left.Compare(other.left); + if (lres) return lres; + int rres = right.Compare(other.right); + if (rres) return rres; + return (int)full - (int)other.full; + } + + Left left; + State right; + bool full; +}; + +template class RuleScore { + public: + explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(out), left_done_(false), left_write_(out.left.pointers), prob_(0.0) { + out.left.length = 0; + out.right.length = 0; + } + + void BeginSentence() { + out_.right = model_.BeginSentenceState(); + // out_.left is empty. + left_done_ = true; + } + + void Terminal(WordIndex word) { + State copy(out_.right); + FullScoreReturn ret = model_.FullScore(copy, word, out_.right); + ProcessRet(ret); + if (out_.right.length != copy.length + 1) left_done_ = true; + } + + // Faster version of NonTerminal for the case where the rule begins with a non-terminal. + void BeginNonTerminal(const ChartState &in, float prob) { + prob_ = prob; + out_ = in; + left_write_ = out_.left.pointers + out_.left.length; + left_done_ = in.full; + } + + void NonTerminal(const ChartState &in, float prob) { + prob_ += prob; + + if (!in.left.length) { + if (in.full) { + for (const float *i = out_.right.backoff; i < out_.right.backoff + out_.right.length; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + } + return; + } + + if (!out_.right.length) { + out_.right = in.right; + if (left_done_) return; + if (left_write_ != out_.left.pointers) { + left_done_ = true; + } else { + out_.left = in.left; + left_write_ = out_.left.pointers + in.left.length; + left_done_ = in.full; + } + return; + } + + float backoffs[kMaxOrder - 1], backoffs2[kMaxOrder - 1]; + float *back = backoffs, *back2 = backoffs2; + unsigned char next_use; + FullScoreReturn ret; + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + out_.right.length, out_.right.backoff, in.left.pointers[0], 1, back, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + unsigned char extend_length = 2; + for (const uint64_t *i = in.left.pointers + 1; i < in.left.pointers + in.left.length; ++i, ++extend_length) { + ProcessRet(ret = model_.ExtendLeft(out_.right.words, out_.right.words + next_use, back, *i, extend_length, back2, next_use)); + if (!next_use) { + left_done_ = true; + out_.right = in.right; + return; + } + std::swap(back, back2); + } + + if (in.full) { + for (const float *i = back; i != back + next_use; ++i) prob_ += *i; + left_done_ = true; + out_.right = in.right; + return; + } + + // Right state was minimized, so it's already independent of the new words to the left. + if (in.right.length < in.left.length) { + out_.right = in.right; + return; + } + + // Shift exisiting words down. + for (WordIndex *i = out_.right.words + next_use - 1; i >= out_.right.words; --i) { + *(i + in.right.length) = *i; + } + // Add words from in.right. + std::copy(in.right.words, in.right.words + in.right.length, out_.right.words); + // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. + std::copy(in.right.backoff, in.right.backoff + in.right.length, out_.right.backoff); + std::copy(back, back + next_use, out_.right.backoff + in.right.length); + out_.right.length = in.right.length + next_use; + } + + float Finish() { + out_.left.length = left_write_ - out_.left.pointers; + out_.full = left_done_; + return prob_; + } + + private: + void ProcessRet(const FullScoreReturn &ret) { + prob_ += ret.prob; + if (left_done_) return; + if (ret.independent_left) { + left_done_ = true; + return; + } + *(left_write_++) = ret.extend_left; + } + + const M &model_; + + ChartState &out_; + + bool left_done_; + + uint64_t *left_write_; + + float prob_; +}; + +} // namespace ngram +} // namespace lm + +#endif // LM_LEFT__ diff --git a/klm/lm/left_test.cc b/klm/lm/left_test.cc new file mode 100644 index 00000000..8bb91cb3 --- /dev/null +++ b/klm/lm/left_test.cc @@ -0,0 +1,360 @@ +#include "lm/left.hh" +#include "lm/model.hh" + +#include "util/tokenize_piece.hh" + +#include + +#define BOOST_TEST_MODULE LeftTest +#include +#include + +namespace lm { +namespace ngram { +namespace { + +#define Term(word) score.Terminal(m.GetVocabulary().Index(word)); +#define VCheck(word, value) BOOST_CHECK_EQUAL(m.GetVocabulary().Index(word), value); + +template void Short(const M &m) { + ChartState base; + { + RuleScore score(m, base); + Term("more"); + Term("loin"); + BOOST_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001); + } + BOOST_CHECK(base.full); + BOOST_CHECK_EQUAL(2, base.left.length); + BOOST_CHECK_EQUAL(1, base.right.length); + VCheck("loin", base.right.words[0]); + + ChartState more_left; + { + RuleScore score(m, more_left); + Term("little"); + score.NonTerminal(base, -1.206319 - 0.3561665); + // p(little more loin | null context) + BOOST_CHECK_CLOSE(-1.56538, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(3, more_left.left.length); + BOOST_CHECK_EQUAL(1, more_left.right.length); + VCheck("loin", more_left.right.words[0]); + BOOST_CHECK(more_left.full); + + ChartState shorter; + { + RuleScore score(m, shorter); + Term("to"); + score.NonTerminal(base, -1.206319 - 0.3561665); + BOOST_CHECK_CLOSE(-0.30103 - 1.687872 - 1.206319 - 0.3561665, score.Finish(), 0.01); + } + BOOST_CHECK_EQUAL(1, shorter.left.length); + BOOST_CHECK_EQUAL(1, shorter.right.length); + VCheck("loin", shorter.right.words[0]); + BOOST_CHECK(shorter.full); +} + +template void Charge(const M &m) { + ChartState base; + { + RuleScore score(m, base); + Term("on"); + Term("more"); + BOOST_CHECK_CLOSE(-1.509559 -0.4771212 -1.206319, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(1, base.left.length); + BOOST_CHECK_EQUAL(1, base.right.length); + VCheck("more", base.right.words[0]); + BOOST_CHECK(base.full); + + ChartState extend; + { + RuleScore score(m, extend); + Term("looking"); + score.NonTerminal(base, -1.509559 -0.4771212 -1.206319); + BOOST_CHECK_CLOSE(-3.91039, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, extend.left.length); + BOOST_CHECK_EQUAL(1, extend.right.length); + VCheck("more", extend.right.words[0]); + BOOST_CHECK(extend.full); + + ChartState tobos; + { + RuleScore score(m, tobos); + score.BeginSentence(); + score.NonTerminal(extend, -3.91039); + BOOST_CHECK_CLOSE(-3.471169, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(0, tobos.left.length); + BOOST_CHECK_EQUAL(1, tobos.right.length); +} + +template float LeftToRight(const M &m, const std::vector &words) { + float ret = 0.0; + State right = m.NullContextState(); + for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { + State copy(right); + ret += m.Score(copy, *i, right); + } + return ret; +} + +template float RightToLeft(const M &m, const std::vector &words) { + float ret = 0.0; + ChartState state; + state.left.length = 0; + state.right.length = 0; + state.full = false; + for (std::vector::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) { + ChartState copy(state); + RuleScore score(m, state); + score.Terminal(*i); + score.NonTerminal(copy, ret); + ret = score.Finish(); + } + return ret; +} + +template float TreeMiddle(const M &m, const std::vector &words) { + std::vector > states(words.size()); + for (unsigned int i = 0; i < words.size(); ++i) { + RuleScore score(m, states[i].first); + score.Terminal(words[i]); + states[i].second = score.Finish(); + } + while (states.size() > 1) { + std::vector > upper((states.size() + 1) / 2); + for (unsigned int i = 0; i < states.size() / 2; ++i) { + RuleScore score(m, upper[i].first); + score.NonTerminal(states[i*2].first, states[i*2].second); + score.NonTerminal(states[i*2+1].first, states[i*2+1].second); + upper[i].second = score.Finish(); + } + if (states.size() % 2) { + upper.back() = states.back(); + } + std::swap(states, upper); + } + return states.empty() ? 0 : states.back().second; +} + +template void LookupVocab(const M &m, const StringPiece &str, std::vector &out) { + out.clear(); + for (util::PieceIterator<' '> i(str); i; ++i) { + out.push_back(m.GetVocabulary().Index(*i)); + } +} + +#define TEXT_TEST(str) \ +{ \ + std::vector words; \ + LookupVocab(m, str, words); \ + float expect = LeftToRight(m, words); \ + BOOST_CHECK_CLOSE(expect, RightToLeft(m, words), 0.001); \ + BOOST_CHECK_CLOSE(expect, TreeMiddle(m, words), 0.001); \ +} + +// Build sentences, or parts thereof, from right to left. +template void GrowBig(const M &m) { + TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown "); + TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown "); + TEXT_TEST("on a little more loin also would consider higher to look good"); + TEXT_TEST("more loin also would consider higher to look good"); + TEXT_TEST("more loin also would consider higher to look"); + TEXT_TEST("also would consider higher to look"); + TEXT_TEST("also would consider higher"); + TEXT_TEST("would consider higher to look"); + TEXT_TEST("consider higher to look"); + TEXT_TEST("consider higher to"); + TEXT_TEST("consider higher"); +} + +template void AlsoWouldConsiderHigher(const M &m) { + ChartState also; + { + RuleScore score(m, also); + score.Terminal(m.GetVocabulary().Index("also")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + ChartState would; + { + RuleScore score(m, would); + score.Terminal(m.GetVocabulary().Index("would")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + ChartState combine_also_would; + { + RuleScore score(m, combine_also_would); + score.NonTerminal(also, -1.687872); + score.NonTerminal(would, -1.687872); + BOOST_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, combine_also_would.right.length); + + ChartState also_would; + { + RuleScore score(m, also_would); + score.Terminal(m.GetVocabulary().Index("also")); + score.Terminal(m.GetVocabulary().Index("would")); + BOOST_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, also_would.right.length); + + ChartState consider; + { + RuleScore score(m, consider); + score.Terminal(m.GetVocabulary().Index("consider")); + BOOST_CHECK_CLOSE(-1.687872, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(1, consider.left.length); + BOOST_CHECK_EQUAL(1, consider.right.length); + BOOST_CHECK(!consider.full); + + ChartState higher; + float higher_score; + { + RuleScore score(m, higher); + score.Terminal(m.GetVocabulary().Index("higher")); + higher_score = score.Finish(); + } + BOOST_CHECK_CLOSE(-1.509559, higher_score, 0.001); + BOOST_CHECK_EQUAL(1, higher.left.length); + BOOST_CHECK_EQUAL(1, higher.right.length); + BOOST_CHECK(!higher.full); + VCheck("higher", higher.right.words[0]); + BOOST_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001); + + ChartState consider_higher; + { + RuleScore score(m, consider_higher); + score.NonTerminal(consider, -1.687872); + score.NonTerminal(higher, higher_score); + BOOST_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(2, consider_higher.left.length); + BOOST_CHECK(!consider_higher.full); + + ChartState full; + { + RuleScore score(m, full); + score.NonTerminal(combine_also_would, -1.687872 - 2.0); + score.NonTerminal(consider_higher, -1.509559 - 1.687872 - 0.30103); + BOOST_CHECK_CLOSE(-10.6879, score.Finish(), 0.001); + } + BOOST_CHECK_EQUAL(4, full.right.length); +} + +template void GrowSmall(const M &m) { + TEXT_TEST("in biarritz watching considering looking . "); + TEXT_TEST("in biarritz watching considering looking ."); + TEXT_TEST("in biarritz"); +} + +#define CHECK_SCORE(str, val) \ +{ \ + float got = val; \ + std::vector indices; \ + LookupVocab(m, str, indices); \ + BOOST_CHECK_CLOSE(LeftToRight(m, indices), got, 0.001); \ +} + +template void FullGrow(const M &m) { + std::vector words; + LookupVocab(m, "in biarritz watching considering looking . ", words); + + ChartState lexical[7]; + float lexical_scores[7]; + for (unsigned int i = 0; i < 7; ++i) { + RuleScore score(m, lexical[i]); + score.Terminal(words[i]); + lexical_scores[i] = score.Finish(); + } + CHECK_SCORE("in", lexical_scores[0]); + CHECK_SCORE("biarritz", lexical_scores[1]); + CHECK_SCORE("watching", lexical_scores[2]); + CHECK_SCORE("", lexical_scores[6]); + + ChartState l1[4]; + float l1_scores[4]; + { + RuleScore score(m, l1[0]); + score.NonTerminal(lexical[0], lexical_scores[0]); + score.NonTerminal(lexical[1], lexical_scores[1]); + CHECK_SCORE("in biarritz", l1_scores[0] = score.Finish()); + } + { + RuleScore score(m, l1[1]); + score.NonTerminal(lexical[2], lexical_scores[2]); + score.NonTerminal(lexical[3], lexical_scores[3]); + CHECK_SCORE("watching considering", l1_scores[1] = score.Finish()); + } + { + RuleScore score(m, l1[2]); + score.NonTerminal(lexical[4], lexical_scores[4]); + score.NonTerminal(lexical[5], lexical_scores[5]); + CHECK_SCORE("looking .", l1_scores[2] = score.Finish()); + } + BOOST_CHECK_EQUAL(l1[2].left.length, 1); + l1[3] = lexical[6]; + l1_scores[3] = lexical_scores[6]; + + ChartState l2[2]; + float l2_scores[2]; + { + RuleScore score(m, l2[0]); + score.NonTerminal(l1[0], l1_scores[0]); + score.NonTerminal(l1[1], l1_scores[1]); + CHECK_SCORE("in biarritz watching considering", l2_scores[0] = score.Finish()); + } + { + RuleScore score(m, l2[1]); + score.NonTerminal(l1[2], l1_scores[2]); + score.NonTerminal(l1[3], l1_scores[3]); + CHECK_SCORE("looking . ", l2_scores[1] = score.Finish()); + } + BOOST_CHECK_EQUAL(l2[1].left.length, 1); + BOOST_CHECK(l2[1].full); + + ChartState top; + { + RuleScore score(m, top); + score.NonTerminal(l2[0], l2_scores[0]); + score.NonTerminal(l2[1], l2_scores[1]); + CHECK_SCORE("in biarritz watching considering looking . ", score.Finish()); + } +} + +template void Everything() { + Config config; + config.messages = NULL; + M m("test.arpa", config); + + Short(m); + Charge(m); + GrowBig(m); + AlsoWouldConsiderHigher(m); + GrowSmall(m); + FullGrow(m); +} + +BOOST_AUTO_TEST_CASE(ProbingAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(TrieAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(QuantTrieAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(ArrayQuantTrieAll) { + Everything(); +} +BOOST_AUTO_TEST_CASE(ArrayTrieAll) { + Everything(); +} + +} // namespace +} // namespace ngram +} // namespace lm diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 27e24b1c..ca581d8a 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -16,7 +16,7 @@ namespace lm { namespace ngram { size_t hash_value(const State &state) { - return util::MurmurHashNative(state.history_, sizeof(WordIndex) * state.valid_length_); + return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length); } namespace detail { @@ -41,11 +41,11 @@ template GenericModel::Ge // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); - begin_sentence.valid_length_ = 1; - begin_sentence.history_[0] = vocab_.BeginSentence(); - begin_sentence.backoff_[0] = search_.unigram.Lookup(begin_sentence.history_[0]).backoff; + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + begin_sentence.backoff[0] = search_.unigram.Lookup(begin_sentence.words[0]).backoff; State null_context = State(); - null_context.valid_length_ = 0; + null_context.length = 0; P::Init(begin_sentence, null_context, vocab_, search_.MiddleEnd() - search_.MiddleBegin() + 2); } @@ -87,7 +87,7 @@ template void GenericModel void GenericModel FullScoreReturn GenericModel::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const { - FullScoreReturn ret = ScoreExceptBackoff(in_state.history_, in_state.history_ + in_state.valid_length_, new_word, out_state); - if (ret.ngram_length - 1 < in_state.valid_length_) { - ret.prob = std::accumulate(in_state.backoff_ + ret.ngram_length - 1, in_state.backoff_ + in_state.valid_length_, ret.prob); + FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state); + if (ret.ngram_length - 1 < in_state.length) { + ret.prob = std::accumulate(in_state.backoff + ret.ngram_length - 1, in_state.backoff + in_state.length, ret.prob); } return ret; } @@ -131,32 +131,80 @@ template void GenericModel FullScoreReturn GenericModel::ExtendLeft( + const WordIndex *add_rbegin, const WordIndex *add_rend, + const float *backoff_in, + uint64_t extend_pointer, + unsigned char extend_length, + float *backoff_out, + unsigned char &next_use) const { + FullScoreReturn ret; + float subtract_me; + typename Search::Node node(search_.Unpack(extend_pointer, extend_length, subtract_me)); + ret.prob = subtract_me; + ret.ngram_length = extend_length; + next_use = 0; + // If this function is called, then it does depend on left words. + ret.independent_left = false; + ret.extend_left = extend_pointer; + const typename Search::Middle *mid_iter = search_.MiddleBegin() + extend_length - 1; + const WordIndex *i = add_rbegin; + for (; ; ++i, ++backoff_out, ++mid_iter) { + if (i == add_rend) { + // Ran out of words. + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + if (mid_iter == search_.MiddleEnd()) break; + if (ret.independent_left || !search_.LookupMiddle(*mid_iter, *i, *backoff_out, node, ret)) { + // Didn't match a word. + ret.independent_left = true; + for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b; + ret.prob -= subtract_me; + return ret; + } + ret.ngram_length = mid_iter - search_.MiddleBegin() + 2; + if (HasExtension(*backoff_out)) next_use = i - add_rbegin + 1; + } + + if (ret.independent_left || !search_.LookupLongest(*i, ret.prob, node)) { + // The last backoff weight, for Order() - 1. + ret.prob += backoff_in[i - add_rbegin]; + } else { + ret.ngram_length = P::Order(); + } + ret.independent_left = true; + ret.prob -= subtract_me; + return ret; } namespace { // Do a paraonoid copy of history, assuming new_word has already been copied -// (hence the -1). out_state.valid_length_ could be zero so I avoided using +// (hence the -1). out_state.length could be zero so I avoided using // std::copy. void CopyRemainingHistory(const WordIndex *from, State &out_state) { - WordIndex *out = out_state.history_ + 1; - const WordIndex *in_end = from + static_cast(out_state.valid_length_) - 1; + WordIndex *out = out_state.words + 1; + const WordIndex *in_end = from + static_cast(out_state.length) - 1; for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in; } } // namespace @@ -175,17 +223,17 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel class GenericModel : public base::Mod // This is the model type returned by RecognizeBinary. static const ModelType kModelType; + static const unsigned int kVersion = Search::kVersion; + /* Get the size of memory that will be mapped given ngram counts. This * does not include small non-mapped control structures, such as this class * itself. @@ -114,6 +116,25 @@ template class GenericModel : public base::Mod */ void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; + /* More efficient version of FullScore where a partial n-gram has already + * been scored. + * NOTE: THE RETURNED .prob IS RELATIVE, NOT ABSOLUTE. So for example, if + * the n-gram does not end up extending further left, then 0 is returned. + */ + FullScoreReturn ExtendLeft( + // Additional context in reverse order. This will update add_rend to + const WordIndex *add_rbegin, const WordIndex *add_rend, + // Backoff weights to use. + const float *backoff_in, + // extend_left returned by a previous query. + uint64_t extend_pointer, + // Length of n-gram that the pointer corresponds to. + unsigned char extend_length, + // Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)] + float *backoff_out, + // Amount of additional content that should be considered by the next call. + unsigned char &next_use) const; + private: friend void LoadLM<>(const char *file, const Config &config, GenericModel &to); diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 57c7291c..2654071f 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -10,8 +10,8 @@ namespace lm { namespace ngram { std::ostream &operator<<(std::ostream &o, const State &state) { - o << "State length " << static_cast(state.valid_length_) << ':'; - for (const WordIndex *i = state.history_; i < state.history_ + state.valid_length_; ++i) { + o << "State length " << static_cast(state.length) << ':'; + for (const WordIndex *i = state.words; i < state.words + state.length; ++i) { o << ' ' << *i; } return o; @@ -19,25 +19,26 @@ std::ostream &operator<<(std::ostream &o, const State &state) { namespace { -#define StartTest(word, ngram, score) \ +#define StartTest(word, ngram, score, indep_left) \ ret = model.FullScore( \ state, \ model.GetVocabulary().Index(word), \ out);\ BOOST_CHECK_CLOSE(score, ret.prob, 0.001); \ BOOST_CHECK_EQUAL(static_cast(ngram), ret.ngram_length); \ - BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.valid_length_); \ + BOOST_CHECK_GE(std::min(ngram, 5 - 1), out.length); \ + BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \ {\ - WordIndex context[state.valid_length_ + 1]; \ + WordIndex context[state.length + 1]; \ context[0] = model.GetVocabulary().Index(word); \ - std::copy(state.history_, state.history_ + state.valid_length_, context + 1); \ + std::copy(state.words, state.words + state.length, context + 1); \ State get_state; \ - model.GetState(context, context + state.valid_length_ + 1, get_state); \ + model.GetState(context, context + state.length + 1, get_state); \ BOOST_CHECK_EQUAL(out, get_state); \ } -#define AppendTest(word, ngram, score) \ - StartTest(word, ngram, score) \ +#define AppendTest(word, ngram, score, indep_left) \ + StartTest(word, ngram, score, indep_left) \ state = out; template void Starters(const M &model) { @@ -45,12 +46,12 @@ template void Starters(const M &model) { Model::State state(model.BeginSentenceState()); Model::State out; - StartTest("looking", 2, -0.4846522); + StartTest("looking", 2, -0.4846522, true); // , probability plus backoff - StartTest(",", 1, -1.383514 + -0.4149733); + StartTest(",", 1, -1.383514 + -0.4149733, true); // probability plus backoff - StartTest("this_is_not_found", 1, -1.995635 + -0.4149733); + StartTest("this_is_not_found", 1, -1.995635 + -0.4149733, true); } template void Continuation(const M &model) { @@ -58,46 +59,64 @@ template void Continuation(const M &model) { Model::State state(model.BeginSentenceState()); Model::State out; - AppendTest("looking", 2, -0.484652); - AppendTest("on", 3, -0.348837); - AppendTest("a", 4, -0.0155266); - AppendTest("little", 5, -0.00306122); + AppendTest("looking", 2, -0.484652, true); + AppendTest("on", 3, -0.348837, true); + AppendTest("a", 4, -0.0155266, true); + AppendTest("little", 5, -0.00306122, true); State preserve = state; - AppendTest("the", 1, -4.04005); - AppendTest("biarritz", 1, -1.9889); - AppendTest("not_found", 1, -2.29666); - AppendTest("more", 1, -1.20632 - 20.0); - AppendTest(".", 2, -0.51363); - AppendTest("", 3, -0.0191651); - BOOST_CHECK_EQUAL(0, state.valid_length_); + AppendTest("the", 1, -4.04005, true); + AppendTest("biarritz", 1, -1.9889, true); + AppendTest("not_found", 1, -2.29666, true); + AppendTest("more", 1, -1.20632 - 20.0, true); + AppendTest(".", 2, -0.51363, true); + AppendTest("", 3, -0.0191651, true); + BOOST_CHECK_EQUAL(0, state.length); state = preserve; - AppendTest("more", 5, -0.00181395); - BOOST_CHECK_EQUAL(4, state.valid_length_); - AppendTest("loin", 5, -0.0432557); - BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("more", 5, -0.00181395, true); + BOOST_CHECK_EQUAL(4, state.length); + AppendTest("loin", 5, -0.0432557, true); + BOOST_CHECK_EQUAL(1, state.length); } template void Blanks(const M &model) { FullScoreReturn ret; State state(model.NullContextState()); State out; - AppendTest("also", 1, -1.687872); - AppendTest("would", 2, -2); - AppendTest("consider", 3, -3); + AppendTest("also", 1, -1.687872, false); + AppendTest("would", 2, -2, true); + AppendTest("consider", 3, -3, true); State preserve = state; - AppendTest("higher", 4, -4); - AppendTest("looking", 5, -5); - BOOST_CHECK_EQUAL(1, state.valid_length_); + AppendTest("higher", 4, -4, true); + AppendTest("looking", 5, -5, true); + BOOST_CHECK_EQUAL(1, state.length); state = preserve; - AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103); + // also would consider not_found + AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103, true); state = model.NullContextState(); // higher looking is a blank. - AppendTest("higher", 1, -1.509559); - AppendTest("looking", 1, -1.285941 - 0.30103); - AppendTest("not_found", 1, -1.995635 - 0.4771212); + AppendTest("higher", 1, -1.509559, false); + AppendTest("looking", 2, -1.285941 - 0.30103, false); + + State higher_looking = state; + + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("not_found", 1, -1.995635 - 0.4771212, true); + + state = higher_looking; + // higher looking consider + AppendTest("consider", 1, -1.687872 - 0.4771212, true); + + state = model.NullContextState(); + AppendTest("would", 1, -1.687872, false); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("consider", 2, -1.687872 -0.30103, false); + BOOST_CHECK_EQUAL(2, state.length); + AppendTest("higher", 3, -1.509559 - 0.30103, false); + BOOST_CHECK_EQUAL(3, state.length); + AppendTest("looking", 4, -1.285941 - 0.30103, false); } template void Unknowns(const M &model) { @@ -105,14 +124,14 @@ template void Unknowns(const M &model) { State state(model.NullContextState()); State out; - AppendTest("not_found", 1, -1.995635); + AppendTest("not_found", 1, -1.995635, false); State preserve = state; - AppendTest("not_found2", 2, -15.0); - AppendTest("not_found3", 2, -15.0 - 2.0); + AppendTest("not_found2", 2, -15.0, true); + AppendTest("not_found3", 2, -15.0 - 2.0, true); state = preserve; - AppendTest("however", 2, -4); - AppendTest("not_found3", 3, -6); + AppendTest("however", 2, -4, true); + AppendTest("not_found3", 3, -6, true); } template void MinimalState(const M &model) { @@ -120,22 +139,66 @@ template void MinimalState(const M &model) { State state(model.NullContextState()); State out; - AppendTest("baz", 1, -6.535897); - BOOST_CHECK_EQUAL(0, state.valid_length_); + AppendTest("baz", 1, -6.535897, true); + BOOST_CHECK_EQUAL(0, state.length); state = model.NullContextState(); - AppendTest("foo", 1, -3.141592); - BOOST_CHECK_EQUAL(1, state.valid_length_); - AppendTest("bar", 2, -6.0); + AppendTest("foo", 1, -3.141592, true); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("bar", 2, -6.0, true); // Has to include the backoff weight. - BOOST_CHECK_EQUAL(1, state.valid_length_); - AppendTest("bar", 1, -2.718281 + 3.0); - BOOST_CHECK_EQUAL(1, state.valid_length_); + BOOST_CHECK_EQUAL(1, state.length); + AppendTest("bar", 1, -2.718281 + 3.0, true); + BOOST_CHECK_EQUAL(1, state.length); state = model.NullContextState(); - AppendTest("to", 1, -1.687872); - AppendTest("look", 2, -0.2922095); - BOOST_CHECK_EQUAL(2, state.valid_length_); - AppendTest("good", 3, -7); + AppendTest("to", 1, -1.687872, false); + AppendTest("look", 2, -0.2922095, true); + BOOST_CHECK_EQUAL(2, state.length); + AppendTest("good", 3, -7, true); +} + +template void ExtendLeftTest(const M &model) { + State right; + FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right)); + const float kLittleProb = -1.285941; + BOOST_CHECK_CLOSE(kLittleProb, little.prob, 0.001); + unsigned char next_use; + float backoff_out[4]; + + FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use)); + BOOST_CHECK_EQUAL(0, next_use); + BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left); + BOOST_CHECK_CLOSE(0.0, extend_none.prob, 0.001); + BOOST_CHECK_EQUAL(1, extend_none.ngram_length); + + const WordIndex a = model.GetVocabulary().Index("a"); + float backoff_in = 3.14; + // a little + FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use)); + BOOST_CHECK_EQUAL(1, next_use); + BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.09132547 - kLittleProb, extend_a.prob, 0.001); + BOOST_CHECK_EQUAL(2, extend_a.ngram_length); + BOOST_CHECK(!extend_a.independent_left); + + const WordIndex on = model.GetVocabulary().Index("on"); + FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use)); + BOOST_CHECK_EQUAL(1, next_use); + BOOST_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - -0.09132547, extend_on.prob, 0.001); + BOOST_CHECK_EQUAL(3, extend_on.ngram_length); + BOOST_CHECK(!extend_on.independent_left); + + const WordIndex both[2] = {a, on}; + float backoff_in_arr[4]; + FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use)); + BOOST_CHECK_EQUAL(2, next_use); + BOOST_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001); + BOOST_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001); + BOOST_CHECK_CLOSE(-0.0283603 - kLittleProb, extend_both.prob, 0.001); + BOOST_CHECK_EQUAL(3, extend_both.ngram_length); + BOOST_CHECK(!extend_both.independent_left); + BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left); } #define StatelessTest(word, provide, ngram, score) \ @@ -166,17 +229,17 @@ template void Stateless(const M &model) { // looking StatelessTest(1, 2, 2, -0.484652); // on - AppendTest("on", 3, -0.348837); + AppendTest("on", 3, -0.348837, true); StatelessTest(2, 3, 3, -0.348837); StatelessTest(2, 2, 3, -0.348837); StatelessTest(2, 1, 2, -0.4638903); // a StatelessTest(3, 4, 4, -0.0155266); // little - AppendTest("little", 5, -0.00306122); + AppendTest("little", 5, -0.00306122, true); StatelessTest(4, 5, 5, -0.00306122); // the - AppendTest("the", 1, -4.04005); + AppendTest("the", 1, -4.04005, true); StatelessTest(5, 5, 1, -4.04005); // No context of the. StatelessTest(5, 0, 1, -1.687872); @@ -189,8 +252,8 @@ template void Stateless(const M &model) { WordIndex unk[1]; unk[0] = 0; model.GetState(unk, unk + 1, state); - BOOST_CHECK_EQUAL(1, state.valid_length_); - BOOST_CHECK_EQUAL(static_cast(0), state.history_[0]); + BOOST_CHECK_EQUAL(1, state.length); + BOOST_CHECK_EQUAL(static_cast(0), state.words[0]); } template void NoUnkCheck(const M &model) { @@ -207,6 +270,7 @@ template void Everything(const M &m) { Blanks(m); Unknowns(m); MinimalState(m); + ExtendLeftTest(m); Stateless(m); } @@ -245,6 +309,7 @@ template void LoadingTest() { config.enumerate_vocab = &enumerate; ModelT m("test.arpa", config); enumerate.Check(m.GetVocabulary()); + BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); Everything(m); } { @@ -252,6 +317,7 @@ template void LoadingTest() { config.enumerate_vocab = &enumerate; ModelT m("test_nounk.arpa", config); enumerate.Check(m.GetVocabulary()); + BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound()); NoUnkCheck(m); } } diff --git a/klm/lm/model_type.hh b/klm/lm/model_type.hh new file mode 100644 index 00000000..5057ed25 --- /dev/null +++ b/klm/lm/model_type.hh @@ -0,0 +1,16 @@ +#ifndef LM_MODEL_TYPE__ +#define LM_MODEL_TYPE__ + +namespace lm { +namespace ngram { + +/* Not the best numbering system, but it grew this way for historical reasons + * and I want to preserve existing binary files. */ +typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3, ARRAY_TRIE_SORTED=4, QUANT_ARRAY_TRIE_SORTED=5} ModelType; + +const static ModelType kQuantAdd = static_cast(QUANT_TRIE_SORTED - TRIE_SORTED); +const static ModelType kArrayAdd = static_cast(ARRAY_TRIE_SORTED - TRIE_SORTED); + +} // namespace ngram +} // namespace lm +#endif // LM_MODEL_TYPE__ diff --git a/klm/lm/quantize.cc b/klm/lm/quantize.cc index fd371cc8..98a5d048 100644 --- a/klm/lm/quantize.cc +++ b/klm/lm/quantize.cc @@ -1,5 +1,6 @@ #include "lm/quantize.hh" +#include "lm/binary_format.hh" #include "lm/lm_exception.hh" #include @@ -70,8 +71,7 @@ void SeparatelyQuantize::Train(uint8_t order, std::vector &prob, std::vec void SeparatelyQuantize::TrainProb(uint8_t order, std::vector &prob) { float *centers = start_ + TableStart(order); - *(centers++) = kBlankProb; - MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_) - 1); + MakeBins(&*prob.begin(), &*prob.end(), centers, (1ULL << prob_bits_)); } void SeparatelyQuantize::FinishedLoading(const Config &config) { diff --git a/klm/lm/quantize.hh b/klm/lm/quantize.hh index 0b71d14a..4cf4236e 100644 --- a/klm/lm/quantize.hh +++ b/klm/lm/quantize.hh @@ -1,9 +1,9 @@ #ifndef LM_QUANTIZE_H__ #define LM_QUANTIZE_H__ -#include "lm/binary_format.hh" // for ModelType #include "lm/blank.hh" #include "lm/config.hh" +#include "lm/model_type.hh" #include "util/bit_packing.hh" #include @@ -36,6 +36,9 @@ class DontQuantize { prob = util::ReadNonPositiveFloat31(base, bit_offset); backoff = util::ReadFloat32(base, bit_offset + 31); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = util::ReadNonPositiveFloat31(base, bit_offset); + } void ReadBackoff(const void *base, uint64_t bit_offset, float &backoff) const { backoff = util::ReadFloat32(base, bit_offset + 31); } @@ -77,7 +80,7 @@ class SeparatelyQuantize { Bins(uint8_t bits, const float *const begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} uint64_t EncodeProb(float value) const { - return(value == kBlankProb ? kBlankProbQuant : Encode(value, 1)); + return Encode(value, 0); } uint64_t EncodeBackoff(float value) const { @@ -132,6 +135,10 @@ class SeparatelyQuantize { (prob_.EncodeProb(prob) << backoff_.Bits()) | backoff_.EncodeBackoff(backoff)); } + void ReadProb(const void *base, uint64_t bit_offset, float &prob) const { + prob = prob_.Decode(util::ReadInt25(base, bit_offset + backoff_.Bits(), prob_.Bits(), prob_.Mask())); + } + void Read(const void *base, uint64_t bit_offset, float &prob, float &backoff) const { uint64_t both = util::ReadInt57(base, bit_offset, total_bits_, total_mask_); prob = prob_.Decode(both >> backoff_.Bits()); @@ -179,7 +186,7 @@ class SeparatelyQuantize { void SetupMemory(void *start, const Config &config); static const bool kTrain = true; - // Assumes kBlankProb is removed from prob and 0.0 is removed from backoff. + // Assumes 0.0 is removed from backoff. void Train(uint8_t order, std::vector &prob, std::vector &backoff); // Train just probabilities (for longest order). void TrainProb(uint8_t order, std::vector &prob); diff --git a/klm/lm/return.hh b/klm/lm/return.hh new file mode 100644 index 00000000..15571960 --- /dev/null +++ b/klm/lm/return.hh @@ -0,0 +1,39 @@ +#ifndef LM_RETURN__ +#define LM_RETURN__ + +#include + +namespace lm { +/* Structure returned by scoring routines. */ +struct FullScoreReturn { + // log10 probability + float prob; + + /* The length of n-gram matched. Do not use this for recombination. + * Consider a model containing only the following n-grams: + * -1 foo + * -3.14 bar + * -2.718 baz -5 + * -6 foo bar + * + * If you score ``bar'' then ngram_length is 1 and recombination state is the + * empty string because bar has zero backoff and does not extend to the + * right. + * If you score ``foo'' then ngram_length is 1 and recombination state is + * ``foo''. + * + * Ideally, keep output states around and compare them. Failing that, + * get out_state.ValidLength() and use that length for recombination. + */ + unsigned char ngram_length; + + /* Left extension information. If independent_left is set, then prob is + * independent of words to the left (up to additional backoff). Otherwise, + * extend_left indicates how to efficiently extend further to the left. + */ + bool independent_left; + uint64_t extend_left; // Defined only if independent_left +}; + +} // namespace lm +#endif // LM_RETURN__ diff --git a/klm/lm/search_hashed.cc b/klm/lm/search_hashed.cc index 82c53ec8..334adf12 100644 --- a/klm/lm/search_hashed.cc +++ b/klm/lm/search_hashed.cc @@ -1,10 +1,12 @@ #include "lm/search_hashed.hh" +#include "lm/binary_format.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/read_arpa.hh" #include "lm/vocab.hh" +#include "util/bit_packing.hh" #include "util/file_piece.hh" #include @@ -48,30 +50,77 @@ class ActivateUnigram { ProbBackoff *modify_; }; -template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, std::vector &middle, Activate activate, Store &store, PositiveProbWarn &warn) { - - ReadNGramHeader(f, n); +template void FixSRI(int lower, float negative_lower_prob, unsigned int n, const uint64_t *keys, const WordIndex *vocab_ids, ProbBackoff *unigrams, std::vector &middle) { ProbBackoff blank; - blank.prob = kBlankProb; - blank.backoff = kBlankBackoff; + blank.backoff = kNoExtensionBackoff; + // Fix SRI's stupidity. + // Note that negative_lower_prob is the negative of the probability (so it's currently >= 0). We still want the sign bit off to indicate left extension, so I just do -= on the backoffs. + blank.prob = negative_lower_prob; + // An entry was found at lower (order lower + 2). + // We need to insert blanks starting at lower + 1 (order lower + 3). + unsigned int fix = static_cast(lower + 1); + uint64_t backoff_hash = detail::CombineWordHash(static_cast(vocab_ids[1]), vocab_ids[2]); + if (fix == 0) { + // Insert a missing bigram. + blank.prob -= unigrams[vocab_ids[1]].backoff; + SetExtension(unigrams[vocab_ids[1]].backoff); + // Bigram including a unigram's backoff + middle[0].Insert(Middle::Packing::Make(keys[0], blank)); + fix = 1; + } else { + for (unsigned int i = 3; i < fix + 2; ++i) backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]); + } + // fix >= 1. Insert trigrams and above. + for (; fix <= n - 3; ++fix) { + typename Middle::MutableIterator gotit; + if (middle[fix - 1].UnsafeMutableFind(backoff_hash, gotit)) { + float &backoff = gotit->MutableValue().backoff; + SetExtension(backoff); + blank.prob -= backoff; + } + middle[fix].Insert(Middle::Packing::Make(keys[fix], blank)); + backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[fix + 2]); + } +} + +template void ReadNGrams(util::FilePiece &f, const unsigned int n, const size_t count, const Voc &vocab, ProbBackoff *unigrams, std::vector &middle, Activate activate, Store &store, PositiveProbWarn &warn) { + ReadNGramHeader(f, n); // vocab ids of words in reverse order WordIndex vocab_ids[n]; uint64_t keys[n - 1]; typename Store::Packing::Value value; - typename Middle::ConstIterator found; + typename Middle::MutableIterator found; for (size_t i = 0; i < count; ++i) { ReadNGram(f, n, vocab, vocab_ids, value, warn); + keys[0] = detail::CombineWordHash(static_cast(*vocab_ids), vocab_ids[1]); for (unsigned int h = 1; h < n - 1; ++h) { keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]); } + // Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0. + util::SetSign(value.prob); store.Insert(Store::Packing::Make(keys[n-2], value)); - // Go back and insert blanks. - for (int lower = n - 3; lower >= 0; --lower) { - if (middle[lower].Find(keys[lower], found)) break; - middle[lower].Insert(Middle::Packing::Make(keys[lower], blank)); + // Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb. + int lower; + util::FloatEnc fix_prob; + for (lower = n - 3; ; --lower) { + if (lower == -1) { + fix_prob.f = unigrams[vocab_ids[0]].prob; + fix_prob.i &= ~util::kSignBit; + unigrams[vocab_ids[0]].prob = fix_prob.f; + break; + } + if (middle[lower].UnsafeMutableFind(keys[lower], found)) { + // Turn off sign bit to indicate that it extends left. + fix_prob.f = found->MutableValue().prob; + fix_prob.i &= ~util::kSignBit; + found->MutableValue().prob = fix_prob.f; + // We don't need to recurse further down because this entry already set the bits for lower entries. + break; + } } + if (lower != static_cast(n) - 3) FixSRI(lower, fix_prob.f, n, keys, vocab_ids, unigrams, middle); activate(vocab_ids, n); } @@ -107,15 +156,15 @@ template template void TemplateHashe try { if (counts.size() > 2) { - ReadNGrams(f, 2, counts[1], vocab, middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); + ReadNGrams(f, 2, counts[1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), middle_[0], warn); } for (unsigned int n = 3; n < counts.size(); ++n) { - ReadNGrams(f, n, counts[n-1], vocab, middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); + ReadNGrams(f, n, counts[n-1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle(middle_[n-3]), middle_[n-2], warn); } if (counts.size() > 2) { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateLowerMiddle(middle_.back()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateLowerMiddle(middle_.back()), longest, warn); } else { - ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, middle_, ActivateUnigram(unigram.Raw()), longest, warn); + ReadNGrams(f, counts.size(), counts[counts.size() - 1], vocab, unigram.Raw(), middle_, ActivateUnigram(unigram.Raw()), longest, warn); } } catch (util::ProbingSizeException &e) { UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n"); @@ -133,7 +182,7 @@ template void TemplateHashedSearch; -template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); +template void TemplateHashedSearch::InitializeFromARPA(const char *, util::FilePiece &f, const std::vector &counts, const Config &, ProbingVocabulary &vocab, Backing &backing); } // namespace detail } // namespace ngram diff --git a/klm/lm/search_hashed.hh b/klm/lm/search_hashed.hh index c62985e4..e289fd11 100644 --- a/klm/lm/search_hashed.hh +++ b/klm/lm/search_hashed.hh @@ -1,15 +1,18 @@ #ifndef LM_SEARCH_HASHED__ #define LM_SEARCH_HASHED__ -#include "lm/binary_format.hh" +#include "lm/model_type.hh" #include "lm/config.hh" #include "lm/read_arpa.hh" +#include "lm/return.hh" #include "lm/weights.hh" +#include "util/bit_packing.hh" #include "util/key_value_packing.hh" #include "util/probing_hash_table.hh" #include +#include #include namespace util { class FilePiece; } @@ -52,9 +55,14 @@ struct HashedSearch { Unigram unigram; - void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const { + void LookupUnigram(WordIndex word, float &backoff, Node &next, FullScoreReturn &ret) const { const ProbBackoff &entry = unigram.Lookup(word); - prob = entry.prob; + util::FloatEnc val; + val.f = entry.prob; + ret.independent_left = (val.i & util::kSignBit); + ret.extend_left = static_cast(word); + val.i |= util::kSignBit; + ret.prob = val.f; backoff = entry.backoff; next = static_cast(word); } @@ -67,6 +75,8 @@ template class TemplateHashedSearch : public Has typedef LongestT Longest; Longest longest; + static const unsigned int kVersion = 0; + // TODO: move probing_multiplier here with next binary file format update. static void UpdateConfigFromBinary(int, const std::vector &, Config &) {} @@ -85,11 +95,33 @@ template class TemplateHashedSearch : public Has const Middle *MiddleBegin() const { return &*middle_.begin(); } const Middle *MiddleEnd() const { return &*middle_.end(); } - bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const { + Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { + util::FloatEnc val; + if (extend_length == 1) { + val.f = unigram.Lookup(static_cast(extend_pointer)).prob; + } else { + typename Middle::ConstIterator found; + if (!middle_[extend_length - 2].Find(extend_pointer, found)) { + std::cerr << "Extend pointer " << extend_pointer << " should have been found for length " << (unsigned) extend_length << std::endl; + abort(); + } + val.f = found->GetValue().prob; + } + val.i |= util::kSignBit; + prob = val.f; + return extend_pointer; + } + + bool LookupMiddle(const Middle &middle, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { node = CombineWordHash(node, word); typename Middle::ConstIterator found; if (!middle.Find(node, found)) return false; - prob = found->GetValue().prob; + util::FloatEnc enc; + enc.f = found->GetValue().prob; + ret.independent_left = (enc.i & util::kSignBit); + ret.extend_left = node; + enc.i |= util::kSignBit; + ret.prob = enc.f; backoff = found->GetValue().backoff; return true; } @@ -105,6 +137,7 @@ template class TemplateHashedSearch : public Has } bool LookupLongest(WordIndex word, float &prob, Node &node) const { + // Sign bit is always on because longest n-grams do not extend left. node = CombineWordHash(node, word); typename Longest::ConstIterator found; if (!longest.Find(node, found)) return false; diff --git a/klm/lm/search_trie.cc b/klm/lm/search_trie.cc index 05059ffb..6479813b 100644 --- a/klm/lm/search_trie.cc +++ b/klm/lm/search_trie.cc @@ -2,26 +2,25 @@ #include "lm/search_trie.hh" #include "lm/bhiksha.hh" +#include "lm/binary_format.hh" #include "lm/blank.hh" #include "lm/lm_exception.hh" #include "lm/max_order.hh" #include "lm/quantize.hh" -#include "lm/read_arpa.hh" #include "lm/trie.hh" +#include "lm/trie_sort.hh" #include "lm/vocab.hh" #include "lm/weights.hh" #include "lm/word_index.hh" #include "util/ersatz_progress.hh" -#include "util/file_piece.hh" -#include "util/have.hh" #include "util/proxy_iterator.hh" #include "util/scoped.hh" +#include "util/sized_iterator.hh" #include -#include #include #include -#include +#include #include #include #include @@ -29,575 +28,221 @@ #include #include #include -#include -#include -#include namespace lm { namespace ngram { namespace trie { namespace { -/* An entry is a n-gram with probability. It consists of: - * WordIndex[order] - * float probability - * backoff probability (omitted for highest order n-gram) - * These are stored consecutively in memory. We want to sort them. - * - * The problem is the length depends on order (but all n-grams being compared - * have the same order). Allocating each entry on the heap (i.e. std::vector - * or std::string) then sorting pointers is the normal solution. But that's - * too memory inefficient. A lot of this code is just here to force std::sort - * to work with records where length is specified at runtime (and avoid using - * Boost for LM code). I could have used qsort, but the point is to also - * support __gnu_cxx:parallel_sort which doesn't have a qsort version. - */ - -class EntryIterator { - public: - EntryIterator() {} - - EntryIterator(void *ptr, std::size_t size) : ptr_(static_cast(ptr)), size_(size) {} - - bool operator==(const EntryIterator &other) const { - return ptr_ == other.ptr_; - } - bool operator<(const EntryIterator &other) const { - return ptr_ < other.ptr_; - } - EntryIterator &operator+=(std::ptrdiff_t amount) { - ptr_ += amount * size_; - return *this; - } - std::ptrdiff_t operator-(const EntryIterator &other) const { - return (ptr_ - other.ptr_) / size_; - } - - const void *Data() const { return ptr_; } - void *Data() { return ptr_; } - std::size_t EntrySize() const { return size_; } - - private: - uint8_t *ptr_; - std::size_t size_; -}; - -class EntryProxy { - public: - EntryProxy() {} - - EntryProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} - - operator std::string() const { - return std::string(reinterpret_cast(inner_.Data()), inner_.EntrySize()); - } - - EntryProxy &operator=(const EntryProxy &from) { - memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); - return *this; - } - - EntryProxy &operator=(const std::string &from) { - memcpy(inner_.Data(), from.data(), inner_.EntrySize()); - return *this; - } - - const WordIndex *Indices() const { - return reinterpret_cast(inner_.Data()); - } - - private: - friend class util::ProxyIterator; - - typedef std::string value_type; - - typedef EntryIterator InnerIterator; - InnerIterator &Inner() { return inner_; } - const InnerIterator &Inner() const { return inner_; } - InnerIterator inner_; -}; - -typedef util::ProxyIterator NGramIter; - -// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. -class PartialViewProxy { - public: - PartialViewProxy() : attention_size_(0), inner_() {} - - PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} - - operator std::string() const { - return std::string(reinterpret_cast(inner_.Data()), attention_size_); - } - - PartialViewProxy &operator=(const PartialViewProxy &from) { - memcpy(inner_.Data(), from.inner_.Data(), attention_size_); - return *this; - } - - PartialViewProxy &operator=(const std::string &from) { - memcpy(inner_.Data(), from.data(), attention_size_); - return *this; - } - - const WordIndex *Indices() const { - return reinterpret_cast(inner_.Data()); - } - - private: - friend class util::ProxyIterator; - - typedef std::string value_type; - - const std::size_t attention_size_; - - typedef EntryIterator InnerIterator; - InnerIterator &Inner() { return inner_; } - const InnerIterator &Inner() const { return inner_; } - InnerIterator inner_; -}; - -typedef util::ProxyIterator PartialIter; - -template class CompareRecords : public std::binary_function { - public: - explicit CompareRecords(unsigned char order) : order_(order) {} - - bool operator()(const Proxy &first, const Proxy &second) const { - return Compare(first.Indices(), second.Indices()); - } - bool operator()(const Proxy &first, const std::string &second) const { - return Compare(first.Indices(), reinterpret_cast(second.data())); - } - bool operator()(const std::string &first, const Proxy &second) const { - return Compare(reinterpret_cast(first.data()), second.Indices()); - } - bool operator()(const std::string &first, const std::string &second) const { - return Compare(reinterpret_cast(first.data()), reinterpret_cast(second.data())); - } - - private: - bool Compare(const WordIndex *first, const WordIndex *second) const { - const WordIndex *end = first + order_; - for (; first != end; ++first, ++second) { - if (*first < *second) return true; - if (*first > *second) return false; - } - return false; - } - - unsigned char order_; -}; - -FILE *OpenOrThrow(const char *name, const char *mode) { - FILE *ret = fopen(name, mode); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); - return ret; -} - -void WriteOrThrow(FILE *to, const void *data, size_t size) { - assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); -} - void ReadOrThrow(FILE *from, void *data, size_t size) { - if (1 != std::fread(data, size, 1, from)) UTIL_THROW(util::ErrnoException, "Short read; requested size " << size); -} - -const std::size_t kCopyBufSize = 512; -void CopyOrThrow(FILE *from, FILE *to, size_t size) { - char buf[std::min(size, kCopyBufSize)]; - for (size_t i = 0; i < size; i += kCopyBufSize) { - std::size_t amount = std::min(size - i, kCopyBufSize); - ReadOrThrow(from, buf, amount); - WriteOrThrow(to, buf, amount); - } + UTIL_THROW_IF(1 != std::fread(data, size, 1, from), util::ErrnoException, "Short read"); } -void CopyRestOrThrow(FILE *from, FILE *to) { - char buf[kCopyBufSize]; - size_t amount; - while ((amount = fread(buf, 1, kCopyBufSize, from))) { - WriteOrThrow(to, buf, amount); +int Compare(unsigned char order, const void *first_void, const void *second_void) { + const WordIndex *first = reinterpret_cast(first_void), *second = reinterpret_cast(second_void); + const WordIndex *end = first + order; + for (; first != end; ++first, ++second) { + if (*first < *second) return -1; + if (*first > *second) return 1; } - if (!feof(from)) UTIL_THROW(util::ErrnoException, "Short read"); -} - -void RemoveOrThrow(const char *name) { - if (std::remove(name)) UTIL_THROW(util::ErrnoException, "Could not remove " << name); + return 0; } -std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) { - const std::size_t entry_size = sizeof(WordIndex) * order + weights_size; - const std::size_t prefix_size = sizeof(WordIndex) * (order - 1); - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << '_' << batch; - std::string ret(assembled.str()); - util::scoped_FILE out(OpenOrThrow(ret.c_str(), "w")); - // Compress entries that being with the same (order-1) words. - for (const uint8_t *group_begin = static_cast(mem_begin); group_begin != static_cast(mem_end);) { - const uint8_t *group_end; - for (group_end = group_begin + entry_size; - (group_end != static_cast(mem_end)) && !memcmp(group_begin, group_end, prefix_size); - group_end += entry_size) {} - WriteOrThrow(out.get(), group_begin, prefix_size); - WordIndex group_size = (group_end - group_begin) / entry_size; - WriteOrThrow(out.get(), &group_size, sizeof(group_size)); - for (const uint8_t *i = group_begin; i != group_end; i += entry_size) { - WriteOrThrow(out.get(), i + prefix_size, sizeof(WordIndex)); - WriteOrThrow(out.get(), i + sizeof(WordIndex) * order, weights_size); - } - group_begin = group_end; - } - return ret; -} +struct ProbPointer { + unsigned char array; + uint64_t index; +}; -class SortedFileReader { +// Array of n-grams and float indices. +class BackoffMessages { public: - SortedFileReader() : ended_(false) {} - - void Init(const std::string &name, unsigned char order) { - file_.reset(OpenOrThrow(name.c_str(), "r")); - header_.resize(order - 1); - NextHeader(); + void Init(std::size_t entry_size) { + current_ = NULL; + allocated_ = NULL; + entry_size_ = entry_size; } - // Preceding words. - const WordIndex *Header() const { - return &*header_.begin(); - } - const std::vector &HeaderVector() const { return header_;} - - std::size_t HeaderBytes() const { return header_.size() * sizeof(WordIndex); } - - void NextHeader() { - if (1 != fread(&*header_.begin(), HeaderBytes(), 1, file_.get())) { - if (feof(file_.get())) { - ended_ = true; - } else { - UTIL_THROW(util::ErrnoException, "Short read of counts"); + void Add(const WordIndex *to, ProbPointer index) { + while (current_ + entry_size_ > allocated_) { + std::size_t allocated_size = allocated_ - (uint8_t*)backing_.get(); + Resize(std::max(allocated_size * 2, entry_size_)); + } + memcpy(current_, to, entry_size_ - sizeof(ProbPointer)); + *reinterpret_cast(current_ + entry_size_ - sizeof(ProbPointer)) = index; + current_ += entry_size_; + } + + void Apply(float *const *const base, FILE *unigrams) { + FinishedAdding(); + if (current_ == allocated_) return; + rewind(unigrams); + ProbBackoff weights; + WordIndex unigram = 0; + ReadOrThrow(unigrams, &weights, sizeof(weights)); + for (; current_ != allocated_; current_ += entry_size_) { + const WordIndex &cur_word = *reinterpret_cast(current_); + for (; unigram < cur_word; ++unigram) { + ReadOrThrow(unigrams, &weights, sizeof(weights)); } + if (!HasExtension(weights.backoff)) { + weights.backoff = kExtensionBackoff; + UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed."); + WriteOrThrow(unigrams, &weights, sizeof(weights)); + } + const ProbPointer &write_to = *reinterpret_cast(current_ + sizeof(WordIndex)); + base[write_to.array][write_to.index] += weights.backoff; } + backing_.reset(); + } + + void Apply(float *const *const base, RecordReader &reader) { + FinishedAdding(); + if (current_ == allocated_) return; + // We'll also use the same buffer to record messages to blanks that they extend. + WordIndex *extend_out = reinterpret_cast(current_); + const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex); + for (reader.Rewind(); reader && (current_ != allocated_); ) { + switch (Compare(order, reader.Data(), current_)) { + case -1: + ++reader; + break; + case 1: + // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. + for (const WordIndex *w = reinterpret_cast(current_); w != reinterpret_cast(current_) + order; ++w, ++extend_out) *extend_out = *w; + current_ += entry_size_; + break; + case 0: + float &backoff = reinterpret_cast((uint8_t*)reader.Data() + order * sizeof(WordIndex))->backoff; + if (!HasExtension(backoff)) { + backoff = kExtensionBackoff; + reader.Overwrite(&backoff, sizeof(float)); + } else { + const ProbPointer &write_to = *reinterpret_cast(current_ + entry_size_ - sizeof(ProbPointer)); + base[write_to.array][write_to.index] += backoff; + } + current_ += entry_size_; + break; + } + } + // Now this is a list of blanks that extend right. + entry_size_ = sizeof(WordIndex) * order; + Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get())); + current_ = (uint8_t*)backing_.get(); } - WordIndex ReadCount() { - WordIndex ret; - ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); - return ret; - } - - WordIndex ReadWord() { - WordIndex ret; - ReadOrThrow(file_.get(), &ret, sizeof(WordIndex)); - return ret; - } - - template void ReadWeights(Weights &weights) { - ReadOrThrow(file_.get(), &weights, sizeof(Weights)); + // Call after Apply + bool Extends(unsigned char order, const WordIndex *words) { + if (current_ == allocated_) return false; + assert(order * sizeof(WordIndex) == entry_size_); + while (true) { + switch(Compare(order, words, current_)) { + case 1: + current_ += entry_size_; + if (current_ == allocated_) return false; + break; + case -1: + return false; + case 0: + return true; + } + } } - bool Ended() const { - return ended_; + private: + void FinishedAdding() { + Resize(current_ - (uint8_t*)backing_.get()); + current_ = (uint8_t*)backing_.get(); } - void Rewind() { - rewind(file_.get()); - ended_ = false; - NextHeader(); + void Resize(std::size_t to) { + std::size_t current = current_ - (uint8_t*)backing_.get(); + backing_.call_realloc(to); + current_ = (uint8_t*)backing_.get() + current; + allocated_ = (uint8_t*)backing_.get() + to; } - FILE *File() { return file_.get(); } - - private: - util::scoped_FILE file_; + util::scoped_malloc backing_; - std::vector header_; + uint8_t *current_, *allocated_; - bool ended_; + std::size_t entry_size_; }; -void CopyFullRecord(SortedFileReader &from, FILE *to, std::size_t weights_size) { - WriteOrThrow(to, from.Header(), from.HeaderBytes()); - WordIndex count = from.ReadCount(); - WriteOrThrow(to, &count, sizeof(WordIndex)); - - CopyOrThrow(from.File(), to, (weights_size + sizeof(WordIndex)) * count); -} - -void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order) { - SortedFileReader first, second; - first.Init(first_name.c_str(), order); - RemoveOrThrow(first_name.c_str()); - second.Init(second_name.c_str(), order); - RemoveOrThrow(second_name.c_str()); - util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); - while (!first.Ended() && !second.Ended()) { - if (first.HeaderVector() < second.HeaderVector()) { - CopyFullRecord(first, out_file.get(), weights_size); - first.NextHeader(); - continue; - } - if (first.HeaderVector() > second.HeaderVector()) { - CopyFullRecord(second, out_file.get(), weights_size); - second.NextHeader(); - continue; - } - // Merge at the entry level. - WriteOrThrow(out_file.get(), first.Header(), first.HeaderBytes()); - WordIndex first_count = first.ReadCount(), second_count = second.ReadCount(); - WordIndex total_count = first_count + second_count; - WriteOrThrow(out_file.get(), &total_count, sizeof(WordIndex)); - - WordIndex first_word = first.ReadWord(), second_word = second.ReadWord(); - WordIndex first_index = 0, second_index = 0; - while (true) { - if (first_word < second_word) { - WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); - CopyOrThrow(first.File(), out_file.get(), weights_size); - if (++first_index == first_count) break; - first_word = first.ReadWord(); - } else { - WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); - CopyOrThrow(second.File(), out_file.get(), weights_size); - if (++second_index == second_count) break; - second_word = second.ReadWord(); - } - } - if (first_index == first_count) { - WriteOrThrow(out_file.get(), &second_word, sizeof(WordIndex)); - CopyOrThrow(second.File(), out_file.get(), (second_count - second_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); - } else { - WriteOrThrow(out_file.get(), &first_word, sizeof(WordIndex)); - CopyOrThrow(first.File(), out_file.get(), (first_count - first_index) * (weights_size + sizeof(WordIndex)) - sizeof(WordIndex)); - } - first.NextHeader(); - second.NextHeader(); - } - - for (SortedFileReader &remaining = first.Ended() ? second : first; !remaining.Ended(); remaining.NextHeader()) { - CopyFullRecord(remaining, out_file.get(), weights_size); - } -} - -const char *kContextSuffix = "_contexts"; - -void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { - const size_t context_size = sizeof(WordIndex) * (order - 1); - // Sort just the contexts using the same memory. - PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); - PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); - - std::sort(context_begin, context_end, CompareRecords(order - 1)); - - std::string name(ngram_file_name + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); - - // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. - if (context_begin == context_end) return; - PartialIter i(context_begin); - WriteOrThrow(out.get(), i->Indices(), context_size); - const WordIndex *previous = i->Indices(); - ++i; - for (; i != context_end; ++i) { - if (memcmp(previous, i->Indices(), context_size)) { - WriteOrThrow(out.get(), i->Indices(), context_size); - previous = i->Indices(); - } - } -} +const float kBadProb = std::numeric_limits::infinity(); -class ContextReader { +class SRISucks { public: - ContextReader() : valid_(false) {} - - ContextReader(const char *name, unsigned char order) { - Reset(name, order); - } - - void Reset(const char *name, unsigned char order) { - file_.reset(OpenOrThrow(name, "r")); - length_ = sizeof(WordIndex) * static_cast(order); - words_.resize(order); - valid_ = true; - ++*this; - } - - ContextReader &operator++() { - if (1 != fread(&*words_.begin(), length_, 1, file_.get())) { - if (!feof(file_.get())) - UTIL_THROW(util::ErrnoException, "Short read"); - valid_ = false; + SRISucks() { + for (BackoffMessages *i = messages_; i != messages_ + kMaxOrder - 1; ++i) + i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1)); + } + + void Send(unsigned char begin, unsigned char order, const WordIndex *to, float prob_basis) { + assert(prob_basis != kBadProb); + ProbPointer pointer; + pointer.array = order - 1; + pointer.index = values_[order - 1].size(); + for (unsigned char i = begin; i < order; ++i) { + messages_[i - 1].Add(to, pointer); } - return *this; + values_[order - 1].push_back(prob_basis); } - const WordIndex *operator*() const { return &*words_.begin(); } - - operator bool() const { return valid_; } - - FILE *GetFile() { return file_.get(); } - - private: - util::scoped_FILE file_; - - size_t length_; - - std::vector words_; - - bool valid_; -}; - -void MergeContextFiles(const std::string &first_base, const std::string &second_base, const std::string &out_base, unsigned char order) { - const size_t context_size = sizeof(WordIndex) * (order - 1); - std::string first_name(first_base + kContextSuffix); - std::string second_name(second_base + kContextSuffix); - ContextReader first(first_name.c_str(), order - 1), second(second_name.c_str(), order - 1); - RemoveOrThrow(first_name.c_str()); - RemoveOrThrow(second_name.c_str()); - std::string out_name(out_base + kContextSuffix); - util::scoped_FILE out(OpenOrThrow(out_name.c_str(), "w")); - while (first && second) { - for (const WordIndex *f = *first, *s = *second; ; ++f, ++s) { - if (f == *first + order - 1) { - // Equal. - WriteOrThrow(out.get(), *first, context_size); - ++first; - ++second; - break; + void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) { + for (unsigned char i = 0; i < kMaxOrder - 1; ++i) { + it_[i] = &*values_[i].begin(); } - if (*f < *s) { - // First lower - WriteOrThrow(out.get(), *first, context_size); - ++first; - break; - } else if (*f > *s) { - WriteOrThrow(out.get(), *second, context_size); - ++second; - break; + messages_[0].Apply(it_, unigram_file); + BackoffMessages *messages = messages_ + 1; + const RecordReader *end = reader + total_order - 2 /* exclude unigrams and longest order */; + for (; reader != end; ++messages, ++reader) { + messages->Apply(it_, *reader); } } - } - ContextReader &remaining = first ? first : second; - if (!remaining) return; - WriteOrThrow(out.get(), *remaining, context_size); - CopyRestOrThrow(remaining.GetFile(), out.get()); -} -void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { - ReadNGramHeader(f, order); - const size_t count = counts[order - 1]; - // Size of weights. Does it include backoff? - const size_t words_size = sizeof(WordIndex) * order; - const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); - const size_t entry_size = words_size + weights_size; - const size_t batch_size = std::min(count, mem.size() / entry_size); - uint8_t *const begin = reinterpret_cast(mem.get()); - std::deque files; - for (std::size_t batch = 0, done = 0; done < count; ++batch) { - uint8_t *out = begin; - uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; - if (order == counts.size()) { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); - } - } else { - for (; out != out_end; out += entry_size) { - ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); - } + ProbBackoff GetBlank(unsigned char total_order, unsigned char order, const WordIndex *indices) { + assert(order > 1); + ProbBackoff ret; + ret.prob = *(it_[order - 1]++); + ret.backoff = ((order != total_order - 1) && messages_[order - 1].Extends(order, indices)) ? kExtensionBackoff : kNoExtensionBackoff; + return ret; } - // Sort full records by full n-gram. - EntryProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); - // parallel_sort uses too much RAM - std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), CompareRecords(order)); - files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order, weights_size)); - WriteContextFile(begin, out_end, files.back(), entry_size, order); - - done += (out_end - begin) / entry_size; - } - // All individual files created. Merge them. - - std::size_t merge_count = 0; - while (files.size() > 1) { - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); - files.push_back(assembled.str()); - MergeSortedFiles(files[0], files[1], files.back(), weights_size, order); - MergeContextFiles(files[0], files[1], files.back(), order); - files.pop_front(); - files.pop_front(); - } - if (!files.empty()) { - std::stringstream assembled; - assembled << file_prefix << static_cast(order) << "_merged"; - std::string merged_name(assembled.str()); - if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); - std::string context_name = files[0] + kContextSuffix; - merged_name += kContextSuffix; - if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); - } -} - -void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { - PositiveProbWarn warn(config.positive_log_probability); - { - std::string unigram_name = file_prefix + "unigrams"; - util::scoped_fd unigram_file; - // In case appears. - size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); - util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); - Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); - CheckSpecials(config, vocab); - if (!vocab.SawUnk()) ++counts[0]; - } + const std::vector &Values(unsigned char order) const { + return values_[order - 1]; + } - // Only use as much buffer as we need. - size_t buffer_use = 0; - for (unsigned int order = 2; order < counts.size(); ++order) { - buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); - } - buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); - buffer = std::min(buffer, buffer_use); + private: + // This used to be one array. Then I needed to separate it by order for quantization to work. + std::vector values_[kMaxOrder - 1]; + BackoffMessages messages_[kMaxOrder - 1]; - util::scoped_memory mem; - mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); - if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + float *it_[kMaxOrder - 1]; +}; - for (unsigned char order = 2; order <= counts.size(); ++order) { - ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); - } - ReadEnd(f); -} +class FindBlanks { + public: + FindBlanks(uint64_t *counts, unsigned char order, const ProbBackoff *unigrams, SRISucks &messages) + : counts_(counts), longest_counts_(counts + order - 1), unigrams_(unigrams), sri_(messages) {} -bool HeadMatch(const WordIndex *words, const WordIndex *const words_end, const WordIndex *header) { - for (; words != words_end; ++words, ++header) { - if (*words != *header) { - //assert(*words <= *header); - return false; + float UnigramProb(WordIndex index) const { + return unigrams_[index].prob; } - } - return true; -} -// Phase to count n-grams, including blanks inserted because they were pruned but have extensions -class JustCount { - public: - template JustCount(ContextReader * /*contexts*/, UnigramValue * /*unigrams*/, Middle * /*middle*/, Longest &/*longest*/, uint64_t *counts, unsigned char order) - : counts_(counts), longest_counts_(counts + order - 1) {} - - void Unigrams(WordIndex begin, WordIndex end) { - counts_[0] += end - begin; + void Unigram(WordIndex /*index*/) { + ++counts_[0]; } - void MiddleBlank(const unsigned char mid_idx, WordIndex /* idx */) { - ++counts_[mid_idx + 1]; + void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) { + sri_.Send(lower, order, indices + 1, prob_basis); + ++counts_[order - 1]; } - void Middle(const unsigned char mid_idx, const WordIndex * /*before*/, WordIndex /*key*/, const ProbBackoff &/*weights*/) { - ++counts_[mid_idx + 1]; + void Middle(const unsigned char order, const void * /*data*/) { + ++counts_[order - 1]; } - void Longest(WordIndex /*key*/, Prob /*prob*/) { + void Longest(const void * /*data*/) { ++*longest_counts_; } @@ -608,167 +253,156 @@ class JustCount { private: uint64_t *const counts_, *const longest_counts_; + + const ProbBackoff *unigrams_; + + SRISucks &sri_; }; // Phase to actually write n-grams to the trie. template class WriteEntries { public: - WriteEntries(ContextReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, const uint64_t * /*counts*/, unsigned char order) : + WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), unigrams_(unigrams), middle_(middle), longest_(longest), - bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)) {} + bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)), + order_(order), + sri_(sri) {} - void Unigrams(WordIndex begin, WordIndex end) { - uint64_t next = bigram_pack_.InsertIndex(); - for (UnigramValue *i = unigrams_ + begin; i < unigrams_ + end; ++i) { - i->next = next; - } + float UnigramProb(WordIndex index) const { return unigrams_[index].weights.prob; } + + void Unigram(WordIndex word) { + unigrams_[word].next = bigram_pack_.InsertIndex(); } - void MiddleBlank(const unsigned char mid_idx, WordIndex key) { - middle_[mid_idx].Insert(key, kBlankProb, kBlankBackoff); + void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) { + ProbBackoff weights = sri_.GetBlank(order_, order, indices); + middle_[order - 2].Insert(indices[order - 1], weights.prob, weights.backoff); } - void Middle(const unsigned char mid_idx, const WordIndex *before, WordIndex key, ProbBackoff weights) { - // Order (mid_idx+2). - ContextReader &context = contexts_[mid_idx + 1]; - if (context && !memcmp(before, *context, sizeof(WordIndex) * (mid_idx + 1)) && (*context)[mid_idx + 1] == key) { + void Middle(const unsigned char order, const void *data) { + RecordReader &context = contexts_[order - 1]; + const WordIndex *words = reinterpret_cast(data); + ProbBackoff weights = *reinterpret_cast(words + order); + if (context && !memcmp(data, context.Data(), sizeof(WordIndex) * order)) { SetExtension(weights.backoff); ++context; } - middle_[mid_idx].Insert(key, weights.prob, weights.backoff); + middle_[order - 2].Insert(words[order - 1], weights.prob, weights.backoff); } - void Longest(WordIndex key, Prob prob) { - longest_.Insert(key, prob.prob); + void Longest(const void *data) { + const WordIndex *words = reinterpret_cast(data); + longest_.Insert(words[order_ - 1], reinterpret_cast(words + order_)->prob); } void Cleanup() {} private: - ContextReader *contexts_; + RecordReader *contexts_; UnigramValue *const unigrams_; BitPackedMiddle *const middle_; BitPackedLongest &longest_; BitPacked &bigram_pack_; + const unsigned char order_; + SRISucks &sri_; }; -template class RecursiveInsert { - public: - template RecursiveInsert(SortedFileReader *inputs, ContextReader *contexts, UnigramValue *unigrams, MiddleT *middle, LongestT &longest, uint64_t *counts, unsigned char order) : - doing_(contexts, unigrams, middle, longest, counts, order), inputs_(inputs), inputs_end_(inputs + order - 1), order_minus_2_(order - 2) { - } +struct Gram { + Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {} - // Outer unigram loop. - void Apply(std::ostream *progress_out, const char *message, WordIndex unigram_count) { - util::ErsatzProgress progress(progress_out, message, unigram_count + 1); - for (words_[0] = 0; ; ++words_[0]) { - progress.Set(words_[0]); - WordIndex min_continue = unigram_count; - for (SortedFileReader *other = inputs_; other != inputs_end_; ++other) { - if (other->Ended()) continue; - min_continue = std::min(min_continue, other->Header()[0]); - } - // This will write at unigram_count. This is by design so that the next pointers will make sense. - doing_.Unigrams(words_[0], min_continue + 1); - if (min_continue == unigram_count) break; - words_[0] = min_continue; - Middle(0); - } - doing_.Cleanup(); - } + const WordIndex *begin, *end; - private: - void Middle(const unsigned char mid_idx) { - // (mid_idx + 2)-gram. - if (mid_idx == order_minus_2_) { - Longest(); - return; - } - // Orders [2, order) + // For queue, this is the direction we want. + bool operator<(const Gram &other) const { + return std::lexicographical_compare(other.begin, other.end, begin, end); + } +}; - SortedFileReader &reader = inputs_[mid_idx]; +template class BlankManager { + public: + BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) { + for (float *i = basis_; i != basis_ + kMaxOrder - 1; ++i) *i = kBadProb; + } - if (reader.Ended() || !HeadMatch(words_, words_ + mid_idx + 1, reader.Header())) { - // This order doesn't have a header match, but longer ones might. - MiddleAllBlank(mid_idx); - return; + void Visit(const WordIndex *to, unsigned char length, float prob) { + basis_[length - 1] = prob; + unsigned char overlap = std::min(length - 1, been_length_); + const WordIndex *cur; + WordIndex *pre; + for (cur = to, pre = been_; cur != to + overlap; ++cur, ++pre) { + if (*pre != *cur) break; } - - // There is a header match. - WordIndex count = reader.ReadCount(); - WordIndex current = reader.ReadWord(); - while (count) { - WordIndex min_continue = std::numeric_limits::max(); - for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) - min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); - } - while (true) { - if (current > min_continue) { - doing_.MiddleBlank(mid_idx, min_continue); - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); - break; - } - ProbBackoff weights; - reader.ReadWeights(weights); - doing_.Middle(mid_idx, words_, current, weights); - --count; - if (current == min_continue) { - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); - if (count) current = reader.ReadWord(); - break; - } - if (!count) break; - current = reader.ReadWord(); - } + if (cur == to + length - 1) { + *pre = *cur; + been_length_ = length; + return; } - // Count is now zero. Finish off remaining blanks. - MiddleAllBlank(mid_idx); - reader.NextHeader(); - } - - void MiddleAllBlank(const unsigned char mid_idx) { - while (true) { - WordIndex min_continue = std::numeric_limits::max(); - for (SortedFileReader *other = inputs_ + mid_idx + 1; other < inputs_end_; ++other) { - if (!other->Ended() && HeadMatch(words_, words_ + mid_idx + 1, other->Header())) - min_continue = std::min(min_continue, other->Header()[mid_idx + 1]); - } - if (min_continue == std::numeric_limits::max()) return; - doing_.MiddleBlank(mid_idx, min_continue); - words_[mid_idx + 1] = min_continue; - Middle(mid_idx + 1); + // There are blanks to insert starting with order blank. + unsigned char blank = cur - to + 1; + UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); + const float *lower_basis; + for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {} + unsigned char based_on = lower_basis - basis_ + 1; + for (; cur != to + length - 1; ++blank, ++cur, ++pre) { + assert(*lower_basis != kBadProb); + doing_.MiddleBlank(blank, to, based_on, *lower_basis); + *pre = *cur; + // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. + basis_[blank - 1] = kBadProb; } + been_length_ = length; } - void Longest() { - SortedFileReader &reader = *(inputs_end_ - 1); - if (reader.Ended() || !HeadMatch(words_, words_ + order_minus_2_ + 1, reader.Header())) return; - WordIndex count = reader.ReadCount(); - for (WordIndex i = 0; i < count; ++i) { - WordIndex word = reader.ReadWord(); - Prob prob; - reader.ReadWeights(prob); - doing_.Longest(word, prob); - } - reader.NextHeader(); - return; - } + private: + const unsigned char total_order_; - Doing doing_; + WordIndex been_[kMaxOrder]; + unsigned char been_length_; - SortedFileReader *inputs_; - SortedFileReader *inputs_end_; + float basis_[kMaxOrder]; + + Doing &doing_; +}; - WordIndex words_[kMaxOrder]; +template void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) { + util::ErsatzProgress progress(progress_out, message, unigram_count + 1); + unsigned int unigram = 0; + std::priority_queue grams; + grams.push(Gram(&unigram, 1)); + for (unsigned char i = 2; i <= total_order; ++i) { + if (input[i-2]) grams.push(Gram(reinterpret_cast(input[i-2].Data()), i)); + } - const unsigned char order_minus_2_; -}; + BlankManager blank(total_order, doing); + + while (true) { + Gram top = grams.top(); + grams.pop(); + unsigned char order = top.end - top.begin; + if (order == 1) { + blank.Visit(&unigram, 1, doing.UnigramProb(unigram)); + doing.Unigram(unigram); + progress.Set(unigram); + if (++unigram == unigram_count + 1) break; + grams.push(top); + } else { + if (order == total_order) { + blank.Visit(top.begin, order, reinterpret_cast(top.end)->prob); + doing.Longest(top.begin); + } else { + blank.Visit(top.begin, order, reinterpret_cast(top.end)->prob); + doing.Middle(order, top.begin); + } + RecordReader &reader = input[order - 2]; + if (++reader) grams.push(top); + } + } + assert(grams.empty()); + doing.Cleanup(); +} void SanityCheckCounts(const std::vector &initial, const std::vector &fixed) { if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]); @@ -778,120 +412,122 @@ void SanityCheckCounts(const std::vector &initial, const std::vector void TrainQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - ProbBackoff weights; - std::vector probs, backoffs; - probs.reserve(count); +template void TrainQuantizer(uint8_t order, uint64_t count, const std::vector &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { + std::vector probs(additional), backoffs; + probs.reserve(count + additional.size()); backoffs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); - ++progress; - } + for (reader.Rewind(); reader; ++reader) { + const ProbBackoff &weights = *reinterpret_cast(reinterpret_cast(reader.Data()) + sizeof(WordIndex) * order); + probs.push_back(weights.prob); + if (weights.backoff != 0.0) backoffs.push_back(weights.backoff); + ++progress; } quant.Train(order, probs, backoffs); } -template void TrainProbQuantizer(uint8_t order, uint64_t count, SortedFileReader &reader, util::ErsatzProgress &progress, Quant &quant) { - Prob weights; +template void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) { std::vector probs, backoffs; probs.reserve(count); - for (reader.Rewind(); !reader.Ended(); reader.NextHeader()) { - uint64_t entries = reader.ReadCount(); - for (uint64_t c = 0; c < entries; ++c) { - reader.ReadWord(); - reader.ReadWeights(weights); - // kBlankProb isn't added yet. - probs.push_back(weights.prob); - ++progress; - } + for (reader.Rewind(); reader; ++reader) { + const Prob &weights = *reinterpret_cast(reinterpret_cast(reader.Data()) + sizeof(WordIndex) * order); + probs.push_back(weights.prob); + ++progress; } quant.TrainProb(order, probs); } +void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { + // Fill unigram probabilities. + try { + rewind(file); + for (WordIndex i = 0; i < unigram_count; ++i) { + ReadOrThrow(file, &unigrams[i].weights, sizeof(ProbBackoff)); + if (contexts && *reinterpret_cast(contexts.Data()) == i) { + SetExtension(unigrams[i].weights.backoff); + ++contexts; + } + } + } catch (util::Exception &e) { + e << " while re-reading unigram probabilities"; + throw; + } +} + } // namespace template void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing) { - std::vector inputs(counts.size() - 1); - std::vector contexts(counts.size() - 1); + RecordReader inputs[kMaxOrder - 1]; + RecordReader contexts[kMaxOrder - 1]; for (unsigned char i = 2; i <= counts.size(); ++i) { std::stringstream assembled; assembled << file_prefix << static_cast(i) << "_merged"; - inputs[i-2].Init(assembled.str(), i); - RemoveOrThrow(assembled.str().c_str()); + inputs[i-2].Init(assembled.str(), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff))); + util::RemoveOrThrow(assembled.str().c_str()); assembled << kContextSuffix; - contexts[i-2].Reset(assembled.str().c_str(), i-1); - RemoveOrThrow(assembled.str().c_str()); + contexts[i-2].Init(assembled.str(), (i-1) * sizeof(WordIndex)); + util::RemoveOrThrow(assembled.str().c_str()); } + SRISucks sri; std::vector fixed_counts(counts.size()); { - RecursiveInsert counter(&*inputs.begin(), &*contexts.begin(), NULL, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - counter.Apply(config.messages, "Counting n-grams that should not have been pruned", counts[0]); + std::string temp(file_prefix); temp += "unigrams"; + util::scoped_fd unigram_file(util::OpenReadOrThrow(temp.c_str())); + util::scoped_memory unigrams; + MapRead(util::POPULATE_OR_READ, unigram_file.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); + FindBlanks finder(&*fixed_counts.begin(), counts.size(), reinterpret_cast(unigrams.get()), sri); + RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); } - for (std::vector::const_iterator i = inputs.begin(); i != inputs.end(); ++i) { - if (!i->Ended()) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs.begin() + 2) << "-gram table did not complete reading"); + for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) { + if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading"); } SanityCheckCounts(counts, fixed_counts); counts = fixed_counts; + util::scoped_FILE unigram_file; + { + std::string name(file_prefix + "unigrams"); + unigram_file.reset(OpenOrThrow(name.c_str(), "r")); + util::RemoveOrThrow(name.c_str()); + } + sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs); + out.SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), TrieSearch::Size(fixed_counts, config), backing), fixed_counts, config); + for (unsigned char i = 2; i <= counts.size(); ++i) { + inputs[i-2].Rewind(); + } if (Quant::kTrain) { util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0)); for (unsigned char i = 2; i < counts.size(); ++i) { - TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant); + TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant); quant.FinishedLoading(config); } + UnigramValue *unigrams = out.unigram.Raw(); + PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams); + unigram_file.reset(); + for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); } - UnigramValue *unigrams = out.unigram.Raw(); // Fill entries except unigram probabilities. { - RecursiveInsert > inserter(&*inputs.begin(), &*contexts.begin(), unigrams, out.middle_begin_, out.longest, &*fixed_counts.begin(), counts.size()); - inserter.Apply(config.messages, "Building trie", fixed_counts[0]); - } - - // Fill unigram probabilities. - try { - std::string name(file_prefix + "unigrams"); - util::scoped_FILE file(OpenOrThrow(name.c_str(), "r")); - for (WordIndex i = 0; i < counts[0]; ++i) { - ReadOrThrow(file.get(), &unigrams[i].weights, sizeof(ProbBackoff)); - if (contexts[0] && **contexts[0] == i) { - SetExtension(unigrams[i].weights.backoff); - ++contexts[0]; - } - } - RemoveOrThrow(name.c_str()); - } catch (util::Exception &e) { - e << " while re-reading unigram probabilities"; - throw; + WriteEntries writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri); + RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); } // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. for (unsigned char order = 2; order <= counts.size(); ++order) { - const ContextReader &context = contexts[order - 2]; + const RecordReader &context = contexts[order - 2]; if (context) { FormatLoadException e; - e << "An " << static_cast(order) << "-gram has the context (i.e. all but the last word):"; - for (const WordIndex *i = *context; i != *context + order - 1; ++i) { + e << "An " << static_cast(order) << "-gram has context"; + const WordIndex *ctx = reinterpret_cast(context.Data()); + for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) { e << ' ' << *i; } e << " so this context must appear in the model as a " << static_cast(order - 1) << "-gram but it does not"; @@ -945,6 +581,14 @@ template void TrieSearch::LoadedBin longest.LoadedBinary(); } +namespace { +bool IsDirectory(const char *path) { + struct stat info; + if (0 != stat(path, &info)) return false; + return S_ISDIR(info.st_mode); +} +} // namespace + template void TrieSearch::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing) { std::string temporary_directory; if (config.temporary_directory_prefix) { diff --git a/klm/lm/search_trie.hh b/klm/lm/search_trie.hh index 2f39c09f..c3e02a98 100644 --- a/klm/lm/search_trie.hh +++ b/klm/lm/search_trie.hh @@ -1,10 +1,16 @@ #ifndef LM_SEARCH_TRIE__ #define LM_SEARCH_TRIE__ -#include "lm/binary_format.hh" +#include "lm/config.hh" +#include "lm/model_type.hh" +#include "lm/return.hh" #include "lm/trie.hh" #include "lm/weights.hh" +#include "util/file_piece.hh" + +#include + #include namespace lm { @@ -30,6 +36,8 @@ template class TrieSearch { static const ModelType kModelType = static_cast(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); + static const unsigned int kVersion = 0; + static void UpdateConfigFromBinary(int fd, const std::vector &counts, Config &config) { Quant::UpdateConfigFromBinary(fd, counts, config); AdvanceOrThrow(fd, Quant::Size(counts.size(), config) + Unigram::Size(counts[0])); @@ -57,12 +65,16 @@ template class TrieSearch { void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector &counts, const Config &config, SortedVocabulary &vocab, Backing &backing); - void LookupUnigram(WordIndex word, float &prob, float &backoff, Node &node) const { - unigram.Find(word, prob, backoff, node); + void LookupUnigram(WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { + unigram.Find(word, ret.prob, backoff, node); + ret.independent_left = (node.begin == node.end); + ret.extend_left = static_cast(word); } - bool LookupMiddle(const Middle &mid, WordIndex word, float &prob, float &backoff, Node &node) const { - return mid.Find(word, prob, backoff, node); + bool LookupMiddle(const Middle &mid, WordIndex word, float &backoff, Node &node, FullScoreReturn &ret) const { + if (!mid.Find(word, ret.prob, backoff, node, ret.extend_left)) return false; + ret.independent_left = (node.begin == node.end); + return true; } bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const { @@ -76,14 +88,25 @@ template class TrieSearch { bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { // TODO: don't decode backoff. assert(begin != end); - float ignored_prob, ignored_backoff; - LookupUnigram(*begin, ignored_prob, ignored_backoff, node); + FullScoreReturn ignored; + float ignored_backoff; + LookupUnigram(*begin, ignored_backoff, node, ignored); for (const WordIndex *i = begin + 1; i < end; ++i) { if (!LookupMiddleNoProb(middle_begin_[i - begin - 1], *i, ignored_backoff, node)) return false; } return true; } + Node Unpack(uint64_t extend_pointer, unsigned char extend_length, float &prob) const { + if (extend_length == 1) { + float ignored; + Node ret; + unigram.Find(static_cast(extend_pointer), prob, ignored, ret); + return ret; + } + return middle_begin_[extend_length - 2].ReadEntry(extend_pointer, prob); + } + private: friend void BuildTrie(const std::string &file_prefix, std::vector &counts, const Config &config, TrieSearch &out, Quant &quant, const SortedVocabulary &vocab, Backing &backing); diff --git a/klm/lm/trie.cc b/klm/lm/trie.cc index 8c536e66..4e60b184 100644 --- a/klm/lm/trie.cc +++ b/klm/lm/trie.cc @@ -86,7 +86,7 @@ template void BitPackedMiddle::Inse ++insert_index_; } -template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const { +template bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const { uint64_t at_pointer; if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) { return false; @@ -94,6 +94,9 @@ template bool BitPackedMiddle::Find uint64_t index = at_pointer; at_pointer *= total_bits_; at_pointer += word_bits_; + + pointer = at_pointer; + quant_.Read(base_, at_pointer, prob, backoff); at_pointer += quant_.TotalBits(); diff --git a/klm/lm/trie.hh b/klm/lm/trie.hh index 53612064..a9f5e417 100644 --- a/klm/lm/trie.hh +++ b/klm/lm/trie.hh @@ -94,10 +94,18 @@ template class BitPackedMiddle : public BitPacked { void LoadedBinary() { bhiksha_.LoadedBinary(); } - bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const; + bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range, uint64_t &pointer) const; bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const; + NodeRange ReadEntry(uint64_t pointer, float &prob) { + quant_.ReadProb(base_, pointer, prob); + NodeRange ret; + // pointer/total_bits_ should always round down. + bhiksha_.ReadNext(base_, pointer + quant_.TotalBits(), pointer / total_bits_, total_bits_, ret); + return ret; + } + private: Quant quant_; Bhiksha bhiksha_; diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc new file mode 100644 index 00000000..01c4e490 --- /dev/null +++ b/klm/lm/trie_sort.cc @@ -0,0 +1,261 @@ +#include "lm/trie_sort.hh" + +#include "lm/config.hh" +#include "lm/lm_exception.hh" +#include "lm/read_arpa.hh" +#include "lm/vocab.hh" +#include "lm/weights.hh" +#include "lm/word_index.hh" +#include "util/file_piece.hh" +#include "util/mmap.hh" +#include "util/proxy_iterator.hh" +#include "util/sized_iterator.hh" + +#include +#include +#include +#include +#include +#include + +namespace lm { +namespace ngram { +namespace trie { + +const char *kContextSuffix = "_contexts"; + +FILE *OpenOrThrow(const char *name, const char *mode) { + FILE *ret = fopen(name, mode); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode); + return ret; +} + +void WriteOrThrow(FILE *to, const void *data, size_t size) { + assert(size); + if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); +} + +namespace { + +typedef util::SizedIterator NGramIter; + +// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams. +class PartialViewProxy { + public: + PartialViewProxy() : attention_size_(0), inner_() {} + + PartialViewProxy(void *ptr, std::size_t block_size, std::size_t attention_size) : attention_size_(attention_size), inner_(ptr, block_size) {} + + operator std::string() const { + return std::string(reinterpret_cast(inner_.Data()), attention_size_); + } + + PartialViewProxy &operator=(const PartialViewProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), attention_size_); + return *this; + } + + PartialViewProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), attention_size_); + return *this; + } + + const void *Data() const { return inner_.Data(); } + void *Data() { return inner_.Data(); } + + private: + friend class util::ProxyIterator; + + typedef std::string value_type; + + const std::size_t attention_size_; + + typedef util::SizedInnerIterator InnerIterator; + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef util::ProxyIterator PartialIter; + +std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order) { + std::stringstream assembled; + assembled << file_prefix << static_cast(order) << '_' << batch; + std::string ret(assembled.str()); + util::scoped_fd out(util::CreateOrThrow(ret.c_str())); + util::WriteOrThrow(out.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); + return ret; +} + +void WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &ngram_file_name, std::size_t entry_size, unsigned char order) { + const size_t context_size = sizeof(WordIndex) * (order - 1); + // Sort just the contexts using the same memory. + PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); + PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, context_size)); + + std::sort(context_begin, context_end, util::SizedCompare(EntryCompare(order - 1))); + + std::string name(ngram_file_name + kContextSuffix); + util::scoped_FILE out(OpenOrThrow(name.c_str(), "w")); + + // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. + if (context_begin == context_end) return; + PartialIter i(context_begin); + WriteOrThrow(out.get(), i->Data(), context_size); + const void *previous = i->Data(); + ++i; + for (; i != context_end; ++i) { + if (memcmp(previous, i->Data(), context_size)) { + WriteOrThrow(out.get(), i->Data(), context_size); + previous = i->Data(); + } + } +} + +struct ThrowCombine { + void operator()(std::size_t /*entry_size*/, const void * /*first*/, const void * /*second*/, FILE * /*out*/) const { + UTIL_THROW(FormatLoadException, "Duplicate n-gram detected."); + } +}; + +// Useful for context files that just contain records with no value. +struct FirstCombine { + void operator()(std::size_t entry_size, const void *first, const void * /*second*/, FILE *out) const { + WriteOrThrow(out, first, entry_size); + } +}; + +template void MergeSortedFiles(const std::string &first_name, const std::string &second_name, const std::string &out, std::size_t weights_size, unsigned char order, const Combine &combine = ThrowCombine()) { + std::size_t entry_size = sizeof(WordIndex) * order + weights_size; + RecordReader first, second; + first.Init(first_name.c_str(), entry_size); + util::RemoveOrThrow(first_name.c_str()); + second.Init(second_name.c_str(), entry_size); + util::RemoveOrThrow(second_name.c_str()); + util::scoped_FILE out_file(OpenOrThrow(out.c_str(), "w")); + EntryCompare less(order); + while (first && second) { + if (less(first.Data(), second.Data())) { + WriteOrThrow(out_file.get(), first.Data(), entry_size); + ++first; + } else if (less(second.Data(), first.Data())) { + WriteOrThrow(out_file.get(), second.Data(), entry_size); + ++second; + } else { + combine(entry_size, first.Data(), second.Data(), out_file.get()); + ++first; ++second; + } + } + for (RecordReader &remains = (first ? second : first); remains; ++remains) { + WriteOrThrow(out_file.get(), remains.Data(), entry_size); + } +} + +void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, util::scoped_memory &mem, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn) { + ReadNGramHeader(f, order); + const size_t count = counts[order - 1]; + // Size of weights. Does it include backoff? + const size_t words_size = sizeof(WordIndex) * order; + const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float)); + const size_t entry_size = words_size + weights_size; + const size_t batch_size = std::min(count, mem.size() / entry_size); + uint8_t *const begin = reinterpret_cast(mem.get()); + std::deque files; + for (std::size_t batch = 0, done = 0; done < count; ++batch) { + uint8_t *out = begin; + uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size; + if (order == counts.size()) { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + } + } else { + for (; out != out_end; out += entry_size) { + ReadNGram(f, order, vocab, reinterpret_cast(out), *reinterpret_cast(out + words_size), warn); + } + } + // Sort full records by full n-gram. + util::SizedProxy proxy_begin(begin, entry_size), proxy_end(out_end, entry_size); + // parallel_sort uses too much RAM + std::sort(NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare(EntryCompare(order))); + files.push_back(DiskFlush(begin, out_end, file_prefix, batch, order)); + WriteContextFile(begin, out_end, files.back(), entry_size, order); + + done += (out_end - begin) / entry_size; + } + + // All individual files created. Merge them. + + std::size_t merge_count = 0; + while (files.size() > 1) { + std::stringstream assembled; + assembled << file_prefix << static_cast(order) << "_merge_" << (merge_count++); + files.push_back(assembled.str()); + MergeSortedFiles(files[0], files[1], files.back(), weights_size, order, ThrowCombine()); + MergeSortedFiles(files[0], files[1], files.back(), 0, order, FirstCombine()); + files.pop_front(); + files.pop_front(); + } + if (!files.empty()) { + std::stringstream assembled; + assembled << file_prefix << static_cast(order) << "_merged"; + std::string merged_name(assembled.str()); + if (std::rename(files[0].c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << files[0].c_str() << " to " << merged_name.c_str()); + std::string context_name = files[0] + kContextSuffix; + merged_name += kContextSuffix; + if (std::rename(context_name.c_str(), merged_name.c_str())) UTIL_THROW(util::ErrnoException, "Could not rename " << context_name << " to " << merged_name.c_str()); + } +} + +} // namespace + +void RecordReader::Init(const std::string &name, std::size_t entry_size) { + file_.reset(OpenOrThrow(name.c_str(), "r+")); + data_.reset(malloc(entry_size)); + UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer"); + remains_ = true; + entry_size_ = entry_size; + ++*this; +} + +void RecordReader::Overwrite(const void *start, std::size_t amount) { + long internal = (uint8_t*)start - (uint8_t*)data_.get(); + UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision"); + WriteOrThrow(file_.get(), start, amount); + long forward = entry_size_ - internal - amount; + if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision"); +} + +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { + PositiveProbWarn warn(config.positive_log_probability); + { + std::string unigram_name = file_prefix + "unigrams"; + util::scoped_fd unigram_file; + // In case appears. + size_t file_out = (counts[0] + 1) * sizeof(ProbBackoff); + util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_name.c_str(), file_out, unigram_file), file_out); + Read1Grams(f, counts[0], vocab, reinterpret_cast(unigram_mmap.get()), warn); + CheckSpecials(config, vocab); + if (!vocab.SawUnk()) ++counts[0]; + } + + // Only use as much buffer as we need. + size_t buffer_use = 0; + for (unsigned int order = 2; order < counts.size(); ++order) { + buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1])); + } + buffer_use = std::max(buffer_use, static_cast((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back())); + buffer = std::min(buffer, buffer_use); + + util::scoped_memory mem; + mem.reset(malloc(buffer), buffer, util::scoped_memory::MALLOC_ALLOCATED); + if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer); + + for (unsigned char order = 2; order <= counts.size(); ++order) { + ConvertToSorted(f, vocab, counts, mem, file_prefix, order, warn); + } + ReadEnd(f); +} + +} // namespace trie +} // namespace ngram +} // namespace lm diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh new file mode 100644 index 00000000..a6916483 --- /dev/null +++ b/klm/lm/trie_sort.hh @@ -0,0 +1,94 @@ +#ifndef LM_TRIE_SORT__ +#define LM_TRIE_SORT__ + +#include "lm/word_index.hh" + +#include "util/file.hh" +#include "util/scoped.hh" + +#include +#include +#include +#include + +#include + +namespace util { class FilePiece; } + +// Step of trie builder: create sorted files. +namespace lm { +namespace ngram { +class SortedVocabulary; +class Config; + +namespace trie { + +extern const char *kContextSuffix; +FILE *OpenOrThrow(const char *name, const char *mode); +void WriteOrThrow(FILE *to, const void *data, size_t size); + +class EntryCompare : public std::binary_function { + public: + explicit EntryCompare(unsigned char order) : order_(order) {} + + bool operator()(const void *first_void, const void *second_void) const { + const WordIndex *first = static_cast(first_void); + const WordIndex *second = static_cast(second_void); + const WordIndex *end = first + order_; + for (; first != end; ++first, ++second) { + if (*first < *second) return true; + if (*first > *second) return false; + } + return false; + } + private: + unsigned char order_; +}; + +class RecordReader { + public: + RecordReader() : remains_(true) {} + + void Init(const std::string &name, std::size_t entry_size); + + void *Data() { return data_.get(); } + const void *Data() const { return data_.get(); } + + RecordReader &operator++() { + std::size_t ret = fread(data_.get(), entry_size_, 1, file_.get()); + if (!ret) { + UTIL_THROW_IF(!feof(file_.get()), util::ErrnoException, "Error reading temporary file"); + remains_ = false; + } + return *this; + } + + operator bool() const { return remains_; } + + void Rewind() { + rewind(file_.get()); + remains_ = true; + ++*this; + } + + std::size_t EntrySize() const { return entry_size_; } + + void Overwrite(const void *start, std::size_t amount); + + private: + util::scoped_malloc data_; + + bool remains_; + + std::size_t entry_size_; + + util::scoped_FILE file_; +}; + +void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab); + +} // namespace trie +} // namespace ngram +} // namespace lm + +#endif // LM_TRIE_SORT__ diff --git a/klm/lm/virtual_interface.hh b/klm/lm/virtual_interface.hh index 08627efd..6a5a0196 100644 --- a/klm/lm/virtual_interface.hh +++ b/klm/lm/virtual_interface.hh @@ -1,37 +1,13 @@ #ifndef LM_VIRTUAL_INTERFACE__ #define LM_VIRTUAL_INTERFACE__ +#include "lm/return.hh" #include "lm/word_index.hh" #include "util/string_piece.hh" #include namespace lm { - -/* Structure returned by scoring routines. */ -struct FullScoreReturn { - // log10 probability - float prob; - - /* The length of n-gram matched. Do not use this for recombination. - * Consider a model containing only the following n-grams: - * -1 foo - * -3.14 bar - * -2.718 baz -5 - * -6 foo bar - * - * If you score ``bar'' then ngram_length is 1 and recombination state is the - * empty string because bar has zero backoff and does not extend to the - * right. - * If you score ``foo'' then ngram_length is 1 and recombination state is - * ``foo''. - * - * Ideally, keep output states around and compare them. Failing that, - * get out_state.ValidLength() and use that length for recombination. - */ - unsigned char ngram_length; -}; - namespace base { template class ModelFacade; diff --git a/klm/lm/vocab.cc b/klm/lm/vocab.cc index 04979d51..03b0767a 100644 --- a/klm/lm/vocab.cc +++ b/klm/lm/vocab.cc @@ -1,5 +1,6 @@ #include "lm/vocab.hh" +#include "lm/binary_format.hh" #include "lm/enumerate_vocab.hh" #include "lm/lm_exception.hh" #include "lm/config.hh" @@ -56,16 +57,6 @@ WordIndex ReadWords(int fd, EnumerateVocab *enumerate) { } } -void WriteOrThrow(int fd, const void *data_void, std::size_t size) { - const uint8_t *data = static_cast(data_void); - while (size) { - ssize_t ret = write(fd, data, size); - if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); - data += ret; - size -= ret; - } -} - } // namespace WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {} @@ -80,7 +71,7 @@ void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) { void WriteWordsWrapper::Write(int fd) { if ((off_t)-1 == lseek(fd, 0, SEEK_END)) UTIL_THROW(util::ErrnoException, "Failed to seek in binary to vocab words"); - WriteOrThrow(fd, buffer_.data(), buffer_.size()); + util::WriteOrThrow(fd, buffer_.data(), buffer_.size()); } SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {} @@ -146,15 +137,28 @@ void SortedVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { SetSpecial(Index(""), Index(""), 0); } +namespace { +const unsigned int kProbingVocabularyVersion = 0; +} // namespace + +namespace detail { +struct ProbingVocabularyHeader { + // Lowest unused vocab id. This is also the number of words, including . + unsigned int version; + WordIndex bound; +}; +} // namespace detail + ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {} std::size_t ProbingVocabulary::Size(std::size_t entries, const Config &config) { - return Lookup::Size(entries, config.probing_multiplier); + return Align8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, config.probing_multiplier); } void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { - lookup_ = Lookup(start, allocated); - available_ = 1; + header_ = static_cast(start); + lookup_ = Lookup(static_cast(start) + Align8(sizeof(detail::ProbingVocabularyHeader)), allocated); + bound_ = 1; saw_unk_ = false; } @@ -172,20 +176,24 @@ WordIndex ProbingVocabulary::Insert(const StringPiece &str) { saw_unk_ = true; return 0; } else { - if (enumerate_) enumerate_->Add(available_, str); - lookup_.Insert(Lookup::Packing::Make(hashed, available_)); - return available_++; + if (enumerate_) enumerate_->Add(bound_, str); + lookup_.Insert(Lookup::Packing::Make(hashed, bound_)); + return bound_++; } } void ProbingVocabulary::FinishedLoading(ProbBackoff * /*reorder_vocab*/) { lookup_.FinishedInserting(); + header_->bound = bound_; + header_->version = kProbingVocabularyVersion; SetSpecial(Index(""), Index(""), 0); } void ProbingVocabulary::LoadedBinary(int fd, EnumerateVocab *to) { + UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); lookup_.LoadedBinary(); - available_ = ReadWords(fd, to); + ReadWords(fd, to); + bound_ = header_->bound; SetSpecial(Index(""), Index(""), 0); } diff --git a/klm/lm/vocab.hh b/klm/lm/vocab.hh index 9d218fff..41e97052 100644 --- a/klm/lm/vocab.hh +++ b/klm/lm/vocab.hh @@ -25,6 +25,7 @@ uint64_t HashForVocab(const char *str, std::size_t len); inline uint64_t HashForVocab(const StringPiece &str) { return HashForVocab(str.data(), str.length()); } +class ProbingVocabularyHeader; } // namespace detail class WriteWordsWrapper : public EnumerateVocab { @@ -113,10 +114,7 @@ class ProbingVocabulary : public base::Vocabulary { static size_t Size(std::size_t entries, const Config &config); // Vocab words are [0, Bound()). - // WARNING WARNING: returns UINT_MAX when loading binary and not enumerating vocabulary. - // Fixing this bug requires a binary file format change and will be fixed with the next binary file format update. - // Specifically, the binary file format does not currently indicate whether is in count or not. - WordIndex Bound() const { return available_; } + WordIndex Bound() const { return bound_; } // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); @@ -141,11 +139,13 @@ class ProbingVocabulary : public base::Vocabulary { Lookup lookup_; - WordIndex available_; + WordIndex bound_; bool saw_unk_; EnumerateVocab *enumerate_; + + detail::ProbingVocabularyHeader *header_; }; void MissingUnknown(const Config &config) throw(SpecialWordMissingException); diff --git a/klm/test.sh b/klm/test.sh index d02a3dc9..fb33300a 100755 --- a/klm/test.sh +++ b/klm/test.sh @@ -2,7 +2,7 @@ #Run tests. Requires Boost. set -e ./compile.sh -for i in util/{bit_packing,file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/model_test; do +for i in util/{bit_packing,file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/{model,left}_test; do g++ -I. -O3 $CXXFLAGS $i.cc {lm,util}/*.o -lboost_test_exec_monitor -lz -o $i pushd $(dirname $i) >/dev/null && ./$(basename $i) || echo "$i failed"; popd >/dev/null done diff --git a/klm/util/bit_packing.hh b/klm/util/bit_packing.hh index 9f47d559..33266b94 100644 --- a/klm/util/bit_packing.hh +++ b/klm/util/bit_packing.hh @@ -86,6 +86,20 @@ inline void WriteFloat32(void *base, uint64_t bit_off, float value) { const uint32_t kSignBit = 0x80000000; +inline void SetSign(float &to) { + FloatEnc enc; + enc.f = to; + enc.i |= kSignBit; + to = enc.f; +} + +inline void UnsetSign(float &to) { + FloatEnc enc; + enc.f = to; + enc.i &= ~kSignBit; + to = enc.f; +} + inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { FloatEnc encoded; encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 62280970..96951495 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -79,4 +79,9 @@ ErrnoException::ErrnoException() throw() : errno_(errno) { ErrnoException::~ErrnoException() throw() {} +EndOfFileException::EndOfFileException() throw() { + *this << "End of file"; +} +EndOfFileException::~EndOfFileException() throw() {} + } // namespace util diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 81675a57..6d6a37cb 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -105,6 +105,12 @@ class ErrnoException : public Exception { int errno_; }; +class EndOfFileException : public Exception { + public: + EndOfFileException() throw(); + ~EndOfFileException() throw(); +}; + } // namespace util #endif // UTIL_EXCEPTION__ diff --git a/klm/util/file.cc b/klm/util/file.cc new file mode 100644 index 00000000..d707568e --- /dev/null +++ b/klm/util/file.cc @@ -0,0 +1,74 @@ +#include "util/file.hh" + +#include "util/exception.hh" + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace util { + +scoped_fd::~scoped_fd() { + if (fd_ != -1 && close(fd_)) { + std::cerr << "Could not close file " << fd_ << std::endl; + std::abort(); + } +} + +scoped_FILE::~scoped_FILE() { + if (file_ && std::fclose(file_)) { + std::cerr << "Could not close file " << std::endl; + std::abort(); + } +} + +int OpenReadOrThrow(const char *name) { + int ret; + UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name); + return ret; +} + +int CreateOrThrow(const char *name) { + int ret; + UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR)), ErrnoException, "while creating " << name); + return ret; +} + +off_t SizeFile(int fd) { + struct stat sb; + if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; + return sb.st_size; +} + +void ReadOrThrow(int fd, void *to_void, std::size_t amount) { + uint8_t *to = static_cast(to_void); + while (amount) { + ssize_t ret = read(fd, to, amount); + if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); + if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); + amount -= ret; + to += ret; + } +} + +void WriteOrThrow(int fd, const void *data_void, std::size_t size) { + const uint8_t *data = static_cast(data_void); + while (size) { + ssize_t ret = write(fd, data, size); + if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); + data += ret; + size -= ret; + } +} + +void RemoveOrThrow(const char *name) { + UTIL_THROW_IF(std::remove(name), util::ErrnoException, "Could not remove " << name); +} + +} // namespace util diff --git a/klm/util/file.hh b/klm/util/file.hh new file mode 100644 index 00000000..d6cca41d --- /dev/null +++ b/klm/util/file.hh @@ -0,0 +1,74 @@ +#ifndef UTIL_FILE__ +#define UTIL_FILE__ + +#include +#include + +namespace util { + +class scoped_fd { + public: + scoped_fd() : fd_(-1) {} + + explicit scoped_fd(int fd) : fd_(fd) {} + + ~scoped_fd(); + + void reset(int to) { + scoped_fd other(fd_); + fd_ = to; + } + + int get() const { return fd_; } + + int operator*() const { return fd_; } + + int release() { + int ret = fd_; + fd_ = -1; + return ret; + } + + operator bool() { return fd_ != -1; } + + private: + int fd_; + + scoped_fd(const scoped_fd &); + scoped_fd &operator=(const scoped_fd &); +}; + +class scoped_FILE { + public: + explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} + + ~scoped_FILE(); + + std::FILE *get() { return file_; } + const std::FILE *get() const { return file_; } + + void reset(std::FILE *to = NULL) { + scoped_FILE other(file_); + file_ = to; + } + + private: + std::FILE *file_; +}; + +int OpenReadOrThrow(const char *name); + +int CreateOrThrow(const char *name); + +// Return value for SizeFile when it can't size properly. +const off_t kBadSize = -1; +off_t SizeFile(int fd); + +void ReadOrThrow(int fd, void *to, std::size_t size); +void WriteOrThrow(int fd, const void *data_void, std::size_t size); + +void RemoveOrThrow(const char *name); + +} // namespace util + +#endif // UTIL_FILE__ diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index cbe4234f..b57582a0 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -1,6 +1,7 @@ #include "util/file_piece.hh" #include "util/exception.hh" +#include "util/file.hh" #include #include @@ -21,11 +22,6 @@ namespace util { -EndOfFileException::EndOfFileException() throw() { - *this << "End of file"; -} -EndOfFileException::~EndOfFileException() throw() {} - ParseNumberException::ParseNumberException(StringPiece value) throw() { *this << "Could not parse \"" << value << "\" into a number"; } @@ -40,18 +36,6 @@ GZException::GZException(void *file) { // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; -int OpenReadOrThrow(const char *name) { - int ret; - UTIL_THROW_IF(-1 == (ret = open(name, O_RDONLY)), ErrnoException, "while opening " << name); - return ret; -} - -off_t SizeFile(int fd) { - struct stat sb; - if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; - return sb.st_size; -} - FilePiece::FilePiece(const char *name, std::ostream *show_progress, off_t min_buffer) : file_(OpenReadOrThrow(name)), total_size_(SizeFile(file_.get())), page_(sysconf(_SC_PAGE_SIZE)), progress_(total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name, total_size_) { diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index a5c00910..a627f38c 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -3,9 +3,9 @@ #include "util/ersatz_progress.hh" #include "util/exception.hh" +#include "util/file.hh" #include "util/have.hh" #include "util/mmap.hh" -#include "util/scoped.hh" #include "util/string_piece.hh" #include @@ -14,12 +14,6 @@ namespace util { -class EndOfFileException : public Exception { - public: - EndOfFileException() throw(); - ~EndOfFileException() throw(); -}; - class ParseNumberException : public Exception { public: explicit ParseNumberException(StringPiece value) throw(); @@ -33,14 +27,8 @@ class GZException : public Exception { ~GZException() throw() {} }; -int OpenReadOrThrow(const char *name); - extern const bool kSpaces[256]; -// Return value for SizeFile when it can't size properly. -const off_t kBadSize = -1; -off_t SizeFile(int fd); - // Memory backing the returned StringPiece may vanish on the next call. class FilePiece { public: diff --git a/klm/util/mmap.cc b/klm/util/mmap.cc index e7c0643b..5ce7adc9 100644 --- a/klm/util/mmap.cc +++ b/klm/util/mmap.cc @@ -1,6 +1,6 @@ #include "util/exception.hh" +#include "util/file.hh" #include "util/mmap.hh" -#include "util/scoped.hh" #include @@ -66,20 +66,6 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int return ret; } -namespace { -void ReadAll(int fd, void *to_void, std::size_t amount) { - uint8_t *to = static_cast(to_void); - while (amount) { - ssize_t ret = read(fd, to, amount); - if (ret == -1) UTIL_THROW(ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); - if (ret == 0) UTIL_THROW(Exception, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); - amount -= ret; - to += ret; - } -} - -} // namespace - const int kFileFlags = #ifdef MAP_FILE MAP_FILE | MAP_SHARED @@ -106,7 +92,7 @@ void MapRead(LoadMethod method, int fd, off_t offset, std::size_t size, scoped_m out.reset(malloc(size), size, scoped_memory::MALLOC_ALLOCATED); if (!out.get()) UTIL_THROW(util::ErrnoException, "Allocating " << size << " bytes with malloc"); if (-1 == lseek(fd, offset, SEEK_SET)) UTIL_THROW(ErrnoException, "lseek to " << offset << " in fd " << fd << " failed."); - ReadAll(fd, out.get(), size); + ReadOrThrow(fd, out.get(), size); break; } } diff --git a/klm/util/mmap.hh b/klm/util/mmap.hh index e4439fa4..b0eb6672 100644 --- a/klm/util/mmap.hh +++ b/klm/util/mmap.hh @@ -2,8 +2,6 @@ #define UTIL_MMAP__ // Utilities for mmaped files. -#include "util/scoped.hh" - #include #include @@ -11,6 +9,8 @@ namespace util { +class scoped_fd; + // (void*)-1 is MAP_FAILED; this is done to avoid including the mmap header here. class scoped_mmap { public: diff --git a/klm/util/murmur_hash.cc b/klm/util/murmur_hash.cc index fec47fd9..d58a0727 100644 --- a/klm/util/murmur_hash.cc +++ b/klm/util/murmur_hash.cc @@ -1,129 +1,129 @@ -/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All - * code is released to the public domain. For business purposes, Murmurhash is - * under the MIT license." - * This is modified from the original: - * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. - * length changed to unsigned int. - * placed in namespace util - * add MurmurHashNative - * default option = 0 for seed - */ - -#include "util/murmur_hash.hh" - -namespace util { - -//----------------------------------------------------------------------------- -// MurmurHash2, 64-bit versions, by Austin Appleby - -// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment -// and endian-ness issues if used across multiple platforms. - -// 64-bit hash for 64-bit platforms - -uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) -{ - const uint64_t m = 0xc6a4a7935bd1e995ULL; - const int r = 47; - - uint64_t h = seed ^ (len * m); - - const uint64_t * data = (const uint64_t *)key; - const uint64_t * end = data + (len/8); - - while(data != end) - { - uint64_t k = *data++; - - k *= m; - k ^= k >> r; - k *= m; - - h ^= k; - h *= m; - } - - const unsigned char * data2 = (const unsigned char*)data; - - switch(len & 7) - { - case 7: h ^= uint64_t(data2[6]) << 48; - case 6: h ^= uint64_t(data2[5]) << 40; - case 5: h ^= uint64_t(data2[4]) << 32; - case 4: h ^= uint64_t(data2[3]) << 24; - case 3: h ^= uint64_t(data2[2]) << 16; - case 2: h ^= uint64_t(data2[1]) << 8; - case 1: h ^= uint64_t(data2[0]); - h *= m; - }; - - h ^= h >> r; - h *= m; - h ^= h >> r; - - return h; -} - - -// 64-bit hash for 32-bit platforms - -uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) -{ - const unsigned int m = 0x5bd1e995; - const int r = 24; - - unsigned int h1 = seed ^ len; - unsigned int h2 = 0; - - const unsigned int * data = (const unsigned int *)key; - - while(len >= 8) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - - unsigned int k2 = *data++; - k2 *= m; k2 ^= k2 >> r; k2 *= m; - h2 *= m; h2 ^= k2; - len -= 4; - } - - if(len >= 4) - { - unsigned int k1 = *data++; - k1 *= m; k1 ^= k1 >> r; k1 *= m; - h1 *= m; h1 ^= k1; - len -= 4; - } - - switch(len) - { - case 3: h2 ^= ((unsigned char*)data)[2] << 16; - case 2: h2 ^= ((unsigned char*)data)[1] << 8; - case 1: h2 ^= ((unsigned char*)data)[0]; - h2 *= m; - }; - - h1 ^= h2 >> 18; h1 *= m; - h2 ^= h1 >> 22; h2 *= m; - h1 ^= h2 >> 17; h1 *= m; - h2 ^= h1 >> 19; h2 *= m; - - uint64_t h = h1; - - h = (h << 32) | h2; - - return h; -} - -uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { - if (sizeof(int) == 4) { - return MurmurHash64B(key, len, seed); - } else { - return MurmurHash64A(key, len, seed); - } -} - -} // namespace util +/* Downloaded from http://sites.google.com/site/murmurhash/ which says "All + * code is released to the public domain. For business purposes, Murmurhash is + * under the MIT license." + * This is modified from the original: + * ULL tag on 0xc6a4a7935bd1e995 so this will compile on 32-bit. + * length changed to unsigned int. + * placed in namespace util + * add MurmurHashNative + * default option = 0 for seed + */ + +#include "util/murmur_hash.hh" + +namespace util { + +//----------------------------------------------------------------------------- +// MurmurHash2, 64-bit versions, by Austin Appleby + +// The same caveats as 32-bit MurmurHash2 apply here - beware of alignment +// and endian-ness issues if used across multiple platforms. + +// 64-bit hash for 64-bit platforms + +uint64_t MurmurHash64A ( const void * key, std::size_t len, unsigned int seed ) +{ + const uint64_t m = 0xc6a4a7935bd1e995ULL; + const int r = 47; + + uint64_t h = seed ^ (len * m); + + const uint64_t * data = (const uint64_t *)key; + const uint64_t * end = data + (len/8); + + while(data != end) + { + uint64_t k = *data++; + + k *= m; + k ^= k >> r; + k *= m; + + h ^= k; + h *= m; + } + + const unsigned char * data2 = (const unsigned char*)data; + + switch(len & 7) + { + case 7: h ^= uint64_t(data2[6]) << 48; + case 6: h ^= uint64_t(data2[5]) << 40; + case 5: h ^= uint64_t(data2[4]) << 32; + case 4: h ^= uint64_t(data2[3]) << 24; + case 3: h ^= uint64_t(data2[2]) << 16; + case 2: h ^= uint64_t(data2[1]) << 8; + case 1: h ^= uint64_t(data2[0]); + h *= m; + }; + + h ^= h >> r; + h *= m; + h ^= h >> r; + + return h; +} + + +// 64-bit hash for 32-bit platforms + +uint64_t MurmurHash64B ( const void * key, std::size_t len, unsigned int seed ) +{ + const unsigned int m = 0x5bd1e995; + const int r = 24; + + unsigned int h1 = seed ^ len; + unsigned int h2 = 0; + + const unsigned int * data = (const unsigned int *)key; + + while(len >= 8) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + + unsigned int k2 = *data++; + k2 *= m; k2 ^= k2 >> r; k2 *= m; + h2 *= m; h2 ^= k2; + len -= 4; + } + + if(len >= 4) + { + unsigned int k1 = *data++; + k1 *= m; k1 ^= k1 >> r; k1 *= m; + h1 *= m; h1 ^= k1; + len -= 4; + } + + switch(len) + { + case 3: h2 ^= ((unsigned char*)data)[2] << 16; + case 2: h2 ^= ((unsigned char*)data)[1] << 8; + case 1: h2 ^= ((unsigned char*)data)[0]; + h2 *= m; + }; + + h1 ^= h2 >> 18; h1 *= m; + h2 ^= h1 >> 22; h2 *= m; + h1 ^= h2 >> 17; h1 *= m; + h2 ^= h1 >> 19; h2 *= m; + + uint64_t h = h1; + + h = (h << 32) | h2; + + return h; +} + +uint64_t MurmurHashNative(const void * key, std::size_t len, unsigned int seed) { + if (sizeof(int) == 4) { + return MurmurHash64B(key, len, seed); + } else { + return MurmurHash64A(key, len, seed); + } +} + +} // namespace util diff --git a/klm/util/scoped.cc b/klm/util/scoped.cc deleted file mode 100644 index a4cc5016..00000000 --- a/klm/util/scoped.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "util/scoped.hh" - -#include - -#include -#include - -namespace util { - -scoped_fd::~scoped_fd() { - if (fd_ != -1 && close(fd_)) { - std::cerr << "Could not close file " << fd_ << std::endl; - abort(); - } -} - -scoped_FILE::~scoped_FILE() { - if (file_ && fclose(file_)) { - std::cerr << "Could not close file " << std::endl; - abort(); - } -} - -} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index d36a7df3..12e6652b 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -1,10 +1,11 @@ #ifndef UTIL_SCOPED__ #define UTIL_SCOPED__ -/* Other scoped objects in the style of scoped_ptr. */ +#include "util/exception.hh" +/* Other scoped objects in the style of scoped_ptr. */ #include -#include +#include namespace util { @@ -34,52 +35,33 @@ template class scoped_thing { scoped_thing &operator=(const scoped_thing &); }; -class scoped_fd { +class scoped_malloc { public: - scoped_fd() : fd_(-1) {} + scoped_malloc() : p_(NULL) {} - explicit scoped_fd(int fd) : fd_(fd) {} + scoped_malloc(void *p) : p_(p) {} - ~scoped_fd(); + ~scoped_malloc() { std::free(p_); } - void reset(int to) { - scoped_fd other(fd_); - fd_ = to; + void reset(void *p = NULL) { + scoped_malloc other(p_); + p_ = p; } - int get() const { return fd_; } - - int operator*() const { return fd_; } - - int release() { - int ret = fd_; - fd_ = -1; - return ret; + void call_realloc(std::size_t to) { + void *ret; + UTIL_THROW_IF(!(ret = std::realloc(p_, to)), util::ErrnoException, "realloc to " << to << " bytes failed."); + p_ = ret; } - private: - int fd_; - - scoped_fd(const scoped_fd &); - scoped_fd &operator=(const scoped_fd &); -}; - -class scoped_FILE { - public: - explicit scoped_FILE(std::FILE *file = NULL) : file_(file) {} - - ~scoped_FILE(); - - std::FILE *get() { return file_; } - const std::FILE *get() const { return file_; } - - void reset(std::FILE *to = NULL) { - scoped_FILE other(file_); - file_ = to; - } + void *get() { return p_; } + const void *get() const { return p_; } private: - std::FILE *file_; + void *p_; + + scoped_malloc(const scoped_malloc &); + scoped_malloc &operator=(const scoped_malloc &); }; // Hat tip to boost. diff --git a/klm/util/sized_iterator.hh b/klm/util/sized_iterator.hh new file mode 100644 index 00000000..47dfc245 --- /dev/null +++ b/klm/util/sized_iterator.hh @@ -0,0 +1,107 @@ +#ifndef UTIL_SIZED_ITERATOR__ +#define UTIL_SIZED_ITERATOR__ + +#include "util/proxy_iterator.hh" + +#include +#include + +#include +#include + +namespace util { + +class SizedInnerIterator { + public: + SizedInnerIterator() {} + + SizedInnerIterator(void *ptr, std::size_t size) : ptr_(static_cast(ptr)), size_(size) {} + + bool operator==(const SizedInnerIterator &other) const { + return ptr_ == other.ptr_; + } + bool operator<(const SizedInnerIterator &other) const { + return ptr_ < other.ptr_; + } + SizedInnerIterator &operator+=(std::ptrdiff_t amount) { + ptr_ += amount * size_; + return *this; + } + std::ptrdiff_t operator-(const SizedInnerIterator &other) const { + return (ptr_ - other.ptr_) / size_; + } + + const void *Data() const { return ptr_; } + void *Data() { return ptr_; } + std::size_t EntrySize() const { return size_; } + + private: + uint8_t *ptr_; + std::size_t size_; +}; + +class SizedProxy { + public: + SizedProxy() {} + + SizedProxy(void *ptr, std::size_t size) : inner_(ptr, size) {} + + operator std::string() const { + return std::string(reinterpret_cast(inner_.Data()), inner_.EntrySize()); + } + + SizedProxy &operator=(const SizedProxy &from) { + memcpy(inner_.Data(), from.inner_.Data(), inner_.EntrySize()); + return *this; + } + + SizedProxy &operator=(const std::string &from) { + memcpy(inner_.Data(), from.data(), inner_.EntrySize()); + return *this; + } + + const void *Data() const { return inner_.Data(); } + void *Data() { return inner_.Data(); } + + private: + friend class util::ProxyIterator; + + typedef std::string value_type; + + typedef SizedInnerIterator InnerIterator; + + InnerIterator &Inner() { return inner_; } + const InnerIterator &Inner() const { return inner_; } + InnerIterator inner_; +}; + +typedef ProxyIterator SizedIterator; + +inline SizedIterator SizedIt(void *ptr, std::size_t size) { return SizedIterator(SizedProxy(ptr, size)); } + +// Useful wrapper for a comparison function i.e. sort. +template class SizedCompare : public std::binary_function { + public: + explicit SizedCompare(const Delegate &delegate = Delegate()) : delegate_(delegate) {} + + bool operator()(const Proxy &first, const Proxy &second) const { + return delegate_(first.Data(), second.Data()); + } + bool operator()(const Proxy &first, const std::string &second) const { + return delegate_(first.Data(), second.data()); + } + bool operator()(const std::string &first, const Proxy &second) const { + return delegate_(first.data(), second.Data()); + } + bool operator()(const std::string &first, const std::string &second) const { + return delegate_(first.data(), second.data()); + } + + const Delegate &GetDelegate() const { return delegate_; } + + private: + const Delegate delegate_; +}; + +} // namespace util +#endif // UTIL_SIZED_ITERATOR__ diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh new file mode 100644 index 00000000..ee1c7ab2 --- /dev/null +++ b/klm/util/tokenize_piece.hh @@ -0,0 +1,69 @@ +#ifndef UTIL_TOKENIZE_PIECE__ +#define UTIL_TOKENIZE_PIECE__ + +#include "util/string_piece.hh" + +#include + +/* Usage: + * + * for (PieceIterator<' '> i(" foo \r\n bar "); i; ++i) { + * std::cout << *i << "\n"; + * } + * + */ + +namespace util { + +// Tokenize a StringPiece using an iterator interface. boost::tokenizer doesn't work with StringPiece. +template class PieceIterator : public boost::iterator_facade, const StringPiece, boost::forward_traversal_tag> { + public: + // Default construct is end, which is also accessed by kEndPieceIterator; + PieceIterator() {} + + explicit PieceIterator(const StringPiece &str) + : after_(str) { + increment(); + } + + bool operator!() const { + return after_.data() == 0; + } + operator bool() const { + return after_.data() != 0; + } + + static PieceIterator end() { + return PieceIterator(); + } + + private: + friend class boost::iterator_core_access; + + void increment() { + const char *start = after_.data(); + for (; (start != after_.data() + after_.size()) && (d == *start); ++start) {} + if (start == after_.data() + after_.size()) { + // End condition. + after_.clear(); + return; + } + const char *finish = start; + for (; (finish != after_.data() + after_.size()) && (d != *finish); ++finish) {} + current_ = StringPiece(start, finish - start); + after_ = StringPiece(finish, after_.data() + after_.size() - finish); + } + + bool equal(const PieceIterator &other) const { + return after_.data() == other.after_.data(); + } + + const StringPiece &dereference() const { return current_; } + + StringPiece current_; + StringPiece after_; +}; + +} // namespace util + +#endif // UTIL_TOKENIZE_PIECE__ -- cgit v1.2.3