From d3e2ec203a5cf550320caa8023ac3dd103b0be7d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Mon, 13 Oct 2014 00:42:37 -0400 Subject: new kenlm --- klm/lm/builder/initial_probabilities.cc | 191 ++++++++++++++++++++++++++++---- 1 file changed, 167 insertions(+), 24 deletions(-) (limited to 'klm/lm/builder/initial_probabilities.cc') diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc index 58b42a20..5d19a897 100644 --- a/klm/lm/builder/initial_probabilities.cc +++ b/klm/lm/builder/initial_probabilities.cc @@ -3,6 +3,8 @@ #include "lm/builder/discount.hh" #include "lm/builder/ngram_stream.hh" #include "lm/builder/sort.hh" +#include "lm/builder/hash_gamma.hh" +#include "util/murmur_hash.hh" #include "util/file.hh" #include "util/stream/chain.hh" #include "util/stream/io.hh" @@ -14,55 +16,182 @@ namespace lm { namespace builder { namespace { struct BufferEntry { - // Gamma from page 20 of Chen and Goodman. + // Gamma from page 20 of Chen and Goodman. float gamma; - // \sum_w a(c w) for all w. + // \sum_w a(c w) for all w. float denominator; }; -// Extract an array of gamma from an array of BufferEntry. +struct HashBufferEntry : public BufferEntry { + // Hash value of ngram. Used to join contexts with backoffs. + uint64_t hash_value; +}; + +// Reads all entries in order like NGramStream does. +// But deletes any entries that have CutoffCount below or equal to pruning +// threshold. +class PruneNGramStream { + public: + PruneNGramStream(const util::stream::ChainPosition &position) : + current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + dest_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + currentCount_(0), + block_(position) + { + StartBlock(); + } + + NGram &operator*() { return current_; } + NGram *operator->() { return ¤t_; } + + operator bool() const { + return block_; + } + + PruneNGramStream &operator++() { + assert(block_); + + if (current_.Order() > 1) { + if(currentCount_ > 0) { + if(dest_.Base() < current_.Base()) { + memcpy(dest_.Base(), current_.Base(), current_.TotalSize()); + } + dest_.NextInMemory(); + } + } else { + dest_.NextInMemory(); + } + + current_.NextInMemory(); + + uint8_t *block_base = static_cast(block_->Get()); + if (current_.Base() == block_base + block_->ValidSize()) { + block_->SetValidSize(dest_.Base() - block_base); + ++block_; + StartBlock(); + if (block_) { + currentCount_ = current_.CutoffCount(); + } + } else { + currentCount_ = current_.CutoffCount(); + } + + return *this; + } + + private: + void StartBlock() { + for (; ; ++block_) { + if (!block_) return; + if (block_->ValidSize()) break; + } + current_.ReBase(block_->Get()); + currentCount_ = current_.CutoffCount(); + + dest_.ReBase(block_->Get()); + } + + NGram current_; // input iterator + NGram dest_; // output iterator + + uint64_t currentCount_; + + util::stream::Link block_; +}; + +// Extract an array of HashedGamma from an array of BufferEntry. class OnlyGamma { public: + OnlyGamma(bool pruning) : pruning_(pruning) {} + void Run(const util::stream::ChainPosition &position) { for (util::stream::Link block_it(position); block_it; ++block_it) { - float *out = static_cast(block_it->Get()); - const float *in = out; - const float *end = static_cast(block_it->ValidEnd()); - for (out += 1, in += 2; in < end; out += 1, in += 2) { - *out = *in; + if(pruning_) { + const HashBufferEntry *in = static_cast(block_it->Get()); + const HashBufferEntry *end = static_cast(block_it->ValidEnd()); + + // Just make it point to the beginning of the stream so it can be overwritten + // With HashGamma values. Do not attempt to interpret the values until set below. + HashGamma *out = static_cast(block_it->Get()); + for (; in < end; out += 1, in += 1) { + // buffering, otherwise might overwrite values too early + float gamma_buf = in->gamma; + uint64_t hash_buf = in->hash_value; + + out->gamma = gamma_buf; + out->hash_value = hash_buf; + } + block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry)); + } + else { + float *out = static_cast(block_it->Get()); + const float *in = out; + const float *end = static_cast(block_it->ValidEnd()); + for (out += 1, in += 2; in < end; out += 1, in += 2) { + *out = *in; + } + block_it->SetValidSize(block_it->ValidSize() / 2); } - block_it->SetValidSize(block_it->ValidSize() / 2); } } + + private: + bool pruning_; }; class AddRight { public: - AddRight(const Discount &discount, const util::stream::ChainPosition &input) - : discount_(discount), input_(input) {} + AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning) + : discount_(discount), input_(input), pruning_(pruning) {} void Run(const util::stream::ChainPosition &output) { NGramStream in(input_); util::stream::Stream out(output); std::vector previous(in->Order() - 1); + // Silly windows requires this workaround to just get an invalid pointer when empty. + void *const previous_raw = previous.empty() ? NULL : static_cast(&previous[0]); const std::size_t size = sizeof(WordIndex) * previous.size(); + for(; in; ++out) { - memcpy(&previous[0], in->begin(), size); + memcpy(previous_raw, in->begin(), size); uint64_t denominator = 0; + uint64_t normalizer = 0; + uint64_t counts[4]; memset(counts, 0, sizeof(counts)); do { - denominator += in->Count(); - ++counts[std::min(in->Count(), static_cast(3))]; - } while (++in && !memcmp(&previous[0], in->begin(), size)); + denominator += in->UnmarkedCount(); + + // Collect unused probability mass from pruning. + // Becomes 0 for unpruned ngrams. + normalizer += in->UnmarkedCount() - in->CutoffCount(); + + // Chen&Goodman do not mention counting based on cutoffs, but + // backoff becomes larger than 1 otherwise, so probably needs + // to count cutoffs. Counts normally without pruning. + if(in->CutoffCount() > 0) + ++counts[std::min(in->CutoffCount(), static_cast(3))]; + + } while (++in && !memcmp(previous_raw, in->begin(), size)); + BufferEntry &entry = *reinterpret_cast(out.Get()); entry.denominator = static_cast(denominator); entry.gamma = 0.0; for (unsigned i = 1; i <= 3; ++i) { entry.gamma += discount_.Get(i) * static_cast(counts[i]); } + + // Makes model sum to 1 with pruning (I hope). + entry.gamma += normalizer; + entry.gamma /= entry.denominator; + + if(pruning_) { + // If pruning is enabled the stream actually contains HashBufferEntry, see InitialProbabilities(...), + // so add a hash value that identifies the current ngram. + static_cast(&entry)->hash_value = util::MurmurHashNative(previous_raw, size); + } } out.Poison(); } @@ -70,6 +199,7 @@ class AddRight { private: const Discount &discount_; const util::stream::ChainPosition input_; + bool pruning_; }; class MergeRight { @@ -82,7 +212,7 @@ class MergeRight { void Run(const util::stream::ChainPosition &primary) { util::stream::Stream summed(from_adder_); - NGramStream grams(primary); + PruneNGramStream grams(primary); // Without interpolation, the interpolation weight goes to . if (grams->Order() == 1 && !interpolate_unigrams_) { @@ -97,15 +227,16 @@ class MergeRight { ++summed; return; } - + std::vector previous(grams->Order() - 1); const std::size_t size = sizeof(WordIndex) * previous.size(); for (; grams; ++summed) { memcpy(&previous[0], grams->begin(), size); const BufferEntry &sums = *static_cast(summed.Get()); + do { Payload &pay = grams->Value(); - pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator; + pay.uninterp.prob = discount_.Apply(grams->UnmarkedCount()) / sums.denominator; pay.uninterp.gamma = sums.gamma; } while (++grams && !memcmp(&previous[0], grams->begin(), size)); } @@ -119,17 +250,29 @@ class MergeRight { } // namespace -void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) { - util::stream::ChainConfig gamma_config = config.adder_out; - gamma_config.entry_size = sizeof(BufferEntry); +void InitialProbabilities( + const InitialProbabilitiesConfig &config, + const std::vector &discounts, + util::stream::Chains &primary, + util::stream::Chains &second_in, + util::stream::Chains &gamma_out, + const std::vector &prune_thresholds) { for (size_t i = 0; i < primary.size(); ++i) { + util::stream::ChainConfig gamma_config = config.adder_out; + if(prune_thresholds[i] > 0) + gamma_config.entry_size = sizeof(HashBufferEntry); + else + gamma_config.entry_size = sizeof(BufferEntry); + util::stream::ChainPosition second(second_in[i].Add()); second_in[i] >> util::stream::kRecycle; gamma_out.push_back(gamma_config); - gamma_out[i] >> AddRight(discounts[i], second); + gamma_out[i] >> AddRight(discounts[i], second, prune_thresholds[i] > 0); + primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); - // Don't bother with the OnlyGamma thread for something to discard. - if (i) gamma_out[i] >> OnlyGamma(); + + // Don't bother with the OnlyGamma thread for something to discard. + if (i) gamma_out[i] >> OnlyGamma(prune_thresholds[i] > 0); } } -- cgit v1.2.3