diff options
Diffstat (limited to 'klm/lm/builder/interpolate.cc')
-rw-r--r-- | klm/lm/builder/interpolate.cc | 122 |
1 files changed, 107 insertions, 15 deletions
diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc index 50026806..a7947a42 100644 --- a/klm/lm/builder/interpolate.cc +++ b/klm/lm/builder/interpolate.cc @@ -1,18 +1,74 @@ #include "lm/builder/interpolate.hh" +#include "lm/builder/hash_gamma.hh" #include "lm/builder/joint_order.hh" -#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram_stream.hh" #include "lm/builder/sort.hh" #include "lm/lm_exception.hh" +#include "util/fixed_array.hh" +#include "util/murmur_hash.hh" #include <assert.h> +#include <math.h> namespace lm { namespace builder { namespace { -class Callback { +/* Calculate q, the collapsed probability and backoff, as defined in + * @inproceedings{Heafield-rest, + * author = {Kenneth Heafield and Philipp Koehn and Alon Lavie}, + * title = {Language Model Rest Costs and Space-Efficient Storage}, + * year = {2012}, + * month = {July}, + * booktitle = {Proceedings of the Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning}, + * address = {Jeju Island, Korea}, + * pages = {1169--1178}, + * url = {http://kheafield.com/professional/edinburgh/rest\_paper.pdf}, + * } + * This is particularly convenient to calculate during interpolation because + * the needed backoff terms are already accessed at the same time. + */ +class OutputQ { public: - Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { + explicit OutputQ(std::size_t order) : q_delta_(order) {} + + void Gram(unsigned order_minus_1, float full_backoff, ProbBackoff &out) { + float &q_del = q_delta_[order_minus_1]; + if (order_minus_1) { + // Divide by context's backoff (which comes in as out.backoff) + q_del = q_delta_[order_minus_1 - 1] / out.backoff * full_backoff; + } else { + q_del = full_backoff; + } + out.prob = log10f(out.prob * q_del); + // TODO: stop wastefully outputting this! + out.backoff = 0.0; + } + + private: + // Product of backoffs in the numerator divided by backoffs in the + // denominator. Does not include + std::vector<float> q_delta_; +}; + +/* Default: output probability and backoff */ +class OutputProbBackoff { + public: + explicit OutputProbBackoff(std::size_t /*order*/) {} + + void Gram(unsigned /*order_minus_1*/, float full_backoff, ProbBackoff &out) const { + // Correcting for numerical precision issues. Take that IRST. + out.prob = std::min(0.0f, log10f(out.prob)); + out.backoff = log10f(full_backoff); + } +}; + +template <class Output> class Callback { + public: + Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds) + : backoffs_(backoffs.size()), probs_(backoffs.size() + 2), + prune_thresholds_(prune_thresholds), + output_(backoffs.size() + 1 /* order */) { probs_[0] = uniform_prob; for (std::size_t i = 0; i < backoffs.size(); ++i) { backoffs_.push_back(backoffs[i]); @@ -21,6 +77,10 @@ class Callback { ~Callback() { for (std::size_t i = 0; i < backoffs_.size(); ++i) { + if(prune_thresholds_[i + 1] > 0) + while(backoffs_[i]) + ++backoffs_[i]; + if (backoffs_[i]) { std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; abort(); @@ -32,34 +92,66 @@ class Callback { Payload &pay = gram.Value(); pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; probs_[order_minus_1 + 1] = pay.complete.prob; - pay.complete.prob = log10(pay.complete.prob); - // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. + + float out_backoff; if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { - pay.complete.backoff = log10(*static_cast<const float*>(backoffs_[order_minus_1].Get())); - ++backoffs_[order_minus_1]; + if(prune_thresholds_[order_minus_1 + 1] > 0) { + //Compute hash value for current context + uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex)); + + const HashGamma *hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get()); + while(current_hash != hashed_backoff->hash_value && ++backoffs_[order_minus_1]) + hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get()); + + if(current_hash == hashed_backoff->hash_value) { + out_backoff = hashed_backoff->gamma; + ++backoffs_[order_minus_1]; + } else { + // Has been pruned away so it is not a context anymore + out_backoff = 1.0; + } + } else { + out_backoff = *static_cast<const float*>(backoffs_[order_minus_1].Get()); + ++backoffs_[order_minus_1]; + } } else { - // Not a context. - pay.complete.backoff = 0.0; + // Not a context. + out_backoff = 1.0; } + + output_.Gram(order_minus_1, out_backoff, pay.complete); } void Exit(unsigned, const NGram &) const {} private: - FixedArray<util::stream::Stream> backoffs_; + util::FixedArray<util::stream::Stream> backoffs_; std::vector<float> probs_; + const std::vector<uint64_t>& prune_thresholds_; + + Output output_; }; } // namespace -Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) - : uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {} +Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool output_q) + : uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>. + backoffs_(backoffs), + prune_thresholds_(prune_thresholds), + output_q_(output_q) {} // perform order-wise interpolation -void Interpolate::Run(const ChainPositions &positions) { +void Interpolate::Run(const util::stream::ChainPositions &positions) { assert(positions.size() == backoffs_.size() + 1); - Callback callback(uniform_prob_, backoffs_); - JointOrder<Callback, SuffixOrder>(positions, callback); + if (output_q_) { + typedef Callback<OutputQ> C; + C callback(uniform_prob_, backoffs_, prune_thresholds_); + JointOrder<C, SuffixOrder>(positions, callback); + } else { + typedef Callback<OutputProbBackoff> C; + C callback(uniform_prob_, backoffs_, prune_thresholds_); + JointOrder<C, SuffixOrder>(positions, callback); + } } }} // namespaces |